fix: rolling and cum operator on multiple series (#16945)

* fix: rolling and cum operator on multiple series

* add UT

* updates

(cherry picked from commit fd8461406d)
This commit is contained in:
Yongjie Zhao
2021-10-07 09:42:08 +01:00
committed by Elizabeth Thompson
parent 4d33f7b7b6
commit 19d2fef490
3 changed files with 185 additions and 9 deletions

View File

@@ -35,6 +35,8 @@ from superset.utils.core import (
from .base_tests import SupersetTestCase
from .fixtures.dataframes import (
categories_df,
single_metric_df,
multiple_metrics_df,
lonlat_df,
names_df,
timeseries_df,
@@ -305,6 +307,23 @@ class TestPostProcessing(SupersetTestCase):
)
self.assertTrue(np.isnan(df["metric, 1, 1"][0]))
def test_pivot_without_flatten_columns_and_reset_index(self):
df = proc.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
# metric
# country UK US
# dttm
# 2019-01-01 5 6
# 2019-01-02 7 8
assert df.columns.to_list() == [("sum_metric", "UK"), ("sum_metric", "US")]
assert df.index.to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
def test_aggregate(self):
aggregates = {
"asc sum": {"column": "asc_idx", "operator": "sum"},
@@ -405,6 +424,60 @@ class TestPostProcessing(SupersetTestCase):
window=2,
)
def test_rolling_with_pivot_df_and_single_metric(self):
pivot_df = proc.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
rolling_df = proc.rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert rolling_df["UK"].to_list() == [5.0, 12.0]
assert rolling_df["US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
)
rolling_df = proc.rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
)
assert rolling_df.empty is True
def test_rolling_with_pivot_df_and_multiple_metrics(self):
pivot_df = proc.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
aggregates={
"sum_metric": {"operator": "sum"},
"count_metric": {"operator": "sum"},
},
flatten_columns=False,
reset_index=False,
)
rolling_df = proc.rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1.0 2.0 5.0 6.0
# 1 2019-01-02 4.0 6.0 12.0 14.0
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
)
def test_select(self):
# reorder columns
post_df = proc.select(df=timeseries_df, columns=["y", "label"])
@@ -557,6 +630,51 @@ class TestPostProcessing(SupersetTestCase):
operator="abc",
)
def test_cum_with_pivot_df_and_single_metric(self):
pivot_df = proc.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert cum_df["UK"].to_list() == [5.0, 12.0]
assert cum_df["US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
)
def test_cum_with_pivot_df_and_multiple_metrics(self):
pivot_df = proc.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
aggregates={
"sum_metric": {"operator": "sum"},
"count_metric": {"operator": "sum"},
},
flatten_columns=False,
reset_index=False,
)
cum_df = proc.cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1 2 5 6
# 1 2019-01-02 4 6 12 14
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
)
def test_geohash_decode(self):
# decode lon/lat from geohash
post_df = proc.geohash_decode(