feat(advanced analysis): support MultiIndex column in post processing stage (#19116)

This commit is contained in:
Yongjie Zhao
2022-03-23 13:46:28 +08:00
committed by GitHub
parent 6083545e86
commit 375c03e084
55 changed files with 1267 additions and 772 deletions

View File

@@ -19,10 +19,9 @@ from typing import Any, Dict, Optional, Union
from flask_babel import gettext as _
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.utils import (
_append_columns,
_flatten_column_after_pivot,
DENYLIST_ROLLING_FUNCTIONS,
validate_column_args,
)
@@ -32,13 +31,12 @@ from superset.utils.pandas_postprocessing.utils import (
def rolling( # pylint: disable=too-many-arguments
df: DataFrame,
rolling_type: str,
columns: Optional[Dict[str, str]] = None,
columns: Dict[str, str],
window: Optional[int] = None,
rolling_type_options: Optional[Dict[str, Any]] = None,
center: bool = False,
win_type: Optional[str] = None,
min_periods: Optional[int] = None,
is_pivot_df: bool = False,
) -> DataFrame:
"""
Apply a rolling window on the dataset. See the Pandas docs for further details:
@@ -58,21 +56,17 @@ def rolling( # pylint: disable=too-many-arguments
:param win_type: Type of window function.
:param min_periods: The minimum amount of periods required for a row to be included
in the result set.
:param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with the rolling columns
:raises QueryObjectValidationError: If the request in incorrect
:raises InvalidPostProcessingError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
columns = columns or {}
if is_pivot_df:
df_rolling = df
else:
df_rolling = df[columns.keys()]
df_rolling = df.loc[:, columns.keys()]
kwargs: Dict[str, Union[str, int]] = {}
if window is None:
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
raise InvalidPostProcessingError(_("Undefined window for rolling operation"))
if window == 0:
raise QueryObjectValidationError(_("Window must be > 0"))
raise InvalidPostProcessingError(_("Window must be > 0"))
kwargs["window"] = window
if min_periods is not None:
@@ -86,13 +80,13 @@ def rolling( # pylint: disable=too-many-arguments
if rolling_type not in DENYLIST_ROLLING_FUNCTIONS or not hasattr(
df_rolling, rolling_type
):
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_("Invalid rolling_type: %(type)s", type=rolling_type)
)
try:
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
except TypeError as ex:
raise QueryObjectValidationError(
raise InvalidPostProcessingError(
_(
"Invalid options for %(rolling_type)s: %(options)s",
rolling_type=rolling_type,
@@ -100,15 +94,7 @@ def rolling( # pylint: disable=too-many-arguments
)
) from ex
if is_pivot_df:
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
df_rolling.columns = [
_flatten_column_after_pivot(col, agg) for col in df_rolling.columns
]
df_rolling.reset_index(level=0, inplace=True)
else:
df_rolling = _append_columns(df, df_rolling, columns)
df_rolling = _append_columns(df, df_rolling, columns)
if min_periods:
df_rolling = df_rolling[min_periods:]