feat(sqllab): use sqlglot instead of sqlparse (#33542)

This commit is contained in:
Beto Dealmeida
2025-05-30 17:08:19 -04:00
committed by GitHub
parent f219dc1794
commit cf315388f2
10 changed files with 574 additions and 552 deletions

View File

@@ -36,7 +36,7 @@ from sqlalchemy.dialects.mysql import dialect
from tests.integration_tests.constants import ADMIN_USERNAME
from tests.integration_tests.test_app import app, login
from superset.sql_parse import CtasMethod
from superset.sql.parse import CTASMethod
from superset import db, security_manager
from superset.connectors.sqla.models import BaseDatasource, SqlaTable
from superset.models import core as models
@@ -387,7 +387,7 @@ class SupersetTestCase(TestCase):
select_as_cta=False,
tmp_table_name=None,
schema=None,
ctas_method=CtasMethod.TABLE,
ctas_method=CTASMethod.TABLE,
template_params="{}",
):
if username:
@@ -400,7 +400,7 @@ class SupersetTestCase(TestCase):
"client_id": client_id,
"queryLimit": query_limit,
"sql_editor_id": sql_editor_id,
"ctas_method": ctas_method,
"ctas_method": ctas_method.name,
"templateParams": template_params,
}
if tmp_table_name:

View File

@@ -39,7 +39,7 @@ from superset.db_engine_specs.base import BaseEngineSpec
from superset.errors import ErrorLevel, SupersetErrorType
from superset.extensions import celery_app
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, CtasMethod
from superset.sql.parse import CTASMethod
from superset.utils.core import backend
from superset.utils.database import get_example_database
from tests.integration_tests.conftest import CTAS_SCHEMA_NAME
@@ -76,13 +76,19 @@ def setup_sqllab():
db.session.query(Query).delete()
db.session.commit()
for tbl in TMP_TABLES:
drop_table_if_exists(f"{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE)
drop_table_if_exists(f"{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW)
drop_table_if_exists(
f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.TABLE.lower()}", CtasMethod.TABLE
f"{tbl}_{CTASMethod.TABLE.name.lower()}", CTASMethod.TABLE
)
drop_table_if_exists(
f"{CTAS_SCHEMA_NAME}.{tbl}_{CtasMethod.VIEW.lower()}", CtasMethod.VIEW
f"{tbl}_{CTASMethod.VIEW.name.lower()}", CTASMethod.VIEW
)
drop_table_if_exists(
f"{CTAS_SCHEMA_NAME}.{tbl}_{CTASMethod.TABLE.name.lower()}",
CTASMethod.TABLE,
)
drop_table_if_exists(
f"{CTAS_SCHEMA_NAME}.{tbl}_{CTASMethod.VIEW.name.lower()}",
CTASMethod.VIEW,
)
@@ -90,7 +96,7 @@ def run_sql(
test_client,
sql,
cta=False,
ctas_method=CtasMethod.TABLE,
ctas_method=CTASMethod.TABLE,
tmp_table="tmp",
async_=False,
):
@@ -104,14 +110,14 @@ def run_sql(
select_as_cta=cta,
tmp_table_name=tmp_table,
client_id="".join(random.choice(string.ascii_lowercase) for i in range(5)), # noqa: S311
ctas_method=ctas_method,
ctas_method=ctas_method.name,
),
).json
def drop_table_if_exists(table_name: str, table_type: CtasMethod) -> None:
def drop_table_if_exists(table_name: str, table_type: CTASMethod) -> None:
"""Drop table if it exists, works on any DB"""
sql = f"DROP {table_type} IF EXISTS {table_name}"
sql = f"DROP {table_type.name} IF EXISTS {table_name}"
database = get_example_database()
with database.get_sqla_engine() as engine:
engine.execute(sql)
@@ -124,10 +130,10 @@ def quote_f(value: Optional[str]):
return inspector.engine.dialect.identifier_preparer.quote_identifier(value)
def cta_result(ctas_method: CtasMethod):
def cta_result(ctas_method: CTASMethod):
if backend() != "presto":
return [], []
if ctas_method == CtasMethod.TABLE:
if ctas_method == CTASMethod.TABLE:
return [{"rows": 1}], [{"name": "rows", "type": "BIGINT", "is_dttm": False}]
return [{"result": True}], [{"name": "result", "type": "BOOLEAN", "is_dttm": False}]
@@ -143,13 +149,13 @@ def get_select_star(table: str, limit: int, schema: Optional[str] = None):
@pytest.mark.usefixtures("login_as_admin")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@pytest.mark.parametrize("ctas_method", [CTASMethod.TABLE, CTASMethod.VIEW])
def test_run_sync_query_dont_exist(test_client, ctas_method):
examples_db = get_example_database()
engine_name = examples_db.db_engine_spec.engine_name
sql_dont_exist = "SELECT name FROM table_dont_exist"
result = run_sql(test_client, sql_dont_exist, cta=True, ctas_method=ctas_method)
if backend() == "sqlite" and ctas_method == CtasMethod.VIEW:
if backend() == "sqlite" and ctas_method == CTASMethod.VIEW:
assert QueryStatus.SUCCESS == result["status"], result
elif backend() == "presto":
assert (
@@ -188,9 +194,9 @@ def test_run_sync_query_dont_exist(test_client, ctas_method):
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_sync_query_cta(test_client, ctas_method):
tmp_table_name = f"{TEST_SYNC}_{ctas_method.lower()}"
@pytest.mark.parametrize("ctas_method", [CTASMethod.TABLE, CTASMethod.VIEW])
def test_run_sync_query_cta(test_client, ctas_method: CTASMethod) -> None:
tmp_table_name = f"{TEST_SYNC}_{ctas_method.name.lower()}"
result = run_sql(
test_client, QUERY, tmp_table=tmp_table_name, cta=True, ctas_method=ctas_method
)
@@ -218,16 +224,44 @@ def test_run_sync_query_cta_no_data(test_client):
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@pytest.mark.parametrize(
"ctas_method, expected",
[
(
CTASMethod.TABLE,
"""
CREATE TABLE sqllab_test_db.test_sync_cta_table AS
SELECT
name
FROM birth_names
LIMIT 1
""".strip(),
),
(
CTASMethod.VIEW,
"""
CREATE VIEW sqllab_test_db.test_sync_cta_view AS
SELECT
name
FROM birth_names
LIMIT 1
""".strip(),
),
],
)
@mock.patch( # noqa: PT008
"superset.sqllab.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_sync_query_cta_config(test_client, ctas_method):
def test_run_sync_query_cta_config(
test_client,
ctas_method: CTASMethod,
expected: str,
) -> None:
if backend() == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.lower()}"
tmp_table_name = f"{TEST_SYNC_CTA}_{ctas_method.name.lower()}"
result = run_sql(
test_client, QUERY, cta=True, ctas_method=ctas_method, tmp_table=tmp_table_name
)
@@ -235,10 +269,7 @@ def test_run_sync_query_cta_config(test_client, ctas_method):
assert cta_result(ctas_method) == (result["data"], result["columns"])
query = get_query_by_id(result["query"]["serverId"])
assert (
f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}"
== query.executed_sql
)
assert query.executed_sql == expected
assert query.select_sql == get_select_star(
tmp_table_name, limit=query.limit, schema=CTAS_SCHEMA_NAME
)
@@ -249,16 +280,44 @@ def test_run_sync_query_cta_config(test_client, ctas_method):
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
@pytest.mark.parametrize(
"ctas_method, expected",
[
(
CTASMethod.TABLE,
"""
CREATE TABLE sqllab_test_db.test_async_cta_config_table AS
SELECT
name
FROM birth_names
LIMIT 1
""".strip(),
),
(
CTASMethod.VIEW,
"""
CREATE VIEW sqllab_test_db.test_async_cta_config_view AS
SELECT
name
FROM birth_names
LIMIT 1
""".strip(),
),
],
)
@mock.patch( # noqa: PT008
"superset.sqllab.sqllab_execution_context.get_cta_schema_name",
lambda d, u, s, sql: CTAS_SCHEMA_NAME,
)
def test_run_async_query_cta_config(test_client, ctas_method):
def test_run_async_query_cta_config(
test_client,
ctas_method: CTASMethod,
expected: str,
) -> None:
if backend() == "sqlite":
# sqlite doesn't support schemas
return
tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.lower()}"
tmp_table_name = f"{TEST_ASYNC_CTA_CONFIG}_{ctas_method.name.lower()}"
result = run_sql(
test_client,
QUERY,
@@ -275,18 +334,43 @@ def test_run_async_query_cta_config(test_client, ctas_method):
get_select_star(tmp_table_name, limit=query.limit, schema=CTAS_SCHEMA_NAME)
== query.select_sql
)
assert (
f"CREATE {ctas_method} {CTAS_SCHEMA_NAME}.{tmp_table_name} AS \n{QUERY}"
== query.executed_sql
)
assert query.executed_sql == expected
delete_tmp_view_or_table(f"{CTAS_SCHEMA_NAME}.{tmp_table_name}", ctas_method)
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_cta_query(test_client, ctas_method):
table_name = f"{TEST_ASYNC_CTA}_{ctas_method.lower()}"
@pytest.mark.parametrize(
"ctas_method, expected",
[
(
CTASMethod.TABLE,
"""
CREATE TABLE test_async_cta_table AS
SELECT
name
FROM birth_names
LIMIT 1
""".strip(),
),
(
CTASMethod.VIEW,
"""
CREATE VIEW test_async_cta_view AS
SELECT
name
FROM birth_names
LIMIT 1
""".strip(),
),
],
)
def test_run_async_cta_query(
test_client,
ctas_method: CTASMethod,
expected: str,
) -> None:
table_name = f"{TEST_ASYNC_CTA}_{ctas_method.name.lower()}"
result = run_sql(
test_client,
QUERY,
@@ -301,7 +385,7 @@ def test_run_async_cta_query(test_client, ctas_method):
assert QueryStatus.SUCCESS == query.status
assert get_select_star(table_name, query.limit) in query.select_sql
assert f"CREATE {ctas_method} {table_name} AS \n{QUERY}" == query.executed_sql
assert query.executed_sql == expected
assert QUERY == query.sql
assert query.rows == (1 if backend() == "presto" else 0)
assert query.select_as_cta
@@ -311,9 +395,37 @@ def test_run_async_cta_query(test_client, ctas_method):
@pytest.mark.usefixtures("load_birth_names_data", "login_as_admin")
@pytest.mark.parametrize("ctas_method", [CtasMethod.TABLE, CtasMethod.VIEW])
def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.lower()}"
@pytest.mark.parametrize(
"ctas_method, expected",
[
(
CTASMethod.TABLE,
"""
CREATE TABLE test_async_lower_limit_table AS
SELECT
name
FROM birth_names
LIMIT 1
""".strip(),
),
(
CTASMethod.VIEW,
"""
CREATE VIEW test_async_lower_limit_view AS
SELECT
name
FROM birth_names
LIMIT 1
""".strip(),
),
],
)
def test_run_async_cta_query_with_lower_limit(
test_client,
ctas_method: CTASMethod,
expected: str,
) -> None:
tmp_table = f"{TEST_ASYNC_LOWER_LIMIT}_{ctas_method.name.lower()}"
result = run_sql(
test_client,
QUERY,
@@ -332,7 +444,7 @@ def test_run_async_cta_query_with_lower_limit(test_client, ctas_method):
else get_select_star(tmp_table, query.limit)
)
assert f"CREATE {ctas_method} {tmp_table} AS \n{QUERY}" == query.executed_sql
assert query.executed_sql == expected
assert QUERY == query.sql
assert query.rows == (1 if backend() == "presto" else 0)
@@ -442,28 +554,6 @@ def test_msgpack_payload_serialization():
assert isinstance(serialized, bytes)
def test_create_table_as():
q = ParsedQuery("SELECT * FROM outer_space;")
assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp")
assert (
"DROP TABLE IF EXISTS tmp;\nCREATE TABLE tmp AS \nSELECT * FROM outer_space"
== q.as_create_table("tmp", overwrite=True)
)
# now without a semicolon
q = ParsedQuery("SELECT * FROM outer_space")
assert "CREATE TABLE tmp AS \nSELECT * FROM outer_space" == q.as_create_table("tmp")
# now a multi-line query
multi_line_query = "SELECT * FROM planets WHERE\nLuke_Father = 'Darth Vader'"
q = ParsedQuery(multi_line_query)
assert (
"CREATE TABLE tmp AS \nSELECT * FROM planets WHERE\nLuke_Father = 'Darth Vader'"
== q.as_create_table("tmp")
)
def test_in_app_context():
@celery_app.task(bind=True)
def my_task(self):
@@ -484,8 +574,8 @@ def test_in_app_context():
)
def delete_tmp_view_or_table(name: str, db_object_type: str):
db.get_engine().execute(f"DROP {db_object_type} IF EXISTS {name}")
def delete_tmp_view_or_table(name: str, ctas_method: CTASMethod):
db.get_engine().execute(f"DROP {ctas_method.name} IF EXISTS {name}")
def wait_for_success(result):

