fix: normalize totals cache keys for async hits (#36274)

This commit is contained in:
Beto Dealmeida
2025-12-01 11:11:10 -05:00
committed by GitHub
parent 9fc7a83320
commit 775d1ba061
2 changed files with 125 additions and 12 deletions

View File

@@ -18,7 +18,7 @@ from __future__ import annotations
import logging
import re
from typing import Any, cast, ClassVar, TYPE_CHECKING
from typing import Any, cast, ClassVar, Sequence, TYPE_CHECKING
import pandas as pd
from flask import current_app
@@ -251,9 +251,13 @@ class QueryContextProcessor:
return df.to_dict(orient="records")
def ensure_totals_available(self) -> None:
queries_needing_totals = []
totals_queries = []
def _prepare_contribution_totals(self) -> tuple[list[int], int | None]:
"""
Identify contribution queries and normalize the totals query so cache keys
align with cached results.
"""
queries_needing_totals: list[int] = []
totals_idx: int | None = None
for i, query in enumerate(self._query_context.queries):
needs_totals = any(
@@ -267,17 +271,28 @@ class QueryContextProcessor:
is_totals_query = (
not query.columns and query.metrics and not query.post_processing
)
if is_totals_query:
totals_queries.append(i)
if is_totals_query and totals_idx is None:
totals_idx = i
if not queries_needing_totals or not totals_queries:
if queries_needing_totals and totals_idx is not None:
totals_query = self._query_context.queries[totals_idx]
totals_query.row_limit = None
return queries_needing_totals, totals_idx
def ensure_totals_available(
self,
queries_needing_totals: Sequence[int] | None = None,
totals_idx: int | None = None,
) -> None:
if queries_needing_totals is None or totals_idx is None:
queries_needing_totals, totals_idx = self._prepare_contribution_totals()
if not queries_needing_totals or totals_idx is None:
return
totals_idx = totals_queries[0]
totals_query = self._query_context.queries[totals_idx]
totals_query.row_limit = None
result = self._query_context.get_query_result(totals_query)
df = result.df
@@ -299,10 +314,12 @@ class QueryContextProcessor:
) -> dict[str, Any]:
"""Returns the query results with both metadata and data"""
queries_needing_totals, totals_idx = self._prepare_contribution_totals()
# Skip ensure_totals_available when force_cached=True
# This prevents recalculating contribution_totals from cached results
if not force_cached:
self.ensure_totals_available()
self.ensure_totals_available(queries_needing_totals, totals_idx)
# Update cache_values to reflect modifications made by
# ensure_totals_available()

View File

@@ -15,13 +15,15 @@
# specific language governing permissions and limitations
# under the License.
from typing import Any
from unittest.mock import MagicMock, patch
import numpy as np
import pandas as pd
import pytest
from superset.common.chart_data import ChartDataResultFormat
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.db_query_status import QueryStatus
from superset.common.query_context_processor import QueryContextProcessor
from superset.utils.core import GenericDataType
@@ -1215,3 +1217,97 @@ def test_cache_key_non_contribution_post_processing_unchanged():
assert query1.cache_key() != query2.cache_key(), (
"Cache keys should differ for different non-contribution post_processing"
)
def test_force_cached_normalizes_totals_query_row_limit():
"""
When fetching from cache (force_cached=True), the totals query should still be
normalized so its cache key matches the cached entry, but the totals query should
not be executed.
"""
from superset.common.query_object import QueryObject
mock_datasource = MagicMock()
mock_datasource.uid = "test_datasource"
mock_datasource.column_names = ["region", "sales"]
mock_datasource.cache_timeout = None
mock_datasource.changed_on = None
mock_datasource.get_extra_cache_keys.return_value = []
mock_datasource.database.extra = "{}"
mock_datasource.database.impersonate_user = False
mock_datasource.database.db_engine_spec.get_impersonation_key.return_value = None
totals_query = QueryObject(
datasource=mock_datasource,
columns=[],
metrics=["sales"],
row_limit=1000,
)
main_query = QueryObject(
datasource=mock_datasource,
columns=["region"],
metrics=["sales"],
row_limit=1000,
post_processing=[{"operation": "contribution", "options": {}}],
)
totals_query.validate = MagicMock()
main_query.validate = MagicMock()
captured_limits: list[int | None] = []
def totals_cache_key(**kwargs: Any) -> str:
captured_limits.append(totals_query.row_limit)
return "totals-cache-key"
totals_query.cache_key = totals_cache_key
main_query.cache_key = lambda **kwargs: "main-cache-key"
mock_query_context = MagicMock()
mock_query_context.force = False
mock_query_context.datasource = mock_datasource
mock_query_context.queries = [main_query, totals_query]
mock_query_context.result_type = ChartDataResultType.FULL
mock_query_context.result_format = ChartDataResultFormat.JSON
mock_query_context.cache_values = {
"queries": [main_query.to_dict(), totals_query.to_dict()]
}
mock_query_context.get_query_result = MagicMock()
processor = QueryContextProcessor(mock_query_context)
processor._qc_datasource = mock_datasource
mock_query_context.get_df_payload = processor.get_df_payload
mock_query_context.get_data = processor.get_data
with patch(
"superset.common.query_context_processor.security_manager"
) as mock_security_manager:
mock_security_manager.get_rls_cache_key.return_value = None
with patch(
"superset.common.query_context_processor.QueryCacheManager"
) as mock_cache_manager:
def cache_get(*args: Any, **kwargs: Any) -> Any:
df = pd.DataFrame({"region": ["North"], "sales": [100]})
cache = MagicMock()
cache.is_loaded = True
cache.df = df
cache.query = "SELECT 1"
cache.error_message = None
cache.status = QueryStatus.SUCCESS
cache.applied_template_filters = []
cache.applied_filter_columns = []
cache.rejected_filter_columns = []
cache.annotation_data = {}
cache.is_cached = True
cache.sql_rowcount = len(df)
cache.cache_dttm = "2024-01-01T00:00:00"
return cache
mock_cache_manager.get.side_effect = cache_get
processor.get_payload(cache_query_context=False, force_cached=True)
assert captured_limits == [None], "Totals query should be normalized before caching"
mock_query_context.get_query_result.assert_not_called()