feat: use sqlglot to set limit (#33473)

This commit is contained in:
Beto Dealmeida
2025-05-27 15:20:02 -04:00
committed by GitHub
parent cc8ab2c556
commit 8de58b9848
34 changed files with 573 additions and 557 deletions

View File

@@ -206,9 +206,6 @@ def test_select_star(mocker: MockerFixture) -> None:
"""
from superset.db_engine_specs.base import BaseEngineSpec
class NoLimitDBEngineSpec(BaseEngineSpec):
allow_limit_clause = False
cols: list[ResultSetColumnType] = [
{
"column_name": "a",
@@ -243,19 +240,7 @@ def test_select_star(mocker: MockerFixture) -> None:
latest_partition=False,
cols=cols,
)
assert sql == "SELECT a\nFROM my_table\nLIMIT ?\nOFFSET ?"
sql = NoLimitDBEngineSpec.select_star(
database=database,
table=Table("my_table"),
engine=engine,
limit=100,
show_cols=True,
indent=True,
latest_partition=False,
cols=cols,
)
assert sql == "SELECT a\nFROM my_table"
assert sql == "SELECT\n a\nFROM my_table\nLIMIT ?\nOFFSET ?"
def test_extra_table_metadata(mocker: MockerFixture) -> None:

View File

@@ -254,36 +254,6 @@ def test_cte_query_parsing(original: TypeEngine, expected: str) -> None:
assert actual == expected
@pytest.mark.parametrize(
"original,expected,top",
[
("SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table", 100),
("SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table", 100),
("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 10000),
("SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table", 1000),
(
"""with abc as (select * from test union select * from test1)
select TOP 100 * from currency""",
"""WITH abc as (select * from test union select * from test1)
select TOP 100 * from currency""",
1000,
),
("SELECT DISTINCT x from tbl", "SELECT DISTINCT TOP 100 x from tbl", 100),
("SELECT 1 as cnt", "SELECT TOP 10 1 as cnt", 10),
(
"select TOP 1000 * from abc where id=1",
"select TOP 10 * from abc where id=1",
10,
),
],
)
def test_top_query_parsing(original: TypeEngine, expected: str, top: int) -> None:
from superset.db_engine_specs.mssql import MssqlEngineSpec
actual = MssqlEngineSpec.apply_top_to_sql(original, top)
assert actual == expected
def test_extract_errors() -> None:
"""
Test that custom error messages are extracted correctly.

View File

@@ -1,43 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=unused-argument, import-outside-toplevel, protected-access
import pytest
@pytest.mark.parametrize(
"limit,original,expected",
[
(100, "SEL TOP 1000 * FROM My_table", "SEL TOP 100 * FROM My_table"),
(100, "SEL TOP 1000 * FROM My_table;", "SEL TOP 100 * FROM My_table"),
(10000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"),
(1000, "SEL TOP 1000 * FROM My_table;", "SEL TOP 1000 * FROM My_table"),
(100, "SELECT TOP 1000 * FROM My_table", "SELECT TOP 100 * FROM My_table"),
(100, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 100 * FROM My_table"),
(10000, "SEL SAMPLE 1000 * FROM My_table", "SEL SAMPLE 1000 * FROM My_table"),
],
)
def test_apply_top_to_sql_limit(
limit: int,
original: str,
expected: str,
) -> None:
"""
Ensure limits are applied to the query correctly
"""
from superset.db_engine_specs.teradata import TeradataEngineSpec
assert TeradataEngineSpec.apply_top_to_sql(original, limit) == expected

View File

@@ -38,6 +38,7 @@ from superset.connectors.sqla.models import SqlaTable, TableColumn
from superset.errors import SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.models.core import Database
from superset.sql.parse import LimitMethod
from superset.sql_parse import Table
from superset.utils import json
from tests.unit_tests.conftest import with_feature_flags
@@ -910,3 +911,144 @@ def test_get_all_view_names_in_schema(mocker: MockerFixture) -> None:
("third_view", "public", "examples"),
}
)
@pytest.mark.parametrize(
"sql, limit, force, method, expected",
[
(
"SELECT * FROM table",
100,
False,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM table\nLIMIT 100",
),
(
"SELECT * FROM table LIMIT 100",
10,
False,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM table\nLIMIT 10",
),
(
"SELECT * FROM table LIMIT 10",
100,
False,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM table\nLIMIT 10",
),
(
"SELECT * FROM table LIMIT 10",
100,
True,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM table\nLIMIT 100",
),
(
"SELECT * FROM a \t \n ; \t \n ",
1000,
False,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM a\nLIMIT 1000",
),
(
"SELECT 'LIMIT 777'",
1000,
False,
LimitMethod.FORCE_LIMIT,
"SELECT\n 'LIMIT 777'\nLIMIT 1000",
),
(
"SELECT * FROM table",
1000,
False,
LimitMethod.FETCH_MANY,
"SELECT\n *\nFROM table",
),
(
"SELECT * FROM (SELECT * FROM a LIMIT 10) LIMIT 9999",
1000,
False,
LimitMethod.FORCE_LIMIT,
"""SELECT
*
FROM (
SELECT
*
FROM a
LIMIT 10
)
LIMIT 1000""",
),
(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990""",
1000,
None,
LimitMethod.FORCE_LIMIT,
"SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000",
),
(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990 ;""",
1000,
None,
LimitMethod.FORCE_LIMIT,
"SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000",
),
(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990, 999999""",
1000,
None,
LimitMethod.FORCE_LIMIT,
"SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000\nOFFSET 99990",
),
(
"""
SELECT
'LIMIT 777' AS a
, b
FROM
table
LIMIT 99990
OFFSET 999999""",
1000,
None,
LimitMethod.FORCE_LIMIT,
"SELECT\n 'LIMIT 777' AS a,\n b\nFROM table\nLIMIT 1000\nOFFSET 999999",
),
],
)
def test_apply_limit_to_sql(
sql: str,
limit: int,
force: bool,
method: LimitMethod,
expected: str,
mocker: MockerFixture,
) -> None:
"""
Test the `apply_limit_to_sql` method.
"""
db = Database(database_name="test_database", sqlalchemy_uri="sqlite://")
db_engine_spec = mocker.MagicMock(limit_method=method)
db.get_db_engine_spec = mocker.MagicMock(return_value=db_engine_spec)
limited = db.apply_limit_to_sql(sql, limit, force)
assert limited == expected

View File

@@ -24,6 +24,7 @@ from superset.exceptions import SupersetParseError
from superset.sql.parse import (
extract_tables_from_statement,
KustoKQLStatement,
LimitMethod,
split_kql,
SQLGLOT_DIALECTS,
SQLScript,
@@ -302,7 +303,11 @@ def test_format_no_dialect() -> None:
"""
assert (
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format()
== "SELECT col\nFROM t\nWHERE col NOT IN (1,\n 2)"
== """SELECT
col
FROM t
WHERE
NOT col IN (1, 2)"""
)
@@ -1100,16 +1105,18 @@ FROM (
WHERE
TRUE AND TRUE"""
not_optimized = """
SELECT anon_1.a,
anon_1.b
FROM
(SELECT some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table) AS anon_1
WHERE anon_1.a > 1
AND anon_1.b = 2"""
not_optimized = """SELECT
anon_1.a,
anon_1.b
FROM (
SELECT
some_table.a AS a,
some_table.b AS b,
some_table.c AS c
FROM some_table
) AS anon_1
WHERE
anon_1.a > 1 AND anon_1.b = 2"""
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
assert SQLStatement(sql, "dremio").optimize().format() == not_optimized
@@ -1191,6 +1198,18 @@ def test_firebolt_old_escape_string() -> None:
"sql, engine, expected",
[
("SELECT * FROM users LIMIT 10", "postgresql", 10),
(
"""
WITH cte_example AS (
SELECT * FROM my_table
LIMIT 100
)
SELECT * FROM cte_example
LIMIT 10;
""",
"postgresql",
10,
),
("SELECT * FROM users ORDER BY id DESC LIMIT 25", "postgresql", 25),
("SELECT * FROM users", "postgresql", None),
("SELECT TOP 5 name FROM employees", "teradatasql", 5),
@@ -1221,7 +1240,7 @@ LATERAL generate_series(1, value) AS i;
),
],
)
def test_get_limit_value(sql, engine, expected):
def test_get_limit_value(sql: str, engine: str, expected: str) -> None:
assert SQLStatement(sql, engine).get_limit_value() == expected
@@ -1243,5 +1262,232 @@ def test_get_limit_value(sql, engine, expected):
),
],
)
def test_get_kql_limit_value(kql, expected):
def test_get_kql_limit_value(kql: str, expected: str) -> None:
assert KustoKQLStatement(kql, "kustokql").get_limit_value() == expected
@pytest.mark.parametrize(
"sql, engine, limit, method, expected",
[
(
"SELECT * FROM t",
"postgresql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM t\nLIMIT 10",
),
(
"SELECT * FROM t LIMIT 1000",
"postgresql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM t\nLIMIT 10",
),
(
"SELECT * FROM t",
"mssql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10\n *\nFROM t",
),
(
"SELECT * FROM t",
"teradatasql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10\n *\nFROM t",
),
(
"SELECT * FROM t",
"oracle",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM t\nFETCH FIRST 10 ROWS ONLY",
),
(
"SELECT * FROM t",
"db2",
10,
LimitMethod.WRAP_SQL,
"SELECT\n *\nFROM (\n SELECT\n *\n FROM t\n)\nLIMIT 10",
),
(
"SEL TOP 1000 * FROM My_table",
"teradatasql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SEL TOP 1000 * FROM My_table;",
"teradatasql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SEL TOP 1000 * FROM My_table;",
"teradatasql",
1000,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 1000\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"teradatasql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"teradatasql",
10000,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10000\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table",
"mssql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"mssql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 100\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"mssql",
10000,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10000\n *\nFROM My_table",
),
(
"SELECT TOP 1000 * FROM My_table;",
"mssql",
1000,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 1000\n *\nFROM My_table",
),
(
"""
with abc as (select * from test union select * from test1)
select TOP 100 * from currency
""",
"mssql",
1000,
LimitMethod.FORCE_LIMIT,
"""WITH abc AS (
SELECT
*
FROM test
UNION
SELECT
*
FROM test1
)
SELECT
TOP 1000
*
FROM currency""",
),
(
"SELECT DISTINCT x from tbl",
"mssql",
100,
LimitMethod.FORCE_LIMIT,
"SELECT DISTINCT\nTOP 100\n x\nFROM tbl",
),
(
"SELECT 1 as cnt",
"mssql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10\n 1 AS cnt",
),
(
"select TOP 1000 * from abc where id=1",
"mssql",
10,
LimitMethod.FORCE_LIMIT,
"SELECT\nTOP 10\n *\nFROM abc\nWHERE\n id = 1",
),
(
"SELECT * FROM birth_names -- SOME COMMENT",
"postgresql",
1000,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM birth_names /* SOME COMMENT */\nLIMIT 1000",
),
(
"SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555",
"postgresql",
1000,
LimitMethod.FORCE_LIMIT,
"""SELECT
*
FROM birth_names /* SOME COMMENT WITH LIMIT 555 */
LIMIT 1000""",
),
(
"SELECT * FROM birth_names LIMIT 555",
"postgresql",
1000,
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM birth_names\nLIMIT 1000",
),
],
)
def test_set_limit_value(
sql: str,
engine: str,
limit: int,
method: LimitMethod,
expected: str,
) -> None:
statement = SQLStatement(sql, engine)
statement.set_limit_value(limit, method)
assert statement.format() == expected
@pytest.mark.parametrize(
"kql, limit, expected",
[
("StormEvents | take 10", 100, "StormEvents | take 100"),
("StormEvents | limit 20", 10, "StormEvents | limit 10"),
(
"StormEvents | where State == 'FL' | summarize count()",
10,
"StormEvents | where State == 'FL' | summarize count() | take 10",
),
(
"StormEvents | where name has 'limit 10'",
10,
"StormEvents | where name has 'limit 10' | take 10",
),
("AnotherTable | take 5", 50, "AnotherTable | take 50"),
(
"datatable(x:int) [1, 2, 3] | take 100",
10,
"datatable(x:int) [1, 2, 3] | take 10",
),
(
"""
Table1 | where msg contains 'abc;xyz'
| limit 5
""",
10,
"""Table1 | where msg contains 'abc;xyz'
| limit 10""",
),
],
)
def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None:
statement = KustoKQLStatement(kql, "kustokql")
statement.set_limit_value(limit)
assert statement.format() == expected

View File

@@ -297,7 +297,7 @@ def test_sql_lab_insert_rls_as_subquery(
| 3 | 3 |
| 4 | 4 |""".strip()
)
assert query.executed_sql == "SELECT c FROM t\nLIMIT 6"
assert query.executed_sql == "SELECT\n c\nFROM t\nLIMIT 6"
# now with RLS
rls = RowLevelSecurityFilter(
@@ -333,7 +333,18 @@ def test_sql_lab_insert_rls_as_subquery(
)
assert (
query.executed_sql
== "SELECT c FROM (SELECT * FROM t WHERE (t.c > 5)) AS t\nLIMIT 6"
== """SELECT
c
FROM (
SELECT
*
FROM t
WHERE
(
t.c > 5
)
) AS t
LIMIT 6"""
)

View File

@@ -1104,46 +1104,6 @@ def test_unknown_select() -> None:
assert not ParsedQuery(sql).is_select()
def test_get_query_with_new_limit_comment() -> None:
"""
Test that limit is applied correctly.
"""
query = ParsedQuery("SELECT * FROM birth_names -- SOME COMMENT")
assert query.set_or_update_query_limit(1000) == (
"SELECT * FROM birth_names -- SOME COMMENT\nLIMIT 1000"
)
def test_get_query_with_new_limit_comment_with_limit() -> None:
"""
Test that limits in comments are ignored.
"""
query = ParsedQuery("SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555")
assert query.set_or_update_query_limit(1000) == (
"SELECT * FROM birth_names -- SOME COMMENT WITH LIMIT 555\nLIMIT 1000"
)
def test_get_query_with_new_limit_lower() -> None:
"""
Test that lower limits are not replaced.
"""
query = ParsedQuery("SELECT * FROM birth_names LIMIT 555")
assert query.set_or_update_query_limit(1000) == (
"SELECT * FROM birth_names LIMIT 555"
)
def test_get_query_with_new_limit_upper() -> None:
"""
Test that higher limits are replaced.
"""
query = ParsedQuery("SELECT * FROM birth_names LIMIT 2000")
assert query.set_or_update_query_limit(1000) == (
"SELECT * FROM birth_names LIMIT 1000"
)
def test_basic_breakdown_statements() -> None:
"""
Test that multiple statements are parsed correctly.