mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat: implement limit extraction in sqlglot (#33456)
This commit is contained in:
@@ -32,7 +32,7 @@ from deprecation import deprecated
|
||||
from sqlglot import exp
|
||||
from sqlglot.dialects.dialect import Dialect, Dialects
|
||||
from sqlglot.errors import ParseError
|
||||
from sqlglot.expressions import Func
|
||||
from sqlglot.expressions import Func, Limit
|
||||
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
|
||||
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
|
||||
|
||||
@@ -237,6 +237,21 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def check_functions_present(self, functions: set[str]) -> bool:
|
||||
"""
|
||||
Check if any of the given functions are present in the script.
|
||||
|
||||
:param functions: List of functions to check for
|
||||
:return: True if any of the functions are present
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_limit_value(self) -> int | None:
|
||||
"""
|
||||
Get the limit value of the statement.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.format()
|
||||
|
||||
@@ -471,6 +486,24 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
}
|
||||
return any(function.upper() in present for function in functions)
|
||||
|
||||
def get_limit_value(self) -> int | None:
|
||||
"""
|
||||
Parse a SQL query and return the `LIMIT` or `TOP` value, if present.
|
||||
"""
|
||||
limit_node = (
|
||||
self._parsed
|
||||
if isinstance(self._parsed, Limit)
|
||||
else self._parsed.args.get("limit")
|
||||
)
|
||||
if not isinstance(limit_node, exp.Limit):
|
||||
return None
|
||||
|
||||
literal = limit_node.args.get("expression") or getattr(limit_node, "this", None)
|
||||
if isinstance(literal, exp.Literal) and literal.is_int:
|
||||
return int(literal.name)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
@@ -486,48 +519,118 @@ class KQLSplitState(enum.Enum):
|
||||
INSIDE_MULTILINE_STRING = enum.auto()
|
||||
|
||||
|
||||
class KQLTokenType(enum.Enum):
|
||||
"""
|
||||
Token types for KQL.
|
||||
"""
|
||||
|
||||
STRING = enum.auto()
|
||||
WORD = enum.auto()
|
||||
NUMBER = enum.auto()
|
||||
SEMICOLON = enum.auto()
|
||||
WHITESPACE = enum.auto()
|
||||
OTHER = enum.auto()
|
||||
|
||||
|
||||
def classify_non_string_kql(text: str) -> list[tuple[KQLTokenType, str]]:
|
||||
"""
|
||||
Classify non-string KQL.
|
||||
"""
|
||||
tokens: list[tuple[KQLTokenType, str]] = []
|
||||
for m in re.finditer(r"[A-Za-z_][A-Za-z_0-9]*|\d+|\s+|.", text):
|
||||
tok = m.group(0)
|
||||
if tok == ";":
|
||||
tokens.append((KQLTokenType.SEMICOLON, tok))
|
||||
elif tok.isdigit():
|
||||
tokens.append((KQLTokenType.NUMBER, tok))
|
||||
elif re.match(r"[A-Za-z_][A-Za-z_0-9]*", tok):
|
||||
tokens.append((KQLTokenType.WORD, tok))
|
||||
elif re.match(r"\s+", tok):
|
||||
tokens.append((KQLTokenType.WHITESPACE, tok))
|
||||
else:
|
||||
tokens.append((KQLTokenType.OTHER, tok))
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]:
|
||||
"""
|
||||
Turn a KQL script into a flat list of tokens.
|
||||
"""
|
||||
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
tokens: list[tuple[KQLTokenType, str]] = []
|
||||
buffer = ""
|
||||
script = kql if kql.endswith(";") else kql + ";"
|
||||
|
||||
for i, ch in enumerate(script):
|
||||
if state == KQLSplitState.OUTSIDE_STRING:
|
||||
if ch in {"'", '"'}:
|
||||
if buffer:
|
||||
tokens.extend(classify_non_string_kql(buffer))
|
||||
buffer = ""
|
||||
state = (
|
||||
KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||
if ch == "'"
|
||||
else KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||
)
|
||||
buffer = ch
|
||||
elif ch == "`" and script[i - 2 : i] == "``":
|
||||
if buffer:
|
||||
tokens.extend(classify_non_string_kql(buffer))
|
||||
buffer = ""
|
||||
state = KQLSplitState.INSIDE_MULTILINE_STRING
|
||||
buffer = "`"
|
||||
else:
|
||||
buffer += ch
|
||||
else:
|
||||
buffer += ch
|
||||
end_str = (
|
||||
(
|
||||
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||
and ch == "'"
|
||||
and script[i - 1] != "\\"
|
||||
)
|
||||
or (
|
||||
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||
and ch == '"'
|
||||
and script[i - 1] != "\\"
|
||||
)
|
||||
or (
|
||||
state == KQLSplitState.INSIDE_MULTILINE_STRING
|
||||
and ch == "`"
|
||||
and script[i - 2 : i] == "``"
|
||||
)
|
||||
)
|
||||
if end_str:
|
||||
tokens.append((KQLTokenType.STRING, buffer))
|
||||
buffer = ""
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
|
||||
if buffer:
|
||||
tokens.extend(classify_non_string_kql(buffer))
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
def split_kql(kql: str) -> list[str]:
|
||||
"""
|
||||
Custom function for splitting KQL statements.
|
||||
Split a KQL script into statements on semicolons,
|
||||
ignoring those inside strings.
|
||||
"""
|
||||
statements = []
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
statement_start = 0
|
||||
script = kql if kql.endswith(";") else kql + ";"
|
||||
for i, character in enumerate(script):
|
||||
if state == KQLSplitState.OUTSIDE_STRING:
|
||||
if character == ";":
|
||||
statements.append(script[statement_start:i])
|
||||
statement_start = i + 1
|
||||
elif character == "'":
|
||||
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||
elif character == '"':
|
||||
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||
elif character == "`" and script[i - 2 : i] == "``":
|
||||
state = KQLSplitState.INSIDE_MULTILINE_STRING
|
||||
tokens = tokenize_kql(kql)
|
||||
stmts_tokens: list[list[tuple[KQLTokenType, str]]] = []
|
||||
current: list[tuple[KQLTokenType, str]] = []
|
||||
|
||||
elif (
|
||||
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
|
||||
and character == "'"
|
||||
and script[i - 1] != "\\"
|
||||
):
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
for ttype, val in tokens:
|
||||
if ttype == KQLTokenType.SEMICOLON:
|
||||
if current:
|
||||
stmts_tokens.append(current)
|
||||
current = []
|
||||
else:
|
||||
current.append((ttype, val))
|
||||
|
||||
elif (
|
||||
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
|
||||
and character == '"'
|
||||
and script[i - 1] != "\\"
|
||||
):
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
|
||||
elif (
|
||||
state == KQLSplitState.INSIDE_MULTILINE_STRING
|
||||
and character == "`"
|
||||
and script[i - 2 : i] == "``"
|
||||
):
|
||||
state = KQLSplitState.OUTSIDE_STRING
|
||||
|
||||
return statements
|
||||
return ["".join(val for _, val in stmt) for stmt in stmts_tokens]
|
||||
|
||||
|
||||
class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
@@ -647,6 +750,23 @@ class KustoKQLStatement(BaseSQLStatement[str]):
|
||||
logger.warning("Kusto KQL doesn't support checking for functions present.")
|
||||
return True
|
||||
|
||||
def get_limit_value(self) -> int | None:
|
||||
"""
|
||||
Get the limit value of the statement.
|
||||
"""
|
||||
tokens = [
|
||||
token
|
||||
for token in tokenize_kql(self._sql)
|
||||
if token[0] != KQLTokenType.WHITESPACE
|
||||
]
|
||||
for idx, (ttype, val) in enumerate(tokens):
|
||||
if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}:
|
||||
if idx + 1 < len(tokens) and tokens[idx + 1][0] == KQLTokenType.NUMBER:
|
||||
return int(tokens[idx + 1][1])
|
||||
break
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class SQLScript:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user