fix(Jinja): Extra cache keys for calculated columns and metrics using Jinja (#30735)

This commit is contained in:
Vitor Avila
2024-10-29 10:14:27 -03:00
committed by GitHub
parent c03bf80864
commit 09d3f60d85
2 changed files with 98 additions and 36 deletions

View File

@@ -116,7 +116,6 @@ from superset.superset_typing import (
)
from superset.utils import core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import GenericDataType, is_adhoc_column, MediumText
config = app.config
metadata = Model.metadata # pylint: disable=no-member
@@ -477,7 +476,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
]
filtered_columns: list[Column] = []
column_types: set[GenericDataType] = set()
column_types: set[utils.GenericDataType] = set()
for column_ in data["columns"]:
generic_type = column_.get("type_generic")
if generic_type is not None:
@@ -511,7 +510,7 @@ class BaseDatasource(AuditMixinNullable, ImportExportMixin): # pylint: disable=
def filter_values_handler( # pylint: disable=too-many-arguments
values: FilterValues | None,
operator: str,
target_generic_type: GenericDataType,
target_generic_type: utils.GenericDataType,
target_native_type: str | None = None,
is_list_target: bool = False,
db_engine_spec: builtins.type[BaseEngineSpec] | None = None,
@@ -829,10 +828,10 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod
advanced_data_type = Column(String(255))
groupby = Column(Boolean, default=True)
filterable = Column(Boolean, default=True)
description = Column(MediumText())
description = Column(utils.MediumText())
table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
is_dttm = Column(Boolean, default=False)
expression = Column(MediumText())
expression = Column(utils.MediumText())
python_date_format = Column(String(255))
extra = Column(Text)
@@ -892,21 +891,21 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod
"""
Check if the column has a boolean datatype.
"""
return self.type_generic == GenericDataType.BOOLEAN
return self.type_generic == utils.GenericDataType.BOOLEAN
@property
def is_numeric(self) -> bool:
"""
Check if the column has a numeric datatype.
"""
return self.type_generic == GenericDataType.NUMERIC
return self.type_generic == utils.GenericDataType.NUMERIC
@property
def is_string(self) -> bool:
"""
Check if the column has a string datatype.
"""
return self.type_generic == GenericDataType.STRING
return self.type_generic == utils.GenericDataType.STRING
@property
def is_temporal(self) -> bool:
@@ -918,7 +917,7 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod
"""
if self.is_dttm is not None:
return self.is_dttm
return self.type_generic == GenericDataType.TEMPORAL
return self.type_generic == utils.GenericDataType.TEMPORAL
@property
def database(self) -> Database:
@@ -935,7 +934,7 @@ class TableColumn(AuditMixinNullable, ImportExportMixin, CertificationMixin, Mod
@property
def type_generic(self) -> utils.GenericDataType | None:
if self.is_dttm:
return GenericDataType.TEMPORAL
return utils.GenericDataType.TEMPORAL
return (
column_spec.generic_type
@@ -1038,12 +1037,12 @@ class SqlMetric(AuditMixinNullable, ImportExportMixin, CertificationMixin, Model
metric_name = Column(String(255), nullable=False)
verbose_name = Column(String(1024))
metric_type = Column(String(32))
description = Column(MediumText())
description = Column(utils.MediumText())
d3format = Column(String(128))
currency = Column(String(128))
warning_text = Column(Text)
table_id = Column(Integer, ForeignKey("tables.id", ondelete="CASCADE"))
expression = Column(MediumText(), nullable=False)
expression = Column(utils.MediumText(), nullable=False)
extra = Column(Text)
table: Mapped[SqlaTable] = relationship(
@@ -1185,7 +1184,7 @@ class SqlaTable(
)
schema = Column(String(255))
catalog = Column(String(256), nullable=True, default=None)
sql = Column(MediumText())
sql = Column(utils.MediumText())
is_sqllab_view = Column(Boolean, default=False)
template_params = Column(Text)
extra = Column(Text)
@@ -1980,10 +1979,26 @@ class SqlaTable(
templatable_statements.append(extras["where"])
if "having" in extras:
templatable_statements.append(extras["having"])
if "columns" in query_obj:
templatable_statements += [
c["sqlExpression"] for c in query_obj["columns"] if is_adhoc_column(c)
]
if columns := query_obj.get("columns"):
calculated_columns: dict[str, Any] = {
c.column_name: c.expression for c in self.columns if c.expression
}
for column_ in columns:
if utils.is_adhoc_column(column_):
templatable_statements.append(column_["sqlExpression"])
elif isinstance(column_, str) and column_ in calculated_columns:
templatable_statements.append(calculated_columns[column_])
if metrics := query_obj.get("metrics"):
metrics_by_name: dict[str, Any] = {
m.metric_name: m.expression for m in self.metrics
}
for metric in metrics:
if utils.is_adhoc_metric(metric) and (
sql := metric.get("sqlExpression")
):
templatable_statements.append(sql)
elif isinstance(metric, str) and metric in metrics_by_name:
templatable_statements.append(metrics_by_name[metric])
if self.is_rls_supported:
templatable_statements += [
f.clause for f in security_manager.get_rls_filters(self)
@@ -2125,4 +2140,4 @@ class RowLevelSecurityFilter(Model, AuditMixinNullable):
secondary=RLSFilterTables,
backref="row_level_security_filters",
)
clause = Column(MediumText(), nullable=False)
clause = Column(utils.MediumText(), nullable=False)

View File

@@ -15,11 +15,13 @@
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
from __future__ import annotations
import re
from datetime import datetime
from typing import Any, NamedTuple, Optional, Union
from typing import Any, Literal, NamedTuple, Optional, Union
from re import Pattern
from unittest.mock import patch
from unittest.mock import Mock, patch
import pytest
import numpy as np
@@ -913,54 +915,99 @@ def test_extra_cache_keys_in_sql_expression(
@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
"sql_expression,expected_cache_keys,has_extra_cache_keys",
"sql_expression,expected_cache_keys,has_extra_cache_keys,item_type",
[
("'{{ current_username() }}'", ["abc"], True),
("(user != 'abc')", [], False),
("'{{ current_username() }}'", ["abc"], True, "columns"),
("(user != 'abc')", [], False, "columns"),
("{{ current_user_id() }}", [1], True, "metrics"),
("COUNT(*)", [], False, "metrics"),
],
)
@patch("superset.jinja_context.get_user_id", return_value=1)
@patch("superset.jinja_context.get_username", return_value="abc")
@patch("superset.jinja_context.get_user_email", return_value="abc@test.com")
def test_extra_cache_keys_in_columns(
mock_user_email,
mock_username,
mock_user_id,
sql_expression,
expected_cache_keys,
has_extra_cache_keys,
def test_extra_cache_keys_in_adhoc_metrics_and_columns(
mock_username: Mock,
mock_user_id: Mock,
sql_expression: str,
expected_cache_keys: list[str | None],
has_extra_cache_keys: bool,
item_type: Literal["columns", "metrics"],
):
table = SqlaTable(
table_name="test_has_no_extra_cache_keys_table",
sql="SELECT 'abc' as user",
database=get_example_database(),
)
base_query_obj = {
base_query_obj: dict[str, Any] = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": [],
"metrics": [],
"columns": [],
"is_timeseries": False,
"filter": [],
}
query_obj = dict(
**base_query_obj,
columns=[
items: dict[str, Any] = {
item_type: [
{
"label": None,
"expressionType": "SQL",
"sqlExpression": sql_expression,
}
],
)
}
query_obj = {**base_query_obj, **items}
extra_cache_keys = table.get_extra_cache_keys(query_obj)
assert table.has_extra_cache_key_calls(query_obj) == has_extra_cache_keys
assert extra_cache_keys == expected_cache_keys
@pytest.mark.usefixtures("app_context")
@patch("superset.jinja_context.get_user_id", return_value=1)
@patch("superset.jinja_context.get_username", return_value="abc")
def test_extra_cache_keys_in_dataset_metrics_and_columns(
mock_username: Mock,
mock_user_id: Mock,
):
table = SqlaTable(
table_name="test_has_no_extra_cache_keys_table",
sql="SELECT 'abc' as user",
database=get_example_database(),
columns=[
TableColumn(column_name="user", type="VARCHAR(255)"),
TableColumn(
column_name="username",
type="VARCHAR(255)",
expression="{{ current_username() }}",
),
],
metrics=[
SqlMetric(
metric_name="variable_profit",
expression="SUM(price) * {{ url_param('multiplier') }}",
),
],
)
query_obj: dict[str, Any] = {
"granularity": None,
"from_dttm": None,
"to_dttm": None,
"groupby": [],
"columns": ["username"],
"metrics": ["variable_profit"],
"is_timeseries": False,
"filter": [],
}
extra_cache_keys = table.get_extra_cache_keys(query_obj)
assert table.has_extra_cache_key_calls(query_obj) is True
assert set(extra_cache_keys) == {"abc", None}
@pytest.mark.usefixtures("app_context")
@pytest.mark.parametrize(
"row,dimension,result",