fix: Time Column on Generic X-axis (#23021)

This commit is contained in:
Michael S. Molina
2023-02-10 13:33:07 -05:00
committed by GitHub
parent 85f07798bf
commit 464ddee4b4
4 changed files with 97 additions and 56 deletions

View File

@@ -22,6 +22,7 @@ from superset import app, db
from superset.charts.dao import ChartDAO
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_object import QueryObject
from superset.common.query_object_factory import QueryObjectFactory
from superset.datasource.dao import DatasourceDAO
from superset.models.slice import Slice
@@ -65,8 +66,12 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
result_type = result_type or ChartDataResultType.FULL
result_format = result_format or ChartDataResultFormat.JSON
queries_ = [
self._query_object_factory.create(
result_type, datasource=datasource, **query_obj
self._process_query_object(
datasource_model_instance,
form_data,
self._query_object_factory.create(
result_type, datasource=datasource, **query_obj
),
)
for query_obj in queries
]
@@ -90,7 +95,6 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
# pylint: disable=no-self-use
def _convert_to_model(self, datasource: DatasourceDict) -> BaseDatasource:
return DatasourceDAO.get_datasource(
session=db.session,
datasource_type=DatasourceType(datasource["type"]),
@@ -99,3 +103,89 @@ class QueryContextFactory: # pylint: disable=too-few-public-methods
def _get_slice(self, slice_id: Any) -> Optional[Slice]:
return ChartDAO.find_by_id(slice_id)
def _process_query_object(
self,
datasource: BaseDatasource,
form_data: Optional[Dict[str, Any]],
query_object: QueryObject,
) -> QueryObject:
self._apply_granularity(query_object, form_data, datasource)
self._apply_filters(query_object)
return query_object
def _apply_granularity(
self,
query_object: QueryObject,
form_data: Optional[Dict[str, Any]],
datasource: BaseDatasource,
) -> None:
temporal_columns = {
column.column_name
for column in datasource.columns
if (column["is_dttm"] if isinstance(column, dict) else column.is_dttm)
}
granularity = query_object.granularity
x_axis = form_data and form_data.get("x_axis")
if granularity:
filter_to_remove = None
if x_axis and x_axis in temporal_columns:
filter_to_remove = x_axis
x_axis_column = next(
(
column
for column in query_object.columns
if column == x_axis
or (
isinstance(column, dict)
and column["sqlExpression"] == x_axis
)
),
None,
)
# Replaces x-axis column values with granularity
if x_axis_column:
if isinstance(x_axis_column, dict):
x_axis_column["sqlExpression"] = granularity
x_axis_column["label"] = granularity
else:
query_object.columns = [
granularity if column == x_axis_column else column
for column in query_object.columns
]
for post_processing in query_object.post_processing:
if post_processing.get("operation") == "pivot":
post_processing["options"]["index"] = [granularity]
# If no temporal x-axis, then get the default temporal filter
if not filter_to_remove:
temporal_filters = [
filter["col"]
for filter in query_object.filter
if filter["op"] == "TEMPORAL_RANGE"
]
if len(temporal_filters) > 0:
# Use granularity if it's already in the filters
if granularity in temporal_filters:
filter_to_remove = granularity
else:
# Use the first temporal filter
filter_to_remove = temporal_filters[0]
# Removes the temporal filter which may be an x-axis or
# another temporal filter. A new filter based on the value of
# the granularity will be added later in the code.
# In practice, this is replacing the previous default temporal filter.
if filter_to_remove:
query_object.filter = [
filter
for filter in query_object.filter
if filter["col"] != filter_to_remove
]
def _apply_filters(self, query_object: QueryObject) -> None:
if query_object.time_range:
for filter_object in query_object.filter:
if filter_object["op"] == "TEMPORAL_RANGE":
filter_object["val"] = query_object.time_range