feat: new splice RLSMethod

This commit is contained in:
Beto Dealmeida
2026-05-08 13:33:42 -04:00
parent 5bde86785f
commit 3fe2b2505f
6 changed files with 615 additions and 14 deletions

View File

@@ -142,6 +142,7 @@ class RLSMethod(enum.Enum):
AS_PREDICATE = enum.auto()
AS_SUBQUERY = enum.auto()
AS_PREDICATE_SPLICE = enum.auto()
class RLSTransformer:
@@ -355,6 +356,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
statement: str | None = None,
engine: str = "base",
ast: InternalRepresentation | None = None,
source: str | None = None,
):
if ast:
self._parsed = ast
@@ -365,6 +367,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
# Original SQL substring for this statement, when known. Used by the
# splice-mode RLS path which rewrites this string instead of regenerating
# SQL from the AST. ``None`` means the statement was constructed from an
# AST without an associated source string (splice mode falls back).
self._source_sql: str | None = source if source is not None else statement
# Verbatim SQL to return from ``format()``. Set by string-rewriting
# operations (e.g. splice-mode RLS) that produce a final SQL string and
# need to bypass the dialect generator. Cleared by AST-mutating methods
# since those invalidate this cached text.
self._raw_sql: str | None = None
@classmethod
def split_script(
@@ -559,9 +571,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
statement: str | None = None,
engine: str = "base",
ast: exp.Expression | None = None,
source: str | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine, ast)
super().__init__(statement, engine, ast, source)
@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -626,10 +639,55 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
script: str,
engine: str,
) -> list[SQLStatement]:
asts = [ast for ast in cls._parse(script, engine) if ast]
sources = cls._split_source(script, engine, len(asts))
return [
cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast
cls(ast=ast, engine=engine, source=source)
for ast, source in zip(asts, sources, strict=False)
]
@classmethod
def _split_source(
cls,
script: str,
engine: str,
expected_count: int,
) -> list[str | None]:
"""
Slice ``script`` into per-statement substrings using top-level semicolon
positions from the tokenizer. Returns a list of length ``expected_count``;
any entry is ``None`` if the slicing didn't yield a usable substring.
The returned substrings preserve the original byte content of the script
for each statement — necessary for splice-mode RLS, which rewrites the
original SQL rather than regenerating from the AST.
"""
none_result: list[str | None] = [None] * expected_count
dialect = SQLGLOT_DIALECTS.get(engine)
try:
tokens = list(Dialect.get_or_raise(dialect).tokenize(script))
except sqlglot.errors.SqlglotError:
return none_result
# Top-level semicolon offsets (depth 0).
boundaries: list[int] = []
depth = 0
for tok in tokens:
if tok.token_type == sqlglot.tokens.TokenType.L_PAREN:
depth += 1
elif tok.token_type == sqlglot.tokens.TokenType.R_PAREN:
depth -= 1
elif tok.token_type == sqlglot.tokens.TokenType.SEMICOLON and depth == 0:
boundaries.append(tok.start)
starts = [0, *(b + 1 for b in boundaries)]
ends = [*boundaries, len(script)]
sources = [script[s:e].strip() for s, e in zip(starts, ends, strict=False)]
sources = [s for s in sources if s]
if len(sources) != expected_count:
return none_result
return list(sources)
@classmethod
def _parse_statement(
cls,
@@ -722,7 +780,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
When a string-rewriting operation (e.g. splice-mode RLS) has cached a
verbatim result in ``_raw_sql``, return it as-is — the whole point of
those operations is to avoid the dialect generator round-trip.
"""
if self._raw_sql is not None:
return self._raw_sql
return Dialect.get_or_raise(self._dialect).generate(
self._parsed,
copy=True,
@@ -808,6 +872,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
"""
# AST mutation invalidates any cached verbatim SQL (e.g. from splice).
self._raw_sql = None
if method == LimitMethod.FORCE_LIMIT:
self._parsed.args["limit"] = exp.Limit(
expression=exp.Literal(this=str(limit), is_string=False)
@@ -902,7 +968,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[exp.Expression]],
predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]],
method: RLSMethod,
) -> None:
"""
@@ -910,11 +976,18 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param predicates: Mapping of fully qualified ``Table`` to predicates.
For ``AS_PREDICATE`` and ``AS_SUBQUERY`` the predicates are sqlglot
expressions. For ``AS_PREDICATE_SPLICE`` they are raw SQL strings.
:param method: The method to use for applying the rules.
"""
if not predicates:
return
if method == RLSMethod.AS_PREDICATE_SPLICE:
self._apply_rls_splice(catalog, schema, predicates)
return
transformers = {
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
@@ -925,6 +998,44 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
transformer = transformers[method](catalog, schema, predicates)
self._parsed = self._parsed.transform(transformer)
def _apply_rls_splice(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]],
) -> None:
"""
Apply RLS via text splicing on the original SQL.
Requires the source SQL substring to be available. Raises ``ValueError``
if it isn't — the caller must ensure the statement was constructed from
a source string (the standard ``SQLScript`` path does this).
"""
from superset.sql.rls_splice import apply_rls_splice
if self._source_sql is None:
raise ValueError(
"Splice-mode RLS requires the source SQL string; "
"this SQLStatement was constructed without one."
)
# Splice operates on raw predicate strings; coerce expressions if needed.
string_predicates: dict[Table, list[str]] = {
table: [
pred if isinstance(pred, str) else pred.sql(dialect=self._dialect)
for pred in preds
]
for table, preds in predicates.items()
}
spliced = apply_rls_splice(
self._source_sql,
catalog,
schema,
string_predicates,
dialect=self._dialect,
)
self._raw_sql = spliced
class KQLSplitState(enum.Enum):
"""