mirror of
https://github.com/apache/superset.git
synced 2026-04-25 19:14:27 +00:00
feat(advanced analysis): support MultiIndex column in post processing stage (#19116)
This commit is contained in:
@@ -172,18 +172,21 @@ POSTPROCESSING_OPERATIONS = {
|
||||
{
|
||||
"operation": "aggregate",
|
||||
"options": {
|
||||
"groupby": ["gender"],
|
||||
"groupby": ["name"],
|
||||
"aggregates": {
|
||||
"q1": {
|
||||
"operator": "percentile",
|
||||
"column": "sum__num",
|
||||
"options": {"q": 25},
|
||||
# todo: rename "interpolation" to "method" when we updated
|
||||
# numpy.
|
||||
# https://numpy.org/doc/stable/reference/generated/numpy.percentile.html
|
||||
"options": {"q": 25, "interpolation": "lower"},
|
||||
},
|
||||
"median": {"operator": "median", "column": "sum__num",},
|
||||
},
|
||||
},
|
||||
},
|
||||
{"operation": "sort", "options": {"columns": {"q1": False, "gender": True},},},
|
||||
{"operation": "sort", "options": {"columns": {"q1": False, "name": True},},},
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import datetime
|
||||
import re
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
@@ -30,7 +29,7 @@ from superset.common.query_object import QueryObject
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.connectors.sqla.models import SqlMetric
|
||||
from superset.extensions import cache_manager
|
||||
from superset.utils.core import AdhocMetricExpressionType, backend
|
||||
from superset.utils.core import AdhocMetricExpressionType, backend, QueryStatus
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
from tests.integration_tests.fixtures.birth_names_dashboard import (
|
||||
load_birth_names_dashboard_with_slices,
|
||||
@@ -91,8 +90,9 @@ class TestQueryContext(SupersetTestCase):
|
||||
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||
def test_cache(self):
|
||||
table_name = "birth_names"
|
||||
table = self.get_table(name=table_name)
|
||||
payload = get_query_context(table_name, table.id)
|
||||
payload = get_query_context(
|
||||
query_name=table_name, add_postprocessing_operations=True,
|
||||
)
|
||||
payload["force"] = True
|
||||
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
@@ -100,6 +100,10 @@ class TestQueryContext(SupersetTestCase):
|
||||
query_cache_key = query_context.query_cache_key(query_object)
|
||||
|
||||
response = query_context.get_payload(cache_query_context=True)
|
||||
# MUST BE a successful query
|
||||
query_dump = response["queries"][0]
|
||||
assert query_dump["status"] == QueryStatus.SUCCESS
|
||||
|
||||
cache_key = response["cache_key"]
|
||||
assert cache_key is not None
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import PostProcessingBoxplotWhiskerType
|
||||
from superset.utils.pandas_postprocessing import boxplot
|
||||
from tests.unit_tests.fixtures.dataframes import names_df
|
||||
@@ -90,7 +90,7 @@ def test_boxplot_percentile():
|
||||
|
||||
|
||||
def test_boxplot_percentile_incorrect_params():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
@@ -98,7 +98,7 @@ def test_boxplot_percentile_incorrect_params():
|
||||
metrics=["cars"],
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
@@ -107,7 +107,7 @@ def test_boxplot_percentile_incorrect_params():
|
||||
percentiles=[10],
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
@@ -116,7 +116,7 @@ def test_boxplot_percentile_incorrect_params():
|
||||
percentiles=[90, 10],
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
boxplot(
|
||||
df=names_df,
|
||||
groupby=["region"],
|
||||
|
||||
@@ -14,49 +14,220 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
|
||||
from superset.utils.pandas_postprocessing import compare
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df2
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
from superset.constants import PandasPostprocessingCompare as PPC
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
from tests.unit_tests.fixtures.dataframes import multiple_metrics_df, timeseries_df2
|
||||
|
||||
|
||||
def test_compare():
|
||||
def test_compare_should_not_side_effect():
|
||||
_timeseries_df2 = timeseries_df2.copy()
|
||||
pp.compare(
|
||||
df=_timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type=PPC.DIFF,
|
||||
)
|
||||
assert _timeseries_df2.equals(timeseries_df2)
|
||||
|
||||
|
||||
def test_compare_diff():
|
||||
# `difference` comparison
|
||||
post_df = compare(
|
||||
post_df = pp.compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="difference",
|
||||
compare_type=PPC.DIFF,
|
||||
)
|
||||
"""
|
||||
label y z difference__y__z
|
||||
2019-01-01 x 2.0 2.0 0.0
|
||||
2019-01-02 y 2.0 4.0 2.0
|
||||
2019-01-05 z 2.0 10.0 8.0
|
||||
2019-01-07 q 2.0 8.0 6.0
|
||||
"""
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=timeseries_df2.index,
|
||||
data={
|
||||
"label": ["x", "y", "z", "q"],
|
||||
"y": [2.0, 2.0, 2.0, 2.0],
|
||||
"z": [2.0, 4.0, 10.0, 8.0],
|
||||
"difference__y__z": [0.0, 2.0, 8.0, 6.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z", "difference__y__z"]
|
||||
assert series_to_list(post_df["difference__y__z"]) == [0.0, -2.0, -8.0, -6.0]
|
||||
|
||||
# drop original columns
|
||||
post_df = compare(
|
||||
post_df = pp.compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="difference",
|
||||
compare_type=PPC.DIFF,
|
||||
drop_original_columns=True,
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "difference__y__z"]
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=timeseries_df2.index,
|
||||
data={
|
||||
"label": ["x", "y", "z", "q"],
|
||||
"difference__y__z": [0.0, 2.0, 8.0, 6.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_compare_percentage():
|
||||
# `percentage` comparison
|
||||
post_df = compare(
|
||||
post_df = pp.compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="percentage",
|
||||
compare_type=PPC.PCT,
|
||||
)
|
||||
"""
|
||||
label y z percentage__y__z
|
||||
2019-01-01 x 2.0 2.0 0.0
|
||||
2019-01-02 y 2.0 4.0 1.0
|
||||
2019-01-05 z 2.0 10.0 4.0
|
||||
2019-01-07 q 2.0 8.0 3.0
|
||||
"""
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=timeseries_df2.index,
|
||||
data={
|
||||
"label": ["x", "y", "z", "q"],
|
||||
"y": [2.0, 2.0, 2.0, 2.0],
|
||||
"z": [2.0, 4.0, 10.0, 8.0],
|
||||
"percentage__y__z": [0.0, 1.0, 4.0, 3.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z", "percentage__y__z"]
|
||||
assert series_to_list(post_df["percentage__y__z"]) == [0.0, -0.5, -0.8, -0.75]
|
||||
|
||||
|
||||
def test_compare_ratio():
|
||||
# `ratio` comparison
|
||||
post_df = compare(
|
||||
post_df = pp.compare(
|
||||
df=timeseries_df2,
|
||||
source_columns=["y"],
|
||||
compare_columns=["z"],
|
||||
compare_type="ratio",
|
||||
compare_type=PPC.RAT,
|
||||
)
|
||||
"""
|
||||
label y z ratio__y__z
|
||||
2019-01-01 x 2.0 2.0 1.0
|
||||
2019-01-02 y 2.0 4.0 2.0
|
||||
2019-01-05 z 2.0 10.0 5.0
|
||||
2019-01-07 q 2.0 8.0 4.0
|
||||
"""
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=timeseries_df2.index,
|
||||
data={
|
||||
"label": ["x", "y", "z", "q"],
|
||||
"y": [2.0, 2.0, 2.0, 2.0],
|
||||
"z": [2.0, 4.0, 10.0, 8.0],
|
||||
"ratio__y__z": [1.0, 2.0, 5.0, 4.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_compare_multi_index_column():
|
||||
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
|
||||
index.name = "__timestamp"
|
||||
iterables = [["m1", "m2"], ["a", "b"], ["x", "y"]]
|
||||
columns = pd.MultiIndex.from_product(iterables, names=[None, "level1", "level2"])
|
||||
df = pd.DataFrame(index=index, columns=columns, data=1)
|
||||
"""
|
||||
m1 m2
|
||||
level1 a b a b
|
||||
level2 x y x y x y x y
|
||||
__timestamp
|
||||
2021-01-01 1 1 1 1 1 1 1 1
|
||||
2021-01-02 1 1 1 1 1 1 1 1
|
||||
2021-01-03 1 1 1 1 1 1 1 1
|
||||
"""
|
||||
post_df = pp.compare(
|
||||
df,
|
||||
source_columns=["m1"],
|
||||
compare_columns=["m2"],
|
||||
compare_type=PPC.DIFF,
|
||||
drop_original_columns=True,
|
||||
)
|
||||
flat_df = pp.flatten(post_df)
|
||||
"""
|
||||
__timestamp difference__m1__m2, a, x difference__m1__m2, a, y difference__m1__m2, b, x difference__m1__m2, b, y
|
||||
0 2021-01-01 0 0 0 0
|
||||
1 2021-01-02 0 0 0 0
|
||||
2 2021-01-03 0 0 0 0
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"__timestamp": pd.to_datetime(
|
||||
["2021-01-01", "2021-01-02", "2021-01-03"]
|
||||
),
|
||||
"difference__m1__m2, a, x": [0, 0, 0],
|
||||
"difference__m1__m2, a, y": [0, 0, 0],
|
||||
"difference__m1__m2, b, x": [0, 0, 0],
|
||||
"difference__m1__m2, b, y": [0, 0, 0],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_compare_after_pivot():
|
||||
pivot_df = pp.pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={
|
||||
"sum_metric": {"operator": "sum"},
|
||||
"count_metric": {"operator": "sum"},
|
||||
},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1 2 5 6
|
||||
2019-01-02 3 4 7 8
|
||||
"""
|
||||
compared_df = pp.compare(
|
||||
pivot_df,
|
||||
source_columns=["count_metric"],
|
||||
compare_columns=["sum_metric"],
|
||||
compare_type=PPC.DIFF,
|
||||
drop_original_columns=True,
|
||||
)
|
||||
"""
|
||||
difference__count_metric__sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 4 4
|
||||
2019-01-02 4 4
|
||||
"""
|
||||
flat_df = pp.flatten(compared_df)
|
||||
"""
|
||||
dttm difference__count_metric__sum_metric, UK difference__count_metric__sum_metric, US
|
||||
0 2019-01-01 4 4
|
||||
1 2019-01-02 4 4
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(
|
||||
["difference__count_metric__sum_metric", "UK"]
|
||||
): [4, 4],
|
||||
FLAT_COLUMN_SEPARATOR.join(
|
||||
["difference__count_metric__sum_metric", "US"]
|
||||
): [4, 4],
|
||||
}
|
||||
)
|
||||
)
|
||||
assert post_df.columns.tolist() == ["label", "y", "z", "ratio__y__z"]
|
||||
assert series_to_list(post_df["ratio__y__z"]) == [1.0, 0.5, 0.2, 0.25]
|
||||
|
||||
@@ -22,7 +22,7 @@ from numpy import nan
|
||||
from numpy.testing import assert_array_equal
|
||||
from pandas import DataFrame
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import DTTM_ALIAS, PostProcessingContributionOrientation
|
||||
from superset.utils.pandas_postprocessing import contribution
|
||||
|
||||
@@ -40,10 +40,10 @@ def test_contribution():
|
||||
"c": [nan, nan, nan],
|
||||
}
|
||||
)
|
||||
with pytest.raises(QueryObjectValidationError, match="not numeric"):
|
||||
with pytest.raises(InvalidPostProcessingError, match="not numeric"):
|
||||
contribution(df, columns=[DTTM_ALIAS])
|
||||
|
||||
with pytest.raises(QueryObjectValidationError, match="same length"):
|
||||
with pytest.raises(InvalidPostProcessingError, match="same length"):
|
||||
contribution(df, columns=["a"], rename_columns=["aa", "bb"])
|
||||
|
||||
# cell contribution across row
|
||||
|
||||
@@ -14,11 +14,12 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas import to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import cum, pivot
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
from tests.unit_tests.fixtures.dataframes import (
|
||||
multiple_metrics_df,
|
||||
single_metric_df,
|
||||
@@ -27,33 +28,41 @@ from tests.unit_tests.fixtures.dataframes import (
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_cum_should_not_side_effect():
|
||||
_timeseries_df = timeseries_df.copy()
|
||||
pp.cum(
|
||||
df=timeseries_df, columns={"y": "y2"}, operator="sum",
|
||||
)
|
||||
assert _timeseries_df.equals(timeseries_df)
|
||||
|
||||
|
||||
def test_cum():
|
||||
# create new column (cumsum)
|
||||
post_df = cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
|
||||
post_df = pp.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
|
||||
assert post_df.columns.tolist() == ["label", "y", "y2"]
|
||||
assert series_to_list(post_df["label"]) == ["x", "y", "z", "q"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
|
||||
assert series_to_list(post_df["y2"]) == [1.0, 3.0, 6.0, 10.0]
|
||||
|
||||
# overwrite column (cumprod)
|
||||
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
|
||||
post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
|
||||
assert post_df.columns.tolist() == ["label", "y"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 6.0, 24.0]
|
||||
|
||||
# overwrite column (cummin)
|
||||
post_df = cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
|
||||
post_df = pp.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
|
||||
assert post_df.columns.tolist() == ["label", "y"]
|
||||
assert series_to_list(post_df["y"]) == [1.0, 1.0, 1.0, 1.0]
|
||||
|
||||
# invalid operator
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
cum(
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pp.cum(
|
||||
df=timeseries_df, columns={"y": "y"}, operator="abc",
|
||||
)
|
||||
|
||||
|
||||
def test_cum_with_pivot_df_and_single_metric():
|
||||
pivot_df = pivot(
|
||||
def test_cum_after_pivot_with_single_metric():
|
||||
pivot_df = pp.pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
@@ -61,19 +70,40 @@ def test_cum_with_pivot_df_and_single_metric():
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
||||
# dttm UK US
|
||||
# 0 2019-01-01 5 6
|
||||
# 1 2019-01-02 12 14
|
||||
assert cum_df["UK"].to_list() == [5.0, 12.0]
|
||||
assert cum_df["US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
"""
|
||||
sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 5 6
|
||||
2019-01-02 7 8
|
||||
"""
|
||||
cum_df = pp.cum(df=pivot_df, operator="sum", columns={"sum_metric": "sum_metric"})
|
||||
"""
|
||||
sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 5 6
|
||||
2019-01-02 12 14
|
||||
"""
|
||||
cum_and_flat_df = pp.flatten(cum_df)
|
||||
"""
|
||||
dttm sum_metric, UK sum_metric, US
|
||||
0 2019-01-01 5 6
|
||||
1 2019-01-02 12 14
|
||||
"""
|
||||
assert cum_and_flat_df.equals(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_cum_with_pivot_df_and_multiple_metrics():
|
||||
pivot_df = pivot(
|
||||
def test_cum_after_pivot_with_multiple_metrics():
|
||||
pivot_df = pp.pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
@@ -84,14 +114,39 @@ def test_cum_with_pivot_df_and_multiple_metrics():
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
cum_df = cum(df=pivot_df, operator="sum", is_pivot_df=True,)
|
||||
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
# 0 2019-01-01 1 2 5 6
|
||||
# 1 2019-01-02 4 6 12 14
|
||||
assert cum_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
||||
assert cum_df["count_metric, US"].to_list() == [2.0, 6.0]
|
||||
assert cum_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
||||
assert cum_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
cum_df["dttm"].to_list() == to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1 2 5 6
|
||||
2019-01-02 3 4 7 8
|
||||
"""
|
||||
cum_df = pp.cum(
|
||||
df=pivot_df,
|
||||
operator="sum",
|
||||
columns={"sum_metric": "sum_metric", "count_metric": "count_metric"},
|
||||
)
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1 2 5 6
|
||||
2019-01-02 4 6 12 14
|
||||
"""
|
||||
flat_df = pp.flatten(cum_df)
|
||||
"""
|
||||
dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
0 2019-01-01 1 2 5 6
|
||||
1 2019-01-02 4 6 12 14
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1, 4],
|
||||
FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2, 6],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5, 12],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6, 14],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing import diff
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df, timeseries_df2
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
@@ -39,7 +39,7 @@ def test_diff():
|
||||
assert series_to_list(post_df["y1"]) == [-1.0, -1.0, -1.0, None]
|
||||
|
||||
# invalid column reference
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
diff(
|
||||
df=timeseries_df, columns={"abc": "abc"},
|
||||
)
|
||||
|
||||
64
tests/unit_tests/pandas_postprocessing/test_flatten.py
Normal file
64
tests/unit_tests/pandas_postprocessing/test_flatten.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
|
||||
|
||||
def test_flat_should_not_change():
|
||||
df = pd.DataFrame(data={"foo": [1, 2, 3], "bar": [4, 5, 6],})
|
||||
|
||||
assert pp.flatten(df).equals(df)
|
||||
|
||||
|
||||
def test_flat_should_not_reset_index():
|
||||
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
|
||||
index.name = "__timestamp"
|
||||
df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
|
||||
|
||||
assert pp.flatten(df, reset_index=False).equals(df)
|
||||
|
||||
|
||||
def test_flat_should_flat_datetime_index():
|
||||
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
|
||||
index.name = "__timestamp"
|
||||
df = pd.DataFrame(index=index, data={"foo": [1, 2, 3], "bar": [4, 5, 6]})
|
||||
|
||||
assert pp.flatten(df).equals(
|
||||
pd.DataFrame({"__timestamp": index, "foo": [1, 2, 3], "bar": [4, 5, 6],})
|
||||
)
|
||||
|
||||
|
||||
def test_flat_should_flat_multiple_index():
|
||||
index = pd.to_datetime(["2021-01-01", "2021-01-02", "2021-01-03"])
|
||||
index.name = "__timestamp"
|
||||
iterables = [["foo", "bar"], [1, "two"]]
|
||||
columns = pd.MultiIndex.from_product(iterables, names=["level1", "level2"])
|
||||
df = pd.DataFrame(index=index, columns=columns, data=1)
|
||||
|
||||
assert pp.flatten(df).equals(
|
||||
pd.DataFrame(
|
||||
{
|
||||
"__timestamp": index,
|
||||
FLAT_COLUMN_SEPARATOR.join(["foo", "1"]): [1, 1, 1],
|
||||
FLAT_COLUMN_SEPARATOR.join(["foo", "two"]): [1, 1, 1],
|
||||
FLAT_COLUMN_SEPARATOR.join(["bar", "1"]): [1, 1, 1],
|
||||
FLAT_COLUMN_SEPARATOR.join(["bar", "two"]): [1, 1, 1],
|
||||
}
|
||||
)
|
||||
)
|
||||
@@ -19,7 +19,7 @@ import numpy as np
|
||||
import pytest
|
||||
from pandas import DataFrame, Timestamp, to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing import _flatten_column_after_pivot, pivot
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df, single_metric_df
|
||||
from tests.unit_tests.pandas_postprocessing.utils import (
|
||||
@@ -172,7 +172,7 @@ def test_pivot_exceptions():
|
||||
pivot(df=categories_df, columns=["dept"], aggregates=AGGREGATES_SINGLE)
|
||||
|
||||
# invalid index reference
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pivot(
|
||||
df=categories_df,
|
||||
index=["abc"],
|
||||
@@ -181,7 +181,7 @@ def test_pivot_exceptions():
|
||||
)
|
||||
|
||||
# invalid column reference
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pivot(
|
||||
df=categories_df,
|
||||
index=["dept"],
|
||||
@@ -190,7 +190,7 @@ def test_pivot_exceptions():
|
||||
)
|
||||
|
||||
# invalid aggregate options
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pivot(
|
||||
df=categories_df,
|
||||
index=["name"],
|
||||
|
||||
@@ -19,7 +19,7 @@ from importlib.util import find_spec
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.core import DTTM_ALIAS
|
||||
from superset.utils.pandas_postprocessing import prophet
|
||||
from tests.unit_tests.fixtures.dataframes import prophet_df
|
||||
@@ -75,40 +75,40 @@ def test_prophet_valid_zero_periods():
|
||||
def test_prophet_import():
|
||||
dynamic_module = find_spec("prophet")
|
||||
if dynamic_module is None:
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9)
|
||||
|
||||
|
||||
def test_prophet_missing_temporal_column():
|
||||
df = prophet_df.drop(DTTM_ALIAS, axis=1)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=df, time_grain="P1M", periods=3, confidence_interval=0.9,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_incorrect_confidence_interval():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.0,
|
||||
)
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="P1M", periods=3, confidence_interval=1.0,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_incorrect_periods():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="P1M", periods=-1, confidence_interval=0.8,
|
||||
)
|
||||
|
||||
|
||||
def test_prophet_incorrect_time_grain():
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
prophet(
|
||||
df=prophet_df, time_grain="yearly", periods=10, confidence_interval=0.8,
|
||||
)
|
||||
|
||||
@@ -14,45 +14,80 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas import DataFrame, to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import resample
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df, timeseries_df
|
||||
|
||||
|
||||
def test_resample_should_not_side_effect():
|
||||
_timeseries_df = timeseries_df.copy()
|
||||
pp.resample(df=_timeseries_df, rule="1D", method="ffill")
|
||||
assert _timeseries_df.equals(timeseries_df)
|
||||
|
||||
|
||||
def test_resample():
|
||||
df = timeseries_df.copy()
|
||||
df.index.name = "time_column"
|
||||
df.reset_index(inplace=True)
|
||||
|
||||
post_df = resample(df=df, rule="1D", method="ffill", time_column="time_column",)
|
||||
assert post_df["label"].tolist() == ["x", "y", "y", "y", "z", "z", "q"]
|
||||
|
||||
assert post_df["y"].tolist() == [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0]
|
||||
|
||||
post_df = resample(
|
||||
df=df, rule="1D", method="asfreq", time_column="time_column", fill_value=0,
|
||||
post_df = pp.resample(df=timeseries_df, rule="1D", method="ffill")
|
||||
"""
|
||||
label y
|
||||
2019-01-01 x 1.0
|
||||
2019-01-02 y 2.0
|
||||
2019-01-03 y 2.0
|
||||
2019-01-04 y 2.0
|
||||
2019-01-05 z 3.0
|
||||
2019-01-06 z 3.0
|
||||
2019-01-07 q 4.0
|
||||
"""
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=pd.to_datetime(
|
||||
[
|
||||
"2019-01-01",
|
||||
"2019-01-02",
|
||||
"2019-01-03",
|
||||
"2019-01-04",
|
||||
"2019-01-05",
|
||||
"2019-01-06",
|
||||
"2019-01-07",
|
||||
]
|
||||
),
|
||||
data={
|
||||
"label": ["x", "y", "y", "y", "z", "z", "q"],
|
||||
"y": [1.0, 2.0, 2.0, 2.0, 3.0, 3.0, 4.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
assert post_df["label"].tolist() == ["x", "y", 0, 0, "z", 0, "q"]
|
||||
assert post_df["y"].tolist() == [1.0, 2.0, 0, 0, 3.0, 0, 4.0]
|
||||
|
||||
|
||||
def test_resample_with_groupby():
|
||||
"""
|
||||
The Dataframe contains a timestamp column, a string column and a numeric column.
|
||||
__timestamp city val
|
||||
0 2022-01-13 Chicago 6.0
|
||||
1 2022-01-13 LA 5.0
|
||||
2 2022-01-13 NY 4.0
|
||||
3 2022-01-11 Chicago 3.0
|
||||
4 2022-01-11 LA 2.0
|
||||
5 2022-01-11 NY 1.0
|
||||
"""
|
||||
df = DataFrame(
|
||||
{
|
||||
"__timestamp": to_datetime(
|
||||
def test_resample_zero_fill():
|
||||
post_df = pp.resample(df=timeseries_df, rule="1D", method="asfreq", fill_value=0)
|
||||
assert post_df.equals(
|
||||
pd.DataFrame(
|
||||
index=pd.to_datetime(
|
||||
[
|
||||
"2019-01-01",
|
||||
"2019-01-02",
|
||||
"2019-01-03",
|
||||
"2019-01-04",
|
||||
"2019-01-05",
|
||||
"2019-01-06",
|
||||
"2019-01-07",
|
||||
]
|
||||
),
|
||||
data={
|
||||
"label": ["x", "y", 0, 0, "z", 0, "q"],
|
||||
"y": [1.0, 2.0, 0, 0, 3.0, 0, 4.0],
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_resample_after_pivot():
|
||||
df = pd.DataFrame(
|
||||
data={
|
||||
"__timestamp": pd.to_datetime(
|
||||
[
|
||||
"2022-01-13",
|
||||
"2022-01-13",
|
||||
@@ -66,42 +101,53 @@ __timestamp city val
|
||||
"val": [6.0, 5.0, 4.0, 3.0, 2.0, 1.0],
|
||||
}
|
||||
)
|
||||
post_df = resample(
|
||||
pivot_df = pp.pivot(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city",),
|
||||
index=["__timestamp"],
|
||||
columns=["city"],
|
||||
aggregates={"val": {"operator": "sum"},},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
assert list(post_df.columns) == [
|
||||
"__timestamp",
|
||||
"city",
|
||||
"val",
|
||||
]
|
||||
assert [str(dt.date()) for dt in post_df["__timestamp"]] == (
|
||||
["2022-01-11"] * 3 + ["2022-01-12"] * 3 + ["2022-01-13"] * 3
|
||||
"""
|
||||
val
|
||||
city Chicago LA NY
|
||||
__timestamp
|
||||
2022-01-11 3.0 2.0 1.0
|
||||
2022-01-13 6.0 5.0 4.0
|
||||
"""
|
||||
resample_df = pp.resample(df=pivot_df, rule="1D", method="asfreq", fill_value=0,)
|
||||
"""
|
||||
val
|
||||
city Chicago LA NY
|
||||
__timestamp
|
||||
2022-01-11 3.0 2.0 1.0
|
||||
2022-01-12 0.0 0.0 0.0
|
||||
2022-01-13 6.0 5.0 4.0
|
||||
"""
|
||||
flat_df = pp.flatten(resample_df)
|
||||
"""
|
||||
__timestamp val, Chicago val, LA val, NY
|
||||
0 2022-01-11 3.0 2.0 1.0
|
||||
1 2022-01-12 0.0 0.0 0.0
|
||||
2 2022-01-13 6.0 5.0 4.0
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"__timestamp": pd.to_datetime(
|
||||
["2022-01-11", "2022-01-12", "2022-01-13"]
|
||||
),
|
||||
"val, Chicago": [3.0, 0, 6.0],
|
||||
"val, LA": [2.0, 0, 5.0],
|
||||
"val, NY": [1.0, 0, 4.0],
|
||||
}
|
||||
)
|
||||
)
|
||||
assert list(post_df["val"]) == [3.0, 2.0, 1.0, 0, 0, 0, 6.0, 5.0, 4.0]
|
||||
|
||||
# should raise error when get a non-existent column
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
resample(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city", "unkonw_column",),
|
||||
)
|
||||
|
||||
# should raise error when get a None value in groupby list
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
resample(
|
||||
df=df,
|
||||
rule="1D",
|
||||
method="asfreq",
|
||||
fill_value=0,
|
||||
time_column="__timestamp",
|
||||
groupby_columns=("city", None,),
|
||||
def test_resample_should_raise_ex():
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pp.resample(
|
||||
df=categories_df, rule="1D", method="asfreq",
|
||||
)
|
||||
|
||||
@@ -14,11 +14,12 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import pandas as pd
|
||||
import pytest
|
||||
from pandas import to_datetime
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.utils.pandas_postprocessing import pivot, rolling
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils import pandas_postprocessing as pp
|
||||
from superset.utils.pandas_postprocessing.utils import FLAT_COLUMN_SEPARATOR
|
||||
from tests.unit_tests.fixtures.dataframes import (
|
||||
multiple_metrics_df,
|
||||
single_metric_df,
|
||||
@@ -27,9 +28,21 @@ from tests.unit_tests.fixtures.dataframes import (
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
|
||||
|
||||
def test_rolling_should_not_side_effect():
|
||||
_timeseries_df = timeseries_df.copy()
|
||||
pp.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=0,
|
||||
)
|
||||
assert _timeseries_df.equals(timeseries_df)
|
||||
|
||||
|
||||
def test_rolling():
|
||||
# sum rolling type
|
||||
post_df = rolling(
|
||||
post_df = pp.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="sum",
|
||||
@@ -41,7 +54,7 @@ def test_rolling():
|
||||
assert series_to_list(post_df["y"]) == [1.0, 3.0, 5.0, 7.0]
|
||||
|
||||
# mean rolling type with alias
|
||||
post_df = rolling(
|
||||
post_df = pp.rolling(
|
||||
df=timeseries_df,
|
||||
rolling_type="mean",
|
||||
columns={"y": "y_mean"},
|
||||
@@ -52,7 +65,7 @@ def test_rolling():
|
||||
assert series_to_list(post_df["y_mean"]) == [1.0, 1.5, 2.0, 2.5]
|
||||
|
||||
# count rolling type
|
||||
post_df = rolling(
|
||||
post_df = pp.rolling(
|
||||
df=timeseries_df,
|
||||
rolling_type="count",
|
||||
columns={"y": "y"},
|
||||
@@ -63,7 +76,7 @@ def test_rolling():
|
||||
assert series_to_list(post_df["y"]) == [1.0, 2.0, 3.0, 4.0]
|
||||
|
||||
# quantile rolling type
|
||||
post_df = rolling(
|
||||
post_df = pp.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "q1"},
|
||||
rolling_type="quantile",
|
||||
@@ -75,14 +88,14 @@ def test_rolling():
|
||||
assert series_to_list(post_df["q1"]) == [1.0, 1.25, 1.5, 1.75]
|
||||
|
||||
# incorrect rolling type
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
rolling(
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pp.rolling(
|
||||
df=timeseries_df, columns={"y": "y"}, rolling_type="abc", window=2,
|
||||
)
|
||||
|
||||
# incorrect rolling type options
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
rolling(
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
pp.rolling(
|
||||
df=timeseries_df,
|
||||
columns={"y": "y"},
|
||||
rolling_type="quantile",
|
||||
@@ -91,8 +104,8 @@ def test_rolling():
|
||||
)
|
||||
|
||||
|
||||
def test_rolling_with_pivot_df_and_single_metric():
|
||||
pivot_df = pivot(
|
||||
def test_rolling_should_empty_df():
|
||||
pivot_df = pp.pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
@@ -100,27 +113,65 @@ def test_rolling_with_pivot_df_and_single_metric():
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
rolling_df = rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
||||
)
|
||||
# dttm UK US
|
||||
# 0 2019-01-01 5 6
|
||||
# 1 2019-01-02 12 14
|
||||
assert rolling_df["UK"].to_list() == [5.0, 12.0]
|
||||
assert rolling_df["US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
rolling_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02"]).to_list()
|
||||
)
|
||||
|
||||
rolling_df = rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=2, is_pivot_df=True,
|
||||
rolling_df = pp.rolling(
|
||||
df=pivot_df,
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=2,
|
||||
columns={"sum_metric": "sum_metric"},
|
||||
)
|
||||
assert rolling_df.empty is True
|
||||
|
||||
|
||||
def test_rolling_with_pivot_df_and_multiple_metrics():
|
||||
pivot_df = pivot(
|
||||
def test_rolling_after_pivot_with_single_metric():
|
||||
pivot_df = pp.pivot(
|
||||
df=single_metric_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
aggregates={"sum_metric": {"operator": "sum"}},
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
"""
|
||||
sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 5 6
|
||||
2019-01-02 7 8
|
||||
"""
|
||||
rolling_df = pp.rolling(
|
||||
df=pivot_df,
|
||||
columns={"sum_metric": "sum_metric"},
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=0,
|
||||
)
|
||||
"""
|
||||
sum_metric
|
||||
country UK US
|
||||
dttm
|
||||
2019-01-01 5.0 6.0
|
||||
2019-01-02 12.0 14.0
|
||||
"""
|
||||
flat_df = pp.flatten(rolling_df)
|
||||
"""
|
||||
dttm sum_metric, UK sum_metric, US
|
||||
0 2019-01-01 5.0 6.0
|
||||
1 2019-01-02 12.0 14.0
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_rolling_after_pivot_with_multiple_metrics():
|
||||
pivot_df = pp.pivot(
|
||||
df=multiple_metrics_df,
|
||||
index=["dttm"],
|
||||
columns=["country"],
|
||||
@@ -131,17 +182,41 @@ def test_rolling_with_pivot_df_and_multiple_metrics():
|
||||
flatten_columns=False,
|
||||
reset_index=False,
|
||||
)
|
||||
rolling_df = rolling(
|
||||
df=pivot_df, rolling_type="sum", window=2, min_periods=0, is_pivot_df=True,
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1 2 5 6
|
||||
2019-01-02 3 4 7 8
|
||||
"""
|
||||
rolling_df = pp.rolling(
|
||||
df=pivot_df,
|
||||
columns={"count_metric": "count_metric", "sum_metric": "sum_metric",},
|
||||
rolling_type="sum",
|
||||
window=2,
|
||||
min_periods=0,
|
||||
)
|
||||
# dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
# 0 2019-01-01 1.0 2.0 5.0 6.0
|
||||
# 1 2019-01-02 4.0 6.0 12.0 14.0
|
||||
assert rolling_df["count_metric, UK"].to_list() == [1.0, 4.0]
|
||||
assert rolling_df["count_metric, US"].to_list() == [2.0, 6.0]
|
||||
assert rolling_df["sum_metric, UK"].to_list() == [5.0, 12.0]
|
||||
assert rolling_df["sum_metric, US"].to_list() == [6.0, 14.0]
|
||||
assert (
|
||||
rolling_df["dttm"].to_list()
|
||||
== to_datetime(["2019-01-01", "2019-01-02",]).to_list()
|
||||
"""
|
||||
count_metric sum_metric
|
||||
country UK US UK US
|
||||
dttm
|
||||
2019-01-01 1.0 2.0 5.0 6.0
|
||||
2019-01-02 4.0 6.0 12.0 14.0
|
||||
"""
|
||||
flat_df = pp.flatten(rolling_df)
|
||||
"""
|
||||
dttm count_metric, UK count_metric, US sum_metric, UK sum_metric, US
|
||||
0 2019-01-01 1.0 2.0 5.0 6.0
|
||||
1 2019-01-02 4.0 6.0 12.0 14.0
|
||||
"""
|
||||
assert flat_df.equals(
|
||||
pd.DataFrame(
|
||||
data={
|
||||
"dttm": pd.to_datetime(["2019-01-01", "2019-01-02"]),
|
||||
FLAT_COLUMN_SEPARATOR.join(["count_metric", "UK"]): [1.0, 4.0],
|
||||
FLAT_COLUMN_SEPARATOR.join(["count_metric", "US"]): [2.0, 6.0],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "UK"]): [5.0, 12.0],
|
||||
FLAT_COLUMN_SEPARATOR.join(["sum_metric", "US"]): [6.0, 14.0],
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing.select import select
|
||||
from tests.unit_tests.fixtures.dataframes import timeseries_df
|
||||
|
||||
@@ -47,9 +47,9 @@ def test_select():
|
||||
assert post_df.columns.tolist() == ["y1"]
|
||||
|
||||
# invalid columns
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
select(df=timeseries_df, columns=["abc"], rename={"abc": "qwerty"})
|
||||
|
||||
# select renamed column by new name
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
select(df=timeseries_df, columns=["label_new"], rename={"label": "label_new"})
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
# under the License.
|
||||
import pytest
|
||||
|
||||
from superset.exceptions import QueryObjectValidationError
|
||||
from superset.exceptions import InvalidPostProcessingError
|
||||
from superset.utils.pandas_postprocessing import sort
|
||||
from tests.unit_tests.fixtures.dataframes import categories_df
|
||||
from tests.unit_tests.pandas_postprocessing.utils import series_to_list
|
||||
@@ -26,5 +26,5 @@ def test_sort():
|
||||
df = sort(df=categories_df, columns={"category": True, "asc_idx": False})
|
||||
assert series_to_list(df["asc_idx"])[1] == 96
|
||||
|
||||
with pytest.raises(QueryObjectValidationError):
|
||||
with pytest.raises(InvalidPostProcessingError):
|
||||
sort(df=df, columns={"abc": True})
|
||||
|
||||
Reference in New Issue
Block a user