View File

@@ -31,16 +31,15 @@ from superset.db_engine_specs import BaseEngineSpec
from superset.db_engine_specs.hive import HiveEngineSpec
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException
from superset.exceptions import SupersetErrorException, SupersetInvalidCVASException
from superset.models.sql_lab import Query
from superset.result_set import SupersetResultSet
from superset.sqllab.limiting_factor import LimitingFactor
from superset.sql.parse import CTASMethod
from superset.sql_lab import (
cancel_query,
execute_sql_statements,
apply_limit_if_exists,
)
from superset.sql_parse import CtasMethod
from superset.utils.core import backend
from superset.utils import json
from superset.utils.json import datetime_to_epoch # noqa: F401
@@ -132,31 +131,13 @@ class TestSqlLab(SupersetTestCase):
self.login(ADMIN_USERNAME)
data = self.run_sql("DELETE FROM birth_names", "1")
assert data == {
"errors": [
{
"message": (
"This database does not allow for DDL/DML, and the query "
"could not be parsed to confirm it is a read-only query. Please " # noqa: E501
"contact your administrator for more assistance."
),
"error_type": SupersetErrorType.DML_NOT_ALLOWED_ERROR,
"level": ErrorLevel.ERROR,
"extra": {
"issue_codes": [
{
"code": 1022,
"message": "Issue 1022 - Database does not allow data manipulation.", # noqa: E501
}
]
},
}
]
}
assert (
data["errors"][0]["error_type"] == SupersetErrorType.DML_NOT_ALLOWED_ERROR
)
@parameterized.expand([CtasMethod.TABLE, CtasMethod.VIEW])
@parameterized.expand([CTASMethod.TABLE, CTASMethod.VIEW])
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_sql_json_cta_dynamic_db(self, ctas_method):
def test_sql_json_cta_dynamic_db(self, ctas_method: CTASMethod) -> None:
examples_db = get_example_database()
if examples_db.backend == "sqlite":
# sqlite doesn't support database creation
@@ -170,7 +151,7 @@ class TestSqlLab(SupersetTestCase):
examples_db.allow_ctas = True # enable cta
self.login(ADMIN_USERNAME)
tmp_table_name = f"test_target_{ctas_method.lower()}"
tmp_table_name = f"test_target_{ctas_method.name.lower()}"
self.run_sql(
"SELECT * FROM birth_names",
"1",
@@ -195,7 +176,9 @@ class TestSqlLab(SupersetTestCase):
) # SQL_MAX_ROW not applied due to the SQLLAB_CTAS_NO_LIMIT set to True
# cleanup
engine.execute(f"DROP {ctas_method} admin_database.{tmp_table_name}")
engine.execute(
f"DROP {ctas_method.name} admin_database.{tmp_table_name}"
)
examples_db.allow_ctas = old_allow_ctas
db.session.commit()
@@ -608,10 +591,10 @@ class TestSqlLab(SupersetTestCase):
@mock.patch("superset.sql_lab.db")
@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
@mock.patch("superset.sql_lab.execute_query")
def test_execute_sql_statements(
self,
mock_execute_sql_statement,
mock_execute_query,
mock_get_query,
mock_db,
):
@@ -623,7 +606,7 @@ class TestSqlLab(SupersetTestCase):
"""
)
mock_db = mock.MagicMock() # noqa: F841
mock_query = mock.MagicMock()
mock_query = mock.MagicMock(select_as_cta=False)
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
@@ -641,30 +624,20 @@ class TestSqlLab(SupersetTestCase):
expand_data=False,
log_params=None,
)
mock_execute_sql_statement.assert_has_calls(
mock_execute_query.assert_has_calls(
[
mock.call(
"-- comment\nSET @value = 42",
mock_query,
mock_cursor,
None,
False,
),
mock.call(
"SELECT /*+ hint */ @value AS foo",
mock_query,
mock_cursor,
None,
False,
),
mock.call(mock_query, mock_cursor, None),
mock.call(mock_query, mock_cursor, None),
]
)
@mock.patch("superset.sql_lab.results_backend", None)
@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
@mock.patch("superset.sql_lab.execute_query")
def test_execute_sql_statements_no_results_backend(
self, mock_execute_sql_statement, mock_get_query
self,
mock_execute_query,
mock_get_query,
):
sql = dedent(
"""
@@ -712,10 +685,10 @@ class TestSqlLab(SupersetTestCase):
@mock.patch("superset.sql_lab.db")
@mock.patch("superset.sql_lab.get_query")
@mock.patch("superset.sql_lab.execute_sql_statement")
@mock.patch("superset.sql_lab.execute_query")
def test_execute_sql_statements_ctas(
self,
mock_execute_sql_statement,
mock_execute_query,
mock_get_query,
mock_db,
):
@@ -727,7 +700,13 @@ class TestSqlLab(SupersetTestCase):
"""
)
mock_db = mock.MagicMock() # noqa: F841
mock_query = mock.MagicMock()
mock_query = mock.MagicMock(
select_as_cta=True,
ctas_method=CTASMethod.TABLE.name,
tmp_table_name="table",
tmp_schema_name="schema",
catalog="catalog",
)
mock_query.database.allow_run_async = False
mock_cursor = mock.MagicMock()
mock_query.database.get_raw_connection().__enter__().cursor.return_value = (
@@ -738,7 +717,7 @@ class TestSqlLab(SupersetTestCase):
# set the query to CTAS
mock_query.select_as_cta = True
mock_query.ctas_method = CtasMethod.TABLE
mock_query.ctas_method = CTASMethod.TABLE.name
execute_sql_statements(
query_id=1,
@@ -749,22 +728,10 @@ class TestSqlLab(SupersetTestCase):
expand_data=False,
log_params=None,
)
mock_execute_sql_statement.assert_has_calls(
mock_execute_query.assert_has_calls(
[
mock.call(
"-- comment\nSET @value = 42",
mock_query,
mock_cursor,
None,
False,
),
mock.call(
"SELECT /*+ hint */ @value AS foo",
mock_query,
mock_cursor,
None,
True, # apply_ctas
),
mock.call(mock_query, mock_cursor, None),
mock.call(mock_query, mock_cursor, None),
]
)
@@ -795,7 +762,7 @@ class TestSqlLab(SupersetTestCase):
)
# try invalid CVAS
mock_query.ctas_method = CtasMethod.VIEW
mock_query.ctas_method = CTASMethod.VIEW.name
sql = dedent(
"""
-- comment
@@ -803,7 +770,7 @@ class TestSqlLab(SupersetTestCase):
SELECT /*+ hint */ @value AS foo;
"""
)
with pytest.raises(SupersetErrorException) as excinfo:
with pytest.raises(SupersetInvalidCVASException) as excinfo:
execute_sql_statements(
query_id=1,
rendered_query=sql,
@@ -870,29 +837,6 @@ class TestSqlLab(SupersetTestCase):
]
}
def test_apply_limit_if_exists_when_incremented_limit_is_none(self):
sql = """
SET @value = 42;
SELECT @value AS foo;
"""
database = get_example_database()
mock_query = mock.MagicMock()
mock_query.limit = 300
final_sql = apply_limit_if_exists(database, None, mock_query, sql)
assert final_sql == sql
def test_apply_limit_if_exists_when_increased_limit(self):
sql = """
SET @value = 42;
SELECT @value AS foo;
"""
database = get_example_database()
mock_query = mock.MagicMock()
mock_query.limit = 300
final_sql = apply_limit_if_exists(database, 1000, mock_query, sql)
assert "LIMIT 1000" in final_sql
@pytest.mark.parametrize("spec", [HiveEngineSpec, PrestoEngineSpec])
def test_cancel_query_implicit(spec: BaseEngineSpec) -> None: