diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 6fb132e1613..de2d41f3878 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -2072,6 +2072,7 @@ class SqlaTable( self.database, self.catalog, self.schema or default_schema or "", + exclude_dataset_id=self.id, ) # Add each predicate as a separate cache key component extra_cache_keys.extend(rls_predicates) diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 6634d25910e..7a2668a49ed 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -2060,12 +2060,17 @@ class ExploreMixin: # pylint: disable=too-many-public-methods default_schema = self.database.get_default_schema(self.catalog) try: rls_applied = False + # ``id`` lives on concrete subclasses (e.g. SqlaTable), not on + # ExploreMixin itself. getattr keeps this safe for non-dataset + # subclasses (e.g. SQL Lab Query), which have no RLS to dedupe. + self_id = getattr(self, "id", None) for statement in parsed_script.statements: if apply_rls( self.database, self.catalog, self.schema or default_schema or "", statement, + exclude_dataset_id=self_id, ): rls_applied = True diff --git a/superset/utils/rls.py b/superset/utils/rls.py index 7e6cdf2aee7..f6643b9fc51 100644 --- a/superset/utils/rls.py +++ b/superset/utils/rls.py @@ -34,10 +34,16 @@ def apply_rls( catalog: str | None, schema: str, parsed_statement: BaseSQLStatement[Any], + exclude_dataset_id: int | None = None, ) -> bool: """ Modify statement inplace to ensure RLS rules are applied. + :param exclude_dataset_id: When applying RLS to a virtual dataset's inner SQL, + pass the virtual dataset's id here so its own RLS isn't injected again + on top of the outer-WHERE application (avoids double-apply when the + virtual dataset's table_name collides with a table in its own SQL — for + example, after converting a physical dataset with RLS to virtual). :returns: True if any RLS predicates were actually applied, False otherwise. """ # There are two ways to insert RLS: either replacing the table with a subquery @@ -55,6 +61,7 @@ def apply_rls( table, database, database.get_default_catalog(), + exclude_dataset_id=exclude_dataset_id, ) if predicate ] @@ -68,6 +75,7 @@ def get_predicates_for_table( table: Table, database: Database, default_catalog: str | None, + exclude_dataset_id: int | None = None, ) -> list[str]: """ Get the RLS predicates for a table. @@ -87,18 +95,21 @@ def get_predicates_for_table( SqlaTable.catalog.is_(None), ) - dataset = ( - db.session.query(SqlaTable) - .filter( - and_( - SqlaTable.database_id == database.id, - catalog_predicate, - SqlaTable.schema == table.schema, - SqlaTable.table_name == table.table, - ) - ) - .one_or_none() - ) + filters = [ + SqlaTable.database_id == database.id, + catalog_predicate, + SqlaTable.schema == table.schema, + SqlaTable.table_name == table.table, + ] + # When applying RLS to a virtual dataset's inner SQL, skip a match against + # the dataset itself — its RLS is already applied on the outer WHERE via + # get_sqla_row_level_filters(). Without this, a virtual dataset whose + # table_name happens to equal a table in its own SQL (e.g. after a + # physical→virtual conversion) double-applies its own predicates. + if exclude_dataset_id is not None: + filters.append(SqlaTable.id != exclude_dataset_id) + + dataset = db.session.query(SqlaTable).filter(and_(*filters)).one_or_none() if not dataset: return [] @@ -130,6 +141,7 @@ def collect_rls_predicates_for_sql( database: Database, catalog: str | None, schema: str, + exclude_dataset_id: int | None = None, ) -> list[str]: """ Collect all RLS predicates that would be applied to tables in the given SQL. @@ -141,6 +153,9 @@ def collect_rls_predicates_for_sql( :param database: The database the query runs against :param catalog: The default catalog for the query :param schema: The default schema for the query + :param exclude_dataset_id: Mirror of the same parameter on apply_rls — pass + the virtual dataset's id so its self-match is excluded from the cache key + (kept consistent with what's actually applied at query time). :return: List of RLS predicate strings that would be applied """ from superset.sql.parse import SQLScript @@ -161,6 +176,7 @@ def collect_rls_predicates_for_sql( table, database, default_catalog, + exclude_dataset_id=exclude_dataset_id, ) } ) diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 62cafdd5692..a7b670bcfc3 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -285,8 +285,18 @@ def test_apply_rls(mocker: MockerFixture) -> None: get_predicates_for_table.assert_has_calls( [ - mocker.call(Table("t1", "public", "examples"), database, "examples"), - mocker.call(Table("t2", "public", "examples"), database, "examples"), + mocker.call( + Table("t1", "public", "examples"), + database, + "examples", + exclude_dataset_id=None, + ), + mocker.call( + Table("t2", "public", "examples"), + database, + "examples", + exclude_dataset_id=None, + ), ] ) @@ -329,3 +339,27 @@ def test_get_predicates_for_table(mocker: MockerFixture) -> None: dataset.get_sqla_row_level_filters.assert_called_once_with( include_global_guest_rls=False ) + + +def test_get_predicates_for_table_excludes_self(mocker: MockerFixture) -> None: + """ + When ``exclude_dataset_id`` is supplied, the lookup query must add an + ``id != exclude_dataset_id`` filter so a virtual dataset whose + ``table_name`` matches a table referenced inside its own SQL doesn't get + its own RLS injected into the inner SQL (would double-apply on top of the + outer WHERE). Regression test for the physical→virtual conversion bug. + """ + database = mocker.MagicMock() + db = mocker.patch("superset.utils.rls.db") + db.session.query().filter().one_or_none.return_value = None + + table = Table("orders", "public", "examples") + assert ( + get_predicates_for_table(table, database, "examples", exclude_dataset_id=42) + == [] + ) + # The filter call should have received four base filters plus the exclusion + # filter, i.e. five total positional args inside and_(). + filter_call = db.session.query().filter.call_args + and_clause = filter_call.args[0] + assert len(and_clause.clauses) == 5