feat: current_user_rls_rules Jinja macro (#33614)

This commit is contained in:
Vitor Avila
2025-05-29 11:58:40 -03:00
committed by GitHub
parent e20a08cb14
commit fdea4e21b0
3 changed files with 150 additions and 68 deletions

View File

@@ -250,6 +250,14 @@ Will be rendered as:
SELECT * FROM users WHERE role IN ('admin', 'viewer')
```
**Current User RLS Rules**
The `{{ current_user_rls_rules() }}` macro returns an array of RLS rules applied to the current dataset for the logged in user.
If you have caching enabled in your Superset configuration, then the list of RLS Rules will be used
by Superset when calculating the cache key. A cache key is a unique identifier that determines if there's a
cache hit in the future and Superset can retrieve cached data.
**Custom URL Parameters**
The `{{ url_param('custom_variable') }}` macro lets you define arbitrary URL

View File

@@ -22,7 +22,7 @@ import re
from dataclasses import dataclass
from datetime import datetime
from functools import lru_cache, partial
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Union
from typing import Any, Callable, cast, TYPE_CHECKING, TypedDict, Union
import dateutil
from flask import current_app, g, has_request_context, request
@@ -109,6 +109,7 @@ class ExtraCache:
r"current_user_id\([^()]*\)|"
r"current_username\([^()]*\)|"
r"current_user_email\([^()]*\)|"
r"current_user_rls_rules\([^()]*\)|"
r"current_user_roles\([^()]*\)|"
r"cache_key_wrapper\([^()]*\)|"
r"url_param\([^()]*\)"
@@ -118,12 +119,12 @@ class ExtraCache:
def __init__( # pylint: disable=too-many-arguments
self,
extra_cache_keys: Optional[list[Any]] = None,
applied_filters: Optional[list[str]] = None,
removed_filters: Optional[list[str]] = None,
database: Optional[Database] = None,
dialect: Optional[Dialect] = None,
table: Optional[SqlaTable] = None,
extra_cache_keys: list[Any] | None = None,
applied_filters: list[str] | None = None,
removed_filters: list[str] | None = None,
database: Database | None = None,
dialect: Dialect | None = None,
table: SqlaTable | None = None,
):
self.extra_cache_keys = extra_cache_keys
self.applied_filters = applied_filters if applied_filters is not None else []
@@ -132,7 +133,7 @@ class ExtraCache:
self.dialect = dialect
self.table = table
def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:
def current_user_id(self, add_to_cache_keys: bool = True) -> int | None:
"""
Return the user ID of the user who is currently logged in.
@@ -146,7 +147,7 @@ class ExtraCache:
return user_id
return None
def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]:
def current_username(self, add_to_cache_keys: bool = True) -> str | None:
"""
Return the username of the user who is currently logged in.
@@ -160,7 +161,7 @@ class ExtraCache:
return username
return None
def current_user_email(self, add_to_cache_keys: bool = True) -> Optional[str]:
def current_user_email(self, add_to_cache_keys: bool = True) -> str | None:
"""
Return the email address of the user who is currently logged in.
@@ -193,6 +194,31 @@ class ExtraCache:
except Exception: # pylint: disable=broad-except
return None
def current_user_rls_rules(self) -> list[str] | None:
"""
Return the row level security rules applied to the current user and dataset.
"""
if not self.table:
return None
rls_rules = (
sorted(
[
rule["clause"]
for rule in security_manager.get_guest_rls_filters(self.table)
]
)
if security_manager.is_guest_user()
else sorted(
[rule.clause for rule in security_manager.get_rls_filters(self.table)]
)
)
if not rls_rules:
return None
self.cache_key_wrapper(json.dumps(rls_rules))
return rls_rules
def cache_key_wrapper(self, key: Any) -> Any:
"""
Adds values to a list that is added to the query object used for calculating a
@@ -213,10 +239,10 @@ class ExtraCache:
def url_param(
self,
param: str,
default: Optional[str] = None,
default: str | None = None,
add_to_cache_keys: bool = True,
escape_result: bool = True,
) -> Optional[str]:
) -> str | None:
"""
Read a url or post parameter and use it in your SQL Lab query.
@@ -259,7 +285,7 @@ class ExtraCache:
return result
def filter_values(
self, column: str, default: Optional[str] = None, remove_filter: bool = False
self, column: str, default: str | None = None, remove_filter: bool = False
) -> list[Any]:
"""Gets a values for a particular filter as a list
@@ -524,7 +550,7 @@ def validate_context_types(context: dict[str, Any]) -> dict[str, Any]:
def validate_template_context(
engine: Optional[str], context: dict[str, Any]
engine: str | None, context: dict[str, Any]
) -> dict[str, Any]:
if engine and engine in context:
# validate engine context separately to allow for engine-specific methods
@@ -543,7 +569,7 @@ class WhereInMacro: # pylint: disable=too-few-public-methods
def __call__(
self,
values: list[Any],
mark: Optional[str] = None,
mark: str | None = None,
default_to_none: bool = False,
) -> str | None:
"""
@@ -605,17 +631,17 @@ class BaseTemplateProcessor:
Base class for database-specific jinja context
"""
engine: Optional[str] = None
engine: str | None = None
# pylint: disable=too-many-arguments
def __init__(
self,
database: "Database",
query: Optional["Query"] = None,
table: Optional["SqlaTable"] = None,
extra_cache_keys: Optional[list[Any]] = None,
removed_filters: Optional[list[str]] = None,
applied_filters: Optional[list[str]] = None,
query: "Query" | None = None,
table: "SqlaTable" | None = None,
extra_cache_keys: list[Any] | None = None,
removed_filters: list[str] | None = None,
applied_filters: list[str] | None = None,
**kwargs: Any,
) -> None:
self._database = database
@@ -667,7 +693,7 @@ class BaseTemplateProcessor:
class JinjaTemplateProcessor(BaseTemplateProcessor):
def _parse_datetime(self, dttm: str) -> Optional[datetime]:
def _parse_datetime(self, dttm: str) -> datetime | None:
"""
Try to parse a datetime and default to None in the worst case.
@@ -719,6 +745,9 @@ class JinjaTemplateProcessor(BaseTemplateProcessor):
"current_user_roles": partial(
safe_proxy, extra_cache.current_user_roles
),
"current_user_rls_rules": partial(
safe_proxy, extra_cache.current_user_rls_rules
),
"cache_key_wrapper": partial(safe_proxy, extra_cache.cache_key_wrapper),
"filter_values": partial(safe_proxy, extra_cache.filter_values),
"get_filters": partial(safe_proxy, extra_cache.get_filters),
@@ -763,14 +792,12 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor):
}
@staticmethod
def _schema_table(
table_name: str, schema: Optional[str]
) -> tuple[str, Optional[str]]:
def _schema_table(table_name: str, schema: str | None) -> tuple[str, str | None]:
if "." in table_name:
schema, table_name = table_name.split(".")
return table_name, schema
def first_latest_partition(self, table_name: str) -> Optional[str]:
def first_latest_partition(self, table_name: str) -> str | None:
"""
Gets the first value in the array of all latest partitions
@@ -782,7 +809,7 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor):
latest_partitions = self.latest_partitions(table_name)
return latest_partitions[0] if latest_partitions else None
def latest_partitions(self, table_name: str) -> Optional[list[str]]:
def latest_partitions(self, table_name: str) -> list[str] | None:
"""
Gets the array of all latest partitions
@@ -864,8 +891,8 @@ def get_template_processors() -> dict[str, Any]:
def get_template_processor(
database: "Database",
table: Optional["SqlaTable"] = None,
query: Optional["Query"] = None,
table: "SqlaTable" | None = None,
query: "Query" | None = None,
**kwargs: Any,
) -> BaseTemplateProcessor:
if feature_flag_manager.is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
@@ -880,9 +907,9 @@ def get_template_processor(
def dataset_macro(
dataset_id: int,
include_metrics: bool = False,
columns: Optional[list[str]] = None,
from_dttm: Optional[datetime] = None,
to_dttm: Optional[datetime] = None,
columns: list[str] | None = None,
from_dttm: datetime | None = None,
to_dttm: datetime | None = None,
) -> str:
"""
Given a dataset ID, return the SQL that represents it.
@@ -964,7 +991,7 @@ def metric_macro(
env: Environment,
context: dict[str, Any],
metric_key: str,
dataset_id: Optional[int] = None,
dataset_id: int | None = None,
) -> str:
"""
Given a metric key, returns its syntax.

View File

@@ -31,7 +31,12 @@ from sqlalchemy.dialects.postgresql import dialect
from superset import app
from superset.commands.dataset.exceptions import DatasetNotFoundError
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
from superset.connectors.sqla.models import (
RowLevelSecurityFilter,
SqlaTable,
SqlMetric,
TableColumn,
)
from superset.exceptions import SupersetTemplateException
from superset.jinja_context import (
dataset_macro,
@@ -46,6 +51,7 @@ from superset.jinja_context import (
from superset.models.core import Database
from superset.models.slice import Slice
from superset.utils import json
from tests.unit_tests.conftest import with_feature_flags
def test_filter_values_adhoc_filters() -> None:
@@ -355,16 +361,29 @@ def test_safe_proxy_nested_lambda() -> None:
safe_proxy(func, {"foo": lambda: "bar"})
def test_user_macros(mocker: MockerFixture):
@pytest.mark.parametrize(
"add_to_cache_keys,mock_cache_key_wrapper_call_count",
[
(True, 4),
(False, 0),
],
)
def test_user_macros(
mocker: MockerFixture,
add_to_cache_keys: bool,
mock_cache_key_wrapper_call_count: int,
):
"""
Test all user macros:
- ``current_user_id``
- ``current_username``
- ``current_user_email``
- ``current_user_roles``
- ``current_user_rls_rules``
"""
mock_g = mocker.patch("superset.utils.core.g")
mock_get_user_roles = mocker.patch("superset.security_manager.get_user_roles")
mock_get_user_rls = mocker.patch("superset.security_manager.get_rls_filters")
mock_cache_key_wrapper = mocker.patch(
"superset.jinja_context.ExtraCache.cache_key_wrapper"
)
@@ -372,36 +391,20 @@ def test_user_macros(mocker: MockerFixture):
mock_g.user.username = "my_username"
mock_g.user.email = "my_email@test.com"
mock_get_user_roles.return_value = [Role(name="my_role1"), Role(name="my_role2")]
cache = ExtraCache()
assert cache.current_user_id() == 1
assert cache.current_username() == "my_username"
assert cache.current_user_email() == "my_email@test.com"
assert cache.current_user_roles() == ["my_role1", "my_role2"]
assert mock_cache_key_wrapper.call_count == 4
mock_get_user_rls.return_value = [
RowLevelSecurityFilter(group_key="test", clause="1=1"),
RowLevelSecurityFilter(group_key="other_test", clause="product_id=1"),
]
cache = ExtraCache(table=mocker.MagicMock())
assert cache.current_user_id(add_to_cache_keys) == 1
assert cache.current_username(add_to_cache_keys) == "my_username"
assert cache.current_user_email(add_to_cache_keys) == "my_email@test.com"
assert cache.current_user_roles(add_to_cache_keys) == ["my_role1", "my_role2"]
assert mock_cache_key_wrapper.call_count == mock_cache_key_wrapper_call_count
mock_get_user_roles.return_value = []
assert cache.current_user_roles() is None
def test_user_macros_without_cache_key_inclusion(mocker: MockerFixture):
"""
Test all user macros with ``add_to_cache_keys`` set to ``False``.
"""
mock_g = mocker.patch("superset.utils.core.g")
mock_get_user_roles = mocker.patch("superset.security_manager.get_user_roles")
mock_cache_key_wrapper = mocker.patch(
"superset.jinja_context.ExtraCache.cache_key_wrapper"
)
mock_g.user.id = 1
mock_g.user.username = "my_username"
mock_g.user.email = "my_email@test.com"
mock_get_user_roles.return_value = [Role(name="my_role1"), Role(name="my_role2")]
cache = ExtraCache()
assert cache.current_user_id(False) == 1
assert cache.current_username(False) == "my_username"
assert cache.current_user_email(False) == "my_email@test.com"
assert cache.current_user_roles(False) == ["my_role1", "my_role2"]
assert mock_cache_key_wrapper.call_count == 0
# Testing {{ current_user_rls_rules() }} macro isolated and always without
# the param because it does not support it to avoid shared cache.
assert cache.current_user_rls_rules() == ["1=1", "product_id=1"]
def test_user_macros_without_user_info(mocker: MockerFixture):
@@ -410,11 +413,55 @@ def test_user_macros_without_user_info(mocker: MockerFixture):
"""
mock_g = mocker.patch("superset.utils.core.g")
mock_g.user = None
cache = ExtraCache(table=mocker.MagicMock())
assert cache.current_user_id() is None
assert cache.current_username() is None
assert cache.current_user_email() is None
assert cache.current_user_roles() is None
assert cache.current_user_rls_rules() is None
def test_current_user_rls_rules_with_no_table(mocker: MockerFixture):
"""
Test the ``current_user_rls_rules`` macro when no table is provided.
"""
mock_g = mocker.patch("superset.utils.core.g")
mock_get_user_rls = mocker.patch("superset.security_manager.get_rls_filters")
mock_is_guest_user = mocker.patch("superset.security_manager.is_guest_user")
mock_cache_key_wrapper = mocker.patch(
"superset.jinja_context.ExtraCache.cache_key_wrapper"
)
mock_g.user.id = 1
mock_g.user.username = "my_username"
mock_g.user.email = "my_email@test.com"
cache = ExtraCache()
assert cache.current_user_id() == None # noqa: E711
assert cache.current_username() == None # noqa: E711
assert cache.current_user_email() == None # noqa: E711
assert cache.current_user_roles() == None # noqa: E711
assert cache.current_user_rls_rules() is None
assert mock_cache_key_wrapper.call_count == 0
assert mock_get_user_rls.call_count == 0
assert mock_is_guest_user.call_count == 0
@with_feature_flags(EMBEDDED_SUPERSET=True)
def test_current_user_rls_rules_guest_user(mocker: MockerFixture):
"""
Test the ``current_user_rls_rules`` with an embedded user.
"""
mock_g = mocker.patch("superset.utils.core.g")
mock_gg = mocker.patch("superset.tasks.utils.g")
mock_ggg = mocker.patch("superset.security.manager.g")
mock_get_user_rls = mocker.patch("superset.security_manager.get_guest_rls_filters")
mock_user = mocker.MagicMock()
mock_user.username = "my_username"
mock_user.is_guest_user = True
mock_user.is_anonymous = False
mock_g.user = mock_gg.user = mock_ggg.user = mock_user
mock_get_user_rls.return_value = [
{"group_key": "test", "clause": "1=1"},
{"group_key": "other_test", "clause": "product_id=1"},
]
cache = ExtraCache(table=mocker.MagicMock())
assert cache.current_user_rls_rules() == ["1=1", "product_id=1"]
def test_where_in() -> None: