Increase parity

This commit is contained in:
Beto Dealmeida
2026-05-08 15:32:10 -04:00
parent e9dd9a6107
commit 884649e3ed
2 changed files with 348 additions and 107 deletions

View File

@@ -2563,89 +2563,142 @@ def test_rls_predicate_transformer(
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42 "
"AND (bar = 'baz' OR foo = 'qux')",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id)",
),
(
"SELECT * FROM table JOIN other_table",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
"WHERE 1=1",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id) WHERE 1=1",
),
],
)
def test_rls_predicate_splice_semantics_match_predicate(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Splice mode should preserve predicate-mode semantics for boolean grouping
and JOIN-vs-WHERE placement.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
# Simple — no WHERE clause to extend.
(
"SELECT LAST_DAY(d) FROM some_table",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42",
),
# Append to an existing WHERE clause.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
"AND tenant_id = 42",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open')",
),
# WHERE precedes GROUP BY: predicate goes before GROUP BY.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
"AND tenant_id = 42 GROUP BY d",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open') GROUP BY d",
),
# No WHERE, but GROUP BY and ORDER BY are present.
(
"SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42 "
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42 "
"GROUP BY d ORDER BY d",
),
# JOIN — predicate scoped to one of the tables.
(
"SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
{Table("some_table", "schema1", "catalog1"): "o.tenant_id = 42"},
"SELECT o.id FROM some_table o JOIN locations l "
"ON o.loc_id = l.id WHERE tenant_id = 42",
"ON o.loc_id = l.id WHERE o.tenant_id = 42",
),
# JOIN — different predicate per table, both spliced into one WHERE.
(
"SELECT * FROM some_table JOIN events ON some_table.id = events.order_id",
{
Table("some_table", "schema1", "catalog1"): "tenant_id = 42",
Table("events", "schema1", "catalog1"): "user_id = 99",
Table("events", "schema1", "catalog1"): "events.user_id = 99",
Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42",
},
"SELECT * FROM some_table JOIN events "
"ON some_table.id = events.order_id "
"WHERE tenant_id = 42 AND user_id = 99",
"ON events.user_id = 99 AND (some_table.id = events.order_id) "
"WHERE some_table.tenant_id = 42",
),
# Subquery in FROM — splice into the inner SELECT.
(
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
"WHERE tenant_id = 42) sub",
"WHERE some_table.tenant_id = 42) sub",
),
# CTE — splice into the CTE body.
(
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
"WHERE tenant_id = 42) SELECT * FROM cte",
"WHERE some_table.tenant_id = 42) SELECT * FROM cte",
),
# Dialect-specific function (LAST_DAY) preserved verbatim.
(
"SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT id, LAST_DAY(created_at) FROM some_table "
"WHERE region = 'US' AND tenant_id = 42",
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Multiline + inline comment preserved exactly.
(
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE region = 'US' AND tenant_id = 42",
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Schema-qualified table name (no default schema match) — no predicate.
(
"SELECT t.foo FROM schema2.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM schema2.some_table AS t",
),
],
@@ -2698,10 +2751,12 @@ def test_rls_predicate_splice_preserves_dialect_function() -> None:
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42"]},
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert "LAST_DAY(d)" in statement.format()
assert statement.format() == (
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42"
)
def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
@@ -2714,11 +2769,11 @@ def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42 AND active"]},
{Table("some_table"): ["some_table.tenant_id = 42 AND some_table.active"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT * FROM some_table WHERE tenant_id = 42 AND active"
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 AND some_table.active"
)
@@ -2727,11 +2782,13 @@ def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
[
(
"SELECT * FROM some_table -- hi\nGROUP BY id",
"SELECT * FROM some_table WHERE tenant_id = 42 -- hi\nGROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"-- hi\nGROUP BY id",
),
(
"SELECT * FROM some_table /* inline */ GROUP BY id",
"SELECT * FROM some_table WHERE tenant_id = 42 /* inline */ GROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"/* inline */ GROUP BY id",
),
],
)
@@ -2744,7 +2801,7 @@ def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42"]},
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@@ -2757,13 +2814,14 @@ def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -
"SELECT * FROM some_table QUALIFY row_number() OVER "
"(PARTITION BY id ORDER BY ts DESC) = 1",
"snowflake",
"SELECT * FROM some_table WHERE tenant_id = 42 "
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
),
(
"SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY id)",
"postgresql",
"SELECT sum(v) OVER () FROM some_table WHERE tenant_id = 42 "
"SELECT sum(v) OVER () FROM some_table "
"WHERE some_table.tenant_id = 42 "
"WINDOW w AS (PARTITION BY id)",
),
],
@@ -2781,7 +2839,7 @@ def test_rls_predicate_splice_handles_additional_clause_boundaries(
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42"]},
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@@ -2801,7 +2859,7 @@ def test_rls_predicate_splice_then_limit_keeps_rls() -> None:
statement.set_limit_value(101, LimitMethod.FORCE_LIMIT)
formatted = statement.format()
assert "tenant_id = 42" in formatted
assert "some_table.tenant_id = 42" in formatted
assert "LIMIT 101" in formatted