feat(advanced analysis): support MultiIndex column in post processing stage (#19116)

This commit is contained in:
Yongjie Zhao
2022-03-23 13:46:28 +08:00
committed by GitHub
parent 6083545e86
commit 375c03e084
55 changed files with 1267 additions and 772 deletions

View File

@@ -172,18 +172,21 @@ POSTPROCESSING_OPERATIONS = {
{
"operation": "aggregate",
"options": {
"groupby": ["gender"],
"groupby": ["name"],
"aggregates": {
"q1": {
"operator": "percentile",
"column": "sum__num",
"options": {"q": 25},
# todo: rename "interpolation" to "method" when we updated
# numpy.
# https://numpy.org/doc/stable/reference/generated/numpy.percentile.html
"options": {"q": 25, "interpolation": "lower"},
},
"median": {"operator": "median", "column": "sum__num",},
},
},
},
{"operation": "sort", "options": {"columns": {"q1": False, "gender": True},},},
{"operation": "sort", "options": {"columns": {"q1": False, "name": True},},},
]
}

View File

@@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import datetime
import re
import time
from typing import Any, Dict
@@ -30,7 +29,7 @@ from superset.common.query_object import QueryObject
from superset.connectors.connector_registry import ConnectorRegistry
from superset.connectors.sqla.models import SqlMetric
from superset.extensions import cache_manager
from superset.utils.core import AdhocMetricExpressionType, backend
from superset.utils.core import AdhocMetricExpressionType, backend, QueryStatus
from tests.integration_tests.base_tests import SupersetTestCase
from tests.integration_tests.fixtures.birth_names_dashboard import (
load_birth_names_dashboard_with_slices,
@@ -91,8 +90,9 @@ class TestQueryContext(SupersetTestCase):
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_cache(self):
table_name = "birth_names"
table = self.get_table(name=table_name)
payload = get_query_context(table_name, table.id)
payload = get_query_context(
query_name=table_name, add_postprocessing_operations=True,
)
payload["force"] = True
query_context = ChartDataQueryContextSchema().load(payload)
@@ -100,6 +100,10 @@ class TestQueryContext(SupersetTestCase):
query_cache_key = query_context.query_cache_key(query_object)
response = query_context.get_payload(cache_query_context=True)
# MUST BE a successful query
query_dump = response["queries"][0]
assert query_dump["status"] == QueryStatus.SUCCESS
cache_key = response["cache_key"]
assert cache_key is not None

View File

@@ -16,7 +16,7 @@
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import PostProcessingBoxplotWhiskerType
from superset.utils.pandas_postprocessing import boxplot
from tests.unit_tests.fixtures.dataframes import names_df
@@ -90,7 +90,7 @@ def test_boxplot_percentile():
def test_boxplot_percentile_incorrect_params():
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
@@ -98,7 +98,7 @@ def test_boxplot_percentile_incorrect_params():
metrics=["cars"],
)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
@@ -107,7 +107,7 @@ def test_boxplot_percentile_incorrect_params():
percentiles=[10],
)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],
@@ -116,7 +116,7 @@ def test_boxplot_percentile_incorrect_params():
percentiles=[90, 10],
)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
boxplot(
df=names_df,
groupby=["region"],

View File

@@ -14,49 +14,220 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
from superset.utils.pandas_postprocessing import compare
from tests.unit_tests.fixtures.dataframes import timeseries_df2
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
from superset.constants import PandasPostprocessingCompare as PPC
from superset.utils import pandas_postprocessing as pp
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.unit_tests.fixtures.dataframes import multiple_metrics_df, timeseries_df2
def test_compare():
def test_compare_should_not_side_effect():
_timeseries_df2 = timeseries_df2.copy()
pp.compare(
df=_timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type=PPC.DIFF,
)
assert _timeseries_df2.equals(timeseries_df2)
def test_compare_diff():
# `difference` comparison
post_df = compare(
post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="difference",
compare_type=PPC.DIFF,
)
"""
label y z difference__y__z
2019-01-01 x 2.0 2.0 0.0
2019-01-02 y 2.0 4.0 2.0
2019-01-05 z 2.0 10.0 8.0
2019-01-07 q 2.0 8.0 6.0
"""
assert post_df.equals(
pd.DataFrame(
index=timeseries_df2.index,
data={
"label": ["x", "y", "z", "q"],
"y": [2.0, 2.0, 2.0, 2.0],
"z": [2.0, 4.0, 10.0, 8.0],
"difference__y__z": [0.0, 2.0, 8.0, 6.0],
},
)
)
assert post_df.columns.tolist() == ["label", "y", "z", "difference__y__z"]
assert series_to_list(post_df["difference__y__z"]) == [0.0, -2.0, -8.0, -6.0]
# drop original columns
post_df = compare(
post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="difference",
compare_type=PPC.DIFF,
drop_original_columns=True,
)
assert post_df.columns.tolist() == ["label", "difference__y__z"]
assert post_df.equals(
pd.DataFrame(
index=timeseries_df2.index,
data={
"label": ["x", "y", "z", "q"],
"difference__y__z": [0.0, 2.0, 8.0, 6.0],
},
)
)
def test_compare_percentage():
# `percentage` comparison
post_df = compare(
post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="percentage",
compare_type=PPC.PCT,
)
"""
label y z percentage__y__z
2019-01-01 x 2.0 2.0 0.0
2019-01-02 y 2.0 4.0 1.0
2019-01-05 z 2.0 10.0 4.0
2019-01-07 q 2.0 8.0 3.0
"""
assert post_df.equals(
pd.DataFrame(
index=timeseries_df2.index,
data={
"label": ["x", "y", "z", "q"],
"y": [2.0, 2.0, 2.0, 2.0],
"z": [2.0, 4.0, 10.0, 8.0],
"percentage__y__z": [0.0, 1.0, 4.0, 3.0],
},
)
)
assert post_df.columns.tolist() == ["label", "y", "z", "percentage__y__z"]
assert series_to_list(post_df["percentage__y__z"]) == [0.0, -0.5, -0.8, -0.75]
def test_compare_ratio():
# `ratio` comparison
post_df = compare(
post_df = pp.compare(
df=timeseries_df2,
source_columns=["y"],
compare_columns=["z"],
compare_type="ratio",
compare_type=PPC.RAT,
)
"""
label y z ratio__y__z
2019-01-01 x 2.0 2.0 1.0
2019-01-02 y 2.0 4.0 2.0
2019-01-05 z 2.0 10.0 5.0
2019-01-07 q 2.0 8.0 4.0
"""
assert post_df.equals(
pd.DataFrame(
index=timeseries_df2.index,
data={
"label": ["x", "y", "z", "q"],
"y": [2.0, 2.0, 2.0, 2.0],
"z": [2.0, 4.0, 10.0, 8.0],
"ratio__y__z": [1.0, 2.0, 5.0, 4.0],
},
)
)
def test_compare_multi_index_column():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]]
columns = pd.MultiIndex.from_product(iterables, names=[None, "level1", "level2"])
df = pd.DataFrame(index=index, columns=columns, data=1)
"""
m1 m2
level1 a b a b
level2 x y x y x y x y
__timestamp
2021-01-01 1 1 1 1 1 1 1 1
2021-01-02 1 1 1 1 1 1 1 1
2021-01-03 1 1 1 1 1 1 1 1
"""
post_df = pp.compare(
df,
source_columns=["m1"],
compare_columns=["m2"],
compare_type=PPC.DIFF,
drop_original_columns=True,
)
flat_df = pp.flatten(post_df)
"""
__timestamp difference__m1__m2, a, x difference__m1__m2, a, y difference__m1__m2, b, x difference__m1__m2, b, y
0 2021-01-01 0 0 0 0
1 2021-01-02 0 0 0 0
2 2021-01-03 0 0 0 0
"""
assert flat_df.equals(
pd.DataFrame(
data={
"__timestamp": pd.to_datetime(
["2021-01-01", "2021-01-02", "2021-01-03"]
),
"difference__m1__m2, a, x": [0, 0, 0],
"difference__m1__m2, a, y": [0, 0, 0],
"difference__m1__m2, b, x": [0, 0, 0],
"difference__m1__m2, b, y": [0, 0, 0],
}
)
)
def test_compare_after_pivot():
pivot_df = pp.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
aggregates={
"sum_metric": {"operator": "sum"},
"count_metric": {"operator": "sum"},
},
flatten_columns=False,
reset_index=False,
)
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1 2 5 6
2019-01-02 3 4 7 8
"""
compared_df = pp.compare(
pivot_df,
source_columns=["count_metric"],
compare_columns=["sum_metric"],
compare_type=PPC.DIFF,
drop_original_columns=True,
)
"""
difference__count_metric__sum_metric
country UK US
dttm
2019-01-01 4 4
2019-01-02 4 4
"""
flat_df = pp.flatten(compared_df)
"""
dttm difference__count_metric__sum_metric, UK difference__count_metric__sum_metric, US
0 2019-01-01 4 4
1 2019-01-02 4 4
"""
assert flat_df.equals(
pd.DataFrame(
data={
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(
["difference__count_metric__sum_metric", "UK"]
): [4, 4],
FLAT_COLUMN_SEPARATOR.join(
["difference__count_metric__sum_metric", "US"]
): [4, 4],
}
)
)
assert post_df.columns.tolist() == ["label", "y", "z", "ratio__y__z"]
assert series_to_list(post_df["ratio__y__z"]) == [1.0, 0.5, 0.2, 0.25]

View File

@@ -22,7 +22,7 @@ from numpy import nan
from numpy.testing import assert_array_equal
from pandas import DataFrame
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
from superset.utils.pandas_postprocessing import contribution
@@ -40,10 +40,10 @@ def test_contribution():
"c": [nan, nan, nan],
}
)
with pytest.raises(QueryObjectValidationError, match="not numeric"):
with pytest.raises(InvalidPostProcessingError, match="not numeric"):
contribution(df, columns=[DTTM_ALIAS])
with pytest.raises(QueryObjectValidationError, match="same length"):
with pytest.raises(InvalidPostProcessingError, match="same length"):
contribution(df, columns=["a"], rename_columns=["aa", "bb"])
# cell contribution across row

View File

@@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
import pytest
from pandas import to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import cum, pivot
from superset.exceptions import InvalidPostProcessingError
from superset.utils import pandas_postprocessing as pp
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.unit_tests.fixtures.dataframes import (
multiple_metrics_df,
single_metric_df,
@@ -27,33 +28,41 @@ from tests.unit_tests.fixtures.dataframes import (
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_cum_should_not_side_effect():
_timeseries_df = timeseries_df.copy()
pp.cum(
df=timeseries_df, columns={"y": "y2"}, operator="sum",
)
assert _timeseries_df.equals(timeseries_df)
def test_cum():
# create new column (cumsum)
post_df = cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
post_df = pp.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
assert post_df.columns.tolist() == ["label", "y", "y2"]
assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0]
# overwrite column (cumprod)
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0]
# overwrite column (cummin)
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
assert post_df.columns.tolist() == ["label", "y"]
assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0]
# invalid operator
with pytest.raises(QueryObjectValidationError):
cum(
with pytest.raises(InvalidPostProcessingError):
pp.cum(
df=timeseries_df, columns={"y": "y"}, operator="abc",
)
def test_cum_with_pivot_df_and_single_metric():
pivot_df = pivot(
def test_cum_after_pivot_with_single_metric():
pivot_df = pp.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
@@ -61,19 +70,40 @@ def test_cum_with_pivot_df_and_single_metric():
flatten_columns=False,
reset_index=False,
)
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert cum_df["UK"].to_list() == [5.0, 12.0]
assert cum_df["US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
"""
sum_metric
country UK US
dttm
2019-01-01 5 6
2019-01-02 7 8
"""
cum_df = pp.cum(df=pivot_df, operator="sum", columns={"sum_metric": "sum_metric"})
"""
sum_metric
country UK US
dttm
2019-01-01 5 6
2019-01-02 12 14
"""
cum_and_flat_df = pp.flatten(cum_df)
"""
dttm sum_metric, UK sum_metric, US
0 2019-01-01 5 6
1 2019-01-02 12 14
"""
assert cum_and_flat_df.equals(
pd.DataFrame(
{
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
}
)
)
def test_cum_with_pivot_df_and_multiple_metrics():
pivot_df = pivot(
def test_cum_after_pivot_with_multiple_metrics():
pivot_df = pp.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
@@ -84,14 +114,39 @@ def test_cum_with_pivot_df_and_multiple_metrics():
flatten_columns=False,
reset_index=False,
)
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1 2 5 6
# 1 2019-01-02 4 6 12 14
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1 2 5 6
2019-01-02 3 4 7 8
"""
cum_df = pp.cum(
df=pivot_df,
operator="sum",
columns={"sum_metric": "sum_metric", "count_metric": "count_metric"},
)
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1 2 5 6
2019-01-02 4 6 12 14
"""
flat_df = pp.flatten(cum_df)
"""
dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
0 2019-01-01 1 2 5 6
1 2019-01-02 4 6 12 14
"""
assert flat_df.equals(
pd.DataFrame(
{
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1, 4],
FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2, 6],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
}
)
)

View File

@@ -16,7 +16,7 @@
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import diff
from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
@@ -39,7 +39,7 @@ def test_diff():
assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None]
# invalid column reference
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
diff(
df=timeseries_df, columns={"abc": "abc"},
)

View File

@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
from superset.utils import pandas_postprocessing as pp
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
def test_flat_should_not_change():
df = pd.DataFrame(data={"foo": [1, 2, 3], "bar": [4, 5, 6],})
assert pp.flatten(df).equals(df)
def test_flat_should_not_reset_index():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
assert pp.flatten(df, reset_index=False).equals(df)
def test_flat_should_flat_datetime_index():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
assert pp.flatten(df).equals(
pd.DataFrame({"__timestamp": index, "foo": [1, 2, 3], "bar": [4, 5, 6],})
)
def test_flat_should_flat_multiple_index():
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
index.name = "__timestamp"
iterables = [["foo", "bar"], [1, "two"]]
columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
df = pd.DataFrame(index=index, columns=columns, data=1)
assert pp.flatten(df).equals(
pd.DataFrame(
{
"__timestamp": index,
FLAT_COLUMN_SEPARATOR.join(["foo", "1"]): [1, 1, 1],
FLAT_COLUMN_SEPARATOR.join(["foo", "two"]): [1, 1, 1],
FLAT_COLUMN_SEPARATOR.join(["bar", "1"]): [1, 1, 1],
FLAT_COLUMN_SEPARATOR.join(["bar", "two"]): [1, 1, 1],
}
)
)

View File

@@ -19,7 +19,7 @@ import numpy as np
import pytest
from pandas import DataFrame, Timestamp, to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import _flatten_column_after_pivot, pivot
from tests.unit_tests.fixtures.dataframes import categories_df, single_metric_df
from tests.unit_tests.pandas_postprocessing.utils import (
@@ -172,7 +172,7 @@ def test_pivot_exceptions():
pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE)
# invalid index reference
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
pivot(
df=categories_df,
index=["abc"],
@@ -181,7 +181,7 @@ def test_pivot_exceptions():
)
# invalid column reference
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
pivot(
df=categories_df,
index=["dept"],
@@ -190,7 +190,7 @@ def test_pivot_exceptions():
)
# invalid aggregate options
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
pivot(
df=categories_df,
index=["name"],

View File

@@ -19,7 +19,7 @@ from importlib.util import find_spec
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.core import DTTM_ALIAS
from superset.utils.pandas_postprocessing import prophet
from tests.unit_tests.fixtures.dataframes import prophet_df
@@ -75,40 +75,40 @@ def test_prophet_valid_zero_periods():
def test_prophet_import():
dynamic_module = find_spec("prophet")
if dynamic_module is None:
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
def test_prophet_missing_temporal_column():
df = prophet_df.drop(DTTM_ALIAS, axis=1)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=df, time_grain="P1M", periods=3, confidence_interval=0.9,
)
def test_prophet_incorrect_confidence_interval():
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.0,
)
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=1.0,
)
def test_prophet_incorrect_periods():
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="P1M", periods=-1, confidence_interval=0.8,
)
def test_prophet_incorrect_time_grain():
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
prophet(
df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8,
)

View File

@@ -14,45 +14,80 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
import pytest
from pandas import DataFrame, to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import resample
from tests.unit_tests.fixtures.dataframes import timeseries_df
from superset.exceptions import InvalidPostProcessingError
from superset.utils import pandas_postprocessing as pp
from tests.unit_tests.fixtures.dataframes import categories_df, timeseries_df
def test_resample_should_not_side_effect():
_timeseries_df = timeseries_df.copy()
pp.resample(df=_timeseries_df, rule="1D", method="ffill")
assert _timeseries_df.equals(timeseries_df)
def test_resample():
df = timeseries_df.copy()
df.index.name = "time_column"
df.reset_index(inplace=True)
post_df = resample(df=df, rule="1D", method="ffill", time_column="time_column",)
assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"]
assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]
post_df = resample(
df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
post_df = pp.resample(df=timeseries_df, rule="1D", method="ffill")
"""
label y
2019-01-01 x 1.0
2019-01-02 y 2.0
2019-01-03 y 2.0
2019-01-04 y 2.0
2019-01-05 z 3.0
2019-01-06 z 3.0
2019-01-07 q 4.0
"""
assert post_df.equals(
pd.DataFrame(
index=pd.to_datetime(
[
"2019-01-01",
"2019-01-02",
"2019-01-03",
"2019-01-04",
"2019-01-05",
"2019-01-06",
"2019-01-07",
]
),
data={
"label": ["x", "y", "y", "y", "z", "z", "q"],
"y": [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0],
},
)
)
assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"]
assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0]
def test_resample_with_groupby():
"""
The Dataframe contains a timestamp column, a string column and a numeric column.
__timestamp city val
0 2022-01-13 Chicago 6.0
1 2022-01-13 LA 5.0
2 2022-01-13 NY 4.0
3 2022-01-11 Chicago 3.0
4 2022-01-11 LA 2.0
5 2022-01-11 NY 1.0
"""
df = DataFrame(
{
"__timestamp": to_datetime(
def test_resample_zero_fill():
post_df = pp.resample(df=timeseries_df, rule="1D", method="asfreq", fill_value=0)
assert post_df.equals(
pd.DataFrame(
index=pd.to_datetime(
[
"2019-01-01",
"2019-01-02",
"2019-01-03",
"2019-01-04",
"2019-01-05",
"2019-01-06",
"2019-01-07",
]
),
data={
"label": ["x", "y", 0, 0, "z", 0, "q"],
"y": [1.0, 2.0, 0, 0, 3.0, 0, 4.0],
},
)
)
def test_resample_after_pivot():
df = pd.DataFrame(
data={
"__timestamp": pd.to_datetime(
[
"2022-01-13",
"2022-01-13",
@@ -66,42 +101,53 @@ __timestamp city val
"val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
}
)
post_df = resample(
pivot_df = pp.pivot(
df=df,
rule="1D",
method="asfreq",
fill_value=0,
time_column="__timestamp",
groupby_columns=("city",),
index=["__timestamp"],
columns=["city"],
aggregates={"val": {"operator": "sum"},},
flatten_columns=False,
reset_index=False,
)
assert list(post_df.columns) == [
"__timestamp",
"city",
"val",
]
assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
"""
val
city Chicago LA NY
__timestamp
2022-01-11 3.0 2.0 1.0
2022-01-13 6.0 5.0 4.0
"""
resample_df = pp.resample(df=pivot_df, rule="1D", method="asfreq", fill_value=0,)
"""
val
city Chicago LA NY
__timestamp
2022-01-11 3.0 2.0 1.0
2022-01-12 0.0 0.0 0.0
2022-01-13 6.0 5.0 4.0
"""
flat_df = pp.flatten(resample_df)
"""
__timestamp val, Chicago val, LA val, NY
0 2022-01-11 3.0 2.0 1.0
1 2022-01-12 0.0 0.0 0.0
2 2022-01-13 6.0 5.0 4.0
"""
assert flat_df.equals(
pd.DataFrame(
data={
"__timestamp": pd.to_datetime(
["2022-01-11", "2022-01-12", "2022-01-13"]
),
"val, Chicago": [3.0, 0, 6.0],
"val, LA": [2.0, 0, 5.0],
"val, NY": [1.0, 0, 4.0],
}
)
)
assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
# should raise error when get a non-existent column
with pytest.raises(QueryObjectValidationError):
resample(
df=df,
rule="1D",
method="asfreq",
fill_value=0,
time_column="__timestamp",
groupby_columns=("city", "unkonw_column",),
)
# should raise error when get a None value in groupby list
with pytest.raises(QueryObjectValidationError):
resample(
df=df,
rule="1D",
method="asfreq",
fill_value=0,
time_column="__timestamp",
groupby_columns=("city", None,),
def test_resample_should_raise_ex():
with pytest.raises(InvalidPostProcessingError):
pp.resample(
df=categories_df, rule="1D", method="asfreq",
)

View File

@@ -14,11 +14,12 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
import pytest
from pandas import to_datetime
from superset.exceptions import QueryObjectValidationError
from superset.utils.pandas_postprocessing import pivot, rolling
from superset.exceptions import InvalidPostProcessingError
from superset.utils import pandas_postprocessing as pp
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
from tests.unit_tests.fixtures.dataframes import (
multiple_metrics_df,
single_metric_df,
@@ -27,9 +28,21 @@ from tests.unit_tests.fixtures.dataframes import (
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
def test_rolling_should_not_side_effect():
_timeseries_df = timeseries_df.copy()
pp.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="sum",
window=2,
min_periods=0,
)
assert _timeseries_df.equals(timeseries_df)
def test_rolling():
# sum rolling type
post_df = rolling(
post_df = pp.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="sum",
@@ -41,7 +54,7 @@ def test_rolling():
assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0]
# mean rolling type with alias
post_df = rolling(
post_df = pp.rolling(
df=timeseries_df,
rolling_type="mean",
columns={"y": "y_mean"},
@@ -52,7 +65,7 @@ def test_rolling():
assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5]
# count rolling type
post_df = rolling(
post_df = pp.rolling(
df=timeseries_df,
rolling_type="count",
columns={"y": "y"},
@@ -63,7 +76,7 @@ def test_rolling():
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
# quantile rolling type
post_df = rolling(
post_df = pp.rolling(
df=timeseries_df,
columns={"y": "q1"},
rolling_type="quantile",
@@ -75,14 +88,14 @@ def test_rolling():
assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75]
# incorrect rolling type
with pytest.raises(QueryObjectValidationError):
rolling(
with pytest.raises(InvalidPostProcessingError):
pp.rolling(
df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2,
)
# incorrect rolling type options
with pytest.raises(QueryObjectValidationError):
rolling(
with pytest.raises(InvalidPostProcessingError):
pp.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="quantile",
@@ -91,8 +104,8 @@ def test_rolling():
)
def test_rolling_with_pivot_df_and_single_metric():
pivot_df = pivot(
def test_rolling_should_empty_df():
pivot_df = pp.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
@@ -100,27 +113,65 @@ def test_rolling_with_pivot_df_and_single_metric():
flatten_columns=False,
reset_index=False,
)
rolling_df = rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
)
# dttm UK US
# 0 2019-01-01 5 6
# 1 2019-01-02 12 14
assert rolling_df["UK"].to_list() == [5.0, 12.0]
assert rolling_df["US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02"]).to_list()
)
rolling_df = rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
rolling_df = pp.rolling(
df=pivot_df,
rolling_type="sum",
window=2,
min_periods=2,
columns={"sum_metric": "sum_metric"},
)
assert rolling_df.empty is True
def test_rolling_with_pivot_df_and_multiple_metrics():
pivot_df = pivot(
def test_rolling_after_pivot_with_single_metric():
pivot_df = pp.pivot(
df=single_metric_df,
index=["dttm"],
columns=["country"],
aggregates={"sum_metric": {"operator": "sum"}},
flatten_columns=False,
reset_index=False,
)
"""
sum_metric
country UK US
dttm
2019-01-01 5 6
2019-01-02 7 8
"""
rolling_df = pp.rolling(
df=pivot_df,
columns={"sum_metric": "sum_metric"},
rolling_type="sum",
window=2,
min_periods=0,
)
"""
sum_metric
country UK US
dttm
2019-01-01 5.0 6.0
2019-01-02 12.0 14.0
"""
flat_df = pp.flatten(rolling_df)
"""
dttm sum_metric, UK sum_metric, US
0 2019-01-01 5.0 6.0
1 2019-01-02 12.0 14.0
"""
assert flat_df.equals(
pd.DataFrame(
data={
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
}
)
)
def test_rolling_after_pivot_with_multiple_metrics():
pivot_df = pp.pivot(
df=multiple_metrics_df,
index=["dttm"],
columns=["country"],
@@ -131,17 +182,41 @@ def test_rolling_with_pivot_df_and_multiple_metrics():
flatten_columns=False,
reset_index=False,
)
rolling_df = rolling(
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1 2 5 6
2019-01-02 3 4 7 8
"""
rolling_df = pp.rolling(
df=pivot_df,
columns={"count_metric": "count_metric", "sum_metric": "sum_metric",},
rolling_type="sum",
window=2,
min_periods=0,
)
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
# 0 2019-01-01 1.0 2.0 5.0 6.0
# 1 2019-01-02 4.0 6.0 12.0 14.0
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
assert (
rolling_df["dttm"].to_list()
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
"""
count_metric sum_metric
country UK US UK US
dttm
2019-01-01 1.0 2.0 5.0 6.0
2019-01-02 4.0 6.0 12.0 14.0
"""
flat_df = pp.flatten(rolling_df)
"""
dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
0 2019-01-01 1.0 2.0 5.0 6.0
1 2019-01-02 4.0 6.0 12.0 14.0
"""
assert flat_df.equals(
pd.DataFrame(
data={
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1.0, 4.0],
FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2.0, 6.0],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
}
)
)

View File

@@ -16,7 +16,7 @@
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing.select import select
from tests.unit_tests.fixtures.dataframes import timeseries_df
@@ -47,9 +47,9 @@ def test_select():
assert post_df.columns.tolist() == ["y1"]
# invalid columns
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"})
# select renamed column by new name
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
select(df=timeseries_df, columns=["label_new"], rename={"label": "label_new"})

View File

@@ -16,7 +16,7 @@
# under the License.
import pytest
from superset.exceptions import QueryObjectValidationError
from superset.exceptions import InvalidPostProcessingError
from superset.utils.pandas_postprocessing import sort
from tests.unit_tests.fixtures.dataframes import categories_df
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
@@ -26,5 +26,5 @@ def test_sort():
df = sort(df=categories_df, columns={"category": True, "asc_idx": False})
assert series_to_list(df["asc_idx"])[1] == 96
with pytest.raises(QueryObjectValidationError):
with pytest.raises(InvalidPostProcessingError):
sort(df=df, columns={"abc": True})