mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
fix: normalize totals cache keys for async hits (#36274)
This commit is contained in:
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, cast, ClassVar, TYPE_CHECKING
|
||||
from typing import Any, cast, ClassVar, Sequence, TYPE_CHECKING
|
||||
|
||||
import pandas as pd
|
||||
from flask import current_app
|
||||
@@ -251,9 +251,13 @@ class QueryContextProcessor:
|
||||
|
||||
return df.to_dict(orient="records")
|
||||
|
||||
def ensure_totals_available(self) -> None:
|
||||
queries_needing_totals = []
|
||||
totals_queries = []
|
||||
def _prepare_contribution_totals(self) -> tuple[list[int], int | None]:
|
||||
"""
|
||||
Identify contribution queries and normalize the totals query so cache keys
|
||||
align with cached results.
|
||||
"""
|
||||
queries_needing_totals: list[int] = []
|
||||
totals_idx: int | None = None
|
||||
|
||||
for i, query in enumerate(self._query_context.queries):
|
||||
needs_totals = any(
|
||||
@@ -267,17 +271,28 @@ class QueryContextProcessor:
|
||||
is_totals_query = (
|
||||
not query.columns and query.metrics and not query.post_processing
|
||||
)
|
||||
if is_totals_query:
|
||||
totals_queries.append(i)
|
||||
if is_totals_query and totals_idx is None:
|
||||
totals_idx = i
|
||||
|
||||
if not queries_needing_totals or not totals_queries:
|
||||
if queries_needing_totals and totals_idx is not None:
|
||||
totals_query = self._query_context.queries[totals_idx]
|
||||
totals_query.row_limit = None
|
||||
|
||||
return queries_needing_totals, totals_idx
|
||||
|
||||
def ensure_totals_available(
|
||||
self,
|
||||
queries_needing_totals: Sequence[int] | None = None,
|
||||
totals_idx: int | None = None,
|
||||
) -> None:
|
||||
if queries_needing_totals is None or totals_idx is None:
|
||||
queries_needing_totals, totals_idx = self._prepare_contribution_totals()
|
||||
|
||||
if not queries_needing_totals or totals_idx is None:
|
||||
return
|
||||
|
||||
totals_idx = totals_queries[0]
|
||||
totals_query = self._query_context.queries[totals_idx]
|
||||
|
||||
totals_query.row_limit = None
|
||||
|
||||
result = self._query_context.get_query_result(totals_query)
|
||||
df = result.df
|
||||
|
||||
@@ -299,10 +314,12 @@ class QueryContextProcessor:
|
||||
) -> dict[str, Any]:
|
||||
"""Returns the query results with both metadata and data"""
|
||||
|
||||
queries_needing_totals, totals_idx = self._prepare_contribution_totals()
|
||||
|
||||
# Skip ensure_totals_available when force_cached=True
|
||||
# This prevents recalculating contribution_totals from cached results
|
||||
if not force_cached:
|
||||
self.ensure_totals_available()
|
||||
self.ensure_totals_available(queries_needing_totals, totals_idx)
|
||||
|
||||
# Update cache_values to reflect modifications made by
|
||||
# ensure_totals_available()
|
||||
|
||||
@@ -15,13 +15,15 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pytest
|
||||
|
||||
from superset.common.chart_data import ChartDataResultFormat
|
||||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.common.db_query_status import QueryStatus
|
||||
from superset.common.query_context_processor import QueryContextProcessor
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
@@ -1215,3 +1217,97 @@ def test_cache_key_non_contribution_post_processing_unchanged():
|
||||
assert query1.cache_key() != query2.cache_key(), (
|
||||
"Cache keys should differ for different non-contribution post_processing"
|
||||
)
|
||||
|
||||
|
||||
def test_force_cached_normalizes_totals_query_row_limit():
|
||||
"""
|
||||
When fetching from cache (force_cached=True), the totals query should still be
|
||||
normalized so its cache key matches the cached entry, but the totals query should
|
||||
not be executed.
|
||||
"""
|
||||
from superset.common.query_object import QueryObject
|
||||
|
||||
mock_datasource = MagicMock()
|
||||
mock_datasource.uid = "test_datasource"
|
||||
mock_datasource.column_names = ["region", "sales"]
|
||||
mock_datasource.cache_timeout = None
|
||||
mock_datasource.changed_on = None
|
||||
mock_datasource.get_extra_cache_keys.return_value = []
|
||||
mock_datasource.database.extra = "{}"
|
||||
mock_datasource.database.impersonate_user = False
|
||||
mock_datasource.database.db_engine_spec.get_impersonation_key.return_value = None
|
||||
|
||||
totals_query = QueryObject(
|
||||
datasource=mock_datasource,
|
||||
columns=[],
|
||||
metrics=["sales"],
|
||||
row_limit=1000,
|
||||
)
|
||||
main_query = QueryObject(
|
||||
datasource=mock_datasource,
|
||||
columns=["region"],
|
||||
metrics=["sales"],
|
||||
row_limit=1000,
|
||||
post_processing=[{"operation": "contribution", "options": {}}],
|
||||
)
|
||||
|
||||
totals_query.validate = MagicMock()
|
||||
main_query.validate = MagicMock()
|
||||
|
||||
captured_limits: list[int | None] = []
|
||||
|
||||
def totals_cache_key(**kwargs: Any) -> str:
|
||||
captured_limits.append(totals_query.row_limit)
|
||||
return "totals-cache-key"
|
||||
|
||||
totals_query.cache_key = totals_cache_key
|
||||
main_query.cache_key = lambda **kwargs: "main-cache-key"
|
||||
|
||||
mock_query_context = MagicMock()
|
||||
mock_query_context.force = False
|
||||
mock_query_context.datasource = mock_datasource
|
||||
mock_query_context.queries = [main_query, totals_query]
|
||||
mock_query_context.result_type = ChartDataResultType.FULL
|
||||
mock_query_context.result_format = ChartDataResultFormat.JSON
|
||||
mock_query_context.cache_values = {
|
||||
"queries": [main_query.to_dict(), totals_query.to_dict()]
|
||||
}
|
||||
mock_query_context.get_query_result = MagicMock()
|
||||
|
||||
processor = QueryContextProcessor(mock_query_context)
|
||||
processor._qc_datasource = mock_datasource
|
||||
mock_query_context.get_df_payload = processor.get_df_payload
|
||||
mock_query_context.get_data = processor.get_data
|
||||
|
||||
with patch(
|
||||
"superset.common.query_context_processor.security_manager"
|
||||
) as mock_security_manager:
|
||||
mock_security_manager.get_rls_cache_key.return_value = None
|
||||
|
||||
with patch(
|
||||
"superset.common.query_context_processor.QueryCacheManager"
|
||||
) as mock_cache_manager:
|
||||
|
||||
def cache_get(*args: Any, **kwargs: Any) -> Any:
|
||||
df = pd.DataFrame({"region": ["North"], "sales": [100]})
|
||||
cache = MagicMock()
|
||||
cache.is_loaded = True
|
||||
cache.df = df
|
||||
cache.query = "SELECT 1"
|
||||
cache.error_message = None
|
||||
cache.status = QueryStatus.SUCCESS
|
||||
cache.applied_template_filters = []
|
||||
cache.applied_filter_columns = []
|
||||
cache.rejected_filter_columns = []
|
||||
cache.annotation_data = {}
|
||||
cache.is_cached = True
|
||||
cache.sql_rowcount = len(df)
|
||||
cache.cache_dttm = "2024-01-01T00:00:00"
|
||||
return cache
|
||||
|
||||
mock_cache_manager.get.side_effect = cache_get
|
||||
|
||||
processor.get_payload(cache_query_context=False, force_cached=True)
|
||||
|
||||
assert captured_limits == [None], "Totals query should be normalized before caching"
|
||||
mock_query_context.get_query_result.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user