mirror of
https://github.com/apache/superset.git
synced 2026-05-06 16:34:32 +00:00
fix(chart-data-api): assert referenced columns are present in datasource (#10451)
* fix(chart-data-api): assert requested columns are present in datasource * add filter tests * add column_names to AnnotationDatasource * add assertion for simple metrics * lint
This commit is contained in:
@@ -17,11 +17,12 @@
|
||||
import tests.test_app
|
||||
from superset import db
|
||||
from superset.charts.schemas import ChartDataQueryContextSchema
|
||||
from superset.common.query_context import QueryContext
|
||||
from superset.connectors.connector_registry import ConnectorRegistry
|
||||
from superset.utils.core import (
|
||||
AdhocMetricExpressionType,
|
||||
ChartDataResultFormat,
|
||||
ChartDataResultType,
|
||||
FilterOperator,
|
||||
TimeRangeEndpoint,
|
||||
)
|
||||
from tests.base_tests import SupersetTestCase
|
||||
@@ -75,7 +76,7 @@ class TestQueryContext(SupersetTestCase):
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
|
||||
# construct baseline cache_key
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_original = query_context.cache_key(query_object)
|
||||
|
||||
@@ -92,7 +93,7 @@ class TestQueryContext(SupersetTestCase):
|
||||
db.session.commit()
|
||||
|
||||
# create new QueryContext with unchanged attributes and extract new cache_key
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_new = query_context.cache_key(query_object)
|
||||
|
||||
@@ -108,20 +109,20 @@ class TestQueryContext(SupersetTestCase):
|
||||
)
|
||||
|
||||
# construct baseline cache_key from query_context with post processing operation
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_original = query_context.cache_key(query_object)
|
||||
|
||||
# ensure added None post_processing operation doesn't change cache_key
|
||||
payload["queries"][0]["post_processing"].append(None)
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_with_null = query_context.cache_key(query_object)
|
||||
self.assertEqual(cache_key_original, cache_key_with_null)
|
||||
|
||||
# ensure query without post processing operation is different
|
||||
payload["queries"][0].pop("post_processing")
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
cache_key_without_post_processing = query_context.cache_key(query_object)
|
||||
self.assertNotEqual(cache_key_original, cache_key_without_post_processing)
|
||||
@@ -136,7 +137,7 @@ class TestQueryContext(SupersetTestCase):
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
del payload["queries"][0]["extras"]["time_range_endpoints"]
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_object = query_context.queries[0]
|
||||
extras = query_object.to_dict()["extras"]
|
||||
self.assertTrue("time_range_endpoints" in extras)
|
||||
@@ -155,8 +156,8 @@ class TestQueryContext(SupersetTestCase):
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["granularity_sqla"] = "timecol"
|
||||
payload["queries"][0]["having_filters"] = {"col": "a", "op": "==", "val": "b"}
|
||||
query_context = QueryContext(**payload)
|
||||
payload["queries"][0]["having_filters"] = [{"col": "a", "op": "==", "val": "b"}]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
self.assertEqual(len(query_context.queries), 1)
|
||||
query_object = query_context.queries[0]
|
||||
self.assertEqual(query_object.granularity, "timecol")
|
||||
@@ -172,13 +173,79 @@ class TestQueryContext(SupersetTestCase):
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["result_format"] = ChartDataResultFormat.CSV.value
|
||||
payload["queries"][0]["row_limit"] = 10
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
data = responses[0]["data"]
|
||||
self.assertIn("name,sum__num\n", data)
|
||||
self.assertEqual(len(data.split("\n")), 12)
|
||||
|
||||
def test_sql_injection_via_groupby(self):
|
||||
"""
|
||||
Ensure that calling invalid columns names in groupby are caught
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["groupby"] = ["currentDatabase()"]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
|
||||
def test_sql_injection_via_columns(self):
|
||||
"""
|
||||
Ensure that calling invalid columns names in columns are caught
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["groupby"] = []
|
||||
payload["queries"][0]["metrics"] = []
|
||||
payload["queries"][0]["columns"] = ["*, 'extra'"]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
|
||||
def test_sql_injection_via_filters(self):
|
||||
"""
|
||||
Ensure that calling invalid columns names in filters are caught
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["groupby"] = ["name"]
|
||||
payload["queries"][0]["metrics"] = []
|
||||
payload["queries"][0]["filters"] = [
|
||||
{"col": "*", "op": FilterOperator.EQUALS.value, "val": ";"}
|
||||
]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
|
||||
def test_sql_injection_via_metrics(self):
|
||||
"""
|
||||
Ensure that calling invalid columns names in filters are caught
|
||||
"""
|
||||
self.login(username="admin")
|
||||
table_name = "birth_names"
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["queries"][0]["groupby"] = ["name"]
|
||||
payload["queries"][0]["metrics"] = [
|
||||
{
|
||||
"expressionType": AdhocMetricExpressionType.SIMPLE.value,
|
||||
"column": {"column_name": "invalid_col"},
|
||||
"aggregate": "SUM",
|
||||
"label": "My Simple Label",
|
||||
}
|
||||
]
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
query_payload = query_context.get_payload()
|
||||
assert query_payload[0].get("error") is not None
|
||||
|
||||
def test_samples_response_type(self):
|
||||
"""
|
||||
Ensure that samples result type works
|
||||
@@ -189,7 +256,7 @@ class TestQueryContext(SupersetTestCase):
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["result_type"] = ChartDataResultType.SAMPLES.value
|
||||
payload["queries"][0]["row_limit"] = 5
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
data = responses[0]["data"]
|
||||
@@ -206,7 +273,7 @@ class TestQueryContext(SupersetTestCase):
|
||||
table = self.get_table_by_name(table_name)
|
||||
payload = get_query_context(table.name, table.id, table.type)
|
||||
payload["result_type"] = ChartDataResultType.QUERY.value
|
||||
query_context = QueryContext(**payload)
|
||||
query_context = ChartDataQueryContextSchema().load(payload)
|
||||
responses = query_context.get_payload()
|
||||
self.assertEqual(len(responses), 1)
|
||||
response = responses[0]
|
||||
|
||||
Reference in New Issue
Block a user