diff --git a/superset/models/helpers.py b/superset/models/helpers.py index a4fb9e3fea1..d0ca58a7651 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -120,6 +120,7 @@ from superset.utils.core import ( from superset.utils.date_parser import get_past_or_future, normalize_time_delta from superset.utils.dates import datetime_to_epoch from superset.utils.rls import apply_rls +from superset.data_access_rules.utils import apply_data_access_rules class ValidationResultDict(TypedDict): @@ -1049,6 +1050,22 @@ class ExploreMixin: # pylint: disable=too-many-public-methods ) sql = self._apply_cte(sql, sqlaq.cte) + # Apply Data Access Rules (RLS and CLS) if enabled + if is_feature_enabled("DATA_ACCESS_RULES"): + try: + default_schema = self.database.get_default_schema(self.catalog) + parsed_script = SQLScript(sql, engine=self.db_engine_spec.engine) + for statement in parsed_script.statements: + apply_data_access_rules( + self.database, + self.catalog, + self.schema or default_schema or "", + statement, + ) + sql = parsed_script.format() + except Exception as ex: + logger.warning("Failed to apply Data Access Rules: %s", ex) + if mutate: sql = self.database.mutate_sql_based_on_config(sql) return QueryStringExtended( @@ -2051,6 +2068,23 @@ class ExploreMixin: # pylint: disable=too-many-public-methods # Log the error but don't fail - RLS application is best-effort logger.warning("Failed to apply RLS to virtual dataset SQL: %s", ex) + # Apply Data Access Rules to virtual dataset SQL + if is_feature_enabled("DATA_ACCESS_RULES") and parsed_script.statements: + default_schema = self.database.get_default_schema(self.catalog) + try: + for statement in parsed_script.statements: + apply_data_access_rules( + self.database, + self.catalog, + self.schema or default_schema or "", + statement, + ) + from_sql = parsed_script.format() + except Exception as ex: + logger.warning( + "Failed to apply Data Access Rules to virtual dataset SQL: %s", ex + ) + cte = self.db_engine_spec.get_cte_query(from_sql) from_clause = ( sa.table(self.db_engine_spec.cte_alias) diff --git a/tests/unit_tests/models/helpers_test.py b/tests/unit_tests/models/helpers_test.py index e913526975f..0e5a89189ad 100644 --- a/tests/unit_tests/models/helpers_test.py +++ b/tests/unit_tests/models/helpers_test.py @@ -1703,3 +1703,122 @@ def test_adhoc_column_with_spaces_in_full_query(database: Database) -> None: # Verify SELECT and FROM clauses are present assert "SELECT" in sql assert "FROM" in sql + + +def test_get_query_str_extended_with_data_access_rules( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test that get_query_str_extended calls apply_data_access_rules when enabled. + + This test mocks the get_sqla_query to return a simple SQL query, then verifies + that apply_data_access_rules is called when the feature flag is enabled. + """ + from unittest.mock import MagicMock + + import sqlalchemy as sa + + from superset.connectors.sqla.models import SqlaTable, TableColumn + from superset.models.helpers import SqlaQuery + + table = SqlaTable( + database=database, + schema=None, + table_name="t", + columns=[TableColumn(column_name="a")], + ) + + # Create a mock SqlaQuery result (fields must be in order per NamedTuple) + mock_sqla_query = SqlaQuery( + applied_template_filters=[], + applied_filter_columns=[], + rejected_filter_columns=[], + cte=None, + extra_cache_keys=[], + labels_expected=["a"], + prequeries=[], + sqla_query=sa.select(sa.column("a")).select_from(sa.table("t")), + ) + + # Mock get_sqla_query to return our simple query + mocker.patch.object(table, "get_sqla_query", return_value=mock_sqla_query) + + # Mock apply_data_access_rules + mock_apply_dar = mocker.patch( + "superset.models.helpers.apply_data_access_rules" + ) + + # Mock is_feature_enabled to return True for DATA_ACCESS_RULES + mocker.patch( + "superset.models.helpers.is_feature_enabled", + return_value=True, + ) + + # Mock database.get_default_schema + mocker.patch.object( + database, "get_default_schema", return_value="public" + ) + + query_obj = { + "granularity": None, + "from_dttm": None, + "to_dttm": None, + "columns": ["a"], + "metrics": [], + "orderby": [], + "row_limit": 100, + "filter": [], + } + + result = table.get_query_str_extended(query_obj) + + # Verify we got a result + assert result is not None + assert result.sql is not None + + # Verify apply_data_access_rules was called + assert mock_apply_dar.called + + +def test_get_from_clause_with_data_access_rules( + mocker: MockerFixture, + database: Database, +) -> None: + """ + Test that get_from_clause calls apply_data_access_rules for virtual datasets. + """ + from superset.connectors.sqla.models import SqlaTable, TableColumn + + # Create a virtual dataset (has SQL defined) + table = SqlaTable( + database=database, + schema=None, + table_name="virtual_table", + sql="SELECT * FROM base_table", + columns=[TableColumn(column_name="a")], + ) + + # Mock apply_data_access_rules + mock_apply_dar = mocker.patch( + "superset.models.helpers.apply_data_access_rules" + ) + + # Mock is_feature_enabled to return True for DATA_ACCESS_RULES + mocker.patch( + "superset.models.helpers.is_feature_enabled", + return_value=True, + ) + + # Mock database.get_default_schema + mocker.patch.object( + database, "get_default_schema", return_value="public" + ) + + try: + table.get_from_clause() + except Exception: + pass # We're just testing that apply_data_access_rules is called + + # Verify apply_data_access_rules was called + assert mock_apply_dar.called