mirror of
https://github.com/apache/superset.git
synced 2026-05-29 20:29:34 +00:00
Add error-level logging to all failure paths in the chart data API to help diagnose intermittent 400 BAD REQUEST failures during CSV exports. Previously, JSON parsing errors were silently swallowed by contextlib.suppress and validation/query errors returned 400 without any logging, making it impossible to identify which failure path was hit. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
732 lines
29 KiB
Python
732 lines
29 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
from __future__ import annotations
|
|
|
|
import contextlib
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Callable, TYPE_CHECKING
|
|
|
|
from flask import current_app as app, g, make_response, request, Response
|
|
from flask_appbuilder.api import expose, protect
|
|
from flask_babel import gettext as _
|
|
from marshmallow import ValidationError
|
|
from werkzeug.utils import secure_filename
|
|
|
|
from superset import is_feature_enabled, security_manager
|
|
from superset.async_events.async_query_manager import AsyncQueryTokenException
|
|
from superset.charts.api import ChartRestApi
|
|
from superset.charts.client_processing import apply_client_processing
|
|
from superset.charts.data.query_context_cache_loader import QueryContextCacheLoader
|
|
from superset.charts.schemas import ChartDataQueryContextSchema
|
|
from superset.commands.chart.data.create_async_job_command import (
|
|
CreateAsyncChartDataJobCommand,
|
|
)
|
|
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
|
from superset.commands.chart.data.streaming_export_command import (
|
|
StreamingCSVExportCommand,
|
|
)
|
|
from superset.commands.chart.exceptions import (
|
|
ChartDataCacheLoadError,
|
|
ChartDataQueryFailedError,
|
|
)
|
|
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
|
from superset.connectors.sqla.models import BaseDatasource
|
|
from superset.constants import CACHE_DISABLED_TIMEOUT
|
|
from superset.daos.exceptions import DatasourceNotFound
|
|
from superset.exceptions import QueryObjectValidationError
|
|
from superset.extensions import event_logger
|
|
from superset.models.sql_lab import Query
|
|
from superset.utils import json
|
|
from superset.utils.core import (
|
|
create_zip,
|
|
DatasourceType,
|
|
get_user_id,
|
|
)
|
|
from superset.utils.decorators import logs_context
|
|
from superset.views.base import CsvResponse, generate_download_headers, XlsxResponse
|
|
from superset.views.base_api import statsd_metrics
|
|
|
|
if TYPE_CHECKING:
|
|
from superset.common.query_context import QueryContext
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ChartDataRestApi(ChartRestApi):
|
|
include_route_methods = {"get_data", "data", "data_from_cache"}
|
|
|
|
@expose("/<int:pk>/data/", methods=("GET",))
|
|
@protect()
|
|
@statsd_metrics
|
|
@event_logger.log_this_with_context(
|
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data",
|
|
log_to_statsd=False,
|
|
allow_extra_payload=True,
|
|
)
|
|
def get_data( # noqa: C901
|
|
self,
|
|
pk: int,
|
|
add_extra_log_payload: Callable[..., None] = lambda **kwargs: None,
|
|
) -> Response:
|
|
"""
|
|
Take a chart ID and uses the query context stored when the chart was saved
|
|
to return payload data response.
|
|
---
|
|
get:
|
|
summary: Return payload data response for a chart
|
|
description: >-
|
|
Takes a chart ID and uses the query context stored when the chart was saved
|
|
to return payload data response.
|
|
parameters:
|
|
- in: path
|
|
schema:
|
|
type: integer
|
|
name: pk
|
|
description: The chart ID
|
|
- in: query
|
|
name: format
|
|
description: The format in which the data should be returned
|
|
schema:
|
|
type: string
|
|
- in: query
|
|
name: type
|
|
description: The type in which the data should be returned
|
|
schema:
|
|
type: string
|
|
- in: query
|
|
name: force
|
|
description: Should the queries be forced to load from the source
|
|
schema:
|
|
type: boolean
|
|
responses:
|
|
200:
|
|
description: Query result
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataResponseSchema"
|
|
202:
|
|
description: Async job details
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
|
|
400:
|
|
$ref: '#/components/responses/400'
|
|
401:
|
|
$ref: '#/components/responses/401'
|
|
500:
|
|
$ref: '#/components/responses/500'
|
|
"""
|
|
chart = self.datamodel.get(pk, self._base_filters)
|
|
if not chart:
|
|
return self.response_404()
|
|
|
|
try:
|
|
json_body = json.loads(chart.query_context)
|
|
except (TypeError, json.JSONDecodeError):
|
|
json_body = None
|
|
|
|
if json_body is None:
|
|
return self.response_400(
|
|
message=_(
|
|
"Chart has no query context saved. Please save the chart again."
|
|
)
|
|
)
|
|
|
|
# override saved query context
|
|
json_body["result_format"] = request.args.get(
|
|
"format", ChartDataResultFormat.JSON
|
|
)
|
|
json_body["result_type"] = request.args.get("type", ChartDataResultType.FULL)
|
|
json_body["force"] = request.args.get("force")
|
|
|
|
try:
|
|
query_context = self._create_query_context_from_form(json_body)
|
|
command = ChartDataCommand(query_context)
|
|
command.validate()
|
|
except DatasourceNotFound:
|
|
return self.response_404()
|
|
except QueryObjectValidationError as error:
|
|
return self.response_400(message=error.message)
|
|
except ValidationError as error:
|
|
return self.response_400(
|
|
message=_(
|
|
"Request is incorrect: %(error)s", error=error.normalized_messages()
|
|
)
|
|
)
|
|
|
|
# TODO: support CSV, SQL query and other non-JSON types
|
|
# Don't use async queries when cache is disabled (cache_timeout=-1)
|
|
# as async queries depend on caching to retrieve results
|
|
cache_timeout = query_context.get_cache_timeout()
|
|
use_async = (
|
|
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
|
|
and query_context.result_format == ChartDataResultFormat.JSON
|
|
and query_context.result_type == ChartDataResultType.FULL
|
|
and cache_timeout != CACHE_DISABLED_TIMEOUT
|
|
)
|
|
if use_async:
|
|
return self._run_async(json_body, command, add_extra_log_payload)
|
|
|
|
try:
|
|
form_data = json.loads(chart.params)
|
|
except (TypeError, json.JSONDecodeError):
|
|
form_data = {}
|
|
|
|
return self._get_data_response(
|
|
command=command,
|
|
form_data=form_data,
|
|
datasource=query_context.datasource,
|
|
add_extra_log_payload=add_extra_log_payload,
|
|
)
|
|
|
|
@expose("/data", methods=("POST",))
|
|
@protect()
|
|
@statsd_metrics
|
|
@event_logger.log_this_with_context(
|
|
action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.data",
|
|
log_to_statsd=False,
|
|
allow_extra_payload=True,
|
|
)
|
|
def data( # noqa: C901
|
|
self, add_extra_log_payload: Callable[..., None] = lambda **kwargs: None
|
|
) -> Response:
|
|
"""
|
|
Take a query context constructed in the client and return payload
|
|
data response for the given query
|
|
---
|
|
post:
|
|
summary: Return payload data response for the given query
|
|
description: >-
|
|
Takes a query context constructed in the client and returns payload data
|
|
response for the given query.
|
|
requestBody:
|
|
description: >-
|
|
A query context consists of a datasource from which to fetch data
|
|
and one or many query objects.
|
|
required: true
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataQueryContextSchema"
|
|
responses:
|
|
200:
|
|
description: Query result
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataResponseSchema"
|
|
202:
|
|
description: Async job details
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataAsyncResponseSchema"
|
|
400:
|
|
$ref: '#/components/responses/400'
|
|
401:
|
|
$ref: '#/components/responses/401'
|
|
500:
|
|
$ref: '#/components/responses/500'
|
|
"""
|
|
json_body = None
|
|
if request.is_json:
|
|
json_body = request.json
|
|
elif request.form.get("form_data"):
|
|
# CSV export submits regular form data
|
|
try:
|
|
json_body = json.loads(request.form["form_data"])
|
|
except (TypeError, json.JSONDecodeError):
|
|
logger.error(
|
|
"Failed to parse form_data JSON: "
|
|
"content_type=%s, content_length=%s, form_data_length=%s, "
|
|
"referrer=%s",
|
|
request.content_type,
|
|
request.content_length,
|
|
len(request.form.get("form_data", "")),
|
|
request.referrer,
|
|
)
|
|
if json_body is None:
|
|
logger.error(
|
|
"Chart data request rejected: json_body is None. "
|
|
"is_json=%s, content_type=%s, content_length=%s, "
|
|
"has_form_data=%s, referrer=%s",
|
|
request.is_json,
|
|
request.content_type,
|
|
request.content_length,
|
|
bool(request.form.get("form_data")),
|
|
request.referrer,
|
|
)
|
|
return self.response_400(message=_("Request is not JSON"))
|
|
|
|
try:
|
|
query_context = self._create_query_context_from_form(json_body)
|
|
command = ChartDataCommand(query_context)
|
|
command.validate()
|
|
except DatasourceNotFound:
|
|
logger.error(
|
|
"Chart data request: DatasourceNotFound. "
|
|
"datasource=%s, result_format=%s, "
|
|
"slice_id=%s, referrer=%s",
|
|
json_body.get("datasource"),
|
|
json_body.get("result_format"),
|
|
json_body.get("form_data", {}).get("slice_id"),
|
|
request.referrer,
|
|
)
|
|
return self.response_404()
|
|
except QueryObjectValidationError as error:
|
|
logger.error(
|
|
"Chart data request: QueryObjectValidationError: %s. "
|
|
"result_format=%s, slice_id=%s, referrer=%s",
|
|
error.message,
|
|
json_body.get("result_format"),
|
|
json_body.get("form_data", {}).get("slice_id"),
|
|
request.referrer,
|
|
)
|
|
return self.response_400(message=error.message)
|
|
except ValidationError as error:
|
|
logger.error(
|
|
"Chart data request: ValidationError: %s. "
|
|
"result_format=%s, datasource=%s, "
|
|
"slice_id=%s, referrer=%s",
|
|
error.normalized_messages(),
|
|
json_body.get("result_format"),
|
|
json_body.get("datasource"),
|
|
json_body.get("form_data", {}).get("slice_id"),
|
|
request.referrer,
|
|
)
|
|
return self.response_400(
|
|
message=_(
|
|
"Request is incorrect: %(error)s", error=error.normalized_messages()
|
|
)
|
|
)
|
|
|
|
# TODO: support CSV, SQL query and other non-JSON types
|
|
# Don't use async queries when cache is disabled (cache_timeout=-1)
|
|
# as async queries depend on caching to retrieve results
|
|
cache_timeout = query_context.get_cache_timeout()
|
|
use_async = (
|
|
is_feature_enabled("GLOBAL_ASYNC_QUERIES")
|
|
and query_context.result_format == ChartDataResultFormat.JSON
|
|
and query_context.result_type == ChartDataResultType.FULL
|
|
and cache_timeout != CACHE_DISABLED_TIMEOUT
|
|
)
|
|
if use_async:
|
|
return self._run_async(json_body, command, add_extra_log_payload)
|
|
|
|
form_data = json_body.get("form_data")
|
|
filename, expected_rows = self._extract_export_params_from_request()
|
|
|
|
return self._get_data_response(
|
|
command,
|
|
form_data=form_data,
|
|
datasource=query_context.datasource,
|
|
add_extra_log_payload=add_extra_log_payload,
|
|
filename=filename,
|
|
expected_rows=expected_rows,
|
|
)
|
|
|
|
@expose("/data/<cache_key>", methods=("GET",))
|
|
@protect()
|
|
@statsd_metrics
|
|
@event_logger.log_this_with_context(
|
|
action=lambda self, *args, **kwargs: (
|
|
f"{self.__class__.__name__}.data_from_cache"
|
|
),
|
|
log_to_statsd=False,
|
|
)
|
|
def data_from_cache(self, cache_key: str) -> Response:
|
|
"""
|
|
Take a query context cache key and return payload
|
|
data response for the given query.
|
|
---
|
|
get:
|
|
summary: Return payload data response for the given query
|
|
description: >-
|
|
Takes a query context cache key and returns payload data
|
|
response for the given query.
|
|
parameters:
|
|
- in: path
|
|
schema:
|
|
type: string
|
|
name: cache_key
|
|
responses:
|
|
200:
|
|
description: Query result
|
|
content:
|
|
application/json:
|
|
schema:
|
|
$ref: "#/components/schemas/ChartDataResponseSchema"
|
|
400:
|
|
$ref: '#/components/responses/400'
|
|
401:
|
|
$ref: '#/components/responses/401'
|
|
404:
|
|
$ref: '#/components/responses/404'
|
|
422:
|
|
$ref: '#/components/responses/422'
|
|
500:
|
|
$ref: '#/components/responses/500'
|
|
"""
|
|
try:
|
|
cached_data = self._load_query_context_form_from_cache(cache_key)
|
|
# Set form_data in Flask Global as it is used as a fallback
|
|
# for async queries with jinja context
|
|
g.form_data = cached_data
|
|
query_context = self._create_query_context_from_form(cached_data)
|
|
command = ChartDataCommand(query_context)
|
|
command.validate()
|
|
except ChartDataCacheLoadError:
|
|
return self.response_404()
|
|
except ValidationError as error:
|
|
return self.response_400(
|
|
message=_("Request is incorrect: %(error)s", error=error.messages)
|
|
)
|
|
|
|
return self._get_data_response(command, True)
|
|
|
|
def _run_async(
|
|
self,
|
|
form_data: dict[str, Any],
|
|
command: ChartDataCommand,
|
|
add_extra_log_payload: Callable[..., None] | None = None,
|
|
) -> Response:
|
|
"""
|
|
Execute command as an async query.
|
|
"""
|
|
# First, look for the chart query results in the cache,
|
|
# but only if we're not forcing a refresh.
|
|
if not form_data.get("force"):
|
|
with contextlib.suppress(ChartDataCacheLoadError):
|
|
result = command.run(force_cached=True)
|
|
if result is not None:
|
|
# Log is_cached if extra payload callback is provided.
|
|
# This indicates no async job was triggered - data was already
|
|
# cached and a synchronous response is being returned immediately.
|
|
self._log_is_cached(result, add_extra_log_payload)
|
|
return self._send_chart_response(result)
|
|
# Otherwise, kick off a background job to run the chart query.
|
|
# Clients will either poll or be notified of query completion,
|
|
# at which point they will call the /data/<cache_key> endpoint
|
|
# to retrieve the results.
|
|
async_command = CreateAsyncChartDataJobCommand()
|
|
try:
|
|
async_command.validate(request)
|
|
except AsyncQueryTokenException:
|
|
return self.response_401()
|
|
|
|
result = async_command.run(form_data, get_user_id())
|
|
return self.response(202, **result)
|
|
|
|
def _send_chart_response( # noqa: C901
|
|
self,
|
|
result: dict[Any, Any],
|
|
form_data: dict[str, Any] | None = None,
|
|
datasource: BaseDatasource | Query | None = None,
|
|
filename: str | None = None,
|
|
expected_rows: int | None = None,
|
|
) -> Response:
|
|
result_type = result["query_context"].result_type
|
|
result_format = result["query_context"].result_format
|
|
|
|
# Post-process the data so it matches the data presented in the chart.
|
|
# This is needed for sending reports based on text charts that do the
|
|
# post-processing of data, eg, the pivot table.
|
|
if result_type == ChartDataResultType.POST_PROCESSED:
|
|
result = apply_client_processing(result, form_data, datasource)
|
|
|
|
if result_format in ChartDataResultFormat.table_like():
|
|
# Verify user has permission to export file
|
|
if is_feature_enabled("GRANULAR_EXPORT_CONTROLS"):
|
|
has_export_perm = security_manager.can_access(
|
|
"can_export_data", "Superset"
|
|
)
|
|
else:
|
|
has_export_perm = security_manager.can_access("can_csv", "Superset")
|
|
if not has_export_perm:
|
|
logger.error(
|
|
"Chart data request: export permission denied. "
|
|
"result_format=%s, referrer=%s",
|
|
result_format,
|
|
request.referrer,
|
|
)
|
|
return self.response_403()
|
|
|
|
if not result["queries"]:
|
|
logger.error(
|
|
"Chart data request: empty query result. "
|
|
"result_format=%s, referrer=%s",
|
|
result_format,
|
|
request.referrer,
|
|
)
|
|
return self.response_400(_("Empty query result"))
|
|
|
|
is_csv_format = result_format == ChartDataResultFormat.CSV
|
|
|
|
# Check if we should use streaming for large datasets
|
|
if is_csv_format and self._should_use_streaming(result, form_data):
|
|
return self._create_streaming_csv_response(
|
|
result, form_data, filename=filename, expected_rows=expected_rows
|
|
)
|
|
|
|
if len(result["queries"]) == 1:
|
|
# return single query results
|
|
data = result["queries"][0]["data"]
|
|
if is_csv_format:
|
|
return CsvResponse(data, headers=generate_download_headers("csv"))
|
|
|
|
return XlsxResponse(data, headers=generate_download_headers("xlsx"))
|
|
|
|
# return multi-query results bundled as a zip file
|
|
def _process_data(query_data: Any) -> Any:
|
|
if result_format == ChartDataResultFormat.CSV:
|
|
encoding = app.config["CSV_EXPORT"].get("encoding", "utf-8")
|
|
return query_data.encode(encoding)
|
|
return query_data
|
|
|
|
files = {
|
|
f"query_{idx + 1}.{result_format}": _process_data(query["data"])
|
|
for idx, query in enumerate(result["queries"])
|
|
}
|
|
return Response(
|
|
create_zip(files),
|
|
headers=generate_download_headers("zip"),
|
|
mimetype="application/zip",
|
|
)
|
|
|
|
if result_format == ChartDataResultFormat.JSON:
|
|
queries = result["queries"]
|
|
if security_manager.is_guest_user():
|
|
for query in queries:
|
|
query.pop("query", None)
|
|
with event_logger.log_context(f"{self.__class__.__name__}.json_dumps"):
|
|
response_data = json.dumps(
|
|
{"result": queries},
|
|
default=json.json_int_dttm_ser,
|
|
ignore_nan=True,
|
|
)
|
|
resp = make_response(response_data, 200)
|
|
resp.headers["Content-Type"] = "application/json; charset=utf-8"
|
|
return resp
|
|
|
|
return self.response_400(message=f"Unsupported result_format: {result_format}")
|
|
|
|
def _log_is_cached(
|
|
self,
|
|
result: dict[str, Any],
|
|
add_extra_log_payload: Callable[..., None] | None,
|
|
) -> None:
|
|
"""
|
|
Log is_cached values from query results to event logger.
|
|
|
|
Extracts is_cached from each query in the result and logs it.
|
|
If there's a single query, logs the boolean value directly.
|
|
If multiple queries, logs as a list.
|
|
"""
|
|
if add_extra_log_payload and result and "queries" in result:
|
|
is_cached_values = [query.get("is_cached") for query in result["queries"]]
|
|
if len(is_cached_values) == 1:
|
|
add_extra_log_payload(is_cached=is_cached_values[0])
|
|
elif is_cached_values:
|
|
add_extra_log_payload(is_cached=is_cached_values)
|
|
|
|
@event_logger.log_this
|
|
def _get_data_response(
|
|
self,
|
|
command: ChartDataCommand,
|
|
force_cached: bool = False,
|
|
form_data: dict[str, Any] | None = None,
|
|
datasource: BaseDatasource | Query | None = None,
|
|
filename: str | None = None,
|
|
expected_rows: int | None = None,
|
|
add_extra_log_payload: Callable[..., None] | None = None,
|
|
) -> Response:
|
|
"""Get data response and optionally log is_cached information."""
|
|
try:
|
|
result = command.run(force_cached=force_cached)
|
|
except ChartDataCacheLoadError as exc:
|
|
return self.response_422(message=exc.message)
|
|
except ChartDataQueryFailedError as exc:
|
|
logger.error(
|
|
"Chart data query failed: %s. "
|
|
"result_format=%s, force_cached=%s, referrer=%s",
|
|
exc.message,
|
|
form_data.get("result_format") if form_data else None,
|
|
force_cached,
|
|
request.referrer,
|
|
)
|
|
return self.response_400(message=exc.message)
|
|
|
|
# Log is_cached if extra payload callback is provided
|
|
if add_extra_log_payload and result and "queries" in result:
|
|
is_cached_values = [query.get("is_cached") for query in result["queries"]]
|
|
add_extra_log_payload(is_cached=is_cached_values)
|
|
|
|
return self._send_chart_response(
|
|
result, form_data, datasource, filename, expected_rows
|
|
)
|
|
|
|
def _extract_export_params_from_request(self) -> tuple[str | None, int | None]:
|
|
"""Extract filename and expected_rows from request for streaming exports."""
|
|
filename = request.form.get("filename")
|
|
if filename:
|
|
logger.info("FRONTEND PROVIDED FILENAME: %s", filename)
|
|
|
|
expected_rows = None
|
|
if expected_rows_str := request.form.get("expected_rows"):
|
|
try:
|
|
expected_rows = int(expected_rows_str)
|
|
logger.info("FRONTEND PROVIDED EXPECTED ROWS: %d", expected_rows)
|
|
except (ValueError, TypeError):
|
|
logger.warning("Invalid expected_rows value: %s", expected_rows_str)
|
|
|
|
return filename, expected_rows
|
|
|
|
# pylint: disable=invalid-name
|
|
def _load_query_context_form_from_cache(self, cache_key: str) -> dict[str, Any]:
|
|
return QueryContextCacheLoader.load(cache_key)
|
|
|
|
def _map_form_data_datasource_to_dataset_id(
|
|
self, form_data: dict[str, Any]
|
|
) -> dict[str, Any]:
|
|
return {
|
|
"dashboard_id": form_data.get("form_data", {}).get("dashboardId"),
|
|
"dataset_id": (
|
|
form_data.get("datasource", {}).get("id")
|
|
if isinstance(form_data.get("datasource"), dict)
|
|
and form_data.get("datasource", {}).get("type")
|
|
== DatasourceType.TABLE.value
|
|
else None
|
|
),
|
|
"slice_id": form_data.get("form_data", {}).get("slice_id"),
|
|
}
|
|
|
|
@logs_context(context_func=_map_form_data_datasource_to_dataset_id)
|
|
def _create_query_context_from_form(
|
|
self, form_data: dict[str, Any]
|
|
) -> QueryContext:
|
|
"""
|
|
Create the query context from the form data.
|
|
|
|
:param form_data: The chart form data
|
|
:returns: The query context
|
|
:raises ValidationError: If the request is incorrect
|
|
"""
|
|
|
|
try:
|
|
return ChartDataQueryContextSchema().load(form_data)
|
|
except KeyError as ex:
|
|
raise ValidationError("Request is incorrect") from ex
|
|
|
|
def _should_use_streaming(
|
|
self, result: dict[Any, Any], form_data: dict[str, Any] | None = None
|
|
) -> bool:
|
|
"""Determine if streaming should be used based on actual row count threshold."""
|
|
query_context = result["query_context"]
|
|
result_format = query_context.result_format
|
|
|
|
# Only support CSV streaming currently
|
|
if result_format.lower() != "csv":
|
|
return False
|
|
|
|
# Get streaming threshold from config
|
|
threshold = app.config.get("CSV_STREAMING_ROW_THRESHOLD", 100000)
|
|
|
|
# Extract actual row count (same logic as frontend)
|
|
actual_row_count: int | None = None
|
|
viz_type = form_data.get("viz_type") if form_data else None
|
|
|
|
# For table viz, try to get actual row count from query results
|
|
if viz_type == "table" and result.get("queries"):
|
|
# Check if we have rowcount in the second query result (like frontend does)
|
|
queries = result.get("queries", [])
|
|
if len(queries) > 1 and queries[1].get("data"):
|
|
data = queries[1]["data"]
|
|
if isinstance(data, list) and len(data) > 0:
|
|
rowcount = data[0].get("rowcount")
|
|
actual_row_count = int(rowcount) if rowcount else None
|
|
|
|
# Fallback to row_limit if actual count not available
|
|
if actual_row_count is None:
|
|
if form_data and "row_limit" in form_data:
|
|
row_limit = form_data.get("row_limit", 0)
|
|
actual_row_count = int(row_limit) if row_limit else 0
|
|
elif query_context.form_data and "row_limit" in query_context.form_data:
|
|
row_limit = query_context.form_data.get("row_limit", 0)
|
|
actual_row_count = int(row_limit) if row_limit else 0
|
|
|
|
# Use streaming if row count meets or exceeds threshold
|
|
return actual_row_count is not None and actual_row_count >= threshold
|
|
|
|
def _create_streaming_csv_response(
|
|
self,
|
|
result: dict[Any, Any],
|
|
form_data: dict[str, Any] | None = None,
|
|
filename: str | None = None,
|
|
expected_rows: int | None = None,
|
|
) -> Response:
|
|
"""Create a streaming CSV response for large datasets."""
|
|
query_context = result["query_context"]
|
|
|
|
# Use filename from frontend if provided, otherwise generate one
|
|
if not filename:
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
|
chart_name = "export"
|
|
|
|
if form_data and form_data.get("slice_name"):
|
|
chart_name = form_data["slice_name"]
|
|
elif form_data and form_data.get("viz_type"):
|
|
chart_name = form_data["viz_type"]
|
|
|
|
# Sanitize chart name for filename
|
|
filename = secure_filename(f"superset_{chart_name}_{timestamp}.csv")
|
|
|
|
logger.info("Creating streaming CSV response: %s", filename)
|
|
if expected_rows:
|
|
logger.info("Using expected_rows from frontend: %d", expected_rows)
|
|
|
|
# Execute streaming command
|
|
# TODO: Make chunk size configurable via SUPERSET_CONFIG
|
|
chunk_size = 1024
|
|
command = StreamingCSVExportCommand(query_context, chunk_size)
|
|
command.validate()
|
|
|
|
# Get the callable that returns the generator
|
|
csv_generator_callable = command.run()
|
|
|
|
# Get encoding from config
|
|
encoding = app.config.get("CSV_EXPORT", {}).get("encoding", "utf-8")
|
|
|
|
# Create response with streaming headers
|
|
response = Response(
|
|
csv_generator_callable(), # Call the callable to get generator
|
|
mimetype=f"text/csv; charset={encoding}",
|
|
headers={
|
|
"Content-Disposition": f'attachment; filename="{filename}"',
|
|
"Cache-Control": "no-cache",
|
|
"X-Accel-Buffering": "no", # Disable nginx buffering
|
|
},
|
|
direct_passthrough=False, # Flask must iterate generator
|
|
)
|
|
|
|
# Force chunked transfer encoding
|
|
response.implicit_sequence_conversion = False
|
|
|
|
return response
|