From adeed60fe059f16c399a84ced90d88e4076869d8 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Thu, 22 May 2025 20:09:36 -0400 Subject: [PATCH] feat: implement limit extraction in sqlglot (#33456) --- superset/commands/sql_lab/export.py | 9 +- superset/sql/parse.py | 196 ++++++++++++++++++++++------ tests/unit_tests/sql/parse_tests.py | 60 +++++++++ 3 files changed, 222 insertions(+), 43 deletions(-) diff --git a/superset/commands/sql_lab/export.py b/superset/commands/sql_lab/export.py index bfa73905483..44fdafe5cdb 100644 --- a/superset/commands/sql_lab/export.py +++ b/superset/commands/sql_lab/export.py @@ -27,7 +27,7 @@ from superset.commands.base import BaseCommand from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetErrorException, SupersetSecurityException from superset.models.sql_lab import Query -from superset.sql_parse import ParsedQuery +from superset.sql.parse import SQLScript from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import core as utils, csv from superset.views.utils import _deserialize_results_payload @@ -115,10 +115,9 @@ class SqlResultExportCommand(BaseCommand): limit = None else: sql = self._query.executed_sql - limit = ParsedQuery( - sql, - engine=self._query.database.db_engine_spec.engine, - ).limit + script = SQLScript(sql, self._query.database.db_engine_spec.engine) + # when a query has multiple statements only the last one returns data + limit = script.statements[-1].get_limit_value() if limit is not None and self._query.limiting_factor in { LimitingFactor.QUERY, LimitingFactor.DROPDOWN, diff --git a/superset/sql/parse.py b/superset/sql/parse.py index dc9ca632bba..84017736cfe 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -32,7 +32,7 @@ from deprecation import deprecated from sqlglot import exp from sqlglot.dialects.dialect import Dialect, Dialects from sqlglot.errors import ParseError -from sqlglot.expressions import Func +from sqlglot.expressions import Func, Limit from sqlglot.optimizer.pushdown_predicates import pushdown_predicates from sqlglot.optimizer.scope import Scope, ScopeType, traverse_scope @@ -237,6 +237,21 @@ class BaseSQLStatement(Generic[InternalRepresentation]): """ raise NotImplementedError() + def check_functions_present(self, functions: set[str]) -> bool: + """ + Check if any of the given functions are present in the script. + + :param functions: List of functions to check for + :return: True if any of the functions are present + """ + raise NotImplementedError() + + def get_limit_value(self) -> int | None: + """ + Get the limit value of the statement. + """ + raise NotImplementedError() + def __str__(self) -> str: return self.format() @@ -471,6 +486,24 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): } return any(function.upper() in present for function in functions) + def get_limit_value(self) -> int | None: + """ + Parse a SQL query and return the `LIMIT` or `TOP` value, if present. + """ + limit_node = ( + self._parsed + if isinstance(self._parsed, Limit) + else self._parsed.args.get("limit") + ) + if not isinstance(limit_node, exp.Limit): + return None + + literal = limit_node.args.get("expression") or getattr(limit_node, "this", None) + if isinstance(literal, exp.Literal) and literal.is_int: + return int(literal.name) + + return None + class KQLSplitState(enum.Enum): """ @@ -486,48 +519,118 @@ class KQLSplitState(enum.Enum): INSIDE_MULTILINE_STRING = enum.auto() +class KQLTokenType(enum.Enum): + """ + Token types for KQL. + """ + + STRING = enum.auto() + WORD = enum.auto() + NUMBER = enum.auto() + SEMICOLON = enum.auto() + WHITESPACE = enum.auto() + OTHER = enum.auto() + + +def classify_non_string_kql(text: str) -> list[tuple[KQLTokenType, str]]: + """ + Classify non-string KQL. + """ + tokens: list[tuple[KQLTokenType, str]] = [] + for m in re.finditer(r"[A-Za-z_][A-Za-z_0-9]*|\d+|\s+|.", text): + tok = m.group(0) + if tok == ";": + tokens.append((KQLTokenType.SEMICOLON, tok)) + elif tok.isdigit(): + tokens.append((KQLTokenType.NUMBER, tok)) + elif re.match(r"[A-Za-z_][A-Za-z_0-9]*", tok): + tokens.append((KQLTokenType.WORD, tok)) + elif re.match(r"\s+", tok): + tokens.append((KQLTokenType.WHITESPACE, tok)) + else: + tokens.append((KQLTokenType.OTHER, tok)) + + return tokens + + +def tokenize_kql(kql: str) -> list[tuple[KQLTokenType, str]]: + """ + Turn a KQL script into a flat list of tokens. + """ + + state = KQLSplitState.OUTSIDE_STRING + tokens: list[tuple[KQLTokenType, str]] = [] + buffer = "" + script = kql if kql.endswith(";") else kql + ";" + + for i, ch in enumerate(script): + if state == KQLSplitState.OUTSIDE_STRING: + if ch in {"'", '"'}: + if buffer: + tokens.extend(classify_non_string_kql(buffer)) + buffer = "" + state = ( + KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + if ch == "'" + else KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + ) + buffer = ch + elif ch == "`" and script[i - 2 : i] == "``": + if buffer: + tokens.extend(classify_non_string_kql(buffer)) + buffer = "" + state = KQLSplitState.INSIDE_MULTILINE_STRING + buffer = "`" + else: + buffer += ch + else: + buffer += ch + end_str = ( + ( + state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING + and ch == "'" + and script[i - 1] != "\\" + ) + or ( + state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING + and ch == '"' + and script[i - 1] != "\\" + ) + or ( + state == KQLSplitState.INSIDE_MULTILINE_STRING + and ch == "`" + and script[i - 2 : i] == "``" + ) + ) + if end_str: + tokens.append((KQLTokenType.STRING, buffer)) + buffer = "" + state = KQLSplitState.OUTSIDE_STRING + + if buffer: + tokens.extend(classify_non_string_kql(buffer)) + + return tokens + + def split_kql(kql: str) -> list[str]: """ - Custom function for splitting KQL statements. + Split a KQL script into statements on semicolons, + ignoring those inside strings. """ - statements = [] - state = KQLSplitState.OUTSIDE_STRING - statement_start = 0 - script = kql if kql.endswith(";") else kql + ";" - for i, character in enumerate(script): - if state == KQLSplitState.OUTSIDE_STRING: - if character == ";": - statements.append(script[statement_start:i]) - statement_start = i + 1 - elif character == "'": - state = KQLSplitState.INSIDE_SINGLE_QUOTED_STRING - elif character == '"': - state = KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING - elif character == "`" and script[i - 2 : i] == "``": - state = KQLSplitState.INSIDE_MULTILINE_STRING + tokens = tokenize_kql(kql) + stmts_tokens: list[list[tuple[KQLTokenType, str]]] = [] + current: list[tuple[KQLTokenType, str]] = [] - elif ( - state == KQLSplitState.INSIDE_SINGLE_QUOTED_STRING - and character == "'" - and script[i - 1] != "\\" - ): - state = KQLSplitState.OUTSIDE_STRING + for ttype, val in tokens: + if ttype == KQLTokenType.SEMICOLON: + if current: + stmts_tokens.append(current) + current = [] + else: + current.append((ttype, val)) - elif ( - state == KQLSplitState.INSIDE_DOUBLE_QUOTED_STRING - and character == '"' - and script[i - 1] != "\\" - ): - state = KQLSplitState.OUTSIDE_STRING - - elif ( - state == KQLSplitState.INSIDE_MULTILINE_STRING - and character == "`" - and script[i - 2 : i] == "``" - ): - state = KQLSplitState.OUTSIDE_STRING - - return statements + return ["".join(val for _, val in stmt) for stmt in stmts_tokens] class KustoKQLStatement(BaseSQLStatement[str]): @@ -647,6 +750,23 @@ class KustoKQLStatement(BaseSQLStatement[str]): logger.warning("Kusto KQL doesn't support checking for functions present.") return True + def get_limit_value(self) -> int | None: + """ + Get the limit value of the statement. + """ + tokens = [ + token + for token in tokenize_kql(self._sql) + if token[0] != KQLTokenType.WHITESPACE + ] + for idx, (ttype, val) in enumerate(tokens): + if ttype != KQLTokenType.STRING and val.lower() in {"take", "limit"}: + if idx + 1 < len(tokens) and tokens[idx + 1][0] == KQLTokenType.NUMBER: + return int(tokens[idx + 1][1]) + break + + return None + class SQLScript: """ diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 8dc06aeea39..f8da1d77b2a 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -1185,3 +1185,63 @@ def test_firebolt_old_escape_string() -> None: 'foo''bar', 'foo''bar'""" ) + + +@pytest.mark.parametrize( + "sql, engine, expected", + [ + ("SELECT * FROM users LIMIT 10", "postgresql", 10), + ("SELECT * FROM users ORDER BY id DESC LIMIT 25", "postgresql", 25), + ("SELECT * FROM users", "postgresql", None), + ("SELECT TOP 5 name FROM employees", "teradatasql", 5), + ("SELECT TOP (42) * FROM table_name", "teradatasql", 42), + ("select * from table", "postgresql", None), + ("select * from mytable limit 10", "postgresql", 10), + ( + "select * from (select * from my_subquery limit 10) where col=1 limit 20", + "postgresql", + 20, + ), + ("select * from (select * from my_subquery limit 10);", "postgresql", None), + ( + "select * from (select * from my_subquery limit 10) where col=1 limit 20;", + "postgresql", + 20, + ), + ("select * from mytable limit 20, 10", "postgresql", 10), + ("select * from mytable limit 10 offset 20", "postgresql", 10), + ( + """ +SELECT id, value, i +FROM (SELECT * FROM my_table LIMIT 10), +LATERAL generate_series(1, value) AS i; + """, + "postgresql", + None, + ), + ], +) +def test_get_limit_value(sql, engine, expected): + assert SQLStatement(sql, engine).get_limit_value() == expected + + +@pytest.mark.parametrize( + "kql, expected", + [ + ("StormEvents | take 10", 10), + ("StormEvents | limit 20", 20), + ("StormEvents | where State == 'FL' | summarize count()", None), + ("StormEvents | where name has 'limit 10'", None), + ("AnotherTable | take 5", 5), + ("datatable(x:int) [1, 2, 3] | take 100", 100), + ( + """ + Table1 | where msg contains 'abc;xyz' + | limit 5 + """, + 5, + ), + ], +) +def test_get_kql_limit_value(kql, expected): + assert KustoKQLStatement(kql, "kustokql").get_limit_value() == expected