mirror of
https://github.com/apache/superset.git
synced 2026-04-20 08:34:37 +00:00
chore: 100% test coverage for SQL parsing (#33568)
This commit is contained in:
@@ -48,7 +48,7 @@ class Firebolt(Dialect):
|
||||
self,
|
||||
this: exp.Expression | None = None,
|
||||
) -> exp.Expression | None:
|
||||
if not this:
|
||||
if not this: # pragma: no cover
|
||||
return this
|
||||
|
||||
return self.expression(exp.Not, this=self.expression(exp.Paren, this=this))
|
||||
@@ -109,42 +109,15 @@ class FireboltOld(Firebolt):
|
||||
expressions = self._parse_wrapped_csv(self._parse_expression)
|
||||
offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)
|
||||
|
||||
alias = self._parse_table_alias() if with_alias else None
|
||||
|
||||
if alias:
|
||||
if self.dialect.UNNEST_COLUMN_ONLY:
|
||||
if alias.args.get("columns"):
|
||||
self.raise_error("Unexpected extra column alias in unnest.")
|
||||
|
||||
alias.set("columns", [alias.this])
|
||||
alias.set("this", None)
|
||||
|
||||
columns = alias.args.get("columns") or []
|
||||
if offset and len(expressions) < len(columns):
|
||||
offset = columns.pop()
|
||||
|
||||
if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET):
|
||||
self._match(TokenType.ALIAS)
|
||||
offset = self._parse_id_var(
|
||||
any_token=False, tokens=self.UNNEST_OFFSET_ALIAS_TOKENS
|
||||
) or exp.to_identifier("offset")
|
||||
|
||||
return self.expression(
|
||||
exp.Unnest,
|
||||
expressions=expressions,
|
||||
alias=alias,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
class Generator(Firebolt.Generator):
|
||||
def join_sql(self, expression: exp.Join) -> str:
|
||||
if not self.SEMI_ANTI_JOIN_WITH_SIDE and expression.kind in (
|
||||
"SEMI",
|
||||
"ANTI",
|
||||
):
|
||||
side = None
|
||||
else:
|
||||
side = expression.side
|
||||
side = expression.side
|
||||
|
||||
op_sql = " ".join(
|
||||
op
|
||||
@@ -168,9 +141,6 @@ class FireboltOld(Firebolt):
|
||||
this = expression.this
|
||||
this_sql = self.sql(this)
|
||||
|
||||
if exprs := self.expressions(expression):
|
||||
this_sql = f"{this_sql},{self.seg(exprs)}"
|
||||
|
||||
if on_sql:
|
||||
on_sql = self.indent(on_sql, skip_first=True)
|
||||
space = self.seg(" " * self.pad) if self.pretty else " "
|
||||
@@ -189,7 +159,6 @@ class FireboltOld(Firebolt):
|
||||
|
||||
return f", {this_sql}"
|
||||
|
||||
if op_sql != "STRAIGHT_JOIN":
|
||||
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
|
||||
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"
|
||||
|
||||
return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
|
||||
|
||||
@@ -551,7 +551,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
last_statement = statements.pop()
|
||||
target = statements[-1]
|
||||
for node in statements[-1].walk():
|
||||
if hasattr(node, "comments"):
|
||||
if hasattr(node, "comments"): # pragma: no cover
|
||||
target = node
|
||||
|
||||
target.comments = target.comments or []
|
||||
@@ -565,47 +565,9 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
script: str,
|
||||
engine: str,
|
||||
) -> list[SQLStatement]:
|
||||
if dialect := SQLGLOT_DIALECTS.get(engine):
|
||||
try:
|
||||
return [
|
||||
cls(ast.sql(), engine, ast)
|
||||
for ast in cls._parse(script, engine)
|
||||
if ast
|
||||
]
|
||||
except ValueError:
|
||||
# `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES
|
||||
# FROM`). In this case, we rely on the tokenizer to generate the
|
||||
# statements.
|
||||
pass
|
||||
|
||||
# When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly
|
||||
# generate the SQL of each statement, so we tokenize the script and split it
|
||||
# based on the location of semi-colons.
|
||||
statements = []
|
||||
start = 0
|
||||
remainder = script
|
||||
|
||||
try:
|
||||
tokens = sqlglot.tokenize(script, dialect)
|
||||
except sqlglot.errors.TokenError as ex:
|
||||
raise SupersetParseError(
|
||||
script,
|
||||
engine,
|
||||
message="Unable to tokenize script",
|
||||
) from ex
|
||||
|
||||
for token in tokens:
|
||||
if token.token_type == sqlglot.TokenType.SEMICOLON:
|
||||
statement, start = script[start : token.start], token.end + 1
|
||||
ast = cls._parse(statement, engine)[0]
|
||||
statements.append(cls(statement.strip(), engine, ast))
|
||||
remainder = script[start:]
|
||||
|
||||
if remainder.strip():
|
||||
ast = cls._parse(remainder, engine)[0]
|
||||
statements.append(cls(remainder.strip(), engine, ast))
|
||||
|
||||
return statements
|
||||
return [
|
||||
cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
@@ -618,7 +580,11 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
statements = cls.split_script(statement, engine)
|
||||
if len(statements) != 1:
|
||||
raise SupersetParseError("SQLStatement should have exactly one statement")
|
||||
raise SupersetParseError(
|
||||
statement,
|
||||
engine,
|
||||
message="SQLStatement should have exactly one statement",
|
||||
)
|
||||
|
||||
return statements[0]._parsed # pylint: disable=protected-access
|
||||
|
||||
@@ -657,10 +623,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
exp.Create,
|
||||
exp.Drop,
|
||||
exp.TruncateTable,
|
||||
exp.Alter,
|
||||
),
|
||||
):
|
||||
return True
|
||||
|
||||
# depending on the dialect (Oracle, MS SQL) the `ALTER` is parsed as a
|
||||
# command, not an expression
|
||||
if isinstance(node, exp.Command) and node.name == "ALTER":
|
||||
return True
|
||||
|
||||
@@ -821,9 +790,16 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
Check if the statement has a subquery.
|
||||
|
||||
:return: True if the statement has a subquery at the top level.
|
||||
:return: True if the statement has a subquery.
|
||||
"""
|
||||
return bool(self._parsed.find(exp.Subquery))
|
||||
return bool(self._parsed.find(exp.Subquery)) or (
|
||||
isinstance(self._parsed, exp.Select)
|
||||
and any(
|
||||
isinstance(expression, exp.Select)
|
||||
for expression in self._parsed.walk()
|
||||
if expression != self._parsed
|
||||
)
|
||||
)
|
||||
|
||||
def parse_predicate(self, predicate: str) -> exp.Expression:
|
||||
"""
|
||||
@@ -933,11 +909,8 @@ def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]:
|
||||
)
|
||||
buffer = ch
|
||||
elif ch == "`" and script[i - 2 : i] == "``":
|
||||
if buffer:
|
||||
tokens.extend(classify_non_string_kql(buffer))
|
||||
buffer = ""
|
||||
state = KQLSplitState.INSIDE_MULTILINE_STRING
|
||||
buffer = "`"
|
||||
buffer = "```"
|
||||
else:
|
||||
buffer += ch
|
||||
else:
|
||||
@@ -1042,11 +1015,19 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
engine: str,
|
||||
) -> str:
|
||||
if engine != "kustokql":
|
||||
raise SupersetParseError(f"Invalid engine: {engine}")
|
||||
raise SupersetParseError(
|
||||
statement,
|
||||
engine,
|
||||
message=f"Invalid engine: {engine}",
|
||||
)
|
||||
|
||||
statements = split_kql(statement)
|
||||
if len(statements) != 1:
|
||||
raise SupersetParseError("SQLStatement should have exactly one statement")
|
||||
raise SupersetParseError(
|
||||
statement,
|
||||
engine,
|
||||
message="KustoKQLStatement should have exactly one statement",
|
||||
)
|
||||
|
||||
return statements[0].strip()
|
||||
|
||||
@@ -1122,7 +1103,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
:return: True if any of the functions are present
|
||||
"""
|
||||
logger.warning("Kusto KQL doesn't support checking for functions present.")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_limit_value(self) -> int | None:
|
||||
"""
|
||||
@@ -1150,7 +1131,11 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
Add a limit to the statement.
|
||||
"""
|
||||
if method != LimitMethod.FORCE_LIMIT:
|
||||
raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.")
|
||||
raise SupersetParseError(
|
||||
self._parsed,
|
||||
self.engine,
|
||||
message="Kusto KQL only supports the FORCE_LIMIT method.",
|
||||
)
|
||||
|
||||
tokens = tokenize_kql(self._parsed)
|
||||
found_limit_token = False
|
||||
|
||||
Reference in New Issue
Block a user