mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
fix: sqlparse fallback for formatting queries (#30578)
This commit is contained in:
@@ -26,6 +26,8 @@ from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
from deprecation import deprecated
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.errors import ParseError
|
||||
@@ -138,9 +140,9 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
Base class for SQL statements.
|
||||
|
||||
The class can be instantiated with a string representation of the script or, for
|
||||
efficiency reasons, with a pre-parsed AST. This is useful with `sqlglot.parse`,
|
||||
which will split a script in multiple already parsed statements.
|
||||
The class should be instantiated with a string representation of the script and, for
|
||||
efficiency reasons, optionally with a pre-parsed AST. This is useful with
|
||||
`sqlglot.parse`, which will split a script in multiple already parsed statements.
|
||||
|
||||
The `engine` parameters comes from the `engine` attribute in a Superset DB engine
|
||||
spec.
|
||||
@@ -148,14 +150,12 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
statement: str | InternalRepresentation,
|
||||
statement: str,
|
||||
engine: str,
|
||||
ast: InternalRepresentation | None = None,
|
||||
):
|
||||
self._parsed: InternalRepresentation = (
|
||||
self._parse_statement(statement, engine)
|
||||
if isinstance(statement, str)
|
||||
else statement
|
||||
)
|
||||
self._sql = statement
|
||||
self._parsed = ast or self._parse_statement(statement, engine)
|
||||
self.engine = engine
|
||||
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
|
||||
|
||||
@@ -239,11 +239,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
statement: str | exp.Expression,
|
||||
statement: str,
|
||||
engine: str,
|
||||
ast: exp.Expression | None = None,
|
||||
):
|
||||
self._dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
super().__init__(statement, engine)
|
||||
super().__init__(statement, engine, ast)
|
||||
|
||||
@classmethod
|
||||
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
|
||||
@@ -275,11 +276,47 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
script: str,
|
||||
engine: str,
|
||||
) -> list[SQLStatement]:
|
||||
return [
|
||||
cls(statement, engine)
|
||||
for statement in cls._parse(script, engine)
|
||||
if statement
|
||||
]
|
||||
if engine in SQLGLOT_DIALECTS:
|
||||
try:
|
||||
return [
|
||||
cls(ast.sql(), engine, ast)
|
||||
for ast in cls._parse(script, engine)
|
||||
if ast
|
||||
]
|
||||
except ValueError:
|
||||
# `ast.sql()` might raise an error on some cases (eg, `SHOW TABLES
|
||||
# FROM`). In this case, we rely on the tokenizer to generate the
|
||||
# statements.
|
||||
pass
|
||||
|
||||
# When we don't have a sqlglot dialect we can't rely on `ast.sql()` to correctly
|
||||
# generate the SQL of each statement, so we tokenize the script and split it
|
||||
# based on the location of semi-colons.
|
||||
statements = []
|
||||
start = 0
|
||||
remainder = script
|
||||
|
||||
try:
|
||||
tokens = sqlglot.tokenize(script)
|
||||
except sqlglot.errors.TokenError as ex:
|
||||
raise SupersetParseError(
|
||||
script,
|
||||
engine,
|
||||
message="Unable to tokenize script",
|
||||
) from ex
|
||||
|
||||
for token in tokens:
|
||||
if token.token_type == sqlglot.TokenType.SEMICOLON:
|
||||
statement, start = script[start : token.start], token.end + 1
|
||||
ast = cls._parse(statement, engine)[0]
|
||||
statements.append(cls(statement.strip(), engine, ast))
|
||||
remainder = script[start:]
|
||||
|
||||
if remainder.strip():
|
||||
ast = cls._parse(remainder, engine)[0]
|
||||
statements.append(cls(remainder.strip(), engine, ast))
|
||||
|
||||
return statements
|
||||
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
@@ -349,8 +386,34 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
write = Dialect.get_or_raise(self._dialect)
|
||||
return write.generate(self._parsed, copy=False, comments=comments, pretty=True)
|
||||
if self._dialect:
|
||||
try:
|
||||
write = Dialect.get_or_raise(self._dialect)
|
||||
return write.generate(
|
||||
self._parsed,
|
||||
copy=False,
|
||||
comments=comments,
|
||||
pretty=True,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return self._fallback_formatting()
|
||||
|
||||
@deprecated(deprecated_in="4.0", removed_in="5.0")
|
||||
def _fallback_formatting(self) -> str:
|
||||
"""
|
||||
Format SQL without a specific dialect.
|
||||
|
||||
Reformatting SQL using the generic sqlglot dialect is known to break queries.
|
||||
For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which
|
||||
breaks the query for Firebolt. To avoid this, we use sqlparse for formatting
|
||||
when the dialect is not known.
|
||||
|
||||
In 5.0 we should remove `sqlparse`, and the method should return the query
|
||||
unmodified.
|
||||
"""
|
||||
return sqlparse.format(self._sql, reindent=True, keyword_case="upper")
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
@@ -456,7 +519,9 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
https://learn.microsoft.com/en-us/azure/data-explorer/kusto/query/scalar-data-types/string
|
||||
for more information.
|
||||
"""
|
||||
return [cls(statement, engine) for statement in split_kql(script)]
|
||||
return [
|
||||
cls(statement, engine, statement.strip()) for statement in split_kql(script)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
@@ -498,7 +563,7 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
return self._parsed
|
||||
return self._sql.strip()
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
@@ -548,6 +613,9 @@ class SQLScript:
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL script.
|
||||
|
||||
Note that even though KQL is very different from SQL, multiple statements are
|
||||
still separated by semi-colons.
|
||||
"""
|
||||
return ";\n".join(statement.format(comments) for statement in self.statements)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user