feat: post-processing for pivot table v2 (#15879)

* feat: add pivot v2 post-processing

* Fix lint
This commit is contained in:
Beto Dealmeida
2021-07-29 11:05:56 -07:00
committed by GitHub
parent 6afa840659
commit f4739f427e
3 changed files with 504 additions and 1 deletions

View File

@@ -33,6 +33,13 @@ import pandas as pd
from superset.utils.core import DTTM_ALIAS, extract_dataframe_dtypes, get_metric_name
def sql_like_sum(series: pd.Series) -> pd.Series:
"""
A SUM aggregation function that mimics the behavior from SQL.
"""
return series.sum(min_count=1)
def pivot_table(
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None
) -> Dict[Any, Any]:
@@ -53,7 +60,7 @@ def pivot_table(
aggfunc = form_data.get("pandas_aggfunc") or "sum"
if pd.api.types.is_numeric_dtype(df[metric]):
if aggfunc == "sum":
aggfunc = lambda x: x.sum(min_count=1)
aggfunc = sql_like_sum
elif aggfunc not in {"min", "max"}:
aggfunc = "max"
aggfuncs[metric] = aggfunc
@@ -95,6 +102,120 @@ def pivot_table(
return result
def list_unique_values(series: pd.Series) -> str:
"""
List unique values in a series.
"""
return ", ".join(set(str(v) for v in pd.Series.unique(series)))
pivot_v2_aggfunc_map = {
"Count": pd.Series.count,
"Count Unique Values": pd.Series.nunique,
"List Unique Values": list_unique_values,
"Sum": pd.Series.sum,
"Average": pd.Series.mean,
"Median": pd.Series.median,
"Sample Variance": lambda series: pd.series.var(series) if len(series) > 1 else 0,
"Sample Standard Deviation": (
lambda series: pd.series.std(series) if len(series) > 1 else 0,
),
"Minimum": pd.Series.min,
"Maximum": pd.Series.max,
"First": lambda series: series[:1],
"Last": lambda series: series[-1:],
"Sum as Fraction of Total": pd.Series.sum,
"Sum as Fraction of Rows": pd.Series.sum,
"Sum as Fraction of Columns": pd.Series.sum,
"Count as Fraction of Total": pd.Series.count,
"Count as Fraction of Rows": pd.Series.count,
"Count as Fraction of Columns": pd.Series.count,
}
def pivot_table_v2( # pylint: disable=too-many-branches
result: Dict[Any, Any], form_data: Optional[Dict[str, Any]] = None,
) -> Dict[Any, Any]:
"""
Pivot table v2.
"""
for query in result["queries"]:
data = query["data"]
df = pd.DataFrame(data)
form_data = form_data or {}
if form_data.get("granularity_sqla") == "all" and DTTM_ALIAS in df:
del df[DTTM_ALIAS]
# TODO (betodealmeida): implement metricsLayout
metrics = [get_metric_name(m) for m in form_data["metrics"]]
aggregate_function = form_data.get("aggregateFunction", "Sum")
groupby = form_data.get("groupbyRows") or []
columns = form_data.get("groupbyColumns") or []
if form_data.get("transposePivot"):
groupby, columns = columns, groupby
df = df.pivot_table(
index=groupby,
columns=columns,
values=metrics,
aggfunc=pivot_v2_aggfunc_map[aggregate_function],
margins=True,
)
# The pandas `pivot_table` method either brings both row/column
# totals, or none at all. We pass `margin=True` to get both, and
# remove any dimension that was not requests.
if not form_data.get("rowTotals"):
df.drop(df.columns[len(df.columns) - 1], axis=1, inplace=True)
if not form_data.get("colTotals"):
df = df[:-1]
# Compute fractions, if needed. If `colTotals` or `rowTotals` are
# present we need to adjust for including them in the sum
if aggregate_function.endswith(" as Fraction of Total"):
total = df.sum().sum()
df = df.astype(total.dtypes) / total
if form_data.get("colTotals"):
df *= 2
if form_data.get("rowTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Columns"):
total = df.sum(axis=0)
df = df.astype(total.dtypes).div(total, axis=1)
if form_data.get("colTotals"):
df *= 2
elif aggregate_function.endswith(" as Fraction of Rows"):
total = df.sum(axis=1)
df = df.astype(total.dtypes).div(total, axis=0)
if form_data.get("rowTotals"):
df *= 2
# Re-order the columns adhering to the metric ordering.
df = df[metrics]
# Display metrics side by side with each column
if form_data.get("combineMetric"):
df = df.stack(0).unstack().reindex(level=-1, columns=metrics)
# flatten column names
df.columns = [" ".join(column) for column in df.columns]
# re-arrange data into a list of dicts
data = []
for i in df.index:
row = {col: df[col][i] for col in df.columns}
row[df.index.name] = i
data.append(row)
query["data"] = data
query["colnames"] = list(df.columns)
query["coltypes"] = extract_dataframe_dtypes(df)
query["rowcount"] = len(df.index)
return result
post_processors = {
"pivot_table": pivot_table,
"pivot_table_v2": pivot_table_v2,
}