mirror of
https://github.com/apache/superset.git
synced 2026-04-24 18:44:53 +00:00
feat: implement RLS in sqlglot (#33524)
This commit is contained in:
@@ -55,7 +55,7 @@ SQLGLOT_DIALECTS = {
|
||||
# "db2": ???
|
||||
# "dremio": ???
|
||||
"drill": Dialects.DRILL,
|
||||
# "druid": ???
|
||||
"druid": Dialects.DRUID,
|
||||
"duckdb": Dialects.DUCKDB,
|
||||
# "dynamodb": ???
|
||||
# "elasticsearch": ???
|
||||
@@ -108,6 +108,150 @@ class LimitMethod(enum.Enum):
|
||||
FETCH_MANY = enum.auto()
|
||||
|
||||
|
||||
class RLSMethod(enum.Enum):
|
||||
"""
|
||||
Methods for enforcing RLS.
|
||||
"""
|
||||
|
||||
AS_PREDICATE = enum.auto()
|
||||
AS_SUBQUERY = enum.auto()
|
||||
|
||||
|
||||
class RLSTransformer:
|
||||
"""
|
||||
AST transformer to apply RLS rules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
rules: dict[Table, list[exp.Expression]],
|
||||
) -> None:
|
||||
self.catalog = catalog
|
||||
self.schema = schema
|
||||
self.rules = rules
|
||||
|
||||
def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
|
||||
"""
|
||||
Get the combined RLS predicate for a table.
|
||||
"""
|
||||
table = Table(
|
||||
table_node.name,
|
||||
table_node.db if table_node.db else self.schema,
|
||||
table_node.catalog if table_node.catalog else self.catalog,
|
||||
)
|
||||
if predicates := self.rules.get(table):
|
||||
return (
|
||||
exp.And(
|
||||
this=predicates[0],
|
||||
expressions=predicates[1:],
|
||||
)
|
||||
if len(predicates) > 1
|
||||
else predicates[0]
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class RLSAsPredicateTransformer(RLSTransformer):
|
||||
"""
|
||||
Apply Row Level Security role as a predicate.
|
||||
|
||||
This transformer will apply any RLS predicates to the relevant tables. For example,
|
||||
given the RLS rule:
|
||||
|
||||
table: some_table
|
||||
clause: id = 42
|
||||
|
||||
If a user subject to the rule runs the following query:
|
||||
|
||||
SELECT foo FROM some_table WHERE bar = 'baz'
|
||||
|
||||
The query will be modified to:
|
||||
|
||||
SELECT foo FROM some_table WHERE bar = 'baz' AND id = 42
|
||||
|
||||
This approach is probably less secure than using subqueries, so it's only used for
|
||||
databases without support for subqueries.
|
||||
"""
|
||||
|
||||
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||
if not isinstance(node, exp.Table):
|
||||
return node
|
||||
|
||||
predicate = self.get_predicate(node)
|
||||
if not predicate:
|
||||
return node
|
||||
|
||||
# qualify columns with table name
|
||||
for column in predicate.find_all(exp.Column):
|
||||
column.set("table", node.alias or node.this)
|
||||
|
||||
if isinstance(node.parent, exp.From):
|
||||
select = node.parent.parent
|
||||
if where := select.args.get("where"):
|
||||
predicate = exp.And(
|
||||
this=predicate,
|
||||
expression=exp.Paren(this=where.this),
|
||||
)
|
||||
select.set("where", exp.Where(this=predicate))
|
||||
|
||||
elif isinstance(node.parent, exp.Join):
|
||||
join = node.parent
|
||||
if on := join.args.get("on"):
|
||||
predicate = exp.And(
|
||||
this=predicate,
|
||||
expression=exp.Paren(this=on),
|
||||
)
|
||||
join.set("on", predicate)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
class RLSAsSubqueryTransformer(RLSTransformer):
|
||||
"""
|
||||
Apply Row Level Security role as a subquery.
|
||||
|
||||
This transformer will apply any RLS predicates to the relevant tables. For example,
|
||||
given the RLS rule:
|
||||
|
||||
table: some_table
|
||||
clause: id = 42
|
||||
|
||||
If a user subject to the rule runs the following query:
|
||||
|
||||
SELECT foo FROM some_table WHERE bar = 'baz'
|
||||
|
||||
The query will be modified to:
|
||||
|
||||
SELECT foo FROM (SELECT * FROM some_table WHERE id = 42) AS some_table
|
||||
WHERE bar = 'baz'
|
||||
|
||||
This approach is probably more secure than using predicates, but it doesn't work for
|
||||
all databases.
|
||||
"""
|
||||
|
||||
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||
if not isinstance(node, exp.Table):
|
||||
return node
|
||||
|
||||
if predicate := self.get_predicate(node):
|
||||
# use alias or name
|
||||
alias = node.alias or node.sql()
|
||||
node.set("alias", None)
|
||||
node = exp.Subquery(
|
||||
this=exp.Select(
|
||||
expressions=[exp.Star()],
|
||||
where=exp.Where(this=predicate),
|
||||
**{"from": exp.From(this=node.copy())},
|
||||
),
|
||||
alias=alias,
|
||||
)
|
||||
|
||||
return node
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table:
|
||||
"""
|
||||
@@ -173,7 +317,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
elif statement:
|
||||
self._parsed = self._parse_statement(statement, engine)
|
||||
else:
|
||||
raise SupersetParseError("Either statement or ast must be provided")
|
||||
raise ValueError("Either statement or ast must be provided")
|
||||
|
||||
self.engine = engine
|
||||
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
|
||||
@@ -293,6 +437,22 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def apply_rls(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[InternalRepresentation]],
|
||||
method: RLSMethod,
|
||||
) -> None:
|
||||
"""
|
||||
Apply relevant RLS rules to the statement inplace.
|
||||
|
||||
:param catalog: The default catalog for non-qualified table names
|
||||
:param schema: The default schema for non-qualified table names
|
||||
:param method: The method to use for applying the rules.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.format()
|
||||
|
||||
@@ -573,6 +733,30 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
engine=self.engine,
|
||||
)
|
||||
|
||||
def apply_rls(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[exp.Expression]],
|
||||
method: RLSMethod,
|
||||
) -> None:
|
||||
"""
|
||||
Apply relevant RLS rules to the statement inplace.
|
||||
|
||||
:param catalog: The default catalog for non-qualified table names
|
||||
:param schema: The default schema for non-qualified table names
|
||||
:param method: The method to use for applying the rules.
|
||||
"""
|
||||
transformers = {
|
||||
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
|
||||
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
|
||||
}
|
||||
if method not in transformers:
|
||||
raise ValueError(f"Invalid RLS method: {method}")
|
||||
|
||||
transformer = transformers[method](catalog, schema, predicates)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
@@ -966,7 +1150,7 @@ def extract_tables_from_statement(
|
||||
"""
|
||||
Extract all table references in a single statement.
|
||||
|
||||
Please not that this is not trivial; consider the following queries:
|
||||
Please note that this is not trivial; consider the following queries:
|
||||
|
||||
DESCRIBE some_table;
|
||||
SHOW PARTITIONS FROM some_table;
|
||||
|
||||
Reference in New Issue
Block a user