# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations import contextlib import logging from datetime import datetime from typing import Any, Callable, TYPE_CHECKING from flask import current_app as app, g, make_response, request, Response from flask_appbuilder.api import expose, protect from flask_babel import gettext as _ from marshmallow import ValidationError from werkzeug.utils import secure_filename from superset import is_feature_enabled, security_manager from superset.async_events.async_query_manager import AsyncQueryTokenException from superset.charts.api import ChartRestApi from superset.charts.client_processing import apply_client_processing from superset.charts.data.dashboard_filter_context import ( DashboardFilterContext, get_dashboard_filter_context, ) from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader from superset.charts.schemas import ChartDataQueryContextSchema from superset.commands.chart.data.create_async_job_command import ( CreateAsyncChartDataJobCommand, ) from superset.commands.chart.data.get_data_command import ChartDataCommand from superset.commands.chart.data.streaming_export_command import ( StreamingCSVExportCommand, ) from superset.commands.chart.exceptions import ( ChartDataCacheLoadError, ChartDataQueryFailedError, ) from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.connectors.sqla.models import BaseDatasource from superset.constants import ( CACHE_DISABLED_TIMEOUT, EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS, EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS, ) from superset.daos.exceptions import DatasourceNotFound from superset.exceptions import QueryObjectValidationError, SupersetSecurityException from superset.extensions import event_logger from superset.models.sql_lab import Query from superset.utils import json from superset.utils.core import ( create_zip, DatasourceType, get_user_id, ) from superset.utils.decorators import logs_context from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse from superset.views.base_api import statsd_metrics if TYPE_CHECKING: from superset.common.query_context import QueryContext logger = logging.getLogger(__name__) class ChartDataRestApi(ChartRestApi): include_route_methods = {"get_data", "data", "data_from_cache"} @expose("//data/", methods=("GET",)) @protect() @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", log_to_statsd=False, allow_extra_payload=True, ) def get_data( # noqa: C901 self, pk: int, add_extra_log_payload: Callable[..., None] = lambda **kwargs: None, ) -> Response: """ Take a chart ID and uses the query context stored when the chart was saved to return payload data response. --- get: summary: Return payload data response for a chart description: >- Takes a chart ID and uses the query context stored when the chart was saved to return payload data response. When filters_dashboard_id is provided, the chart's compiled SQL includes in scope dashboard filter default values. parameters: - in: path schema: type: integer name: pk description: The chart ID - in: query name: format description: The format in which the data should be returned schema: type: string - in: query name: type description: The type in which the data should be returned schema: type: string - in: query name: force description: Should the queries be forced to load from the source schema: type: boolean - in: query name: filters_dashboard_id description: >- Dashboard ID whose filter defaults should be applied to the chart's query context. The chart must belong to the specified dashboard. Only in scope filters with static default values are applied; filters that require a database query (I.E. defaultToFirstItem) or have no default are reported in the dashboard_filters response metadata. schema: type: integer responses: 200: description: Query result content: application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" 202: description: Async job details content: application/json: schema: $ref: "#/components/schemas/ChartDataAsyncResponseSchema" 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 403: $ref: '#/components/responses/403' 404: $ref: '#/components/responses/404' 500: $ref: '#/components/responses/500' """ chart = self.datamodel.get(pk, self._base_filters) if not chart: return self.response_404() try: json_body = json.loads(chart.query_context) except (TypeError, json.JSONDecodeError): json_body = None if json_body is None: return self.response_400( message=_( "Chart has no query context saved. Please save the chart again." ) ) # override saved query context json_body["result_format"] = request.args.get( "format", ChartDataResultFormat.JSON ) json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL) json_body["force"] = request.args.get("force") # Apply dashboard filter context when filters_dashboard_id is provided dashboard_filter_context: DashboardFilterContext | None = None if "filters_dashboard_id" in request.args: raw = request.args.get("filters_dashboard_id") try: filters_dashboard_id = int(raw) except (ValueError, TypeError): return self.response_400( message="filters_dashboard_id must be an integer" ) else: filters_dashboard_id = None if filters_dashboard_id is not None: try: dashboard_filter_context = get_dashboard_filter_context( dashboard_id=filters_dashboard_id, chart_id=pk, ) except ValueError as error: return self.response_400(message=str(error)) except SupersetSecurityException: return self.response_403() if dashboard_filter_context.extra_form_data: efd = dashboard_filter_context.extra_form_data extra_filters = efd.get("filters", []) for query in json_body.get("queries", []): if extra_filters: existing = query.get("filters") or [] query["filters"] = existing + [ {**f, "isExtra": True} for f in extra_filters ] extras = query.get("extras") or {} for key in EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS: if key in efd: extras[key] = efd[key] if extras: query["extras"] = extras for ( src_key, target_key, ) in EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS.items(): if src_key in efd: query[target_key] = efd[src_key] query["extra_form_data"] = efd # We need to apply the form data to the global context as jinja # templating pulls form data from the request globally, so this # fallback ensures it has the filters and extra_form_data applied # when used in get_sqla_query which constructs the final query. # Jinja macros like metric() resolve dataset context from g.form_data # when not given an explicit dataset_id. For GET requests there is no # JSON body, so we must always expose the saved query context here. g.form_data = json_body try: query_context = self._create_query_context_from_form(json_body) command = ChartDataCommand(query_context) command.validate() except DatasourceNotFound: return self.response_404() except QueryObjectValidationError as error: return self.response_400(message=error.message) except ValidationError as error: return self.response_400( message=_( "Request is incorrect: %(error)s", error=error.normalized_messages() ) ) # TODO: support CSV, SQL query and other non-JSON types # Don't use async queries when cache is disabled (cache_timeout=-1) # as async queries depend on caching to retrieve results cache_timeout = query_context.get_cache_timeout() use_async = ( is_feature_enabled("GLOBAL_ASYNC_QUERIES") and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL and cache_timeout != CACHE_DISABLED_TIMEOUT ) if use_async: return self._run_async(json_body, command, add_extra_log_payload) try: form_data = json.loads(chart.params) except (TypeError, json.JSONDecodeError): form_data = {} return self._get_data_response( command=command, form_data=form_data, datasource=query_context.datasource, add_extra_log_payload=add_extra_log_payload, dashboard_filter_context=dashboard_filter_context, ) @expose("/data", methods=("POST",)) @protect() @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", log_to_statsd=False, allow_extra_payload=True, ) def data( # noqa: C901 self, add_extra_log_payload: Callable[..., None] = lambda **kwargs: None ) -> Response: """ Take a query context constructed in the client and return payload data response for the given query --- post: summary: Return payload data response for the given query description: >- Takes a query context constructed in the client and returns payload data response for the given query. requestBody: description: >- A query context consists of a datasource from which to fetch data and one or many query objects. required: true content: application/json: schema: $ref: "#/components/schemas/ChartDataQueryContextSchema" responses: 200: description: Query result content: application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" 202: description: Async job details content: application/json: schema: $ref: "#/components/schemas/ChartDataAsyncResponseSchema" 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 500: $ref: '#/components/responses/500' """ json_body = None if request.is_json: json_body = request.json elif request.form.get("form_data"): # CSV export submits regular form data with contextlib.suppress(TypeError, json.JSONDecodeError): json_body = json.loads(request.form["form_data"]) if json_body is None: return self.response_400(message=_("Request is not JSON")) try: query_context = self._create_query_context_from_form(json_body) command = ChartDataCommand(query_context) command.validate() except DatasourceNotFound: return self.response_404() except QueryObjectValidationError as error: return self.response_400(message=error.message) except ValidationError as error: return self.response_400( message=_( "Request is incorrect: %(error)s", error=error.normalized_messages() ) ) # TODO: support CSV, SQL query and other non-JSON types # Don't use async queries when cache is disabled (cache_timeout=-1) # as async queries depend on caching to retrieve results cache_timeout = query_context.get_cache_timeout() use_async = ( is_feature_enabled("GLOBAL_ASYNC_QUERIES") and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL and cache_timeout != CACHE_DISABLED_TIMEOUT ) if use_async: return self._run_async(json_body, command, add_extra_log_payload) form_data = json_body.get("form_data") filename, expected_rows = self._extract_export_params_from_request() return self._get_data_response( command, form_data=form_data, datasource=query_context.datasource, add_extra_log_payload=add_extra_log_payload, filename=filename, expected_rows=expected_rows, ) @expose("/data/", methods=("GET",)) @protect() @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: ( f"{self.__class__.__name__}.data_from_cache" ), log_to_statsd=False, ) def data_from_cache(self, cache_key: str) -> Response: """ Take a query context cache key and return payload data response for the given query. --- get: summary: Return payload data response for the given query description: >- Takes a query context cache key and returns payload data response for the given query. parameters: - in: path schema: type: string name: cache_key responses: 200: description: Query result content: application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 404: $ref: '#/components/responses/404' 422: $ref: '#/components/responses/422' 500: $ref: '#/components/responses/500' """ try: cached_data = self._load_query_context_form_from_cache(cache_key) # Set form_data in Flask Global as it is used as a fallback # for async queries with jinja context g.form_data = cached_data query_context = self._create_query_context_from_form(cached_data) command = ChartDataCommand(query_context) command.validate() except ChartDataCacheLoadError: return self.response_404() except ValidationError as error: return self.response_400( message=_("Request is incorrect: %(error)s", error=error.messages) ) return self._get_data_response(command, True) def _run_async( self, form_data: dict[str, Any], command: ChartDataCommand, add_extra_log_payload: Callable[..., None] | None = None, ) -> Response: """ Execute command as an async query. """ # First, look for the chart query results in the cache, # but only if we're not forcing a refresh. if not form_data.get("force"): with contextlib.suppress(ChartDataCacheLoadError): result = command.run(force_cached=True) if result is not None: # Log is_cached if extra payload callback is provided. # This indicates no async job was triggered - data was already # cached and a synchronous response is being returned immediately. self._log_is_cached(result, add_extra_log_payload) return self._send_chart_response(result) # Otherwise, kick off a background job to run the chart query. # Clients will either poll or be notified of query completion, # at which point they will call the /data/ endpoint # to retrieve the results. async_command = CreateAsyncChartDataJobCommand() try: async_command.validate(request) except AsyncQueryTokenException: return self.response_401() result = async_command.run(form_data, get_user_id()) return self.response(202, **result) def _send_chart_response( # noqa: C901 self, result: dict[Any, Any], form_data: dict[str, Any] | None = None, datasource: BaseDatasource | Query | None = None, filename: str | None = None, expected_rows: int | None = None, dashboard_filter_context: DashboardFilterContext | None = None, ) -> Response: result_type = result["query_context"].result_type result_format = result["query_context"].result_format # Post-process the data so it matches the data presented in the chart. # This is needed for sending reports based on text charts that do the # post-processing of data, eg, the pivot table. if result_type == ChartDataResultType.POST_PROCESSED: result = apply_client_processing(result, form_data, datasource) if result_format in ChartDataResultFormat.table_like(): # Verify user has permission to export file if is_feature_enabled("GRANULAR_EXPORT_CONTROLS"): has_export_perm = security_manager.can_access( "can_export_data", "Superset" ) else: has_export_perm = security_manager.can_access("can_csv", "Superset") if not has_export_perm: return self.response_403() if not result["queries"]: return self.response_400(_("Empty query result")) is_csv_format = result_format == ChartDataResultFormat.CSV # Check if we should use streaming for large datasets if is_csv_format and self._should_use_streaming(result, form_data): return self._create_streaming_csv_response( result, form_data, filename=filename, expected_rows=expected_rows ) if len(result["queries"]) == 1: # return single query results data = result["queries"][0]["data"] if is_csv_format: return CsvResponse(data, headers=generate_download_headers("csv")) return XlsxResponse(data, headers=generate_download_headers("xlsx")) # return multi-query results bundled as a zip file def _process_data(query_data: Any) -> Any: if result_format == ChartDataResultFormat.CSV: encoding = app.config["CSV_EXPORT"].get("encoding", "utf-8") return query_data.encode(encoding) return query_data files = { f"query_{idx + 1}.{result_format}": _process_data(query["data"]) for idx, query in enumerate(result["queries"]) } return Response( create_zip(files), headers=generate_download_headers("zip"), mimetype="application/zip", ) if result_format == ChartDataResultFormat.JSON: queries = result["queries"] if security_manager.is_guest_user(): for query in queries: query.pop("query", None) payload: dict[str, Any] = {"result": queries} if dashboard_filter_context is not None: payload["dashboard_filters"] = dashboard_filter_context.to_dict() with event_logger.log_context(f"{self.__class__.__name__}.json_dumps"): response_data = json.dumps( payload, default=json.json_int_dttm_ser, ignore_nan=True, ) resp = make_response(response_data, 200) resp.headers["Content-Type"] = "application/json; charset=utf-8" return resp return self.response_400(message=f"Unsupported result_format: {result_format}") def _log_is_cached( self, result: dict[str, Any], add_extra_log_payload: Callable[..., None] | None, ) -> None: """ Log is_cached values from query results to event logger. Extracts is_cached from each query in the result and logs it. If there's a single query, logs the boolean value directly. If multiple queries, logs as a list. """ if add_extra_log_payload and result and "queries" in result: is_cached_values = [query.get("is_cached") for query in result["queries"]] if len(is_cached_values) == 1: add_extra_log_payload(is_cached=is_cached_values[0]) elif is_cached_values: add_extra_log_payload(is_cached=is_cached_values) @event_logger.log_this def _get_data_response( self, command: ChartDataCommand, force_cached: bool = False, form_data: dict[str, Any] | None = None, datasource: BaseDatasource | Query | None = None, filename: str | None = None, expected_rows: int | None = None, add_extra_log_payload: Callable[..., None] | None = None, dashboard_filter_context: DashboardFilterContext | None = None, ) -> Response: """Get data response and optionally log is_cached information.""" try: result = command.run(force_cached=force_cached) except ChartDataCacheLoadError as exc: return self.response_422(message=exc.message) except ChartDataQueryFailedError as exc: return self.response_400(message=exc.message) # Log is_cached if extra payload callback is provided if add_extra_log_payload and result and "queries" in result: is_cached_values = [query.get("is_cached") for query in result["queries"]] add_extra_log_payload(is_cached=is_cached_values) return self._send_chart_response( result, form_data, datasource, filename, expected_rows, dashboard_filter_context=dashboard_filter_context, ) def _extract_export_params_from_request(self) -> tuple[str | None, int | None]: """Extract filename and expected_rows from request for streaming exports.""" filename = request.form.get("filename") if filename: logger.info("FRONTEND PROVIDED FILENAME: %s", filename) expected_rows = None if expected_rows_str := request.form.get("expected_rows"): try: expected_rows = int(expected_rows_str) logger.info("FRONTEND PROVIDED EXPECTED ROWS: %d", expected_rows) except (ValueError, TypeError): logger.warning("Invalid expected_rows value: %s", expected_rows_str) return filename, expected_rows # pylint: disable=invalid-name def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]: return QueryContextCacheLoader.load(cache_key) def _map_form_data_datasource_to_dataset_id( self, form_data: dict[str, Any] ) -> dict[str, Any]: return { "dashboard_id": form_data.get("form_data", {}).get("dashboardId"), "dataset_id": ( form_data.get("datasource", {}).get("id") if isinstance(form_data.get("datasource"), dict) and form_data.get("datasource", {}).get("type") == DatasourceType.TABLE.value else None ), "slice_id": form_data.get("form_data", {}).get("slice_id"), } @logs_context(context_func=_map_form_data_datasource_to_dataset_id) def _create_query_context_from_form( self, form_data: dict[str, Any] ) -> QueryContext: """ Create the query context from the form data. :param form_data: The chart form data :returns: The query context :raises ValidationError: If the request is incorrect """ try: return ChartDataQueryContextSchema().load(form_data) except KeyError as ex: raise ValidationError("Request is incorrect") from ex def _should_use_streaming( self, result: dict[Any, Any], form_data: dict[str, Any] | None = None ) -> bool: """Determine if streaming should be used based on actual row count threshold.""" query_context = result["query_context"] result_format = query_context.result_format # Only support CSV streaming currently if result_format.lower() != "csv": return False # Get streaming threshold from config threshold = app.config.get("CSV_STREAMING_ROW_THRESHOLD", 100000) # Extract actual row count (same logic as frontend) actual_row_count: int | None = None viz_type = form_data.get("viz_type") if form_data else None # For table viz, try to get actual row count from query results if viz_type == "table" and result.get("queries"): # Check if we have rowcount in the second query result (like frontend does) queries = result.get("queries", []) if len(queries) > 1 and queries[1].get("data"): data = queries[1]["data"] if isinstance(data, list) and len(data) > 0: rowcount = data[0].get("rowcount") actual_row_count = int(rowcount) if rowcount else None # Fallback to row_limit if actual count not available if actual_row_count is None: if form_data and "row_limit" in form_data: row_limit = form_data.get("row_limit", 0) actual_row_count = int(row_limit) if row_limit else 0 elif query_context.form_data and "row_limit" in query_context.form_data: row_limit = query_context.form_data.get("row_limit", 0) actual_row_count = int(row_limit) if row_limit else 0 # Use streaming if row count meets or exceeds threshold return actual_row_count is not None and actual_row_count >= threshold def _create_streaming_csv_response( self, result: dict[Any, Any], form_data: dict[str, Any] | None = None, filename: str | None = None, expected_rows: int | None = None, ) -> Response: """Create a streaming CSV response for large datasets.""" query_context = result["query_context"] # Use filename from frontend if provided, otherwise generate one if not filename: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") chart_name = "export" if form_data and form_data.get("slice_name"): chart_name = form_data["slice_name"] elif form_data and form_data.get("viz_type"): chart_name = form_data["viz_type"] # Sanitize chart name for filename filename = secure_filename(f"superset_{chart_name}_{timestamp}.csv") logger.info("Creating streaming CSV response: %s", filename) if expected_rows: logger.info("Using expected_rows from frontend: %d", expected_rows) # Execute streaming command # TODO: Make chunk size configurable via SUPERSET_CONFIG chunk_size = 1024 command = StreamingCSVExportCommand(query_context, chunk_size) command.validate() # Get the callable that returns the generator csv_generator_callable = command.run() # Get encoding from config encoding = app.config.get("CSV_EXPORT", {}).get("encoding", "utf-8") # Create response with streaming headers response = Response( csv_generator_callable(), # Call the callable to get generator mimetype=f"text/csv; charset={encoding}", headers={ "Content-Disposition": f'attachment; filename="{filename}"', "Cache-Control": "no-cache", "X-Accel-Buffering": "no", # Disable nginx buffering }, direct_passthrough=False, # Flask must iterate generator ) # Force chunked transfer encoding response.implicit_sequence_conversion = False return response