mirror of
https://github.com/apache/superset.git
synced 2026-04-17 23:25:05 +00:00
feat(advanced analysis): support MultiIndex column in post processing stage (#19116)
This commit is contained in:
@@ -14,45 +14,80 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas import DataFrame, to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import resample
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df, timeseries_df
|
||||
|
||||
|
||||
def test_resample_should_not_side_effect():
|
||||
_timeseries_df = timeseries_df.copy()
|
||||
pp.resample(df=_timeseries_df, rule="1D", method="ffill")
|
||||
assert _timeseries_df.equals(timeseries_df)
|
||||
|
||||
|
||||
def test_resample():
|
||||
df = timeseries_df.copy()
|
||||
df.index.name = "time_column"
|
||||
df.reset_index(inplace=True)
|
||||
|
||||
post_df = resample(df=df, rule="1D", method="ffill", time_column="time_column",)
|
||||
assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"]
|
||||
|
||||
assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]
|
||||
|
||||
post_df = resample(
|
||||
df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
|
||||
post_df = pp.resample(df=timeseries_df, rule="1D", method="ffill")
|
||||
"""
|
||||
label y
|
||||
2019-01-01 x 1.0
|
||||
2019-01-02 y 2.0
|
||||
2019-01-03 y 2.0
|
||||
2019-01-04 y 2.0
|
||||
2019-01-05 z 3.0
|
||||
2019-01-06 z 3.0
|
||||
2019-01-07 q 4.0
|
||||
"""
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=pd.to_datetime(
|
||||
[
|
||||
"2019-01-01",
|
||||
"2019-01-02",
|
||||
"2019-01-03",
|
||||
"2019-01-04",
|
||||
"2019-01-05",
|
||||
"2019-01-06",
|
||||
"2019-01-07",
|
||||
]
|
||||
),
|
||||
data={
|
||||
"label": ["x", "y", "y", "y", "z", "z", "q"],
|
||||
"y": [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"]
|
||||
assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0]
|
||||
|
||||
|
||||
def test_resample_with_groupby():
|
||||
"""
|
||||
The Dataframe contains a timestamp column, a string column and a numeric column.
|
||||
__timestamp city val
|
||||
0 2022-01-13 Chicago 6.0
|
||||
1 2022-01-13 LA 5.0
|
||||
2 2022-01-13 NY 4.0
|
||||
3 2022-01-11 Chicago 3.0
|
||||
4 2022-01-11 LA 2.0
|
||||
5 2022-01-11 NY 1.0
|
||||
"""
|
||||
df = DataFrame(
|
||||
{
|
||||
"__timestamp": to_datetime(
|
||||
def test_resample_zero_fill():
|
||||
post_df = pp.resample(df=timeseries_df, rule="1D", method="asfreq", fill_value=0)
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=pd.to_datetime(
|
||||
[
|
||||
"2019-01-01",
|
||||
"2019-01-02",
|
||||
"2019-01-03",
|
||||
"2019-01-04",
|
||||
"2019-01-05",
|
||||
"2019-01-06",
|
||||
"2019-01-07",
|
||||
]
|
||||
),
|
||||
data={
|
||||
"label": ["x", "y", 0, 0, "z", 0, "q"],
|
||||
"y": [1.0, 2.0, 0, 0, 3.0, 0, 4.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_resample_after_pivot():
|
||||
df = pd.DataFrame(
|
||||
data={
|
||||
"__timestamp": pd.to_datetime(
|
||||
[
|
||||
"2022-01-13",
|
||||
"2022-01-13",
|
||||
@@ -66,42 +101,53 @@ __timestamp city val
|
||||
"val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
|
||||
}
|
||||
)
|
||||
post_df = resample(
|
||||
pivot_df = pp.pivot(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city",),
|
||||
index=["__timestamp"],
|
||||
columns=["city"],
|
||||
aggregates={"val": {"operator": "sum"},},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
assert list(post_df.columns) == [
|
||||
"__timestamp",
|
||||
"city",
|
||||
"val",
|
||||
]
|
||||
assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
|
||||
["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
|
||||
"""
|
||||
val
|
||||
city Chicago LA NY
|
||||
__timestamp
|
||||
2022-01-11 3.0 2.0 1.0
|
||||
2022-01-13 6.0 5.0 4.0
|
||||
"""
|
||||
resample_df = pp.resample(df=pivot_df, rule="1D", method="asfreq", fill_value=0,)
|
||||
"""
|
||||
val
|
||||
city Chicago LA NY
|
||||
__timestamp
|
||||
2022-01-11 3.0 2.0 1.0
|
||||
2022-01-12 0.0 0.0 0.0
|
||||
2022-01-13 6.0 5.0 4.0
|
||||
"""
|
||||
flat_df = pp.flatten(resample_df)
|
||||
"""
|
||||
__timestamp val, Chicago val, LA val, NY
|
||||
0 2022-01-11 3.0 2.0 1.0
|
||||
1 2022-01-12 0.0 0.0 0.0
|
||||
2 2022-01-13 6.0 5.0 4.0
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"__timestamp": pd.to_datetime(
|
||||
["2022-01-11", "2022-01-12", "2022-01-13"]
|
||||
),
|
||||
"val, Chicago": [3.0, 0, 6.0],
|
||||
"val, LA": [2.0, 0, 5.0],
|
||||
"val, NY": [1.0, 0, 4.0],
|
||||
}
|
||||
)
|
||||
)
|
||||
assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
|
||||
|
||||
# should raise error when get a non-existent column
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
resample(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city", "unkonw_column",),
|
||||
)
|
||||
|
||||
# should raise error when get a None value in groupby list
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
resample(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city", None,),
|
||||
def test_resample_should_raise_ex():
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pp.resample(
|
||||
df=categories_df, rule="1D", method="asfreq",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user