feat: implement RLS in sqlglot (#33524)

This commit is contained in:
Beto Dealmeida
2025-05-28 09:10:45 -04:00
committed by GitHub
parent e205846845
commit 0abe6eed89
2 changed files with 826 additions and 20 deletions

View File

@@ -55,7 +55,7 @@ SQLGLOT_DIALECTS = {
# "db2": ???
# "dremio": ???
"drill": Dialects.DRILL,
# "druid": ???
"druid": Dialects.DRUID,
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
@@ -108,6 +108,150 @@ class LimitMethod(enum.Enum):
FETCH_MANY = enum.auto()
class RLSMethod(enum.Enum):
"""
Methods for enforcing RLS.
"""
AS_PREDICATE = enum.auto()
AS_SUBQUERY = enum.auto()
class RLSTransformer:
"""
AST transformer to apply RLS rules.
"""
def __init__(
self,
catalog: str | None,
schema: str | None,
rules: dict[Table, list[exp.Expression]],
) -> None:
self.catalog = catalog
self.schema = schema
self.rules = rules
def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
"""
Get the combined RLS predicate for a table.
"""
table = Table(
table_node.name,
table_node.db if table_node.db else self.schema,
table_node.catalog if table_node.catalog else self.catalog,
)
if predicates := self.rules.get(table):
return (
exp.And(
this=predicates[0],
expressions=predicates[1:],
)
if len(predicates) > 1
else predicates[0]
)
return None
class RLSAsPredicateTransformer(RLSTransformer):
"""
Apply Row Level Security role as a predicate.
This transformer will apply any RLS predicates to the relevant tables. For example,
given the RLS rule:
table: some_table
clause: id = 42
If a user subject to the rule runs the following query:
SELECT foo FROM some_table WHERE bar = 'baz'
The query will be modified to:
SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
This approach is probably less secure than using subqueries, so it's only used for
databases without support for subqueries.
"""
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Table):
return node
predicate = self.get_predicate(node)
if not predicate:
return node
# qualify columns with table name
for column in predicate.find_all(exp.Column):
column.set("table", node.alias or node.this)
if isinstance(node.parent, exp.From):
select = node.parent.parent
if where := select.args.get("where"):
predicate = exp.And(
this=predicate,
expression=exp.Paren(this=where.this),
)
select.set("where", exp.Where(this=predicate))
elif isinstance(node.parent, exp.Join):
join = node.parent
if on := join.args.get("on"):
predicate = exp.And(
this=predicate,
expression=exp.Paren(this=on),
)
join.set("on", predicate)
return node
class RLSAsSubqueryTransformer(RLSTransformer):
"""
Apply Row Level Security role as a subquery.
This transformer will apply any RLS predicates to the relevant tables. For example,
given the RLS rule:
table: some_table
clause: id = 42
If a user subject to the rule runs the following query:
SELECT foo FROM some_table WHERE bar = 'baz'
The query will be modified to:
SELECT foo FROM (SELECT * FROM some_table WHERE id = 42) AS some_table
WHERE bar = 'baz'
This approach is probably more secure than using predicates, but it doesn't work for
all databases.
"""
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Table):
return node
if predicate := self.get_predicate(node):
# use alias or name
alias = node.alias or node.sql()
node.set("alias", None)
node = exp.Subquery(
this=exp.Select(
expressions=[exp.Star()],
where=exp.Where(this=predicate),
**{"from": exp.From(this=node.copy())},
),
alias=alias,
)
return node
@dataclass(eq=True, frozen=True)
class Table:
"""
@@ -173,7 +317,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
elif statement:
self._parsed = self._parse_statement(statement, engine)
else:
raise SupersetParseError("Either statement or ast must be provided")
raise ValueError("Either statement or ast must be provided")
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
@@ -293,6 +437,22 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def apply_rls(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[InternalRepresentation]],
method: RLSMethod,
) -> None:
"""
Apply relevant RLS rules to the statement inplace.
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param method: The method to use for applying the rules.
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
@@ -573,6 +733,30 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
engine=self.engine,
)
def apply_rls(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[exp.Expression]],
method: RLSMethod,
) -> None:
"""
Apply relevant RLS rules to the statement inplace.
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param method: The method to use for applying the rules.
"""
transformers = {
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
}
if method not in transformers:
raise ValueError(f"Invalid RLS method: {method}")
transformer = transformers[method](catalog, schema, predicates)
self._parsed = self._parsed.transform(transformer)
class KQLSplitState(enum.Enum):
"""
@@ -966,7 +1150,7 @@ def extract_tables_from_statement(
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
Please note that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;

View File

@@ -18,13 +18,14 @@
import pytest
from sqlglot import Dialects
from sqlglot import Dialects, parse_one
from superset.exceptions import SupersetParseError
from superset.sql.parse import (
extract_tables_from_statement,
KustoKQLStatement,
LimitMethod,
RLSMethod,
split_kql,
SQLGLOT_DIALECTS,
SQLScript,
@@ -303,11 +304,13 @@ def test_format_no_dialect() -> None:
"""
assert (
SQLScript("SELECT col FROM t WHERE col NOT IN (1, 2)", "dremio").format()
== """SELECT
== """
SELECT
col
FROM t
WHERE
NOT col IN (1, 2)"""
NOT col IN (1, 2)
""".strip()
)
@@ -1118,7 +1121,8 @@ FROM some_table) AS anon_1
WHERE anon_1.a > 1 AND anon_1.b = 2
"""
optimized = """SELECT
optimized = """
SELECT
anon_1.a,
anon_1.b
FROM (
@@ -1131,9 +1135,11 @@ FROM (
some_table.a > 1 AND some_table.b = 2
) AS anon_1
WHERE
TRUE AND TRUE"""
TRUE AND TRUE
""".strip()
not_optimized = """SELECT
not_optimized = """
SELECT
anon_1.a,
anon_1.b
FROM (
@@ -1144,7 +1150,8 @@ FROM (
FROM some_table
) AS anon_1
WHERE
anon_1.a > 1 AND anon_1.b = 2"""
anon_1.a > 1 AND anon_1.b = 2
""".strip()
assert SQLStatement(sql, "sqlite").optimize().format() == optimized
assert SQLStatement(sql, "dremio").optimize().format() == not_optimized
@@ -1195,9 +1202,11 @@ def test_firebolt_old() -> None:
sql = "SELECT * FROM t1 UNNEST(col1 AS foo)"
assert (
SQLStatement(sql, "firebolt").format()
== """SELECT
== """
SELECT
*
FROM t1 UNNEST(col1 AS foo)"""
FROM t1 UNNEST(col1 AS foo)
""".strip()
)
@@ -1216,9 +1225,11 @@ def test_firebolt_old_escape_string() -> None:
# but they normalize to ''
assert (
SQLStatement(sql, "firebolt").format()
== """SELECT
== """
SELECT
'foo''bar',
'foo''bar'"""
'foo''bar'
""".strip()
)
@@ -1410,7 +1421,8 @@ select TOP 100 * from currency
"mssql",
1000,
LimitMethod.FORCE_LIMIT,
"""WITH abc AS (
"""
WITH abc AS (
SELECT
*
FROM test
@@ -1422,7 +1434,8 @@ select TOP 100 * from currency
SELECT
TOP 1000
*
FROM currency""",
FROM currency
""".strip(),
),
(
"SELECT DISTINCT x from tbl",
@@ -1457,10 +1470,12 @@ FROM currency""",
"postgresql",
1000,
LimitMethod.FORCE_LIMIT,
"""SELECT
"""
SELECT
*
FROM birth_names /* SOME COMMENT WITH LIMIT 555 */
LIMIT 1000""",
LIMIT 1000
""".strip(),
),
(
"SELECT * FROM birth_names LIMIT 555",
@@ -1602,7 +1617,8 @@ UNION ALL
SELECT * FROM currency_2
""",
"postgresql",
"""WITH currency AS (
"""
WITH currency AS (
SELECT
'INR' AS cur
), currency_2 AS (
@@ -1616,7 +1632,8 @@ SELECT * FROM currency_2
SELECT
*
FROM currency_2
)""",
)
""".strip(),
),
],
)
@@ -1625,3 +1642,608 @@ def test_as_cte(sql: str, engine: str, expected: str) -> None:
Test that we can covert select to CTE.
"""
assert SQLStatement(sql, engine).as_cte().format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM (
SELECT
*
FROM some_table
WHERE
id = 42
) AS t
""".strip(),
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM (
SELECT
*
FROM some_table
WHERE
id = 42
) AS t
WHERE
bar = 'baz'
""".strip(),
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM (
SELECT
*
FROM schema1.some_table
WHERE
id = 42
) AS t
""".strip(),
),
(
"SELECT t.foo FROM schema1.some_table AS t",
{Table("some_table", "schema2"): "id = 42"},
"SELECT\n t.foo\nFROM schema1.some_table AS t",
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM (
SELECT
*
FROM catalog1.schema1.some_table
WHERE
id = 42
) AS t
""".strip(),
),
(
"SELECT t.foo FROM catalog1.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog2"): "id = 42"},
"SELECT\n t.foo\nFROM catalog1.schema1.some_table AS t",
),
(
"SELECT * FROM some_table WHERE 1=1",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM some_table
WHERE
id = 42
) AS some_table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM table
WHERE
id = 42
) AS table
WHERE
1 = 1
""".strip(),
),
(
'SELECT * FROM "table" WHERE 1=1',
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM "table"
WHERE
id = 42
) AS "table"
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM other_table WHERE 1=1",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM other_table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
JOIN (
SELECT
*
FROM other_table
WHERE
id = 42
) AS other_table
ON table.id = other_table.id
""".strip(),
),
(
'SELECT * FROM "table" JOIN other_table ON "table".id = other_table.id',
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM "table"
WHERE
id = 42
) AS "table"
JOIN other_table
ON "table".id = other_table.id
""".strip(),
),
(
"SELECT * FROM (SELECT * FROM some_table)",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM (
SELECT
*
FROM some_table
WHERE
id = 42
) AS some_table
)
""".strip(),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM table
WHERE
id = 42
) AS table
UNION ALL
SELECT
*
FROM other_table
""".strip(),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
UNION ALL
SELECT
*
FROM (
SELECT
*
FROM other_table
WHERE
id = 42
) AS other_table
""".strip(),
),
(
"SELECT a.*, b.* FROM tbl_a AS a INNER JOIN tbl_b AS b ON a.col = b.col",
{Table("tbl_a", "schema1", "catalog1"): "id = 42"},
"""
SELECT
a.*,
b.*
FROM (
SELECT
*
FROM tbl_a
WHERE
id = 42
) AS a
INNER JOIN tbl_b AS b
ON a.col = b.col
""".strip(),
),
(
"SELECT a.*, b.* FROM tbl_a a INNER JOIN tbl_b b ON a.col = b.col",
{Table("tbl_a", "schema1", "catalog1"): "id = 42"},
"""
SELECT
a.*,
b.*
FROM (
SELECT
*
FROM tbl_a
WHERE
id = 42
) AS a
INNER JOIN tbl_b AS b
ON a.col = b.col
""".strip(),
),
],
)
def test_rls_subquery_transformer(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Test `RLSAsSubqueryTransformer`.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
RLSMethod.AS_SUBQUERY,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM some_table AS t
WHERE
t.id = 42
""".strip(),
),
(
"SELECT t.foo FROM schema2.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM schema2.some_table AS t
""".strip(),
),
(
"SELECT t.foo FROM catalog2.schema1.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM catalog2.schema1.some_table AS t
""".strip(),
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz'",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM some_table AS t
WHERE
t.id = 42 AND (
bar = 'baz'
)
""".strip(),
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
t.foo
FROM some_table AS t
WHERE
t.id = 42 AND (
bar = 'baz' OR foo = 'qux'
)
""".strip(),
),
(
"SELECT * FROM some_table WHERE 1=1",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM some_table
WHERE
some_table.id = 42 AND (
1 = 1
)
""".strip(),
),
(
"SELECT * FROM some_table WHERE TRUE OR FALSE",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM some_table
WHERE
some_table.id = 42 AND (
TRUE OR FALSE
)
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42 AND (
1 = 1
)
""".strip(),
),
(
'SELECT * FROM "table" WHERE 1=1',
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM "table"
WHERE
"table".id = 42 AND (
1 = 1
)
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM other_table WHERE 1=1",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM other_table
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM table",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42
""".strip(),
),
(
"SELECT * FROM some_table",
{Table("some_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM some_table
WHERE
some_table.id = 42
""".strip(),
),
(
"SELECT * FROM table ORDER BY id",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42
ORDER BY
id
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1 AND table.id=42",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42 AND (
1 = 1 AND table.id = 42
)
""".strip(),
),
(
"""
SELECT * FROM table
JOIN other_table
ON table.id = other_table.id
AND other_table.id=42
""",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
JOIN other_table
ON other_table.id = 42 AND (
table.id = other_table.id AND other_table.id = 42
)
""".strip(),
),
(
"SELECT * FROM table WHERE 1=1 AND id=42",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42 AND (
1 = 1 AND id = 42
)
""".strip(),
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
JOIN other_table
ON other_table.id = 42 AND (
table.id = other_table.id
)
""".strip(),
),
(
"""
SELECT *
FROM table
JOIN other_table
ON table.id = other_table.id
WHERE 1=1
""",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
JOIN other_table
ON other_table.id = 42 AND (
table.id = other_table.id
)
WHERE
1 = 1
""".strip(),
),
(
"SELECT * FROM (SELECT * FROM other_table)",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM (
SELECT
*
FROM other_table
WHERE
other_table.id = 42
)
""".strip(),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
{Table("table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
WHERE
table.id = 42
UNION ALL
SELECT
*
FROM other_table
""".strip(),
),
(
"SELECT * FROM table UNION ALL SELECT * FROM other_table",
{Table("other_table", "schema1", "catalog1"): "id = 42"},
"""
SELECT
*
FROM table
UNION ALL
SELECT
*
FROM other_table
WHERE
other_table.id = 42
""".strip(),
),
],
)
def test_rls_predicate_transformer(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Test `RLSPredicateTransformer`.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
RLSMethod.AS_PREDICATE,
)
assert statement.format() == expected