feat(advanced analysis): support MultiIndex column in post processing stage (#19116)

This commit is contained in:
Yongjie Zhao
2022-03-23 13:46:28 +08:00
committed by Ville Brofeldt
parent f8a92de75c
commit 9bc76337cf
55 changed files with 1272 additions and 772 deletions

View File

@@ -18,10 +18,11 @@ from functools import partial
from typing import Any, Callable, Dict, Tuple, Union
import numpy as np
import pandas as pd
from flask_babel import gettext as _
from pandas import DataFrame, NamedAgg, Timestamp
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
NUMPY_FUNCTIONS = {
"average": np.average,
@@ -91,6 +92,8 @@ PROPHET_TIME_GRAIN_MAP = {
"P1W/1970-01-04T00:00:00Z": "W",
}
FLAT_COLUMN_SEPARATOR = ", "
def _flatten_column_after_pivot(
column: Union[float, Timestamp, str, Tuple[str, ...]],
@@ -113,21 +116,26 @@ def _flatten_column_after_pivot(
# drop aggregate for single aggregate pivots with multiple groupings
# from column name (aggregates always come first in column name)
column = column[1:]
return ", ".join([str(col) for col in column])
return FLAT_COLUMN_SEPARATOR.join([str(col) for col in column])
def _is_multi_index_on_columns(df: DataFrame) -> bool:
return isinstance(df.columns, pd.MultiIndex)
def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
if options.get("is_pivot_df"):
# skip validation when pivot Dataframe
return func(df, **options)
columns = df.columns.tolist()
if _is_multi_index_on_columns(df):
# MultiIndex column validate first level
columns = df.columns.get_level_values(0)
else:
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
elem in columns for elem in options.get(name) or []
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Referenced columns not available in DataFrame.")
)
return func(df, **options)
@@ -152,14 +160,14 @@ def _get_aggregate_funcs(
for name, agg_obj in aggregates.items():
column = agg_obj.get("column", name)
if column not in df:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
"Column referenced by aggregate is undefined: %(column)s",
column=column,
)
)
if "operator" not in agg_obj:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Operator undefined for aggregator: %(name)s", name=name,)
)
operator = agg_obj["operator"]
@@ -168,7 +176,7 @@ def _get_aggregate_funcs(
else:
func = NUMPY_FUNCTIONS.get(operator)
if not func:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Invalid numpy function: %(operator)s", operator=operator,)
)
options = agg_obj.get("options", {})
@@ -186,6 +194,8 @@ def _append_columns(
assign method, which overwrites the original column in `base_df` if the column
already exists, and appends the column if the name is not defined.
Note that! this is a memory-intensive operation.
:param base_df: DataFrame which to use as the base
:param append_df: DataFrame from which to select data.
:param columns: columns on which to append, mapping source column to
@@ -196,6 +206,10 @@ def _append_columns(
in `base_df` unchanged.
:return: new DataFrame with combined data from `base_df` and `append_df`
"""
return base_df.assign(
**{target: append_df[source] for source, target in columns.items()}
)
if all(key == value for key, value in columns.items()):
# make sure to return a new DataFrame instead of changing the `base_df`.
_base_df = base_df.copy()
_base_df.loc[:, columns.keys()] = append_df
return _base_df
append_df = append_df.rename(columns=columns)
return pd.concat([base_df, append_df], axis="columns")