feat(sqlparse): improve table parsing (#26476)

This commit is contained in:
Beto Dealmeida
2024-01-22 11:16:50 -05:00
committed by GitHub
parent d34874cf2b
commit c0b57bd1c3
17 changed files with 265 additions and 120 deletions

View File

@@ -40,11 +40,11 @@ from superset.sql_parse import (
)
def extract_tables(query: str) -> set[Table]:
def extract_tables(query: str, engine: Optional[str] = None) -> set[Table]:
"""
Helper function to extract tables referenced in a query.
"""
return ParsedQuery(query).tables
return ParsedQuery(query, engine=engine).tables
def test_table() -> None:
@@ -96,8 +96,13 @@ def test_extract_tables() -> None:
Table("left_table")
}
# reverse select
assert extract_tables("FROM t1 SELECT field") == {Table("t1")}
assert extract_tables(
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;"
) == {Table("forbidden_table")}
assert extract_tables(
"select * from (select * from forbidden_table) forbidden_table"
) == {Table("forbidden_table")}
def test_extract_tables_subselect() -> None:
@@ -263,14 +268,16 @@ def test_extract_tables_illdefined() -> None:
assert extract_tables("SELECT * FROM schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname.schemaname.") == set()
assert extract_tables("SELECT * FROM catalogname..") == set()
assert extract_tables("SELECT * FROM catalogname..tbname") == set()
assert extract_tables("SELECT * FROM catalogname..tbname") == {
Table(table="tbname", schema=None, catalog="catalogname")
}
def test_extract_tables_show_tables_from() -> None:
"""
Test ``SHOW TABLES FROM``.
"""
assert extract_tables("SHOW TABLES FROM s1 like '%order%'") == set()
assert extract_tables("SHOW TABLES FROM s1 like '%order%'", "mysql") == set()
def test_extract_tables_show_columns_from() -> None:
@@ -311,7 +318,7 @@ WHERE regionkey IN (SELECT regionkey FROM t2)
"""
SELECT name
FROM t1
WHERE regionkey EXISTS (SELECT regionkey FROM t2)
WHERE EXISTS (SELECT 1 FROM t2 WHERE t1.regionkey = t2.regionkey);
"""
)
== {Table("t1"), Table("t2")}
@@ -526,6 +533,18 @@ select * from (select key from q1) a
== {Table("src")}
)
# weird query with circular dependency
assert (
extract_tables(
"""
with src as ( select key from q2 where key = '5'),
q2 as ( select key from src where key = '5')
select * from (select key from src) a
"""
)
== set()
)
def test_extract_tables_multistatement() -> None:
"""
@@ -665,7 +684,8 @@ def test_extract_tables_nested_select() -> None:
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(TABLE_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
"""
""",
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
@@ -676,7 +696,8 @@ WHERE TABLE_SCHEMA like "%bi%"),0x7e)));
select (extractvalue(1,concat(0x7e,(select GROUP_CONCAT(COLUMN_NAME)
from INFORMATION_SCHEMA.COLUMNS
WHERE TABLE_NAME="bi_achievement_daily"),0x7e)));
"""
""",
"mysql",
)
== {Table("COLUMNS", "INFORMATION_SCHEMA")}
)
@@ -1306,6 +1327,14 @@ def test_sqlparse_issue_652():
"(SELECT table_name FROM /**/ information_schema.tables WHERE table_name LIKE '%user%' LIMIT 1)",
True,
),
(
"SELECT FROM (SELECT FROM forbidden_table) AS forbidden_table;",
True,
),
(
"SELECT * FROM (SELECT * FROM forbidden_table) forbidden_table",
True,
),
],
)
def test_has_table_query(sql: str, expected: bool) -> None:
@@ -1790,13 +1819,17 @@ def test_extract_table_references(mocker: MockerFixture) -> None:
assert extract_table_references(
sql,
"trino",
) == {Table(table="other_table", schema=None, catalog=None)}
) == {
Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
logger.warning.assert_called_once()
logger = mocker.patch("superset.migrations.shared.utils.logger")
sql = "SELECT * FROM table UNION ALL SELECT * FROM other_table"
assert extract_table_references(sql, "trino", show_warning=False) == {
Table(table="other_table", schema=None, catalog=None)
Table(table="table", schema=None, catalog=None),
Table(table="other_table", schema=None, catalog=None),
}
logger.warning.assert_not_called()