fix: rolling and cum operator on multiple series (#16945)

* fix: rolling and cum operator on multiple series

* add UT

* updates
This commit is contained in:
Yongjie Zhao
2021-10-07 09:42:08 +01:00
committed by GitHub
parent 6dc00b3e3f
commit fd8461406d
3 changed files with 185 additions and 9 deletions

View File

@@ -131,6 +131,9 @@ def _flatten_column_after_pivot(
def validate_column_args(*argnames: str) -> Callable[..., Any]:
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
def wrapped(df: DataFrame, **options: Any) -> Any:
if options.get("is_pivot_df"):
# skip validation when pivot Dataframe
return func(df, **options)
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
@@ -223,6 +226,7 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
marginal_distributions: Optional[bool] = None,
marginal_distribution_name: Optional[str] = None,
flatten_columns: bool = True,
reset_index: bool = True,
) -> DataFrame:
"""
Perform a pivot operation on a DataFrame.
@@ -243,6 +247,7 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
:param marginal_distribution_name: Name of row/column with marginal distribution.
Default to 'All'.
:param flatten_columns: Convert column names to strings
:param reset_index: Convert index to column
:return: A pivot table
:raises QueryObjectValidationError: If the request in incorrect
"""
@@ -300,7 +305,8 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
_flatten_column_after_pivot(col, aggregates) for col in df.columns
]
# return index as regular column
df.reset_index(level=0, inplace=True)
if reset_index:
df.reset_index(level=0, inplace=True)
return df
@@ -343,13 +349,14 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
@validate_column_args("columns")
def rolling( # pylint: disable=too-many-arguments
df: DataFrame,
columns: Dict[str, str],
rolling_type: str,
columns: Optional[Dict[str, str]] = None,
window: Optional[int] = None,
rolling_type_options: Optional[Dict[str, Any]] = None,
center: bool = False,
win_type: Optional[str] = None,
min_periods: Optional[int] = None,
is_pivot_df: bool = False,
) -> DataFrame:
"""
Apply a rolling window on the dataset. See the Pandas docs for further details:
@@ -369,11 +376,16 @@ def rolling( # pylint: disable=too-many-arguments
:param win_type: Type of window function.
:param min_periods: The minimum amount of periods required for a row to be included
in the result set.
:param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with the rolling columns
:raises QueryObjectValidationError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
df_rolling = df[columns.keys()]
columns = columns or {}
if is_pivot_df:
df_rolling = df
else:
df_rolling = df[columns.keys()]
kwargs: Dict[str, Union[str, int]] = {}
if window is None:
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
@@ -405,10 +417,20 @@ def rolling( # pylint: disable=too-many-arguments
options=rolling_type_options,
)
) from ex
df = _append_columns(df, df_rolling, columns)
if is_pivot_df:
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
df_rolling.columns = [
_flatten_column_after_pivot(col, agg) for col in df_rolling.columns
]
df_rolling.reset_index(level=0, inplace=True)
else:
df_rolling = _append_columns(df, df_rolling, columns)
if min_periods:
df = df[min_periods:]
return df
df_rolling = df_rolling[min_periods:]
return df_rolling
@validate_column_args("columns", "drop", "rename")
@@ -524,7 +546,12 @@ def compare( # pylint: disable=too-many-arguments
@validate_column_args("columns")
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
def cum(
df: DataFrame,
operator: str,
columns: Optional[Dict[str, str]] = None,
is_pivot_df: bool = False,
) -> DataFrame:
"""
Calculate cumulative sum/product/min/max for select columns.
@@ -535,9 +562,14 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
`y2` based on cumulative values calculated from `y`, leaving the original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
:param is_pivot_df: Dataframe is pivoted or not
:return: DataFrame with cumulated columns
"""
df_cum = df[columns.keys()]
columns = columns or {}
if is_pivot_df:
df_cum = df
else:
df_cum = df[columns.keys()]
operation = "cum" + operator
if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr(
df_cum, operation
@@ -545,7 +577,17 @@ def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
raise QueryObjectValidationError(
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
return _append_columns(df, getattr(df_cum, operation)(), columns)
if is_pivot_df:
df_cum = getattr(df_cum, operation)()
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
df_cum.columns = [
_flatten_column_after_pivot(col, agg) for col in df_cum.columns
]
df_cum.reset_index(level=0, inplace=True)
else:
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
return df_cum
def geohash_decode(