feat: implement limit extraction in sqlglot (#33456)

This commit is contained in:
Beto Dealmeida
2025-05-22 20:09:36 -04:00
committed by GitHub
parent 546945e7a6
commit adeed60fe0
3 changed files with 222 additions and 43 deletions

View File

@@ -32,7 +32,7 @@ from deprecation import deprecated
from sqlglot import exp
from sqlglot.dialects.dialect import Dialect, Dialects
from sqlglot.errors import ParseError
from sqlglot.expressions import Func
from sqlglot.expressions import Func, Limit
from sqlglot.optimizer.pushdown_predicates import pushdown_predicates
from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope
@@ -237,6 +237,21 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def check_functions_present(self, functions: set[str]) -> bool:
"""
Check if any of the given functions are present in the script.
:param functions: List of functions to check for
:return: True if any of the functions are present
"""
raise NotImplementedError()
def get_limit_value(self) -> int | None:
"""
Get the limit value of the statement.
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
@@ -471,6 +486,24 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
}
return any(function.upper() in present for function in functions)
def get_limit_value(self) -> int | None:
"""
Parse a SQL query and return the `LIMIT` or `TOP` value, if present.
"""
limit_node = (
self._parsed
if isinstance(self._parsed, Limit)
else self._parsed.args.get("limit")
)
if not isinstance(limit_node, exp.Limit):
return None
literal = limit_node.args.get("expression") or getattr(limit_node, "this", None)
if isinstance(literal, exp.Literal) and literal.is_int:
return int(literal.name)
return None
class KQLSplitState(enum.Enum):
"""
@@ -486,48 +519,118 @@ class KQLSplitState(enum.Enum):
INSIDE_MULTILINE_STRING = enum.auto()
class KQLTokenType(enum.Enum):
"""
Token types for KQL.
"""
STRING = enum.auto()
WORD = enum.auto()
NUMBER = enum.auto()
SEMICOLON = enum.auto()
WHITESPACE = enum.auto()
OTHER = enum.auto()
def classify_non_string_kql(text: str) -> list[tuple[KQLTokenType, str]]:
"""
Classify non-string KQL.
"""
tokens: list[tuple[KQLTokenType, str]] = []
for m in re.finditer(r"[A-Za-z_][A-Za-z_0-9]*|\d+|\s+|.", text):
tok = m.group(0)
if tok == ";":
tokens.append((KQLTokenType.SEMICOLON, tok))
elif tok.isdigit():
tokens.append((KQLTokenType.NUMBER, tok))
elif re.match(r"[A-Za-z_][A-Za-z_0-9]*", tok):
tokens.append((KQLTokenType.WORD, tok))
elif re.match(r"\s+", tok):
tokens.append((KQLTokenType.WHITESPACE, tok))
else:
tokens.append((KQLTokenType.OTHER, tok))
return tokens
def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]:
"""
Turn a KQL script into a flat list of tokens.
"""
state = KQLSplitState.OUTSIDE_STRING
tokens: list[tuple[KQLTokenType, str]] = []
buffer = ""
script = kql if kql.endswith(";") else kql + ";"
for i, ch in enumerate(script):
if state == KQLSplitState.OUTSIDE_STRING:
if ch in {"'", '"'}:
if buffer:
tokens.extend(classify_non_string_kql(buffer))
buffer = ""
state = (
KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
if ch == "'"
else KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
)
buffer = ch
elif ch == "`" and script[i - 2 : i] == "``":
if buffer:
tokens.extend(classify_non_string_kql(buffer))
buffer = ""
state = KQLSplitState.INSIDE_MULTILINE_STRING
buffer = "`"
else:
buffer += ch
else:
buffer += ch
end_str = (
(
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
and ch == "'"
and script[i - 1] != "\\"
)
or (
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
and ch == '"'
and script[i - 1] != "\\"
)
or (
state == KQLSplitState.INSIDE_MULTILINE_STRING
and ch == "`"
and script[i - 2 : i] == "``"
)
)
if end_str:
tokens.append((KQLTokenType.STRING, buffer))
buffer = ""
state = KQLSplitState.OUTSIDE_STRING
if buffer:
tokens.extend(classify_non_string_kql(buffer))
return tokens
def split_kql(kql: str) -> list[str]:
"""
Custom function for splitting KQL statements.
Split a KQL script into statements on semicolons,
ignoring those inside strings.
"""
statements = []
state = KQLSplitState.OUTSIDE_STRING
statement_start = 0
script = kql if kql.endswith(";") else kql + ";"
for i, character in enumerate(script):
if state == KQLSplitState.OUTSIDE_STRING:
if character == ";":
statements.append(script[statement_start:i])
statement_start = i + 1
elif character == "'":
state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
elif character == '"':
state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
elif character == "`" and script[i - 2 : i] == "``":
state = KQLSplitState.INSIDE_MULTILINE_STRING
tokens = tokenize_kql(kql)
stmts_tokens: list[list[tuple[KQLTokenType, str]]] = []
current: list[tuple[KQLTokenType, str]] = []
elif (
state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING
and character == "'"
and script[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
for ttype, val in tokens:
if ttype == KQLTokenType.SEMICOLON:
if current:
stmts_tokens.append(current)
current = []
else:
current.append((ttype, val))
elif (
state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING
and character == '"'
and script[i - 1] != "\\"
):
state = KQLSplitState.OUTSIDE_STRING
elif (
state == KQLSplitState.INSIDE_MULTILINE_STRING
and character == "`"
and script[i - 2 : i] == "``"
):
state = KQLSplitState.OUTSIDE_STRING
return statements
return ["".join(val for _, val in stmt) for stmt in stmts_tokens]
class KustoKQLStatement(BaseSQLStatement[str]):
@@ -647,6 +750,23 @@ class KustoKQLStatement(BaseSQLStatement[str]):
logger.warning("Kusto KQL doesn't support checking for functions present.")
return True
def get_limit_value(self) -> int | None:
"""
Get the limit value of the statement.
"""
tokens = [
token
for token in tokenize_kql(self._sql)
if token[0] != KQLTokenType.WHITESPACE
]
for idx, (ttype, val) in enumerate(tokens):
if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}:
if idx + 1 < len(tokens) and tokens[idx + 1][0] == KQLTokenType.NUMBER:
return int(tokens[idx + 1][1])
break
return None
class SQLScript:
"""