mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
fix: Always use temporal type for dttm columns [ID-2] (#17458)
* fix: Always use temporal type for dttm columns * move inference and implement in chart postproc * fix test * fix test case Co-authored-by: Ville Brofeldt <ville.v.brofeldt@gmail.com>
This commit is contained in:
committed by
GitHub
parent
66d756955b
commit
1f8eff72de
@@ -40,6 +40,7 @@ from superset.charts.data.query_context_cache_loader import QueryContextCacheLoa
|
||||
from superset.charts.post_processing import apply_post_process
|
||||
from superset.charts.schemas import ChartDataQueryContextSchema
|
||||
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.extensions import event_logger
|
||||
from superset.utils.async_query_manager import AsyncQueryTokenException
|
||||
@@ -158,7 +159,9 @@ class ChartDataRestApi(ChartRestApi):
|
||||
except (TypeError, json.decoder.JSONDecodeError):
|
||||
form_data = {}
|
||||
|
||||
return self._get_data_response(command, form_data=form_data)
|
||||
return self._get_data_response(
|
||||
command=command, form_data=form_data, datasource=query_context.datasource
|
||||
)
|
||||
|
||||
@expose("/data", methods=["POST"])
|
||||
@protect()
|
||||
@@ -327,7 +330,10 @@ class ChartDataRestApi(ChartRestApi):
|
||||
return self.response(202, **result)
|
||||
|
||||
def _send_chart_response(
|
||||
self, result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
|
||||
self,
|
||||
result: Dict[Any, Any],
|
||||
form_data: Optional[Dict[str, Any]] = None,
|
||||
datasource: Optional[BaseDatasource] = None,
|
||||
) -> Response:
|
||||
result_type = result["query_context"].result_type
|
||||
result_format = result["query_context"].result_format
|
||||
@@ -336,7 +342,7 @@ class ChartDataRestApi(ChartRestApi):
|
||||
# This is needed for sending reports based on text charts that do the
|
||||
# post-processing of data, eg, the pivot table.
|
||||
if result_type == ChartDataResultType.POST_PROCESSED:
|
||||
result = apply_post_process(result, form_data)
|
||||
result = apply_post_process(result, form_data, datasource)
|
||||
|
||||
if result_format == ChartDataResultFormat.CSV:
|
||||
# Verify user has permission to export CSV file
|
||||
@@ -364,6 +370,7 @@ class ChartDataRestApi(ChartRestApi):
|
||||
command: ChartDataCommand,
|
||||
force_cached: bool = False,
|
||||
form_data: Optional[Dict[str, Any]] = None,
|
||||
datasource: Optional[BaseDatasource] = None,
|
||||
) -> Response:
|
||||
try:
|
||||
result = command.run(force_cached=force_cached)
|
||||
@@ -372,7 +379,7 @@ class ChartDataRestApi(ChartRestApi):
|
||||
except ChartDataQueryFailedError as exc:
|
||||
return self.response_400(message=exc.message)
|
||||
|
||||
return self._send_chart_response(result, form_data)
|
||||
return self._send_chart_response(result, form_data, datasource)
|
||||
|
||||
# pylint: disable=invalid-name, no-self-use
|
||||
def _load_query_context_form_from_cache(self, cache_key: str) -> Dict[str, Any]:
|
||||
|
||||
@@ -27,13 +27,16 @@ for these chart types.
|
||||
"""
|
||||
|
||||
from io import StringIO
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from superset.common.chart_data import ChartDataResultFormat
|
||||
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.base.models import BaseDatasource
|
||||
|
||||
|
||||
def get_column_key(label: Tuple[str, ...], metrics: List[str]) -> Tuple[Any, ...]:
|
||||
"""
|
||||
@@ -284,7 +287,9 @@ post_processors = {
|
||||
|
||||
|
||||
def apply_post_process(
|
||||
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
|
||||
result: Dict[Any, Any],
|
||||
form_data: Optional[Dict[str, Any]] = None,
|
||||
datasource: Optional["BaseDatasource"] = None,
|
||||
) -> Dict[Any, Any]:
|
||||
form_data = form_data or {}
|
||||
|
||||
@@ -306,7 +311,7 @@ def apply_post_process(
|
||||
|
||||
query["colnames"] = list(processed_df.columns)
|
||||
query["indexnames"] = list(processed_df.index)
|
||||
query["coltypes"] = extract_dataframe_dtypes(processed_df)
|
||||
query["coltypes"] = extract_dataframe_dtypes(processed_df, datasource)
|
||||
query["rowcount"] = len(processed_df.index)
|
||||
|
||||
# Flatten hierarchical columns/index since they are represented as
|
||||
|
||||
@@ -104,7 +104,7 @@ def _get_full(
|
||||
if status != QueryStatus.FAILED:
|
||||
payload["colnames"] = list(df.columns)
|
||||
payload["indexnames"] = list(df.index)
|
||||
payload["coltypes"] = extract_dataframe_dtypes(df)
|
||||
payload["coltypes"] = extract_dataframe_dtypes(df, datasource)
|
||||
payload["data"] = query_context.get_data(df)
|
||||
payload["result_format"] = query_context.result_format
|
||||
del payload["df"]
|
||||
|
||||
@@ -1597,7 +1597,9 @@ def get_column_names_from_metrics(metrics: List[Metric]) -> List[str]:
|
||||
return [col for col in map(get_column_name_from_metric, metrics) if col]
|
||||
|
||||
|
||||
def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:
|
||||
def extract_dataframe_dtypes(
|
||||
df: pd.DataFrame, datasource: Optional["BaseDatasource"] = None,
|
||||
) -> List[GenericDataType]:
|
||||
"""Serialize pandas/numpy dtypes to generic types"""
|
||||
|
||||
# omitting string types as those will be the default type
|
||||
@@ -1612,11 +1614,21 @@ def extract_dataframe_dtypes(df: pd.DataFrame) -> List[GenericDataType]:
|
||||
"date": GenericDataType.TEMPORAL,
|
||||
}
|
||||
|
||||
columns_by_name = (
|
||||
{column.column_name: column for column in datasource.columns}
|
||||
if datasource
|
||||
else {}
|
||||
)
|
||||
generic_types: List[GenericDataType] = []
|
||||
for column in df.columns:
|
||||
column_object = columns_by_name.get(column)
|
||||
series = df[column]
|
||||
inferred_type = infer_dtype(series)
|
||||
generic_type = inferred_type_map.get(inferred_type, GenericDataType.STRING)
|
||||
generic_type = (
|
||||
GenericDataType.TEMPORAL
|
||||
if column_object and column_object.is_dttm
|
||||
else inferred_type_map.get(inferred_type, GenericDataType.STRING)
|
||||
)
|
||||
generic_types.append(generic_type)
|
||||
|
||||
return generic_types
|
||||
|
||||
@@ -1121,7 +1121,9 @@ class TestUtils(SupersetTestCase):
|
||||
generated_token = get_form_data_token({})
|
||||
assert re.match(r"^token_[a-z0-9]{8}$", generated_token) is not None
|
||||
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_extract_dataframe_dtypes(self):
|
||||
slc = self.get_slice("Girls", db.session)
|
||||
cols: Tuple[Tuple[str, GenericDataType, List[Any]], ...] = (
|
||||
("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]),
|
||||
(
|
||||
@@ -1147,10 +1149,13 @@ class TestUtils(SupersetTestCase):
|
||||
("float_null", GenericDataType.NUMERIC, [None, 0.5]),
|
||||
("bool_null", GenericDataType.BOOLEAN, [None, False]),
|
||||
("obj_null", GenericDataType.STRING, [None, {"a": 1}]),
|
||||
# Non-timestamp columns should be identified as temporal if
|
||||
# `is_dttm` is set to `True` in the underlying datasource
|
||||
("ds", GenericDataType.TEMPORAL, [None, {"ds": "2017-01-01"}]),
|
||||
)
|
||||
|
||||
df = pd.DataFrame(data={col[0]: col[2] for col in cols})
|
||||
assert extract_dataframe_dtypes(df) == [col[1] for col in cols]
|
||||
assert extract_dataframe_dtypes(df, slc.datasource) == [col[1] for col in cols]
|
||||
|
||||
def test_normalize_dttm_col(self):
|
||||
def normalize_col(
|
||||
|
||||
Reference in New Issue
Block a user