feat: Add post processing to QueryObject (#9427)

* Add post processing to QueryObject

* Simplify sort signature and require explicit sort order

* Add new operations and unit tests

* linting

* Address comments

* Simplify test method names

* Address comments

* Linting

* remove unnecessary logic

* Apply strict whitelisting to all getattr calls

* Add checking of rolling_type_options and add/improve docs
This commit is contained in:
Ville Brofeldt
2020-04-10 20:50:11 +03:00
committed by GitHub
parent 5ec0192bcc
commit a8ce3bccdf
9 changed files with 899 additions and 12 deletions

View File

@@ -111,7 +111,7 @@ class CoreTests(SupersetTestCase):
resp = self.client.get("/superset/slice/-1/")
assert resp.status_code == 404
def _get_query_context_dict(self) -> Dict[str, Any]:
def _get_query_context(self) -> Dict[str, Any]:
self.login(username="admin")
slc = self.get_slice("Girl Name Cloud", db.session)
return {
@@ -127,6 +127,45 @@ class CoreTests(SupersetTestCase):
],
}
def _get_query_context_with_post_processing(self) -> Dict[str, Any]:
self.login(username="admin")
slc = self.get_slice("Girl Name Cloud", db.session)
return {
"datasource": {"id": slc.datasource_id, "type": slc.datasource_type},
"queries": [
{
"granularity": "ds",
"groupby": ["name", "state"],
"metrics": [{"label": "sum__num"}],
"filters": [],
"row_limit": 100,
"post_processing": [
{
"operation": "aggregate",
"options": {
"groupby": ["state"],
"aggregates": {
"q1": {
"operator": "percentile",
"column": "sum__num",
"options": {"q": 25},
},
"median": {
"operator": "median",
"column": "sum__num",
},
},
},
},
{
"operation": "sort",
"options": {"columns": {"q1": False, "state": True},},
},
],
}
],
}
def test_viz_cache_key(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)
@@ -140,7 +179,7 @@ class CoreTests(SupersetTestCase):
self.assertNotEqual(cache_key, viz.cache_key(qobj))
def test_cache_key_changes_when_datasource_is_updated(self):
qc_dict = self._get_query_context_dict()
qc_dict = self._get_query_context()
# construct baseline cache_key
query_context = QueryContext(**qc_dict)
@@ -168,7 +207,7 @@ class CoreTests(SupersetTestCase):
self.assertNotEqual(cache_key_original, cache_key_new)
def test_query_context_time_range_endpoints(self):
query_context = QueryContext(**self._get_query_context_dict())
query_context = QueryContext(**self._get_query_context())
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)
@@ -217,11 +256,18 @@ class CoreTests(SupersetTestCase):
def test_api_v1_query_endpoint(self):
self.login(username="admin")
qc_dict = self._get_query_context_dict()
qc_dict = self._get_query_context()
data = json.dumps(qc_dict)
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
self.assertEqual(resp[0]["rowcount"], 100)
def test_api_v1_query_endpoint_with_post_processing(self):
self.login(username="admin")
qc_dict = self._get_query_context_with_post_processing()
data = json.dumps(qc_dict)
resp = json.loads(self.get_resp("/api/v1/query/", {"query_context": data}))
self.assertEqual(resp[0]["rowcount"], 6)
def test_old_slice_json_endpoint(self):
self.login(username="admin")
slc = self.get_slice("Girls", db.session)

121
tests/fixtures/dataframes.py vendored Normal file
View File

