feat: Add post processing to QueryObject (#9427)

* Add post processing to QueryObject

* Simplify sort signature and require explicit sort order

* Add new operations and unit tests

* linting

* Address comments

* Simplify test method names

* Address comments

* Linting

* remove unnecessary logic

* Apply strict whitelisting to all getattr calls

* Add checking of rolling_type_options and add/improve docs
This commit is contained in:
Ville Brofeldt
2020-04-10 20:50:11 +03:00
committed by GitHub
parent 5ec0192bcc
commit a8ce3bccdf
9 changed files with 899 additions and 12 deletions

View File

@@ -0,0 +1,389 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
from flask_babel import gettext as _
from pandas import DataFrame, NamedAgg
from superset.exceptions import QueryObjectValidationError
WHITELIST_NUMPY_FUNCTIONS = (
"average",
"argmin",
"argmax",
"cumsum",
"cumprod",
"max",
"mean",
"median",
"nansum",
"nanmin",
"nanmax",
"nanmean",
"nanmedian",
"min",
"percentile",
"prod",
"product",
"std",
"sum",
"var",
)
WHITELIST_ROLLING_FUNCTIONS = (
"count",
"corr",
"cov",
"kurt",
"max",
"mean",
"median",
"min",
"std",
"skew",
"sum",
"var",
"quantile",
)
WHITELIST_CUMULATIVE_FUNCTIONS = (
"cummax",
"cummin",
"cumprod",
"cumsum",
)
def validate_column_args(*argnames: str) -> Callable:
def wrapper(func):
def wrapped(df, **options):
columns = df.columns.tolist()
for name in argnames:
if name in options and not all(
elem in columns for elem in options[name]
):
raise QueryObjectValidationError(
_("Referenced columns not available in DataFrame.")
)
return func(df, **options)
return wrapped
return wrapper
def _get_aggregate_funcs(
df: DataFrame, aggregates: Dict[str, Dict[str, Any]],
) -> Dict[str, NamedAgg]:
"""
Converts a set of aggregate config objects into functions that pandas can use as
aggregators. Currently only numpy aggregators are supported.
:param df: DataFrame on which to perform aggregate operation.
:param aggregates: Mapping from column name to aggregat config.
:return: Mapping from metric name to function that takes a single input argument.
"""
agg_funcs: Dict[str, NamedAgg] = {}
for name, agg_obj in aggregates.items():
column = agg_obj.get("column", name)
if column not in df:
raise QueryObjectValidationError(
_(
"Column referenced by aggregate is undefined: %(column)s",
column=column,
)
)
if "operator" not in agg_obj:
raise QueryObjectValidationError(
_("Operator undefined for aggregator: %(name)s", name=name,)
)
operator = agg_obj["operator"]
if operator not in WHITELIST_NUMPY_FUNCTIONS or not hasattr(np, operator):
raise QueryObjectValidationError(
_("Invalid numpy function: %(operator)s", operator=operator,)
)
func = getattr(np, operator)
options = agg_obj.get("options", {})
agg_funcs[name] = NamedAgg(column=column, aggfunc=partial(func, **options))
return agg_funcs
def _append_columns(
base_df: DataFrame, append_df: DataFrame, columns: Dict[str, str]
) -> DataFrame:
"""
Function for adding columns from one DataFrame to another DataFrame. Calls the
assign method, which overwrites the original column in `base_df` if the column
already exists, and appends the column if the name is not defined.
:param base_df: DataFrame which to use as the base
:param append_df: DataFrame from which to select data.
:param columns: columns on which to append, mapping source column to
target column. For instance, `{'y': 'y'}` will replace the values in
column `y` in `base_df` with the values in `y` in `append_df`,
while `{'y': 'y2'}` will add a column `y2` to `base_df` based
on values in column `y` in `append_df`, leaving the original column `y`
in `base_df` unchanged.
:return: new DataFrame with combined data from `base_df` and `append_df`
"""
return base_df.assign(
**{
target: append_df[append_df.columns[idx]]
for idx, target in enumerate(columns.values())
}
)
@validate_column_args("index", "columns")
def pivot( # pylint: disable=too-many-arguments
df: DataFrame,
index: List[str],
columns: List[str],
aggregates: Dict[str, Dict[str, Any]],
metric_fill_value: Optional[Any] = None,
column_fill_value: Optional[str] = None,
drop_missing_columns: Optional[bool] = True,
combine_value_with_metric=False,
marginal_distributions: Optional[bool] = None,
marginal_distribution_name: Optional[str] = None,
) -> DataFrame:
"""
Perform a pivot operation on a DataFrame.
:param df: Object on which pivot operation will be performed
:param index: Columns to group by on the table index (=rows)
:param columns: Columns to group by on the table columns
:param metric_fill_value: Value to replace missing values with
:param column_fill_value: Value to replace missing pivot columns with
:param drop_missing_columns: Do not include columns whose entries are all missing
:param combine_value_with_metric: Display metrics side by side within each column,
as opposed to each column being displayed side by side for each metric.
:param aggregates: A mapping from aggregate column name to the the aggregate
config.
:param marginal_distributions: Add totals for row/column. Default to False
:param marginal_distribution_name: Name of row/column with marginal distribution.
Default to 'All'.
:return: A pivot table
:raises ChartDataValidationError: If the request in incorrect
"""
if not index:
raise QueryObjectValidationError(
_("Pivot operation requires at least one index")
)
if not columns:
raise QueryObjectValidationError(
_("Pivot operation requires at least one column")
)
if not aggregates:
raise QueryObjectValidationError(
_("Pivot operation must include at least one aggregate")
)
if column_fill_value:
df[columns] = df[columns].fillna(value=column_fill_value)
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
# TODO (villebro): Pandas 1.0.3 doesn't yet support NamedAgg in pivot_table.
# Remove once/if support is added.
aggfunc = {na.column: na.aggfunc for na in aggregate_funcs.values()}
df = df.pivot_table(
values=aggfunc.keys(),
index=index,
columns=columns,
aggfunc=aggfunc,
fill_value=metric_fill_value,
dropna=drop_missing_columns,
margins=marginal_distributions,
margins_name=marginal_distribution_name,
)
if combine_value_with_metric:
df = df.stack(0).unstack()
return df
@validate_column_args("groupby")
def aggregate(
df: DataFrame, groupby: List[str], aggregates: Dict[str, Dict[str, Any]]
) -> DataFrame:
"""
Apply aggregations to a DataFrame.
:param df: Object to aggregate.
:param groupby: columns to aggregate
:param aggregates: A mapping from metric column to the function used to
aggregate values.
:raises ChartDataValidationError: If the request in incorrect
"""
aggregates = aggregates or {}
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
return df.groupby(by=groupby).agg(**aggregate_funcs).reset_index()
@validate_column_args("columns")
def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
"""
Sort a DataFrame.
:param df: DataFrame to sort.
:param columns: columns by by which to sort. The key specifies the column name,
value specifies if sorting in ascending order.
:return: Sorted DataFrame
:raises ChartDataValidationError: If the request in incorrect
"""
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))
@validate_column_args("columns")
def rolling( # pylint: disable=too-many-arguments
df: DataFrame,
columns: Dict[str, str],
rolling_type: str,
window: int,
rolling_type_options: Optional[Dict[str, Any]] = None,
center: bool = False,
win_type: Optional[str] = None,
min_periods: Optional[int] = None,
) -> DataFrame:
"""
Apply a rolling window on the dataset. See the Pandas docs for further details:
https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.rolling.html
:param df: DataFrame on which the rolling period will be based.
:param columns: columns on which to perform rolling, mapping source column to
target column. For instance, `{'y': 'y'}` will replace the column `y` with
the rolling value in `y`, while `{'y': 'y2'}` will add a column `y2` based
on rolling values calculated from `y`, leaving the original column `y`
unchanged.
:param rolling_type: Type of rolling window. Any numpy function will work.
:param rolling_type_options: Optional options to pass to rolling method. Needed
for e.g. quantile operation.
:param center: Should the label be at the center of the window.
:param win_type: Type of window function.
:param window: Size of the window.
:param min_periods:
:return: DataFrame with the rolling columns
:raises ChartDataValidationError: If the request in incorrect
"""
rolling_type_options = rolling_type_options or {}
df_rolling = df[columns.keys()]
kwargs: Dict[str, Union[str, int]] = {}
if not window:
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
kwargs["window"] = window
if min_periods is not None:
kwargs["min_periods"] = min_periods
if center is not None:
kwargs["center"] = center
if win_type is not None:
kwargs["win_type"] = win_type
df_rolling = df_rolling.rolling(**kwargs)
if rolling_type not in WHITELIST_ROLLING_FUNCTIONS or not hasattr(
df_rolling, rolling_type
):
raise QueryObjectValidationError(
_("Invalid rolling_type: %(type)s", type=rolling_type)
)
try:
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
except TypeError:
raise QueryObjectValidationError(
_(
"Invalid options for %(rolling_type)s: %(options)s",
rolling_type=rolling_type,
options=rolling_type_options,
)
)
df = _append_columns(df, df_rolling, columns)
if min_periods:
df = df[min_periods:]
return df
@validate_column_args("columns", "rename")
def select(
df: DataFrame, columns: List[str], rename: Optional[Dict[str, str]] = None
) -> DataFrame:
"""
Only select a subset of columns in the original dataset. Can be useful for
removing unnecessary intermediate results, renaming and reordering columns.
:param df: DataFrame on which the rolling period will be based.
:param columns: Columns which to select from the DataFrame, in the desired order.
If columns are renamed, the new column name should be referenced
here.
:param rename: columns which to rename, mapping source column to target column.
For instance, `{'y': 'y2'}` will rename the column `y` to
`y2`.
:return: Subset of columns in original DataFrame
:raises ChartDataValidationError: If the request in incorrect
"""
df_select = df[columns]
if rename is not None:
df_select = df_select.rename(columns=rename)
return df_select
@validate_column_args("columns")
def diff(df: DataFrame, columns: Dict[str, str], periods: int = 1,) -> DataFrame:
"""
:param df: DataFrame on which the diff will be based.
:param columns: columns on which to perform diff, mapping source column to
target column. For instance, `{'y': 'y'}` will replace the column `y` with
the diff value in `y`, while `{'y': 'y2'}` will add a column `y2` based
on diff values calculated from `y`, leaving the original column `y`
unchanged.
:param periods: periods to shift for calculating difference.
:return: DataFrame with diffed columns
:raises ChartDataValidationError: If the request in incorrect
"""
df_diff = df[columns.keys()]
df_diff = df_diff.diff(periods=periods)
return _append_columns(df, df_diff, columns)
@validate_column_args("columns")
def cum(df: DataFrame, columns: Dict[str, str], operator: str) -> DataFrame:
"""
:param df: DataFrame on which the cumulative operation will be based.
:param columns: columns on which to perform a cumulative operation, mapping source
column to target column. For instance, `{'y': 'y'}` will replace the column
`y` with the cumulative value in `y`, while `{'y': 'y2'}` will add a column
`y2` based on cumulative values calculated from `y`, leaving the original
column `y` unchanged.
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
:return:
"""
df_cum = df[columns.keys()]
operation = "cum" + operator
if operation not in WHITELIST_CUMULATIVE_FUNCTIONS or not hasattr(
df_cum, operation
):
raise QueryObjectValidationError(
_("Invalid cumulative operator: %(operator)s", operator=operator)
)
return _append_columns(df, getattr(df_cum, operation)(), columns)