feat: new splice RLSMethod

This commit is contained in:
Beto Dealmeida
2026-05-08 13:33:42 -04:00
parent 5bde86785f
commit 3fe2b2505f
6 changed files with 615 additions and 14 deletions

View File

@@ -2548,6 +2548,165 @@ def test_rls_predicate_transformer(
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
# Simple — no WHERE clause to extend.
(
"SELECT LAST_DAY(d) FROM some_table",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42",
),
# Append to an existing WHERE clause.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
"AND tenant_id = 42",
),
# WHERE precedes GROUP BY: predicate goes before GROUP BY.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
"AND tenant_id = 42 GROUP BY d",
),
# No WHERE, but GROUP BY and ORDER BY are present.
(
"SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42 "
"GROUP BY d ORDER BY d",
),
# JOIN — predicate scoped to one of the tables.
(
"SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT o.id FROM some_table o JOIN locations l "
"ON o.loc_id = l.id WHERE tenant_id = 42",
),
# JOIN — different predicate per table, both spliced into one WHERE.
(
"SELECT * FROM some_table JOIN events ON some_table.id = events.order_id",
{
Table("some_table", "schema1", "catalog1"): "tenant_id = 42",
Table("events", "schema1", "catalog1"): "user_id = 99",
},
"SELECT * FROM some_table JOIN events "
"ON some_table.id = events.order_id "
"WHERE tenant_id = 42 AND user_id = 99",
),
# Subquery in FROM — splice into the inner SELECT.
(
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
"WHERE tenant_id = 42) sub",
),
# CTE — splice into the CTE body.
(
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
"WHERE tenant_id = 42) SELECT * FROM cte",
),
# Dialect-specific function (LAST_DAY) preserved verbatim.
(
"SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT id, LAST_DAY(created_at) FROM some_table "
"WHERE region = 'US' AND tenant_id = 42",
),
# Multiline + inline comment preserved exactly.
(
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE region = 'US' AND tenant_id = 42",
),
# Schema-qualified table name (no default schema match) — no predicate.
(
"SELECT t.foo FROM schema2.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"SELECT t.foo FROM schema2.some_table AS t",
),
],
)
def test_rls_predicate_splice(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Test the splice-mode RLS via ``RLSMethod.AS_PREDICATE_SPLICE``.
Splice mode rewrites the original SQL string instead of re-rendering the
AST through the dialect generator, so byte-level fidelity (including
dialect-specific functions, comments, and whitespace) is preserved.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
def test_rls_predicate_splice_requires_source() -> None:
"""
Splice mode requires the original SQL substring; constructing a statement
purely from an AST should make splice mode raise.
"""
ast = parse_one("SELECT * FROM some_table")
statement = SQLStatement(ast=ast, engine="postgresql")
with pytest.raises(ValueError, match="Splice-mode RLS requires the source SQL"):
statement.apply_rls(
"catalog1",
"schema1",
{Table("some_table", "schema1", "catalog1"): ["id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
def test_rls_predicate_splice_preserves_dialect_function() -> None:
"""
Splice mode must NOT round-trip through the sqlglot generator. ``LAST_DAY``
on the postgres dialect would otherwise be transpiled by the generator.
"""
sql = "SELECT LAST_DAY(d) FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert "LAST_DAY(d)" in statement.format()
def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
"""
Splice mode accepts predicate strings directly — no ``parse_predicate`` is
needed at the call site.
"""
sql = "SELECT * FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42 AND active"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT * FROM some_table WHERE tenant_id = 42 AND active"
)
@pytest.mark.parametrize(
"sql, table, expected",
[