chore: 100% test coverage for SQL parsing (#33568)

This commit is contained in:
Beto Dealmeida
2025-06-04 22:18:09 -04:00
committed by GitHub
parent c9518485ba
commit edc60914f6
8 changed files with 853 additions and 113 deletions

View File

@@ -19,15 +19,18 @@
import pytest
from pytest_mock import MockerFixture
from sqlglot import Dialects, parse_one
from sqlglot import Dialects, exp, parse_one
from superset.exceptions import QueryClauseValidationException, SupersetParseError
from superset.jinja_context import JinjaTemplateProcessor
from superset.sql.parse import (
CTASMethod,
extract_tables_from_jinja_sql,
extract_tables_from_statement,
KQLTokenType,
KustoKQLStatement,
LimitMethod,
remove_quotes,
RLSMethod,
sanitize_clause,
split_kql,
@@ -35,6 +38,7 @@ from superset.sql.parse import (
SQLScript,
SQLStatement,
Table,
tokenize_kql,
)
from tests.integration_tests.conftest import with_feature_flags
@@ -776,6 +780,20 @@ Events | take 100""",
"postgresql",
["SELECT\n 1 /* extraneous comment */"],
),
(
"SHOW TABLES FROM s1 like '%order%';",
"mysql",
["SHOW TABLES FROM s1 LIKE '%order%'"],
),
(
"SELECT 1; SELECT 2; SELECT 3;",
"unknown-engine",
[
"SELECT\n 1",
"SELECT\n 2",
"SELECT\n 3",
],
),
],
)
def test_sqlscript_split(sql: str, engine: str, expected: list[str]) -> None:
@@ -795,18 +813,59 @@ def test_sqlstatement() -> None:
"sqlite",
)
assert statement.tables == {
Table(table="table1", schema=None, catalog=None),
Table(table="table2", schema=None, catalog=None),
}
assert (
statement.format()
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
)
assert str(statement) == statement.format()
assert statement.tables == {
Table(table="table1", schema=None, catalog=None),
Table(table="table2", schema=None, catalog=None),
}
assert statement.parse_predicate("a > 1") == exp.GT(
this=exp.Column(this=exp.Identifier(this="a", quoted=False)),
expression=exp.Literal(this="1", is_string=False),
)
statement = SQLStatement("SET a=1", "sqlite")
assert statement.get_settings() == {"a": "1"}
with pytest.raises(
ValueError,
match="Either statement or ast must be provided",
):
SQLStatement()
def test_kustokqlstatement() -> None:
"""
Test the `KustoKQLStatement` class.
"""
statement = KustoKQLStatement("foo | take 100", "kustokql")
assert statement.format() == "foo | take 100"
assert str(statement) == statement.format()
# doesn't support table extraction
assert statement.tables == set()
# optimize is a no-op
assert statement.optimize().format() == "foo | take 100"
# predicate parsing is also no-op
assert statement.parse_predicate("a > 1") == "a > 1"
with pytest.raises(SupersetParseError, match="Invalid engine: invalid-engine"):
KustoKQLStatement("foo | take 100", "invalid-engine")
with pytest.raises(
SupersetParseError,
match="KustoKQLStatement should have exactly one statement",
):
KustoKQLStatement("foo | take 1; bar | take 2", "kustokql")
def test_kustokqlstatement_split_script() -> None:
"""
@@ -887,11 +946,13 @@ def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements
def test_split_kql() -> None:
"""
Test the `split_kql` function.
"""
kql = """
@pytest.mark.parametrize(
"kql, expected",
[
(";Table | take 5", ["Table | take 5"]),
(";Table | take 5;", ["Table | take 5"]),
(
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day;
@@ -912,18 +973,18 @@ on Page
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
"""
assert split_kql(kql) == [
"""
""",
[
"""
let totalPagesPerDay = PageViews
| summarize by Page, Day = startofday(Timestamp)
| summarize count() by Day""",
"""
"""
let materializedScope = PageViews
| summarize by Page, Day = startofday(Timestamp)""",
"""
"""
let cachedResult = materialize(materializedScope)""",
"""
"""
cachedResult
| project Page, Day1 = Day
| join kind = inner
@@ -938,8 +999,16 @@ on Page
totalPagesPerDay
on $left.Day1 == $right.Day
| project Day1, Day2, Percentage = count_*100.0/count_1
""",
]
""",
],
),
],
)
def test_split_kql(kql: str, expected: list[str]) -> None:
"""
Test the `split_kql` function.
"""
assert split_kql(kql) == expected
@pytest.mark.parametrize(
@@ -1106,14 +1175,19 @@ def test_custom_dialect(app: None) -> None:
"vertica",
],
)
def test_is_mutating(engine: str) -> None:
@pytest.mark.parametrize(
"sql, expected",
[
("SELECT 1", False),
("with source as ( select 1 as one ) select * from source", False),
("ALTER TABLE foo ADD COLUMN bar INT", True),
],
)
def test_is_mutating(sql: str, engine: str, expected: bool) -> None:
"""
Global tests for `is_mutating`, covering all supported engines.
"""
assert not SQLStatement(
"with source as ( select 1 as one ) select * from source",
engine=engine,
).is_mutating()
assert SQLStatement(sql, engine).is_mutating() == expected
def test_optimize() -> None:
@@ -1164,6 +1238,9 @@ WHERE
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
assert SQLStatement(sql, "dremio").optimize().format() == not_optimized
# also works for scripts
assert SQLScript(sql, "sqlite").optimize().format() == optimized
def test_firebolt() -> None:
"""
@@ -1285,6 +1362,8 @@ LATERAL generate_series(1, value) AS i;
"postgresql",
None,
),
# not really valid SQL, but let's roll with it
("SELECT * FROM my_table LIMIT invalid", "postgresql", None),
],
)
def test_get_limit_value(sql: str, engine: str, expected: str) -> None:
@@ -1307,6 +1386,7 @@ def test_get_limit_value(sql: str, engine: str, expected: str) -> None:
""",
5,
),
("table | take five", None),
],
)
def test_get_kql_limit_value(kql: str, expected: str) -> None:
@@ -1492,6 +1572,13 @@ LIMIT 1000
LimitMethod.FORCE_LIMIT,
"SELECT\n *\nFROM birth_names\nLIMIT 1000",
),
(
"SELECT * FROM birth_names LIMIT 555",
"postgresql",
1000,
LimitMethod.FETCH_MANY,
"SELECT\n *\nFROM birth_names\nLIMIT 555",
),
],
)
def test_set_limit_value(
@@ -1539,11 +1626,28 @@ def test_set_limit_value(
],
)
def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None:
"""
Test the `set_limit_value` method for KustoKQLStatement.
"""
statement = KustoKQLStatement(kql, "kustokql")
statement.set_limit_value(limit)
assert statement.format() == expected
@pytest.mark.parametrize("method", [LimitMethod.WRAP_SQL, LimitMethod.FETCH_MANY])
def test_set_kql_limit_value_invalid_method(method: LimitMethod) -> None:
"""
Test that setting a limit value with an invalid method raises an error.
"""
statement = KustoKQLStatement("foo", "kustokql")
with pytest.raises(
SupersetParseError,
match="Kusto KQL only supports the FORCE_LIMIT method.",
):
statement.set_limit_value(10, method)
@pytest.mark.parametrize(
"sql, engine, expected",
[
@@ -1670,6 +1774,15 @@ FROM (
) AS t
""".strip(),
),
(
"SELECT t.foo FROM some_table AS t",
{},
"""
SELECT
t.foo
FROM some_table AS t
""".strip(),
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
@@ -1947,6 +2060,17 @@ def test_rls_subquery_transformer(
assert statement.format() == expected
def test_rls_invalid_method(mocker: MockerFixture) -> None:
"""
Test that an invalid RLS method raises an error.
"""
statement = SQLStatement("SELECT 1", "postgresql")
predicates = mocker.MagicMock()
with pytest.raises(ValueError, match="Invalid RLS method: invalid"):
statement.apply_rls("catalog1", "schema1", predicates, "invalid") # type: ignore
@pytest.mark.parametrize(
"sql, rules, expected",
[
@@ -2171,6 +2295,17 @@ JOIN other_table
)
""".strip(),
),
(
"SELECT * FROM table JOIN other_table",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
JOIN other_table
ON other_table.id = 42
""".strip(),
),
(
"""
SELECT *
@@ -2237,6 +2372,18 @@ WHERE
other_table.id = 42
""".strip(),
),
(
"INSERT INTO some_table (col1, col2) VALUES (1, 2)",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
INSERT INTO some_table (
col1,
col2
)
VALUES
(1, 2)
""".strip(),
),
],
)
def test_rls_predicate_transformer(
@@ -2427,7 +2574,7 @@ def test_sanitize_clause(sql: str, expected: str | Exception, engine: str) -> No
],
)
@pytest.mark.parametrize(
"macro,expected",
"macro, expected",
[
(
"latest_partition('foo.bar')",
@@ -2464,7 +2611,7 @@ def test_extract_tables_from_jinja_sql(
assert (
extract_tables_from_jinja_sql(
sql=f"'{{{{ {engine}.{macro} }}}}'",
database=mocker.Mock(),
database=mocker.MagicMock(backend=engine),
)
== expected
)
@@ -2475,10 +2622,154 @@ def test_extract_tables_from_jinja_sql_disabled(mocker: MockerFixture) -> None:
"""
Test the function when the feature flag is disabled.
"""
database = mocker.Mock()
database = mocker.MagicMock()
database.db_engine_spec.engine = "mssql"
assert extract_tables_from_jinja_sql(
sql="SELECT 1 FROM t",
database=database,
) == {Table("t")}
def test_extract_tables_from_jinja_sql_invalid_function(mocker: MockerFixture) -> None:
"""
Test the function with an invalid function.
"""
database = mocker.MagicMock(backend="postgresql")
processor = JinjaTemplateProcessor(database)
processor.env.globals["my_table"] = lambda: "t"
mocker.patch(
"superset.jinja_context.get_template_processor",
return_value=processor,
)
assert extract_tables_from_jinja_sql(
sql="SELECT * FROM {{ my_table() }}",
database=database,
) == {Table("t")}
@pytest.mark.parametrize(
"sql, engine, expected",
[
("SELECT * FROM users", "postgresql", True),
("WITH cte AS (SELECT * FROM users) SELECT * FROM cte", "postgresql", True),
("CREATE TABLE users AS SELECT * FROM users", "postgresql", False),
("ALTER TABLE users ADD COLUMN age INT", "postgresql", False),
("SET @value = 42", "postgresql", False),
],
)
def test_sqlstatement_is_select(sql: str, engine: str, expected: bool) -> None:
"""
Test the `SQLStatement.is_select()` method.
"""
assert SQLStatement(sql, engine).is_select() == expected
@pytest.mark.parametrize(
"kql, expected",
[
("StormEvents | take 10", True),
("StormEvents | limit 20", True),
("StormEvents | where State == 'FL' | summarize count()", True),
("StormEvents | where name has 'limit 10'", True),
("AnotherTable | take 5", True),
("datatable(x:int) [1, 2, 3] | take 100", True),
(".create table StormEvents (x:int)", False),
(".ingest inline into table StormEvents <| StormEvents | take 10", False),
],
)
def test_kqlstatement_is_select(kql: str, expected: bool) -> None:
"""
Test the `KustoKQLStatement.is_select()` method.
"""
assert KustoKQLStatement(kql, "kustokql").is_select() == expected
def test_remove_quotes() -> None:
"""
Test the `remove_quotes` helper function.
"""
assert remove_quotes(None) is None
assert remove_quotes('"foo"') == "foo"
assert remove_quotes("'foo'") == "foo"
assert remove_quotes("`foo`") == "foo"
assert remove_quotes("'foo`") == "'foo`"
@pytest.mark.parametrize(
"sql, engine, expected",
[
("SELECT * FROM table", "postgresql", False),
("SELECT VERSION()", "postgresql", True),
("SELECT query_to_xml()", "postgresql", True),
("WITH cte AS (SELECT * FROM table) SELECT * FROM cte", "postgresql", False),
(
"""
SELECT *
FROM query_to_xml('SELECT * from some_table WHERE id = 42')
""",
"postgresql",
True,
),
("Table | limit 10", "kustokql", False),
],
)
def test_check_functions_present(sql: str, engine: str, expected: bool) -> None:
"""
Check the `check_functions_present` method.
"""
functions = {"version", "query_to_xml"}
assert SQLScript(sql, engine).check_functions_present(functions) == expected
@pytest.mark.parametrize(
"kql, expected",
[
(
"StormEvents | take 10",
[
(KQLTokenType.WORD, "StormEvents"),
(KQLTokenType.WHITESPACE, " "),
(KQLTokenType.OTHER, "|"),
(KQLTokenType.WHITESPACE, " "),
(KQLTokenType.WORD, "take"),
(KQLTokenType.WHITESPACE, " "),
(KQLTokenType.NUMBER, "10"),
],
),
("'test'", [(KQLTokenType.STRING, "'test'")]),
("```test```", [(KQLTokenType.STRING, "```test```")]),
],
)
def test_tokenize_kql(kql: str, expected: list[tuple[KQLTokenType, str]]) -> None:
"""
Test the `tokenize_kql` function.
"""
assert tokenize_kql(kql) == expected
@pytest.mark.parametrize(
"sql, engine, expected",
[
("a = 1", "postgresql", False),
("(SELECT * FROM table)", "postgresql", True),
("SELECT * FROM table", "postgresql", False),
("SELECT * FROM (SELECT 1)", "postgresql", True),
("SELECT * FROM (SELECT 1) AS subquery", "postgresql", True),
("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True),
("SELECT * FROM table WHERE EXISTS (SELECT 1)", "postgresql", True),
("SELECT * FROM table WHERE NOT EXISTS (SELECT 1)", "postgresql", True),
(
"SELECT * FROM table WHERE id IN (SELECT id FROM other_table)",
"postgresql",
True,
),
],
)
def test_has_subquery(sql: str, engine: str, expected: bool) -> None:
"""
Test the `has_subquery` method.
"""
assert SQLStatement(sql, engine).has_subquery() == expected