feat: implement CVAS/CTAS in sqlglot (#33525)

This commit is contained in:
Beto Dealmeida
2025-05-28 09:45:59 -04:00
committed by GitHub
parent 0abe6eed89
commit ea5a609d0b
2 changed files with 220 additions and 1 deletions

View File

@@ -108,6 +108,11 @@ class LimitMethod(enum.Enum):
FETCH_MANY = enum.auto()
class CTASMethod(enum.Enum):
TABLE = enum.auto()
VIEW = enum.auto()
class RLSMethod(enum.Enum):
"""
Methods for enforcing RLS.
@@ -381,6 +386,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def is_select(self) -> bool:
"""
Check if the statement is a `SELECT` statement.
"""
raise NotImplementedError()
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
@@ -437,6 +448,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement:
"""
Rewrite the statement as a `CREATE TABLE AS` statement.
:param table: The table to create.
:param method: The method to use for creating the table.
:return: A new SQLStatement with the CTE.
"""
raise NotImplementedError()
def apply_rls(
self,
catalog: str | None,
@@ -480,7 +501,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
dialect = SQLGLOT_DIALECTS.get(engine)
try:
return sqlglot.parse(script, dialect=dialect)
statements = sqlglot.parse(script, dialect=dialect)
except sqlglot.errors.ParseError as ex:
error = ex.errors[0]
raise SupersetParseError(
@@ -497,6 +518,20 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
message="Unable to parse script",
) from ex
# `sqlglot` will parse comments after the last semicolon as a separate
# statement; move them back to the last token in the last real statement
if len(statements) > 1 and isinstance(statements[-1], exp.Semicolon):
last_statement = statements.pop()
target = statements[-1]
for node in statements[-1].walk():
if hasattr(node, "comments"):
target = node
target.comments = target.comments or []
target.comments.extend(last_statement.comments)
return statements
@classmethod
def split_script(
cls,
@@ -572,6 +607,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
dialect = SQLGLOT_DIALECTS.get(engine)
return extract_tables_from_statement(parsed, dialect)
def is_select(self) -> bool:
"""
Check if the statement is a `SELECT` statement.
"""
return isinstance(self._parsed, exp.Select)
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
@@ -733,6 +774,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
engine=self.engine,
)
def as_create_table(self, table: Table, method: CTASMethod) -> SQLStatement:
"""
Rewrite the statement as a `CREATE TABLE AS` statement.
:param table: The table to create.
:param method: The method to use for creating the table.
:return: A new SQLStatement with the create table statement.
"""
create_table = exp.Create(
this=sqlglot.parse_one(str(table), into=exp.Table),
kind=method.name,
expression=self._parsed.copy(),
)
return SQLStatement(ast=create_table, engine=self.engine)
def apply_rls(
self,
catalog: str | None,
@@ -988,6 +1045,12 @@ class KustoKQLStatement(BaseSQLStatement[str]):
return {}
def is_select(self) -> bool:
"""
Check if the statement is a `SELECT` statement.
"""
return not self._parsed.startswith(".")
def is_mutating(self) -> bool:
"""
Check if the statement mutates data (DDL/DML).
@@ -1142,6 +1205,24 @@ class SQLScript:
for statement in self.statements
)
def is_valid_ctas(self) -> bool:
"""
Check if the script contains a valid CTAS statement.
CTAS (`CREATE TABLE AS SELECT`) can only be run with scripts where the last
statement is a `SELECT`.
"""
return self.statements[-1].is_select()
def is_valid_cvas(self) -> bool:
"""
Check if the script contains a valid CVAS statement.
CVAS (`CREATE VIEW AS SELECT`) can only be run with scripts with a single
`SELECT` statement.
"""
return len(self.statements) == 1 and self.statements[0].is_select()
def extract_tables_from_statement(
statement: exp.Expression,