mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
feat: current_user_rls_rules Jinja macro (#33614)
This commit is contained in:
@@ -250,6 +250,14 @@ Will be rendered as:
|
||||
SELECT * FROM users WHERE role IN ('admin', 'viewer')
|
||||
```
|
||||
|
||||
**Current User RLS Rules**
|
||||
|
||||
The `{{ current_user_rls_rules() }}` macro returns an array of RLS rules applied to the current dataset for the logged in user.
|
||||
|
||||
If you have caching enabled in your Superset configuration, then the list of RLS Rules will be used
|
||||
by Superset when calculating the cache key. A cache key is a unique identifier that determines if there's a
|
||||
cache hit in the future and Superset can retrieve cached data.
|
||||
|
||||
**Custom URL Parameters**
|
||||
|
||||
The `{{ url_param('custom_variable') }}` macro lets you define arbitrary URL
|
||||
|
||||
@@ -22,7 +22,7 @@ import re
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from functools import lru_cache, partial
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Union
|
||||
from typing import Any, Callable, cast, TYPE_CHECKING, TypedDict, Union
|
||||
|
||||
import dateutil
|
||||
from flask import current_app, g, has_request_context, request
|
||||
@@ -109,6 +109,7 @@ class ExtraCache:
|
||||
r"current_user_id\([^()]*\)|"
|
||||
r"current_username\([^()]*\)|"
|
||||
r"current_user_email\([^()]*\)|"
|
||||
r"current_user_rls_rules\([^()]*\)|"
|
||||
r"current_user_roles\([^()]*\)|"
|
||||
r"cache_key_wrapper\([^()]*\)|"
|
||||
r"url_param\([^()]*\)"
|
||||
@@ -118,12 +119,12 @@ class ExtraCache:
|
||||
|
||||
def __init__( # pylint: disable=too-many-arguments
|
||||
self,
|
||||
extra_cache_keys: Optional[list[Any]] = None,
|
||||
applied_filters: Optional[list[str]] = None,
|
||||
removed_filters: Optional[list[str]] = None,
|
||||
database: Optional[Database] = None,
|
||||
dialect: Optional[Dialect] = None,
|
||||
table: Optional[SqlaTable] = None,
|
||||
extra_cache_keys: list[Any] | None = None,
|
||||
applied_filters: list[str] | None = None,
|
||||
removed_filters: list[str] | None = None,
|
||||
database: Database | None = None,
|
||||
dialect: Dialect | None = None,
|
||||
table: SqlaTable | None = None,
|
||||
):
|
||||
self.extra_cache_keys = extra_cache_keys
|
||||
self.applied_filters = applied_filters if applied_filters is not None else []
|
||||
@@ -132,7 +133,7 @@ class ExtraCache:
|
||||
self.dialect = dialect
|
||||
self.table = table
|
||||
|
||||
def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]:
|
||||
def current_user_id(self, add_to_cache_keys: bool = True) -> int | None:
|
||||
"""
|
||||
Return the user ID of the user who is currently logged in.
|
||||
|
||||
@@ -146,7 +147,7 @@ class ExtraCache:
|
||||
return user_id
|
||||
return None
|
||||
|
||||
def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]:
|
||||
def current_username(self, add_to_cache_keys: bool = True) -> str | None:
|
||||
"""
|
||||
Return the username of the user who is currently logged in.
|
||||
|
||||
@@ -160,7 +161,7 @@ class ExtraCache:
|
||||
return username
|
||||
return None
|
||||
|
||||
def current_user_email(self, add_to_cache_keys: bool = True) -> Optional[str]:
|
||||
def current_user_email(self, add_to_cache_keys: bool = True) -> str | None:
|
||||
"""
|
||||
Return the email address of the user who is currently logged in.
|
||||
|
||||
@@ -193,6 +194,31 @@ class ExtraCache:
|
||||
except Exception: # pylint: disable=broad-except
|
||||
return None
|
||||
|
||||
def current_user_rls_rules(self) -> list[str] | None:
|
||||
"""
|
||||
Return the row level security rules applied to the current user and dataset.
|
||||
"""
|
||||
if not self.table:
|
||||
return None
|
||||
|
||||
rls_rules = (
|
||||
sorted(
|
||||
[
|
||||
rule["clause"]
|
||||
for rule in security_manager.get_guest_rls_filters(self.table)
|
||||
]
|
||||
)
|
||||
if security_manager.is_guest_user()
|
||||
else sorted(
|
||||
[rule.clause for rule in security_manager.get_rls_filters(self.table)]
|
||||
)
|
||||
)
|
||||
if not rls_rules:
|
||||
return None
|
||||
|
||||
self.cache_key_wrapper(json.dumps(rls_rules))
|
||||
return rls_rules
|
||||
|
||||
def cache_key_wrapper(self, key: Any) -> Any:
|
||||
"""
|
||||
Adds values to a list that is added to the query object used for calculating a
|
||||
@@ -213,10 +239,10 @@ class ExtraCache:
|
||||
def url_param(
|
||||
self,
|
||||
param: str,
|
||||
default: Optional[str] = None,
|
||||
default: str | None = None,
|
||||
add_to_cache_keys: bool = True,
|
||||
escape_result: bool = True,
|
||||
) -> Optional[str]:
|
||||
) -> str | None:
|
||||
"""
|
||||
Read a url or post parameter and use it in your SQL Lab query.
|
||||
|
||||
@@ -259,7 +285,7 @@ class ExtraCache:
|
||||
return result
|
||||
|
||||
def filter_values(
|
||||
self, column: str, default: Optional[str] = None, remove_filter: bool = False
|
||||
self, column: str, default: str | None = None, remove_filter: bool = False
|
||||
) -> list[Any]:
|
||||
"""Gets a values for a particular filter as a list
|
||||
|
||||
@@ -524,7 +550,7 @@ def validate_context_types(context: dict[str, Any]) -> dict[str, Any]:
|
||||
|
||||
|
||||
def validate_template_context(
|
||||
engine: Optional[str], context: dict[str, Any]
|
||||
engine: str | None, context: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
if engine and engine in context:
|
||||
# validate engine context separately to allow for engine-specific methods
|
||||
@@ -543,7 +569,7 @@ class WhereInMacro: # pylint: disable=too-few-public-methods
|
||||
def __call__(
|
||||
self,
|
||||
values: list[Any],
|
||||
mark: Optional[str] = None,
|
||||
mark: str | None = None,
|
||||
default_to_none: bool = False,
|
||||
) -> str | None:
|
||||
"""
|
||||
@@ -605,17 +631,17 @@ class BaseTemplateProcessor:
|
||||
Base class for database-specific jinja context
|
||||
"""
|
||||
|
||||
engine: Optional[str] = None
|
||||
engine: str | None = None
|
||||
|
||||
# pylint: disable=too-many-arguments
|
||||
def __init__(
|
||||
self,
|
||||
database: "Database",
|
||||
query: Optional["Query"] = None,
|
||||
table: Optional["SqlaTable"] = None,
|
||||
extra_cache_keys: Optional[list[Any]] = None,
|
||||
removed_filters: Optional[list[str]] = None,
|
||||
applied_filters: Optional[list[str]] = None,
|
||||
query: "Query" | None = None,
|
||||
table: "SqlaTable" | None = None,
|
||||
extra_cache_keys: list[Any] | None = None,
|
||||
removed_filters: list[str] | None = None,
|
||||
applied_filters: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self._database = database
|
||||
@@ -667,7 +693,7 @@ class BaseTemplateProcessor:
|
||||
|
||||
|
||||
class JinjaTemplateProcessor(BaseTemplateProcessor):
|
||||
def _parse_datetime(self, dttm: str) -> Optional[datetime]:
|
||||
def _parse_datetime(self, dttm: str) -> datetime | None:
|
||||
"""
|
||||
Try to parse a datetime and default to None in the worst case.
|
||||
|
||||
@@ -719,6 +745,9 @@ class JinjaTemplateProcessor(BaseTemplateProcessor):
|
||||
"current_user_roles": partial(
|
||||
safe_proxy, extra_cache.current_user_roles
|
||||
),
|
||||
"current_user_rls_rules": partial(
|
||||
safe_proxy, extra_cache.current_user_rls_rules
|
||||
),
|
||||
"cache_key_wrapper": partial(safe_proxy, extra_cache.cache_key_wrapper),
|
||||
"filter_values": partial(safe_proxy, extra_cache.filter_values),
|
||||
"get_filters": partial(safe_proxy, extra_cache.get_filters),
|
||||
@@ -763,14 +792,12 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _schema_table(
|
||||
table_name: str, schema: Optional[str]
|
||||
) -> tuple[str, Optional[str]]:
|
||||
def _schema_table(table_name: str, schema: str | None) -> tuple[str, str | None]:
|
||||
if "." in table_name:
|
||||
schema, table_name = table_name.split(".")
|
||||
return table_name, schema
|
||||
|
||||
def first_latest_partition(self, table_name: str) -> Optional[str]:
|
||||
def first_latest_partition(self, table_name: str) -> str | None:
|
||||
"""
|
||||
Gets the first value in the array of all latest partitions
|
||||
|
||||
@@ -782,7 +809,7 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor):
|
||||
latest_partitions = self.latest_partitions(table_name)
|
||||
return latest_partitions[0] if latest_partitions else None
|
||||
|
||||
def latest_partitions(self, table_name: str) -> Optional[list[str]]:
|
||||
def latest_partitions(self, table_name: str) -> list[str] | None:
|
||||
"""
|
||||
Gets the array of all latest partitions
|
||||
|
||||
@@ -864,8 +891,8 @@ def get_template_processors() -> dict[str, Any]:
|
||||
|
||||
def get_template_processor(
|
||||
database: "Database",
|
||||
table: Optional["SqlaTable"] = None,
|
||||
query: Optional["Query"] = None,
|
||||
table: "SqlaTable" | None = None,
|
||||
query: "Query" | None = None,
|
||||
**kwargs: Any,
|
||||
) -> BaseTemplateProcessor:
|
||||
if feature_flag_manager.is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"):
|
||||
@@ -880,9 +907,9 @@ def get_template_processor(
|
||||
def dataset_macro(
|
||||
dataset_id: int,
|
||||
include_metrics: bool = False,
|
||||
columns: Optional[list[str]] = None,
|
||||
from_dttm: Optional[datetime] = None,
|
||||
to_dttm: Optional[datetime] = None,
|
||||
columns: list[str] | None = None,
|
||||
from_dttm: datetime | None = None,
|
||||
to_dttm: datetime | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Given a dataset ID, return the SQL that represents it.
|
||||
@@ -964,7 +991,7 @@ def metric_macro(
|
||||
env: Environment,
|
||||
context: dict[str, Any],
|
||||
metric_key: str,
|
||||
dataset_id: Optional[int] = None,
|
||||
dataset_id: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Given a metric key, returns its syntax.
|
||||
|
||||
@@ -31,7 +31,12 @@ from sqlalchemy.dialects.postgresql import dialect
|
||||
|
||||
from superset import app
|
||||
from superset.commands.dataset.exceptions import DatasetNotFoundError
|
||||
from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn
|
||||
from superset.connectors.sqla.models import (
|
||||
RowLevelSecurityFilter,
|
||||
SqlaTable,
|
||||
SqlMetric,
|
||||
TableColumn,
|
||||
)
|
||||
from superset.exceptions import SupersetTemplateException
|
||||
from superset.jinja_context import (
|
||||
dataset_macro,
|
||||
@@ -46,6 +51,7 @@ from superset.jinja_context import (
|
||||
from superset.models.core import Database
|
||||
from superset.models.slice import Slice
|
||||
from superset.utils import json
|
||||
from tests.unit_tests.conftest import with_feature_flags
|
||||
|
||||
|
||||
def test_filter_values_adhoc_filters() -> None:
|
||||
@@ -355,16 +361,29 @@ def test_safe_proxy_nested_lambda() -> None:
|
||||
safe_proxy(func, {"foo": lambda: "bar"})
|
||||
|
||||
|
||||
def test_user_macros(mocker: MockerFixture):
|
||||
@pytest.mark.parametrize(
|
||||
"add_to_cache_keys,mock_cache_key_wrapper_call_count",
|
||||
[
|
||||
(True, 4),
|
||||
(False, 0),
|
||||
],
|
||||
)
|
||||
def test_user_macros(
|
||||
mocker: MockerFixture,
|
||||
add_to_cache_keys: bool,
|
||||
mock_cache_key_wrapper_call_count: int,
|
||||
):
|
||||
"""
|
||||
Test all user macros:
|
||||
- ``current_user_id``
|
||||
- ``current_username``
|
||||
- ``current_user_email``
|
||||
- ``current_user_roles``
|
||||
- ``current_user_rls_rules``
|
||||
"""
|
||||
mock_g = mocker.patch("superset.utils.core.g")
|
||||
mock_get_user_roles = mocker.patch("superset.security_manager.get_user_roles")
|
||||
mock_get_user_rls = mocker.patch("superset.security_manager.get_rls_filters")
|
||||
mock_cache_key_wrapper = mocker.patch(
|
||||
"superset.jinja_context.ExtraCache.cache_key_wrapper"
|
||||
)
|
||||
@@ -372,36 +391,20 @@ def test_user_macros(mocker: MockerFixture):
|
||||
mock_g.user.username = "my_username"
|
||||
mock_g.user.email = "my_email@test.com"
|
||||
mock_get_user_roles.return_value = [Role(name="my_role1"), Role(name="my_role2")]
|
||||
cache = ExtraCache()
|
||||
assert cache.current_user_id() == 1
|
||||
assert cache.current_username() == "my_username"
|
||||
assert cache.current_user_email() == "my_email@test.com"
|
||||
assert cache.current_user_roles() == ["my_role1", "my_role2"]
|
||||
assert mock_cache_key_wrapper.call_count == 4
|
||||
mock_get_user_rls.return_value = [
|
||||
RowLevelSecurityFilter(group_key="test", clause="1=1"),
|
||||
RowLevelSecurityFilter(group_key="other_test", clause="product_id=1"),
|
||||
]
|
||||
cache = ExtraCache(table=mocker.MagicMock())
|
||||
assert cache.current_user_id(add_to_cache_keys) == 1
|
||||
assert cache.current_username(add_to_cache_keys) == "my_username"
|
||||
assert cache.current_user_email(add_to_cache_keys) == "my_email@test.com"
|
||||
assert cache.current_user_roles(add_to_cache_keys) == ["my_role1", "my_role2"]
|
||||
assert mock_cache_key_wrapper.call_count == mock_cache_key_wrapper_call_count
|
||||
|
||||
mock_get_user_roles.return_value = []
|
||||
assert cache.current_user_roles() is None
|
||||
|
||||
|
||||
def test_user_macros_without_cache_key_inclusion(mocker: MockerFixture):
|
||||
"""
|
||||
Test all user macros with ``add_to_cache_keys`` set to ``False``.
|
||||
"""
|
||||
mock_g = mocker.patch("superset.utils.core.g")
|
||||
mock_get_user_roles = mocker.patch("superset.security_manager.get_user_roles")
|
||||
mock_cache_key_wrapper = mocker.patch(
|
||||
"superset.jinja_context.ExtraCache.cache_key_wrapper"
|
||||
)
|
||||
mock_g.user.id = 1
|
||||
mock_g.user.username = "my_username"
|
||||
mock_g.user.email = "my_email@test.com"
|
||||
mock_get_user_roles.return_value = [Role(name="my_role1"), Role(name="my_role2")]
|
||||
cache = ExtraCache()
|
||||
assert cache.current_user_id(False) == 1
|
||||
assert cache.current_username(False) == "my_username"
|
||||
assert cache.current_user_email(False) == "my_email@test.com"
|
||||
assert cache.current_user_roles(False) == ["my_role1", "my_role2"]
|
||||
assert mock_cache_key_wrapper.call_count == 0
|
||||
# Testing {{ current_user_rls_rules() }} macro isolated and always without
|
||||
# the param because it does not support it to avoid shared cache.
|
||||
assert cache.current_user_rls_rules() == ["1=1", "product_id=1"]
|
||||
|
||||
|
||||
def test_user_macros_without_user_info(mocker: MockerFixture):
|
||||
@@ -410,11 +413,55 @@ def test_user_macros_without_user_info(mocker: MockerFixture):
|
||||
"""
|
||||
mock_g = mocker.patch("superset.utils.core.g")
|
||||
mock_g.user = None
|
||||
cache = ExtraCache(table=mocker.MagicMock())
|
||||
assert cache.current_user_id() is None
|
||||
assert cache.current_username() is None
|
||||
assert cache.current_user_email() is None
|
||||
assert cache.current_user_roles() is None
|
||||
assert cache.current_user_rls_rules() is None
|
||||
|
||||
|
||||
def test_current_user_rls_rules_with_no_table(mocker: MockerFixture):
|
||||
"""
|
||||
Test the ``current_user_rls_rules`` macro when no table is provided.
|
||||
"""
|
||||
mock_g = mocker.patch("superset.utils.core.g")
|
||||
mock_get_user_rls = mocker.patch("superset.security_manager.get_rls_filters")
|
||||
mock_is_guest_user = mocker.patch("superset.security_manager.is_guest_user")
|
||||
mock_cache_key_wrapper = mocker.patch(
|
||||
"superset.jinja_context.ExtraCache.cache_key_wrapper"
|
||||
)
|
||||
mock_g.user.id = 1
|
||||
mock_g.user.username = "my_username"
|
||||
mock_g.user.email = "my_email@test.com"
|
||||
cache = ExtraCache()
|
||||
assert cache.current_user_id() == None # noqa: E711
|
||||
assert cache.current_username() == None # noqa: E711
|
||||
assert cache.current_user_email() == None # noqa: E711
|
||||
assert cache.current_user_roles() == None # noqa: E711
|
||||
assert cache.current_user_rls_rules() is None
|
||||
assert mock_cache_key_wrapper.call_count == 0
|
||||
assert mock_get_user_rls.call_count == 0
|
||||
assert mock_is_guest_user.call_count == 0
|
||||
|
||||
|
||||
@with_feature_flags(EMBEDDED_SUPERSET=True)
|
||||
def test_current_user_rls_rules_guest_user(mocker: MockerFixture):
|
||||
"""
|
||||
Test the ``current_user_rls_rules`` with an embedded user.
|
||||
"""
|
||||
mock_g = mocker.patch("superset.utils.core.g")
|
||||
mock_gg = mocker.patch("superset.tasks.utils.g")
|
||||
mock_ggg = mocker.patch("superset.security.manager.g")
|
||||
mock_get_user_rls = mocker.patch("superset.security_manager.get_guest_rls_filters")
|
||||
mock_user = mocker.MagicMock()
|
||||
mock_user.username = "my_username"
|
||||
mock_user.is_guest_user = True
|
||||
mock_user.is_anonymous = False
|
||||
mock_g.user = mock_gg.user = mock_ggg.user = mock_user
|
||||
|
||||
mock_get_user_rls.return_value = [
|
||||
{"group_key": "test", "clause": "1=1"},
|
||||
{"group_key": "other_test", "clause": "product_id=1"},
|
||||
]
|
||||
cache = ExtraCache(table=mocker.MagicMock())
|
||||
assert cache.current_user_rls_rules() == ["1=1", "product_id=1"]
|
||||
|
||||
|
||||
def test_where_in() -> None:
|
||||
|
||||
Reference in New Issue
Block a user