mirror of
https://github.com/apache/superset.git
synced 2026-05-23 16:55:19 +00:00
feat: new splice RLSMethod
This commit is contained in:
@@ -2548,6 +2548,165 @@ def test_rls_predicate_transformer(
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, rules, expected",
|
||||
[
|
||||
# Simple — no WHERE clause to extend.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42",
|
||||
),
|
||||
# Append to an existing WHERE clause.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
|
||||
"AND tenant_id = 42",
|
||||
),
|
||||
# WHERE precedes GROUP BY: predicate goes before GROUP BY.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' "
|
||||
"AND tenant_id = 42 GROUP BY d",
|
||||
),
|
||||
# No WHERE, but GROUP BY and ORDER BY are present.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE tenant_id = 42 "
|
||||
"GROUP BY d ORDER BY d",
|
||||
),
|
||||
# JOIN — predicate scoped to one of the tables.
|
||||
(
|
||||
"SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT o.id FROM some_table o JOIN locations l "
|
||||
"ON o.loc_id = l.id WHERE tenant_id = 42",
|
||||
),
|
||||
# JOIN — different predicate per table, both spliced into one WHERE.
|
||||
(
|
||||
"SELECT * FROM some_table JOIN events ON some_table.id = events.order_id",
|
||||
{
|
||||
Table("some_table", "schema1", "catalog1"): "tenant_id = 42",
|
||||
Table("events", "schema1", "catalog1"): "user_id = 99",
|
||||
},
|
||||
"SELECT * FROM some_table JOIN events "
|
||||
"ON some_table.id = events.order_id "
|
||||
"WHERE tenant_id = 42 AND user_id = 99",
|
||||
),
|
||||
# Subquery in FROM — splice into the inner SELECT.
|
||||
(
|
||||
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
|
||||
"WHERE tenant_id = 42) sub",
|
||||
),
|
||||
# CTE — splice into the CTE body.
|
||||
(
|
||||
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
|
||||
"WHERE tenant_id = 42) SELECT * FROM cte",
|
||||
),
|
||||
# Dialect-specific function (LAST_DAY) preserved verbatim.
|
||||
(
|
||||
"SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT id, LAST_DAY(created_at) FROM some_table "
|
||||
"WHERE region = 'US' AND tenant_id = 42",
|
||||
),
|
||||
# Multiline + inline comment preserved exactly.
|
||||
(
|
||||
"SELECT LAST_DAY(created_at) -- last day of month\n"
|
||||
"FROM some_table\n"
|
||||
"WHERE region = 'US'",
|
||||
{Table("some_table", "schema1", "catalog1"): "tenant_id = 42"},
|
||||
"SELECT LAST_DAY(created_at) -- last day of month\n"
|
||||
"FROM some_table\n"
|
||||
"WHERE region = 'US' AND tenant_id = 42",
|
||||
),
|
||||
# Schema-qualified table name (no default schema match) — no predicate.
|
||||
(
|
||||
"SELECT t.foo FROM schema2.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"SELECT t.foo FROM schema2.some_table AS t",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_splice(
|
||||
sql: str,
|
||||
rules: dict[Table, str],
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Test the splice-mode RLS via ``RLSMethod.AS_PREDICATE_SPLICE``.
|
||||
|
||||
Splice mode rewrites the original SQL string instead of re-rendering the
|
||||
AST through the dialect generator, so byte-level fidelity (including
|
||||
dialect-specific functions, comments, and whitespace) is preserved.
|
||||
"""
|
||||
statement = SQLStatement(sql)
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{k: [v] for k, v in rules.items()},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
def test_rls_predicate_splice_requires_source() -> None:
|
||||
"""
|
||||
Splice mode requires the original SQL substring; constructing a statement
|
||||
purely from an AST should make splice mode raise.
|
||||
"""
|
||||
ast = parse_one("SELECT * FROM some_table")
|
||||
statement = SQLStatement(ast=ast, engine="postgresql")
|
||||
with pytest.raises(ValueError, match="Splice-mode RLS requires the source SQL"):
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{Table("some_table", "schema1", "catalog1"): ["id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
|
||||
|
||||
def test_rls_predicate_splice_preserves_dialect_function() -> None:
|
||||
"""
|
||||
Splice mode must NOT round-trip through the sqlglot generator. ``LAST_DAY``
|
||||
on the postgres dialect would otherwise be transpiled by the generator.
|
||||
"""
|
||||
sql = "SELECT LAST_DAY(d) FROM some_table"
|
||||
statement = SQLStatement(sql, engine="postgresql")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert "LAST_DAY(d)" in statement.format()
|
||||
|
||||
|
||||
def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
|
||||
"""
|
||||
Splice mode accepts predicate strings directly — no ``parse_predicate`` is
|
||||
needed at the call site.
|
||||
"""
|
||||
sql = "SELECT * FROM some_table"
|
||||
statement = SQLStatement(sql, engine="postgresql")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42 AND active"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == (
|
||||
"SELECT * FROM some_table WHERE tenant_id = 42 AND active"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, table, expected",
|
||||
[
|
||||
|
||||
Reference in New Issue
Block a user