diff --git a/superset/sql/parse.py b/superset/sql/parse.py index d944dce5a2f..db84bb2dc7b 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -543,7 +543,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]): self, catalog: str | None, schema: str | None, - predicates: dict[Table, list[InternalRepresentation]], + predicates: dict[Table, list[str]], method: RLSMethod, ) -> None: """ @@ -973,7 +973,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): self, catalog: str | None, schema: str | None, - predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]], + predicates: dict[Table, list[str]], method: RLSMethod, ) -> None: """ @@ -981,9 +981,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): :param catalog: The default catalog for non-qualified table names :param schema: The default schema for non-qualified table names - :param predicates: Mapping of fully qualified ``Table`` to predicates. - For ``AS_PREDICATE`` and ``AS_SUBQUERY`` the predicates are sqlglot - expressions. For ``AS_PREDICATE_SPLICE`` they are raw SQL strings. + :param predicates: Mapping of fully qualified ``Table`` to raw predicate + SQL strings. :param method: The method to use for applying the rules. """ if not predicates: @@ -993,6 +992,11 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): self._apply_rls_splice(catalog, schema, predicates) return + parsed_predicates: dict[Table, list[exp.Expression]] = { + table: [self.parse_predicate(predicate) for predicate in table_predicates] + for table, table_predicates in predicates.items() + } + transformers = { RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer, RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer, @@ -1000,14 +1004,14 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): if method not in transformers: raise ValueError(f"Invalid RLS method: {method}") - transformer = transformers[method](catalog, schema, predicates) + transformer = transformers[method](catalog, schema, parsed_predicates) self._parsed = self._parsed.transform(transformer) def _apply_rls_splice( self, catalog: str | None, schema: str | None, - predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]], + predicates: dict[Table, list[str]], ) -> None: """ Apply RLS via text splicing on the original SQL. @@ -1024,19 +1028,11 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): "this SQLStatement was constructed without one." ) - # Splice operates on raw predicate strings; coerce expressions if needed. - string_predicates: dict[Table, list[str]] = { - table: [ - pred if isinstance(pred, str) else pred.sql(dialect=self._dialect) - for pred in preds - ] - for table, preds in predicates.items() - } spliced = apply_rls_splice( self._source_sql, catalog, schema, - string_predicates, + predicates, dialect=self._dialect, ) self._raw_sql = spliced diff --git a/superset/utils/rls.py b/superset/utils/rls.py index e7b989f349a..3d18bc9d2e6 100644 --- a/superset/utils/rls.py +++ b/superset/utils/rls.py @@ -22,7 +22,7 @@ from typing import Any, TYPE_CHECKING from sqlalchemy import and_, or_ from superset import db -from superset.sql.parse import RLSMethod, Table +from superset.sql.parse import Table if TYPE_CHECKING: from superset.models.core import Database @@ -48,11 +48,7 @@ def apply_rls( # syntax that the sqlglot generator would otherwise transpile) method = database.db_engine_spec.rls_method - # In splice mode predicates stay as raw SQL strings and are inserted verbatim - # into the source query — re-parsing them would force a generator round-trip - # later and defeat the purpose. - use_splice = method == RLSMethod.AS_PREDICATE_SPLICE - predicates: dict[Table, list[Any]] = {} + predicates: dict[Table, list[str]] = {} for table in parsed_statement.tables: table = table.qualify(catalog=catalog, schema=schema) raw_predicates = [ @@ -64,11 +60,7 @@ def apply_rls( ) if predicate ] - predicates[table] = ( - raw_predicates - if use_splice - else [parsed_statement.parse_predicate(p) for p in raw_predicates] - ) + predicates[table] = raw_predicates has_predicates = any(predicates.values()) parsed_statement.apply_rls(catalog, schema, predicates, method) diff --git a/tests/unit_tests/models/test_virtual_dataset_format.py b/tests/unit_tests/models/test_virtual_dataset_format.py index 921abe4048d..bb8f96bc217 100644 --- a/tests/unit_tests/models/test_virtual_dataset_format.py +++ b/tests/unit_tests/models/test_virtual_dataset_format.py @@ -209,7 +209,7 @@ class TestApplyRlsReturnValue: from superset.utils.rls import apply_rls database = MagicMock() - database.db_engine_spec.get_rls_method.return_value = MagicMock() + database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY database.get_default_catalog.return_value = None statement = MagicMock() @@ -237,7 +237,7 @@ class TestApplyRlsReturnValue: mock_get_predicates.return_value = [] database = MagicMock() - database.db_engine_spec.get_rls_method.return_value = MagicMock() + database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY database.get_default_catalog.return_value = None mock_table = MagicMock() @@ -268,7 +268,7 @@ class TestApplyRlsReturnValue: mock_get_predicates.return_value = ["user_id = 42"] database = MagicMock() - database.db_engine_spec.get_rls_method.return_value = MagicMock() + database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY database.get_default_catalog.return_value = None mock_table = MagicMock() @@ -276,8 +276,6 @@ class TestApplyRlsReturnValue: statement = MagicMock() statement.tables = [mock_table] - statement.parse_predicate.return_value = MagicMock() - result = apply_rls( database=database, catalog=None, @@ -312,11 +310,10 @@ class TestRLSSubqueryAlias: """ sql = "SELECT pens.pen_id, pens.is_green FROM public.pens" statement = SQLStatement(sql, engine="redshift") - predicate = statement.parse_predicate("user_id = 1") statement.apply_rls( None, "public", - {Table("pens", "public", None): [predicate]}, + {Table("pens", "public", None): ["user_id = 1"]}, RLSMethod.AS_SUBQUERY, ) result = statement.format() @@ -333,11 +330,10 @@ class TestRLSSubqueryAlias: """ sql = "SELECT pens.pen_id, pens.is_green FROM mycat.public.pens" statement = SQLStatement(sql, engine="redshift") - predicate = statement.parse_predicate("user_id = 1") statement.apply_rls( None, "public", - {Table("pens", "public", "mycat"): [predicate]}, + {Table("pens", "public", "mycat"): ["user_id = 1"]}, RLSMethod.AS_SUBQUERY, ) result = statement.format() @@ -351,11 +347,10 @@ class TestRLSSubqueryAlias: """ sql = "SELECT p.pen_id, p.is_green FROM public.pens p" statement = SQLStatement(sql, engine="redshift") - predicate = statement.parse_predicate("user_id = 1") statement.apply_rls( None, "public", - {Table("pens", "public", None): [predicate]}, + {Table("pens", "public", None): ["user_id = 1"]}, RLSMethod.AS_SUBQUERY, ) result = statement.format() @@ -369,11 +364,10 @@ class TestRLSSubqueryAlias: """ sql = "SELECT pen_id, is_green FROM public.pens" statement = SQLStatement(sql, engine="redshift") - predicate = statement.parse_predicate("user_id = 1") statement.apply_rls( None, "public", - {Table("pens", "public", None): [predicate]}, + {Table("pens", "public", None): ["user_id = 1"]}, RLSMethod.AS_SUBQUERY, ) result = statement.format() diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index a8de0dcbe39..e21429d0d01 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -2213,7 +2213,7 @@ def test_rls_subquery_transformer( statement.apply_rls( "catalog1", "schema1", - {k: [parse_one(v)] for k, v in rules.items()}, + {k: [v] for k, v in rules.items()}, RLSMethod.AS_SUBQUERY, ) assert statement.format() == expected @@ -2557,7 +2557,7 @@ def test_rls_predicate_transformer( statement.apply_rls( "catalog1", "schema1", - {k: [parse_one(v)] for k, v in rules.items()}, + {k: [v] for k, v in rules.items()}, RLSMethod.AS_PREDICATE, ) assert statement.format() == expected