mirror of
https://github.com/apache/superset.git
synced 2026-04-12 04:37:49 +00:00
feat: use sqlglot to set limit (#33473)
This commit is contained in:
@@ -27,12 +27,9 @@ from dataclasses import dataclass
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
from deprecation import deprecated
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.expressions import Func, Limit
|
||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||
|
||||
@@ -99,6 +96,18 @@ SQLGLOT_DIALECTS = {
|
||||
}
|
||||
|
||||
|
||||
class LimitMethod(enum.Enum):
|
||||
"""
|
||||
Limit methods.
|
||||
|
||||
This is used to determine how to add a limit to a SQL statement.
|
||||
"""
|
||||
|
||||
FORCE_LIMIT = enum.auto()
|
||||
WRAP_SQL = enum.auto()
|
||||
FETCH_MANY = enum.auto()
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table:
|
||||
"""
|
||||
@@ -252,6 +261,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def set_limit_value(
|
||||
self,
|
||||
limit: int,
|
||||
method: LimitMethod = LimitMethod.FORCE_LIMIT,
|
||||
) -> None:
|
||||
"""
|
||||
Add a limit to the statement.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.format()
|
||||
|
||||
@@ -412,34 +431,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
"""
|
||||
if self._dialect:
|
||||
try:
|
||||
write = Dialect.get_or_raise(self._dialect)
|
||||
return write.generate(
|
||||
self._parsed,
|
||||
copy=False,
|
||||
comments=comments,
|
||||
pretty=True,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return self._fallback_formatting()
|
||||
|
||||
@deprecated(deprecated_in="4.0")
|
||||
def _fallback_formatting(self) -> str:
|
||||
"""
|
||||
Format SQL without a specific dialect.
|
||||
|
||||
Reformatting SQL using the generic sqlglot dialect is known to break queries.
|
||||
For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which
|
||||
breaks the query for Firebolt. To avoid this, we use sqlparse for formatting
|
||||
when the dialect is not known.
|
||||
|
||||
In 5.0 we should remove `sqlparse`, and the method should return the query
|
||||
unmodified.
|
||||
"""
|
||||
return sqlparse.format(self._sql, reindent=True, keyword_case="upper")
|
||||
return Dialect.get_or_raise(self._dialect).generate(
|
||||
self._parsed,
|
||||
copy=True,
|
||||
comments=comments,
|
||||
pretty=True,
|
||||
)
|
||||
|
||||
def get_settings(self) -> dict[str, str | bool]:
|
||||
"""
|
||||
@@ -482,7 +479,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
if function.sql_name() != "ANONYMOUS"
|
||||
else function.name.upper()
|
||||
)
|
||||
for function in self._parsed.find_all(Func)
|
||||
for function in self._parsed.find_all(exp.Func)
|
||||
}
|
||||
return any(function.upper() in present for function in functions)
|
||||
|
||||
@@ -490,20 +487,38 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
Parse a SQL query and return the `LIMIT` or `TOP` value, if present.
|
||||
"""
|
||||
limit_node = (
|
||||
self._parsed
|
||||
if isinstance(self._parsed, Limit)
|
||||
else self._parsed.args.get("limit")
|
||||
)
|
||||
if not isinstance(limit_node, exp.Limit):
|
||||
return None
|
||||
|
||||
literal = limit_node.args.get("expression") or getattr(limit_node, "this", None)
|
||||
if isinstance(literal, exp.Literal) and literal.is_int:
|
||||
return int(literal.name)
|
||||
if limit_node := self._parsed.args.get("limit"):
|
||||
literal = limit_node.args.get("expression") or getattr(
|
||||
limit_node, "this", None
|
||||
)
|
||||
if isinstance(literal, exp.Literal) and literal.is_int:
|
||||
return int(literal.name)
|
||||
|
||||
return None
|
||||
|
||||
def set_limit_value(
|
||||
self,
|
||||
limit: int,
|
||||
method: LimitMethod = LimitMethod.FORCE_LIMIT,
|
||||
) -> None:
|
||||
"""
|
||||
Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
|
||||
"""
|
||||
if method == LimitMethod.FORCE_LIMIT:
|
||||
self._parsed.args["limit"] = exp.Limit(
|
||||
expression=exp.Literal(this=str(limit), is_string=False)
|
||||
)
|
||||
elif method == LimitMethod.WRAP_SQL:
|
||||
self._parsed = exp.Select(
|
||||
expressions=[exp.Star()],
|
||||
limit=exp.Limit(
|
||||
expression=exp.Literal(this=str(limit), is_string=False)
|
||||
),
|
||||
**{"from": exp.From(this=exp.Subquery(this=self._parsed.copy()))},
|
||||
)
|
||||
else: # method == LimitMethod.FETCH_MANY
|
||||
pass
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
@@ -561,7 +576,7 @@ def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]:
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
tokens: list[tuple[KQLTokenType, str]] = []
|
||||
buffer = ""
|
||||
script = kql if kql.endswith(";") else kql + ";"
|
||||
script = kql
|
||||
|
||||
for i, ch in enumerate(script):
|
||||
if state == KQLSplitState.OUTSIDE_STRING:
|
||||
@@ -630,6 +645,9 @@ def split_kql(kql: str) -> list[str]:
|
||||
else:
|
||||
current.append((ttype, val))
|
||||
|
||||
if current:
|
||||
stmts_tokens.append(current)
|
||||
|
||||
return ["".join(val for _, val in stmt) for stmt in stmts_tokens]
|
||||
|
||||
|
||||
@@ -767,6 +785,40 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
|
||||
return None
|
||||
|
||||
def set_limit_value(
|
||||
self,
|
||||
limit: int,
|
||||
method: LimitMethod = LimitMethod.FORCE_LIMIT,
|
||||
) -> None:
|
||||
"""
|
||||
Add a limit to the statement.
|
||||
"""
|
||||
if method != LimitMethod.FORCE_LIMIT:
|
||||
raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.")
|
||||
|
||||
tokens = tokenize_kql(self._sql)
|
||||
found_limit_token = False
|
||||
for idx, (ttype, val) in enumerate(tokens):
|
||||
if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}:
|
||||
found_limit_token = True
|
||||
|
||||
if found_limit_token and ttype == KQLTokenType.NUMBER:
|
||||
tokens[idx] = (KQLTokenType.NUMBER, str(limit))
|
||||
break
|
||||
else:
|
||||
tokens.extend(
|
||||
[
|
||||
(KQLTokenType.WHITESPACE, " "),
|
||||
(KQLTokenType.WORD, "|"),
|
||||
(KQLTokenType.WHITESPACE, " "),
|
||||
(KQLTokenType.WORD, "take"),
|
||||
(KQLTokenType.WHITESPACE, " "),
|
||||
(KQLTokenType.NUMBER, str(limit)),
|
||||
]
|
||||
)
|
||||
|
||||
self._parsed = self._sql = "".join(val for _, val in tokens)
|
||||
|
||||
|
||||
class SQLScript:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user