mirror of
https://github.com/apache/superset.git
synced 2026-05-09 09:55:19 +00:00
feat(advanced analysis): support MultiIndex column in post processing stage (#19116)
This commit is contained in:
committed by
Ville Brofeldt
parent
f8a92de75c
commit
9bc76337cf
@@ -18,10 +18,11 @@ from functools import partial
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame, NamedAgg, Timestamp
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
|
||||
NUMPY_FUNCTIONS = {
|
||||
"average": np.average,
|
||||
@@ -91,6 +92,8 @@ PROPHET_TIME_GRAIN_MAP = {
|
||||
"P1W/1970-01-04T00:00:00Z": "W",
|
||||
}
|
||||
|
||||
FLAT_COLUMN_SEPARATOR = ", "
|
||||
|
||||
|
||||
def _flatten_column_after_pivot(
|
||||
column: Union[float, Timestamp, str, Tuple[str, ...]],
|
||||
@@ -113,21 +116,26 @@ def _flatten_column_after_pivot(
|
||||
# drop aggregate for single aggregate pivots with multiple groupings
|
||||
# from column name (aggregates always come first in column name)
|
||||
column = column[1:]
|
||||
return ", ".join([str(col) for col in column])
|
||||
return FLAT_COLUMN_SEPARATOR.join([str(col) for col in column])
|
||||
|
||||
|
||||
def _is_multi_index_on_columns(df: DataFrame) -> bool:
|
||||
return isinstance(df.columns, pd.MultiIndex)
|
||||
|
||||
|
||||
def validate_column_args(*argnames: str) -> Callable[..., Any]:
|
||||
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapped(df: DataFrame, **options: Any) -> Any:
|
||||
if options.get("is_pivot_df"):
|
||||
# skip validation when pivot Dataframe
|
||||
return func(df, **options)
|
||||
columns = df.columns.tolist()
|
||||
if _is_multi_index_on_columns(df):
|
||||
# MultiIndex column validate first level
|
||||
columns = df.columns.get_level_values(0)
|
||||
else:
|
||||
columns = df.columns.tolist()
|
||||
for name in argnames:
|
||||
if name in options and not all(
|
||||
elem in columns for elem in options.get(name) or []
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Referenced columns not available in DataFrame.")
|
||||
)
|
||||
return func(df, **options)
|
||||
@@ -152,14 +160,14 @@ def _get_aggregate_funcs(
|
||||
for name, agg_obj in aggregates.items():
|
||||
column = agg_obj.get("column", name)
|
||||
if column not in df:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
"Column referenced by aggregate is undefined: %(column)s",
|
||||
column=column,
|
||||
)
|
||||
)
|
||||
if "operator" not in agg_obj:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Operator undefined for aggregator: %(name)s", name=name,)
|
||||
)
|
||||
operator = agg_obj["operator"]
|
||||
@@ -168,7 +176,7 @@ def _get_aggregate_funcs(
|
||||
else:
|
||||
func = NUMPY_FUNCTIONS.get(operator)
|
||||
if not func:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Invalid numpy function: %(operator)s", operator=operator,)
|
||||
)
|
||||
options = agg_obj.get("options", {})
|
||||
@@ -186,6 +194,8 @@ def _append_columns(
|
||||
assign method, which overwrites the original column in `base_df` if the column
|
||||
already exists, and appends the column if the name is not defined.
|
||||
|
||||
Note that! this is a memory-intensive operation.
|
||||
|
||||
:param base_df: DataFrame which to use as the base
|
||||
:param append_df: DataFrame from which to select data.
|
||||
:param columns: columns on which to append, mapping source column to
|
||||
@@ -196,6 +206,10 @@ def _append_columns(
|
||||
in `base_df` unchanged.
|
||||
:return: new DataFrame with combined data from `base_df` and `append_df`
|
||||
"""
|
||||
return base_df.assign(
|
||||
**{target: append_df[source] for source, target in columns.items()}
|
||||
)
|
||||
if all(key == value for key, value in columns.items()):
|
||||
# make sure to return a new DataFrame instead of changing the `base_df`.
|
||||
_base_df = base_df.copy()
|
||||
_base_df.loc[:, columns.keys()] = append_df
|
||||
return _base_df
|
||||
append_df = append_df.rename(columns=columns)
|
||||
return pd.concat([base_df, append_df], axis="columns")
|
||||
|
||||
Reference in New Issue
Block a user