feat: use sqlglot to validate adhoc subquery (#33560)

This commit is contained in:
Beto Dealmeida
2025-05-30 18:09:19 -04:00
committed by GitHub
parent cf315388f2
commit 401ce56fa1
10 changed files with 123 additions and 92 deletions

View File

@@ -661,7 +661,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
]
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 400
assert rv.status_code == 422
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_invalid_having_parameter_closing_and_comment__400(self):
@@ -709,7 +709,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
result = rv.json["result"][0]["query"]
if get_example_database().backend != "presto":
assert "('boy' = 'boy')" in result
assert "(\n 'boy' = 'boy'\n )" in result
@unittest.skip("Extremely flaky test on MySQL")
@with_feature_flags(GLOBAL_ASYNC_QUERIES=True)
@@ -840,7 +840,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
unique_names = {row["name"] for row in data}
self.maxDiff = None
assert len(unique_names) == SERIES_LIMIT
assert {column for column in data[0].keys()} == {"state", "name", "sum__num"} # noqa: C416
assert set(data[0]) == {"state", "name", "sum__num"}
@pytest.mark.usefixtures(
"create_annotation_layers", "load_birth_names_dashboard_with_slices"
@@ -931,7 +931,7 @@ class TestPostChartDataApi(BaseTestChartDataApi):
assert rv.status_code == 200
result = rv.json["result"][0]
data = result["data"]
assert {col for col in data[0].keys()} == {"foo", "bar", "state", "count"} # noqa: C416
assert set(data[0]) == {"foo", "bar", "state", "count"}
# make sure results and query parameters are unescaped
assert {row["foo"] for row in data} == {":foo"}
assert {row["bar"] for row in data} == {":bar:"}
@@ -1251,7 +1251,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
data = result["data"]
assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"} # noqa: C416
assert set(data[0]) == {"male_or_female", "sum__num"}
unique_genders = {row["male_or_female"] for row in data}
assert unique_genders == {"male", "female"}
assert result["applied_filters"] == [{"column": "male_or_female"}]
@@ -1271,7 +1271,7 @@ class TestGetChartDataApi(BaseTestChartDataApi):
response_payload = json.loads(rv.data.decode("utf-8"))
result = response_payload["result"][0]
data = result["data"]
assert {column for column in data[0].keys()} == {"male_or_female", "sum__num"} # noqa: C416
assert set(data[0]) == {"male_or_female", "sum__num"}
unique_genders = {row["male_or_female"] for row in data}
assert unique_genders == {"male", "female"}
assert result["applied_filters"] == [{"column": "male_or_female"}]

View File

@@ -568,6 +568,9 @@ def test_get_samples(test_client, login_as_admin, virtual_dataset):
def test_get_samples_with_incorrect_cc(test_client, login_as_admin, virtual_dataset):
if get_example_database().backend == "sqlite":
return
TableColumn(
column_name="DUMMY CC",
type="VARCHAR(255)",
@@ -702,7 +705,7 @@ def test_get_samples_with_multiple_filters(
assert "2000-01-02" in rv.json["result"]["query"]
assert "2000-01-04" in rv.json["result"]["query"]
assert "col3 = 1.2" in rv.json["result"]["query"]
assert "col4 is null" in rv.json["result"]["query"]
assert "col4 IS NULL" in rv.json["result"]["query"]
assert "col2 = 'c'" in rv.json["result"]["query"]

View File

@@ -447,7 +447,11 @@ class TestSqlaTableModel(SupersetTestCase):
return None
old_inner_join = spec.allows_joins
spec.allows_joins = inner_join
arbitrary_gby = "state || gender || '_test'"
arbitrary_gby = (
"state OR gender OR '_test'"
if get_example_database().backend == "mysql"
else "state || gender || '_test'"
)
arbitrary_metric = dict( # noqa: C408
label="arbitrary", expressionType="SQL", sqlExpression="SUM(num_boys)"
)

View File

@@ -876,12 +876,6 @@ def test_special_chars_in_column_name(app_context, physical_dataset):
"columns": [
"col1",
"time column with spaces",
{
"label": "I_AM_A_TRUNC_COLUMN",
"sqlExpression": "time column with spaces",
"columnType": "BASE_AXIS",
"timeGrain": "P1Y",
},
],
"metrics": ["count"],
"orderby": [["col1", True]],
@@ -897,10 +891,8 @@ def test_special_chars_in_column_name(app_context, physical_dataset):
if query_object.datasource.database.backend == "sqlite":
# sqlite returns string as timestamp column
assert df["time column with spaces"][0] == "2002-01-03 00:00:00"
assert df["I_AM_A_TRUNC_COLUMN"][0] == "2002-01-01 00:00:00"
else:
assert df["time column with spaces"][0].strftime("%Y-%m-%d") == "2002-01-03"
assert df["I_AM_A_TRUNC_COLUMN"][0].strftime("%Y-%m-%d") == "2002-01-01"
@only_postgresql

View File

@@ -198,7 +198,7 @@ class TestDatabaseModel(SupersetTestCase):
# assert dataset saved metric
assert "count('bar_P1D')" in query
# assert adhoc metric
assert "SUM(case when user = 'user_abc' then 1 else 0 end)" in query
assert "SUM(CASE WHEN user = 'user_abc' THEN 1 ELSE 0 END)" in query
# Cleanup
db.session.delete(table)
db.session.commit()