@@ -0,0 +1,121 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import date
from pandas import DataFrame, to_datetime
names_df = DataFrame(
[
{
"dt": date(2020, 1, 2),
"name": "John",
"country": "United Kingdom",
"cars": 3,
"bikes": 1,
"seconds": 30,
},
{
"dt": date(2020, 1, 2),
"name": "Peter",
"country": "Sweden",
"cars": 4,
"bikes": 2,
"seconds": 1,
},
{
"dt": date(2020, 1, 3),
"name": "Mary",
"country": "Finland",
"cars": 5,
"bikes": 3,
"seconds": None,
},
{
"dt": date(2020, 1, 3),
"name": "Peter",
"country": "India",
"cars": 6,
"bikes": 4,
"seconds": 12,
},
{
"dt": date(2020, 1, 4),
"name": "John",
"country": "Portugal",
"cars": 7,
"bikes": None,
"seconds": 75,
},
{
"dt": date(2020, 1, 4),
"name": "Peter",
"country": "Italy",
"cars": None,
"bikes": 5,
"seconds": 600,
},
{
"dt": date(2020, 1, 4),
"name": "Mary",
"country": None,
"cars": 9,
"bikes": 6,
"seconds": 2,
},
{
"dt": date(2020, 1, 4),
"name": None,
"country": "Australia",
"cars": 10,
"bikes": 7,
"seconds": 99,
},
{
"dt": date(2020, 1, 1),
"name": "John",
"country": "USA",
"cars": 1,
"bikes": 8,
"seconds": None,
},
{
"dt": date(2020, 1, 1),
"name": "Mary",
"country": "Fiji",
"cars": 2,
"bikes": 9,
"seconds": 50,
},
]
)
categories_df = DataFrame(
{
"constant": ["dummy" for _ in range(0, 101)],
"category": [f"cat{i%3}" for i in range(0, 101)],
"dept": [f"dept{i%5}" for i in range(0, 101)],
"name": [f"person{i}" for i in range(0, 101)],
"asc_idx": [i for i in range(0, 101)],
"desc_idx": [i for i in range(100, -1, -1)],
"idx_nulls": [i if i % 5 == 0 else None for i in range(0, 101)],
}
)
timeseries_df = DataFrame(
index=to_datetime(["2019-01-01", "2019-01-02", "2019-01-05", "2019-01-07"]),
data={"label": ["x", "y", "z", "q"], "y": [1.0, 2.0, 3.0, 4.0]},
)

View File

