Files
superset2/superset/charts/data/api.py

645 lines
25 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations
import contextlib
import logging
from datetime import datetime
from typing import Any, Callable, TYPE_CHECKING
from flask import current_app as app, g, make_response, request, Response
from flask_appbuilder.api import expose, protect
from flask_babel import gettext as _
from marshmallow import ValidationError
from werkzeug.utils import secure_filename
from superset import is_feature_enabled, security_manager
from superset.async_events.async_query_manager import AsyncQueryTokenException
from superset.charts.api import ChartRestApi
from superset.charts.client_processing import apply_client_processing
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.commands.chart.data.create_async_job_command import (
CreateAsyncChartDataJobCommand,
)
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.commands.chart.data.streaming_export_command import (
StreamingCSVExportCommand,
)
from superset.commands.chart.exceptions import (
ChartDataCacheLoadError,
ChartDataQueryFailedError,
)
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.connectors.sqla.models import BaseDatasource
from superset.daos.exceptions import DatasourceNotFound
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger
from superset.models.sql_lab import Query
from superset.utils import json
from superset.utils.core import (
create_zip,
DatasourceType,
get_user_id,
)
from superset.utils.decorators import logs_context
from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse
from superset.views.base_api import statsd_metrics
if TYPE_CHECKING:
from superset.common.query_context import QueryContext
logger = logging.getLogger(__name__)
class ChartDataRestApi(ChartRestApi):
include_route_methods = {"get_data", "data", "data_from_cache"}
@expose("/<int:pk>/data/", methods=("GET",))
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data",
log_to_statsd=False,
allow_extra_payload=True,
)
def get_data( # noqa: C901
self,
pk: int,
add_extra_log_payload: Callable[..., None] = lambda **kwargs: None,
) -> Response:
"""
Take a chart ID and uses the query context stored when the chart was saved
to return payload data response.
---
get:
summary: Return payload data response for a chart
description: >-
Takes a chart ID and uses the query context stored when the chart was saved
to return payload data response.
parameters:
- in: path
schema:
type: integer
name: pk
description: The chart ID
- in: query
name: format
description: The format in which the data should be returned
schema:
type: string
- in: query
name: type
description: The type in which the data should be returned
schema:
type: string
- in: query
name: force
description: Should the queries be forced to load from the source
schema:
type: boolean
responses:
200:
description: Query result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataResponseSchema"
202:
description: Async job details
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
500:
$ref: '#/components/responses/500'
"""
chart = self.datamodel.get(pk, self._base_filters)
if not chart:
return self.response_404()
try:
json_body = json.loads(chart.query_context)
except (TypeError, json.JSONDecodeError):
json_body = None
if json_body is None:
return self.response_400(
message=_(
"Chart has no query context saved. Please save the chart again."
)
)
# override saved query context
json_body["result_format"] = request.args.get(
"format", ChartDataResultFormat.JSON
)
json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL)
json_body["force"] = request.args.get("force")
try:
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command.validate()
except DatasourceNotFound:
return self.response_404()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
except ValidationError as error:
return self.response_400(
message=_(
"Request is incorrect: %(error)s", error=error.normalized_messages()
)
)
# TODO: support CSV, SQL query and other non-JSON types
if (
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
):
return self._run_async(json_body, command, add_extra_log_payload)
try:
form_data = json.loads(chart.params)
except (TypeError, json.JSONDecodeError):
form_data = {}
return self._get_data_response(
command=command,
form_data=form_data,
datasource=query_context.datasource,
add_extra_log_payload=add_extra_log_payload,
)
@expose("/data", methods=("POST",))
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data",
log_to_statsd=False,
allow_extra_payload=True,
)
def data( # noqa: C901
self, add_extra_log_payload: Callable[..., None] = lambda **kwargs: None
) -> Response:
"""
Take a query context constructed in the client and return payload
data response for the given query
---
post:
summary: Return payload data response for the given query
description: >-
Takes a query context constructed in the client and returns payload data
response for the given query.
requestBody:
description: >-
A query context consists of a datasource from which to fetch data
and one or many query objects.
required: true
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataQueryContextSchema"
responses:
200:
description: Query result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataResponseSchema"
202:
description: Async job details
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
500:
$ref: '#/components/responses/500'
"""
json_body = None
if request.is_json:
json_body = request.json
elif request.form.get("form_data"):
# CSV export submits regular form data
with contextlib.suppress(TypeError, json.JSONDecodeError):
json_body = json.loads(request.form["form_data"])
if json_body is None:
return self.response_400(message=_("Request is not JSON"))
try:
query_context = self._create_query_context_from_form(json_body)
command = ChartDataCommand(query_context)
command.validate()
except DatasourceNotFound:
return self.response_404()
except QueryObjectValidationError as error:
return self.response_400(message=error.message)
except ValidationError as error:
return self.response_400(
message=_(
"Request is incorrect: %(error)s", error=error.normalized_messages()
)
)
# TODO: support CSV, SQL query and other non-JSON types
if (
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
and query_context.result_format == ChartDataResultFormat.JSON
and query_context.result_type == ChartDataResultType.FULL
):
return self._run_async(json_body, command, add_extra_log_payload)
form_data = json_body.get("form_data")
filename, expected_rows = self._extract_export_params_from_request()
return self._get_data_response(
command,
form_data=form_data,
datasource=query_context.datasource,
add_extra_log_payload=add_extra_log_payload,
filename=filename,
expected_rows=expected_rows,
)
@expose("/data/<cache_key>", methods=("GET",))
@protect()
@statsd_metrics
@event_logger.log_this_with_context(
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}"
f".data_from_cache",
log_to_statsd=False,
)
def data_from_cache(self, cache_key: str) -> Response:
"""
Take a query context cache key and return payload
data response for the given query.
---
get:
summary: Return payload data response for the given query
description: >-
Takes a query context cache key and returns payload data
response for the given query.
parameters:
- in: path
schema:
type: string
name: cache_key
responses:
200:
description: Query result
content:
application/json:
schema:
$ref: "#/components/schemas/ChartDataResponseSchema"
400:
$ref: '#/components/responses/400'
401:
$ref: '#/components/responses/401'
404:
$ref: '#/components/responses/404'
422:
$ref: '#/components/responses/422'
500:
$ref: '#/components/responses/500'
"""
try:
cached_data = self._load_query_context_form_from_cache(cache_key)
# Set form_data in Flask Global as it is used as a fallback
# for async queries with jinja context
g.form_data = cached_data
query_context = self._create_query_context_from_form(cached_data)
command = ChartDataCommand(query_context)
command.validate()
except ChartDataCacheLoadError:
return self.response_404()
except ValidationError as error:
return self.response_400(
message=_("Request is incorrect: %(error)s", error=error.messages)
)
return self._get_data_response(command, True)
def _run_async(
self,
form_data: dict[str, Any],
command: ChartDataCommand,
add_extra_log_payload: Callable[..., None] | None = None,
) -> Response:
"""
Execute command as an async query.
"""
# First, look for the chart query results in the cache.
with contextlib.suppress(ChartDataCacheLoadError):
result = command.run(force_cached=True)
if result is not None:
# Log is_cached if extra payload callback is provided.
# This indicates no async job was triggered - data was already cached
# and a synchronous response is being returned immediately.
self._log_is_cached(result, add_extra_log_payload)
return self._send_chart_response(result)
# Otherwise, kick off a background job to run the chart query.
# Clients will either poll or be notified of query completion,
# at which point they will call the /data/<cache_key> endpoint
# to retrieve the results.
async_command = CreateAsyncChartDataJobCommand()
try:
async_command.validate(request)
except AsyncQueryTokenException:
return self.response_401()
result = async_command.run(form_data, get_user_id())
return self.response(202, **result)
def _send_chart_response( # noqa: C901
self,
result: dict[Any, Any],
form_data: dict[str, Any] | None = None,
datasource: BaseDatasource | Query | None = None,
filename: str | None = None,
expected_rows: int | None = None,
) -> Response:
result_type = result["query_context"].result_type
result_format = result["query_context"].result_format
# Post-process the data so it matches the data presented in the chart.
# This is needed for sending reports based on text charts that do the
# post-processing of data, eg, the pivot table.
if result_type == ChartDataResultType.POST_PROCESSED:
result = apply_client_processing(result, form_data, datasource)
if result_format in ChartDataResultFormat.table_like():
# Verify user has permission to export file
if not security_manager.can_access("can_csv", "Superset"):
return self.response_403()
if not result["queries"]:
return self.response_400(_("Empty query result"))
is_csv_format = result_format == ChartDataResultFormat.CSV
# Check if we should use streaming for large datasets
if is_csv_format and self._should_use_streaming(result, form_data):
return self._create_streaming_csv_response(
result, form_data, filename=filename, expected_rows=expected_rows
)
if len(result["queries"]) == 1:
# return single query results
data = result["queries"][0]["data"]
if is_csv_format:
return CsvResponse(data, headers=generate_download_headers("csv"))
return XlsxResponse(data, headers=generate_download_headers("xlsx"))
# return multi-query results bundled as a zip file
def _process_data(query_data: Any) -> Any:
if result_format == ChartDataResultFormat.CSV:
encoding = app.config["CSV_EXPORT"].get("encoding", "utf-8")
return query_data.encode(encoding)
return query_data
files = {
f"query_{idx + 1}.{result_format}": _process_data(query["data"])
for idx, query in enumerate(result["queries"])
}
return Response(
create_zip(files),
headers=generate_download_headers("zip"),
mimetype="application/zip",
)
if result_format == ChartDataResultFormat.JSON:
queries = result["queries"]
if security_manager.is_guest_user():
for query in queries:
query.pop("query", None)
with event_logger.log_context(f"{self.__class__.__name__}.json_dumps"):
response_data = json.dumps(
{"result": queries},
default=json.json_int_dttm_ser,
ignore_nan=True,
)
resp = make_response(response_data, 200)
resp.headers["Content-Type"] = "application/json; charset=utf-8"
return resp
return self.response_400(message=f"Unsupported result_format: {result_format}")
def _log_is_cached(
self,
result: dict[str, Any],
add_extra_log_payload: Callable[..., None] | None,
) -> None:
"""
Log is_cached values from query results to event logger.
Extracts is_cached from each query in the result and logs it.
If there's a single query, logs the boolean value directly.
If multiple queries, logs as a list.
"""
if add_extra_log_payload and result and "queries" in result:
is_cached_values = [query.get("is_cached") for query in result["queries"]]
if len(is_cached_values) == 1:
add_extra_log_payload(is_cached=is_cached_values[0])
elif is_cached_values:
add_extra_log_payload(is_cached=is_cached_values)
@event_logger.log_this
def _get_data_response(
self,
command: ChartDataCommand,
force_cached: bool = False,
form_data: dict[str, Any] | None = None,
datasource: BaseDatasource | Query | None = None,
filename: str | None = None,
expected_rows: int | None = None,
add_extra_log_payload: Callable[..., None] | None = None,
) -> Response:
"""Get data response and optionally log is_cached information."""
try:
result = command.run(force_cached=force_cached)
except ChartDataCacheLoadError as exc:
return self.response_422(message=exc.message)
except ChartDataQueryFailedError as exc:
return self.response_400(message=exc.message)
# Log is_cached if extra payload callback is provided
if add_extra_log_payload and result and "queries" in result:
is_cached_values = [query.get("is_cached") for query in result["queries"]]
add_extra_log_payload(is_cached=is_cached_values)
return self._send_chart_response(
result, form_data, datasource, filename, expected_rows
)
def _extract_export_params_from_request(self) -> tuple[str | None, int | None]:
"""Extract filename and expected_rows from request for streaming exports."""
filename = request.form.get("filename")
if filename:
logger.info("FRONTEND PROVIDED FILENAME: %s", filename)
expected_rows = None
if expected_rows_str := request.form.get("expected_rows"):
try:
expected_rows = int(expected_rows_str)
logger.info("FRONTEND PROVIDED EXPECTED ROWS: %d", expected_rows)
except (ValueError, TypeError):
logger.warning("Invalid expected_rows value: %s", expected_rows_str)
return filename, expected_rows
# pylint: disable=invalid-name
def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]:
return QueryContextCacheLoader.load(cache_key)
def _map_form_data_datasource_to_dataset_id(
self, form_data: dict[str, Any]
) -> dict[str, Any]:
return {
"dashboard_id": form_data.get("form_data", {}).get("dashboardId"),
"dataset_id": (
form_data.get("datasource", {}).get("id")
if isinstance(form_data.get("datasource"), dict)
and form_data.get("datasource", {}).get("type")
== DatasourceType.TABLE.value
else None
),
"slice_id": form_data.get("form_data", {}).get("slice_id"),
}
@logs_context(context_func=_map_form_data_datasource_to_dataset_id)
def _create_query_context_from_form(
self, form_data: dict[str, Any]
) -> QueryContext:
"""
Create the query context from the form data.
:param form_data: The chart form data
:returns: The query context
:raises ValidationError: If the request is incorrect
"""
try:
return ChartDataQueryContextSchema().load(form_data)
except KeyError as ex:
raise ValidationError("Request is incorrect") from ex
def _should_use_streaming(
self, result: dict[Any, Any], form_data: dict[str, Any] | None = None
) -> bool:
"""Determine if streaming should be used based on actual row count threshold."""
query_context = result["query_context"]
result_format = query_context.result_format
# Only support CSV streaming currently
if result_format.lower() != "csv":
return False
# Get streaming threshold from config
threshold = app.config.get("CSV_STREAMING_ROW_THRESHOLD", 100000)
# Extract actual row count (same logic as frontend)
actual_row_count: int | None = None
viz_type = form_data.get("viz_type") if form_data else None
# For table viz, try to get actual row count from query results
if viz_type == "table" and result.get("queries"):
# Check if we have rowcount in the second query result (like frontend does)
queries = result.get("queries", [])
if len(queries) > 1 and queries[1].get("data"):
data = queries[1]["data"]
if isinstance(data, list) and len(data) > 0:
rowcount = data[0].get("rowcount")
actual_row_count = int(rowcount) if rowcount else None
# Fallback to row_limit if actual count not available
if actual_row_count is None:
if form_data and "row_limit" in form_data:
row_limit = form_data.get("row_limit", 0)
actual_row_count = int(row_limit) if row_limit else 0
elif query_context.form_data and "row_limit" in query_context.form_data:
row_limit = query_context.form_data.get("row_limit", 0)
actual_row_count = int(row_limit) if row_limit else 0
# Use streaming if row count meets or exceeds threshold
return actual_row_count is not None and actual_row_count >= threshold
def _create_streaming_csv_response(
self,
result: dict[Any, Any],
form_data: dict[str, Any] | None = None,
filename: str | None = None,
expected_rows: int | None = None,
) -> Response:
"""Create a streaming CSV response for large datasets."""
query_context = result["query_context"]
# Use filename from frontend if provided, otherwise generate one
if not filename:
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
chart_name = "export"
if form_data and form_data.get("slice_name"):
chart_name = form_data["slice_name"]
elif form_data and form_data.get("viz_type"):
chart_name = form_data["viz_type"]
# Sanitize chart name for filename
filename = secure_filename(f"superset_{chart_name}_{timestamp}.csv")
logger.info("Creating streaming CSV response: %s", filename)
if expected_rows:
logger.info("Using expected_rows from frontend: %d", expected_rows)
# Execute streaming command
# TODO: Make chunk size configurable via SUPERSET_CONFIG
chunk_size = 1024
command = StreamingCSVExportCommand(query_context, chunk_size)
command.validate()
# Get the callable that returns the generator
csv_generator_callable = command.run()
# Get encoding from config
encoding = app.config.get("CSV_EXPORT", {}).get("encoding", "utf-8")
# Create response with streaming headers
response = Response(
csv_generator_callable(), # Call the callable to get generator
mimetype=f"text/csv; charset={encoding}",
headers={
"Content-Disposition": f'attachment; filename="{filename}"',
"Cache-Control": "no-cache",
"X-Accel-Buffering": "no", # Disable nginx buffering
},
direct_passthrough=False, # Flask must iterate generator
)
# Force chunked transfer encoding
response.implicit_sequence_conversion = False
return response