Files
superset2/superset/charts/data/api.py
2026-04-20 19:36:03 -07:00

763 lines
30 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import contextlib
import logging
from datetime import datetime
from typing import Any, Callable, TYPE_CHECKING
from flask import current_app as app, g, make_response, request, Response
from flask_appbuilder.api import expose, protect
from flask_babel import gettext as _
from marshmallow import ValidationError
from werkzeug.utils import secure_filename
from superset import is_feature_enabled, security_manager
from superset.async_events.async_query_manager import AsyncQueryTokenException
from superset.charts.api import ChartRestApi
from superset.charts.client_processing import apply_client_processing
from superset.charts.data.dashboard_filter_context import (
DashboardFilterContext,
get_dashboard_filter_context,
)
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.commands.chart.data.create_async_job_command import (
CreateAsyncChartDataJobCommand,
)
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.commands.chart.data.streaming_export_command import (
StreamingCSVExportCommand,
)
from superset.commands.chart.exceptions import (
ChartDataCacheLoadError,
ChartDataQueryFailedError,
)
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.connectors.sqla.models import BaseDatasource
from superset.constants import (
CACHE_DISABLED_TIMEOUT,
EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS,
EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS,
)
from superset.daos.exceptions import DatasourceNotFound
from superset.exceptions import QueryObjectValidationError, SupersetSecurityException
from superset.extensions import event_logger
from superset.models.sql_lab import Query
from superset.utils import json
from superset.utils.core import (
create_zip,
DatasourceType,
get_user_id,
)
from superset.utils.decorators import logs_context
from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse
from superset.views.base_api import statsd_metrics
if TYPE_CHECKING:
from superset.common.query_context import QueryContext
logger = logging.getLogger(__name__)
class ChartDataRestApi(ChartRestApi):
include_route_methods = {"get_data", "data", "data_from_cache"}
@expose("/<int:pk>/data/", methods=("GET",))
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data",
log_to_statsd=False,
allow_extra_payload=True,
)
def get_data( # noqa: C901
self,
pk: int,
add_extra_log_payload: Callable[..., None] = lambda **kwargs: None,
) -> Response:
"""
Take a chart ID and uses the query context stored when the chart was saved
to return payload data response.
---
get:
summary: Return payload data response for a chart
description: >-
Takes a chart ID and uses the query context stored when the chart was saved
to return payload data response. When filters_dashboard_id is provided,
the chart's compiled SQL includes in scope dashboard filter
default values.
parameters:
- in: path
schema:
type: integer
name: pk
description: The chart ID
- in: query
name: format
description: The format in which the data should be returned
schema:
type: string
- in: query
name: type
description: The type in which the data should be returned
schema:
type: string
- in: query
name: force
description: Should the queries be forced to load from the source
schema:
type: boolean
- in: query
name: filters_dashboard_id
description: >-
Dashboard ID whose filter defaults should be applied to the
chart's query context. The chart must belong to the specified dashboard.
Only in scope filters with static default values are applied; filters that
require a database query (I.E. defaultToFirstItem) or have no default are
reported in the dashboard_filters response metadata.
schema:
type: integer
responses:
200:
description: Query result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataResponseSchema"
202:
description: Async job details
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
403:
$ref: '#/components/responses/403'
404:
$ref: '#/components/responses/404'
500:
$ref: '#/components/responses/500'
"""
chart = self.datamodel.get(pk, self._base_filters)
if not chart:
return self.response_404()
try:
json_body = json.loads(chart.query_context)
except (TypeError, json.JSONDecodeError):
json_body = None
if json_body is None:
return self.response_400(
message=_(
"Chart has no query context saved. Please save the chart again."
)
)
# override saved query context
json_body["result_format"] = request.args.get(
"format", ChartDataResultFormat.JSON
)
json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL)
json_body["force"] = request.args.get("force")
# Apply dashboard filter context when filters_dashboard_id is provided
dashboard_filter_context: DashboardFilterContext | None = None
if "filters_dashboard_id" in request.args:
raw = request.args.get("filters_dashboard_id")
try:
filters_dashboard_id = int(raw)
except (ValueError, TypeError):
return self.response_400(
message="filters_dashboard_id must be an integer"
)
else:
filters_dashboard_id = None
if filters_dashboard_id is not None:
try:
dashboard_filter_context = get_dashboard_filter_context(
dashboard_id=filters_dashboard_id,
chart_id=pk,
)
except ValueError as error:
return self.response_400(message=str(error))
except SupersetSecurityException:
return self.response_403()
if dashboard_filter_context.extra_form_data:
efd = dashboard_filter_context.extra_form_data
extra_filters = efd.get("filters", [])
for query in json_body.get("queries", []):
if extra_filters:
existing = query.get("filters") or []
query["filters"] = existing + [
{**f, "isExtra": True} for f in extra_filters
]
extras = query.get("extras") or {}
for key in EXTRA_FORM_DATA_OVERRIDE_EXTRA_KEYS:
if key in efd:
extras[key] = efd[key]
if extras:
query["extras"] = extras
for (
src_key,
target_key,
) in EXTRA_FORM_DATA_OVERRIDE_REGULAR_MAPPINGS.items():
if src_key in efd:
query[target_key] = efd[src_key]
query["extra_form_data"] = efd
# We need to apply the form data to the global context as jinja
# templating pulls form data from the request globally, so this
# fallback ensures it has the filters and extra_form_data applied
# when used in get_sqla_query which constructs the final query.
# Jinja macros like metric() resolve dataset context from g.form_data
# when not given an explicit dataset_id. For GET requests there is no
# JSON body, so we must always expose the saved query context here.
g.form_data = json_body
try:
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command.validate()
except DatasourceNotFound:
return self.response_404()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
except ValidationError as error:
return self.response_400(
message=_(
"Request is incorrect: %(error)s", error=error.normalized_messages()
)
)
# TODO: support CSV, SQL query and other non-JSON types
# Don't use async queries when cache is disabled (cache_timeout=-1)
# as async queries depend on caching to retrieve results
cache_timeout = query_context.get_cache_timeout()
use_async = (
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
and cache_timeout != CACHE_DISABLED_TIMEOUT
)
if use_async:
return self._run_async(json_body, command, add_extra_log_payload)
try:
form_data = json.loads(chart.params)
except (TypeError, json.JSONDecodeError):
form_data = {}
return self._get_data_response(
command=command,
form_data=form_data,
datasource=query_context.datasource,
add_extra_log_payload=add_extra_log_payload,
dashboard_filter_context=dashboard_filter_context,
)
@expose("/data", methods=("POST",))
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data",
log_to_statsd=False,
allow_extra_payload=True,
)
def data( # noqa: C901
self, add_extra_log_payload: Callable[..., None] = lambda **kwargs: None
) -> Response:
"""
Take a query context constructed in the client and return payload
data response for the given query
---
post:
summary: Return payload data response for the given query
description: >-
Takes a query context constructed in the client and returns payload data
response for the given query.
requestBody:
description: >-
A query context consists of a datasource from which to fetch data
and one or many query objects.
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataQueryContextSchema"
responses:
200:
description: Query result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataResponseSchema"
202:
description: Async job details
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
500:
$ref: '#/components/responses/500'
"""
json_body = None
if request.is_json:
json_body = request.json
elif request.form.get("form_data"):
# CSV export submits regular form data
with contextlib.suppress(TypeError, json.JSONDecodeError):
json_body = json.loads(request.form["form_data"])
if json_body is None:
return self.response_400(message=_("Request is not JSON"))
try:
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command.validate()
except DatasourceNotFound:
return self.response_404()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
except ValidationError as error:
return self.response_400(
message=_(
"Request is incorrect: %(error)s", error=error.normalized_messages()
)
)
# TODO: support CSV, SQL query and other non-JSON types
# Don't use async queries when cache is disabled (cache_timeout=-1)
# as async queries depend on caching to retrieve results
cache_timeout = query_context.get_cache_timeout()
use_async = (
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
and cache_timeout != CACHE_DISABLED_TIMEOUT
)
if use_async:
return self._run_async(json_body, command, add_extra_log_payload)
form_data = json_body.get("form_data")
filename, expected_rows = self._extract_export_params_from_request()
return self._get_data_response(
command,
form_data=form_data,
datasource=query_context.datasource,
add_extra_log_payload=add_extra_log_payload,
filename=filename,
expected_rows=expected_rows,
)
@expose("/data/<cache_key>", methods=("GET",))
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: (
f"{self.__class__.__name__}.data_from_cache"
),
log_to_statsd=False,
)
def data_from_cache(self, cache_key: str) -> Response:
"""
Take a query context cache key and return payload
data response for the given query.
---
get:
summary: Return payload data response for the given query
description: >-
Takes a query context cache key and returns payload data
response for the given query.
parameters:
- in: path
schema:
type: string
name: cache_key
responses:
200:
description: Query result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
cached_data = self._load_query_context_form_from_cache(cache_key)
# Set form_data in Flask Global as it is used as a fallback
# for async queries with jinja context
g.form_data = cached_data
query_context = self._create_query_context_from_form(cached_data)
command = ChartDataCommand(query_context)
command.validate()
except ChartDataCacheLoadError:
return self.response_404()
except ValidationError as error:
return self.response_400(
message=_("Request is incorrect: %(error)s", error=error.messages)
)
return self._get_data_response(command, True)
def _run_async(
self,
form_data: dict[str, Any],
command: ChartDataCommand,
add_extra_log_payload: Callable[..., None] | None = None,
) -> Response:
"""
Execute command as an async query.
"""
# First, look for the chart query results in the cache,
# but only if we're not forcing a refresh.
if not form_data.get("force"):
with contextlib.suppress(ChartDataCacheLoadError):
result = command.run(force_cached=True)
if result is not None:
# Log is_cached if extra payload callback is provided.
# This indicates no async job was triggered - data was already
# cached and a synchronous response is being returned immediately.
self._log_is_cached(result, add_extra_log_payload)
return self._send_chart_response(result)
# Otherwise, kick off a background job to run the chart query.
# Clients will either poll or be notified of query completion,
# at which point they will call the /data/<cache_key> endpoint
# to retrieve the results.
async_command = CreateAsyncChartDataJobCommand()
try:
async_command.validate(request)
except AsyncQueryTokenException:
return self.response_401()
result = async_command.run(form_data, get_user_id())
return self.response(202, **result)
def _send_chart_response( # noqa: C901
self,
result: dict[Any, Any],
form_data: dict[str, Any] | None = None,
datasource: BaseDatasource | Query | None = None,
filename: str | None = None,
expected_rows: int | None = None,
dashboard_filter_context: DashboardFilterContext | None = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format
# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if result_type == ChartDataResultType.POST_PROCESSED:
result = apply_client_processing(result, form_data, datasource)
if result_format in ChartDataResultFormat.table_like():
# Verify user has permission to export file
if is_feature_enabled("GRANULAR_EXPORT_CONTROLS"):
has_export_perm = security_manager.can_access(
"can_export_data", "Superset"
)
else:
has_export_perm = security_manager.can_access("can_csv", "Superset")
if not has_export_perm:
return self.response_403()
if not result["queries"]:
return self.response_400(_("Empty query result"))
is_csv_format = result_format == ChartDataResultFormat.CSV
# Check if we should use streaming for large datasets
if is_csv_format and self._should_use_streaming(result, form_data):
return self._create_streaming_csv_response(
result, form_data, filename=filename, expected_rows=expected_rows
)
if len(result["queries"]) == 1:
# return single query results
data = result["queries"][0]["data"]
if is_csv_format:
return CsvResponse(data, headers=generate_download_headers("csv"))
return XlsxResponse(data, headers=generate_download_headers("xlsx"))
# return multi-query results bundled as a zip file
def _process_data(query_data: Any) -> Any:
if result_format == ChartDataResultFormat.CSV:
encoding = app.config["CSV_EXPORT"].get("encoding", "utf-8")
return query_data.encode(encoding)
return query_data
files = {
f"query_{idx + 1}.{result_format}": _process_data(query["data"])
for idx, query in enumerate(result["queries"])
}
return Response(
create_zip(files),
headers=generate_download_headers("zip"),
mimetype="application/zip",
)
if result_format == ChartDataResultFormat.JSON:
queries = result["queries"]
if security_manager.is_guest_user():
for query in queries:
query.pop("query", None)
payload: dict[str, Any] = {"result": queries}
if dashboard_filter_context is not None:
payload["dashboard_filters"] = dashboard_filter_context.to_dict()
with event_logger.log_context(f"{self.__class__.__name__}.json_dumps"):
response_data = json.dumps(
payload,
default=json.json_int_dttm_ser,
ignore_nan=True,
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
return resp
return self.response_400(message=f"Unsupported result_format: {result_format}")
def _log_is_cached(
self,
result: dict[str, Any],
add_extra_log_payload: Callable[..., None] | None,
) -> None:
"""
Log is_cached values from query results to event logger.
Extracts is_cached from each query in the result and logs it.
If there's a single query, logs the boolean value directly.
If multiple queries, logs as a list.
"""
if add_extra_log_payload and result and "queries" in result:
is_cached_values = [query.get("is_cached") for query in result["queries"]]
if len(is_cached_values) == 1:
add_extra_log_payload(is_cached=is_cached_values[0])
elif is_cached_values:
add_extra_log_payload(is_cached=is_cached_values)
@event_logger.log_this
def _get_data_response(
self,
command: ChartDataCommand,
force_cached: bool = False,
form_data: dict[str, Any] | None = None,
datasource: BaseDatasource | Query | None = None,
filename: str | None = None,
expected_rows: int | None = None,
add_extra_log_payload: Callable[..., None] | None = None,
dashboard_filter_context: DashboardFilterContext | None = None,
) -> Response:
"""Get data response and optionally log is_cached information."""
try:
result = command.run(force_cached=force_cached)
except ChartDataCacheLoadError as exc:
return self.response_422(message=exc.message)
except ChartDataQueryFailedError as exc:
return self.response_400(message=exc.message)
# Log is_cached if extra payload callback is provided
if add_extra_log_payload and result and "queries" in result:
is_cached_values = [query.get("is_cached") for query in result["queries"]]
add_extra_log_payload(is_cached=is_cached_values)
return self._send_chart_response(
result,
form_data,
datasource,
filename,
expected_rows,
dashboard_filter_context=dashboard_filter_context,
)
def _extract_export_params_from_request(self) -> tuple[str | None, int | None]:
"""Extract filename and expected_rows from request for streaming exports."""
filename = request.form.get("filename")
if filename:
logger.info("FRONTEND PROVIDED FILENAME: %s", filename)
expected_rows = None
if expected_rows_str := request.form.get("expected_rows"):
try:
expected_rows = int(expected_rows_str)
logger.info("FRONTEND PROVIDED EXPECTED ROWS: %d", expected_rows)
except (ValueError, TypeError):
logger.warning("Invalid expected_rows value: %s", expected_rows_str)
return filename, expected_rows
# pylint: disable=invalid-name
def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]:
return QueryContextCacheLoader.load(cache_key)
def _map_form_data_datasource_to_dataset_id(
self, form_data: dict[str, Any]
) -> dict[str, Any]:
return {
"dashboard_id": form_data.get("form_data", {}).get("dashboardId"),
"dataset_id": (
form_data.get("datasource", {}).get("id")
if isinstance(form_data.get("datasource"), dict)
and form_data.get("datasource", {}).get("type")
== DatasourceType.TABLE.value
else None
),
"slice_id": form_data.get("form_data", {}).get("slice_id"),
}
@logs_context(context_func=_map_form_data_datasource_to_dataset_id)
def _create_query_context_from_form(
self, form_data: dict[str, Any]
) -> QueryContext:
"""
Create the query context from the form data.
:param form_data: The chart form data
:returns: The query context
:raises ValidationError: If the request is incorrect
"""
try:
return ChartDataQueryContextSchema().load(form_data)
except KeyError as ex:
raise ValidationError("Request is incorrect") from ex
def _should_use_streaming(
self, result: dict[Any, Any], form_data: dict[str, Any] | None = None
) -> bool:
"""Determine if streaming should be used based on actual row count threshold."""
query_context = result["query_context"]
result_format = query_context.result_format
# Only support CSV streaming currently
if result_format.lower() != "csv":
return False
# Get streaming threshold from config
threshold = app.config.get("CSV_STREAMING_ROW_THRESHOLD", 100000)
# Extract actual row count (same logic as frontend)
actual_row_count: int | None = None
viz_type = form_data.get("viz_type") if form_data else None
# For table viz, try to get actual row count from query results
if viz_type == "table" and result.get("queries"):
# Check if we have rowcount in the second query result (like frontend does)
queries = result.get("queries", [])
if len(queries) > 1 and queries[1].get("data"):
data = queries[1]["data"]
if isinstance(data, list) and len(data) > 0:
rowcount = data[0].get("rowcount")
actual_row_count = int(rowcount) if rowcount else None
# Fallback to row_limit if actual count not available
if actual_row_count is None:
if form_data and "row_limit" in form_data:
row_limit = form_data.get("row_limit", 0)
actual_row_count = int(row_limit) if row_limit else 0
elif query_context.form_data and "row_limit" in query_context.form_data:
row_limit = query_context.form_data.get("row_limit", 0)
actual_row_count = int(row_limit) if row_limit else 0
# Use streaming if row count meets or exceeds threshold
return actual_row_count is not None and actual_row_count >= threshold
def _create_streaming_csv_response(
self,
result: dict[Any, Any],
form_data: dict[str, Any] | None = None,
filename: str | None = None,
expected_rows: int | None = None,
) -> Response:
"""Create a streaming CSV response for large datasets."""
query_context = result["query_context"]
# Use filename from frontend if provided, otherwise generate one
if not filename:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
chart_name = "export"
if form_data and form_data.get("slice_name"):
chart_name = form_data["slice_name"]
elif form_data and form_data.get("viz_type"):
chart_name = form_data["viz_type"]
# Sanitize chart name for filename
filename = secure_filename(f"superset_{chart_name}_{timestamp}.csv")
logger.info("Creating streaming CSV response: %s", filename)
if expected_rows:
logger.info("Using expected_rows from frontend: %d", expected_rows)
# Execute streaming command
# TODO: Make chunk size configurable via SUPERSET_CONFIG
chunk_size = 1024
command = StreamingCSVExportCommand(query_context, chunk_size)
command.validate()
# Get the callable that returns the generator
csv_generator_callable = command.run()
# Get encoding from config
encoding = app.config.get("CSV_EXPORT", {}).get("encoding", "utf-8")
# Create response with streaming headers
response = Response(
csv_generator_callable(), # Call the callable to get generator
mimetype=f"text/csv; charset={encoding}",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
direct_passthrough=False, # Flask must iterate generator
)
# Force chunked transfer encoding
response.implicit_sequence_conversion = False
return response