mirror of
https://github.com/apache/superset.git
synced 2026-06-05 07:39:19 +00:00
feat: implement RLS in sqlglot (#33524)
This commit is contained in:
@@ -55,7 +55,7 @@ SQLGLOT_DIALECTS = {
|
||||
# "db2": ???
|
||||
# "dremio": ???
|
||||
"drill": Dialects.DRILL,
|
||||
# "druid": ???
|
||||
"druid": Dialects.DRUID,
|
||||
"duckdb": Dialects.DUCKDB,
|
||||
# "dynamodb": ???
|
||||
# "elasticsearch": ???
|
||||
@@ -108,6 +108,150 @@ class LimitMethod(enum.Enum):
|
||||
FETCH_MANY = enum.auto()
|
||||
|
||||
|
||||
class RLSMethod(enum.Enum):
|
||||
"""
|
||||
Methods for enforcing RLS.
|
||||
"""
|
||||
|
||||
AS_PREDICATE = enum.auto()
|
||||
AS_SUBQUERY = enum.auto()
|
||||
|
||||
|
||||
class RLSTransformer:
|
||||
"""
|
||||
AST transformer to apply RLS rules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
rules: dict[Table, list[exp.Expression]],
|
||||
) -> None:
|
||||
self.catalog = catalog
|
||||
self.schema = schema
|
||||
self.rules = rules
|
||||
|
||||
def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
|
||||
"""
|
||||
Get the combined RLS predicate for a table.
|
||||
"""
|
||||
table = Table(
|
||||
table_node.name,
|
||||
table_node.db if table_node.db else self.schema,
|
||||
table_node.catalog if table_node.catalog else self.catalog,
|
||||
)
|
||||
if predicates := self.rules.get(table):
|
||||
return (
|
||||
exp.And(
|
||||
this=predicates[0],
|
||||
expressions=predicates[1:],
|
||||
)
|
||||
if len(predicates) > 1
|
||||
else predicates[0]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RLSAsPredicateTransformer(RLSTransformer):
|
||||
"""
|
||||
Apply Row Level Security role as a predicate.
|
||||
|
||||
This transformer will apply any RLS predicates to the relevant tables. For example,
|
||||
given the RLS rule:
|
||||
|
||||
table: some_table
|
||||
clause: id = 42
|
||||
|
||||
If a user subject to the rule runs the following query:
|
||||
|
||||
SELECT foo FROM some_table WHERE bar = 'baz'
|
||||
|
||||
The query will be modified to:
|
||||
|
||||
SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
|
||||
|
||||
This approach is probably less secure than using subqueries, so it's only used for
|
||||
databases without support for subqueries.
|
||||
"""
|
||||
|
||||
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||
if not isinstance(node, exp.Table):
|
||||
return node
|
||||
|
||||
predicate = self.get_predicate(node)
|
||||
if not predicate:
|
||||
return node
|
||||
|
||||
# qualify columns with table name
|
||||
for column in predicate.find_all(exp.Column):
|
||||
column.set("table", node.alias or node.this)
|
||||
|
||||
if isinstance(node.parent, exp.From):
|
||||
select = node.parent.parent
|
||||
if where := select.args.get("where"):
|
||||
predicate = exp.And(
|
||||
this=predicate,
|
||||
expression=exp.Paren(this=where.this),
|
||||
)
|
||||
select.set("where", exp.Where(this=predicate))
|
||||
|
||||
elif isinstance(node.parent, exp.Join):
|
||||
join = node.parent
|
||||
if on := join.args.get("on"):
|
||||
predicate = exp.And(
|
||||
this=predicate,
|
||||
expression=exp.Paren(this=on),
|
||||
)
|
||||
join.set("on", predicate)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
class RLSAsSubqueryTransformer(RLSTransformer):
|
||||
"""
|
||||
Apply Row Level Security role as a subquery.
|
||||
|
||||
This transformer will apply any RLS predicates to the relevant tables. For example,
|
||||
given the RLS rule:
|
||||
|
||||
table: some_table
|
||||
clause: id = 42
|
||||
|
||||
If a user subject to the rule runs the following query:
|
||||
|
||||
SELECT foo FROM some_table WHERE bar = 'baz'
|
||||
|
||||
The query will be modified to:
|
||||
|
||||
SELECT foo FROM (SELECT * FROM some_table WHERE id = 42) AS some_table
|
||||
WHERE bar = 'baz'
|
||||
|
||||
This approach is probably more secure than using predicates, but it doesn't work for
|
||||
all databases.
|
||||
"""
|
||||
|
||||
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||
if not isinstance(node, exp.Table):
|
||||
return node
|
||||
|
||||
if predicate := self.get_predicate(node):
|
||||
# use alias or name
|
||||
alias = node.alias or node.sql()
|
||||
node.set("alias", None)
|
||||
node = exp.Subquery(
|
||||
this=exp.Select(
|
||||
expressions=[exp.Star()],
|
||||
where=exp.Where(this=predicate),
|
||||
**{"from": exp.From(this=node.copy())},
|
||||
),
|
||||
alias=alias,
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table:
|
||||
"""
|
||||
@@ -173,7 +317,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
elif statement:
|
||||
self._parsed = self._parse_statement(statement, engine)
|
||||
else:
|
||||
raise SupersetParseError("Either statement or ast must be provided")
|
||||
raise ValueError("Either statement or ast must be provided")
|
||||
|
||||
self.engine = engine
|
||||
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
|
||||
@@ -293,6 +437,22 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_rls(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[InternalRepresentation]],
|
||||
method: RLSMethod,
|
||||
) -> None:
|
||||
"""
|
||||
Apply relevant RLS rules to the statement inplace.
|
||||
|
||||
:param catalog: The default catalog for non-qualified table names
|
||||
:param schema: The default schema for non-qualified table names
|
||||
:param method: The method to use for applying the rules.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.format()
|
||||
|
||||
@@ -573,6 +733,30 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
engine=self.engine,
|
||||
)
|
||||
|
||||
def apply_rls(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[exp.Expression]],
|
||||
method: RLSMethod,
|
||||
) -> None:
|
||||
"""
|
||||
Apply relevant RLS rules to the statement inplace.
|
||||
|
||||
:param catalog: The default catalog for non-qualified table names
|
||||
:param schema: The default schema for non-qualified table names
|
||||
:param method: The method to use for applying the rules.
|
||||
"""
|
||||
transformers = {
|
||||
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
|
||||
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
|
||||
}
|
||||
if method not in transformers:
|
||||
raise ValueError(f"Invalid RLS method: {method}")
|
||||
|
||||
transformer = transformers[method](catalog, schema, predicates)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
@@ -966,7 +1150,7 @@ def extract_tables_from_statement(
|
||||
"""
|
||||
Extract all table references in a single statement.
|
||||
|
||||
Please not that this is not trivial; consider the following queries:
|
||||
Please note that this is not trivial; consider the following queries:
|
||||
|
||||
DESCRIBE some_table;
|
||||
SHOW PARTITIONS FROM some_table;
|
||||
|
||||
@@ -18,13 +18,14 @@
|
||||
|
||||
|
||||
import pytest
|
||||
from sqlglot import Dialects
|
||||
from sqlglot import Dialects, parse_one
|
||||
|
||||
from superset.exceptions import SupersetParseError
|
||||
from superset.sql.parse import (
|
||||
extract_tables_from_statement,
|
||||
KustoKQLStatement,
|
||||
LimitMethod,
|
||||
RLSMethod,
|
||||
split_kql,
|
||||
SQLGLOT_DIALECTS,
|
||||
SQLScript,
|
||||
@@ -303,11 +304,13 @@ def test_format_no_dialect() -> None:
|
||||
"""
|
||||
assert (
|
||||
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format()
|
||||
== """SELECT
|
||||
== """
|
||||
SELECT
|
||||
col
|
||||
FROM t
|
||||
WHERE
|
||||
NOT col IN (1, 2)"""
|
||||
NOT col IN (1, 2)
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
@@ -1118,7 +1121,8 @@ FROM some_table) AS anon_1
|
||||
WHERE anon_1.a > 1 AND anon_1.b = 2
|
||||
"""
|
||||
|
||||
optimized = """SELECT
|
||||
optimized = """
|
||||
SELECT
|
||||
anon_1.a,
|
||||
anon_1.b
|
||||
FROM (
|
||||
@@ -1131,9 +1135,11 @@ FROM (
|
||||
some_table.a > 1 AND some_table.b = 2
|
||||
) AS anon_1
|
||||
WHERE
|
||||
TRUE AND TRUE"""
|
||||
TRUE AND TRUE
|
||||
""".strip()
|
||||
|
||||
not_optimized = """SELECT
|
||||
not_optimized = """
|
||||
SELECT
|
||||
anon_1.a,
|
||||
anon_1.b
|
||||
FROM (
|
||||
@@ -1144,7 +1150,8 @@ FROM (
|
||||
FROM some_table
|
||||
) AS anon_1
|
||||
WHERE
|
||||
anon_1.a > 1 AND anon_1.b = 2"""
|
||||
anon_1.a > 1 AND anon_1.b = 2
|
||||
""".strip()
|
||||
|
||||
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
|
||||
assert SQLStatement(sql, "dremio").optimize().format() == not_optimized
|
||||
@@ -1195,9 +1202,11 @@ def test_firebolt_old() -> None:
|
||||
sql = "SELECT * FROM t1 UNNEST(col1 AS foo)"
|
||||
assert (
|
||||
SQLStatement(sql, "firebolt").format()
|
||||
== """SELECT
|
||||
== """
|
||||
SELECT
|
||||
*
|
||||
FROM t1 UNNEST(col1 AS foo)"""
|
||||
FROM t1 UNNEST(col1 AS foo)
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
@@ -1216,9 +1225,11 @@ def test_firebolt_old_escape_string() -> None:
|
||||
# but they normalize to ''
|
||||
assert (
|
||||
SQLStatement(sql, "firebolt").format()
|
||||
== """SELECT
|
||||
== """
|
||||
SELECT
|
||||
'foo''bar',
|
||||
'foo''bar'"""
|
||||
'foo''bar'
|
||||
""".strip()
|
||||
)
|
||||
|
||||
|
||||
@@ -1410,7 +1421,8 @@ select TOP 100 * from currency
|
||||
"mssql",
|
||||
1000,
|
||||
LimitMethod.FORCE_LIMIT,
|
||||
"""WITH abc AS (
|
||||
"""
|
||||
WITH abc AS (
|
||||
SELECT
|
||||
*
|
||||
FROM test
|
||||
@@ -1422,7 +1434,8 @@ select TOP 100 * from currency
|
||||
SELECT
|
||||
TOP 1000
|
||||
*
|
||||
FROM currency""",
|
||||
FROM currency
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT DISTINCT x from tbl",
|
||||
@@ -1457,10 +1470,12 @@ FROM currency""",
|
||||
"postgresql",
|
||||
1000,
|
||||
LimitMethod.FORCE_LIMIT,
|
||||
"""SELECT
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM birth_names /* SOME COMMENT WITH LIMIT 555 */
|
||||
LIMIT 1000""",
|
||||
LIMIT 1000
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM birth_names LIMIT 555",
|
||||
@@ -1602,7 +1617,8 @@ UNION ALL
|
||||
SELECT * FROM currency_2
|
||||
""",
|
||||
"postgresql",
|
||||
"""WITH currency AS (
|
||||
"""
|
||||
WITH currency AS (
|
||||
SELECT
|
||||
'INR' AS cur
|
||||
), currency_2 AS (
|
||||
@@ -1616,7 +1632,8 @@ SELECT * FROM currency_2
|
||||
SELECT
|
||||
*
|
||||
FROM currency_2
|
||||
)""",
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -1625,3 +1642,608 @@ def test_as_cte(sql: str, engine: str, expected: str) -> None:
|
||||
Test that we can covert select to CTE.
|
||||
"""
|
||||
assert SQLStatement(sql, engine).as_cte().format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, rules, expected",
|
||||
[
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM some_table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS t
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM some_table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS t
|
||||
WHERE
|
||||
bar = 'baz'
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM schema1.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM schema1.some_table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS t
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM schema1.some_table AS t",
|
||||
{Table("some_table", "schema2"): "id = 42"},
|
||||
"SELECT\n t.foo\nFROM schema1.some_table AS t",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM catalog1.schema1.some_table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS t
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog2"): "id = 42"},
|
||||
"SELECT\n t.foo\nFROM catalog1.schema1.some_table AS t",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table WHERE 1=1",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM some_table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS some_table
|
||||
WHERE
|
||||
1 = 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS table
|
||||
WHERE
|
||||
1 = 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
'SELECT * FROM "table" WHERE 1=1',
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM "table"
|
||||
WHERE
|
||||
id = 42
|
||||
) AS "table"
|
||||
WHERE
|
||||
1 = 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
1 = 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM other_table WHERE 1=1",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM other_table
|
||||
WHERE
|
||||
1 = 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
JOIN (
|
||||
SELECT
|
||||
*
|
||||
FROM other_table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS other_table
|
||||
ON table.id = other_table.id
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
'SELECT * FROM "table" JOIN other_table ON "table".id = other_table.id',
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM "table"
|
||||
WHERE
|
||||
id = 42
|
||||
) AS "table"
|
||||
JOIN other_table
|
||||
ON "table".id = other_table.id
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM some_table)",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM some_table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS some_table
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS table
|
||||
UNION ALL
|
||||
SELECT
|
||||
*
|
||||
FROM other_table
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
UNION ALL
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM other_table
|
||||
WHERE
|
||||
id = 42
|
||||
) AS other_table
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col",
|
||||
{Table("tbl_a", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
a.*,
|
||||
b.*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM tbl_a
|
||||
WHERE
|
||||
id = 42
|
||||
) AS a
|
||||
INNER JOIN tbl_b AS b
|
||||
ON a.col = b.col
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col",
|
||||
{Table("tbl_a", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
a.*,
|
||||
b.*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM tbl_a
|
||||
WHERE
|
||||
id = 42
|
||||
) AS a
|
||||
INNER JOIN tbl_b AS b
|
||||
ON a.col = b.col
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_subquery_transformer(
|
||||
sql: str,
|
||||
rules: dict[Table, str],
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Test `RLSAsSubqueryTransformer`.
|
||||
"""
|
||||
statement = SQLStatement(sql)
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{k: [parse_one(v)] for k, v in rules.items()},
|
||||
RLSMethod.AS_SUBQUERY,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, rules, expected",
|
||||
[
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM some_table AS t
|
||||
WHERE
|
||||
t.id = 42
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM schema2.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM schema2.some_table AS t
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM catalog2.schema1.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM catalog2.schema1.some_table AS t
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM some_table AS t
|
||||
WHERE
|
||||
t.id = 42 AND (
|
||||
bar = 'baz'
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
t.foo
|
||||
FROM some_table AS t
|
||||
WHERE
|
||||
t.id = 42 AND (
|
||||
bar = 'baz' OR foo = 'qux'
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table WHERE 1=1",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM some_table
|
||||
WHERE
|
||||
some_table.id = 42 AND (
|
||||
1 = 1
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table WHERE TRUE OR FALSE",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM some_table
|
||||
WHERE
|
||||
some_table.id = 42 AND (
|
||||
TRUE OR FALSE
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
table.id = 42 AND (
|
||||
1 = 1
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
'SELECT * FROM "table" WHERE 1=1',
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM "table"
|
||||
WHERE
|
||||
"table".id = 42 AND (
|
||||
1 = 1
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
1 = 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM other_table WHERE 1=1",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM other_table
|
||||
WHERE
|
||||
1 = 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
table.id = 42
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table",
|
||||
{Table("some_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM some_table
|
||||
WHERE
|
||||
some_table.id = 42
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table ORDER BY id",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
table.id = 42
|
||||
ORDER BY
|
||||
id
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1 AND table.id=42",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
table.id = 42 AND (
|
||||
1 = 1 AND table.id = 42
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"""
|
||||
SELECT * FROM table
|
||||
JOIN other_table
|
||||
ON table.id = other_table.id
|
||||
AND other_table.id=42
|
||||
""",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
JOIN other_table
|
||||
ON other_table.id = 42 AND (
|
||||
table.id = other_table.id AND other_table.id = 42
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table WHERE 1=1 AND id=42",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
table.id = 42 AND (
|
||||
1 = 1 AND id = 42
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
JOIN other_table
|
||||
ON other_table.id = 42 AND (
|
||||
table.id = other_table.id
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"""
|
||||
SELECT *
|
||||
FROM table
|
||||
JOIN other_table
|
||||
ON table.id = other_table.id
|
||||
WHERE 1=1
|
||||
""",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
JOIN other_table
|
||||
ON other_table.id = 42 AND (
|
||||
table.id = other_table.id
|
||||
)
|
||||
WHERE
|
||||
1 = 1
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM (SELECT * FROM other_table)",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM (
|
||||
SELECT
|
||||
*
|
||||
FROM other_table
|
||||
WHERE
|
||||
other_table.id = 42
|
||||
)
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
{Table("table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
WHERE
|
||||
table.id = 42
|
||||
UNION ALL
|
||||
SELECT
|
||||
*
|
||||
FROM other_table
|
||||
""".strip(),
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
|
||||
{Table("other_table", "schema1", "catalog1"): "id = 42"},
|
||||
"""
|
||||
SELECT
|
||||
*
|
||||
FROM table
|
||||
UNION ALL
|
||||
SELECT
|
||||
*
|
||||
FROM other_table
|
||||
WHERE
|
||||
other_table.id = 42
|
||||
""".strip(),
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_transformer(
|
||||
sql: str,
|
||||
rules: dict[Table, str],
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Test `RLSPredicateTransformer`.
|
||||
"""
|
||||
statement = SQLStatement(sql)
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{k: [parse_one(v)] for k, v in rules.items()},
|
||||
RLSMethod.AS_PREDICATE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
Reference in New Issue
Block a user