diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index a4c12dd00cc..31eadbd8d8f 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -1913,13 +1913,31 @@ class SqlaTable( The cache key of a SqlaTable needs to consider any keys added by the parent class and any keys added via `ExtraCache`. + For virtual datasets, RLS predicates are included in the cache key to ensure + users with different RLS rules get different cached results. + :param query_obj: query object to analyze :return: The extra cache keys """ + from superset.utils.rls import collect_rls_predicates_for_sql + extra_cache_keys = super().get_extra_cache_keys(query_obj) if self.has_extra_cache_key_calls(query_obj): sqla_query = self.get_sqla_query(**query_obj) extra_cache_keys += sqla_query.extra_cache_keys + + # For virtual datasets, include RLS predicates in the cache key + if self.is_virtual and self.sql: + default_schema = self.database.get_default_schema(self.catalog) + rls_predicates = collect_rls_predicates_for_sql( + self.sql, + self.database, + self.catalog, + self.schema or default_schema or "", + ) + # Add each predicate as a separate cache key component + extra_cache_keys.extend(rls_predicates) + return list(set(extra_cache_keys)) @property diff --git a/superset/models/helpers.py b/superset/models/helpers.py index cb77c723e10..ea8526bd732 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -1188,6 +1188,9 @@ class ExploreMixin: # pylint: disable=too-many-public-methods Return where to select the columns and metrics from. Either a physical table or a virtual table with it's own subquery. If the FROM is referencing a CTE, the CTE is returned as the second value in the return tuple. + + For virtual datasets, RLS filters from underlying tables are applied to + prevent RLS bypass. """ from_sql = self.get_rendered_sql(template_processor) + "\n" parsed_script = SQLScript(from_sql, engine=self.db_engine_spec.engine) @@ -1196,6 +1199,24 @@ class ExploreMixin: # pylint: disable=too-many-public-methods _("Virtual dataset query must be read-only") ) + # Apply RLS filters to virtual dataset SQL to prevent RLS bypass + # For each table referenced in the virtual dataset, apply its RLS filters + if parsed_script.statements: + default_schema = self.database.get_default_schema(self.catalog) + try: + for statement in parsed_script.statements: + apply_rls( + self.database, + self.catalog, + self.schema or default_schema or "", + statement, + ) + # Regenerate the SQL after RLS application + from_sql = parsed_script.format() + except Exception as ex: + # Log the error but don't fail - RLS application is best-effort + logger.warning("Failed to apply RLS to virtual dataset SQL: %s", ex) + cte = self.db_engine_spec.get_cte_query(from_sql) from_clause = ( sa.table(self.db_engine_spec.cte_alias) diff --git a/superset/security/manager.py b/superset/security/manager.py index fb7c71050e4..4ed4ba08fa2 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -2346,10 +2346,9 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods query, template_params ) tables = { - Table( - table_.table, - table_.schema or default_schema, - table_.catalog or query.catalog or default_catalog, + table_.qualify( + catalog=query.catalog or default_catalog, + schema=default_schema, ) for table_ in process_jinja_sql( query.sql, database, template_params @@ -2357,9 +2356,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods } elif table: # Make sure table has the default catalog, if not specified. - tables = { - Table(table.table, table.schema, table.catalog or default_catalog) - } + tables = {table.qualify(catalog=default_catalog)} denied = set() diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 09a6cfe49e5..cc1801f5800 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -168,14 +168,7 @@ class RLSTransformer: table_node.catalog if table_node.catalog else self.catalog, ) if predicates := self.rules.get(table): - return ( - exp.And( - this=predicates[0], - expressions=predicates[1:], - ) - if len(predicates) > 1 - else predicates[0] - ) + return sqlglot.and_(*predicates) return None @@ -312,6 +305,21 @@ class Table: def __eq__(self, other: Any) -> bool: return str(self) == str(other) + def qualify( + self, + *, + catalog: str | None = None, + schema: str | None = None, + ) -> Table: + """ + Return a new Table with the given schema and/or catalog, if not already set. + """ + return Table( + table=self.table, + schema=self.schema or schema, + catalog=self.catalog or catalog, + ) + # To avoid unnecessary parsing/formatting of queries, the statement has the concept of # an "internal representation", which is the AST of the SQL statement. For most of the diff --git a/superset/utils/rls.py b/superset/utils/rls.py index cf422b01884..43f7da9ffe8 100644 --- a/superset/utils/rls.py +++ b/superset/utils/rls.py @@ -46,13 +46,7 @@ def apply_rls( # collect all RLS predicates for all tables in the query predicates: dict[Table, list[Any]] = {} for table in parsed_statement.tables: - # fully qualify table - table = Table( - table.table, - table.schema or schema, - table.catalog or catalog, - ) - + table = table.qualify(catalog=catalog, schema=schema) predicates[table] = [ parsed_statement.parse_predicate(predicate) for predicate in get_predicates_for_table( @@ -113,3 +107,48 @@ def get_predicates_for_table( ) for predicate in dataset.get_sqla_row_level_filters() ] + + +def collect_rls_predicates_for_sql( + sql: str, + database: Database, + catalog: str | None, + schema: str, +) -> list[str]: + """ + Collect all RLS predicates that would be applied to tables in the given SQL. + + This is used for cache key generation for virtual datasets to ensure that + different users with different RLS rules get different cache keys. + + :param sql: The SQL query to analyze + :param database: The database the query runs against + :param catalog: The default catalog for the query + :param schema: The default schema for the query + :return: List of RLS predicate strings that would be applied + """ + from superset.sql.parse import SQLScript + + try: + parsed_script = SQLScript(sql, engine=database.db_engine_spec.engine) + tables = { + table.qualify(catalog=catalog, schema=schema) + for statement in parsed_script.statements + for table in statement.tables + } + default_catalog = database.get_default_catalog() + return sorted( + { + predicate + for table in tables + for predicate in get_predicates_for_table( + table, + database, + default_catalog, + ) + } + ) + except Exception: + # If we can't parse the SQL, return empty list + # This ensures RLS application failure doesn't break caching + return [] diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 2d5dd28c17f..cab80b49476 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -311,6 +311,105 @@ class TestRowLevelSecurity(SupersetTestCase): "gender = 'boy'-gender", ] + @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") + def test_rls_filter_applies_to_virtual_dataset(self): + """ + Test that RLS filters from underlying tables are applied to virtual + datasets. + """ + # Get the physical birth_names table which has RLS filters + physical_table = self.get_table(name="birth_names") + + # Create a virtual dataset that queries the birth_names table + virtual_dataset = SqlaTable( + table_name="virtual_birth_names", + database=physical_table.database, + schema=physical_table.schema, + sql="SELECT * FROM birth_names", + ) + db.session.add(virtual_dataset) + db.session.commit() + + try: + # Test as gamma user who has RLS filters + g.user = self.get_user(username="gamma") + + # Get the SQL query for the virtual dataset + sql = virtual_dataset.get_query_str(self.query_obj) + + # Verify that RLS filters from the physical table are applied + # Gamma user should have the name filters (A%, B%, Q%) and gender filter + # Note: SQL uses uppercase LIKE and %% escaping + sql_lower = sql.lower() + assert "name like 'a%" in sql_lower or "name like 'q%" in sql_lower, ( + f"RLS name filters not found in virtual dataset query: {sql}" + ) + assert "gender = 'boy'" in sql_lower, ( + f"RLS gender filter not found in virtual dataset query: {sql}" + ) + + # Test as admin user who has no RLS filters + g.user = self.get_user(username="admin") + sql = virtual_dataset.get_query_str(self.query_obj) + + # Admin should not have RLS filters applied + assert not self.NAMES_A_REGEX.search(sql) + assert not self.NAMES_B_REGEX.search(sql) + assert not self.NAMES_Q_REGEX.search(sql) + assert not self.BASE_FILTER_REGEX.search(sql) + + finally: + # Cleanup + db.session.delete(virtual_dataset) + db.session.commit() + + @pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", "load_energy_table_with_slice" + ) + def test_rls_filter_applies_to_virtual_dataset_with_join(self): + """ + Test that RLS filters are applied when virtual dataset joins + multiple tables. + """ + # Get the physical tables + birth_names_table = self.get_table(name="birth_names") + self.get_table(name="energy_usage") # Load the table for the test + + # Create a virtual dataset with a JOIN query + virtual_dataset = SqlaTable( + table_name="virtual_joined", + database=birth_names_table.database, + schema=birth_names_table.schema, + sql="SELECT b.name, e.value FROM birth_names b JOIN energy_usage e ON 1=1", + ) + db.session.add(virtual_dataset) + db.session.commit() + + try: + # Test as gamma user who has RLS filters on both tables + g.user = self.get_user(username="gamma") + + # Get the SQL query for the virtual dataset + sql = virtual_dataset.get_query_str(self.query_obj) + + # Verify that RLS filters from both physical tables are applied + # birth_names filters + sql_lower = sql.lower() + assert "name like 'a%" in sql_lower or "name like 'q%" in sql_lower, ( + f"birth_names RLS filters not found: {sql}" + ) + assert "gender = 'boy'" in sql_lower, ( + f"birth_names gender filter not found: {sql}" + ) + + # energy_usage filter + assert "value > 1" in sql_lower, f"energy_usage RLS filter not found: {sql}" + + finally: + # Cleanup + db.session.delete(virtual_dataset) + db.session.commit() + class TestRowLevelSecurityCreateAPI(SupersetTestCase): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") diff --git a/tests/integration_tests/sqla_models_tests.py b/tests/integration_tests/sqla_models_tests.py index 2e9c4c00e5a..0852c62362d 100644 --- a/tests/integration_tests/sqla_models_tests.py +++ b/tests/integration_tests/sqla_models_tests.py @@ -190,7 +190,7 @@ class TestDatabaseModel(SupersetTestCase): query = table.database.compile_sqla_query(sqla_query.sqla_query) # assert virtual dataset - assert "SELECT 'user_abc' as user, 'xyz_P1D' as time_grain" in query + assert "SELECT\n 'user_abc' AS user,\n 'xyz_P1D' AS time_grain" in query # assert dataset calculated column assert "case when 'abc' = 'abc' then 'yes' else 'no' end" in query # assert adhoc column diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 6d89a3356e8..0af93696595 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -62,6 +62,81 @@ def test_table() -> None: ) +def test_table_qualify() -> None: + """ + Test the `Table.qualify` method. + + The qualify method should add schema and/or catalog if not already set, + but should not override existing values. + """ + # Table with no schema or catalog + table = Table("tbname") + + # Add schema only + qualified = table.qualify(schema="schemaname") + assert qualified.table == "tbname" + assert qualified.schema == "schemaname" + assert qualified.catalog is None + assert str(qualified) == "schemaname.tbname" + + # Add catalog only + qualified = table.qualify(catalog="catalogname") + assert qualified.table == "tbname" + assert qualified.schema is None + assert qualified.catalog == "catalogname" + assert str(qualified) == "catalogname.tbname" + + # Add both schema and catalog + qualified = table.qualify(schema="schemaname", catalog="catalogname") + assert qualified.table == "tbname" + assert qualified.schema == "schemaname" + assert qualified.catalog == "catalogname" + assert str(qualified) == "catalogname.schemaname.tbname" + + # Table with existing schema - should not override + table_with_schema = Table("tbname", "existingschema") + qualified = table_with_schema.qualify(schema="newschema") + assert qualified.schema == "existingschema" + assert str(qualified) == "existingschema.tbname" + + # Table with existing catalog - should not override + table_with_catalog = Table("tbname", catalog="existingcatalog") + qualified = table_with_catalog.qualify(catalog="newcatalog") + assert qualified.catalog == "existingcatalog" + assert str(qualified) == "existingcatalog.tbname" + + # Table with existing schema and catalog - should not override + fully_qualified = Table("tbname", "existingschema", "existingcatalog") + qualified = fully_qualified.qualify(schema="newschema", catalog="newcatalog") + assert qualified.schema == "existingschema" + assert qualified.catalog == "existingcatalog" + assert str(qualified) == "existingcatalog.existingschema.tbname" + + # Table with schema but no catalog - should add catalog only + table_with_schema_only = Table("tbname", "existingschema") + qualified = table_with_schema_only.qualify( + schema="newschema", catalog="catalogname" + ) + assert qualified.schema == "existingschema" + assert qualified.catalog == "catalogname" + assert str(qualified) == "catalogname.existingschema.tbname" + + # Table with catalog but no schema - should add schema only + table_with_catalog_only = Table("tbname", catalog="existingcatalog") + qualified = table_with_catalog_only.qualify( + schema="schemaname", catalog="newcatalog" + ) + assert qualified.schema == "schemaname" + assert qualified.catalog == "existingcatalog" + assert str(qualified) == "existingcatalog.schemaname.tbname" + + # Calling qualify with no arguments should return equivalent table + qualified = table.qualify() + assert qualified.table == table.table + assert qualified.schema == table.schema + assert qualified.catalog == table.catalog + + def extract_tables_from_sql(sql: str, engine: str = "postgresql") -> set[Table]: """ Helper function to extract tables from SQL.