mirror of
https://github.com/apache/superset.git
synced 2026-05-28 19:25:20 +00:00
Increase parity
This commit is contained in:
@@ -2563,89 +2563,142 @@ def test_rls_predicate_transformer(
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, rules, expected",
|
||||
[
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
|
||||
"SELECT t.foo FROM some_table AS t WHERE t.id = 42",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
|
||||
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
|
||||
"SELECT t.foo FROM some_table AS t WHERE t.id = 42 "
|
||||
"AND (bar = 'baz' OR foo = 'qux')",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
|
||||
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
|
||||
"AND (table.id = other_table.id)",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table",
|
||||
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id = 42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
|
||||
"WHERE 1=1",
|
||||
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
|
||||
"AND (table.id = other_table.id) WHERE 1=1",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_splice_semantics_match_predicate(
|
||||
sql: str,
|
||||
rules: dict[Table, str],
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Splice mode should preserve predicate-mode semantics for boolean grouping
|
||||
and JOIN-vs-WHERE placement.
|
||||
"""
|
||||
statement = SQLStatement(sql)
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{k: [v] for k, v in rules.items()},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, rules, expected",
|
||||
[
|
||||
# Simple — no WHERE clause to extend.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42",
|
||||
),
|
||||
# Append to an existing WHERE clause.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
|
||||
"AND tenant_id = 42",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table "
|
||||
"WHERE some_table.tenant_id = 42 AND (status = 'open')",
|
||||
),
|
||||
# WHERE precedes GROUP BY: predicate goes before GROUP BY.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
|
||||
"AND tenant_id = 42 GROUP BY d",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table "
|
||||
"WHERE some_table.tenant_id = 42 AND (status = 'open') GROUP BY d",
|
||||
),
|
||||
# No WHERE, but GROUP BY and ORDER BY are present.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42 "
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"GROUP BY d ORDER BY d",
|
||||
),
|
||||
# JOIN — predicate scoped to one of the tables.
|
||||
(
|
||||
"SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
{Table("some_table", "schema1", "catalog1"): "o.tenant_id = 42"},
|
||||
"SELECT o.id FROM some_table o JOIN locations l "
|
||||
"ON o.loc_id = l.id WHERE tenant_id = 42",
|
||||
"ON o.loc_id = l.id WHERE o.tenant_id = 42",
|
||||
),
|
||||
# JOIN — different predicate per table, both spliced into one WHERE.
|
||||
(
|
||||
"SELECT * FROM some_table JOIN events ON some_table.id = events.order_id",
|
||||
{
|
||||
Table("some_table", "schema1", "catalog1"): "tenant_id = 42",
|
||||
Table("events", "schema1", "catalog1"): "user_id = 99",
|
||||
Table("events", "schema1", "catalog1"): "events.user_id = 99",
|
||||
Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42",
|
||||
},
|
||||
"SELECT * FROM some_table JOIN events "
|
||||
"ON some_table.id = events.order_id "
|
||||
"WHERE tenant_id = 42 AND user_id = 99",
|
||||
"ON events.user_id = 99 AND (some_table.id = events.order_id) "
|
||||
"WHERE some_table.tenant_id = 42",
|
||||
),
|
||||
# Subquery in FROM — splice into the inner SELECT.
|
||||
(
|
||||
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
|
||||
"WHERE tenant_id = 42) sub",
|
||||
"WHERE some_table.tenant_id = 42) sub",
|
||||
),
|
||||
# CTE — splice into the CTE body.
|
||||
(
|
||||
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
|
||||
"WHERE tenant_id = 42) SELECT * FROM cte",
|
||||
"WHERE some_table.tenant_id = 42) SELECT * FROM cte",
|
||||
),
|
||||
# Dialect-specific function (LAST_DAY) preserved verbatim.
|
||||
(
|
||||
"SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT id, LAST_DAY(created_at) FROM some_table "
|
||||
"WHERE region = 'US' AND tenant_id = 42",
|
||||
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
|
||||
),
|
||||
# Multiline + inline comment preserved exactly.
|
||||
(
|
||||
"SELECT LAST_DAY(created_at) -- last day of month\n"
|
||||
"FROM some_table\n"
|
||||
"WHERE region = 'US'",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(created_at) -- last day of month\n"
|
||||
"FROM some_table\n"
|
||||
"WHERE region = 'US' AND tenant_id = 42",
|
||||
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
|
||||
),
|
||||
# Schema-qualified table name (no default schema match) — no predicate.
|
||||
(
|
||||
"SELECT t.foo FROM schema2.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
|
||||
"SELECT t.foo FROM schema2.some_table AS t",
|
||||
),
|
||||
],
|
||||
@@ -2698,10 +2751,12 @@ def test_rls_predicate_splice_preserves_dialect_function() -> None:
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42"]},
|
||||
{Table("some_table"): ["some_table.tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert "LAST_DAY(d)" in statement.format()
|
||||
assert statement.format() == (
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42"
|
||||
)
|
||||
|
||||
|
||||
def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
|
||||
@@ -2714,11 +2769,11 @@ def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42 AND active"]},
|
||||
{Table("some_table"): ["some_table.tenant_id = 42 AND some_table.active"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == (
|
||||
"SELECT * FROM some_table WHERE tenant_id = 42 AND active"
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 AND some_table.active"
|
||||
)
|
||||
|
||||
|
||||
@@ -2727,11 +2782,13 @@ def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
|
||||
[
|
||||
(
|
||||
"SELECT * FROM some_table -- hi\nGROUP BY id",
|
||||
"SELECT * FROM some_table WHERE tenant_id = 42 -- hi\nGROUP BY id",
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"-- hi\nGROUP BY id",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table /* inline */ GROUP BY id",
|
||||
"SELECT * FROM some_table WHERE tenant_id = 42 /* inline */ GROUP BY id",
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"/* inline */ GROUP BY id",
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -2744,7 +2801,7 @@ def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42"]},
|
||||
{Table("some_table"): ["some_table.tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
@@ -2757,13 +2814,14 @@ def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -
|
||||
"SELECT * FROM some_table QUALIFY row_number() OVER "
|
||||
"(PARTITION BY id ORDER BY ts DESC) = 1",
|
||||
"snowflake",
|
||||
"SELECT * FROM some_table WHERE tenant_id = 42 "
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
|
||||
),
|
||||
(
|
||||
"SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY id)",
|
||||
"postgresql",
|
||||
"SELECT sum(v) OVER () FROM some_table WHERE tenant_id = 42 "
|
||||
"SELECT sum(v) OVER () FROM some_table "
|
||||
"WHERE some_table.tenant_id = 42 "
|
||||
"WINDOW w AS (PARTITION BY id)",
|
||||
),
|
||||
],
|
||||
@@ -2781,7 +2839,7 @@ def test_rls_predicate_splice_handles_additional_clause_boundaries(
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42"]},
|
||||
{Table("some_table"): ["some_table.tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
@@ -2801,7 +2859,7 @@ def test_rls_predicate_splice_then_limit_keeps_rls() -> None:
|
||||
statement.set_limit_value(101, LimitMethod.FORCE_LIMIT)
|
||||
|
||||
formatted = statement.format()
|
||||
assert "tenant_id = 42" in formatted
|
||||
assert "some_table.tenant_id = 42" in formatted
|
||||
assert "LIMIT 101" in formatted
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user