fix(chart-data-api): assert referenced columns are present in datasource (#10451)

* fix(chart-data-api): assert requested columns are present in datasource

* add filter tests

* add column_names to AnnotationDatasource

* add assertion for simple metrics

* lint
This commit is contained in:
Ville Brofeldt
2020-08-14 20:58:24 +03:00
committed by GitHub
parent 6c09b938fe
commit acb00f509c
7 changed files with 196 additions and 24 deletions

View File

@@ -17,11 +17,12 @@
import tests.test_app
from superset import db
from superset.charts.schemas import ChartDataQueryContextSchema
from superset.common.query_context import QueryContext
from superset.connectors.connector_registry import ConnectorRegistry
from superset.utils.core import (
AdhocMetricExpressionType,
ChartDataResultFormat,
ChartDataResultType,
FilterOperator,
TimeRangeEndpoint,
)
from tests.base_tests import SupersetTestCase
@@ -75,7 +76,7 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context(table.name, table.id, table.type)
# construct baseline cache_key
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
@@ -92,7 +93,7 @@ class TestQueryContext(SupersetTestCase):
db.session.commit()
# create new QueryContext with unchanged attributes and extract new cache_key
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_new = query_context.cache_key(query_object)
@@ -108,20 +109,20 @@ class TestQueryContext(SupersetTestCase):
)
# construct baseline cache_key from query_context with post processing operation
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_original = query_context.cache_key(query_object)
# ensure added None post_processing operation doesn't change cache_key
payload["queries"][0]["post_processing"].append(None)
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_with_null = query_context.cache_key(query_object)
self.assertEqual(cache_key_original, cache_key_with_null)
# ensure query without post processing operation is different
payload["queries"][0].pop("post_processing")
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
cache_key_without_post_processing = query_context.cache_key(query_object)
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
@@ -136,7 +137,7 @@ class TestQueryContext(SupersetTestCase):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
del payload["queries"][0]["extras"]["time_range_endpoints"]
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
query_object = query_context.queries[0]
extras = query_object.to_dict()["extras"]
self.assertTrue("time_range_endpoints" in extras)
@@ -155,8 +156,8 @@ class TestQueryContext(SupersetTestCase):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["granularity_sqla"] = "timecol"
payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
query_context = QueryContext(**payload)
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
query_context = ChartDataQueryContextSchema().load(payload)
self.assertEqual(len(query_context.queries), 1)
query_object = query_context.queries[0]
self.assertEqual(query_object.granularity, "timecol")
@@ -172,13 +173,79 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context(table.name, table.id, table.type)
payload["result_format"] = ChartDataResultFormat.CSV.value
payload["queries"][0]["row_limit"] = 10
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
self.assertIn("name,sum__num\n", data)
self.assertEqual(len(data.split("\n")), 12)
def test_sql_injection_via_groupby(self):
"""
Ensure that calling invalid columns names in groupby are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = ["currentDatabase()"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
def test_sql_injection_via_columns(self):
"""
Ensure that calling invalid columns names in columns are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = []
payload["queries"][0]["metrics"] = []
payload["queries"][0]["columns"] = ["*, 'extra'"]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
def test_sql_injection_via_filters(self):
"""
Ensure that calling invalid columns names in filters are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = ["name"]
payload["queries"][0]["metrics"] = []
payload["queries"][0]["filters"] = [
{"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
def test_sql_injection_via_metrics(self):
"""
Ensure that calling invalid columns names in filters are caught
"""
self.login(username="admin")
table_name = "birth_names"
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["queries"][0]["groupby"] = ["name"]
payload["queries"][0]["metrics"] = [
{
"expressionType": AdhocMetricExpressionType.SIMPLE.value,
"column": {"column_name": "invalid_col"},
"aggregate": "SUM",
"label": "My Simple Label",
}
]
query_context = ChartDataQueryContextSchema().load(payload)
query_payload = query_context.get_payload()
assert query_payload[0].get("error") is not None
def test_samples_response_type(self):
"""
Ensure that samples result type works
@@ -189,7 +256,7 @@ class TestQueryContext(SupersetTestCase):
payload = get_query_context(table.name, table.id, table.type)
payload["result_type"] = ChartDataResultType.SAMPLES.value
payload["queries"][0]["row_limit"] = 5
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
data = responses[0]["data"]
@@ -206,7 +273,7 @@ class TestQueryContext(SupersetTestCase):
table = self.get_table_by_name(table_name)
payload = get_query_context(table.name, table.id, table.type)
payload["result_type"] = ChartDataResultType.QUERY.value
query_context = QueryContext(**payload)
query_context = ChartDataQueryContextSchema().load(payload)
responses = query_context.get_payload()
self.assertEqual(len(responses), 1)
response = responses[0]