mirror of
https://github.com/apache/superset.git
synced 2026-04-09 19:35:21 +00:00
279 lines
11 KiB
Python
279 lines
11 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
from __future__ import annotations
|
|
|
|
from typing import Any
|
|
|
|
from flask import current_app
|
|
|
|
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
|
from superset.common.query_context import QueryContext
|
|
from superset.common.query_object import QueryObject
|
|
from superset.common.query_object_factory import QueryObjectFactory
|
|
from superset.daos.chart import ChartDAO
|
|
from superset.daos.datasource import DatasourceDAO
|
|
from superset.explorables.base import Explorable
|
|
from superset.models.slice import Slice
|
|
from superset.superset_typing import Column
|
|
from superset.utils.core import DatasourceDict, DatasourceType, is_adhoc_column
|
|
|
|
|
|
def create_query_object_factory() -> QueryObjectFactory:
|
|
return QueryObjectFactory(current_app.config, DatasourceDAO())
|
|
|
|
|
|
class QueryContextFactory: # pylint: disable=too-few-public-methods
|
|
_query_object_factory: QueryObjectFactory
|
|
|
|
def __init__(self) -> None:
|
|
self._query_object_factory = create_query_object_factory()
|
|
|
|
def create( # pylint: disable=too-many-arguments
|
|
self,
|
|
*,
|
|
datasource: DatasourceDict,
|
|
queries: list[dict[str, Any]],
|
|
form_data: dict[str, Any] | None = None,
|
|
result_type: ChartDataResultType | None = None,
|
|
result_format: ChartDataResultFormat | None = None,
|
|
force: bool = False,
|
|
custom_cache_timeout: int | None = None,
|
|
) -> QueryContext:
|
|
datasource_model_instance = None
|
|
if datasource:
|
|
datasource_model_instance = self._convert_to_model(datasource)
|
|
|
|
slice_ = None
|
|
if form_data and form_data.get("slice_id") is not None:
|
|
slice_ = self._get_slice(form_data.get("slice_id"))
|
|
|
|
result_type = result_type or ChartDataResultType.FULL
|
|
result_format = result_format or ChartDataResultFormat.JSON
|
|
|
|
# The server pagination var is extracted from form data as the
|
|
# row limit for server pagination is more
|
|
# This particular flag server_pagination only exists for table viz type
|
|
server_pagination = (
|
|
bool(form_data.get("server_pagination")) if form_data else False
|
|
)
|
|
|
|
queries_ = [
|
|
self._process_query_object(
|
|
datasource_model_instance,
|
|
form_data,
|
|
self._query_object_factory.create(
|
|
result_type,
|
|
datasource=datasource,
|
|
server_pagination=server_pagination,
|
|
**query_obj,
|
|
),
|
|
)
|
|
for query_obj in queries
|
|
]
|
|
cache_values = {
|
|
"datasource": datasource,
|
|
"queries": queries,
|
|
"result_type": result_type,
|
|
"result_format": result_format,
|
|
}
|
|
return QueryContext(
|
|
datasource=datasource_model_instance,
|
|
queries=queries_,
|
|
slice_=slice_,
|
|
form_data=form_data,
|
|
result_type=result_type,
|
|
result_format=result_format,
|
|
force=force,
|
|
custom_cache_timeout=custom_cache_timeout,
|
|
cache_values=cache_values,
|
|
)
|
|
|
|
def _convert_to_model(self, datasource: DatasourceDict) -> Explorable:
|
|
return DatasourceDAO.get_datasource(
|
|
datasource_type=DatasourceType(datasource["type"]),
|
|
database_id_or_uuid=datasource["id"],
|
|
)
|
|
|
|
def _get_slice(self, slice_id: Any) -> Slice | None:
|
|
return ChartDAO.find_by_id(slice_id)
|
|
|
|
def _process_query_object(
|
|
self,
|
|
datasource: Explorable,
|
|
form_data: dict[str, Any] | None,
|
|
query_object: QueryObject,
|
|
) -> QueryObject:
|
|
self._apply_granularity(query_object, form_data, datasource)
|
|
self._apply_filters(query_object)
|
|
self._add_tooltip_columns(query_object, form_data)
|
|
return query_object
|
|
|
|
def _add_tooltip_columns(
|
|
self,
|
|
query_object: QueryObject,
|
|
form_data: dict[str, Any] | None,
|
|
) -> None:
|
|
"""Add tooltip columns to the query object."""
|
|
if not form_data:
|
|
return
|
|
|
|
tooltip_columns = self._extract_tooltip_columns(form_data)
|
|
if not tooltip_columns:
|
|
return
|
|
|
|
existing_columns = self._get_existing_column_names(query_object.columns)
|
|
self._append_missing_tooltip_columns(
|
|
query_object, tooltip_columns, existing_columns
|
|
)
|
|
|
|
def _get_existing_column_names(self, columns: list[Column]) -> set[str]:
|
|
"""Extract column names from existing columns."""
|
|
column_names: set[str] = set()
|
|
for col in columns:
|
|
if isinstance(col, dict):
|
|
column_name = col.get("column_name")
|
|
if column_name and isinstance(column_name, str):
|
|
column_names.add(column_name)
|
|
elif isinstance(col, str):
|
|
column_names.add(col)
|
|
return column_names
|
|
|
|
def _append_missing_tooltip_columns(
|
|
self,
|
|
query_object: QueryObject,
|
|
tooltip_columns: list[str],
|
|
existing_columns: set[str],
|
|
) -> None:
|
|
"""Append missing tooltip columns to query object."""
|
|
for col in tooltip_columns:
|
|
if col not in existing_columns:
|
|
column_def = self._find_column_definition(query_object, col)
|
|
query_object.columns.append(column_def or col)
|
|
|
|
def _find_column_definition(
|
|
self, query_object: QueryObject, column_name: str
|
|
) -> Any | None:
|
|
"""Find column definition from datasource."""
|
|
if not (
|
|
query_object.datasource and hasattr(query_object.datasource, "columns")
|
|
):
|
|
return None
|
|
|
|
return next(
|
|
(
|
|
c
|
|
for c in query_object.datasource.columns
|
|
if c.column_name == column_name
|
|
),
|
|
None,
|
|
)
|
|
|
|
def _extract_tooltip_columns(self, form_data: dict[str, Any]) -> list[str]:
|
|
"""Extract column names from tooltip_contents configuration."""
|
|
tooltip_columns = []
|
|
if tooltip_contents := form_data.get("tooltip_contents", []):
|
|
for item in tooltip_contents:
|
|
if isinstance(item, str):
|
|
tooltip_columns.append(item)
|
|
elif isinstance(item, dict) and item.get("item_type") == "column":
|
|
column_name = item.get("column_name")
|
|
if column_name:
|
|
tooltip_columns.append(column_name)
|
|
return tooltip_columns
|
|
|
|
def _apply_granularity( # noqa: C901
|
|
self,
|
|
query_object: QueryObject,
|
|
form_data: dict[str, Any] | None,
|
|
datasource: Explorable,
|
|
) -> None:
|
|
temporal_columns = {
|
|
column["column_name"] if isinstance(column, dict) else column.column_name
|
|
for column in datasource.columns
|
|
if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm)
|
|
}
|
|
x_axis = form_data and form_data.get("x_axis")
|
|
|
|
if granularity := query_object.granularity:
|
|
filter_to_remove = None
|
|
if is_adhoc_column(x_axis): # type: ignore
|
|
x_axis = x_axis.get("sqlExpression")
|
|
if isinstance(x_axis, dict) and "sqlExpression" in x_axis:
|
|
x_axis = x_axis.get("sqlExpression")
|
|
if x_axis and x_axis in temporal_columns:
|
|
filter_to_remove = x_axis
|
|
x_axis_column = next(
|
|
(
|
|
column
|
|
for column in query_object.columns
|
|
if column == x_axis
|
|
or (
|
|
isinstance(column, dict)
|
|
and column["sqlExpression"] == x_axis
|
|
)
|
|
),
|
|
None,
|
|
)
|
|
# Replaces x-axis column values with granularity
|
|
if x_axis_column:
|
|
if isinstance(x_axis_column, dict):
|
|
x_axis_column["sqlExpression"] = granularity
|
|
x_axis_column["label"] = granularity
|
|
else:
|
|
query_object.columns = [
|
|
granularity if column == x_axis_column else column
|
|
for column in query_object.columns
|
|
]
|
|
for post_processing in query_object.post_processing:
|
|
if post_processing.get("operation") == "pivot":
|
|
post_processing["options"]["index"] = [granularity]
|
|
|
|
# If no temporal x-axis, then get the default temporal filter
|
|
if not filter_to_remove:
|
|
temporal_filters = [
|
|
filter["col"]
|
|
for filter in query_object.filter
|
|
if filter["op"] == "TEMPORAL_RANGE"
|
|
]
|
|
if len(temporal_filters) > 0:
|
|
# Use granularity if it's already in the filters
|
|
if granularity in temporal_filters:
|
|
filter_to_remove = granularity
|
|
else:
|
|
# Use the first temporal filter
|
|
filter_to_remove = temporal_filters[0]
|
|
|
|
# Removes the temporal filter which may be an x-axis or
|
|
# another temporal filter. A new filter based on the value of
|
|
# the granularity will be added later in the code.
|
|
# In practice, this is replacing the previous default temporal filter.
|
|
if is_adhoc_column(filter_to_remove): # type: ignore
|
|
filter_to_remove = filter_to_remove.get("sqlExpression")
|
|
|
|
if filter_to_remove:
|
|
query_object.filter = [
|
|
filter
|
|
for filter in query_object.filter
|
|
if filter["col"] != filter_to_remove
|
|
]
|
|
|
|
def _apply_filters(self, query_object: QueryObject) -> None:
|
|
if query_object.time_range:
|
|
for filter_object in query_object.filter:
|
|
if filter_object["op"] == "TEMPORAL_RANGE":
|
|
filter_object["val"] = query_object.time_range
|