diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index d806a4be9fa..94f54097e23 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -129,6 +129,11 @@ class QueryContextProcessor: self, query_obj: QueryObject, force_cached: bool | None = False ) -> dict[str, Any]: """Handles caching around the df payload retrieval""" + if query_obj: + # Always validate the query object before generating cache key + # This ensures sanitize_clause() is called and extras are normalized + query_obj.validate() + cache_key = self.query_cache_key(query_obj) timeout = self.get_cache_timeout() force_query = self._query_context.force or timeout == CACHE_DISABLED_TIMEOUT @@ -139,10 +144,6 @@ class QueryContextProcessor: force_cached=force_cached, ) - if query_obj: - # Always validate the query object before processing - query_obj.validate() - if query_obj and cache_key and not cache.is_loaded: try: if invalid_columns := [ diff --git a/superset/sql/parse.py b/superset/sql/parse.py index d5392115f0e..055aa3caaac 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -1482,6 +1482,15 @@ def sanitize_clause(clause: str, engine: str) -> str: Make sure the SQL clause is valid. """ try: - return SQLStatement(clause, engine).format() + statement = SQLStatement(clause, engine) + dialect = SQLGLOT_DIALECTS.get(engine) + from sqlglot.dialects.dialect import Dialect + + return Dialect.get_or_raise(dialect).generate( + statement._parsed, # pylint: disable=protected-access + copy=True, + comments=False, + pretty=False, + ) except SupersetParseError as ex: raise QueryClauseValidationException(f"Invalid SQL clause: {clause}") from ex diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index 7e76a93cec9..ddf3ef44dbc 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -729,7 +729,7 @@ class TestPostChartDataApi(BaseTestChartDataApi): rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data") result = rv.json["result"][0]["query"] if get_example_database().backend != "presto": - assert "(\n 'boy' = 'boy'\n)" in result + assert "('boy' = 'boy')" in result @unittest.skip("Extremely flaky test on MySQL") @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) diff --git a/tests/integration_tests/query_context_tests.py b/tests/integration_tests/query_context_tests.py index 751f96631ea..21694580aa9 100644 --- a/tests/integration_tests/query_context_tests.py +++ b/tests/integration_tests/query_context_tests.py @@ -270,6 +270,60 @@ class TestQueryContext(SupersetTestCase): cache_key = query_context.query_cache_key(query_object) assert cache_key_original != cache_key + def test_query_cache_key_consistent_with_different_sql_formatting(self): + """ + Test that cache keys are consistent regardless of SQL clause formatting. + + This test verifies the fix for the cache key mismatch issue where different + whitespace formatting in WHERE/HAVING clauses caused different cache keys + to be generated between server and worker processes. + """ + # Create payload with compact WHERE clause + payload1 = get_query_context("birth_names") + payload1["queries"][0]["extras"] = {"where": "(name = 'Amy')"} + + query_context1 = ChartDataQueryContextSchema().load(payload1) + # Use get_df_payload which is the actual code path, not query_cache_key directly + result1 = query_context1.get_df_payload( + query_context1.queries[0], force_cached=False + ) + cache_key1 = result1.get("cache_key") + + # Create same payload but with pretty-formatted WHERE clause (with newlines) + payload2 = get_query_context("birth_names") + payload2["queries"][0]["extras"] = {"where": "(\n name = 'Amy'\n)"} + + query_context2 = ChartDataQueryContextSchema().load(payload2) + result2 = query_context2.get_df_payload( + query_context2.queries[0], force_cached=False + ) + cache_key2 = result2.get("cache_key") + + # Cache keys should be identical after sanitization + assert cache_key1 == cache_key2 + + # Also verify with HAVING clause + payload3 = get_query_context("birth_names") + payload3["queries"][0]["extras"] = {"having": "(sum__num > 100)"} + + query_context3 = ChartDataQueryContextSchema().load(payload3) + result3 = query_context3.get_df_payload( + query_context3.queries[0], force_cached=False + ) + cache_key3 = result3.get("cache_key") + + payload4 = get_query_context("birth_names") + payload4["queries"][0]["extras"] = {"having": "(\n sum__num > 100\n)"} + + query_context4 = ChartDataQueryContextSchema().load(payload4) + result4 = query_context4.get_df_payload( + query_context4.queries[0], force_cached=False + ) + cache_key4 = result4.get("cache_key") + + # Cache keys should be identical after sanitization + assert cache_key3 == cache_key4 + def test_handle_metrics_field(self): """ Should support both predefined and adhoc metrics. diff --git a/tests/unit_tests/common/test_query_context_processor.py b/tests/unit_tests/common/test_query_context_processor.py index dc55f8174bf..68c1e4b7ed1 100644 --- a/tests/unit_tests/common/test_query_context_processor.py +++ b/tests/unit_tests/common/test_query_context_processor.py @@ -624,3 +624,80 @@ def test_processing_time_offsets_date_range_enabled(processor): assert isinstance(result["df"], pd.DataFrame) assert isinstance(result["queries"], list) assert isinstance(result["cache_keys"], list) + + +def test_get_df_payload_validates_before_cache_key_generation(): + """ + Test that get_df_payload calls validate() before generating cache key. + """ + from superset.common.query_object import QueryObject + + # Create a mock query context + mock_query_context = MagicMock() + mock_query_context.force = False + mock_query_context.result_type = "full" + + # Create a mock datasource + mock_datasource = MagicMock() + mock_datasource.id = 123 + mock_datasource.uid = "test_datasource" + mock_datasource.cache_timeout = None + mock_datasource.database.db_engine_spec.engine = "postgresql" + mock_datasource.database.extra = "{}" + mock_datasource.get_extra_cache_keys.return_value = [] + mock_datasource.changed_on = None + + # Create processor + processor = QueryContextProcessor(mock_query_context) + processor._qc_datasource = mock_datasource + + # Create a query object with unsanitized where clause + query_obj = QueryObject( + datasource=mock_datasource, + columns=["col1"], + metrics=[], + extras={"where": "(\n col1 > 0\n)"}, # Unsanitized with newlines + ) + + # Track the order of calls + call_order = [] + + original_validate = query_obj.validate + + def mock_validate(*args, **kwargs): + call_order.append("validate") + # Update extras to simulate sanitization + query_obj.extras["where"] = "(col1 > 0)" # Sanitized, compact format + return original_validate(*args, **kwargs) + + original_cache_key = query_obj.cache_key + + def mock_cache_key(*args, **kwargs): + call_order.append("cache_key") + # Verify that extras have been sanitized at this point + assert query_obj.extras["where"] == "(col1 > 0)", ( + f"Expected sanitized clause in cache_key, got: {query_obj.extras['where']}" + ) + return original_cache_key(*args, **kwargs) + + with patch.object(query_obj, "validate", side_effect=mock_validate): + with patch.object(query_obj, "cache_key", side_effect=mock_cache_key): + with patch( + "superset.common.query_context_processor.QueryCacheManager" + ) as mock_cache_manager: + mock_cache = MagicMock() + mock_cache.is_loaded = True + mock_cache.df = pd.DataFrame({"col1": [1, 2, 3]}) + mock_cache.query = "SELECT * FROM table" + mock_cache.error_message = None + mock_cache.status = "success" + mock_cache_manager.get.return_value = mock_cache + + # Call get_df_payload + processor.get_df_payload(query_obj, force_cached=False) + + # Verify validate was called before cache_key + assert call_order == ["validate", "cache_key"], ( + f"Expected validate to be called before cache_key, " + f"but got call order: {call_order}" + ) diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 406df77934f..8678c25ffa3 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -2593,9 +2593,17 @@ def test_is_valid_cvas(sql: str, engine: str, expected: bool) -> None: [ ("col = 1", "col = 1", "base"), ("1=\t\n1", "1 = 1", "base"), - ("(col = 1)", "(\n col = 1\n)", "base"), - ("(col1 = 1) AND (col2 = 2)", "(\n col1 = 1\n) AND (\n col2 = 2\n)", "base"), - ("col = 'abc' -- comment", "col = 'abc' /* comment */", "base"), + ("(col = 1)", "(col = 1)", "base"), # Compact format without newlines + ( + "(col1 = 1) AND (col2 = 2)", + "(col1 = 1) AND (col2 = 2)", + "base", + ), # Compact format + ( + "col = 'abc' -- comment", + "col = 'abc'", + "base", + ), # Comments removed for compact format ("col = 'col1 = 1) AND (col2 = 2'", "col = 'col1 = 1) AND (col2 = 2'", "base"), ("col = 'select 1; select 2'", "col = 'select 1; select 2'", "base"), ("col = 'abc -- comment'", "col = 'abc -- comment'", "base"),