feat: use sqlglot to set limit (#33473)

This commit is contained in:
Beto Dealmeida
2025-05-27 15:20:02 -04:00
committed by GitHub
parent cc8ab2c556
commit 8de58b9848
34 changed files with 573 additions and 557 deletions

View File

@@ -27,12 +27,9 @@ from dataclasses import dataclass
from typing import Any, Generic, TypeVar
import sqlglot
import sqlparse
from deprecation import deprecated
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
from sqlglot.expressions import Func, Limit
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
@@ -99,6 +96,18 @@ SQLGLOT_DIALECTS = {
}
class LimitMethod(enum.Enum):
"""
Limit methods.
This is used to determine how to add a limit to a SQL statement.
"""
FORCE_LIMIT = enum.auto()
WRAP_SQL = enum.auto()
FETCH_MANY = enum.auto()
@dataclass(eq=True, frozen=True)
class Table:
"""
@@ -252,6 +261,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def set_limit_value(
self,
limit: int,
method: LimitMethod = LimitMethod.FORCE_LIMIT,
) -> None:
"""
Add a limit to the statement.
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
@@ -412,34 +431,12 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Pretty-format the SQL statement.
"""
if self._dialect:
try:
write = Dialect.get_or_raise(self._dialect)
return write.generate(
self._parsed,
copy=False,
comments=comments,
pretty=True,
)
except ValueError:
pass
return self._fallback_formatting()
@deprecated(deprecated_in="4.0")
def _fallback_formatting(self) -> str:
"""
Format SQL without a specific dialect.
Reformatting SQL using the generic sqlglot dialect is known to break queries.
For example, it will change `foo NOT IN (1, 2)` to `NOT foo IN (1,2)`, which
breaks the query for Firebolt. To avoid this, we use sqlparse for formatting
when the dialect is not known.
In 5.0 we should remove `sqlparse`, and the method should return the query
unmodified.
"""
return sqlparse.format(self._sql, reindent=True, keyword_case="upper")
return Dialect.get_or_raise(self._dialect).generate(
self._parsed,
copy=True,
comments=comments,
pretty=True,
)
def get_settings(self) -> dict[str, str | bool]:
"""
@@ -482,7 +479,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
if function.sql_name() != "ANONYMOUS"
else function.name.upper()
)
for function in self._parsed.find_all(Func)
for function in self._parsed.find_all(exp.Func)
}
return any(function.upper() in present for function in functions)
@@ -490,20 +487,38 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Parse a SQL query and return the `LIMIT` or `TOP` value, if present.
"""
limit_node = (
self._parsed
if isinstance(self._parsed, Limit)
else self._parsed.args.get("limit")
)
if not isinstance(limit_node, exp.Limit):
return None
literal = limit_node.args.get("expression") or getattr(limit_node, "this", None)
if isinstance(literal, exp.Literal) and literal.is_int:
return int(literal.name)
if limit_node := self._parsed.args.get("limit"):
literal = limit_node.args.get("expression") or getattr(
limit_node, "this", None
)
if isinstance(literal, exp.Literal) and literal.is_int:
return int(literal.name)
return None
def set_limit_value(
self,
limit: int,
method: LimitMethod = LimitMethod.FORCE_LIMIT,
) -> None:
"""
Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
"""
if method == LimitMethod.FORCE_LIMIT:
self._parsed.args["limit"] = exp.Limit(
expression=exp.Literal(this=str(limit), is_string=False)
)
elif method == LimitMethod.WRAP_SQL:
self._parsed = exp.Select(
expressions=[exp.Star()],
limit=exp.Limit(
expression=exp.Literal(this=str(limit), is_string=False)
),
**{"from": exp.From(this=exp.Subquery(this=self._parsed.copy()))},
)
else: # method == LimitMethod.FETCH_MANY
pass
class KQLSplitState(enum.Enum):
"""
@@ -561,7 +576,7 @@ def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]:
state = KQLSplitState.OUTSIDE_STRING
tokens: list[tuple[KQLTokenType, str]] = []
buffer = ""
script = kql if kql.endswith(";") else kql + ";"
script = kql
for i, ch in enumerate(script):
if state == KQLSplitState.OUTSIDE_STRING:
@@ -630,6 +645,9 @@ def split_kql(kql: str) -> list[str]:
else:
current.append((ttype, val))
if current:
stmts_tokens.append(current)
return ["".join(val for _, val in stmt) for stmt in stmts_tokens]
@@ -767,6 +785,40 @@ class KustoKQLStatement(BaseSQLStatement[str]):
return None
def set_limit_value(
self,
limit: int,
method: LimitMethod = LimitMethod.FORCE_LIMIT,
) -> None:
"""
Add a limit to the statement.
"""
if method != LimitMethod.FORCE_LIMIT:
raise SupersetParseError("Kusto KQL only supports the FORCE_LIMIT method.")
tokens = tokenize_kql(self._sql)
found_limit_token = False
for idx, (ttype, val) in enumerate(tokens):
if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}:
found_limit_token = True
if found_limit_token and ttype == KQLTokenType.NUMBER:
tokens[idx] = (KQLTokenType.NUMBER, str(limit))
break
else:
tokens.extend(
[
(KQLTokenType.WHITESPACE, " "),
(KQLTokenType.WORD, "|"),
(KQLTokenType.WHITESPACE, " "),
(KQLTokenType.WORD, "take"),
(KQLTokenType.WHITESPACE, " "),
(KQLTokenType.NUMBER, str(limit)),
]
)
self._parsed = self._sql = "".join(val for _, val in tokens)
class SQLScript:
"""