@@ -0,0 +1,290 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import math
from typing import Any, List
from pandas import Series
from superset.exceptions import QueryObjectValidationError
from superset.utils import pandas_postprocessing as proc
from .base_tests import SupersetTestCase
from .fixtures.dataframes import categories_df, timeseries_df
def series_to_list(series: Series) -> List[Any]:
"""
Converts a `Series` to a regular list, and replaces non-numeric values to
Nones.
:param series: Series to convert
:return: list without nan or inf
"""
return [
None
if not isinstance(val, str) and (math.isnan(val) or math.isinf(val))
else val
for val in series.tolist()
]
class PostProcessingTestCase(SupersetTestCase):
def test_pivot(self):
aggregates = {"idx_nulls": {"operator": "sum"}}
# regular pivot
df = proc.pivot(
df=categories_df,
index=["name"],
columns=["category"],
aggregates=aggregates,
)
self.assertListEqual(
df.columns.tolist(),
[("idx_nulls", "cat0"), ("idx_nulls", "cat1"), ("idx_nulls", "cat2")],
)
self.assertEqual(len(df), 101)
self.assertEqual(df.sum()[0], 315)
# regular pivot
df = proc.pivot(
df=categories_df,
index=["dept"],
columns=["category"],
aggregates=aggregates,
)
self.assertEqual(len(df), 5)
# fill value
df = proc.pivot(
df=categories_df,
index=["name"],
columns=["category"],
metric_fill_value=1,
aggregates={"idx_nulls": {"operator": "sum"}},
)
self.assertEqual(df.sum()[0], 382)
# invalid index reference
self.assertRaises(
QueryObjectValidationError,
proc.pivot,
df=categories_df,
index=["abc"],
columns=["dept"],
aggregates=aggregates,
)
# invalid column reference
self.assertRaises(
QueryObjectValidationError,
proc.pivot,
df=categories_df,
index=["dept"],
columns=["abc"],
aggregates=aggregates,
)
# invalid aggregate options
self.assertRaises(
QueryObjectValidationError,
proc.pivot,
df=categories_df,
index=["name"],
columns=["category"],
aggregates={"idx_nulls": {}},
)
def test_aggregate(self):
aggregates = {
"asc sum": {"column": "asc_idx", "operator": "sum"},
"asc q2": {
"column": "asc_idx",
"operator": "percentile",
"options": {"q": 75},
},
"desc q1": {
"column": "desc_idx",
"operator": "percentile",
"options": {"q": 25},
},
}
df = proc.aggregate(
df=categories_df, groupby=["constant"], aggregates=aggregates
)
self.assertListEqual(
df.columns.tolist(), ["constant", "asc sum", "asc q2", "desc q1"]
)
self.assertEqual(series_to_list(df["asc sum"])[0], 5050)
self.assertEqual(series_to_list(df["asc q2"])[0], 75)
self.assertEqual(series_to_list(df["desc q1"])[0], 25)
def test_sort(self):
df = proc.sort(df=categories_df, columns={"category": True, "asc_idx": False})
self.assertEqual(96, series_to_list(df["asc_idx"])[1])
self.assertRaises(
QueryObjectValidationError, proc.sort, df=df, columns={"abc": True}
)
def test_rolling(self):
# sum rolling type
post_df = proc.rolling(
df=timeseries_df,
columns={"y": "y"},
rolling_type="sum",
window=2,
min_periods=0,
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 3.0, 5.0, 7.0])
# mean rolling type with alias
post_df = proc.rolling(
df=timeseries_df,
rolling_type="mean",
columns={"y": "y_mean"},
window=10,
min_periods=0,
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y_mean"])
self.assertListEqual(series_to_list(post_df["y_mean"]), [1.0, 1.5, 2.0, 2.5])
# count rolling type
post_df = proc.rolling(
df=timeseries_df,
rolling_type="count",
columns={"y": "y"},
window=10,
min_periods=0,
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
# quantile rolling type
post_df = proc.rolling(
df=timeseries_df,
columns={"y": "q1"},
rolling_type="quantile",
rolling_type_options={"quantile": 0.25},
window=10,
min_periods=0,
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "q1"])
self.assertListEqual(series_to_list(post_df["q1"]), [1.0, 1.25, 1.5, 1.75])
# incorrect rolling type
self.assertRaises(
QueryObjectValidationError,
proc.rolling,
df=timeseries_df,
columns={"y": "y"},
rolling_type="abc",
window=2,
)
# incorrect rolling type options
self.assertRaises(
QueryObjectValidationError,
proc.rolling,
df=timeseries_df,
columns={"y": "y"},
rolling_type="quantile",
rolling_type_options={"abc": 123},
window=2,
)
def test_select(self):
# reorder columns
post_df = proc.select(df=timeseries_df, columns=["y", "label"])
self.assertListEqual(post_df.columns.tolist(), ["y", "label"])
# one column
post_df = proc.select(df=timeseries_df, columns=["label"])
self.assertListEqual(post_df.columns.tolist(), ["label"])
# rename one column
post_df = proc.select(df=timeseries_df, columns=["y"], rename={"y": "y1"})
self.assertListEqual(post_df.columns.tolist(), ["y1"])
# rename one and leave one unchanged
post_df = proc.select(
df=timeseries_df, columns=["label", "y"], rename={"y": "y1"}
)
self.assertListEqual(post_df.columns.tolist(), ["label", "y1"])
# invalid columns
self.assertRaises(
QueryObjectValidationError,
proc.select,
df=timeseries_df,
columns=["qwerty"],
rename={"abc": "qwerty"},
)
def test_diff(self):
# overwrite column
post_df = proc.diff(df=timeseries_df, columns={"y": "y"})
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [None, 1.0, 1.0, 1.0])
# add column
post_df = proc.diff(df=timeseries_df, columns={"y": "y1"})
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y1"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
self.assertListEqual(series_to_list(post_df["y1"]), [None, 1.0, 1.0, 1.0])
# look ahead
post_df = proc.diff(df=timeseries_df, columns={"y": "y1"}, periods=-1)
self.assertListEqual(series_to_list(post_df["y1"]), [-1.0, -1.0, -1.0, None])
# invalid column reference
self.assertRaises(
QueryObjectValidationError,
proc.diff,
df=timeseries_df,
columns={"abc": "abc"},
)
def test_cum(self):
# create new column (cumsum)
post_df = proc.cum(df=timeseries_df, columns={"y": "y2"}, operator="sum",)
self.assertListEqual(post_df.columns.tolist(), ["label", "y", "y2"])
self.assertListEqual(series_to_list(post_df["label"]), ["x", "y", "z", "q"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 3.0, 4.0])
self.assertListEqual(series_to_list(post_df["y2"]), [1.0, 3.0, 6.0, 10.0])
# overwrite column (cumprod)
post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="prod",)
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 2.0, 6.0, 24.0])
# overwrite column (cummin)
post_df = proc.cum(df=timeseries_df, columns={"y": "y"}, operator="min",)
self.assertListEqual(post_df.columns.tolist(), ["label", "y"])
self.assertListEqual(series_to_list(post_df["y"]), [1.0, 1.0, 1.0, 1.0])
# invalid operator
self.assertRaises(
QueryObjectValidationError,
proc.cum,
df=timeseries_df,
columns={"y": "y"},
operator="abc",
)