From 3fe2b2505f3d35009d7f7958b33da78933b24f75 Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Fri, 8 May 2026 13:33:42 -0400 Subject: [PATCH] feat: new splice RLSMethod --- superset/db_engine_specs/base.py | 22 +- superset/sql/parse.py | 117 +++++++- superset/sql/rls_splice.py | 263 ++++++++++++++++++ superset/utils/rls.py | 25 +- tests/unit_tests/db_engine_specs/test_base.py | 43 ++- tests/unit_tests/sql/parse_tests.py | 159 +++++++++++ 6 files changed, 615 insertions(+), 14 deletions(-) create mode 100644 superset/sql/rls_splice.py diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 4d26ca8517a..78fc2e78afa 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -557,6 +557,14 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # if True, database will be listed as option in the upload file form supports_file_upload = True + # Optional override for the RLS method used by ``get_rls_method``. When set, + # the engine spec opts into a specific strategy regardless of the + # ``allows_subqueries`` / ``allows_alias_in_select`` defaults. Use + # ``RLSMethod.AS_PREDICATE_SPLICE`` for engines whose sqlglot dialect can + # parse but not faithfully regenerate the SQL — splice mode rewrites the + # original query string instead of round-tripping through the generator. + rls_method: RLSMethod | None = None + # Is the DB engine spec able to change the default schema? This requires implementing # noqa: E501 # a custom `adjust_engine_params` method. supports_dynamic_schema = False @@ -623,10 +631,18 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods """ Returns the RLS method to be used for this engine. - There are two ways to insert RLS: either replacing the table with a subquery - that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is - safer, but not supported in all databases. + There are three ways to insert RLS: replacing the table with a subquery + that has the RLS (safest, but not supported in all databases), appending + the RLS to the ``WHERE`` clause via AST transformation, or splicing the + RLS into the original SQL string (preserves dialect-specific syntax that + the sqlglot generator would otherwise transpile). + + Engine specs can opt into a specific strategy by setting the class-level + ``rls_method`` attribute; otherwise the choice falls back to subquery + when supported, and predicate otherwise. """ + if cls.rls_method is not None: + return cls.rls_method return ( RLSMethod.AS_SUBQUERY if cls.allows_subqueries and cls.allows_alias_in_select diff --git a/superset/sql/parse.py b/superset/sql/parse.py index bb3ef5e1c4b..a45943ea463 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -142,6 +142,7 @@ class RLSMethod(enum.Enum): AS_PREDICATE = enum.auto() AS_SUBQUERY = enum.auto() + AS_PREDICATE_SPLICE = enum.auto() class RLSTransformer: @@ -355,6 +356,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]): statement: str | None = None, engine: str = "base", ast: InternalRepresentation | None = None, + source: str | None = None, ): if ast: self._parsed = ast @@ -365,6 +367,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]): self.engine = engine self.tables = self._extract_tables_from_statement(self._parsed, self.engine) + # Original SQL substring for this statement, when known. Used by the + # splice-mode RLS path which rewrites this string instead of regenerating + # SQL from the AST. ``None`` means the statement was constructed from an + # AST without an associated source string (splice mode falls back). + self._source_sql: str | None = source if source is not None else statement + # Verbatim SQL to return from ``format()``. Set by string-rewriting + # operations (e.g. splice-mode RLS) that produce a final SQL string and + # need to bypass the dialect generator. Cleared by AST-mutating methods + # since those invalidate this cached text. + self._raw_sql: str | None = None @classmethod def split_script( @@ -559,9 +571,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): statement: str | None = None, engine: str = "base", ast: exp.Expression | None = None, + source: str | None = None, ): self._dialect = SQLGLOT_DIALECTS.get(engine) - super().__init__(statement, engine, ast) + super().__init__(statement, engine, ast, source) @classmethod def _parse(cls, script: str, engine: str) -> list[exp.Expression]: @@ -626,10 +639,55 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): script: str, engine: str, ) -> list[SQLStatement]: + asts = [ast for ast in cls._parse(script, engine) if ast] + sources = cls._split_source(script, engine, len(asts)) return [ - cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast + cls(ast=ast, engine=engine, source=source) + for ast, source in zip(asts, sources, strict=False) ] + @classmethod + def _split_source( + cls, + script: str, + engine: str, + expected_count: int, + ) -> list[str | None]: + """ + Slice ``script`` into per-statement substrings using top-level semicolon + positions from the tokenizer. Returns a list of length ``expected_count``; + any entry is ``None`` if the slicing didn't yield a usable substring. + + The returned substrings preserve the original byte content of the script + for each statement — necessary for splice-mode RLS, which rewrites the + original SQL rather than regenerating from the AST. + """ + none_result: list[str | None] = [None] * expected_count + dialect = SQLGLOT_DIALECTS.get(engine) + try: + tokens = list(Dialect.get_or_raise(dialect).tokenize(script)) + except sqlglot.errors.SqlglotError: + return none_result + + # Top-level semicolon offsets (depth 0). + boundaries: list[int] = [] + depth = 0 + for tok in tokens: + if tok.token_type == sqlglot.tokens.TokenType.L_PAREN: + depth += 1 + elif tok.token_type == sqlglot.tokens.TokenType.R_PAREN: + depth -= 1 + elif tok.token_type == sqlglot.tokens.TokenType.SEMICOLON and depth == 0: + boundaries.append(tok.start) + + starts = [0, *(b + 1 for b in boundaries)] + ends = [*boundaries, len(script)] + sources = [script[s:e].strip() for s, e in zip(starts, ends, strict=False)] + sources = [s for s in sources if s] + if len(sources) != expected_count: + return none_result + return list(sources) + @classmethod def _parse_statement( cls, @@ -722,7 +780,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): def format(self, comments: bool = True) -> str: """ Pretty-format the SQL statement. + + When a string-rewriting operation (e.g. splice-mode RLS) has cached a + verbatim result in ``_raw_sql``, return it as-is — the whole point of + those operations is to avoid the dialect generator round-trip. """ + if self._raw_sql is not None: + return self._raw_sql return Dialect.get_or_raise(self._dialect).generate( self._parsed, copy=True, @@ -808,6 +872,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ Modify the `LIMIT` or `TOP` value of the SQL statement inplace. """ + # AST mutation invalidates any cached verbatim SQL (e.g. from splice). + self._raw_sql = None if method == LimitMethod.FORCE_LIMIT: self._parsed.args["limit"] = exp.Limit( expression=exp.Literal(this=str(limit), is_string=False) @@ -902,7 +968,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): self, catalog: str | None, schema: str | None, - predicates: dict[Table, list[exp.Expression]], + predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]], method: RLSMethod, ) -> None: """ @@ -910,11 +976,18 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): :param catalog: The default catalog for non-qualified table names :param schema: The default schema for non-qualified table names + :param predicates: Mapping of fully qualified ``Table`` to predicates. + For ``AS_PREDICATE`` and ``AS_SUBQUERY`` the predicates are sqlglot + expressions. For ``AS_PREDICATE_SPLICE`` they are raw SQL strings. :param method: The method to use for applying the rules. """ if not predicates: return + if method == RLSMethod.AS_PREDICATE_SPLICE: + self._apply_rls_splice(catalog, schema, predicates) + return + transformers = { RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer, RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer, @@ -925,6 +998,44 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): transformer = transformers[method](catalog, schema, predicates) self._parsed = self._parsed.transform(transformer) + def _apply_rls_splice( + self, + catalog: str | None, + schema: str | None, + predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]], + ) -> None: + """ + Apply RLS via text splicing on the original SQL. + + Requires the source SQL substring to be available. Raises ``ValueError`` + if it isn't — the caller must ensure the statement was constructed from + a source string (the standard ``SQLScript`` path does this). + """ + from superset.sql.rls_splice import apply_rls_splice + + if self._source_sql is None: + raise ValueError( + "Splice-mode RLS requires the source SQL string; " + "this SQLStatement was constructed without one." + ) + + # Splice operates on raw predicate strings; coerce expressions if needed. + string_predicates: dict[Table, list[str]] = { + table: [ + pred if isinstance(pred, str) else pred.sql(dialect=self._dialect) + for pred in preds + ] + for table, preds in predicates.items() + } + spliced = apply_rls_splice( + self._source_sql, + catalog, + schema, + string_predicates, + dialect=self._dialect, + ) + self._raw_sql = spliced + class KQLSplitState(enum.Enum): """ diff --git a/superset/sql/rls_splice.py b/superset/sql/rls_splice.py new file mode 100644 index 00000000000..0166a13fe49 --- /dev/null +++ b/superset/sql/rls_splice.py @@ -0,0 +1,263 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +RLS predicate injection via text splicing. + +Instead of round-tripping through sqlglot's generator (which transpiles +dialect-specific functions like ``LAST_DAY`` into something else), this approach: + + 1. Parses the SQL with sqlglot — only to understand structure (scope tree). + 2. Uses sqlglot's tokenizer to get byte-accurate positions for every token + in the original SQL string. + 3. For each ``SELECT`` scope that references a table with an RLS predicate, + finds the exact byte offset to inject at — either the end of an existing + ``WHERE`` clause, or just before ``GROUP BY`` / ``ORDER BY`` / ``HAVING`` + / ``LIMIT`` / the closing paren of a subquery. + 4. Splices the predicate text directly into the original string at that + offset — never calling ``.sql()``, so the generator never runs. + +Result: everything outside the splice points is the original SQL, byte for +byte. Dialect-specific functions, comments, and formatting are all preserved +exactly. + +Known limitations: + - SQL that fails to parse under the chosen dialect raises a ``ParseError``. + A thin dialect subclass is still required for parsing — but only for + parsing, not generation. + - Predicate strings are spliced in as raw SQL. They must come from a trusted + source (the RLS config), not user input. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import sqlglot +from sqlglot import exp +from sqlglot.optimizer.scope import traverse_scope +from sqlglot.tokens import Token, TokenType + +if TYPE_CHECKING: + from superset.sql.parse import Table + + +# Token types that end a WHERE clause / FROM section at the current paren depth, +# indicating where a new predicate must be inserted just before. +_CLAUSE_ENDS = { + TokenType.GROUP_BY, + TokenType.HAVING, + TokenType.ORDER_BY, + TokenType.LIMIT, + TokenType.FETCH, + TokenType.UNION, + TokenType.INTERSECT, + TokenType.EXCEPT, +} + + +def _before_whitespace(sql: str, offset: int) -> int: + """Back up past any whitespace immediately before *offset*.""" + while offset > 0 and sql[offset - 1] in (" ", "\t", "\n", "\r"): + offset -= 1 + return offset + + +def _table_from_node( + node: exp.Table, + catalog: str | None, + schema: str | None, +) -> Table: + """ + Build a fully qualified ``Table`` from a sqlglot ``exp.Table`` node, defaulting + unqualified parts to the supplied catalog/schema. + """ + # Imported lazily to avoid a circular import with ``superset.sql.parse``. + from superset.sql.parse import Table + + return Table( + table=node.name, + schema=node.db if node.db else schema, + catalog=node.catalog if node.catalog else catalog, + ) + + +def apply_rls_splice( + sql: str, + catalog: str | None, + schema: str | None, + predicates: dict[Table, list[str]], + dialect: str | None = None, +) -> str: + """ + Inject RLS predicates into ``sql`` by splicing text at the right positions. + + :param sql: The original SQL query. Returned unchanged except at splice points. + :param catalog: The default catalog for non-qualified table names. + :param schema: The default schema for non-qualified table names. + :param predicates: Mapping of ``Table`` to predicate SQL strings. Each entry + maps a fully qualified table to one or more raw predicate strings to + ``AND`` together when that table is referenced in a SELECT scope. + :param dialect: The sqlglot dialect used for *parsing only* — to understand + scope structure and locate token positions. The generator is never + called, so this does not affect output formatting. + :return: The query with RLS predicates injected into every relevant SELECT + scope. + """ + if not predicates or not any(predicates.values()): + return sql + + resolved_dialect = sqlglot.Dialect.get_or_raise(dialect) + tokens = list(resolved_dialect.tokenize(sql)) + tree = sqlglot.parse_one(sql, dialect=dialect) + + splices: list[tuple[int, str]] = [] + for scope in traverse_scope(tree): + splice = _splice_for_scope(sql, tokens, scope, predicates, catalog, schema) + if splice is not None: + splices.append(splice) + + # Apply splices in reverse offset order so earlier positions stay valid. + splices.sort(key=lambda item: item[0], reverse=True) + result = sql + for offset, text in splices: + result = result[:offset] + text + result[offset:] + return result + + +def _splice_for_scope( + sql: str, + tokens: list[Token], + scope: object, + predicates: dict[Table, list[str]], + catalog: str | None, + schema: str | None, +) -> tuple[int, str] | None: + """ + Compute the (offset, text) splice for a single SELECT scope, or ``None`` if + the scope has no matching predicates or no usable anchor. + """ + scope_preds = _collect_scope_predicates(scope, predicates, catalog, schema) + if not scope_preds: + return None + + # Anchor: rightmost character position among the table-name identifiers + # directly owned by this scope. Used to skip past tokens that belong to + # earlier parts of the query (projections, JOIN ON clauses, etc.). + table_ends = [ + ident._meta["end"] + for source in scope.sources.values() # type: ignore[attr-defined] + if isinstance(source, exp.Table) + for ident in [source.find(exp.Identifier)] + if ident and getattr(ident, "_meta", None) + ] + if not table_ends: + return None + + has_where = scope.expression.args.get("where") is not None # type: ignore[attr-defined] + pred_sql = " AND ".join(scope_preds) + return _find_splice_point(sql, tokens, max(table_ends), has_where, pred_sql) + + +def _collect_scope_predicates( + scope: object, + predicates: dict[Table, list[str]], + catalog: str | None, + schema: str | None, +) -> list[str]: + """ + Collect the predicates that apply to direct Table sources in ``scope``, + deduped while preserving order. + """ + scope_preds: list[str] = [] + for source in scope.sources.values(): # type: ignore[attr-defined] + if not isinstance(source, exp.Table): + continue + table = _table_from_node(source, catalog, schema) + for predicate in predicates.get(table, []): + if predicate and predicate not in scope_preds: + scope_preds.append(predicate) + return scope_preds + + +def _find_splice_point( + sql: str, + tokens: list[Token], + anchor: int, + has_where: bool, + pred_sql: str, +) -> tuple[int, str] | None: + """ + Scan tokens forward from ``anchor``, tracking paren depth, to find where to + insert the RLS predicate for a single scope. + """ + depth = 0 + for i, tok in enumerate(tokens): + if tok.start <= anchor: + continue + + if tok.token_type == TokenType.L_PAREN: + depth += 1 + continue + + if tok.token_type == TokenType.R_PAREN: + if depth == 0: + # Closing paren of our subquery — insert just before it. + offset = _before_whitespace(sql, tok.start) + text = f" AND {pred_sql}" if has_where else f" WHERE {pred_sql}" + return (offset, text) + depth -= 1 + continue + + if depth > 0: + continue + + if has_where and tok.token_type == TokenType.WHERE: + return _find_after_where(sql, tokens, i, pred_sql) + + if not has_where and tok.token_type in _CLAUSE_ENDS: + # Insert WHERE before this clause keyword. + return (_before_whitespace(sql, tok.start), f" WHERE {pred_sql}") + + # No clause boundary found — append at end of SQL. + text = f" AND {pred_sql}" if has_where else f" WHERE {pred_sql}" + return (len(sql), text) + + +def _find_after_where( + sql: str, + tokens: list[Token], + where_index: int, + pred_sql: str, +) -> tuple[int, str] | None: + """ + Given the index of a ``WHERE`` token in ``tokens``, find the offset just + after the WHERE clause body where ``AND `` should be inserted. + """ + depth = 0 + prev_end = tokens[where_index].end + for tok in tokens[where_index + 1 :]: + if tok.token_type == TokenType.L_PAREN: + depth += 1 + elif tok.token_type == TokenType.R_PAREN: + if depth == 0: + return (_before_whitespace(sql, tok.start), f" AND {pred_sql}") + depth -= 1 + elif depth == 0 and tok.token_type in _CLAUSE_ENDS: + return (_before_whitespace(sql, tok.start), f" AND {pred_sql}") + prev_end = tok.end + return (prev_end + 1, f" AND {pred_sql}") diff --git a/superset/utils/rls.py b/superset/utils/rls.py index 7e6cdf2aee7..456b589e365 100644 --- a/superset/utils/rls.py +++ b/superset/utils/rls.py @@ -22,7 +22,7 @@ from typing import Any, TYPE_CHECKING from sqlalchemy import and_, or_ from superset import db -from superset.sql.parse import Table +from superset.sql.parse import RLSMethod, Table if TYPE_CHECKING: from superset.models.core import Database @@ -40,17 +40,23 @@ def apply_rls( :returns: True if any RLS predicates were actually applied, False otherwise. """ - # There are two ways to insert RLS: either replacing the table with a subquery - # that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is - # safer, but not supported in all databases. + # There are three ways to insert RLS: + # - replace the table with a subquery containing the RLS (safest, but not + # supported in all databases) + # - append the RLS to the ``WHERE`` clause via AST transformation + # - splice the RLS into the original SQL string (preserves dialect-specific + # syntax that the sqlglot generator would otherwise transpile) method = database.db_engine_spec.get_rls_method() - # collect all RLS predicates for all tables in the query + # In splice mode predicates stay as raw SQL strings and are inserted verbatim + # into the source query — re-parsing them would force a generator round-trip + # later and defeat the purpose. + use_splice = method == RLSMethod.AS_PREDICATE_SPLICE predicates: dict[Table, list[Any]] = {} for table in parsed_statement.tables: table = table.qualify(catalog=catalog, schema=schema) - predicates[table] = [ - parsed_statement.parse_predicate(predicate) + raw_predicates = [ + predicate for predicate in get_predicates_for_table( table, database, @@ -58,6 +64,11 @@ def apply_rls( ) if predicate ] + predicates[table] = ( + raw_predicates + if use_splice + else [parsed_statement.parse_predicate(p) for p in raw_predicates] + ) has_predicates = any(predicates.values()) parsed_statement.apply_rls(catalog, schema, predicates, method) diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 5eae41458de..1b58809b3b2 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -36,7 +36,7 @@ from sqlalchemy.sql import sqltypes from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import OAuth2RedirectError -from superset.sql.parse import Table +from superset.sql.parse import RLSMethod, Table from superset.superset_typing import ( OAuth2ClientConfig, OAuth2State, @@ -1283,3 +1283,44 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None error = exc_info.value.error assert error.extra["redirect_uri"] == fallback_uri + + +def test_get_rls_method_default_subquery() -> None: + """ + By default, an engine that supports subqueries and aliases-in-select + uses the safer subquery RLS strategy. + """ + + class _Spec(BaseEngineSpec): + allows_subqueries = True + allows_alias_in_select = True + + assert _Spec.get_rls_method() == RLSMethod.AS_SUBQUERY + + +def test_get_rls_method_default_predicate_when_no_subqueries() -> None: + """ + Engines without subquery / alias-in-select support fall back to the + AST predicate strategy. + """ + + class _Spec(BaseEngineSpec): + allows_subqueries = False + allows_alias_in_select = True + + assert _Spec.get_rls_method() == RLSMethod.AS_PREDICATE + + +def test_get_rls_method_class_attribute_override() -> None: + """ + Setting ``rls_method`` on an engine spec opts the engine into a specific + strategy regardless of the subquery/alias defaults — used by engines whose + sqlglot dialect can parse but not faithfully regenerate SQL. + """ + + class _SpliceSpec(BaseEngineSpec): + allows_subqueries = True + allows_alias_in_select = True + rls_method = RLSMethod.AS_PREDICATE_SPLICE + + assert _SpliceSpec.get_rls_method() == RLSMethod.AS_PREDICATE_SPLICE diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 78b00f4487d..7835840b04a 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -2548,6 +2548,165 @@ def test_rls_predicate_transformer( assert statement.format() == expected +@pytest.mark.parametrize( + "sql, rules, expected", + [ + # Simple — no WHERE clause to extend. + ( + "SELECT LAST_DAY(d) FROM some_table", + {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"}, + "SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42", + ), + # Append to an existing WHERE clause. + ( + "SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'", + {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"}, + "SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' " + "AND tenant_id = 42", + ), + # WHERE precedes GROUP BY: predicate goes before GROUP BY. + ( + "SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d", + {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"}, + "SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' " + "AND tenant_id = 42 GROUP BY d", + ), + # No WHERE, but GROUP BY and ORDER BY are present. + ( + "SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d", + {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"}, + "SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42 " + "GROUP BY d ORDER BY d", + ), + # JOIN — predicate scoped to one of the tables. + ( + "SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id", + {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"}, + "SELECT o.id FROM some_table o JOIN locations l " + "ON o.loc_id = l.id WHERE tenant_id = 42", + ), + # JOIN — different predicate per table, both spliced into one WHERE. + ( + "SELECT * FROM some_table JOIN events ON some_table.id = events.order_id", + { + Table("some_table", "schema1", "catalog1"): "tenant_id = 42", + Table("events", "schema1", "catalog1"): "user_id = 99", + }, + "SELECT * FROM some_table JOIN events " + "ON some_table.id = events.order_id " + "WHERE tenant_id = 42 AND user_id = 99", + ), + # Subquery in FROM — splice into the inner SELECT. + ( + "SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub", + {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"}, + "SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table " + "WHERE tenant_id = 42) sub", + ), + # CTE — splice into the CTE body. + ( + "WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte", + {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"}, + "WITH cte AS (SELECT LAST_DAY(d) FROM some_table " + "WHERE tenant_id = 42) SELECT * FROM cte", + ), + # Dialect-specific function (LAST_DAY) preserved verbatim. + ( + "SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'", + {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"}, + "SELECT id, LAST_DAY(created_at) FROM some_table " + "WHERE region = 'US' AND tenant_id = 42", + ), + # Multiline + inline comment preserved exactly. + ( + "SELECT LAST_DAY(created_at) -- last day of month\n" + "FROM some_table\n" + "WHERE region = 'US'", + {Table("some_table", "schema1", "catalog1"): "tenant_id = 42"}, + "SELECT LAST_DAY(created_at) -- last day of month\n" + "FROM some_table\n" + "WHERE region = 'US' AND tenant_id = 42", + ), + # Schema-qualified table name (no default schema match) — no predicate. + ( + "SELECT t.foo FROM schema2.some_table AS t", + {Table("some_table", "schema1", "catalog1"): "id = 42"}, + "SELECT t.foo FROM schema2.some_table AS t", + ), + ], +) +def test_rls_predicate_splice( + sql: str, + rules: dict[Table, str], + expected: str, +) -> None: + """ + Test the splice-mode RLS via ``RLSMethod.AS_PREDICATE_SPLICE``. + + Splice mode rewrites the original SQL string instead of re-rendering the + AST through the dialect generator, so byte-level fidelity (including + dialect-specific functions, comments, and whitespace) is preserved. + """ + statement = SQLStatement(sql) + statement.apply_rls( + "catalog1", + "schema1", + {k: [v] for k, v in rules.items()}, + RLSMethod.AS_PREDICATE_SPLICE, + ) + assert statement.format() == expected + + +def test_rls_predicate_splice_requires_source() -> None: + """ + Splice mode requires the original SQL substring; constructing a statement + purely from an AST should make splice mode raise. + """ + ast = parse_one("SELECT * FROM some_table") + statement = SQLStatement(ast=ast, engine="postgresql") + with pytest.raises(ValueError, match="Splice-mode RLS requires the source SQL"): + statement.apply_rls( + "catalog1", + "schema1", + {Table("some_table", "schema1", "catalog1"): ["id = 42"]}, + RLSMethod.AS_PREDICATE_SPLICE, + ) + + +def test_rls_predicate_splice_preserves_dialect_function() -> None: + """ + Splice mode must NOT round-trip through the sqlglot generator. ``LAST_DAY`` + on the postgres dialect would otherwise be transpiled by the generator. + """ + sql = "SELECT LAST_DAY(d) FROM some_table" + statement = SQLStatement(sql, engine="postgresql") + statement.apply_rls( + None, + None, + {Table("some_table"): ["tenant_id = 42"]}, + RLSMethod.AS_PREDICATE_SPLICE, + ) + assert "LAST_DAY(d)" in statement.format() + + +def test_rls_predicate_splice_string_predicates_skip_parse() -> None: + """ + Splice mode accepts predicate strings directly — no ``parse_predicate`` is + needed at the call site. + """ + sql = "SELECT * FROM some_table" + statement = SQLStatement(sql, engine="postgresql") + statement.apply_rls( + None, + None, + {Table("some_table"): ["tenant_id = 42 AND active"]}, + RLSMethod.AS_PREDICATE_SPLICE, + ) + assert statement.format() == ( + "SELECT * FROM some_table WHERE tenant_id = 42 AND active" + ) + + @pytest.mark.parametrize( "sql, table, expected", [