fix: ensure metric_macro expands templates (#32344)

This commit is contained in:
Beto Dealmeida
2025-02-24 18:08:50 -05:00
committed by GitHub
parent b0dac046e6
commit 83071d0e5f
2 changed files with 137 additions and 85 deletions

View File

@@ -21,6 +21,8 @@ from typing import Any
import pytest
from freezegun import freeze_time
from jinja2 import DebugUndefined
from jinja2.sandbox import SandboxedEnvironment
from pytest_mock import MockerFixture
from sqlalchemy.dialects import mysql
from sqlalchemy.dialects.postgresql import dialect
@@ -32,6 +34,7 @@ from superset.exceptions import SupersetTemplateException
from superset.jinja_context import (
dataset_macro,
ExtraCache,
get_template_processor,
metric_macro,
safe_proxy,
TimeFilter,
@@ -540,7 +543,8 @@ def test_metric_macro_with_dataset_id(mocker: MockerFixture) -> None:
schema="my_schema",
sql=None,
)
assert metric_macro("count", 1) == "COUNT(*)"
env = SandboxedEnvironment(undefined=DebugUndefined)
assert metric_macro(env, {}, "count", 1) == "COUNT(*)"
mock_get_form_data.assert_not_called()
@@ -548,32 +552,64 @@ def test_metric_macro_recursive(mocker: MockerFixture) -> None:
"""
Test the ``metric_macro`` when the definition is recursive.
"""
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"datasource": {"id": 1}}
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = SqlaTable(
table_name="test_dataset",
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
dataset = SqlaTable(
id=1,
metrics=[
SqlMetric(metric_name="a", expression="COUNT(*)"),
SqlMetric(metric_name="b", expression="{{ metric('a') }}"),
SqlMetric(metric_name="c", expression="{{ metric('b') }}"),
],
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
table_name="test_dataset",
database=database,
schema="my_schema",
sql=None,
)
assert metric_macro("c", 1) == "COUNT(*)"
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"datasource": {"id": 1}}
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = dataset
processor = get_template_processor(database=database)
assert processor.process_template("{{ metric('c', 1) }}") == "COUNT(*)"
def test_metric_macro_expansion(mocker: MockerFixture) -> None:
"""
Test that the ``metric_macro`` expands other macros.
"""
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
dataset = SqlaTable(
id=1,
metrics=[
SqlMetric(metric_name="a", expression="{{ current_user_id() }}"),
SqlMetric(metric_name="b", expression="{{ metric('a') }}"),
SqlMetric(metric_name="c", expression="{{ metric('b') }}"),
],
table_name="test_dataset",
database=database,
schema="my_schema",
sql=None,
)
mocker.patch("superset.jinja_context.get_user_id", return_value=42)
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"datasource": {"id": 1}}
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = dataset
processor = get_template_processor(database=database)
assert processor.process_template("{{ metric('c') }}") == "42"
def test_metric_macro_recursive_compound(mocker: MockerFixture) -> None:
"""
Test the ``metric_macro`` when the definition is compound.
"""
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"datasource": {"id": 1}}
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = SqlaTable(
table_name="test_dataset",
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
dataset = SqlaTable(
id=1,
metrics=[
SqlMetric(metric_name="a", expression="SUM(*)"),
SqlMetric(metric_name="b", expression="COUNT(*)"),
@@ -582,11 +618,19 @@ def test_metric_macro_recursive_compound(mocker: MockerFixture) -> None:
expression="{{ metric('a') }} / {{ metric('b') }}",
),
],
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
table_name="test_dataset",
database=database,
schema="my_schema",
sql=None,
)
assert metric_macro("c", 1) == "SUM(*) / COUNT(*)"
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"datasource": {"id": 1}}
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = dataset
processor = get_template_processor(database=database)
assert processor.process_template("{{ metric('c') }}") == "SUM(*) / COUNT(*)"
def test_metric_macro_recursive_cyclic(mocker: MockerFixture) -> None:
@@ -595,23 +639,29 @@ def test_metric_macro_recursive_cyclic(mocker: MockerFixture) -> None:
In this case it should stop, and not go into an infinite loop.
"""
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"datasource": {"id": 1}}
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = SqlaTable(
table_name="test_dataset",
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
dataset = SqlaTable(
id=1,
metrics=[
SqlMetric(metric_name="a", expression="{{ metric('c') }}"),
SqlMetric(metric_name="b", expression="{{ metric('a') }}"),
SqlMetric(metric_name="c", expression="{{ metric('b') }}"),
],
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
table_name="test_dataset",
database=database,
schema="my_schema",
sql=None,
)
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"datasource": {"id": 1}}
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = dataset
processor = get_template_processor(database=database)
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("c", 1)
assert str(excinfo.value) == "Cyclic metric macro detected"
processor.process_template("{{ metric('c') }}")
assert str(excinfo.value) == "Infinite recursion detected in template"
def test_metric_macro_recursive_infinite(mocker: MockerFixture) -> None:
@@ -620,21 +670,27 @@ def test_metric_macro_recursive_infinite(mocker: MockerFixture) -> None:
In this case it should stop, and not go into an infinite loop.
"""
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"datasource": {"id": 1}}
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = SqlaTable(
table_name="test_dataset",
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
dataset = SqlaTable(
id=1,
metrics=[
SqlMetric(metric_name="a", expression="{{ metric('a') }}"),
],
database=Database(database_name="my_database", sqlalchemy_uri="sqlite://"),
table_name="test_dataset",
database=database,
schema="my_schema",
sql=None,
)
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"datasource": {"id": 1}}
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = dataset
processor = get_template_processor(database=database)
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("a", 1)
assert str(excinfo.value) == "Cyclic metric macro detected"
processor.process_template("{{ metric('a') }}")
assert str(excinfo.value) == "Infinite recursion detected in template"
def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> None:
@@ -652,8 +708,9 @@ def test_metric_macro_with_dataset_id_invalid_key(mocker: MockerFixture) -> None
schema="my_schema",
sql=None,
)
env = SandboxedEnvironment(undefined=DebugUndefined)
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("blah", 1)
metric_macro(env, {}, "blah", 1)
assert str(excinfo.value) == "Metric ``blah`` not found in test_dataset."
mock_get_form_data.assert_not_called()
@@ -665,8 +722,9 @@ def test_metric_macro_invalid_dataset_id(mocker: MockerFixture) -> None:
mock_get_form_data = mocker.patch("superset.views.utils.get_form_data")
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
DatasetDAO.find_by_id.return_value = None
env = SandboxedEnvironment(undefined=DebugUndefined)
with pytest.raises(DatasetNotFoundError) as excinfo:
metric_macro("macro_key", 100)
metric_macro(env, {}, "macro_key", 100)
assert str(excinfo.value) == "Dataset ID 100 not found."
mock_get_form_data.assert_not_called()
@@ -679,9 +737,10 @@ def test_metric_macro_no_dataset_id_no_context(mocker: MockerFixture) -> None:
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {}
env = SandboxedEnvironment(undefined=DebugUndefined)
with app.test_request_context():
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("macro_key")
metric_macro(env, {}, "macro_key")
assert str(excinfo.value) == (
"Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501
)
@@ -698,6 +757,8 @@ def test_metric_macro_no_dataset_id_with_context_missing_info(
DatasetDAO = mocker.patch("superset.daos.dataset.DatasetDAO") # noqa: N806
mock_g = mocker.patch("superset.jinja_context.g")
mock_g.form_data = {"queries": []}
env = SandboxedEnvironment(undefined=DebugUndefined)
with app.test_request_context(
data={
"form_data": json.dumps(
@@ -716,7 +777,7 @@ def test_metric_macro_no_dataset_id_with_context_missing_info(
}
):
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("macro_key")
metric_macro(env, {}, "macro_key")
assert str(excinfo.value) == (
"Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501
)
@@ -744,6 +805,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id(
mock_g.form_data = {}
# Getting the data from the request context
env = SandboxedEnvironment(undefined=DebugUndefined)
with app.test_request_context(
data={
"form_data": json.dumps(
@@ -759,7 +821,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id(
)
}
):
assert metric_macro("macro_key") == "COUNT(*)"
assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
# Getting data from g's form_data
mock_g.form_data = {
@@ -772,7 +834,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id(
],
}
with app.test_request_context():
assert metric_macro("macro_key") == "COUNT(*)"
assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
def test_metric_macro_no_dataset_id_with_context_datasource_id_none(
@@ -786,6 +848,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none(
mock_g.form_data = {}
# Getting the data from the request context
env = SandboxedEnvironment(undefined=DebugUndefined)
with app.test_request_context(
data={
"form_data": json.dumps(
@@ -802,7 +865,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none(
}
):
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("macro_key")
metric_macro(env, {}, "macro_key")
assert str(excinfo.value) == (
"Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501
)
@@ -819,7 +882,7 @@ def test_metric_macro_no_dataset_id_with_context_datasource_id_none(
}
with app.test_request_context():
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("macro_key")
metric_macro(env, {}, "macro_key")
assert str(excinfo.value) == (
"Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501
)
@@ -851,6 +914,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id(
mock_g.form_data = {}
# Getting the data from the request context
env = SandboxedEnvironment(undefined=DebugUndefined)
with app.test_request_context(
data={
"form_data": json.dumps(
@@ -866,7 +930,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id(
)
}
):
assert metric_macro("macro_key") == "COUNT(*)"
assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
# Getting data from g's form_data
mock_g.form_data = {
@@ -879,7 +943,7 @@ def test_metric_macro_no_dataset_id_with_context_chart_id(
],
}
with app.test_request_context():
assert metric_macro("macro_key") == "COUNT(*)"
assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
def test_metric_macro_no_dataset_id_with_context_slice_id_none(
@@ -893,6 +957,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none(
mock_g.form_data = {}
# Getting the data from the request context
env = SandboxedEnvironment(undefined=DebugUndefined)
with app.test_request_context(
data={
"form_data": json.dumps(
@@ -909,7 +974,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none(
}
):
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("macro_key")
metric_macro(env, {}, "macro_key")
assert str(excinfo.value) == (
"Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501
)
@@ -926,7 +991,7 @@ def test_metric_macro_no_dataset_id_with_context_slice_id_none(
}
with app.test_request_context():
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("macro_key")
metric_macro(env, {}, "macro_key")
assert str(excinfo.value) == (
"Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501
)
@@ -945,6 +1010,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart(
mock_g.form_data = {}
# Getting the data from the request context
env = SandboxedEnvironment(undefined=DebugUndefined)
with app.test_request_context(
data={
"form_data": json.dumps(
@@ -961,7 +1027,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart(
}
):
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("macro_key")
metric_macro(env, {}, "macro_key")
assert str(excinfo.value) == (
"Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501
)
@@ -978,7 +1044,7 @@ def test_metric_macro_no_dataset_id_with_context_deleted_chart(
}
with app.test_request_context():
with pytest.raises(SupersetTemplateException) as excinfo:
metric_macro("macro_key")
metric_macro(env, {}, "macro_key")
assert str(excinfo.value) == (
"Please specify the Dataset ID for the ``macro_key`` metric in the Jinja macro." # noqa: E501
)
@@ -1006,6 +1072,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data(
mock_g.form_data = {}
# Getting the data from the request context
env = SandboxedEnvironment(undefined=DebugUndefined)
with app.test_request_context(
data={
"form_data": json.dumps(
@@ -1017,7 +1084,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data(
)
}
):
assert metric_macro("macro_key") == "COUNT(*)"
assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
# Getting data from g's form_data
mock_g.form_data = {
@@ -1025,7 +1092,7 @@ def test_metric_macro_no_dataset_id_available_in_request_form_data(
}
with app.test_request_context():
assert metric_macro("macro_key") == "COUNT(*)"
assert metric_macro(env, {}, "macro_key") == "COUNT(*)"
@pytest.mark.parametrize(