mirror of
https://github.com/apache/superset.git
synced 2026-05-29 20:29:34 +00:00
feat: new splice RLSMethod
This commit is contained in:
@@ -142,6 +142,7 @@ class RLSMethod(enum.Enum):
|
||||
|
||||
AS_PREDICATE = enum.auto()
|
||||
AS_SUBQUERY = enum.auto()
|
||||
AS_PREDICATE_SPLICE = enum.auto()
|
||||
|
||||
|
||||
class RLSTransformer:
|
||||
@@ -355,6 +356,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
statement: str | None = None,
|
||||
engine: str = "base",
|
||||
ast: InternalRepresentation | None = None,
|
||||
source: str | None = None,
|
||||
):
|
||||
if ast:
|
||||
self._parsed = ast
|
||||
@@ -365,6 +367,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
|
||||
self.engine = engine
|
||||
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
|
||||
# Original SQL substring for this statement, when known. Used by the
|
||||
# splice-mode RLS path which rewrites this string instead of regenerating
|
||||
# SQL from the AST. ``None`` means the statement was constructed from an
|
||||
# AST without an associated source string (splice mode falls back).
|
||||
self._source_sql: str | None = source if source is not None else statement
|
||||
# Verbatim SQL to return from ``format()``. Set by string-rewriting
|
||||
# operations (e.g. splice-mode RLS) that produce a final SQL string and
|
||||
# need to bypass the dialect generator. Cleared by AST-mutating methods
|
||||
# since those invalidate this cached text.
|
||||
self._raw_sql: str | None = None
|
||||
|
||||
@classmethod
|
||||
def split_script(
|
||||
@@ -559,9 +571,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
statement: str | None = None,
|
||||
engine: str = "base",
|
||||
ast: exp.Expression | None = None,
|
||||
source: str | None = None,
|
||||
):
|
||||
self._dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
super().__init__(statement, engine, ast)
|
||||
super().__init__(statement, engine, ast, source)
|
||||
|
||||
@classmethod
|
||||
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
|
||||
@@ -626,10 +639,55 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
script: str,
|
||||
engine: str,
|
||||
) -> list[SQLStatement]:
|
||||
asts = [ast for ast in cls._parse(script, engine) if ast]
|
||||
sources = cls._split_source(script, engine, len(asts))
|
||||
return [
|
||||
cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast
|
||||
cls(ast=ast, engine=engine, source=source)
|
||||
for ast, source in zip(asts, sources, strict=False)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _split_source(
|
||||
cls,
|
||||
script: str,
|
||||
engine: str,
|
||||
expected_count: int,
|
||||
) -> list[str | None]:
|
||||
"""
|
||||
Slice ``script`` into per-statement substrings using top-level semicolon
|
||||
positions from the tokenizer. Returns a list of length ``expected_count``;
|
||||
any entry is ``None`` if the slicing didn't yield a usable substring.
|
||||
|
||||
The returned substrings preserve the original byte content of the script
|
||||
for each statement — necessary for splice-mode RLS, which rewrites the
|
||||
original SQL rather than regenerating from the AST.
|
||||
"""
|
||||
none_result: list[str | None] = [None] * expected_count
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
try:
|
||||
tokens = list(Dialect.get_or_raise(dialect).tokenize(script))
|
||||
except sqlglot.errors.SqlglotError:
|
||||
return none_result
|
||||
|
||||
# Top-level semicolon offsets (depth 0).
|
||||
boundaries: list[int] = []
|
||||
depth = 0
|
||||
for tok in tokens:
|
||||
if tok.token_type == sqlglot.tokens.TokenType.L_PAREN:
|
||||
depth += 1
|
||||
elif tok.token_type == sqlglot.tokens.TokenType.R_PAREN:
|
||||
depth -= 1
|
||||
elif tok.token_type == sqlglot.tokens.TokenType.SEMICOLON and depth == 0:
|
||||
boundaries.append(tok.start)
|
||||
|
||||
starts = [0, *(b + 1 for b in boundaries)]
|
||||
ends = [*boundaries, len(script)]
|
||||
sources = [script[s:e].strip() for s, e in zip(starts, ends, strict=False)]
|
||||
sources = [s for s in sources if s]
|
||||
if len(sources) != expected_count:
|
||||
return none_result
|
||||
return list(sources)
|
||||
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
cls,
|
||||
@@ -722,7 +780,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
|
||||
When a string-rewriting operation (e.g. splice-mode RLS) has cached a
|
||||
verbatim result in ``_raw_sql``, return it as-is — the whole point of
|
||||
those operations is to avoid the dialect generator round-trip.
|
||||
"""
|
||||
if self._raw_sql is not None:
|
||||
return self._raw_sql
|
||||
return Dialect.get_or_raise(self._dialect).generate(
|
||||
self._parsed,
|
||||
copy=True,
|
||||
@@ -808,6 +872,8 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
|
||||
"""
|
||||
# AST mutation invalidates any cached verbatim SQL (e.g. from splice).
|
||||
self._raw_sql = None
|
||||
if method == LimitMethod.FORCE_LIMIT:
|
||||
self._parsed.args["limit"] = exp.Limit(
|
||||
expression=exp.Literal(this=str(limit), is_string=False)
|
||||
@@ -902,7 +968,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[exp.Expression]],
|
||||
predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]],
|
||||
method: RLSMethod,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -910,11 +976,18 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
|
||||
:param catalog: The default catalog for non-qualified table names
|
||||
:param schema: The default schema for non-qualified table names
|
||||
:param predicates: Mapping of fully qualified ``Table`` to predicates.
|
||||
For ``AS_PREDICATE`` and ``AS_SUBQUERY`` the predicates are sqlglot
|
||||
expressions. For ``AS_PREDICATE_SPLICE`` they are raw SQL strings.
|
||||
:param method: The method to use for applying the rules.
|
||||
"""
|
||||
if not predicates:
|
||||
return
|
||||
|
||||
if method == RLSMethod.AS_PREDICATE_SPLICE:
|
||||
self._apply_rls_splice(catalog, schema, predicates)
|
||||
return
|
||||
|
||||
transformers = {
|
||||
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
|
||||
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
|
||||
@@ -925,6 +998,44 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
transformer = transformers[method](catalog, schema, predicates)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
def _apply_rls_splice(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[exp.Expression]] | dict[Table, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply RLS via text splicing on the original SQL.
|
||||
|
||||
Requires the source SQL substring to be available. Raises ``ValueError``
|
||||
if it isn't — the caller must ensure the statement was constructed from
|
||||
a source string (the standard ``SQLScript`` path does this).
|
||||
"""
|
||||
from superset.sql.rls_splice import apply_rls_splice
|
||||
|
||||
if self._source_sql is None:
|
||||
raise ValueError(
|
||||
"Splice-mode RLS requires the source SQL string; "
|
||||
"this SQLStatement was constructed without one."
|
||||
)
|
||||
|
||||
# Splice operates on raw predicate strings; coerce expressions if needed.
|
||||
string_predicates: dict[Table, list[str]] = {
|
||||
table: [
|
||||
pred if isinstance(pred, str) else pred.sql(dialect=self._dialect)
|
||||
for pred in preds
|
||||
]
|
||||
for table, preds in predicates.items()
|
||||
}
|
||||
spliced = apply_rls_splice(
|
||||
self._source_sql,
|
||||
catalog,
|
||||
schema,
|
||||
string_predicates,
|
||||
dialect=self._dialect,
|
||||
)
|
||||
self._raw_sql = spliced
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user