diff --git a/superset-frontend/src/components/Chart/chartActions.test.ts b/superset-frontend/src/components/Chart/chartActions.test.ts index f91a9d75194..b4dbd03e3e2 100644 --- a/superset-frontend/src/components/Chart/chartActions.test.ts +++ b/superset-frontend/src/components/Chart/chartActions.test.ts @@ -31,6 +31,7 @@ import { AnnotationSourceType, AnnotationStyle, } from '@superset-ui/core'; +import * as toastActions from 'src/components/MessageToasts/actions'; import { LOG_EVENT } from 'src/logger/actions'; import * as exploreUtils from 'src/explore/exploreUtils'; import * as actions from 'src/components/Chart/chartAction'; @@ -340,6 +341,56 @@ describe('chart actions', () => { ); expect(result).toEqual([1, 2, 3]); }); + + test('dispatches addWarningToast when a query response includes a warning', async () => { + const warningMessage = + 'Results truncated to 1,000 rows due to memory constraints.'; + fetchMock.removeRoute(MOCK_URL); + fetchMock.post( + `glob:*${MOCK_URL}*`, + { result: [{ warning: warningMessage }] }, + { name: MOCK_URL }, + ); + const addWarningToastSpy = jest.spyOn(toastActions, 'addWarningToast'); + + const actionThunk = actions.postChartFormData( + { viz_type: 'my_viz' } as QueryFormData, + false, + undefined, + undefined, + ); + await actionThunk( + dispatch as unknown as actions.ChartThunkDispatch, + mockGetState as unknown as () => actions.RootState, + undefined, + ); + + expect(addWarningToastSpy).toHaveBeenCalledWith(warningMessage, { + noDuplicate: true, + }); + addWarningToastSpy.mockRestore(); + fetchMock.removeRoute(MOCK_URL); + setupDefaultFetchMock(); + }); + + test('does not dispatch addWarningToast when no query response has a warning', async () => { + const addWarningToastSpy = jest.spyOn(toastActions, 'addWarningToast'); + + const actionThunk = actions.postChartFormData( + { viz_type: 'my_viz' } as QueryFormData, + false, + undefined, + undefined, + ); + await actionThunk( + dispatch as unknown as actions.ChartThunkDispatch, + mockGetState as unknown as () => actions.RootState, + undefined, + ); + + expect(addWarningToastSpy).not.toHaveBeenCalled(); + addWarningToastSpy.mockRestore(); + }); }); // eslint-disable-next-line no-restricted-globals -- TODO: Migrate from describe blocks diff --git a/superset/common/query_context_processor.py b/superset/common/query_context_processor.py index 5b39c8f95a7..56c7ac4f5c8 100644 --- a/superset/common/query_context_processor.py +++ b/superset/common/query_context_processor.py @@ -21,7 +21,7 @@ import re from typing import Any, cast, ClassVar, Sequence, TYPE_CHECKING import pandas as pd -from flask import current_app, g +from flask import current_app from flask_babel import gettext as _ from superset.common.chart_data import ChartDataResultFormat @@ -191,12 +191,8 @@ class QueryContextProcessor: cache.df.columns = [unescape_separator(col) for col in cache.df.columns.values] warning: str | None = None - if getattr(g, "bq_memory_limited", False): - row_count = getattr(g, "bq_memory_limited_row_count", len(cache.df)) - # Reset flags immediately so subsequent queries in the same request - # don't inherit this warning - g.bq_memory_limited = False - g.bq_memory_limited_row_count = 0 + if cache.bq_memory_limited: + row_count = cache.bq_memory_limited_row_count chart_id = (self._query_context.form_data or {}).get("slice_id", "") prefix = f"Chart {chart_id}: " if chart_id else "" warning = _( diff --git a/superset/common/utils/query_cache_manager.py b/superset/common/utils/query_cache_manager.py index da2d668e8c9..a7a97e2b874 100644 --- a/superset/common/utils/query_cache_manager.py +++ b/superset/common/utils/query_cache_manager.py @@ -20,7 +20,7 @@ import logging from datetime import datetime, timezone from typing import Any -from flask import current_app +from flask import current_app, g, has_request_context from flask_caching import Cache from pandas import DataFrame @@ -86,6 +86,8 @@ class QueryCacheManager: self.cache_value = cache_value self.sql_rowcount = sql_rowcount self.queried_dttm = queried_dttm + self.bq_memory_limited: bool = False + self.bq_memory_limited_row_count: int = 0 # pylint: disable=too-many-arguments def set_query_result( @@ -123,6 +125,15 @@ class QueryCacheManager: ) self.is_loaded = True + # Capture BigQuery memory-limit flag so it survives cache hits + if has_request_context(): + self.bq_memory_limited = getattr(g, "bq_memory_limited", False) + self.bq_memory_limited_row_count = getattr( + g, "bq_memory_limited_row_count", 0 + ) + g.bq_memory_limited = False + g.bq_memory_limited_row_count = 0 + value = { "df": self.df, "query": self.query, @@ -133,6 +144,8 @@ class QueryCacheManager: "sql_rowcount": self.sql_rowcount, "queried_dttm": self.queried_dttm, "dttm": self.queried_dttm, # Backwards compatibility + "bq_memory_limited": self.bq_memory_limited, + "bq_memory_limited_row_count": self.bq_memory_limited_row_count, } if self.is_loaded and key and self.status != QueryStatus.FAILED: self.set( @@ -193,6 +206,12 @@ class QueryCacheManager: "queried_dttm", cache_value.get("dttm") ) query_cache.cache_value = cache_value + query_cache.bq_memory_limited = cache_value.get( + "bq_memory_limited", False + ) + query_cache.bq_memory_limited_row_count = cache_value.get( + "bq_memory_limited_row_count", 0 + ) current_app.config["STATS_LOGGER"].incr("loaded_from_cache") except KeyError as ex: logger.exception(ex) diff --git a/tests/unit_tests/common/test_query_context_processor.py b/tests/unit_tests/common/test_query_context_processor.py index 88e7fc83867..f8aaeebdc54 100644 --- a/tests/unit_tests/common/test_query_context_processor.py +++ b/tests/unit_tests/common/test_query_context_processor.py @@ -1426,17 +1426,14 @@ def test_get_df_payload_bq_memory_limited_warning(): mock_cache.sql_rowcount = 3 mock_cache.cache_dttm = "2024-01-01T00:00:00" mock_cache.queried_dttm = "2024-01-01T00:00:00" + mock_cache.bq_memory_limited = True + mock_cache.bq_memory_limited_row_count = 5000 mock_cache_manager.get.return_value = mock_cache with patch.object(query_obj, "validate", return_value=None): with patch.object(processor, "query_cache_key", return_value="key"): with patch.object(processor, "get_cache_timeout", return_value=3600): - # Simulate BigQuery memory-limited flag being set on Flask g - with patch("superset.common.query_context_processor.g") as mock_g: - mock_g.bq_memory_limited = True - mock_g.bq_memory_limited_row_count = 5000 - - result = processor.get_df_payload(query_obj, force_cached=False) + result = processor.get_df_payload(query_obj, force_cached=False) assert result["warning"] is not None assert "Chart 42" in result["warning"] @@ -1489,15 +1486,12 @@ def test_get_df_payload_no_warning_when_not_memory_limited(): mock_cache.sql_rowcount = 2 mock_cache.cache_dttm = "2024-01-01T00:00:00" mock_cache.queried_dttm = "2024-01-01T00:00:00" + mock_cache.bq_memory_limited = False mock_cache_manager.get.return_value = mock_cache with patch.object(query_obj, "validate", return_value=None): with patch.object(processor, "query_cache_key", return_value="key"): with patch.object(processor, "get_cache_timeout", return_value=3600): - # g.bq_memory_limited is not set (default) - with patch("superset.common.query_context_processor.g") as mock_g: - mock_g.bq_memory_limited = False - - result = processor.get_df_payload(query_obj, force_cached=False) + result = processor.get_df_payload(query_obj, force_cached=False) assert result["warning"] is None