mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat: Add post processing to QueryObject (#9427)
* Add post processing to QueryObject * Simplify sort signature and require explicit sort order * Add new operations and unit tests * linting * Address comments * Simplify test method names * Address comments * Linting * remove unnecessary logic * Apply strict whitelisting to all getattr calls * Add checking of rolling_type_options and add/improve docs
This commit is contained in:
389
superset/utils/pandas_postprocessing.py
Normal file
389
superset/utils/pandas_postprocessing.py
Normal file
@@ -0,0 +1,389 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame, NamedAgg
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
|
||||
WHITELIST_NUMPY_FUNCTIONS = (
|
||||
"average",
|
||||
"argmin",
|
||||
"argmax",
|
||||
"cumsum",
|
||||
"cumprod",
|
||||
"max",
|
||||
"mean",
|
||||
"median",
|
||||
"nansum",
|
||||
"nanmin",
|
||||
"nanmax",
|
||||
"nanmean",
|
||||
"nanmedian",
|
||||
"min",
|
||||
"percentile",
|
||||
"prod",
|
||||
"product",
|
||||
"std",
|
||||
"sum",
|
||||
"var",
|
||||
)
|
||||
|
||||
WHITELIST_ROLLING_FUNCTIONS = (
|
||||
"count",
|
||||
"corr",
|
||||
"cov",
|
||||
"kurt",
|
||||
"max",
|
||||
"mean",
|
||||
"median",
|
||||
"min",
|
||||
"std",
|
||||
"skew",
|
||||
"sum",
|
||||
"var",
|
||||
"quantile",
|
||||
)
|
||||
|
||||
WHITELIST_CUMULATIVE_FUNCTIONS = (
|
||||
"cummax",
|
||||
"cummin",
|
||||
"cumprod",
|
||||
"cumsum",
|
||||
)
|
||||
|
||||
|
||||
def validate_column_args(*argnames: str) -> Callable:
|
||||
def wrapper(func):
|
||||
def wrapped(df, **options):
|
||||
columns = df.columns.tolist()
|
||||
for name in argnames:
|
||||
if name in options and not all(
|
||||
elem in columns for elem in options[name]
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
_("Referenced columns not available in DataFrame.")
|
||||
)
|
||||
return func(df, **options)
|
||||
|
||||
return wrapped
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def _get_aggregate_funcs(
|
||||
df: DataFrame, aggregates: Dict[str, Dict[str, Any]],
|
||||
) -> Dict[str, NamedAgg]:
|
||||
"""
|
||||
Converts a set of aggregate config objects into functions that pandas can use as
|
||||
aggregators. Currently only numpy aggregators are supported.
|
||||
|
||||
:param df: DataFrame on which to perform aggregate operation.
|
||||
:param aggregates: Mapping from column name to aggregat config.
|
||||
:return: Mapping from metric name to function that takes a single input argument.
|
||||
"""
|
||||
agg_funcs: Dict[str, NamedAgg] = {}
|
||||
for name, agg_obj in aggregates.items():
|
||||
column = agg_obj.get("column", name)
|
||||
if column not in df:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
"Column referenced by aggregate is undefined: %(column)s",
|
||||
column=column,
|
||||
)
|
||||
)
|
||||
if "operator" not in agg_obj:
|
||||
raise QueryObjectValidationError(
|
||||
_("Operator undefined for aggregator: %(name)s", name=name,)
|
||||
)
|
||||
operator = agg_obj["operator"]
|
||||
if operator not in WHITELIST_NUMPY_FUNCTIONS or not hasattr(np, operator):
|
||||
raise QueryObjectValidationError(
|
||||
_("Invalid numpy function: %(operator)s", operator=operator,)
|
||||
)
|
||||
func = getattr(np, operator)
|
||||
options = agg_obj.get("options", {})
|
||||
agg_funcs[name] = NamedAgg(column=column, aggfunc=partial(func, **options))
|
||||
|
||||
return agg_funcs
|
||||
|
||||
|
||||
def _append_columns(
|
||||
base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str]
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Function for adding columns from one DataFrame to another DataFrame. Calls the
|
||||
assign method, which overwrites the original column in `base_df` if the column
|
||||
already exists, and appends the column if the name is not defined.
|
||||
|
||||
:param base_df: DataFrame which to use as the base
|
||||
:param append_df: DataFrame from which to select data.
|
||||
:param columns: columns on which to append, mapping source column to
|
||||
target column. For instance, `{'y': 'y'}` will replace the values in
|
||||
column `y` in `base_df` with the values in `y` in `append_df`,
|
||||
while `{'y': 'y2'}` will add a column `y2` to `base_df` based
|
||||
on values in column `y` in `append_df`, leaving the original column `y`
|
||||
in `base_df` unchanged.
|
||||
:return: new DataFrame with combined data from `base_df` and `append_df`
|
||||
"""
|
||||
return base_df.assign(
|
||||
**{
|
||||
target: append_df[append_df.columns[idx]]
|
||||
for idx, target in enumerate(columns.values())
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@validate_column_args("index", "columns")
|
||||
def pivot( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
index: List[str],
|
||||
columns: List[str],
|
||||
aggregates: Dict[str, Dict[str, Any]],
|
||||
metric_fill_value: Optional[Any] = None,
|
||||
column_fill_value: Optional[str] = None,
|
||||
drop_missing_columns: Optional[bool] = True,
|
||||
combine_value_with_metric=False,
|
||||
marginal_distributions: Optional[bool] = None,
|
||||
marginal_distribution_name: Optional[str] = None,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Perform a pivot operation on a DataFrame.
|
||||
|
||||
:param df: Object on which pivot operation will be performed
|
||||
:param index: Columns to group by on the table index (=rows)
|
||||
:param columns: Columns to group by on the table columns
|
||||
:param metric_fill_value: Value to replace missing values with
|
||||
:param column_fill_value: Value to replace missing pivot columns with
|
||||
:param drop_missing_columns: Do not include columns whose entries are all missing
|
||||
:param combine_value_with_metric: Display metrics side by side within each column,
|
||||
as opposed to each column being displayed side by side for each metric.
|
||||
:param aggregates: A mapping from aggregate column name to the the aggregate
|
||||
config.
|
||||
:param marginal_distributions: Add totals for row/column. Default to False
|
||||
:param marginal_distribution_name: Name of row/column with marginal distribution.
|
||||
Default to 'All'.
|
||||
:return: A pivot table
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
if not index:
|
||||
raise QueryObjectValidationError(
|
||||
_("Pivot operation requires at least one index")
|
||||
)
|
||||
if not columns:
|
||||
raise QueryObjectValidationError(
|
||||
_("Pivot operation requires at least one column")
|
||||
)
|
||||
if not aggregates:
|
||||
raise QueryObjectValidationError(
|
||||
_("Pivot operation must include at least one aggregate")
|
||||
)
|
||||
|
||||
if column_fill_value:
|
||||
df[columns] = df[columns].fillna(value=column_fill_value)
|
||||
|
||||
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
|
||||
|
||||
# TODO (villebro): Pandas 1.0.3 doesn't yet support NamedAgg in pivot_table.
|
||||
# Remove once/if support is added.
|
||||
aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()}
|
||||
|
||||
df = df.pivot_table(
|
||||
values=aggfunc.keys(),
|
||||
index=index,
|
||||
columns=columns,
|
||||
aggfunc=aggfunc,
|
||||
fill_value=metric_fill_value,
|
||||
dropna=drop_missing_columns,
|
||||
margins=marginal_distributions,
|
||||
margins_name=marginal_distribution_name,
|
||||
)
|
||||
|
||||
if combine_value_with_metric:
|
||||
df = df.stack(0).unstack()
|
||||
|
||||
return df
|
||||
|
||||
|
||||
@validate_column_args("groupby")
|
||||
def aggregate(
|
||||
df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]]
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Apply aggregations to a DataFrame.
|
||||
|
||||
:param df: Object to aggregate.
|
||||
:param groupby: columns to aggregate
|
||||
:param aggregates: A mapping from metric column to the function used to
|
||||
aggregate values.
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
aggregates = aggregates or {}
|
||||
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
|
||||
return df.groupby(by=groupby).agg(**aggregate_funcs).reset_index()
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
|
||||
"""
|
||||
Sort a DataFrame.
|
||||
|
||||
:param df: DataFrame to sort.
|
||||
:param columns: columns by by which to sort. The key specifies the column name,
|
||||
value specifies if sorting in ascending order.
|
||||
:return: Sorted DataFrame
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def rolling( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
columns: Dict[str, str],
|
||||
rolling_type: str,
|
||||
window: int,
|
||||
rolling_type_options: Optional[Dict[str, Any]] = None,
|
||||
center: bool = False,
|
||||
win_type: Optional[str] = None,
|
||||
min_periods: Optional[int] = None,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Apply a rolling window on the dataset. See the Pandas docs for further details:
|
||||
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.rolling.html
|
||||
|
||||
:param df: DataFrame on which the rolling period will be based.
|
||||
:param columns: columns on which to perform rolling, mapping source column to
|
||||
target column. For instance, `{'y': 'y'}` will replace the column `y` with
|
||||
the rolling value in `y`, while `{'y': 'y2'}` will add a column `y2` based
|
||||
on rolling values calculated from `y`, leaving the original column `y`
|
||||
unchanged.
|
||||
:param rolling_type: Type of rolling window. Any numpy function will work.
|
||||
:param rolling_type_options: Optional options to pass to rolling method. Needed
|
||||
for e.g. quantile operation.
|
||||
:param center: Should the label be at the center of the window.
|
||||
:param win_type: Type of window function.
|
||||
:param window: Size of the window.
|
||||
:param min_periods:
|
||||
:return: DataFrame with the rolling columns
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
rolling_type_options = rolling_type_options or {}
|
||||
df_rolling = df[columns.keys()]
|
||||
kwargs: Dict[str, Union[str, int]] = {}
|
||||
if not window:
|
||||
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
|
||||
|
||||
kwargs["window"] = window
|
||||
if min_periods is not None:
|
||||
kwargs["min_periods"] = min_periods
|
||||
if center is not None:
|
||||
kwargs["center"] = center
|
||||
if win_type is not None:
|
||||
kwargs["win_type"] = win_type
|
||||
|
||||
df_rolling = df_rolling.rolling(**kwargs)
|
||||
if rolling_type not in WHITELIST_ROLLING_FUNCTIONS or not hasattr(
|
||||
df_rolling, rolling_type
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
_("Invalid rolling_type: %(type)s", type=rolling_type)
|
||||
)
|
||||
try:
|
||||
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
|
||||
except TypeError:
|
||||
raise QueryObjectValidationError(
|
||||
_(
|
||||
"Invalid options for %(rolling_type)s: %(options)s",
|
||||
rolling_type=rolling_type,
|
||||
options=rolling_type_options,
|
||||
)
|
||||
)
|
||||
df = _append_columns(df, df_rolling, columns)
|
||||
if min_periods:
|
||||
df = df[min_periods:]
|
||||
return df
|
||||
|
||||
|
||||
@validate_column_args("columns", "rename")
|
||||
def select(
|
||||
df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Only select a subset of columns in the original dataset. Can be useful for
|
||||
removing unnecessary intermediate results, renaming and reordering columns.
|
||||
|
||||
:param df: DataFrame on which the rolling period will be based.
|
||||
:param columns: Columns which to select from the DataFrame, in the desired order.
|
||||
If columns are renamed, the new column name should be referenced
|
||||
here.
|
||||
:param rename: columns which to rename, mapping source column to target column.
|
||||
For instance, `{'y': 'y2'}` will rename the column `y` to
|
||||
`y2`.
|
||||
:return: Subset of columns in original DataFrame
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
df_select = df[columns]
|
||||
if rename is not None:
|
||||
df_select = df_select.rename(columns=rename)
|
||||
return df_select
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame:
|
||||
"""
|
||||
|
||||
:param df: DataFrame on which the diff will be based.
|
||||
:param columns: columns on which to perform diff, mapping source column to
|
||||
target column. For instance, `{'y': 'y'}` will replace the column `y` with
|
||||
the diff value in `y`, while `{'y': 'y2'}` will add a column `y2` based
|
||||
on diff values calculated from `y`, leaving the original column `y`
|
||||
unchanged.
|
||||
:param periods: periods to shift for calculating difference.
|
||||
:return: DataFrame with diffed columns
|
||||
:raises ChartDataValidationError: If the request in incorrect
|
||||
"""
|
||||
df_diff = df[columns.keys()]
|
||||
df_diff = df_diff.diff(periods=periods)
|
||||
return _append_columns(df, df_diff, columns)
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
|
||||
"""
|
||||
|
||||
:param df: DataFrame on which the cumulative operation will be based.
|
||||
:param columns: columns on which to perform a cumulative operation, mapping source
|
||||
column to target column. For instance, `{'y': 'y'}` will replace the column
|
||||
`y` with the cumulative value in `y`, while `{'y': 'y2'}` will add a column
|
||||
`y2` based on cumulative values calculated from `y`, leaving the original
|
||||
column `y` unchanged.
|
||||
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
|
||||
:return:
|
||||
"""
|
||||
df_cum = df[columns.keys()]
|
||||
operation = "cum" + operator
|
||||
if operation not in WHITELIST_CUMULATIVE_FUNCTIONS or not hasattr(
|
||||
df_cum, operation
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
_("Invalid cumulative operator: %(operator)s", operator=operator)
|
||||
)
|
||||
return _append_columns(df, getattr(df_cum, operation)(), columns)
|
||||
Reference in New Issue
Block a user