mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
chore(sqla): refactor query utils (#21811)
Co-authored-by: Ville Brofeldt <ville.brofeldt@apple.com>
This commit is contained in:
@@ -92,6 +92,7 @@ from superset.exceptions import (
|
|||||||
DatasetInvalidPermissionEvaluationException,
|
DatasetInvalidPermissionEvaluationException,
|
||||||
QueryClauseValidationException,
|
QueryClauseValidationException,
|
||||||
QueryObjectValidationError,
|
QueryObjectValidationError,
|
||||||
|
SupersetSecurityException,
|
||||||
)
|
)
|
||||||
from superset.extensions import feature_flag_manager
|
from superset.extensions import feature_flag_manager
|
||||||
from superset.jinja_context import (
|
from superset.jinja_context import (
|
||||||
@@ -647,19 +648,19 @@ def _process_sql_expression(
|
|||||||
expression: Optional[str],
|
expression: Optional[str],
|
||||||
database_id: int,
|
database_id: int,
|
||||||
schema: str,
|
schema: str,
|
||||||
template_processor: Optional[BaseTemplateProcessor],
|
template_processor: Optional[BaseTemplateProcessor] = None,
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
if template_processor and expression:
|
if template_processor and expression:
|
||||||
expression = template_processor.process_template(expression)
|
expression = template_processor.process_template(expression)
|
||||||
if expression:
|
if expression:
|
||||||
expression = validate_adhoc_subquery(
|
|
||||||
expression,
|
|
||||||
database_id,
|
|
||||||
schema,
|
|
||||||
)
|
|
||||||
try:
|
try:
|
||||||
|
expression = validate_adhoc_subquery(
|
||||||
|
expression,
|
||||||
|
database_id,
|
||||||
|
schema,
|
||||||
|
)
|
||||||
expression = sanitize_clause(expression)
|
expression = sanitize_clause(expression)
|
||||||
except QueryClauseValidationException as ex:
|
except (QueryClauseValidationException, SupersetSecurityException) as ex:
|
||||||
raise QueryObjectValidationError(ex.message) from ex
|
raise QueryObjectValidationError(ex.message) from ex
|
||||||
return expression
|
return expression
|
||||||
|
|
||||||
@@ -1639,6 +1640,11 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||||||
msg=ex.message,
|
msg=ex.message,
|
||||||
)
|
)
|
||||||
) from ex
|
) from ex
|
||||||
|
where = _process_sql_expression(
|
||||||
|
expression=where,
|
||||||
|
database_id=self.database_id,
|
||||||
|
schema=self.schema,
|
||||||
|
)
|
||||||
where_clause_and += [self.text(where)]
|
where_clause_and += [self.text(where)]
|
||||||
having = extras.get("having")
|
having = extras.get("having")
|
||||||
if having:
|
if having:
|
||||||
@@ -1651,7 +1657,13 @@ class SqlaTable(Model, BaseDatasource): # pylint: disable=too-many-public-metho
|
|||||||
msg=ex.message,
|
msg=ex.message,
|
||||||
)
|
)
|
||||||
) from ex
|
) from ex
|
||||||
|
having = _process_sql_expression(
|
||||||
|
expression=having,
|
||||||
|
database_id=self.database_id,
|
||||||
|
schema=self.schema,
|
||||||
|
)
|
||||||
having_clause_and += [self.text(having)]
|
having_clause_and += [self.text(having)]
|
||||||
|
|
||||||
if apply_fetch_values_predicate and self.fetch_values_predicate:
|
if apply_fetch_values_predicate and self.fetch_values_predicate:
|
||||||
qry = qry.where(self.get_fetch_values_predicate())
|
qry = qry.where(self.get_fetch_values_predicate())
|
||||||
if granularity:
|
if granularity:
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ import unittest
|
|||||||
import copy
|
import copy
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
from zipfile import ZipFile
|
from zipfile import ZipFile
|
||||||
|
|
||||||
@@ -963,3 +963,191 @@ class TestGetChartDataApi(BaseTestChartDataApi):
|
|||||||
unique_genders = {row["male_or_female"] for row in data}
|
unique_genders = {row["male_or_female"] for row in data}
|
||||||
assert unique_genders == {"male", "female"}
|
assert unique_genders == {"male", "female"}
|
||||||
assert result["applied_filters"] == [{"column": "male_or_female"}]
|
assert result["applied_filters"] == [{"column": "male_or_female"}]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def physical_query_context(physical_dataset) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"datasource": {
|
||||||
|
"type": physical_dataset.type,
|
||||||
|
"id": physical_dataset.id,
|
||||||
|
},
|
||||||
|
"queries": [
|
||||||
|
{
|
||||||
|
"columns": ["col1"],
|
||||||
|
"metrics": ["count"],
|
||||||
|
"orderby": [["col1", True]],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"result_type": ChartDataResultType.FULL,
|
||||||
|
"force": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.common.query_context_processor.config",
|
||||||
|
{
|
||||||
|
**app.config,
|
||||||
|
"CACHE_DEFAULT_TIMEOUT": 1234,
|
||||||
|
"DATA_CACHE_CONFIG": {
|
||||||
|
**app.config["DATA_CACHE_CONFIG"],
|
||||||
|
"CACHE_DEFAULT_TIMEOUT": None,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_cache_default_timeout(test_client, login_as_admin, physical_query_context):
|
||||||
|
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
|
||||||
|
assert rv.json["result"][0]["cache_timeout"] == 1234
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_cache_timeout(test_client, login_as_admin, physical_query_context):
|
||||||
|
physical_query_context["custom_cache_timeout"] = 5678
|
||||||
|
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
|
||||||
|
assert rv.json["result"][0]["cache_timeout"] == 5678
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.common.query_context_processor.config",
|
||||||
|
{
|
||||||
|
**app.config,
|
||||||
|
"CACHE_DEFAULT_TIMEOUT": 100000,
|
||||||
|
"DATA_CACHE_CONFIG": {
|
||||||
|
**app.config["DATA_CACHE_CONFIG"],
|
||||||
|
"CACHE_DEFAULT_TIMEOUT": 3456,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_data_cache_default_timeout(
|
||||||
|
test_client,
|
||||||
|
login_as_admin,
|
||||||
|
physical_query_context,
|
||||||
|
):
|
||||||
|
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
|
||||||
|
assert rv.json["result"][0]["cache_timeout"] == 3456
|
||||||
|
|
||||||
|
|
||||||
|
def test_chart_cache_timeout(
|
||||||
|
test_client,
|
||||||
|
login_as_admin,
|
||||||
|
physical_query_context,
|
||||||
|
load_energy_table_with_slice: List[Slice],
|
||||||
|
):
|
||||||
|
# should override datasource cache timeout
|
||||||
|
|
||||||
|
slice_with_cache_timeout = load_energy_table_with_slice[0]
|
||||||
|
slice_with_cache_timeout.cache_timeout = 20
|
||||||
|
db.session.merge(slice_with_cache_timeout)
|
||||||
|
|
||||||
|
datasource: SqlaTable = (
|
||||||
|
db.session.query(SqlaTable)
|
||||||
|
.filter(SqlaTable.id == physical_query_context["datasource"]["id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
datasource.cache_timeout = 1254
|
||||||
|
db.session.merge(datasource)
|
||||||
|
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
physical_query_context["form_data"] = {"slice_id": slice_with_cache_timeout.id}
|
||||||
|
|
||||||
|
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
|
||||||
|
assert rv.json["result"][0]["cache_timeout"] == 20
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.common.query_context_processor.config",
|
||||||
|
{
|
||||||
|
**app.config,
|
||||||
|
"DATA_CACHE_CONFIG": {
|
||||||
|
**app.config["DATA_CACHE_CONFIG"],
|
||||||
|
"CACHE_DEFAULT_TIMEOUT": 1010,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_chart_cache_timeout_not_present(
|
||||||
|
test_client, login_as_admin, physical_query_context
|
||||||
|
):
|
||||||
|
# should use datasource cache, if it's present
|
||||||
|
|
||||||
|
datasource: SqlaTable = (
|
||||||
|
db.session.query(SqlaTable)
|
||||||
|
.filter(SqlaTable.id == physical_query_context["datasource"]["id"])
|
||||||
|
.first()
|
||||||
|
)
|
||||||
|
datasource.cache_timeout = 1980
|
||||||
|
db.session.merge(datasource)
|
||||||
|
db.session.commit()
|
||||||
|
|
||||||
|
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
|
||||||
|
assert rv.json["result"][0]["cache_timeout"] == 1980
|
||||||
|
|
||||||
|
|
||||||
|
@mock.patch(
|
||||||
|
"superset.common.query_context_processor.config",
|
||||||
|
{
|
||||||
|
**app.config,
|
||||||
|
"DATA_CACHE_CONFIG": {
|
||||||
|
**app.config["DATA_CACHE_CONFIG"],
|
||||||
|
"CACHE_DEFAULT_TIMEOUT": 1010,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def test_chart_cache_timeout_chart_not_found(
|
||||||
|
test_client, login_as_admin, physical_query_context
|
||||||
|
):
|
||||||
|
# should use default timeout
|
||||||
|
|
||||||
|
physical_query_context["form_data"] = {"slice_id": 0}
|
||||||
|
|
||||||
|
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
|
||||||
|
assert rv.json["result"][0]["cache_timeout"] == 1010
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code,extras",
|
||||||
|
[
|
||||||
|
(200, {"where": "1 = 1"}),
|
||||||
|
(200, {"having": "count(*) > 0"}),
|
||||||
|
(400, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
|
||||||
|
(400, {"having": "count(*) > (select count(*) from physical_dataset)"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=False)
|
||||||
|
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||||
|
def test_chart_data_subquery_not_allowed(
|
||||||
|
test_client,
|
||||||
|
login_as_admin,
|
||||||
|
physical_dataset,
|
||||||
|
physical_query_context,
|
||||||
|
status_code,
|
||||||
|
extras,
|
||||||
|
):
|
||||||
|
physical_query_context["queries"][0]["extras"] = extras
|
||||||
|
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
|
||||||
|
|
||||||
|
assert rv.status_code == status_code
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"status_code,extras",
|
||||||
|
[
|
||||||
|
(200, {"where": "1 = 1"}),
|
||||||
|
(200, {"having": "count(*) > 0"}),
|
||||||
|
(200, {"where": "col1 in (select distinct col1 from physical_dataset)"}),
|
||||||
|
(200, {"having": "count(*) > (select count(*) from physical_dataset)"}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@with_feature_flags(ALLOW_ADHOC_SUBQUERY=True)
|
||||||
|
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
|
||||||
|
def test_chart_data_subquery_allowed(
|
||||||
|
test_client,
|
||||||
|
login_as_admin,
|
||||||
|
physical_dataset,
|
||||||
|
physical_query_context,
|
||||||
|
status_code,
|
||||||
|
extras,
|
||||||
|
):
|
||||||
|
physical_query_context["queries"][0]["extras"] = extras
|
||||||
|
rv = test_client.post(CHART_DATA_URI, json=physical_query_context)
|
||||||
|
|
||||||
|
assert rv.status_code == status_code
|
||||||
|
|||||||
@@ -261,7 +261,7 @@ class TestDatabaseModel(SupersetTestCase):
|
|||||||
)
|
)
|
||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
with pytest.raises(SupersetSecurityException):
|
with pytest.raises(QueryObjectValidationError):
|
||||||
table.get_sqla_query(**base_query_obj)
|
table.get_sqla_query(**base_query_obj)
|
||||||
# Cleanup
|
# Cleanup
|
||||||
db.session.delete(table)
|
db.session.delete(table)
|
||||||
|
|||||||
Reference in New Issue
Block a user