diff --git a/superset/charts/data/api.py b/superset/charts/data/api.py index 3591539a098..981f650df1d 100644 --- a/superset/charts/data/api.py +++ b/superset/charts/data/api.py @@ -18,7 +18,7 @@ from __future__ import annotations import contextlib import logging -from typing import Any, TYPE_CHECKING +from typing import Any, Callable, TYPE_CHECKING from flask import current_app as app, g, make_response, request, Response from flask_appbuilder.api import expose, protect @@ -70,8 +70,13 @@ class ChartDataRestApi(ChartRestApi): @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", log_to_statsd=False, + allow_extra_payload=True, ) - def get_data(self, pk: int) -> Response: + def get_data( # noqa: C901 + self, + pk: int, + add_extra_log_payload: Callable[..., None] = lambda **kwargs: None, + ) -> Response: """ Take a chart ID and uses the query context stored when the chart was saved to return payload data response. @@ -174,7 +179,10 @@ class ChartDataRestApi(ChartRestApi): form_data = {} return self._get_data_response( - command=command, form_data=form_data, datasource=query_context.datasource + command=command, + form_data=form_data, + datasource=query_context.datasource, + add_extra_log_payload=add_extra_log_payload, ) @expose("/data", methods=("POST",)) @@ -183,8 +191,11 @@ class ChartDataRestApi(ChartRestApi): @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", log_to_statsd=False, + allow_extra_payload=True, ) - def data(self) -> Response: + def data( # noqa: C901 + self, add_extra_log_payload: Callable[..., None] = lambda **kwargs: None + ) -> Response: """ Take a query context constructed in the client and return payload data response for the given query @@ -258,7 +269,10 @@ class ChartDataRestApi(ChartRestApi): form_data = json_body.get("form_data") return self._get_data_response( - command, form_data=form_data, datasource=query_context.datasource + command=command, + form_data=form_data, + datasource=query_context.datasource, + add_extra_log_payload=add_extra_log_payload, ) @expose("/data/", methods=("GET",)) @@ -417,7 +431,9 @@ class ChartDataRestApi(ChartRestApi): force_cached: bool = False, form_data: dict[str, Any] | None = None, datasource: BaseDatasource | Query | None = None, + add_extra_log_payload: Callable[..., None] | None = None, ) -> Response: + """Get data response and optionally log is_cached information.""" try: result = command.run(force_cached=force_cached) except ChartDataCacheLoadError as exc: @@ -425,6 +441,14 @@ class ChartDataRestApi(ChartRestApi): except ChartDataQueryFailedError as exc: return self.response_400(message=exc.message) + # Log is_cached if extra payload callback is provided + if add_extra_log_payload and result and "queries" in result: + is_cached_values = [query.get("is_cached") for query in result["queries"]] + if len(is_cached_values) == 1: + add_extra_log_payload(is_cached=is_cached_values[0]) + elif is_cached_values: + add_extra_log_payload(is_cached=is_cached_values) + return self._send_chart_response(result, form_data, datasource) # pylint: disable=invalid-name diff --git a/superset/utils/log.py b/superset/utils/log.py index 8e31b86b269..81ea1bab0bf 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -310,11 +310,13 @@ class AbstractEventLogger(ABC): """Decorator that uses the function name as the action""" return self._wrapper(f) - def log_this_with_context(self, **kwargs: Any) -> Callable[..., Any]: + def log_this_with_context( + self, allow_extra_payload: bool = False, **kwargs: Any + ) -> Callable[..., Any]: """Decorator that can override kwargs of log_context""" def func(f: Callable[..., Any]) -> Callable[..., Any]: - return self._wrapper(f, **kwargs) + return self._wrapper(f, allow_extra_payload=allow_extra_payload, **kwargs) return func diff --git a/tests/integration_tests/charts/data/api_tests.py b/tests/integration_tests/charts/data/api_tests.py index ddf3ef44dbc..58a6bce3356 100644 --- a/tests/integration_tests/charts/data/api_tests.py +++ b/tests/integration_tests/charts/data/api_tests.py @@ -782,6 +782,40 @@ class TestPostChartDataApi(BaseTestChartDataApi): patched_run.assert_called_once_with(force_cached=True) assert data == {"result": [{"query": "select * from foo"}]} + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @mock.patch("superset.extensions.event_logger.log") + def test_chart_data_post_is_cached_in_event_logger(self, mock_event_logger): + """ + Chart data API: Test that is_cached is logged to event logger for POST requests + """ + # First request with force=True - should not be cached + payload_with_force = copy.deepcopy(self.query_context_payload) + payload_with_force["force"] = True + self.post_assert_metric(CHART_DATA_URI, payload_with_force, "data") + + # Check that is_cached was logged as None (not from cache) + call_kwargs = mock_event_logger.call_args[1] + records = call_kwargs.get("records", []) + assert len(records) > 0 + # is_cached should be None when force=True (bypasses cache) + assert "is_cached" in records[0] + assert records[0]["is_cached"] is None + + # Reset mock for second request + mock_event_logger.reset_mock() + + # Second request without force - should be cached + payload_without_force = copy.deepcopy(self.query_context_payload) + payload_without_force["force"] = False + self.post_assert_metric(CHART_DATA_URI, payload_without_force, "data") + + # Check that is_cached was logged as True (from cache) + call_kwargs = mock_event_logger.call_args[1] + records = call_kwargs.get("records", []) + assert len(records) > 0 + # is_cached should be True when retrieved from cache + assert records[0]["is_cached"] is True + @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_chart_data_async_results_type(self): @@ -1238,6 +1272,69 @@ class TestGetChartDataApi(BaseTestChartDataApi): rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") assert rv.json["result"][0]["is_cached"] + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + @mock.patch("superset.extensions.event_logger.log") + def test_chart_data_is_cached_in_event_logger(self, mock_event_logger): + """ + Chart data API: Test that is_cached is logged to event logger + """ + chart = db.session.query(Slice).filter_by(slice_name="Genders").one() + chart.query_context = json.dumps( + { + "datasource": {"id": chart.table.id, "type": "table"}, + "force": False, + "queries": [ + { + "time_range": "1900-01-01T00:00:00 : 2000-01-01T00:00:00", + "granularity": "ds", + "filters": [], + "extras": { + "having": "", + "where": "", + }, + "applied_time_extras": {}, + "columns": ["gender"], + "metrics": ["sum__num"], + "orderby": [["sum__num", False]], + "annotation_layers": [], + "row_limit": 50000, + "timeseries_limit": 0, + "order_desc": True, + "url_params": {}, + "custom_params": {}, + "custom_form_data": {}, + } + ], + "result_format": "json", + "result_type": "full", + } + ) + + # First request - should not be cached (force=true bypasses cache) + self.get_assert_metric(f"api/v1/chart/{chart.id}/data/?force=true", "get_data") + + # Check that is_cached was logged as None (not from cache) + call_kwargs = mock_event_logger.call_args[1] + records = call_kwargs.get("records", []) + assert len(records) > 0 + # is_cached should be None when force=true (bypasses cache) + # The field should exist but be None + assert "is_cached" in records[0] + assert records[0]["is_cached"] is None + + # Reset mock for second request + mock_event_logger.reset_mock() + + # Second request - should be cached + self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data") + + # Check that is_cached was logged as True (from cache) + call_kwargs = mock_event_logger.call_args[1] + records = call_kwargs.get("records", []) + assert len(records) > 0 + # is_cached should be True when retrieved from cache + assert records[0]["is_cached"] is True + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") @with_feature_flags(GLOBAL_ASYNC_QUERIES=True) @mock.patch("superset.charts.data.api.QueryContextCacheLoader") diff --git a/tests/integration_tests/event_logger_tests.py b/tests/integration_tests/event_logger_tests.py index 4a9ca8dde64..f87b5a782e3 100644 --- a/tests/integration_tests/event_logger_tests.py +++ b/tests/integration_tests/event_logger_tests.py @@ -230,3 +230,27 @@ class TestEventLogger(unittest.TestCase): ) assert logger.records[0]["user_id"] == None # noqa: E711 + + @patch.object(DBEventLogger, "log") + def test_log_this_with_context_and_extra_payload(self, mock_log): + logger = DBEventLogger() + + @logger.log_this_with_context(action="test_action", allow_extra_payload=True) + def test_func(arg1, add_extra_log_payload, karg1=1): + time.sleep(0.1) + add_extra_log_payload(custom_field="custom_value") + return arg1 * karg1 + + with app.test_request_context(): + result = test_func(1, karg1=2) # pylint: disable=no-value-for-parameter + payload = mock_log.call_args[1] + assert result == 2 + assert payload["records"] == [ + { + "custom_field": "custom_value", + "path": "/", + "karg1": 2, + "object_ref": test_func.__qualname__, + } + ] + assert payload["duration_ms"] >= 100