mirror of
https://github.com/apache/superset.git
synced 2026-04-22 09:35:23 +00:00
chore: 100% test coverage for SQL parsing (#33568)
This commit is contained in:
@@ -19,15 +19,18 @@
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlglot import Dialects, parse_one
|
||||
from sqlglot import Dialects, exp, parse_one
|
||||
|
||||
from superset.exceptions import QueryClauseValidationException, SupersetParseError
|
||||
from superset.jinja_context import JinjaTemplateProcessor
|
||||
from superset.sql.parse import (
|
||||
CTASMethod,
|
||||
extract_tables_from_jinja_sql,
|
||||
extract_tables_from_statement,
|
||||
KQLTokenType,
|
||||
KustoKQLStatement,
|
||||
LimitMethod,
|
||||
remove_quotes,
|
||||
RLSMethod,
|
||||
sanitize_clause,
|
||||
split_kql,
|
||||
@@ -35,6 +38,7 @@ from superset.sql.parse import (
|
||||
SQLScript,
|
||||
SQLStatement,
|
||||
Table,
|
||||
tokenize_kql,
|
||||
)
|
||||
from tests.integration_tests.conftest import with_feature_flags
|
||||
|
||||
@@ -776,6 +780,20 @@ Events | take 100""",
|
||||
"postgresql",
|
||||
["SELECT\n 1 /* extraneous comment */"],
|
||||
),
|
||||
(
|
||||
"SHOW TABLES FROM s1 like '%order%';",
|
||||
"mysql",
|
||||
["SHOW TABLES FROM s1 LIKE '%order%'"],
|
||||
),
|
||||
(
|
||||
"SELECT 1; SELECT 2; SELECT 3;",
|
||||
"unknown-engine",
|
||||
[
|
||||
"SELECT\n 1",
|
||||
"SELECT\n 2",
|
||||
"SELECT\n 3",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_sqlscript_split(sql: str, engine: str, expected: list[str]) -> None:
|
||||
@@ -795,18 +813,59 @@ def test_sqlstatement() -> None:
|
||||
"sqlite",
|
||||
)
|
||||
|
||||
assert statement.tables == {
|
||||
Table(table="table1", schema=None, catalog=None),
|
||||
Table(table="table2", schema=None, catalog=None),
|
||||
}
|
||||
assert (
|
||||
statement.format()
|
||||
== "SELECT\n *\nFROM table1\nUNION ALL\nSELECT\n *\nFROM table2"
|
||||
)
|
||||
assert str(statement) == statement.format()
|
||||
|
||||
assert statement.tables == {
|
||||
Table(table="table1", schema=None, catalog=None),
|
||||
Table(table="table2", schema=None, catalog=None),
|
||||
}
|
||||
|
||||
assert statement.parse_predicate("a > 1") == exp.GT(
|
||||
this=exp.Column(this=exp.Identifier(this="a", quoted=False)),
|
||||
expression=exp.Literal(this="1", is_string=False),
|
||||
)
|
||||
|
||||
statement = SQLStatement("SET a=1", "sqlite")
|
||||
assert statement.get_settings() == {"a": "1"}
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Either statement or ast must be provided",
|
||||
):
|
||||
SQLStatement()
|
||||
|
||||
|
||||
def test_kustokqlstatement() -> None:
|
||||
"""
|
||||
Test the `KustoKQLStatement` class.
|
||||
"""
|
||||
statement = KustoKQLStatement("foo | take 100", "kustokql")
|
||||
|
||||
assert statement.format() == "foo | take 100"
|
||||
assert str(statement) == statement.format()
|
||||
|
||||
# doesn't support table extraction
|
||||
assert statement.tables == set()
|
||||
|
||||
# optimize is a no-op
|
||||
assert statement.optimize().format() == "foo | take 100"
|
||||
|
||||
# predicate parsing is also no-op
|
||||
assert statement.parse_predicate("a > 1") == "a > 1"
|
||||
|
||||
with pytest.raises(SupersetParseError, match="Invalid engine: invalid-engine"):
|
||||
KustoKQLStatement("foo | take 100", "invalid-engine")
|
||||
|
||||
with pytest.raises(
|
||||
SupersetParseError,
|
||||
match="KustoKQLStatement should have exactly one statement",
|
||||
):
|
||||
KustoKQLStatement("foo | take 1; bar | take 2", "kustokql")
|
||||
|
||||
|
||||
def test_kustokqlstatement_split_script() -> None:
|
||||
"""
|
||||
@@ -887,11 +946,13 @@ def test_kustokql_statement_split_special(kql: str, statements: int) -> None:
|
||||
assert len(KustoKQLStatement.split_script(kql, "kustokql")) == statements
|
||||
|
||||
|
||||
def test_split_kql() -> None:
|
||||
"""
|
||||
Test the `split_kql` function.
|
||||
"""
|
||||
kql = """
|
||||
@pytest.mark.parametrize(
|
||||
"kql, expected",
|
||||
[
|
||||
(";Table | take 5", ["Table | take 5"]),
|
||||
(";Table | take 5;", ["Table | take 5"]),
|
||||
(
|
||||
"""
|
||||
let totalPagesPerDay = PageViews
|
||||
| summarize by Page, Day = startofday(Timestamp)
|
||||
| summarize count() by Day;
|
||||
@@ -912,18 +973,18 @@ on Page
|
||||
totalPagesPerDay
|
||||
on $left.Day1 == $right.Day
|
||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||
"""
|
||||
assert split_kql(kql) == [
|
||||
"""
|
||||
""",
|
||||
[
|
||||
"""
|
||||
let totalPagesPerDay = PageViews
|
||||
| summarize by Page, Day = startofday(Timestamp)
|
||||
| summarize count() by Day""",
|
||||
"""
|
||||
"""
|
||||
let materializedScope = PageViews
|
||||
| summarize by Page, Day = startofday(Timestamp)""",
|
||||
"""
|
||||
"""
|
||||
let cachedResult = materialize(materializedScope)""",
|
||||
"""
|
||||
"""
|
||||
cachedResult
|
||||
| project Page, Day1 = Day
|
||||
| join kind = inner
|
||||
@@ -938,8 +999,16 @@ on Page
|
||||
totalPagesPerDay
|
||||
on $left.Day1 == $right.Day
|
||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||
""",
|
||||
]
|
||||
""",
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_split_kql(kql: str, expected: list[str]) -> None:
|
||||
"""
|
||||
Test the `split_kql` function.
|
||||
"""
|
||||
assert split_kql(kql) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
@@ -1106,14 +1175,19 @@ def test_custom_dialect(app: None) -> None:
|
||||
"vertica",
|
||||
],
|
||||
)
|
||||
def test_is_mutating(engine: str) -> None:
|
||||
@pytest.mark.parametrize(
|
||||
"sql, expected",
|
||||
[
|
||||
("SELECT 1", False),
|
||||
("with source as ( select 1 as one ) select * from source", False),
|
||||
("ALTER TABLE foo ADD COLUMN bar INT", True),
|
||||
],
|
||||
)
|
||||
def test_is_mutating(sql: str, engine: str, expected: bool) -> None:
|
||||
"""
|
||||
Global tests for `is_mutating`, covering all supported engines.
|
||||
"""
|
||||
assert not SQLStatement(
|
||||
"with source as ( select 1 as one ) select * from source",
|
||||
engine=engine,
|
||||
).is_mutating()
|
||||
assert SQLStatement(sql, engine).is_mutating() == expected
|
||||
|
||||
|
||||
def test_optimize() -> None:
|
||||
@@ -1164,6 +1238,9 @@ WHERE
|
||||
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
|
||||
assert SQLStatement(sql, "dremio").optimize().format() == not_optimized
|
||||
|
||||
# also works for scripts
|
||||
assert SQLScript(sql, "sqlite").optimize().format() == optimized
|
||||
|
||||
|
||||
def test_firebolt() -> None:
|
||||
"""
|
||||
@@ -1285,6 +1362,8 @@ LATERAL generate_series(1, value) AS i;
|
||||
"postgresql",
|
||||
None,
|
||||
),
|
||||
# not really valid SQL, but let's roll with it
|
||||
("SELECT * FROM my_table LIMIT invalid", "postgresql", None),
|
||||
],
|
||||
)
|
||||
def test_get_limit_value(sql: str, engine: str, expected: str) -> None:
|
||||
@@ -1307,6 +1386,7 @@ def test_get_limit_value(sql: str, engine: str, expected: str) -> None:
|
||||
""",
|
||||
5,
|
||||
),
|
||||
("table | take five", None),
|
||||
],
|
||||
)
|
||||
def test_get_kql_limit_value(kql: str, expected: str) -> None:
|
||||
@@ -1492,6 +1572,13 @@ LIMIT 1000
|
||||
LimitMethod.FORCE_LIMIT,
|
||||
"SELECT\n *\nFROM birth_names\nLIMIT 1000",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM birth_names LIMIT 555",
|
||||
"postgresql",
|
||||
1000,
|
||||
LimitMethod.FETCH_MANY,
|
||||
"SELECT\n *\nFROM birth_names\nLIMIT 555",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_set_limit_value(
|
||||
@@ -1539,11 +1626,28 @@ def test_set_limit_value(
|
||||
],
|
||||
)
|
||||
def test_set_kql_limit_value(kql: str, limit: int, expected: str) -> None:
|
||||
"""
|
||||
Test the `set_limit_value` method for KustoKQLStatement.
|
||||
"""
|
||||
statement = KustoKQLStatement(kql, "kustokql")
|
||||
statement.set_limit_value(limit)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize("method", [LimitMethod.WRAP_SQL, LimitMethod.FETCH_MANY])
|
||||
def test_set_kql_limit_value_invalid_method(method: LimitMethod) -> None:
|
||||
"""
|
||||
Test that setting a limit value with an invalid method raises an error.
|
||||
"""
|
||||
statement = KustoKQLStatement("foo", "kustokql")
|
||||
|
||||
with pytest.raises(
|
||||
SupersetParseError,
|
||||
match="Kusto KQL only supports the FORCE_LIMIT method.",
|
||||
):
|
||||
statement.set_limit_value(10, method)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, engine, expected",
|
||||
[
|
||||
@@ -1670,6 +1774,15 @@ FROM (
|
||||
) AS t
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t",
|
||||
{},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM some_table AS t
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
@@ -1947,6 +2060,17 @@ def test_rls_subquery_transformer(
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
def test_rls_invalid_method(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that an invalid RLS method raises an error.
|
||||
"""
|
||||
statement = SQLStatement("SELECT 1", "postgresql")
|
||||
predicates = mocker.MagicMock()
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid RLS method: invalid"):
|
||||
statement.apply_rls("catalog1", "schema1", predicates, "invalid") # type: ignore
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, rules, expected",
|
||||
[
|
||||
@@ -2171,6 +2295,17 @@ JOIN other_table
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
JOIN other_table
|
||||
ON other_table.id = 42
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"""
|
||||
SELECT *
|
||||
@@ -2237,6 +2372,18 @@ WHERE
|
||||
other_table.id = 42
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"INSERT INTO some_table (col1, col2) VALUES (1, 2)",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
INSERT INTO some_table (
|
||||
col1,
|
||||
col2
|
||||
)
|
||||
VALUES
|
||||
(1, 2)
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_transformer(
|
||||
@@ -2427,7 +2574,7 @@ def test_sanitize_clause(sql: str, expected: str | Exception, engine: str) -> No
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize(
|
||||
"macro,expected",
|
||||
"macro, expected",
|
||||
[
|
||||
(
|
||||
"latest_partition('foo.bar')",
|
||||
@@ -2464,7 +2611,7 @@ def test_extract_tables_from_jinja_sql(
|
||||
assert (
|
||||
extract_tables_from_jinja_sql(
|
||||
sql=f"'{{{{ {engine}.{macro} }}}}'",
|
||||
database=mocker.Mock(),
|
||||
database=mocker.MagicMock(backend=engine),
|
||||
)
|
||||
== expected
|
||||
)
|
||||
@@ -2475,10 +2622,154 @@ def test_extract_tables_from_jinja_sql_disabled(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the function when the feature flag is disabled.
|
||||
"""
|
||||
database = mocker.Mock()
|
||||
database = mocker.MagicMock()
|
||||
database.db_engine_spec.engine = "mssql"
|
||||
|
||||
assert extract_tables_from_jinja_sql(
|
||||
sql="SELECT 1 FROM t",
|
||||
database=database,
|
||||
) == {Table("t")}
|
||||
|
||||
|
||||
def test_extract_tables_from_jinja_sql_invalid_function(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test the function with an invalid function.
|
||||
"""
|
||||
database = mocker.MagicMock(backend="postgresql")
|
||||
|
||||
processor = JinjaTemplateProcessor(database)
|
||||
processor.env.globals["my_table"] = lambda: "t"
|
||||
mocker.patch(
|
||||
"superset.jinja_context.get_template_processor",
|
||||
return_value=processor,
|
||||
)
|
||||
|
||||
assert extract_tables_from_jinja_sql(
|
||||
sql="SELECT * FROM {{ my_table() }}",
|
||||
database=database,
|
||||
) == {Table("t")}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, engine, expected",
|
||||
[
|
||||
("SELECT * FROM users", "postgresql", True),
|
||||
("WITH cte AS (SELECT * FROM users) SELECT * FROM cte", "postgresql", True),
|
||||
("CREATE TABLE users AS SELECT * FROM users", "postgresql", False),
|
||||
("ALTER TABLE users ADD COLUMN age INT", "postgresql", False),
|
||||
("SET @value = 42", "postgresql", False),
|
||||
],
|
||||
)
|
||||
def test_sqlstatement_is_select(sql: str, engine: str, expected: bool) -> None:
|
||||
"""
|
||||
Test the `SQLStatement.is_select()` method.
|
||||
"""
|
||||
assert SQLStatement(sql, engine).is_select() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kql, expected",
|
||||
[
|
||||
("StormEvents | take 10", True),
|
||||
("StormEvents | limit 20", True),
|
||||
("StormEvents | where State == 'FL' | summarize count()", True),
|
||||
("StormEvents | where name has 'limit 10'", True),
|
||||
("AnotherTable | take 5", True),
|
||||
("datatable(x:int) [1, 2, 3] | take 100", True),
|
||||
(".create table StormEvents (x:int)", False),
|
||||
(".ingest inline into table StormEvents <| StormEvents | take 10", False),
|
||||
],
|
||||
)
|
||||
def test_kqlstatement_is_select(kql: str, expected: bool) -> None:
|
||||
"""
|
||||
Test the `KustoKQLStatement.is_select()` method.
|
||||
"""
|
||||
assert KustoKQLStatement(kql, "kustokql").is_select() == expected
|
||||
|
||||
|
||||
def test_remove_quotes() -> None:
|
||||
"""
|
||||
Test the `remove_quotes` helper function.
|
||||
"""
|
||||
assert remove_quotes(None) is None
|
||||
assert remove_quotes('"foo"') == "foo"
|
||||
assert remove_quotes("'foo'") == "foo"
|
||||
assert remove_quotes("`foo`") == "foo"
|
||||
assert remove_quotes("'foo`") == "'foo`"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, engine, expected",
|
||||
[
|
||||
("SELECT * FROM table", "postgresql", False),
|
||||
("SELECT VERSION()", "postgresql", True),
|
||||
("SELECT query_to_xml()", "postgresql", True),
|
||||
("WITH cte AS (SELECT * FROM table) SELECT * FROM cte", "postgresql", False),
|
||||
(
|
||||
"""
|
||||
SELECT *
|
||||
FROM query_to_xml('SELECT * from some_table WHERE id = 42')
|
||||
""",
|
||||
"postgresql",
|
||||
True,
|
||||
),
|
||||
("Table | limit 10", "kustokql", False),
|
||||
],
|
||||
)
|
||||
def test_check_functions_present(sql: str, engine: str, expected: bool) -> None:
|
||||
"""
|
||||
Check the `check_functions_present` method.
|
||||
"""
|
||||
functions = {"version", "query_to_xml"}
|
||||
assert SQLScript(sql, engine).check_functions_present(functions) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kql, expected",
|
||||
[
|
||||
(
|
||||
"StormEvents | take 10",
|
||||
[
|
||||
(KQLTokenType.WORD, "StormEvents"),
|
||||
(KQLTokenType.WHITESPACE, " "),
|
||||
(KQLTokenType.OTHER, "|"),
|
||||
(KQLTokenType.WHITESPACE, " "),
|
||||
(KQLTokenType.WORD, "take"),
|
||||
(KQLTokenType.WHITESPACE, " "),
|
||||
(KQLTokenType.NUMBER, "10"),
|
||||
],
|
||||
),
|
||||
("'test'", [(KQLTokenType.STRING, "'test'")]),
|
||||
("```test```", [(KQLTokenType.STRING, "```test```")]),
|
||||
],
|
||||
)
|
||||
def test_tokenize_kql(kql: str, expected: list[tuple[KQLTokenType, str]]) -> None:
|
||||
"""
|
||||
Test the `tokenize_kql` function.
|
||||
"""
|
||||
assert tokenize_kql(kql) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, engine, expected",
|
||||
[
|
||||
("a = 1", "postgresql", False),
|
||||
("(SELECT * FROM table)", "postgresql", True),
|
||||
("SELECT * FROM table", "postgresql", False),
|
||||
("SELECT * FROM (SELECT 1)", "postgresql", True),
|
||||
("SELECT * FROM (SELECT 1) AS subquery", "postgresql", True),
|
||||
("WITH cte AS (SELECT 1) SELECT * FROM cte", "postgresql", True),
|
||||
("SELECT * FROM table WHERE EXISTS (SELECT 1)", "postgresql", True),
|
||||
("SELECT * FROM table WHERE NOT EXISTS (SELECT 1)", "postgresql", True),
|
||||
(
|
||||
"SELECT * FROM table WHERE id IN (SELECT id FROM other_table)",
|
||||
"postgresql",
|
||||
True,
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_has_subquery(sql: str, engine: str, expected: bool) -> None:
|
||||
"""
|
||||
Test the `has_subquery` method.
|
||||
"""
|
||||
assert SQLStatement(sql, engine).has_subquery() == expected
|
||||
|
||||
Reference in New Issue
Block a user