feat: use sqlglot to set limit (#33473)

This commit is contained in:
Beto Dealmeida
2025-05-27 15:20:02 -04:00
committed by GitHub
parent cc8ab2c556
commit 8de58b9848
34 changed files with 573 additions and 557 deletions

View File

@@ -24,6 +24,7 @@ from superset.exceptions import SupersetParseError
from superset.sql.parse import (
extract_tables_from_statement,
KustoKQLStatement,
LimitMethod,
split_kql,
SQLGLOT_DIALECTS,
SQLScript,
@@ -302,7 +303,11 @@ def test_format_no_dialect() -> None:
"""
assert (
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format()
== "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
== """SELECT
col
FROM t
WHERE
NOT col IN (1, 2)"""
)
@@ -1100,16 +1105,18 @@ FROM (
WHERE
TRUE AND TRUE"""
not_optimized = """
SELECT anon_1.a,
anon_1.b
FROM
(SELECT some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table) AS anon_1
WHERE anon_1.a > 1
AND anon_1.b = 2"""
not_optimized = """SELECT
anon_1.a,
anon_1.b
FROM (
SELECT
some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table
) AS anon_1
WHERE
anon_1.a > 1 AND anon_1.b = 2"""
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
assert SQLStatement(sql, "dremio").optimize().format() == not_optimized
@@ -1191,6 +1198,18 @@ def test_firebolt_old_escape_string() -> None:
"sql, engine, expected",
[
("SELECT * FROM users LIMIT 10", "postgresql", 10),
(
"""
WITH cte_example AS (
SELECT * FROM my_table
LIMIT 100
)
SELECT * FROM cte_example
LIMIT 10;
""",
"postgresql",
10,
),
("SELECT * FROM users ORDER BY id DESC LIMIT 25", "postgresql", 25),
("SELECT * FROM users", "postgresql", None),
("SELECT TOP 5 name FROM employees", "teradatasql", 5),
@@ -1221,7 +1240,7 @@ LATERAL generate_series(1, value) AS i;
),
],
)
def test_get_limit_value(sql, engine, expected):
def test_get_limit_value(sql: str, engine: str, expected: str) -> None:
assert SQLStatement(sql, engine).get_limit_value() == expected
@@ -1243,5 +1262,232 @@ def test_get_limit_value(sql, engine, expected):
),
],
)
def test_get_kql_limit_value(kql, expected):
def test_get_kql_limit_value(kql: str, expected: str) -> None:
assert KustoKQLStatement(kql, "kustokql").get_limit_value() == expected
@pytest.mark.parametrize(
"sql, engine, limit, method, expected",
[
(
"SELECT * FROM t",
"postgresql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM t\nLIMIT 10",
),
(
"SELECT * FROM t LIMIT 1000",
"postgresql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM t\nLIMIT 10",
),
(
"SELECT * FROM t",
"mssql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10\n *\nFROM t",
),
(
"SELECT * FROM t",
"teradatasql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10\n *\nFROM t",
),
(
"SELECT * FROM t",
"oracle",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM t\nFETCH FIRST 10 ROWS ONLY",
),
(
"SELECT * FROM t",
"db2",
10,
LimitMethod.WRAP_SQL,
"SELECT\n *\nFROM (\n SELECT\n *\n FROM t\n)\nLIMIT 10",
),
(
"SEL TOP 1000 * FROM My_table",
"teradatasql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SEL TOP 1000 * FROM My_table;",
"teradatasql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SEL TOP 1000 * FROM My_table;",
"teradatasql",
1000,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 1000\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"teradatasql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"teradatasql",
10000,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10000\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table",
"mssql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"mssql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"mssql",
10000,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10000\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"mssql",
1000,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 1000\n *\nFROM My_table",
),
(
"""
with abc as (select * from test union select * from test1)
select TOP 100 * from currency
""",
"mssql",
1000,
LimitMethod.FORCE_LIMIT,
"""WITH abc AS (
SELECT
*
FROM test
UNION
SELECT
*
FROM test1
)
SELECT
TOP 1000
*
FROM currency""",
),
(
"SELECT DISTINCT x from tbl",
"mssql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT DISTINCT\nTOP 100\n x\nFROM tbl",
),
(
"SELECT 1 as cnt",
"mssql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10\n 1 AS cnt",
),
(
"select TOP 1000 * from abc where id=1",
"mssql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10\n *\nFROM abc\nWHERE\n id = 1",
),
(
"SELECT * FROM birth_names -- SOME COMMENT",
"postgresql",
1000,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM birth_names /* SOME COMMENT */\nLIMIT 1000",
),
(
"SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555",
"postgresql",
1000,
LimitMethod.FORCE_LIMIT,
"""SELECT
*
FROM birth_names /* SOME COMMENT WITH LIMIT 555 */
LIMIT 1000""",
),
(
"SELECT * FROM birth_names LIMIT 555",
"postgresql",
1000,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM birth_names\nLIMIT 1000",
),
],
)
def test_set_limit_value(
sql: str,
engine: str,
limit: int,
method: LimitMethod,
expected: str,
) -> None:
statement = SQLStatement(sql, engine)
statement.set_limit_value(limit, method)
assert statement.format() == expected
@pytest.mark.parametrize(
"kql, limit, expected",
[
("StormEvents | take 10", 100, "StormEvents | take 100"),
("StormEvents | limit 20", 10, "StormEvents | limit 10"),
(
"StormEvents | where State == 'FL' | summarize count()",
10,
"StormEvents | where State == 'FL' | summarize count() | take 10",
),
(
"StormEvents | where name has 'limit 10'",
10,
"StormEvents | where name has 'limit 10' | take 10",
),
("AnotherTable | take 5", 50, "AnotherTable | take 50"),
(
"datatable(x:int) [1, 2, 3] | take 100",
10,
"datatable(x:int) [1, 2, 3] | take 10",
),
(
"""
Table1 | where msg contains 'abc;xyz'
| limit 5
""",
10,
"""Table1 | where msg contains 'abc;xyz'
| limit 10""",
),
],
)
def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None:
statement = KustoKQLStatement(kql, "kustokql")
statement.set_limit_value(limit)
assert statement.format() == expected