mirror of
https://github.com/apache/superset.git
synced 2026-06-07 00:29:17 +00:00
feat: implement CVAS/CTAS in sqlglot (#33525)
This commit is contained in:
@@ -108,6 +108,11 @@ class LimitMethod(enum.Enum):
|
||||
FETCH_MANY = enum.auto()
|
||||
|
||||
|
||||
class CTASMethod(enum.Enum):
|
||||
TABLE = enum.auto()
|
||||
VIEW = enum.auto()
|
||||
|
||||
|
||||
class RLSMethod(enum.Enum):
|
||||
"""
|
||||
Methods for enforcing RLS.
|
||||
@@ -381,6 +386,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_select(self) -> bool:
|
||||
"""
|
||||
Check if the statement is a `SELECT` statement.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def is_mutating(self) -> bool:
|
||||
"""
|
||||
Check if the statement mutates data (DDL/DML).
|
||||
@@ -437,6 +448,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement:
|
||||
"""
|
||||
Rewrite the statement as a `CREATE TABLE AS` statement.
|
||||
|
||||
:param table: The table to create.
|
||||
:param method: The method to use for creating the table.
|
||||
:return: A new SQLStatement with the CTE.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_rls(
|
||||
self,
|
||||
catalog: str | None,
|
||||
@@ -480,7 +501,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
try:
|
||||
return sqlglot.parse(script, dialect=dialect)
|
||||
statements = sqlglot.parse(script, dialect=dialect)
|
||||
except sqlglot.errors.ParseError as ex:
|
||||
error = ex.errors[0]
|
||||
raise SupersetParseError(
|
||||
@@ -497,6 +518,20 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
message="Unable to parse script",
|
||||
) from ex
|
||||
|
||||
# `sqlglot` will parse comments after the last semicolon as a separate
|
||||
# statement; move them back to the last token in the last real statement
|
||||
if len(statements) > 1 and isinstance(statements[-1], exp.Semicolon):
|
||||
last_statement = statements.pop()
|
||||
target = statements[-1]
|
||||
for node in statements[-1].walk():
|
||||
if hasattr(node, "comments"):
|
||||
target = node
|
||||
|
||||
target.comments = target.comments or []
|
||||
target.comments.extend(last_statement.comments)
|
||||
|
||||
return statements
|
||||
|
||||
@classmethod
|
||||
def split_script(
|
||||
cls,
|
||||
@@ -572,6 +607,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
return extract_tables_from_statement(parsed, dialect)
|
||||
|
||||
def is_select(self) -> bool:
|
||||
"""
|
||||
Check if the statement is a `SELECT` statement.
|
||||
"""
|
||||
return isinstance(self._parsed, exp.Select)
|
||||
|
||||
def is_mutating(self) -> bool:
|
||||
"""
|
||||
Check if the statement mutates data (DDL/DML).
|
||||
@@ -733,6 +774,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
engine=self.engine,
|
||||
)
|
||||
|
||||
def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement:
|
||||
"""
|
||||
Rewrite the statement as a `CREATE TABLE AS` statement.
|
||||
|
||||
:param table: The table to create.
|
||||
:param method: The method to use for creating the table.
|
||||
:return: A new SQLStatement with the create table statement.
|
||||
"""
|
||||
create_table = exp.Create(
|
||||
this=sqlglot.parse_one(str(table), into=exp.Table),
|
||||
kind=method.name,
|
||||
expression=self._parsed.copy(),
|
||||
)
|
||||
|
||||
return SQLStatement(ast=create_table, engine=self.engine)
|
||||
|
||||
def apply_rls(
|
||||
self,
|
||||
catalog: str | None,
|
||||
@@ -988,6 +1045,12 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
|
||||
return {}
|
||||
|
||||
def is_select(self) -> bool:
|
||||
"""
|
||||
Check if the statement is a `SELECT` statement.
|
||||
"""
|
||||
return not self._parsed.startswith(".")
|
||||
|
||||
def is_mutating(self) -> bool:
|
||||
"""
|
||||
Check if the statement mutates data (DDL/DML).
|
||||
@@ -1142,6 +1205,24 @@ class SQLScript:
|
||||
for statement in self.statements
|
||||
)
|
||||
|
||||
def is_valid_ctas(self) -> bool:
|
||||
"""
|
||||
Check if the script contains a valid CTAS statement.
|
||||
|
||||
CTAS (`CREATE TABLE AS SELECT`) can only be run with scripts where the last
|
||||
statement is a `SELECT`.
|
||||
"""
|
||||
return self.statements[-1].is_select()
|
||||
|
||||
def is_valid_cvas(self) -> bool:
|
||||
"""
|
||||
Check if the script contains a valid CVAS statement.
|
||||
|
||||
CVAS (`CREATE VIEW AS SELECT`) can only be run with scripts with a single
|
||||
`SELECT` statement.
|
||||
"""
|
||||
return len(self.statements) == 1 and self.statements[0].is_select()
|
||||
|
||||
|
||||
def extract_tables_from_statement(
|
||||
statement: exp.Expression,
|
||||
|
||||
Reference in New Issue
Block a user