mirror of
https://github.com/apache/superset.git
synced 2026-04-10 11:55:24 +00:00
229 lines
8.3 KiB
Python
229 lines
8.3 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
from __future__ import annotations
|
|
|
|
import copy
|
|
from typing import Any, Callable, TYPE_CHECKING
|
|
|
|
from flask_babel import _
|
|
|
|
from superset.common.chart_data import ChartDataResultType
|
|
from superset.common.db_query_status import QueryStatus
|
|
from superset.exceptions import QueryObjectValidationError, SupersetParseError
|
|
from superset.explorables.base import Explorable
|
|
from superset.utils.core import (
|
|
extract_column_dtype,
|
|
extract_dataframe_dtypes,
|
|
ExtraFiltersReasonType,
|
|
get_column_name,
|
|
get_time_filter_status,
|
|
)
|
|
|
|
if TYPE_CHECKING:
|
|
from superset.common.query_context import QueryContext
|
|
from superset.common.query_object import QueryObject
|
|
|
|
|
|
def _get_datasource(query_context: QueryContext, query_obj: QueryObject) -> Explorable:
|
|
return query_obj.datasource or query_context.datasource
|
|
|
|
|
|
def _get_columns(
|
|
query_context: QueryContext, query_obj: QueryObject, _: bool
|
|
) -> dict[str, Any]:
|
|
datasource = _get_datasource(query_context, query_obj)
|
|
return {
|
|
"data": [
|
|
{
|
|
"column_name": col.column_name,
|
|
"verbose_name": col.verbose_name,
|
|
"dtype": extract_column_dtype(col),
|
|
}
|
|
for col in datasource.columns
|
|
]
|
|
}
|
|
|
|
|
|
def _get_timegrains(
|
|
query_context: QueryContext, query_obj: QueryObject, _: bool
|
|
) -> dict[str, Any]:
|
|
datasource = _get_datasource(query_context, query_obj)
|
|
# Use the new get_time_grains() method from Explorable protocol
|
|
grains = datasource.get_time_grains()
|
|
return {"data": grains}
|
|
|
|
|
|
def _get_query(
|
|
query_context: QueryContext,
|
|
query_obj: QueryObject,
|
|
_: bool,
|
|
) -> dict[str, Any]:
|
|
datasource = _get_datasource(query_context, query_obj)
|
|
result = {"language": datasource.query_language}
|
|
try:
|
|
result["query"] = datasource.get_query_str(query_obj.to_dict())
|
|
except QueryObjectValidationError as err:
|
|
# Validation errors (missing required fields, invalid config)
|
|
# No SQL was generated
|
|
result["error"] = err.message
|
|
except SupersetParseError as err:
|
|
# Parsing errors (SQL optimization/parsing failed)
|
|
# SQL was generated but couldn't be optimized - show both
|
|
if err.error.extra and (sql := err.error.extra.get("sql")) is not None:
|
|
result["query"] = sql
|
|
result["error"] = err.error.message
|
|
return result
|
|
|
|
|
|
def _get_full(
|
|
query_context: QueryContext,
|
|
query_obj: QueryObject,
|
|
force_cached: bool | None = False,
|
|
) -> dict[str, Any]:
|
|
datasource = _get_datasource(query_context, query_obj)
|
|
result_type = query_obj.result_type or query_context.result_type
|
|
payload = query_context.get_df_payload(query_obj, force_cached=force_cached)
|
|
df = payload["df"]
|
|
status = payload["status"]
|
|
if status != QueryStatus.FAILED:
|
|
payload["colnames"] = list(df.columns)
|
|
payload["indexnames"] = list(df.index)
|
|
payload["coltypes"] = extract_dataframe_dtypes(df, datasource)
|
|
payload["data"] = query_context.get_data(df, payload["coltypes"])
|
|
payload["result_format"] = query_context.result_format
|
|
del payload["df"]
|
|
|
|
applied_time_columns, rejected_time_columns = get_time_filter_status(
|
|
datasource, query_obj.applied_time_extras
|
|
)
|
|
|
|
applied_filter_columns = payload.get("applied_filter_columns", [])
|
|
rejected_filter_columns = payload.get("rejected_filter_columns", [])
|
|
del payload["applied_filter_columns"]
|
|
del payload["rejected_filter_columns"]
|
|
payload["applied_filters"] = [
|
|
{"column": get_column_name(col)} for col in applied_filter_columns
|
|
] + applied_time_columns
|
|
payload["rejected_filters"] = [
|
|
{
|
|
"reason": ExtraFiltersReasonType.COL_NOT_IN_DATASOURCE,
|
|
"column": get_column_name(col),
|
|
}
|
|
for col in rejected_filter_columns
|
|
] + rejected_time_columns
|
|
|
|
if result_type == ChartDataResultType.RESULTS and status != QueryStatus.FAILED:
|
|
return {
|
|
"data": payload.get("data"),
|
|
"colnames": payload.get("colnames"),
|
|
"coltypes": payload.get("coltypes"),
|
|
"rowcount": payload.get("rowcount"),
|
|
"sql_rowcount": payload.get("sql_rowcount"),
|
|
}
|
|
return payload
|
|
|
|
|
|
def _get_samples(
|
|
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
|
) -> dict[str, Any]:
|
|
datasource = _get_datasource(query_context, query_obj)
|
|
query_obj = copy.copy(query_obj)
|
|
query_obj.is_timeseries = False
|
|
query_obj.orderby = []
|
|
query_obj.metrics = None
|
|
query_obj.post_processing = []
|
|
qry_obj_cols = []
|
|
for o in datasource.columns:
|
|
if isinstance(o, dict):
|
|
if column_name := o.get("column_name"):
|
|
qry_obj_cols.append(column_name)
|
|
else:
|
|
qry_obj_cols.append(o.column_name)
|
|
query_obj.columns = qry_obj_cols
|
|
query_obj.from_dttm = None
|
|
query_obj.to_dttm = None
|
|
return _get_full(query_context, query_obj, force_cached)
|
|
|
|
|
|
def _get_drill_detail(
|
|
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
|
) -> dict[str, Any]:
|
|
# todo(yongjie): Remove this function,
|
|
# when determining whether samples should be applied to the time filter.
|
|
datasource = _get_datasource(query_context, query_obj)
|
|
query_obj = copy.copy(query_obj)
|
|
query_obj.is_timeseries = False
|
|
query_obj.metrics = None
|
|
query_obj.post_processing = []
|
|
qry_obj_cols = []
|
|
for o in datasource.columns:
|
|
if isinstance(o, dict):
|
|
if column_name := o.get("column_name"):
|
|
qry_obj_cols.append(column_name)
|
|
else:
|
|
qry_obj_cols.append(o.column_name)
|
|
query_obj.columns = qry_obj_cols
|
|
query_obj.orderby = [(query_obj.columns[0], True)]
|
|
return _get_full(query_context, query_obj, force_cached)
|
|
|
|
|
|
def _get_results(
|
|
query_context: QueryContext, query_obj: QueryObject, force_cached: bool = False
|
|
) -> dict[str, Any]:
|
|
payload = _get_full(query_context, query_obj, force_cached)
|
|
return payload
|
|
|
|
|
|
_result_type_functions: dict[
|
|
ChartDataResultType, Callable[[QueryContext, QueryObject, bool], dict[str, Any]]
|
|
] = {
|
|
ChartDataResultType.COLUMNS: _get_columns,
|
|
ChartDataResultType.TIMEGRAINS: _get_timegrains,
|
|
ChartDataResultType.QUERY: _get_query,
|
|
ChartDataResultType.SAMPLES: _get_samples,
|
|
ChartDataResultType.FULL: _get_full,
|
|
ChartDataResultType.RESULTS: _get_results,
|
|
# for requests for post-processed data we return the full results,
|
|
# and post-process it later where we have the chart context, since
|
|
# post-processing is unique to each visualization type
|
|
ChartDataResultType.POST_PROCESSED: _get_full,
|
|
ChartDataResultType.DRILL_DETAIL: _get_drill_detail,
|
|
}
|
|
|
|
|
|
def get_query_results(
|
|
result_type: ChartDataResultType,
|
|
query_context: QueryContext,
|
|
query_obj: QueryObject,
|
|
force_cached: bool,
|
|
) -> dict[str, Any]:
|
|
"""
|
|
Return result payload for a chart data request.
|
|
|
|
:param result_type: the type of result to return
|
|
:param query_context: query context to which the query object belongs
|
|
:param query_obj: query object for which to retrieve the results
|
|
:param force_cached: should results be forcefully retrieved from cache
|
|
:raises QueryObjectValidationError: if an unsupported result type is requested
|
|
:return: JSON serializable result payload
|
|
"""
|
|
if result_func := _result_type_functions.get(result_type):
|
|
return result_func(query_context, query_obj, force_cached)
|
|
raise QueryObjectValidationError(
|
|
_("Invalid result type: %(result_type)s", result_type=result_type)
|
|
)
|