chore(sqla): refactor query utils (#21811)

Co-authored-by: Ville Brofeldt <ville.brofeldt@apple.com>
This commit is contained in:
Ville Brofeldt
2022-10-17 10:40:42 +01:00
committed by AAfghahi
parent aa11e5486b
commit d779789652
3 changed files with 209 additions and 9 deletions

View File

@@ -92,6 +92,7 @@ from superset.exceptions import (
DatasetInvalidPermissionEvaluationException, DatasetInvalidPermissionEvaluationException,
QueryClauseValidationException, QueryClauseValidationException,
QueryObjectValidationError, QueryObjectValidationError,
SupersetSecurityException,
) )
from superset.extensions import feature_flag_manager from superset.extensions import feature_flag_manager
from superset.jinja_context import ( from superset.jinja_context import (
@@ -647,19 +648,19 @@ def _process_sql_expression(
expression: Optional[str], expression: Optional[str],
database_id: int, database_id: int,
schema: str, schema: str,
template_processor: Optional[BaseTemplateProcessor], template_processor: Optional[BaseTemplateProcessor] = None,
) -> Optional[str]: ) -> Optional[str]:
if template_processor and expression: if template_processor and expression:
expression = template_processor.process_template(expression) expression = template_processor.process_template(expression)
if expression: if expression:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
try: try:
expression = validate_adhoc_subquery(
expression,
database_id,
schema,
)
expression = sanitize_clause(expression) expression = sanitize_clause(expression)
except QueryClauseValidationException as ex: except (QueryClauseValidationException, SupersetSecurityException) as ex:
raise QueryObjectValidationError(ex.message) from ex raise QueryObjectValidationError(ex.message) from ex
return expression return expression
@@ -1639,6 +1640,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message, msg=ex.message,
) )
) from ex ) from ex
where = _process_sql_expression(
expression=where,
database_id=self.database_id,
schema=self.schema,
)
where_clause_and += [self.text(where)] where_clause_and += [self.text(where)]
having = extras.get("having") having = extras.get("having")
if having: if having:
@@ -1651,7 +1657,13 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
msg=ex.message, msg=ex.message,
) )
) from ex ) from ex
having = _process_sql_expression(
expression=having,
database_id=self.database_id,
schema=self.schema,
)
having_clause_and += [self.text(having)] having_clause_and += [self.text(having)]
if apply_fetch_values_predicate and self.fetch_values_predicate: if apply_fetch_values_predicate and self.fetch_values_predicate:
qry = qry.where(self.get_fetch_values_predicate()) qry = qry.where(self.get_fetch_values_predicate())
if granularity: if granularity:

View File

@@ -21,7 +21,7 @@ import unittest
import copy import copy
from datetime import datetime from datetime import datetime
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Any, Dict, List, Optional
from unittest import mock from unittest import mock
from zipfile import ZipFile from zipfile import ZipFile
@@ -963,3 +963,191 @@ class TestGetChartDataApi(BaseTestChartDataApi):
unique_genders = {row["male_or_female"] for row in data} unique_genders = {row["male_or_female"] for row in data}
assert unique_genders == {"male", "female"} assert unique_genders == {"male", "female"}
assert result["applied_filters"] == [{"column": "male_or_female"}] assert result["applied_filters"] == [{"column": "male_or_female"}]
@pytest.fixture()
def physical_query_context(physical_dataset) -> Dict[str, Any]:
return {
"datasource": {
"type": physical_dataset.type,
"id": physical_dataset.id,
},
"queries": [
{
"columns": ["col1"],
"metrics": ["count"],
"orderby": [["col1", True]],
}
],
"result_type": ChartDataResultType.FULL,
"force": True,
}
@mock.patch(
"superset.common.query_context_processor.config",
{
**app.config,
"CACHE_DEFAULT_TIMEOUT": 1234,
"DATA_CACHE_CONFIG": {
**app.config["DATA_CACHE_CONFIG"],
"CACHE_DEFAULT_TIMEOUT": None,
},
},
)
def test_cache_default_timeout(test_client, login_as_admin, physical_query_context):
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 1234
def test_custom_cache_timeout(test_client, login_as_admin, physical_query_context):
physical_query_context["custom_cache_timeout"] = 5678
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 5678
@mock.patch(
"superset.common.query_context_processor.config",
{
**app.config,
"CACHE_DEFAULT_TIMEOUT": 100000,
"DATA_CACHE_CONFIG": {
**app.config["DATA_CACHE_CONFIG"],
"CACHE_DEFAULT_TIMEOUT": 3456,
},
},
)
def test_data_cache_default_timeout(
test_client,
login_as_admin,
physical_query_context,
):
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 3456
def test_chart_cache_timeout(
test_client,
login_as_admin,
physical_query_context,
load_energy_table_with_slice: List[Slice],
):
# should override datasource cache timeout
slice_with_cache_timeout = load_energy_table_with_slice[0]
slice_with_cache_timeout.cache_timeout = 20
db.session.merge(slice_with_cache_timeout)
datasource: SqlaTable = (
db.session.query(SqlaTable)
.filter(SqlaTable.id == physical_query_context["datasource"]["id"])
.first()
)
datasource.cache_timeout = 1254
db.session.merge(datasource)
db.session.commit()
physical_query_context["form_data"] = {"slice_id": slice_with_cache_timeout.id}
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 20
@mock.patch(
"superset.common.query_context_processor.config",
{
**app.config,
"DATA_CACHE_CONFIG": {
**app.config["DATA_CACHE_CONFIG"],
"CACHE_DEFAULT_TIMEOUT": 1010,
},
},
)
def test_chart_cache_timeout_not_present(
test_client, login_as_admin, physical_query_context
):
# should use datasource cache, if it's present
datasource: SqlaTable = (
db.session.query(SqlaTable)
.filter(SqlaTable.id == physical_query_context["datasource"]["id"])
.first()
)
datasource.cache_timeout = 1980
db.session.merge(datasource)
db.session.commit()
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 1980
@mock.patch(
"superset.common.query_context_processor.config",
{
**app.config,
"DATA_CACHE_CONFIG": {
**app.config["DATA_CACHE_CONFIG"],
"CACHE_DEFAULT_TIMEOUT": 1010,
},
},
)
def test_chart_cache_timeout_chart_not_found(
test_client, login_as_admin, physical_query_context
):
# should use default timeout
physical_query_context["form_data"] = {"slice_id": 0}
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.json["result"][0]["cache_timeout"] == 1010
@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(400, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(400, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_not_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.status_code == status_code
@pytest.mark.parametrize(
"status_code,extras",
[
(200, {"where": "1 = 1"}),
(200, {"having": "count(*) > 0"}),
(200, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
(200, {"having": "count(*) > (select count(*) from physical_dataset)"}),
],
)
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_chart_data_subquery_allowed(
test_client,
login_as_admin,
physical_dataset,
physical_query_context,
status_code,
extras,
):
physical_query_context["queries"][0]["extras"] = extras
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
assert rv.status_code == status_code

View File

@@ -261,7 +261,7 @@ class TestDatabaseModel(SupersetTestCase):
) )
db.session.commit() db.session.commit()
with pytest.raises(SupersetSecurityException): with pytest.raises(QueryObjectValidationError):
table.get_sqla_query(**base_query_obj) table.get_sqla_query(**base_query_obj)
# Cleanup # Cleanup
db.session.delete(table) db.session.delete(table)