mirror of
https://github.com/apache/superset.git
synced 2026-04-30 21:44:40 +00:00
WIP
This commit is contained in:
@@ -20,6 +20,7 @@ from typing import Optional
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
import sqlglot
|
||||
import sqlparse
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import text
|
||||
@@ -41,6 +42,8 @@ from superset.sql_parse import (
|
||||
insert_rls_in_predicate,
|
||||
KustoKQLStatement,
|
||||
ParsedQuery,
|
||||
RLSAsPredicate,
|
||||
RLSAsSubquery,
|
||||
sanitize_clause,
|
||||
split_kql,
|
||||
SQLScript,
|
||||
@@ -119,8 +122,9 @@ def test_extract_tables_subselect() -> None:
|
||||
"""
|
||||
Test that tables inside subselects are parsed correctly.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT sub.*
|
||||
FROM (
|
||||
SELECT *
|
||||
@@ -129,10 +133,13 @@ FROM (
|
||||
) sub, s2.t2
|
||||
WHERE sub.resolution = 'NONE'
|
||||
"""
|
||||
) == {Table("t1", "s1"), Table("t2", "s2")}
|
||||
)
|
||||
== {Table("t1", "s1"), Table("t2", "s2")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT sub.*
|
||||
FROM (
|
||||
SELECT *
|
||||
@@ -141,10 +148,13 @@ FROM (
|
||||
) sub
|
||||
WHERE sub.resolution = 'NONE'
|
||||
"""
|
||||
) == {Table("t1", "s1")}
|
||||
)
|
||||
== {Table("t1", "s1")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT * FROM t1
|
||||
WHERE s11 > ANY (
|
||||
SELECT COUNT(*) /* no hint */ FROM t2
|
||||
@@ -156,7 +166,9 @@ WHERE s11 > ANY (
|
||||
)
|
||||
)
|
||||
"""
|
||||
) == {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
|
||||
)
|
||||
== {Table("t1"), Table("t2"), Table("t3"), Table("t4")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_select_in_expression() -> None:
|
||||
@@ -227,24 +239,30 @@ def test_extract_tables_select_array() -> None:
|
||||
"""
|
||||
Test that queries selecting arrays work as expected.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT ARRAY[1, 2, 3] AS my_array
|
||||
FROM t1 LIMIT 10
|
||||
"""
|
||||
) == {Table("t1")}
|
||||
)
|
||||
== {Table("t1")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_select_if() -> None:
|
||||
"""
|
||||
Test that queries with an ``IF`` work as expected.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT IF(CARDINALITY(my_array) >= 3, my_array[3], NULL)
|
||||
FROM t1 LIMIT 10
|
||||
"""
|
||||
) == {Table("t1")}
|
||||
)
|
||||
== {Table("t1")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_with_catalog() -> None:
|
||||
@@ -312,29 +330,38 @@ def test_extract_tables_where_subquery() -> None:
|
||||
"""
|
||||
Test that tables in a ``WHERE`` subquery are parsed correctly.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey = (SELECT max(regionkey) FROM t2)
|
||||
"""
|
||||
) == {Table("t1"), Table("t2")}
|
||||
)
|
||||
== {Table("t1"), Table("t2")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE regionkey IN (SELECT regionkey FROM t2)
|
||||
"""
|
||||
) == {Table("t1"), Table("t2")}
|
||||
)
|
||||
== {Table("t1"), Table("t2")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT name
|
||||
FROM t1
|
||||
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
|
||||
"""
|
||||
) == {Table("t1"), Table("t2")}
|
||||
)
|
||||
== {Table("t1"), Table("t2")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_describe() -> None:
|
||||
@@ -348,12 +375,15 @@ def test_extract_tables_show_partitions() -> None:
|
||||
"""
|
||||
Test ``SHOW PARTITIONS``.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SHOW PARTITIONS FROM orders
|
||||
WHERE ds >= '2013-01-01' ORDER BY ds DESC
|
||||
"""
|
||||
) == {Table("orders")}
|
||||
)
|
||||
== {Table("orders")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_join() -> None:
|
||||
@@ -365,8 +395,9 @@ def test_extract_tables_join() -> None:
|
||||
Table("t2"),
|
||||
}
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
JOIN (
|
||||
@@ -377,10 +408,13 @@ JOIN (
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
LEFT INNER JOIN (
|
||||
@@ -391,10 +425,13 @@ LEFT INNER JOIN (
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
RIGHT OUTER JOIN (
|
||||
@@ -405,10 +442,13 @@ RIGHT OUTER JOIN (
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
FULL OUTER JOIN (
|
||||
@@ -419,15 +459,18 @@ FULL OUTER JOIN (
|
||||
) b
|
||||
ON a.date = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_semi_join() -> None:
|
||||
"""
|
||||
Test ``LEFT SEMI JOIN``.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT a.date, b.name
|
||||
FROM left_table a
|
||||
LEFT SEMI JOIN (
|
||||
@@ -438,15 +481,18 @@ LEFT SEMI JOIN (
|
||||
) b
|
||||
ON a.data = b.date
|
||||
"""
|
||||
) == {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
== {Table("left_table"), Table("right_table")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_combinations() -> None:
|
||||
"""
|
||||
Test a complex case with nested queries.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT * FROM t1
|
||||
WHERE s11 > ANY (
|
||||
SELECT * FROM t1 UNION ALL SELECT * FROM (
|
||||
@@ -460,10 +506,13 @@ WHERE s11 > ANY (
|
||||
)
|
||||
)
|
||||
"""
|
||||
) == {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
|
||||
)
|
||||
== {Table("t1"), Table("t3"), Table("t4"), Table("t6")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT * FROM (
|
||||
SELECT * FROM (
|
||||
SELECT * FROM (
|
||||
@@ -472,45 +521,56 @@ SELECT * FROM (
|
||||
) AS S2
|
||||
) AS S3
|
||||
"""
|
||||
) == {Table("EmployeeS")}
|
||||
)
|
||||
== {Table("EmployeeS")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_with() -> None:
|
||||
"""
|
||||
Test ``WITH``.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
WITH
|
||||
x AS (SELECT a FROM t1),
|
||||
y AS (SELECT a AS b FROM t2),
|
||||
z AS (SELECT b AS c FROM t3)
|
||||
SELECT c FROM z
|
||||
"""
|
||||
) == {Table("t1"), Table("t2"), Table("t3")}
|
||||
)
|
||||
== {Table("t1"), Table("t2"), Table("t3")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
WITH
|
||||
x AS (SELECT a FROM t1),
|
||||
y AS (SELECT a AS b FROM x),
|
||||
z AS (SELECT b AS c FROM y)
|
||||
SELECT c FROM z
|
||||
"""
|
||||
) == {Table("t1")}
|
||||
)
|
||||
== {Table("t1")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_reusing_aliases() -> None:
|
||||
"""
|
||||
Test that the parser follows aliases.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
with q1 as ( select key from q2 where key = '5'),
|
||||
q2 as ( select key from src where key = '5')
|
||||
select * from (select key from q1) a
|
||||
"""
|
||||
) == {Table("src")}
|
||||
)
|
||||
== {Table("src")}
|
||||
)
|
||||
|
||||
# weird query with circular dependency
|
||||
assert (
|
||||
@@ -547,8 +607,9 @@ def test_extract_tables_complex() -> None:
|
||||
"""
|
||||
Test a few complex queries.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT sum(m_examples) AS "sum__m_example"
|
||||
FROM (
|
||||
SELECT
|
||||
@@ -569,23 +630,29 @@ FROM (
|
||||
ORDER BY "sum__m_example" DESC
|
||||
LIMIT 10;
|
||||
"""
|
||||
) == {
|
||||
Table("my_l_table"),
|
||||
Table("my_b_table"),
|
||||
Table("my_t_table"),
|
||||
Table("inner_table"),
|
||||
}
|
||||
)
|
||||
== {
|
||||
Table("my_l_table"),
|
||||
Table("my_b_table"),
|
||||
Table("my_t_table"),
|
||||
Table("inner_table"),
|
||||
}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT *
|
||||
FROM table_a AS a, table_b AS b, table_c as c
|
||||
WHERE a.id = b.id and b.id = c.id
|
||||
"""
|
||||
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||
)
|
||||
== {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT somecol AS somecol
|
||||
FROM (
|
||||
WITH bla AS (
|
||||
@@ -629,51 +696,63 @@ FROM (
|
||||
LIMIT 50000
|
||||
)
|
||||
"""
|
||||
) == {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
|
||||
)
|
||||
== {Table("a"), Table("b"), Table("c"), Table("d"), Table("e"), Table("f")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_mixed_from_clause() -> None:
|
||||
"""
|
||||
Test that the parser handles a ``FROM`` clause with table and subselect.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
SELECT *
|
||||
FROM table_a AS a, (select * from table_b) AS b, table_c as c
|
||||
WHERE a.id = b.id and b.id = c.id
|
||||
"""
|
||||
) == {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||
)
|
||||
== {Table("table_a"), Table("table_b"), Table("table_c")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_nested_select() -> None:
|
||||
"""
|
||||
Test that the parser handles selects inside functions.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
|
||||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
|
||||
""",
|
||||
"mysql",
|
||||
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
"mysql",
|
||||
)
|
||||
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
)
|
||||
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
|
||||
from INFORMATION_SCHEMA.COLUMNS
|
||||
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
|
||||
""",
|
||||
"mysql",
|
||||
) == {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
"mysql",
|
||||
)
|
||||
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_complex_cte_with_prefix() -> None:
|
||||
"""
|
||||
Test that the parser handles CTEs with prefixes.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
WITH CTE__test (SalesPersonID, SalesOrderID, SalesYear)
|
||||
AS (
|
||||
SELECT SalesPersonID, SalesOrderID, YEAR(OrderDate) AS SalesYear
|
||||
@@ -685,21 +764,26 @@ FROM CTE__test
|
||||
GROUP BY SalesYear, SalesPersonID
|
||||
ORDER BY SalesPersonID, SalesYear;
|
||||
"""
|
||||
) == {Table("SalesOrderHeader")}
|
||||
)
|
||||
== {Table("SalesOrderHeader")}
|
||||
)
|
||||
|
||||
|
||||
def test_extract_tables_identifier_list_with_keyword_as_alias() -> None:
|
||||
"""
|
||||
Test that aliases that are keywords are parsed correctly.
|
||||
"""
|
||||
assert extract_tables(
|
||||
"""
|
||||
assert (
|
||||
extract_tables(
|
||||
"""
|
||||
WITH
|
||||
f AS (SELECT * FROM foo),
|
||||
match AS (SELECT * FROM f)
|
||||
SELECT * FROM match
|
||||
"""
|
||||
) == {Table("foo")}
|
||||
)
|
||||
== {Table("foo")}
|
||||
)
|
||||
|
||||
|
||||
def test_update() -> None:
|
||||
@@ -1841,7 +1925,7 @@ def test_sqlquery() -> None:
|
||||
script = SQLScript("SELECT 1; SELECT 2;", "sqlite")
|
||||
|
||||
assert len(script.statements) == 2
|
||||
assert script.format() == "SELECT\n 1;\nSELECT\n 2"
|
||||
assert script.format() == "SELECT\n 1;\nSELECT\n 2;"
|
||||
assert script.statements[0].format() == "SELECT\n 1"
|
||||
|
||||
script = SQLScript("SET a=1; SET a=2; SELECT 3;", "sqlite")
|
||||
@@ -2058,3 +2142,120 @@ on $left.Day1 == $right.Day
|
||||
| project Day1, Day2, Percentage = count_*100.0/count_1
|
||||
""",
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,rules,expected",
|
||||
[
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t",
|
||||
{Table("some_table"): "id = 42"},
|
||||
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
||||
{Table("some_table"): "id = 42"},
|
||||
(
|
||||
"SELECT t.foo FROM (SELECT * FROM some_table WHERE id = 42) AS t "
|
||||
"WHERE bar = 'baz'"
|
||||
),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM schema1.some_table AS t",
|
||||
{Table("some_table", "schema1"): "id = 42"},
|
||||
"SELECT t.foo FROM (SELECT * FROM schema1.some_table WHERE id = 42) AS t",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM schema1.some_table AS t",
|
||||
{Table("some_table", "schema2"): "id = 42"},
|
||||
"SELECT t.foo FROM schema1.some_table AS t",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"SELECT t.foo FROM (SELECT * FROM catalog1.schema1.some_table WHERE id = 42) AS t",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog2"): "id = 42"},
|
||||
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_RLSAsSubquery(sql: str, rules: dict[Table, str], expected: str) -> None:
|
||||
"""
|
||||
Test the `RLSAsSubquery` transformer.
|
||||
"""
|
||||
statement = sqlglot.parse_one(sql)
|
||||
transformer = RLSAsSubquery(rules)
|
||||
assert str(statement.transform(transformer)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,rules,expected",
|
||||
[
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t",
|
||||
{Table("some_table"): "id = 42"},
|
||||
"SELECT t.foo FROM some_table AS t WHERE id = 42",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
||||
{Table("some_table"): "id = 42"},
|
||||
"SELECT t.foo FROM some_table AS t WHERE id = 42 AND bar = 'baz'",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_RLSAsPredicate(sql: str, rules: dict[Table, str], expected: str) -> None:
|
||||
"""
|
||||
Test the `RLSAsPredicate` transformer.
|
||||
"""
|
||||
statement = sqlglot.parse_one(sql)
|
||||
transformer = RLSAsPredicate(rules)
|
||||
assert str(statement.transform(transformer)) == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,engine,limit,force,expected",
|
||||
[
|
||||
(
|
||||
"SELECT TOP 10 * FROM Customers",
|
||||
"teradatasql",
|
||||
5,
|
||||
False,
|
||||
"SELECT\nTOP 5\n *\nFROM Customers",
|
||||
),
|
||||
(
|
||||
"SELECT TOP 10 * FROM Customers",
|
||||
"teradatasql",
|
||||
15,
|
||||
False,
|
||||
"SELECT\nTOP 10\n *\nFROM Customers",
|
||||
),
|
||||
(
|
||||
"SELECT TOP 10 * FROM Customers",
|
||||
"teradatasql",
|
||||
15,
|
||||
True,
|
||||
"SELECT\nTOP 15\n *\nFROM Customers",
|
||||
),
|
||||
(
|
||||
"SELECT TOP 10 * FROM Customers",
|
||||
"mssql",
|
||||
15,
|
||||
True,
|
||||
"SELECT\nTOP 15\n *\nFROM Customers",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_apply_limit(
|
||||
sql: str,
|
||||
engine: str,
|
||||
limit: int,
|
||||
force: bool,
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Test the `apply_limit` function.
|
||||
"""
|
||||
assert SQLStatement(sql, engine).apply_limit(limit, force).format() == expected
|
||||
|
||||
Reference in New Issue
Block a user