# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations import contextlib import logging from datetime import datetime from typing import Any, Callable, TYPE_CHECKING from flask import current_app as app, g, make_response, request, Response from flask_appbuilder.api import expose, protect from flask_babel import gettext as _ from marshmallow import ValidationError from werkzeug.utils import secure_filename from superset import is_feature_enabled, security_manager from superset.async_events.async_query_manager import AsyncQueryTokenException from superset.charts.api import ChartRestApi from superset.charts.client_processing import apply_client_processing from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader from superset.charts.schemas import ChartDataQueryContextSchema from superset.commands.chart.data.create_async_job_command import ( CreateAsyncChartDataJobCommand, ) from superset.commands.chart.data.get_data_command import ChartDataCommand from superset.commands.chart.data.streaming_export_command import ( StreamingCSVExportCommand, ) from superset.commands.chart.exceptions import ( ChartDataCacheLoadError, ChartDataQueryFailedError, ) from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType from superset.connectors.sqla.models import BaseDatasource from superset.daos.exceptions import DatasourceNotFound from superset.exceptions import QueryObjectValidationError from superset.extensions import event_logger from superset.models.sql_lab import Query from superset.utils import json from superset.utils.core import ( create_zip, DatasourceType, get_user_id, ) from superset.utils.decorators import logs_context from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse from superset.views.base_api import statsd_metrics if TYPE_CHECKING: from superset.common.query_context import QueryContext logger = logging.getLogger(__name__) class ChartDataRestApi(ChartRestApi): include_route_methods = {"get_data", "data", "data_from_cache"} @expose("//data/", methods=("GET",)) @protect() @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", log_to_statsd=False, allow_extra_payload=True, ) def get_data( # noqa: C901 self, pk: int, add_extra_log_payload: Callable[..., None] = lambda **kwargs: None, ) -> Response: """ Take a chart ID and uses the query context stored when the chart was saved to return payload data response. --- get: summary: Return payload data response for a chart description: >- Takes a chart ID and uses the query context stored when the chart was saved to return payload data response. parameters: - in: path schema: type: integer name: pk description: The chart ID - in: query name: format description: The format in which the data should be returned schema: type: string - in: query name: type description: The type in which the data should be returned schema: type: string - in: query name: force description: Should the queries be forced to load from the source schema: type: boolean responses: 200: description: Query result content: application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" 202: description: Async job details content: application/json: schema: $ref: "#/components/schemas/ChartDataAsyncResponseSchema" 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 500: $ref: '#/components/responses/500' """ chart = self.datamodel.get(pk, self._base_filters) if not chart: return self.response_404() try: json_body = json.loads(chart.query_context) except (TypeError, json.JSONDecodeError): json_body = None if json_body is None: return self.response_400( message=_( "Chart has no query context saved. Please save the chart again." ) ) # override saved query context json_body["result_format"] = request.args.get( "format", ChartDataResultFormat.JSON ) json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL) json_body["force"] = request.args.get("force") try: query_context = self._create_query_context_from_form(json_body) command = ChartDataCommand(query_context) command.validate() except DatasourceNotFound: return self.response_404() except QueryObjectValidationError as error: return self.response_400(message=error.message) except ValidationError as error: return self.response_400( message=_( "Request is incorrect: %(error)s", error=error.normalized_messages() ) ) # TODO: support CSV, SQL query and other non-JSON types if ( is_feature_enabled("GLOBAL_ASYNC_QUERIES") and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): return self._run_async(json_body, command, add_extra_log_payload) try: form_data = json.loads(chart.params) except (TypeError, json.JSONDecodeError): form_data = {} return self._get_data_response( command=command, form_data=form_data, datasource=query_context.datasource, add_extra_log_payload=add_extra_log_payload, ) @expose("/data", methods=("POST",)) @protect() @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data", log_to_statsd=False, allow_extra_payload=True, ) def data( # noqa: C901 self, add_extra_log_payload: Callable[..., None] = lambda **kwargs: None ) -> Response: """ Take a query context constructed in the client and return payload data response for the given query --- post: summary: Return payload data response for the given query description: >- Takes a query context constructed in the client and returns payload data response for the given query. requestBody: description: >- A query context consists of a datasource from which to fetch data and one or many query objects. required: true content: application/json: schema: $ref: "#/components/schemas/ChartDataQueryContextSchema" responses: 200: description: Query result content: application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" 202: description: Async job details content: application/json: schema: $ref: "#/components/schemas/ChartDataAsyncResponseSchema" 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 500: $ref: '#/components/responses/500' """ json_body = None if request.is_json: json_body = request.json elif request.form.get("form_data"): # CSV export submits regular form data with contextlib.suppress(TypeError, json.JSONDecodeError): json_body = json.loads(request.form["form_data"]) if json_body is None: return self.response_400(message=_("Request is not JSON")) try: query_context = self._create_query_context_from_form(json_body) command = ChartDataCommand(query_context) command.validate() except DatasourceNotFound: return self.response_404() except QueryObjectValidationError as error: return self.response_400(message=error.message) except ValidationError as error: return self.response_400( message=_( "Request is incorrect: %(error)s", error=error.normalized_messages() ) ) # TODO: support CSV, SQL query and other non-JSON types if ( is_feature_enabled("GLOBAL_ASYNC_QUERIES") and query_context.result_format == ChartDataResultFormat.JSON and query_context.result_type == ChartDataResultType.FULL ): return self._run_async(json_body, command, add_extra_log_payload) form_data = json_body.get("form_data") filename, expected_rows = self._extract_export_params_from_request() return self._get_data_response( command, form_data=form_data, datasource=query_context.datasource, add_extra_log_payload=add_extra_log_payload, filename=filename, expected_rows=expected_rows, ) @expose("/data/", methods=("GET",)) @protect() @statsd_metrics @event_logger.log_this_with_context( action=lambda self, *args, **kwargs: f"{self.__class__.__name__}" f".data_from_cache", log_to_statsd=False, ) def data_from_cache(self, cache_key: str) -> Response: """ Take a query context cache key and return payload data response for the given query. --- get: summary: Return payload data response for the given query description: >- Takes a query context cache key and returns payload data response for the given query. parameters: - in: path schema: type: string name: cache_key responses: 200: description: Query result content: application/json: schema: $ref: "#/components/schemas/ChartDataResponseSchema" 400: $ref: '#/components/responses/400' 401: $ref: '#/components/responses/401' 404: $ref: '#/components/responses/404' 422: $ref: '#/components/responses/422' 500: $ref: '#/components/responses/500' """ try: cached_data = self._load_query_context_form_from_cache(cache_key) # Set form_data in Flask Global as it is used as a fallback # for async queries with jinja context g.form_data = cached_data query_context = self._create_query_context_from_form(cached_data) command = ChartDataCommand(query_context) command.validate() except ChartDataCacheLoadError: return self.response_404() except ValidationError as error: return self.response_400( message=_("Request is incorrect: %(error)s", error=error.messages) ) return self._get_data_response(command, True) def _run_async( self, form_data: dict[str, Any], command: ChartDataCommand, add_extra_log_payload: Callable[..., None] | None = None, ) -> Response: """ Execute command as an async query. """ # First, look for the chart query results in the cache. with contextlib.suppress(ChartDataCacheLoadError): result = command.run(force_cached=True) if result is not None: # Log is_cached if extra payload callback is provided. # This indicates no async job was triggered - data was already cached # and a synchronous response is being returned immediately. self._log_is_cached(result, add_extra_log_payload) return self._send_chart_response(result) # Otherwise, kick off a background job to run the chart query. # Clients will either poll or be notified of query completion, # at which point they will call the /data/ endpoint # to retrieve the results. async_command = CreateAsyncChartDataJobCommand() try: async_command.validate(request) except AsyncQueryTokenException: return self.response_401() result = async_command.run(form_data, get_user_id()) return self.response(202, **result) def _send_chart_response( # noqa: C901 self, result: dict[Any, Any], form_data: dict[str, Any] | None = None, datasource: BaseDatasource | Query | None = None, filename: str | None = None, expected_rows: int | None = None, ) -> Response: result_type = result["query_context"].result_type result_format = result["query_context"].result_format # Post-process the data so it matches the data presented in the chart. # This is needed for sending reports based on text charts that do the # post-processing of data, eg, the pivot table. if result_type == ChartDataResultType.POST_PROCESSED: result = apply_client_processing(result, form_data, datasource) if result_format in ChartDataResultFormat.table_like(): # Verify user has permission to export file if not security_manager.can_access("can_csv", "Superset"): return self.response_403() if not result["queries"]: return self.response_400(_("Empty query result")) is_csv_format = result_format == ChartDataResultFormat.CSV # Check if we should use streaming for large datasets if is_csv_format and self._should_use_streaming(result, form_data): return self._create_streaming_csv_response( result, form_data, filename=filename, expected_rows=expected_rows ) if len(result["queries"]) == 1: # return single query results data = result["queries"][0]["data"] if is_csv_format: return CsvResponse(data, headers=generate_download_headers("csv")) return XlsxResponse(data, headers=generate_download_headers("xlsx")) # return multi-query results bundled as a zip file def _process_data(query_data: Any) -> Any: if result_format == ChartDataResultFormat.CSV: encoding = app.config["CSV_EXPORT"].get("encoding", "utf-8") return query_data.encode(encoding) return query_data files = { f"query_{idx + 1}.{result_format}": _process_data(query["data"]) for idx, query in enumerate(result["queries"]) } return Response( create_zip(files), headers=generate_download_headers("zip"), mimetype="application/zip", ) if result_format == ChartDataResultFormat.JSON: queries = result["queries"] if security_manager.is_guest_user(): for query in queries: query.pop("query", None) with event_logger.log_context(f"{self.__class__.__name__}.json_dumps"): response_data = json.dumps( {"result": queries}, default=json.json_int_dttm_ser, ignore_nan=True, ) resp = make_response(response_data, 200) resp.headers["Content-Type"] = "application/json; charset=utf-8" return resp return self.response_400(message=f"Unsupported result_format: {result_format}") def _log_is_cached( self, result: dict[str, Any], add_extra_log_payload: Callable[..., None] | None, ) -> None: """ Log is_cached values from query results to event logger. Extracts is_cached from each query in the result and logs it. If there's a single query, logs the boolean value directly. If multiple queries, logs as a list. """ if add_extra_log_payload and result and "queries" in result: is_cached_values = [query.get("is_cached") for query in result["queries"]] if len(is_cached_values) == 1: add_extra_log_payload(is_cached=is_cached_values[0]) elif is_cached_values: add_extra_log_payload(is_cached=is_cached_values) @event_logger.log_this def _get_data_response( self, command: ChartDataCommand, force_cached: bool = False, form_data: dict[str, Any] | None = None, datasource: BaseDatasource | Query | None = None, filename: str | None = None, expected_rows: int | None = None, add_extra_log_payload: Callable[..., None] | None = None, ) -> Response: """Get data response and optionally log is_cached information.""" try: result = command.run(force_cached=force_cached) except ChartDataCacheLoadError as exc: return self.response_422(message=exc.message) except ChartDataQueryFailedError as exc: return self.response_400(message=exc.message) # Log is_cached if extra payload callback is provided if add_extra_log_payload and result and "queries" in result: is_cached_values = [query.get("is_cached") for query in result["queries"]] add_extra_log_payload(is_cached=is_cached_values) return self._send_chart_response( result, form_data, datasource, filename, expected_rows ) def _extract_export_params_from_request(self) -> tuple[str | None, int | None]: """Extract filename and expected_rows from request for streaming exports.""" filename = request.form.get("filename") if filename: logger.info("FRONTEND PROVIDED FILENAME: %s", filename) expected_rows = None if expected_rows_str := request.form.get("expected_rows"): try: expected_rows = int(expected_rows_str) logger.info("FRONTEND PROVIDED EXPECTED ROWS: %d", expected_rows) except (ValueError, TypeError): logger.warning("Invalid expected_rows value: %s", expected_rows_str) return filename, expected_rows # pylint: disable=invalid-name def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]: return QueryContextCacheLoader.load(cache_key) def _map_form_data_datasource_to_dataset_id( self, form_data: dict[str, Any] ) -> dict[str, Any]: return { "dashboard_id": form_data.get("form_data", {}).get("dashboardId"), "dataset_id": ( form_data.get("datasource", {}).get("id") if isinstance(form_data.get("datasource"), dict) and form_data.get("datasource", {}).get("type") == DatasourceType.TABLE.value else None ), "slice_id": form_data.get("form_data", {}).get("slice_id"), } @logs_context(context_func=_map_form_data_datasource_to_dataset_id) def _create_query_context_from_form( self, form_data: dict[str, Any] ) -> QueryContext: """ Create the query context from the form data. :param form_data: The chart form data :returns: The query context :raises ValidationError: If the request is incorrect """ try: return ChartDataQueryContextSchema().load(form_data) except KeyError as ex: raise ValidationError("Request is incorrect") from ex def _should_use_streaming( self, result: dict[Any, Any], form_data: dict[str, Any] | None = None ) -> bool: """Determine if streaming should be used based on actual row count threshold.""" query_context = result["query_context"] result_format = query_context.result_format # Only support CSV streaming currently if result_format.lower() != "csv": return False # Get streaming threshold from config threshold = app.config.get("CSV_STREAMING_ROW_THRESHOLD", 100000) # Extract actual row count (same logic as frontend) actual_row_count: int | None = None viz_type = form_data.get("viz_type") if form_data else None # For table viz, try to get actual row count from query results if viz_type == "table" and result.get("queries"): # Check if we have rowcount in the second query result (like frontend does) queries = result.get("queries", []) if len(queries) > 1 and queries[1].get("data"): data = queries[1]["data"] if isinstance(data, list) and len(data) > 0: rowcount = data[0].get("rowcount") actual_row_count = int(rowcount) if rowcount else None # Fallback to row_limit if actual count not available if actual_row_count is None: if form_data and "row_limit" in form_data: row_limit = form_data.get("row_limit", 0) actual_row_count = int(row_limit) if row_limit else 0 elif query_context.form_data and "row_limit" in query_context.form_data: row_limit = query_context.form_data.get("row_limit", 0) actual_row_count = int(row_limit) if row_limit else 0 # Use streaming if row count meets or exceeds threshold return actual_row_count is not None and actual_row_count >= threshold def _create_streaming_csv_response( self, result: dict[Any, Any], form_data: dict[str, Any] | None = None, filename: str | None = None, expected_rows: int | None = None, ) -> Response: """Create a streaming CSV response for large datasets.""" query_context = result["query_context"] # Use filename from frontend if provided, otherwise generate one if not filename: timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") chart_name = "export" if form_data and form_data.get("slice_name"): chart_name = form_data["slice_name"] elif form_data and form_data.get("viz_type"): chart_name = form_data["viz_type"] # Sanitize chart name for filename filename = secure_filename(f"superset_{chart_name}_{timestamp}.csv") logger.info("Creating streaming CSV response: %s", filename) if expected_rows: logger.info("Using expected_rows from frontend: %d", expected_rows) # Execute streaming command # TODO: Make chunk size configurable via SUPERSET_CONFIG chunk_size = 1024 command = StreamingCSVExportCommand(query_context, chunk_size) command.validate() # Get the callable that returns the generator csv_generator_callable = command.run() # Get encoding from config encoding = app.config.get("CSV_EXPORT", {}).get("encoding", "utf-8") # Create response with streaming headers response = Response( csv_generator_callable(), # Call the callable to get generator mimetype=f"text/csv; charset={encoding}", headers={ "Content-Disposition": f'attachment; filename="{filename}"', "Cache-Control": "no-cache", "X-Accel-Buffering": "no", # Disable nginx buffering }, direct_passthrough=False, # Flask must iterate generator ) # Force chunked transfer encoding response.implicit_sequence_conversion = False return response