chore: 100% test coverage for SQL parsing (#33568)

This commit is contained in:
Beto Dealmeida
2025-06-04 22:18:09 -04:00
committed by GitHub
parent c9518485ba
commit edc60914f6
8 changed files with 853 additions and 113 deletions

View File

@@ -551,7 +551,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
last_statement = statements.pop()
target = statements[-1]
for node in statements[-1].walk():
if hasattr(node, "comments"):
if hasattr(node, "comments"): # pragma: no cover
target = node
target.comments = target.comments or []
@@ -565,47 +565,9 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
script: str,
engine: str,
) -> list[SQLStatement]:
if dialect := SQLGLOT_DIALECTS.get(engine):
try:
return [
cls(ast.sql(), engine, ast)
for ast in cls._parse(script, engine)
if ast
]
except ValueError:
# `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES
# FROM`). In this case, we rely on the tokenizer to generate the
# statements.
pass
# When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly
# generate the SQL of each statement, so we tokenize the script and split it
# based on the location of semi-colons.
statements = []
start = 0
remainder = script
try:
tokens = sqlglot.tokenize(script, dialect)
except sqlglot.errors.TokenError as ex:
raise SupersetParseError(
script,
engine,
message="Unable to tokenize script",
) from ex
for token in tokens:
if token.token_type == sqlglot.TokenType.SEMICOLON:
statement, start = script[start : token.start], token.end + 1
ast = cls._parse(statement, engine)[0]
statements.append(cls(statement.strip(), engine, ast))
remainder = script[start:]
if remainder.strip():
ast = cls._parse(remainder, engine)[0]
statements.append(cls(remainder.strip(), engine, ast))
return statements
return [
cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast
]
@classmethod
def _parse_statement(
@@ -618,7 +580,11 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
statements = cls.split_script(statement, engine)
if len(statements) != 1:
raise SupersetParseError("SQLStatement should have exactly one statement")
raise SupersetParseError(
statement,
engine,
message="SQLStatement should have exactly one statement",
)
return statements[0]._parsed # pylint: disable=protected-access
@@ -657,10 +623,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
exp.Create,
exp.Drop,
exp.TruncateTable,
exp.Alter,
),
):
return True
# depending on the dialect (Oracle, MS SQL) the `ALTER` is parsed as a
# command, not an expression
if isinstance(node, exp.Command) and node.name == "ALTER":
return True
@@ -821,9 +790,16 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Check if the statement has a subquery.
:return: True if the statement has a subquery at the top level.
:return: True if the statement has a subquery.
"""
return bool(self._parsed.find(exp.Subquery))
return bool(self._parsed.find(exp.Subquery)) or (
isinstance(self._parsed, exp.Select)
and any(
isinstance(expression, exp.Select)
for expression in self._parsed.walk()
if expression != self._parsed
)
)
def parse_predicate(self, predicate: str) -> exp.Expression:
"""
@@ -933,11 +909,8 @@ def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]:
)
buffer = ch
elif ch == "`" and script[i - 2 : i] == "``":
if buffer:
tokens.extend(classify_non_string_kql(buffer))
buffer = ""
state = KQLSplitState.INSIDE_MULTILINE_STRING
buffer = "`"
buffer = "```"
else:
buffer += ch
else:
@@ -1042,11 +1015,19 @@ class KustoKQLStatement(BaseSQLStatement[str]):
engine: str,
) -> str:
if engine != "kustokql":
raise SupersetParseError(f"Invalid engine: {engine}")
raise SupersetParseError(
statement,
engine,
message=f"Invalid engine: {engine}",
)
statements = split_kql(statement)
if len(statements) != 1:
raise SupersetParseError("SQLStatement should have exactly one statement")
raise SupersetParseError(
statement,
engine,
message="KustoKQLStatement should have exactly one statement",
)
return statements[0].strip()
@@ -1122,7 +1103,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
:return: True if any of the functions are present
"""
logger.warning("Kusto KQL doesn't support checking for functions present.")
return True
return False
def get_limit_value(self) -> int | None:
"""
@@ -1150,7 +1131,11 @@ class KustoKQLStatement(BaseSQLStatement[str]):
Add a limit to the statement.
"""
if method != LimitMethod.FORCE_LIMIT:
raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.")
raise SupersetParseError(
self._parsed,
self.engine,
message="Kusto KQL only supports the FORCE_LIMIT method.",
)
tokens = tokenize_kql(self._parsed)
found_limit_token = False