mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
fix(Jinja): Extra cache keys for calculated columns and metrics using Jinja (#30735)
This commit is contained in:
@@ -116,7 +116,6 @@ from superset.superset_typing import (
|
||||
)
|
||||
from superset.utils import core as utils, json
|
||||
from superset.utils.backports import StrEnum
|
||||
from superset.utils.core import GenericDataType, is_adhoc_column, MediumText
|
||||
|
||||
config = app.config
|
||||
metadata = Model.metadata # pylint: disable=no-member
|
||||
@@ -477,7 +476,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
|
||||
]
|
||||
|
||||
filtered_columns: list[Column] = []
|
||||
column_types: set[GenericDataType] = set()
|
||||
column_types: set[utils.GenericDataType] = set()
|
||||
for column_ in data["columns"]:
|
||||
generic_type = column_.get("type_generic")
|
||||
if generic_type is not None:
|
||||
@@ -511,7 +510,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
|
||||
def filter_values_handler( # pylint: disable=too-many-arguments
|
||||
values: FilterValues | None,
|
||||
operator: str,
|
||||
target_generic_type: GenericDataType,
|
||||
target_generic_type: utils.GenericDataType,
|
||||
target_native_type: str | None = None,
|
||||
is_list_target: bool = False,
|
||||
db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
|
||||
@@ -829,10 +828,10 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod
|
||||
advanced_data_type = Column(String(255))
|
||||
groupby = Column(Boolean, default=True)
|
||||
filterable = Column(Boolean, default=True)
|
||||
description = Column(MediumText())
|
||||
description = Column(utils.MediumText())
|
||||
table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
|
||||
is_dttm = Column(Boolean, default=False)
|
||||
expression = Column(MediumText())
|
||||
expression = Column(utils.MediumText())
|
||||
python_date_format = Column(String(255))
|
||||
extra = Column(Text)
|
||||
|
||||
@@ -892,21 +891,21 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod
|
||||
"""
|
||||
Check if the column has a boolean datatype.
|
||||
"""
|
||||
return self.type_generic == GenericDataType.BOOLEAN
|
||||
return self.type_generic == utils.GenericDataType.BOOLEAN
|
||||
|
||||
@property
|
||||
def is_numeric(self) -> bool:
|
||||
"""
|
||||
Check if the column has a numeric datatype.
|
||||
"""
|
||||
return self.type_generic == GenericDataType.NUMERIC
|
||||
return self.type_generic == utils.GenericDataType.NUMERIC
|
||||
|
||||
@property
|
||||
def is_string(self) -> bool:
|
||||
"""
|
||||
Check if the column has a string datatype.
|
||||
"""
|
||||
return self.type_generic == GenericDataType.STRING
|
||||
return self.type_generic == utils.GenericDataType.STRING
|
||||
|
||||
@property
|
||||
def is_temporal(self) -> bool:
|
||||
@@ -918,7 +917,7 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod
|
||||
"""
|
||||
if self.is_dttm is not None:
|
||||
return self.is_dttm
|
||||
return self.type_generic == GenericDataType.TEMPORAL
|
||||
return self.type_generic == utils.GenericDataType.TEMPORAL
|
||||
|
||||
@property
|
||||
def database(self) -> Database:
|
||||
@@ -935,7 +934,7 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod
|
||||
@property
|
||||
def type_generic(self) -> utils.GenericDataType | None:
|
||||
if self.is_dttm:
|
||||
return GenericDataType.TEMPORAL
|
||||
return utils.GenericDataType.TEMPORAL
|
||||
|
||||
return (
|
||||
column_spec.generic_type
|
||||
@@ -1038,12 +1037,12 @@ class SqlMetric(AuditMixinNullable, ImportExportMixin, CertificationMixin, Model
|
||||
metric_name = Column(String(255), nullable=False)
|
||||
verbose_name = Column(String(1024))
|
||||
metric_type = Column(String(32))
|
||||
description = Column(MediumText())
|
||||
description = Column(utils.MediumText())
|
||||
d3format = Column(String(128))
|
||||
currency = Column(String(128))
|
||||
warning_text = Column(Text)
|
||||
table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
|
||||
expression = Column(MediumText(), nullable=False)
|
||||
expression = Column(utils.MediumText(), nullable=False)
|
||||
extra = Column(Text)
|
||||
|
||||
table: Mapped[SqlaTable] = relationship(
|
||||
@@ -1185,7 +1184,7 @@ class SqlaTable(
|
||||
)
|
||||
schema = Column(String(255))
|
||||
catalog = Column(String(256), nullable=True, default=None)
|
||||
sql = Column(MediumText())
|
||||
sql = Column(utils.MediumText())
|
||||
is_sqllab_view = Column(Boolean, default=False)
|
||||
template_params = Column(Text)
|
||||
extra = Column(Text)
|
||||
@@ -1980,10 +1979,26 @@ class SqlaTable(
|
||||
templatable_statements.append(extras["where"])
|
||||
if "having" in extras:
|
||||
templatable_statements.append(extras["having"])
|
||||
if "columns" in query_obj:
|
||||
templatable_statements += [
|
||||
c["sqlExpression"] for c in query_obj["columns"] if is_adhoc_column(c)
|
||||
]
|
||||
if columns := query_obj.get("columns"):
|
||||
calculated_columns: dict[str, Any] = {
|
||||
c.column_name: c.expression for c in self.columns if c.expression
|
||||
}
|
||||
for column_ in columns:
|
||||
if utils.is_adhoc_column(column_):
|
||||
templatable_statements.append(column_["sqlExpression"])
|
||||
elif isinstance(column_, str) and column_ in calculated_columns:
|
||||
templatable_statements.append(calculated_columns[column_])
|
||||
if metrics := query_obj.get("metrics"):
|
||||
metrics_by_name: dict[str, Any] = {
|
||||
m.metric_name: m.expression for m in self.metrics
|
||||
}
|
||||
for metric in metrics:
|
||||
if utils.is_adhoc_metric(metric) and (
|
||||
sql := metric.get("sqlExpression")
|
||||
):
|
||||
templatable_statements.append(sql)
|
||||
elif isinstance(metric, str) and metric in metrics_by_name:
|
||||
templatable_statements.append(metrics_by_name[metric])
|
||||
if self.is_rls_supported:
|
||||
templatable_statements += [
|
||||
f.clause for f in security_manager.get_rls_filters(self)
|
||||
@@ -2125,4 +2140,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
|
||||
secondary=RLSFilterTables,
|
||||
backref="row_level_security_filters",
|
||||
)
|
||||
clause = Column(MediumText(), nullable=False)
|
||||
clause = Column(utils.MediumText(), nullable=False)
|
||||
|
||||
@@ -15,11 +15,13 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
# isort:skip_file
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, NamedTuple, Optional, Union
|
||||
from typing import Any, Literal, NamedTuple, Optional, Union
|
||||
from re import Pattern
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import Mock, patch
|
||||
import pytest
|
||||
|
||||
import numpy as np
|
||||
@@ -913,54 +915,99 @@ def test_extra_cache_keys_in_sql_expression(
|
||||
|
||||
@pytest.mark.usefixtures("app_context")
|
||||
@pytest.mark.parametrize(
|
||||
"sql_expression,expected_cache_keys,has_extra_cache_keys",
|
||||
"sql_expression,expected_cache_keys,has_extra_cache_keys,item_type",
|
||||
[
|
||||
("'{{ current_username() }}'", ["abc"], True),
|
||||
("(user != 'abc')", [], False),
|
||||
("'{{ current_username() }}'", ["abc"], True, "columns"),
|
||||
("(user != 'abc')", [], False, "columns"),
|
||||
("{{ current_user_id() }}", [1], True, "metrics"),
|
||||
("COUNT(*)", [], False, "metrics"),
|
||||
],
|
||||
)
|
||||
@patch("superset.jinja_context.get_user_id", return_value=1)
|
||||
@patch("superset.jinja_context.get_username", return_value="abc")
|
||||
@patch("superset.jinja_context.get_user_email", return_value="abc@test.com")
|
||||
def test_extra_cache_keys_in_columns(
|
||||
mock_user_email,
|
||||
mock_username,
|
||||
mock_user_id,
|
||||
sql_expression,
|
||||
expected_cache_keys,
|
||||
has_extra_cache_keys,
|
||||
def test_extra_cache_keys_in_adhoc_metrics_and_columns(
|
||||
mock_username: Mock,
|
||||
mock_user_id: Mock,
|
||||
sql_expression: str,
|
||||
expected_cache_keys: list[str | None],
|
||||
has_extra_cache_keys: bool,
|
||||
item_type: Literal["columns", "metrics"],
|
||||
):
|
||||
table = SqlaTable(
|
||||
table_name="test_has_no_extra_cache_keys_table",
|
||||
sql="SELECT 'abc' as user",
|
||||
database=get_example_database(),
|
||||
)
|
||||
base_query_obj = {
|
||||
base_query_obj: dict[str, Any] = {
|
||||
"granularity": None,
|
||||
"from_dttm": None,
|
||||
"to_dttm": None,
|
||||
"groupby": [],
|
||||
"metrics": [],
|
||||
"columns": [],
|
||||
"is_timeseries": False,
|
||||
"filter": [],
|
||||
}
|
||||
|
||||
query_obj = dict(
|
||||
**base_query_obj,
|
||||
columns=[
|
||||
items: dict[str, Any] = {
|
||||
item_type: [
|
||||
{
|
||||
"label": None,
|
||||
"expressionType": "SQL",
|
||||
"sqlExpression": sql_expression,
|
||||
}
|
||||
],
|
||||
)
|
||||
}
|
||||
|
||||
query_obj = {**base_query_obj, **items}
|
||||
|
||||
extra_cache_keys = table.get_extra_cache_keys(query_obj)
|
||||
assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys
|
||||
assert extra_cache_keys == expected_cache_keys
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("app_context")
|
||||
@patch("superset.jinja_context.get_user_id", return_value=1)
|
||||
@patch("superset.jinja_context.get_username", return_value="abc")
|
||||
def test_extra_cache_keys_in_dataset_metrics_and_columns(
|
||||
mock_username: Mock,
|
||||
mock_user_id: Mock,
|
||||
):
|
||||
table = SqlaTable(
|
||||
table_name="test_has_no_extra_cache_keys_table",
|
||||
sql="SELECT 'abc' as user",
|
||||
database=get_example_database(),
|
||||
columns=[
|
||||
TableColumn(column_name="user", type="VARCHAR(255)"),
|
||||
TableColumn(
|
||||
column_name="username",
|
||||
type="VARCHAR(255)",
|
||||
expression="{{ current_username() }}",
|
||||
),
|
||||
],
|
||||
metrics=[
|
||||
SqlMetric(
|
||||
metric_name="variable_profit",
|
||||
expression="SUM(price) * {{ url_param('multiplier') }}",
|
||||
),
|
||||
],
|
||||
)
|
||||
query_obj: dict[str, Any] = {
|
||||
"granularity": None,
|
||||
"from_dttm": None,
|
||||
"to_dttm": None,
|
||||
"groupby": [],
|
||||
"columns": ["username"],
|
||||
"metrics": ["variable_profit"],
|
||||
"is_timeseries": False,
|
||||
"filter": [],
|
||||
}
|
||||
|
||||
extra_cache_keys = table.get_extra_cache_keys(query_obj)
|
||||
assert table.has_extra_cache_key_calls(query_obj) is True
|
||||
assert set(extra_cache_keys) == {"abc", None}
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("app_context")
|
||||
@pytest.mark.parametrize(
|
||||
"row,dimension,result",
|
||||
|
||||
Reference in New Issue
Block a user