diff --git a/superset/security/manager.py b/superset/security/manager.py index 6c47c6e163d..47e772d765d 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -79,7 +79,7 @@ from superset.utils.urls import get_url_host if TYPE_CHECKING: from superset.common.query_context import QueryContext from superset.connectors.base.models import BaseDatasource - from superset.connectors.sqla.models import SqlaTable + from superset.connectors.sqla.models import RowLevelSecurityFilter, SqlaTable from superset.models.core import Database from superset.models.dashboard import Dashboard from superset.models.sql_lab import Query @@ -2091,28 +2091,30 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods ) return query.all() - def get_rls_ids(self, table: "BaseDatasource") -> list[int]: + def get_rls_sorted(self, table: "BaseDatasource") -> list["RowLevelSecurityFilter"]: """ - Retrieves the appropriate row level security filters IDs for the current user - and the passed table. + Retrieves a list RLS filters sorted by ID for + the current user and the passed table. :param table: The table to check against - :returns: A list of IDs + :returns: A list RLS filters """ - ids = [f.id for f in self.get_rls_filters(table)] - ids.sort() # Combinations rather than permutations - return ids + filters = self.get_rls_filters(table) + filters.sort(key=lambda f: f.id) + return filters def get_guest_rls_filters_str(self, table: "BaseDatasource") -> list[str]: return [f.get("clause", "") for f in self.get_guest_rls_filters(table)] def get_rls_cache_key(self, datasource: "BaseDatasource") -> list[str]: - rls_ids = [] + rls_clauses_with_group_key = [] if datasource.is_rls_supported: - rls_ids = self.get_rls_ids(datasource) - rls_str = [str(rls_id) for rls_id in rls_ids] + rls_clauses_with_group_key = [ + f"{f.clause}-{f.group_key or ''}" + for f in self.get_rls_sorted(datasource) + ] guest_rls = self.get_guest_rls_filters_str(datasource) - return guest_rls + rls_str + return guest_rls + rls_clauses_with_group_key @staticmethod def _get_current_epoch_time() -> float: diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index c29ebe9afef..41ca0d5e798 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -305,6 +305,21 @@ class TestRowLevelSecurity(SupersetTestCase): assert not self.NAMES_Q_REGEX.search(sql) assert not self.BASE_FILTER_REGEX.search(sql) + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_get_rls_cache_key(self): + g.user = self.get_user(username="admin") + tbl = self.get_table(name="birth_names") + clauses = security_manager.get_rls_cache_key(tbl) + assert clauses == [] + + g.user = self.get_user(username="gamma") + clauses = security_manager.get_rls_cache_key(tbl) + assert clauses == [ + "name like 'A%' or name like 'B%'-name", + "name like 'Q%'-name", + "gender = 'boy'-gender", + ] + class TestRowLevelSecurityCreateAPI(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")