chore: sql/parse cleanup (#33515)

This commit is contained in:
Beto Dealmeida
2025-05-27 16:42:04 -04:00
committed by GitHub
parent b7ba50033a
commit 1393f7d3d2
4 changed files with 38 additions and 20 deletions

View File

@@ -164,12 +164,17 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
def __init__(
self,
statement: str,
engine: str,
statement: str | None = None,
engine: str = "base",
ast: InternalRepresentation | None = None,
):
self._sql = statement
self._parsed = ast or self._parse_statement(statement, engine)
if ast:
self._parsed = ast
elif statement:
self._parsed = self._parse_statement(statement, engine)
else:
raise SupersetParseError("Either statement or ast must be provided")
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
@@ -284,8 +289,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
def __init__(
self,
statement: str,
engine: str,
statement: str | None = None,
engine: str = "base",
ast: exp.Expression | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
@@ -423,7 +428,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
and self._parsed.expression.name.upper().startswith("ANALYZE ")
):
analyzed_sql = self._parsed.expression.name[len("ANALYZE ") :]
return SQLStatement(analyzed_sql, self.engine).is_mutating()
return SQLStatement(
statement=analyzed_sql,
engine=self.engine,
).is_mutating()
return False
@@ -459,12 +467,11 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
# only optimize statements that have a custom dialect
if not self._dialect:
return SQLStatement(self._sql, self.engine, self._parsed.copy())
return SQLStatement(ast=self._parsed.copy(), engine=self.engine)
optimized = pushdown_predicates(self._parsed, dialect=self._dialect)
sql = optimized.sql(dialect=self._dialect)
return SQLStatement(sql, self.engine, optimized)
return SQLStatement(ast=optimized, engine=self.engine)
def check_functions_present(self, functions: set[str]) -> bool:
"""
@@ -668,6 +675,14 @@ class KustoKQLStatement(BaseSQLStatement[str]):
details about it.
"""
def __init__(
self,
statement: str | None = None,
engine: str = "kustokql",
ast: str | None = None,
):
super().__init__(statement, engine, ast)
@classmethod
def split_script(
cls,
@@ -725,7 +740,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
"""
Pretty-format the SQL statement.
"""
return self._sql.strip()
return self._parsed.strip()
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -756,7 +771,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
Kusto KQL doesn't support optimization, so this method is a no-op.
"""
return KustoKQLStatement(self._sql, self.engine, self._parsed)
return KustoKQLStatement(ast=self._parsed, engine=self.engine)
def check_functions_present(self, functions: set[str]) -> bool:
"""
@@ -774,7 +789,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
"""
tokens = [
token
for token in tokenize_kql(self._sql)
for token in tokenize_kql(self._parsed)
if token[0] != KQLTokenType.WHITESPACE
]
for idx, (ttype, val) in enumerate(tokens):
@@ -796,7 +811,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
if method != LimitMethod.FORCE_LIMIT:
raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.")
tokens = tokenize_kql(self._sql)
tokens = tokenize_kql(self._parsed)
found_limit_token = False
for idx, (ttype, val) in enumerate(tokens):
if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}:
@@ -817,7 +832,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
]
)
self._parsed = self._sql = "".join(val for _, val in tokens)
self._parsed = "".join(val for _, val in tokens)
class SQLScript: