feat: implement limit extraction in sqlglot (#33456)

This commit is contained in:
Beto Dealmeida
2025-05-22 20:09:36 -04:00
committed by GitHub
parent 546945e7a6
commit adeed60fe0
3 changed files with 222 additions and 43 deletions

View File

@@ -1185,3 +1185,63 @@ def test_firebolt_old_escape_string() -> None:
'foo''bar',
'foo''bar'"""
)
@pytest.mark.parametrize(
"sql, engine, expected",
[
("SELECT * FROM users LIMIT 10", "postgresql", 10),
("SELECT * FROM users ORDER BY id DESC LIMIT 25", "postgresql", 25),
("SELECT * FROM users", "postgresql", None),
("SELECT TOP 5 name FROM employees", "teradatasql", 5),
("SELECT TOP (42) * FROM table_name", "teradatasql", 42),
("select * from table", "postgresql", None),
("select * from mytable limit 10", "postgresql", 10),
(
"select * from (select * from my_subquery limit 10) where col=1 limit 20",
"postgresql",
20,
),
("select * from (select * from my_subquery limit 10);", "postgresql", None),
(
"select * from (select * from my_subquery limit 10) where col=1 limit 20;",
"postgresql",
20,
),
("select * from mytable limit 20, 10", "postgresql", 10),
("select * from mytable limit 10 offset 20", "postgresql", 10),
(
"""
SELECT id, value, i
FROM (SELECT * FROM my_table LIMIT 10),
LATERAL generate_series(1, value) AS i;
""",
"postgresql",
None,
),
],
)
def test_get_limit_value(sql, engine, expected):
assert SQLStatement(sql, engine).get_limit_value() == expected
@pytest.mark.parametrize(
"kql, expected",
[
("StormEvents | take 10", 10),
("StormEvents | limit 20", 20),
("StormEvents | where State == 'FL' | summarize count()", None),
("StormEvents | where name has 'limit 10'", None),
("AnotherTable | take 5", 5),
("datatable(x:int) [1, 2, 3] | take 100", 100),
(
"""
Table1 | where msg contains 'abc;xyz'
| limit 5
""",
5,
),
],
)
def test_get_kql_limit_value(kql, expected):
assert KustoKQLStatement(kql, "kustokql").get_limit_value() == expected