Files
superset2/superset/sql/rls_splice.py
Beto Dealmeida d1cd84931e Increase coverage
2026-05-08 16:40:29 -04:00

483 lines
15 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
RLS predicate injection via text splicing.
Instead of round-tripping through sqlglot's generator (which transpiles
dialect-specific functions like ``LAST_DAY`` into something else), this approach:
1. Parses the SQL with sqlglot — only to understand structure (scope tree).
2. Uses sqlglot's tokenizer to get byte-accurate positions for every token
in the original SQL string.
3. For each ``SELECT`` scope that references a table with an RLS predicate,
finds the exact byte offset to inject at — either the end of an existing
``WHERE`` clause, or just before ``GROUP BY`` / ``ORDER BY`` / ``HAVING``
/ ``LIMIT`` / the closing paren of a subquery.
4. Splices the predicate text directly into the original string at that
offset — never calling ``.sql()``, so the generator never runs.
Result: everything outside the splice points is the original SQL, byte for
byte. Dialect-specific functions, comments, and formatting are all preserved
exactly.
Known limitations:
- SQL that fails to parse under the chosen dialect raises a ``ParseError``.
A thin dialect subclass is still required for parsing — but only for
parsing, not generation.
- Predicate strings are spliced in as raw SQL. They must come from a trusted
source (the RLS config), not user input.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlglot
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.tokens import Token, TokenType
if TYPE_CHECKING:
from superset.sql.parse import Table
# Token types that end a WHERE clause / FROM section at the current paren depth,
# indicating where a new predicate must be inserted just before.
_CLAUSE_ENDS = {
TokenType.GROUP_BY,
TokenType.HAVING,
TokenType.ORDER_BY,
TokenType.WINDOW,
TokenType.QUALIFY,
TokenType.LIMIT,
TokenType.FETCH,
TokenType.CLUSTER_BY,
TokenType.DISTRIBUTE_BY,
TokenType.SORT_BY,
TokenType.CONNECT_BY,
TokenType.START_WITH,
TokenType.UNION,
TokenType.INTERSECT,
TokenType.EXCEPT,
}
_JOIN_STARTS = {
TokenType.JOIN,
TokenType.STRAIGHT_JOIN,
TokenType.JOIN_MARKER,
}
def _splice_priority(text: str) -> int:
"""
Priority for applying splices at the same offset.
Insert full SQL fragments (WHERE/ON/predicates) before closing parens so
wrapping splices like ``pred AND (existing)`` compose correctly.
"""
return 1 if text != ")" else 0
def _before_whitespace(sql: str, offset: int) -> int:
"""Back up past any whitespace immediately before *offset*."""
while offset > 0 and sql[offset - 1] in (" ", "\t", "\n", "\r"):
offset -= 1
return offset
def _before_trivia(sql: str, offset: int) -> int:
"""
Back up past whitespace and adjacent comments immediately before *offset*.
This ensures insertion points land before inline/block comments that appear
between `FROM`/`WHERE` and the next clause keyword.
"""
while True:
offset = _before_whitespace(sql, offset)
# Inline comment ending at offset, eg: "... -- comment\nGROUP BY".
line_start = sql.rfind("\n", 0, offset) + 1
inline_comment_start = sql.rfind("--", line_start, offset)
if inline_comment_start != -1:
offset = inline_comment_start
continue
# Block comment ending at offset, eg: "... /* comment */GROUP BY".
if offset >= 2 and sql[offset - 2 : offset] == "*/":
block_comment_start = sql.rfind("/*", 0, offset - 2)
if block_comment_start != -1:
offset = block_comment_start
continue
return offset
def _table_from_node(
node: exp.Table,
catalog: str | None,
schema: str | None,
) -> Table:
"""
Build a fully qualified ``Table`` from a sqlglot ``exp.Table`` node, defaulting
unqualified parts to the supplied catalog/schema.
"""
# Imported lazily to avoid a circular import with ``superset.sql.parse``.
from superset.sql.parse import Table
return Table(
table=node.name,
schema=node.db if node.db else schema,
catalog=node.catalog if node.catalog else catalog,
)
def apply_rls_splice(
sql: str,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[str]],
dialect: str | None = None,
) -> str:
"""
Inject RLS predicates into ``sql`` by splicing text at the right positions.
:param sql: The original SQL query. Returned unchanged except at splice points.
:param catalog: The default catalog for non-qualified table names.
:param schema: The default schema for non-qualified table names.
:param predicates: Mapping of ``Table`` to predicate SQL strings. Each entry
maps a fully qualified table to one or more raw predicate strings to
``AND`` together when that table is referenced in a SELECT scope.
:param dialect: The sqlglot dialect used for *parsing only* — to understand
scope structure and locate token positions. The generator is never
called, so this does not affect output formatting.
:return: The query with RLS predicates injected into every relevant SELECT
scope.
"""
if not predicates or not any(predicates.values()):
return sql
resolved_dialect = sqlglot.Dialect.get_or_raise(dialect)
tokens = list(resolved_dialect.tokenize(sql))
tree = sqlglot.parse_one(sql, dialect=dialect)
splices: list[tuple[int, str]] = []
for scope in traverse_scope(tree):
splices.extend(
_splices_for_scope(
sql,
tokens,
scope,
predicates,
catalog,
schema,
dialect,
)
)
# Apply splices in reverse offset order so earlier positions stay valid.
# For equal offsets, apply predicate/WHERE/ON inserts before ")" inserts.
splices.sort(key=lambda item: (item[0], _splice_priority(item[1])), reverse=True)
result = sql
for offset, text in splices:
result = result[:offset] + text + result[offset:]
return result
def _splices_for_scope(
sql: str,
tokens: list[Token],
scope: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> list[tuple[int, str]]:
"""
Compute all splices for a single SELECT scope.
This mirrors ``RLSAsPredicateTransformer`` semantics:
- predicates for FROM tables are applied to the SELECT WHERE clause as
``pred AND (existing_where)``
- predicates for JOIN tables are applied to each JOIN ON clause as
``pred AND (existing_on)`` (or ``ON pred`` when ON is absent)
"""
from_predicates: list[str] = []
from_table_ends: list[int] = []
join_splices: list[tuple[int, str]] = []
for source in scope.sources.values(): # type: ignore[attr-defined]
source_type, table_end, pred_sql = _classify_source_predicate(
source,
predicates,
catalog,
schema,
dialect,
)
if source_type == "none" or table_end is None or pred_sql is None:
continue
if source_type == "from":
from_predicates.append(pred_sql)
from_table_ends.append(table_end)
continue
join_splice = _find_join_splice(sql, tokens, table_end, pred_sql)
if join_splice:
join_splices.extend(join_splice)
continue
if not from_predicates:
return join_splices
combined_predicates = " AND ".join(dict.fromkeys(from_predicates))
from_splice = _find_where_splice(
sql,
tokens,
max(from_table_ends),
combined_predicates,
)
return [*join_splices, *from_splice]
def _table_end(source: exp.Table) -> int | None:
ident = source.find(exp.Identifier)
if ident and getattr(ident, "_meta", None):
return ident._meta["end"]
return None
def _classify_source_predicate(
source: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> tuple[str, int | None, str | None]:
"""
Return source kind (from/join/none), table end offset, and predicate SQL.
"""
if not isinstance(source, exp.Table):
return ("none", None, None)
table = _table_from_node(source, catalog, schema)
table_predicates = [
_qualify_predicate(predicate, source, dialect)
for predicate in predicates.get(table, [])
if predicate
]
if not table_predicates:
return ("none", None, None)
table_end = _table_end(source)
if table_end is None:
return ("none", None, None)
pred_sql = " AND ".join(dict.fromkeys(table_predicates))
if isinstance(source.parent, exp.From):
return ("from", table_end, pred_sql)
if isinstance(source.parent, exp.Join):
return ("join", table_end, pred_sql)
return ("none", None, None)
def _qualify_predicate(
predicate: str,
table_node: exp.Table,
dialect: str | None,
) -> str:
"""
Qualify predicate columns with the table alias/name, mirroring
``RLSAsPredicateTransformer``.
"""
parsed = sqlglot.parse_one(predicate, dialect=dialect)
table = table_node.alias_or_name
table_expr = exp.to_identifier(table)
for column in parsed.find_all(exp.Column):
column.set("table", table_expr.copy())
return parsed.sql(dialect=dialect)
def _scan_until_scope_boundary(
tokens: list[Token],
anchor: int,
*,
stop_at_join: bool,
) -> tuple[str, int | None]:
"""
Scan tokens forward from ``anchor`` until a clause/scope boundary.
Returns ``("where", index)`` when a WHERE token is found at depth 0,
``("boundary", index)`` for a non-WHERE boundary token, and
``("eof", None)`` when no boundary token is found.
"""
depth = 0
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
return ("boundary", i)
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.WHERE:
return ("where", i)
if tok.token_type in _CLAUSE_ENDS or (
stop_at_join and tok.token_type in _JOIN_STARTS
):
return ("boundary", i)
return ("eof", None)
def _find_condition_end(
sql: str,
tokens: list[Token],
start_index: int,
*,
stop_at_join: bool,
) -> int:
"""
Find the end offset for a WHERE/ON condition body.
"""
depth = 0
prev_end = tokens[start_index].end
for tok in tokens[start_index + 1 :]:
if tok.token_type == TokenType.L_PAREN:
depth += 1
elif tok.token_type == TokenType.R_PAREN:
if depth == 0:
return _before_trivia(sql, tok.start)
depth -= 1
elif depth == 0 and (
(stop_at_join and tok.token_type == TokenType.WHERE)
or tok.token_type in _CLAUSE_ENDS
or (stop_at_join and tok.token_type in _JOIN_STARTS)
):
return _before_trivia(sql, tok.start)
prev_end = tok.end
return prev_end + 1
def _find_where_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to the SELECT WHERE clause:
``pred`` when absent, ``pred AND (existing)`` when present.
"""
kind, idx = _scan_until_scope_boundary(tokens, anchor, stop_at_join=False)
if kind == "where" and idx is not None:
if idx + 1 >= len(tokens):
return [(tokens[idx].end + 1, f" {pred_sql}")]
body_start = tokens[idx + 1].start
body_end = _find_condition_end(sql, tokens, idx, stop_at_join=False)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if kind == "boundary" and idx is not None:
return [(_before_trivia(sql, tokens[idx].start), f" WHERE {pred_sql}")]
return [(len(sql), f" WHERE {pred_sql}")]
def _find_join_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to a JOIN clause:
``ON pred`` when ON absent, ``ON pred AND (existing_on)`` when present.
"""
on_index, boundary_index = _scan_join_clause(tokens, anchor)
if on_index is not None:
if on_index + 1 >= len(tokens):
return [(tokens[on_index].end + 1, f" {pred_sql}")]
body_start = tokens[on_index + 1].start
body_end = _find_condition_end(sql, tokens, on_index, stop_at_join=True)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if boundary_index is not None:
return [(_before_trivia(sql, tokens[boundary_index].start), f" ON {pred_sql}")]
return [(len(sql), f" ON {pred_sql}")]
def _scan_join_clause(
tokens: list[Token],
anchor: int,
) -> tuple[int | None, int | None]:
"""
Find ON and boundary token indexes for a JOIN segment.
"""
depth = 0
on_index: int | None = None
boundary_index: int | None = None
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
boundary_index = i
break
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.ON and on_index is None:
on_index = i
continue
if tok.token_type == TokenType.WHERE:
boundary_index = i
break
if tok.token_type in _JOIN_STARTS or tok.token_type in _CLAUSE_ENDS:
boundary_index = i
break
return on_index, boundary_index