diff --git a/docs/docs/configuration/sql-templating.mdx b/docs/docs/configuration/sql-templating.mdx index dcd5bb0869b..d2b74afa1ab 100644 --- a/docs/docs/configuration/sql-templating.mdx +++ b/docs/docs/configuration/sql-templating.mdx @@ -250,6 +250,14 @@ Will be rendered as: SELECT * FROM users WHERE role IN ('admin', 'viewer') ``` +**Current User RLS Rules** + +The `{{ current_user_rls_rules() }}` macro returns an array of RLS rules applied to the current dataset for the logged in user. + +If you have caching enabled in your Superset configuration, then the list of RLS Rules will be used +by Superset when calculating the cache key. A cache key is a unique identifier that determines if there's a +cache hit in the future and Superset can retrieve cached data. + **Custom URL Parameters** The `{{ url_param('custom_variable') }}` macro lets you define arbitrary URL diff --git a/superset/jinja_context.py b/superset/jinja_context.py index 9765ccbdc1a..837150f0333 100644 --- a/superset/jinja_context.py +++ b/superset/jinja_context.py @@ -22,7 +22,7 @@ import re from dataclasses import dataclass from datetime import datetime from functools import lru_cache, partial -from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypedDict, Union +from typing import Any, Callable, cast, TYPE_CHECKING, TypedDict, Union import dateutil from flask import current_app, g, has_request_context, request @@ -109,6 +109,7 @@ class ExtraCache: r"current_user_id\([^()]*\)|" r"current_username\([^()]*\)|" r"current_user_email\([^()]*\)|" + r"current_user_rls_rules\([^()]*\)|" r"current_user_roles\([^()]*\)|" r"cache_key_wrapper\([^()]*\)|" r"url_param\([^()]*\)" @@ -118,12 +119,12 @@ class ExtraCache: def __init__( # pylint: disable=too-many-arguments self, - extra_cache_keys: Optional[list[Any]] = None, - applied_filters: Optional[list[str]] = None, - removed_filters: Optional[list[str]] = None, - database: Optional[Database] = None, - dialect: Optional[Dialect] = None, - table: Optional[SqlaTable] = None, + extra_cache_keys: list[Any] | None = None, + applied_filters: list[str] | None = None, + removed_filters: list[str] | None = None, + database: Database | None = None, + dialect: Dialect | None = None, + table: SqlaTable | None = None, ): self.extra_cache_keys = extra_cache_keys self.applied_filters = applied_filters if applied_filters is not None else [] @@ -132,7 +133,7 @@ class ExtraCache: self.dialect = dialect self.table = table - def current_user_id(self, add_to_cache_keys: bool = True) -> Optional[int]: + def current_user_id(self, add_to_cache_keys: bool = True) -> int | None: """ Return the user ID of the user who is currently logged in. @@ -146,7 +147,7 @@ class ExtraCache: return user_id return None - def current_username(self, add_to_cache_keys: bool = True) -> Optional[str]: + def current_username(self, add_to_cache_keys: bool = True) -> str | None: """ Return the username of the user who is currently logged in. @@ -160,7 +161,7 @@ class ExtraCache: return username return None - def current_user_email(self, add_to_cache_keys: bool = True) -> Optional[str]: + def current_user_email(self, add_to_cache_keys: bool = True) -> str | None: """ Return the email address of the user who is currently logged in. @@ -193,6 +194,31 @@ class ExtraCache: except Exception: # pylint: disable=broad-except return None + def current_user_rls_rules(self) -> list[str] | None: + """ + Return the row level security rules applied to the current user and dataset. + """ + if not self.table: + return None + + rls_rules = ( + sorted( + [ + rule["clause"] + for rule in security_manager.get_guest_rls_filters(self.table) + ] + ) + if security_manager.is_guest_user() + else sorted( + [rule.clause for rule in security_manager.get_rls_filters(self.table)] + ) + ) + if not rls_rules: + return None + + self.cache_key_wrapper(json.dumps(rls_rules)) + return rls_rules + def cache_key_wrapper(self, key: Any) -> Any: """ Adds values to a list that is added to the query object used for calculating a @@ -213,10 +239,10 @@ class ExtraCache: def url_param( self, param: str, - default: Optional[str] = None, + default: str | None = None, add_to_cache_keys: bool = True, escape_result: bool = True, - ) -> Optional[str]: + ) -> str | None: """ Read a url or post parameter and use it in your SQL Lab query. @@ -259,7 +285,7 @@ class ExtraCache: return result def filter_values( - self, column: str, default: Optional[str] = None, remove_filter: bool = False + self, column: str, default: str | None = None, remove_filter: bool = False ) -> list[Any]: """Gets a values for a particular filter as a list @@ -524,7 +550,7 @@ def validate_context_types(context: dict[str, Any]) -> dict[str, Any]: def validate_template_context( - engine: Optional[str], context: dict[str, Any] + engine: str | None, context: dict[str, Any] ) -> dict[str, Any]: if engine and engine in context: # validate engine context separately to allow for engine-specific methods @@ -543,7 +569,7 @@ class WhereInMacro: # pylint: disable=too-few-public-methods def __call__( self, values: list[Any], - mark: Optional[str] = None, + mark: str | None = None, default_to_none: bool = False, ) -> str | None: """ @@ -605,17 +631,17 @@ class BaseTemplateProcessor: Base class for database-specific jinja context """ - engine: Optional[str] = None + engine: str | None = None # pylint: disable=too-many-arguments def __init__( self, database: "Database", - query: Optional["Query"] = None, - table: Optional["SqlaTable"] = None, - extra_cache_keys: Optional[list[Any]] = None, - removed_filters: Optional[list[str]] = None, - applied_filters: Optional[list[str]] = None, + query: "Query" | None = None, + table: "SqlaTable" | None = None, + extra_cache_keys: list[Any] | None = None, + removed_filters: list[str] | None = None, + applied_filters: list[str] | None = None, **kwargs: Any, ) -> None: self._database = database @@ -667,7 +693,7 @@ class BaseTemplateProcessor: class JinjaTemplateProcessor(BaseTemplateProcessor): - def _parse_datetime(self, dttm: str) -> Optional[datetime]: + def _parse_datetime(self, dttm: str) -> datetime | None: """ Try to parse a datetime and default to None in the worst case. @@ -719,6 +745,9 @@ class JinjaTemplateProcessor(BaseTemplateProcessor): "current_user_roles": partial( safe_proxy, extra_cache.current_user_roles ), + "current_user_rls_rules": partial( + safe_proxy, extra_cache.current_user_rls_rules + ), "cache_key_wrapper": partial(safe_proxy, extra_cache.cache_key_wrapper), "filter_values": partial(safe_proxy, extra_cache.filter_values), "get_filters": partial(safe_proxy, extra_cache.get_filters), @@ -763,14 +792,12 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor): } @staticmethod - def _schema_table( - table_name: str, schema: Optional[str] - ) -> tuple[str, Optional[str]]: + def _schema_table(table_name: str, schema: str | None) -> tuple[str, str | None]: if "." in table_name: schema, table_name = table_name.split(".") return table_name, schema - def first_latest_partition(self, table_name: str) -> Optional[str]: + def first_latest_partition(self, table_name: str) -> str | None: """ Gets the first value in the array of all latest partitions @@ -782,7 +809,7 @@ class PrestoTemplateProcessor(JinjaTemplateProcessor): latest_partitions = self.latest_partitions(table_name) return latest_partitions[0] if latest_partitions else None - def latest_partitions(self, table_name: str) -> Optional[list[str]]: + def latest_partitions(self, table_name: str) -> list[str] | None: """ Gets the array of all latest partitions @@ -864,8 +891,8 @@ def get_template_processors() -> dict[str, Any]: def get_template_processor( database: "Database", - table: Optional["SqlaTable"] = None, - query: Optional["Query"] = None, + table: "SqlaTable" | None = None, + query: "Query" | None = None, **kwargs: Any, ) -> BaseTemplateProcessor: if feature_flag_manager.is_feature_enabled("ENABLE_TEMPLATE_PROCESSING"): @@ -880,9 +907,9 @@ def get_template_processor( def dataset_macro( dataset_id: int, include_metrics: bool = False, - columns: Optional[list[str]] = None, - from_dttm: Optional[datetime] = None, - to_dttm: Optional[datetime] = None, + columns: list[str] | None = None, + from_dttm: datetime | None = None, + to_dttm: datetime | None = None, ) -> str: """ Given a dataset ID, return the SQL that represents it. @@ -964,7 +991,7 @@ def metric_macro( env: Environment, context: dict[str, Any], metric_key: str, - dataset_id: Optional[int] = None, + dataset_id: int | None = None, ) -> str: """ Given a metric key, returns its syntax. diff --git a/tests/unit_tests/jinja_context_test.py b/tests/unit_tests/jinja_context_test.py index fa79cc04936..4cb8cbec2d2 100644 --- a/tests/unit_tests/jinja_context_test.py +++ b/tests/unit_tests/jinja_context_test.py @@ -31,7 +31,12 @@ from sqlalchemy.dialects.postgresql import dialect from superset import app from superset.commands.dataset.exceptions import DatasetNotFoundError -from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn +from superset.connectors.sqla.models import ( + RowLevelSecurityFilter, + SqlaTable, + SqlMetric, + TableColumn, +) from superset.exceptions import SupersetTemplateException from superset.jinja_context import ( dataset_macro, @@ -46,6 +51,7 @@ from superset.jinja_context import ( from superset.models.core import Database from superset.models.slice import Slice from superset.utils import json +from tests.unit_tests.conftest import with_feature_flags def test_filter_values_adhoc_filters() -> None: @@ -355,16 +361,29 @@ def test_safe_proxy_nested_lambda() -> None: safe_proxy(func, {"foo": lambda: "bar"}) -def test_user_macros(mocker: MockerFixture): +@pytest.mark.parametrize( + "add_to_cache_keys,mock_cache_key_wrapper_call_count", + [ + (True, 4), + (False, 0), + ], +) +def test_user_macros( + mocker: MockerFixture, + add_to_cache_keys: bool, + mock_cache_key_wrapper_call_count: int, +): """ Test all user macros: - ``current_user_id`` - ``current_username`` - ``current_user_email`` - ``current_user_roles`` + - ``current_user_rls_rules`` """ mock_g = mocker.patch("superset.utils.core.g") mock_get_user_roles = mocker.patch("superset.security_manager.get_user_roles") + mock_get_user_rls = mocker.patch("superset.security_manager.get_rls_filters") mock_cache_key_wrapper = mocker.patch( "superset.jinja_context.ExtraCache.cache_key_wrapper" ) @@ -372,36 +391,20 @@ def test_user_macros(mocker: MockerFixture): mock_g.user.username = "my_username" mock_g.user.email = "my_email@test.com" mock_get_user_roles.return_value = [Role(name="my_role1"), Role(name="my_role2")] - cache = ExtraCache() - assert cache.current_user_id() == 1 - assert cache.current_username() == "my_username" - assert cache.current_user_email() == "my_email@test.com" - assert cache.current_user_roles() == ["my_role1", "my_role2"] - assert mock_cache_key_wrapper.call_count == 4 + mock_get_user_rls.return_value = [ + RowLevelSecurityFilter(group_key="test", clause="1=1"), + RowLevelSecurityFilter(group_key="other_test", clause="product_id=1"), + ] + cache = ExtraCache(table=mocker.MagicMock()) + assert cache.current_user_id(add_to_cache_keys) == 1 + assert cache.current_username(add_to_cache_keys) == "my_username" + assert cache.current_user_email(add_to_cache_keys) == "my_email@test.com" + assert cache.current_user_roles(add_to_cache_keys) == ["my_role1", "my_role2"] + assert mock_cache_key_wrapper.call_count == mock_cache_key_wrapper_call_count - mock_get_user_roles.return_value = [] - assert cache.current_user_roles() is None - - -def test_user_macros_without_cache_key_inclusion(mocker: MockerFixture): - """ - Test all user macros with ``add_to_cache_keys`` set to ``False``. - """ - mock_g = mocker.patch("superset.utils.core.g") - mock_get_user_roles = mocker.patch("superset.security_manager.get_user_roles") - mock_cache_key_wrapper = mocker.patch( - "superset.jinja_context.ExtraCache.cache_key_wrapper" - ) - mock_g.user.id = 1 - mock_g.user.username = "my_username" - mock_g.user.email = "my_email@test.com" - mock_get_user_roles.return_value = [Role(name="my_role1"), Role(name="my_role2")] - cache = ExtraCache() - assert cache.current_user_id(False) == 1 - assert cache.current_username(False) == "my_username" - assert cache.current_user_email(False) == "my_email@test.com" - assert cache.current_user_roles(False) == ["my_role1", "my_role2"] - assert mock_cache_key_wrapper.call_count == 0 + # Testing {{ current_user_rls_rules() }} macro isolated and always without + # the param because it does not support it to avoid shared cache. + assert cache.current_user_rls_rules() == ["1=1", "product_id=1"] def test_user_macros_without_user_info(mocker: MockerFixture): @@ -410,11 +413,55 @@ def test_user_macros_without_user_info(mocker: MockerFixture): """ mock_g = mocker.patch("superset.utils.core.g") mock_g.user = None + cache = ExtraCache(table=mocker.MagicMock()) + assert cache.current_user_id() is None + assert cache.current_username() is None + assert cache.current_user_email() is None + assert cache.current_user_roles() is None + assert cache.current_user_rls_rules() is None + + +def test_current_user_rls_rules_with_no_table(mocker: MockerFixture): + """ + Test the ``current_user_rls_rules`` macro when no table is provided. + """ + mock_g = mocker.patch("superset.utils.core.g") + mock_get_user_rls = mocker.patch("superset.security_manager.get_rls_filters") + mock_is_guest_user = mocker.patch("superset.security_manager.is_guest_user") + mock_cache_key_wrapper = mocker.patch( + "superset.jinja_context.ExtraCache.cache_key_wrapper" + ) + mock_g.user.id = 1 + mock_g.user.username = "my_username" + mock_g.user.email = "my_email@test.com" cache = ExtraCache() - assert cache.current_user_id() == None # noqa: E711 - assert cache.current_username() == None # noqa: E711 - assert cache.current_user_email() == None # noqa: E711 - assert cache.current_user_roles() == None # noqa: E711 + assert cache.current_user_rls_rules() is None + assert mock_cache_key_wrapper.call_count == 0 + assert mock_get_user_rls.call_count == 0 + assert mock_is_guest_user.call_count == 0 + + +@with_feature_flags(EMBEDDED_SUPERSET=True) +def test_current_user_rls_rules_guest_user(mocker: MockerFixture): + """ + Test the ``current_user_rls_rules`` with an embedded user. + """ + mock_g = mocker.patch("superset.utils.core.g") + mock_gg = mocker.patch("superset.tasks.utils.g") + mock_ggg = mocker.patch("superset.security.manager.g") + mock_get_user_rls = mocker.patch("superset.security_manager.get_guest_rls_filters") + mock_user = mocker.MagicMock() + mock_user.username = "my_username" + mock_user.is_guest_user = True + mock_user.is_anonymous = False + mock_g.user = mock_gg.user = mock_ggg.user = mock_user + + mock_get_user_rls.return_value = [ + {"group_key": "test", "clause": "1=1"}, + {"group_key": "other_test", "clause": "product_id=1"}, + ] + cache = ExtraCache(table=mocker.MagicMock()) + assert cache.current_user_rls_rules() == ["1=1", "product_id=1"] def test_where_in() -> None: