mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
feat(advanced analysis): support MultiIndex column in post processing stage (#19116)
This commit is contained in:
committed by
Ville Brofeldt
parent
f8a92de75c
commit
9bc76337cf
@@ -20,6 +20,7 @@ from superset.utils.pandas_postprocessing.compare import compare
|
||||
from superset.utils.pandas_postprocessing.contribution import contribution
|
||||
from superset.utils.pandas_postprocessing.cum import cum
|
||||
from superset.utils.pandas_postprocessing.diff import diff
|
||||
from superset.utils.pandas_postprocessing.flatten import flatten
|
||||
from superset.utils.pandas_postprocessing.geography import (
|
||||
geodetic_parse,
|
||||
geohash_decode,
|
||||
@@ -49,5 +50,6 @@ __all__ = [
|
||||
"rolling",
|
||||
"select",
|
||||
"sort",
|
||||
"flatten",
|
||||
"_flatten_column_after_pivot",
|
||||
]
|
||||
|
||||
@@ -35,7 +35,7 @@ def aggregate(
|
||||
:param groupby: columns to aggregate
|
||||
:param aggregates: A mapping from metric column to the function used to
|
||||
aggregate values.
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
aggregates = aggregates or {}
|
||||
aggregate_funcs = _get_aggregate_funcs(df, aggregates)
|
||||
|
||||
@@ -20,7 +20,7 @@ import numpy as np
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame, Series
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import PostProcessingBoxplotWhiskerType
|
||||
from superset.utils.pandas_postprocessing.aggregate import aggregate
|
||||
|
||||
@@ -84,7 +84,7 @@ def boxplot(
|
||||
or not isinstance(percentiles[1], (int, float))
|
||||
or percentiles[0] >= percentiles[1]
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
"percentiles must be a list or tuple with two numeric values, "
|
||||
"of which the first is lower than the second value"
|
||||
|
||||
@@ -21,7 +21,7 @@ from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.constants import PandasPostprocessingCompare
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import TIME_COMPARISION
|
||||
from superset.utils.pandas_postprocessing.utils import validate_column_args
|
||||
|
||||
@@ -31,7 +31,7 @@ def compare( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
source_columns: List[str],
|
||||
compare_columns: List[str],
|
||||
compare_type: Optional[PandasPostprocessingCompare],
|
||||
compare_type: PandasPostprocessingCompare,
|
||||
drop_original_columns: Optional[bool] = False,
|
||||
precision: Optional[int] = 4,
|
||||
) -> DataFrame:
|
||||
@@ -46,31 +46,38 @@ def compare( # pylint: disable=too-many-arguments
|
||||
compare columns.
|
||||
:param precision: Round a change rate to a variable number of decimal places.
|
||||
:return: DataFrame with compared columns.
|
||||
:raises QueryObjectValidationError: If the request in incorrect.
|
||||
:raises InvalidPostProcessingError: If the request in incorrect.
|
||||
"""
|
||||
if len(source_columns) != len(compare_columns):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("`compare_columns` must have the same length as `source_columns`.")
|
||||
)
|
||||
if compare_type not in tuple(PandasPostprocessingCompare):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("`compare_type` must be `difference`, `percentage` or `ratio`")
|
||||
)
|
||||
if len(source_columns) == 0:
|
||||
return df
|
||||
|
||||
for s_col, c_col in zip(source_columns, compare_columns):
|
||||
s_df = df.loc[:, [s_col]]
|
||||
s_df.rename(columns={s_col: "__intermediate"}, inplace=True)
|
||||
c_df = df.loc[:, [c_col]]
|
||||
c_df.rename(columns={c_col: "__intermediate"}, inplace=True)
|
||||
if compare_type == PandasPostprocessingCompare.DIFF:
|
||||
diff_series = df[s_col] - df[c_col]
|
||||
diff_df = c_df - s_df
|
||||
elif compare_type == PandasPostprocessingCompare.PCT:
|
||||
diff_series = (
|
||||
((df[s_col] - df[c_col]) / df[c_col]).astype(float).round(precision)
|
||||
)
|
||||
# https://en.wikipedia.org/wiki/Relative_change_and_difference#Percentage_change
|
||||
diff_df = ((c_df - s_df) / s_df).astype(float).round(precision)
|
||||
else:
|
||||
# compare_type == "ratio"
|
||||
diff_series = (df[s_col] / df[c_col]).astype(float).round(precision)
|
||||
diff_df = diff_series.to_frame(
|
||||
name=TIME_COMPARISION.join([compare_type, s_col, c_col])
|
||||
diff_df = (c_df / s_df).astype(float).round(precision)
|
||||
|
||||
diff_df.rename(
|
||||
columns={
|
||||
"__intermediate": TIME_COMPARISION.join([compare_type, s_col, c_col])
|
||||
},
|
||||
inplace=True,
|
||||
)
|
||||
df = pd.concat([df, diff_df], axis=1)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import List, Optional
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import PostProcessingContributionOrientation
|
||||
from superset.utils.pandas_postprocessing.utils import validate_column_args
|
||||
|
||||
@@ -55,7 +55,7 @@ def contribution(
|
||||
numeric_columns = numeric_df.columns.tolist()
|
||||
for col in columns:
|
||||
if col not in numeric_columns:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
'Column "%(column)s" is not numeric or does not '
|
||||
"exists in the query results.",
|
||||
@@ -65,7 +65,7 @@ def contribution(
|
||||
columns = columns or numeric_df.columns
|
||||
rename_columns = rename_columns or columns
|
||||
if len(rename_columns) != len(columns):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("`rename_columns` must have the same length as `columns`.")
|
||||
)
|
||||
# limit to selected columns
|
||||
|
||||
@@ -14,27 +14,21 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Dict
|
||||
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
_append_columns,
|
||||
_flatten_column_after_pivot,
|
||||
ALLOWLIST_CUMULATIVE_FUNCTIONS,
|
||||
validate_column_args,
|
||||
)
|
||||
|
||||
|
||||
@validate_column_args("columns")
|
||||
def cum(
|
||||
df: DataFrame,
|
||||
operator: str,
|
||||
columns: Optional[Dict[str, str]] = None,
|
||||
is_pivot_df: bool = False,
|
||||
) -> DataFrame:
|
||||
def cum(df: DataFrame, operator: str, columns: Dict[str, str],) -> DataFrame:
|
||||
"""
|
||||
Calculate cumulative sum/product/min/max for select columns.
|
||||
|
||||
@@ -45,29 +39,16 @@ def cum(
|
||||
`y2` based on cumulative values calculated from `y`, leaving the original
|
||||
column `y` unchanged.
|
||||
:param operator: cumulative operator, e.g. `sum`, `prod`, `min`, `max`
|
||||
:param is_pivot_df: Dataframe is pivoted or not
|
||||
:return: DataFrame with cumulated columns
|
||||
"""
|
||||
columns = columns or {}
|
||||
if is_pivot_df:
|
||||
df_cum = df
|
||||
else:
|
||||
df_cum = df[columns.keys()]
|
||||
df_cum = df.loc[:, columns.keys()]
|
||||
operation = "cum" + operator
|
||||
if operation not in ALLOWLIST_CUMULATIVE_FUNCTIONS or not hasattr(
|
||||
df_cum, operation
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Invalid cumulative operator: %(operator)s", operator=operator)
|
||||
)
|
||||
if is_pivot_df:
|
||||
df_cum = getattr(df_cum, operation)()
|
||||
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
|
||||
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
|
||||
df_cum.columns = [
|
||||
_flatten_column_after_pivot(col, agg) for col in df_cum.columns
|
||||
]
|
||||
df_cum.reset_index(level=0, inplace=True)
|
||||
else:
|
||||
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
|
||||
df_cum = _append_columns(df, getattr(df_cum, operation)(), columns)
|
||||
return df_cum
|
||||
|
||||
@@ -44,7 +44,7 @@ def diff(
|
||||
:param periods: periods to shift for calculating difference.
|
||||
:param axis: 0 for row, 1 for column. default 0.
|
||||
:return: DataFrame with diffed columns
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
df_diff = df[columns.keys()]
|
||||
df_diff = df_diff.diff(periods=periods, axis=axis)
|
||||
|
||||
81
superset/utils/pandas_postprocessing/flatten.py
Normal file
81
superset/utils/pandas_postprocessing/flatten.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
_is_multi_index_on_columns,
|
||||
FLAT_COLUMN_SEPARATOR,
|
||||
)
|
||||
|
||||
|
||||
def flatten(df: pd.DataFrame, reset_index: bool = True,) -> pd.DataFrame:
|
||||
"""
|
||||
Convert N-dimensional DataFrame to a flat DataFrame
|
||||
|
||||
:param df: N-dimensional DataFrame.
|
||||
:param reset_index: Convert index to column when df.index isn't RangeIndex
|
||||
:return: a flat DataFrame
|
||||
|
||||
Examples
|
||||
-----------
|
||||
|
||||
Convert DatetimeIndex into columns.
|
||||
|
||||
>>> index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03",])
|
||||
>>> index.name = "__timestamp"
|
||||
>>> df = pd.DataFrame(index=index, data={"metric": [1, 2, 3]})
|
||||
>>> df
|
||||
metric
|
||||
__timestamp
|
||||
2021-01-01 1
|
||||
2021-01-02 2
|
||||
2021-01-03 3
|
||||
>>> df = flatten(df)
|
||||
>>> df
|
||||
__timestamp metric
|
||||
0 2021-01-01 1
|
||||
1 2021-01-02 2
|
||||
2 2021-01-03 3
|
||||
|
||||
Convert DatetimeIndex and MultipleIndex into columns
|
||||
|
||||
>>> iterables = [["foo", "bar"], ["one", "two"]]
|
||||
>>> columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
|
||||
>>> df = pd.DataFrame(index=index, columns=columns, data=1)
|
||||
>>> df
|
||||
level1 foo bar
|
||||
level2 one two one two
|
||||
__timestamp
|
||||
2021-01-01 1 1 1 1
|
||||
2021-01-02 1 1 1 1
|
||||
2021-01-03 1 1 1 1
|
||||
>>> flatten(df)
|
||||
__timestamp foo, one foo, two bar, one bar, two
|
||||
0 2021-01-01 1 1 1 1
|
||||
1 2021-01-02 1 1 1 1
|
||||
2 2021-01-03 1 1 1 1
|
||||
"""
|
||||
if _is_multi_index_on_columns(df):
|
||||
# every cell should be converted to string
|
||||
df.columns = [
|
||||
FLAT_COLUMN_SEPARATOR.join([str(cell) for cell in series])
|
||||
for series in df.columns.to_flat_index()
|
||||
]
|
||||
|
||||
if reset_index and not isinstance(df.index, pd.RangeIndex):
|
||||
df = df.reset_index(level=0)
|
||||
return df
|
||||
@@ -21,7 +21,7 @@ from flask_babel import gettext as _
|
||||
from geopy.point import Point
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.utils import _append_columns
|
||||
|
||||
|
||||
@@ -46,7 +46,7 @@ def geohash_decode(
|
||||
df, lonlat_df, {"latitude": latitude, "longitude": longitude}
|
||||
)
|
||||
except ValueError as ex:
|
||||
raise QueryObjectValidationError(_("Invalid geohash string")) from ex
|
||||
raise InvalidPostProcessingError(_("Invalid geohash string")) from ex
|
||||
|
||||
|
||||
def geohash_encode(
|
||||
@@ -69,7 +69,7 @@ def geohash_encode(
|
||||
)
|
||||
return _append_columns(df, encode_df, {"geohash": geohash})
|
||||
except ValueError as ex:
|
||||
raise QueryObjectValidationError(_("Invalid longitude/latitude")) from ex
|
||||
raise InvalidPostProcessingError(_("Invalid longitude/latitude")) from ex
|
||||
|
||||
|
||||
def geodetic_parse(
|
||||
@@ -111,4 +111,4 @@ def geodetic_parse(
|
||||
columns["altitude"] = altitude
|
||||
return _append_columns(df, geodetic_df, columns)
|
||||
except ValueError as ex:
|
||||
raise QueryObjectValidationError(_("Invalid geodetic string")) from ex
|
||||
raise InvalidPostProcessingError(_("Invalid geodetic string")) from ex
|
||||
|
||||
@@ -20,7 +20,7 @@ from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.constants import NULL_STRING, PandasAxis
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
_flatten_column_after_pivot,
|
||||
_get_aggregate_funcs,
|
||||
@@ -64,14 +64,14 @@ def pivot( # pylint: disable=too-many-arguments,too-many-locals
|
||||
:param flatten_columns: Convert column names to strings
|
||||
:param reset_index: Convert index to column
|
||||
:return: A pivot table
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
if not index:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Pivot operation requires at least one index")
|
||||
)
|
||||
if not aggregates:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Pivot operation must include at least one aggregate")
|
||||
)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from typing import Optional, Union
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import DTTM_ALIAS
|
||||
from superset.utils.pandas_postprocessing.utils import PROPHET_TIME_GRAIN_MAP
|
||||
|
||||
@@ -58,7 +58,7 @@ def _prophet_fit_and_predict( # pylint: disable=too-many-arguments
|
||||
prophet_logger.setLevel(logging.CRITICAL)
|
||||
prophet_logger.setLevel(logging.NOTSET)
|
||||
except ModuleNotFoundError as ex:
|
||||
raise QueryObjectValidationError(_("`prophet` package not installed")) from ex
|
||||
raise InvalidPostProcessingError(_("`prophet` package not installed")) from ex
|
||||
model = Prophet(
|
||||
interval_width=confidence_interval,
|
||||
yearly_seasonality=yearly_seasonality,
|
||||
@@ -111,24 +111,24 @@ def prophet( # pylint: disable=too-many-arguments
|
||||
index = index or DTTM_ALIAS
|
||||
# validate inputs
|
||||
if not time_grain:
|
||||
raise QueryObjectValidationError(_("Time grain missing"))
|
||||
raise InvalidPostProcessingError(_("Time grain missing"))
|
||||
if time_grain not in PROPHET_TIME_GRAIN_MAP:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Unsupported time grain: %(time_grain)s", time_grain=time_grain,)
|
||||
)
|
||||
freq = PROPHET_TIME_GRAIN_MAP[time_grain]
|
||||
# check type at runtime due to marhsmallow schema not being able to handle
|
||||
# union types
|
||||
if not isinstance(periods, int) or periods < 0:
|
||||
raise QueryObjectValidationError(_("Periods must be a whole number"))
|
||||
raise InvalidPostProcessingError(_("Periods must be a whole number"))
|
||||
if not confidence_interval or confidence_interval <= 0 or confidence_interval >= 1:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Confidence interval must be between 0 and 1 (exclusive)")
|
||||
)
|
||||
if index not in df.columns:
|
||||
raise QueryObjectValidationError(_("DataFrame must include temporal column"))
|
||||
raise InvalidPostProcessingError(_("DataFrame must include temporal column"))
|
||||
if len(df.columns) < 2:
|
||||
raise QueryObjectValidationError(_("DataFrame include at least one series"))
|
||||
raise InvalidPostProcessingError(_("DataFrame include at least one series"))
|
||||
|
||||
target_df = DataFrame()
|
||||
for column in [column for column in df.columns if column != index]:
|
||||
|
||||
@@ -14,48 +14,35 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import Optional, Union
|
||||
|
||||
from pandas import DataFrame
|
||||
import pandas as pd
|
||||
from flask_babel import gettext as _
|
||||
|
||||
from superset.utils.pandas_postprocessing.utils import validate_column_args
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
|
||||
|
||||
@validate_column_args("groupby_columns")
|
||||
def resample( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
def resample(
|
||||
df: pd.DataFrame,
|
||||
rule: str,
|
||||
method: str,
|
||||
time_column: str,
|
||||
groupby_columns: Optional[Tuple[Optional[str], ...]] = None,
|
||||
fill_value: Optional[Union[float, int]] = None,
|
||||
) -> DataFrame:
|
||||
) -> pd.DataFrame:
|
||||
"""
|
||||
support upsampling in resample
|
||||
|
||||
:param df: DataFrame to resample.
|
||||
:param rule: The offset string representing target conversion.
|
||||
:param method: How to fill the NaN value after resample.
|
||||
:param time_column: existing columns in DataFrame.
|
||||
:param groupby_columns: columns except time_column in dataframe
|
||||
:param fill_value: What values do fill missing.
|
||||
:return: DataFrame after resample
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
if not isinstance(df.index, pd.DatetimeIndex):
|
||||
raise InvalidPostProcessingError(_("Resample operation requires DatetimeIndex"))
|
||||
|
||||
def _upsampling(_df: DataFrame) -> DataFrame:
|
||||
_df = _df.set_index(time_column)
|
||||
if method == "asfreq" and fill_value is not None:
|
||||
return _df.resample(rule).asfreq(fill_value=fill_value)
|
||||
return getattr(_df.resample(rule), method)()
|
||||
|
||||
if groupby_columns:
|
||||
df = (
|
||||
df.set_index(keys=list(groupby_columns))
|
||||
.groupby(by=list(groupby_columns))
|
||||
.apply(_upsampling)
|
||||
)
|
||||
df = df.reset_index().set_index(time_column).sort_index()
|
||||
if method == "asfreq" and fill_value is not None:
|
||||
_df = df.resample(rule).asfreq(fill_value=fill_value)
|
||||
else:
|
||||
df = _upsampling(df)
|
||||
return df.reset_index()
|
||||
_df = getattr(df.resample(rule), method)()
|
||||
return _df
|
||||
|
||||
@@ -19,10 +19,9 @@ from typing import Any, Dict, Optional, Union
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.utils import (
|
||||
_append_columns,
|
||||
_flatten_column_after_pivot,
|
||||
DENYLIST_ROLLING_FUNCTIONS,
|
||||
validate_column_args,
|
||||
)
|
||||
@@ -32,13 +31,12 @@ from superset.utils.pandas_postprocessing.utils import (
|
||||
def rolling( # pylint: disable=too-many-arguments
|
||||
df: DataFrame,
|
||||
rolling_type: str,
|
||||
columns: Optional[Dict[str, str]] = None,
|
||||
columns: Dict[str, str],
|
||||
window: Optional[int] = None,
|
||||
rolling_type_options: Optional[Dict[str, Any]] = None,
|
||||
center: bool = False,
|
||||
win_type: Optional[str] = None,
|
||||
min_periods: Optional[int] = None,
|
||||
is_pivot_df: bool = False,
|
||||
) -> DataFrame:
|
||||
"""
|
||||
Apply a rolling window on the dataset. See the Pandas docs for further details:
|
||||
@@ -58,21 +56,17 @@ def rolling( # pylint: disable=too-many-arguments
|
||||
:param win_type: Type of window function.
|
||||
:param min_periods: The minimum amount of periods required for a row to be included
|
||||
in the result set.
|
||||
:param is_pivot_df: Dataframe is pivoted or not
|
||||
:return: DataFrame with the rolling columns
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
rolling_type_options = rolling_type_options or {}
|
||||
columns = columns or {}
|
||||
if is_pivot_df:
|
||||
df_rolling = df
|
||||
else:
|
||||
df_rolling = df[columns.keys()]
|
||||
df_rolling = df.loc[:, columns.keys()]
|
||||
|
||||
kwargs: Dict[str, Union[str, int]] = {}
|
||||
if window is None:
|
||||
raise QueryObjectValidationError(_("Undefined window for rolling operation"))
|
||||
raise InvalidPostProcessingError(_("Undefined window for rolling operation"))
|
||||
if window == 0:
|
||||
raise QueryObjectValidationError(_("Window must be > 0"))
|
||||
raise InvalidPostProcessingError(_("Window must be > 0"))
|
||||
|
||||
kwargs["window"] = window
|
||||
if min_periods is not None:
|
||||
@@ -86,13 +80,13 @@ def rolling( # pylint: disable=too-many-arguments
|
||||
if rolling_type not in DENYLIST_ROLLING_FUNCTIONS or not hasattr(
|
||||
df_rolling, rolling_type
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Invalid rolling_type: %(type)s", type=rolling_type)
|
||||
)
|
||||
try:
|
||||
df_rolling = getattr(df_rolling, rolling_type)(**rolling_type_options)
|
||||
except TypeError as ex:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
"Invalid options for %(rolling_type)s: %(options)s",
|
||||
rolling_type=rolling_type,
|
||||
@@ -100,15 +94,7 @@ def rolling( # pylint: disable=too-many-arguments
|
||||
)
|
||||
) from ex
|
||||
|
||||
if is_pivot_df:
|
||||
agg_in_pivot_df = df.columns.get_level_values(0).drop_duplicates().to_list()
|
||||
agg: Dict[str, Dict[str, Any]] = {col: {} for col in agg_in_pivot_df}
|
||||
df_rolling.columns = [
|
||||
_flatten_column_after_pivot(col, agg) for col in df_rolling.columns
|
||||
]
|
||||
df_rolling.reset_index(level=0, inplace=True)
|
||||
else:
|
||||
df_rolling = _append_columns(df, df_rolling, columns)
|
||||
df_rolling = _append_columns(df, df_rolling, columns)
|
||||
|
||||
if min_periods:
|
||||
df_rolling = df_rolling[min_periods:]
|
||||
|
||||
@@ -42,7 +42,7 @@ def select(
|
||||
For instance, `{'y': 'y2'}` will rename the column `y` to
|
||||
`y2`.
|
||||
:return: Subset of columns in original DataFrame
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
df_select = df.copy(deep=False)
|
||||
if columns:
|
||||
|
||||
@@ -30,6 +30,6 @@ def sort(df: DataFrame, columns: Dict[str, bool]) -> DataFrame:
|
||||
:param columns: columns by by which to sort. The key specifies the column name,
|
||||
value specifies if sorting in ascending order.
|
||||
:return: Sorted DataFrame
|
||||
:raises QueryObjectValidationError: If the request in incorrect
|
||||
:raises InvalidPostProcessingError: If the request in incorrect
|
||||
"""
|
||||
return df.sort_values(by=list(columns.keys()), ascending=list(columns.values()))
|
||||
|
||||
@@ -18,10 +18,11 @@ from functools import partial
|
||||
from typing import Any, Callable, Dict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from flask_babel import gettext as _
|
||||
from pandas import DataFrame, NamedAgg, Timestamp
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
|
||||
NUMPY_FUNCTIONS = {
|
||||
"average": np.average,
|
||||
@@ -91,6 +92,8 @@ PROPHET_TIME_GRAIN_MAP = {
|
||||
"P1W/1970-01-04T00:00:00Z": "W",
|
||||
}
|
||||
|
||||
FLAT_COLUMN_SEPARATOR = ", "
|
||||
|
||||
|
||||
def _flatten_column_after_pivot(
|
||||
column: Union[float, Timestamp, str, Tuple[str, ...]],
|
||||
@@ -113,21 +116,26 @@ def _flatten_column_after_pivot(
|
||||
# drop aggregate for single aggregate pivots with multiple groupings
|
||||
# from column name (aggregates always come first in column name)
|
||||
column = column[1:]
|
||||
return ", ".join([str(col) for col in column])
|
||||
return FLAT_COLUMN_SEPARATOR.join([str(col) for col in column])
|
||||
|
||||
|
||||
def _is_multi_index_on_columns(df: DataFrame) -> bool:
|
||||
return isinstance(df.columns, pd.MultiIndex)
|
||||
|
||||
|
||||
def validate_column_args(*argnames: str) -> Callable[..., Any]:
|
||||
def wrapper(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
def wrapped(df: DataFrame, **options: Any) -> Any:
|
||||
if options.get("is_pivot_df"):
|
||||
# skip validation when pivot Dataframe
|
||||
return func(df, **options)
|
||||
columns = df.columns.tolist()
|
||||
if _is_multi_index_on_columns(df):
|
||||
# MultiIndex column validate first level
|
||||
columns = df.columns.get_level_values(0)
|
||||
else:
|
||||
columns = df.columns.tolist()
|
||||
for name in argnames:
|
||||
if name in options and not all(
|
||||
elem in columns for elem in options.get(name) or []
|
||||
):
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Referenced columns not available in DataFrame.")
|
||||
)
|
||||
return func(df, **options)
|
||||
@@ -152,14 +160,14 @@ def _get_aggregate_funcs(
|
||||
for name, agg_obj in aggregates.items():
|
||||
column = agg_obj.get("column", name)
|
||||
if column not in df:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_(
|
||||
"Column referenced by aggregate is undefined: %(column)s",
|
||||
column=column,
|
||||
)
|
||||
)
|
||||
if "operator" not in agg_obj:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Operator undefined for aggregator: %(name)s", name=name,)
|
||||
)
|
||||
operator = agg_obj["operator"]
|
||||
@@ -168,7 +176,7 @@ def _get_aggregate_funcs(
|
||||
else:
|
||||
func = NUMPY_FUNCTIONS.get(operator)
|
||||
if not func:
|
||||
raise QueryObjectValidationError(
|
||||
raise InvalidPostProcessingError(
|
||||
_("Invalid numpy function: %(operator)s", operator=operator,)
|
||||
)
|
||||
options = agg_obj.get("options", {})
|
||||
@@ -186,6 +194,8 @@ def _append_columns(
|
||||
assign method, which overwrites the original column in `base_df` if the column
|
||||
already exists, and appends the column if the name is not defined.
|
||||
|
||||
Note that! this is a memory-intensive operation.
|
||||
|
||||
:param base_df: DataFrame which to use as the base
|
||||
:param append_df: DataFrame from which to select data.
|
||||
:param columns: columns on which to append, mapping source column to
|
||||
@@ -196,6 +206,10 @@ def _append_columns(
|
||||
in `base_df` unchanged.
|
||||
:return: new DataFrame with combined data from `base_df` and `append_df`
|
||||
"""
|
||||
return base_df.assign(
|
||||
**{target: append_df[source] for source, target in columns.items()}
|
||||
)
|
||||
if all(key == value for key, value in columns.items()):
|
||||
# make sure to return a new DataFrame instead of changing the `base_df`.
|
||||
_base_df = base_df.copy()
|
||||
_base_df.loc[:, columns.keys()] = append_df
|
||||
return _base_df
|
||||
append_df = append_df.rename(columns=columns)
|
||||
return pd.concat([base_df, append_df], axis="columns")
|
||||
|
||||
Reference in New Issue
Block a user