fix: fix csv and query result type and QueryObject schema (#10312)

This commit is contained in:
Ville Brofeldt
2020-07-14 16:37:19 +03:00
committed by Ville Brofeldt
parent 38bc62db4b
commit 522ed20a20
4 changed files with 48 additions and 8 deletions

View File

@@ -470,7 +470,7 @@ class ChartRestApi(BaseSupersetModelRestApi):
return self.response_401() return self.response_401()
payload = query_context.get_payload() payload = query_context.get_payload()
for query in payload: for query in payload:
if query["error"]: if query.get("error"):
return self.response_400(message=f"Error: {query['error']}") return self.response_400(message=f"Error: {query['error']}")
result_format = query_context.result_format result_format = query_context.result_format
if result_format == ChartDataResultFormat.CSV: if result_format == ChartDataResultFormat.CSV:

View File

@@ -680,7 +680,7 @@ class ChartDataQueryObjectSchema(Schema):
timeseries_limit = fields.Integer( timeseries_limit = fields.Integer(
description="Maximum row count for timeseries queries. Default: `0`", description="Maximum row count for timeseries queries. Default: `0`",
) )
timeseries_limit_metric = fields.Integer( timeseries_limit_metric = fields.Raw(
description="Metric used to limit timeseries queries by.", allow_none=True, description="Metric used to limit timeseries queries by.", allow_none=True,
) )
row_limit = fields.Integer( row_limit = fields.Integer(

View File

@@ -730,6 +730,28 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin):
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data") rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 400) self.assertEqual(rv.status_code, 400)
def test_chart_data_query_result_type(self):
"""
Chart data API: Test chart data with query result format
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_type"] = "query"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
def test_chart_data_csv_result_format(self):
"""
Chart data API: Test chart data with CSV result format
"""
self.login(username="admin")
table = self.get_table_by_name("birth_names")
request_payload = get_query_context(table.name, table.id, table.type)
request_payload["result_format"] = "csv"
rv = self.post_assert_metric(CHART_DATA_URI, request_payload, "data")
self.assertEqual(rv.status_code, 200)
def test_chart_data_mixed_case_filter_op(self): def test_chart_data_mixed_case_filter_op(self):
""" """
Chart data API: Ensure mixed case filter operator generates valid result Chart data API: Ensure mixed case filter operator generates valid result

View File

@@ -26,10 +26,6 @@ from tests.base_tests import SupersetTestCase
from tests.fixtures.query_context import get_query_context from tests.fixtures.query_context import get_query_context
def load_query_context(payload: Dict[str, Any]) -> Tuple[QueryContext, Dict[str, Any]]:
return ChartDataQueryContextSchema().load(payload)
class TestSchema(SupersetTestCase): class TestSchema(SupersetTestCase):
def test_query_context_limit_and_offset(self): def test_query_context_limit_and_offset(self):
self.login(username="admin") self.login(username="admin")
@@ -40,7 +36,7 @@ class TestSchema(SupersetTestCase):
# Use defaults # Use defaults
payload["queries"][0].pop("row_limit", None) payload["queries"][0].pop("row_limit", None)
payload["queries"][0].pop("row_offset", None) payload["queries"][0].pop("row_offset", None)
query_context = load_query_context(payload) query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0] query_object = query_context.queries[0]
self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"]) self.assertEqual(query_object.row_limit, app.config["ROW_LIMIT"])
self.assertEqual(query_object.row_offset, 0) self.assertEqual(query_object.row_offset, 0)
@@ -66,10 +62,32 @@ class TestSchema(SupersetTestCase):
table_name = "birth_names" table_name = "birth_names"
table = self.get_table_by_name(table_name) table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type) payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["extras"]["time_grain_sqla"] = None payload["queries"][0]["extras"]["time_grain_sqla"] = None
_ = ChartDataQueryContextSchema().load(payload) _ = ChartDataQueryContextSchema().load(payload)
def test_query_context_series_limit(self):
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["timeseries_limit"] = 2
payload["queries"][0]["timeseries_limit_metric"] = {
"expressionType": "SIMPLE",
"column": {
"id": 334,
"column_name": "gender",
"filterable": True,
"groupby": True,
"is_dttm": False,
"type": "VARCHAR(16)",
"optionName": "_col_gender",
},
"aggregate": "COUNT_DISTINCT",
"label": "COUNT_DISTINCT(gender)",
}
_ = ChartDataQueryContextSchema().load(payload)
def test_query_context_null_post_processing_op(self): def test_query_context_null_post_processing_op(self):
self.login(username="admin") self.login(username="admin")
table_name = "birth_names" table_name = "birth_names"