feat: implement RLS in sqlglot (#33524)

This commit is contained in:
Beto Dealmeida
2025-05-28 09:10:45 -04:00
committed by GitHub
parent e205846845
commit 0abe6eed89
2 changed files with 826 additions and 20 deletions

View File

@@ -55,7 +55,7 @@ SQLGLOT_DIALECTS = {
# "db2": ???
# "dremio": ???
"drill": Dialects.DRILL,
# "druid": ???
"druid": Dialects.DRUID,
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
@@ -108,6 +108,150 @@ class LimitMethod(enum.Enum):
FETCH_MANY = enum.auto()
class RLSMethod(enum.Enum):
"""
Methods for enforcing RLS.
"""
AS_PREDICATE = enum.auto()
AS_SUBQUERY = enum.auto()
class RLSTransformer:
"""
AST transformer to apply RLS rules.
"""
def __init__(
self,
catalog: str | None,
schema: str | None,
rules: dict[Table, list[exp.Expression]],
) -> None:
self.catalog = catalog
self.schema = schema
self.rules = rules
def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
"""
Get the combined RLS predicate for a table.
"""
table = Table(
table_node.name,
table_node.db if table_node.db else self.schema,
table_node.catalog if table_node.catalog else self.catalog,
)
if predicates := self.rules.get(table):
return (
exp.And(
this=predicates[0],
expressions=predicates[1:],
)
if len(predicates) > 1
else predicates[0]
)
return None
class RLSAsPredicateTransformer(RLSTransformer):
"""
Apply Row Level Security role as a predicate.
This transformer will apply any RLS predicates to the relevant tables. For example,
given the RLS rule:
table: some_table
clause: id = 42
If a user subject to the rule runs the following query:
SELECT foo FROM some_table WHERE bar = 'baz'
The query will be modified to:
SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
This approach is probably less secure than using subqueries, so it's only used for
databases without support for subqueries.
"""
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Table):
return node
predicate = self.get_predicate(node)
if not predicate:
return node
# qualify columns with table name
for column in predicate.find_all(exp.Column):
column.set("table", node.alias or node.this)
if isinstance(node.parent, exp.From):
select = node.parent.parent
if where := select.args.get("where"):
predicate = exp.And(
this=predicate,
expression=exp.Paren(this=where.this),
)
select.set("where", exp.Where(this=predicate))
elif isinstance(node.parent, exp.Join):
join = node.parent
if on := join.args.get("on"):
predicate = exp.And(
this=predicate,
expression=exp.Paren(this=on),
)
join.set("on", predicate)
return node
class RLSAsSubqueryTransformer(RLSTransformer):
"""
Apply Row Level Security role as a subquery.
This transformer will apply any RLS predicates to the relevant tables. For example,
given the RLS rule:
table: some_table
clause: id = 42
If a user subject to the rule runs the following query:
SELECT foo FROM some_table WHERE bar = 'baz'
The query will be modified to:
SELECT foo FROM (SELECT * FROM some_table WHERE id = 42) AS some_table
WHERE bar = 'baz'
This approach is probably more secure than using predicates, but it doesn't work for
all databases.
"""
def __call__(self, node: exp.Expression) -> exp.Expression:
if not isinstance(node, exp.Table):
return node
if predicate := self.get_predicate(node):
# use alias or name
alias = node.alias or node.sql()
node.set("alias", None)
node = exp.Subquery(
this=exp.Select(
expressions=[exp.Star()],
where=exp.Where(this=predicate),
**{"from": exp.From(this=node.copy())},
),
alias=alias,
)
return node
@dataclass(eq=True, frozen=True)
class Table:
"""
@@ -173,7 +317,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
elif statement:
self._parsed = self._parse_statement(statement, engine)
else:
raise SupersetParseError("Either statement or ast must be provided")
raise ValueError("Either statement or ast must be provided")
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
@@ -293,6 +437,22 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
"""
raise NotImplementedError()
def apply_rls(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[InternalRepresentation]],
method: RLSMethod,
) -> None:
"""
Apply relevant RLS rules to the statement inplace.
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param method: The method to use for applying the rules.
"""
raise NotImplementedError()
def __str__(self) -> str:
return self.format()
@@ -573,6 +733,30 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
engine=self.engine,
)
def apply_rls(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[exp.Expression]],
method: RLSMethod,
) -> None:
"""
Apply relevant RLS rules to the statement inplace.
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param method: The method to use for applying the rules.
"""
transformers = {
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
}
if method not in transformers:
raise ValueError(f"Invalid RLS method: {method}")
transformer = transformers[method](catalog, schema, predicates)
self._parsed = self._parsed.transform(transformer)
class KQLSplitState(enum.Enum):
"""
@@ -966,7 +1150,7 @@ def extract_tables_from_statement(
"""
Extract all table references in a single statement.
Please not that this is not trivial; consider the following queries:
Please note that this is not trivial; consider the following queries:
DESCRIBE some_table;
SHOW PARTITIONS FROM some_table;