Compare commits

...

11 Commits

Author SHA1 Message Date
Beto Dealmeida
0ce43abe5b Address comments 2026-05-12 10:17:50 -04:00
Beto Dealmeida
439130db54 Add dialects for Exa/Solr 2026-05-12 09:59:55 -04:00
Beto Dealmeida
ec9f2da81e Add to DB without sqlglot dialect 2026-05-12 09:41:30 -04:00
Beto Dealmeida
549e51bdc4 Fix lint 2026-05-12 09:25:57 -04:00
Beto Dealmeida
72ecb20e5c Run ruff-format 2026-05-08 17:27:11 -04:00
Beto Dealmeida
d1cd84931e Increase coverage 2026-05-08 16:40:29 -04:00
Beto Dealmeida
884649e3ed Increase parity 2026-05-08 15:32:10 -04:00
Beto Dealmeida
e9dd9a6107 Simplify function signature 2026-05-08 14:28:21 -04:00
Beto Dealmeida
0fe8293c7f Simplify 2026-05-08 14:12:46 -04:00
Beto Dealmeida
af2d3babec Improvements 2026-05-08 13:49:37 -04:00
Beto Dealmeida
3fe2b2505f feat: new splice RLSMethod 2026-05-08 13:33:42 -04:00
22 changed files with 1258 additions and 48 deletions

View File

@@ -557,6 +557,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# if True, database will be listed as option in the upload file form
supports_file_upload = True
# RLS strategy for this engine spec. Override in engine-specific classes as
# needed (for example ``RLSMethod.AS_PREDICATE`` for engines that don't
# support subquery-based RLS, or ``RLSMethod.AS_PREDICATE_SPLICE`` for
# engines where sqlglot generation is not faithful).
rls_method = RLSMethod.AS_SUBQUERY
# Is the DB engine spec able to change the default schema? This requires implementing # noqa: E501
# a custom `adjust_engine_params` method.
supports_dynamic_schema = False
@@ -618,21 +624,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
else cls.encrypted_extra_sensitive_fields
)
@classmethod
def get_rls_method(cls) -> RLSMethod:
"""
Returns the RLS method to be used for this engine.
There are two ways to insert RLS: either replacing the table with a subquery
that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
safer, but not supported in all databases.
"""
return (
RLSMethod.AS_SUBQUERY
if cls.allows_subqueries and cls.allows_alias_in_select
else RLSMethod.AS_PREDICATE
)
@classmethod
def is_oauth2_enabled(cls) -> bool:
return (

View File

@@ -35,6 +35,7 @@ from superset.db_engine_specs.base import (
DatabaseCategory,
)
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql.parse import RLSMethod
from superset.utils.network import is_hostname_valid, is_port_open
@@ -82,6 +83,7 @@ class CouchbaseEngineSpec(BasicParametersMixin, BaseEngineSpec):
default_driver = "couchbase"
allows_joins = False
allows_subqueries = False
rls_method = RLSMethod.AS_PREDICATE
sqlalchemy_uri_placeholder = (
"couchbase://user:password@host[:port]?truststorepath=value?ssl=value"
)

View File

@@ -23,6 +23,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
@@ -68,6 +69,8 @@ class CrateEngineSpec(BaseEngineSpec):
TimeGrain.YEAR: "DATE_TRUNC('year', {col})",
}
rls_method = RLSMethod.AS_PREDICATE_SPLICE
@classmethod
def epoch_to_dttm(cls) -> str:
return "{col} * 1000"

View File

@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import (
)
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql.parse import RLSMethod
from superset.utils.core import GenericDataType
from superset.utils.hashing import hash_from_str
from superset.utils.network import is_hostname_valid, is_port_open
@@ -55,6 +56,8 @@ class DatabendBaseEngineSpec(BaseEngineSpec):
time_secondary_columns = True
time_groupby_inline = True
rls_method = RLSMethod.AS_PREDICATE_SPLICE
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",

View File

@@ -26,6 +26,7 @@ from superset.db_engine_specs.base import (
DatabaseCategory,
)
from superset.errors import SupersetErrorType
from superset.sql.parse import RLSMethod
# Internal class for defining error message patterns (for translation)
@@ -58,6 +59,8 @@ class DenodoEngineSpec(BaseEngineSpec, BasicParametersMixin):
engine = "denodo"
engine_name = "Denodo"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
default_driver = "psycopg2"
sqlalchemy_uri_placeholder = (
"denodo://user:password@host:port/dbname[?key=value&key=value...]"

View File

@@ -21,12 +21,15 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class DynamoDBEngineSpec(BaseEngineSpec):
engine = "dynamodb"
engine_name = "Amazon DynamoDB"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": (
"Amazon DynamoDB is a serverless NoSQL database with SQL via PartiQL."

View File

@@ -28,6 +28,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.sql.parse import RLSMethod
logger = logging.getLogger()
@@ -39,6 +40,7 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
allows_joins = False
allows_subqueries = True
allows_sql_comments = False
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": (

View File

@@ -21,7 +21,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import LimitMethod
from superset.sql.parse import LimitMethod, RLSMethod
class FirebirdEngineSpec(BaseEngineSpec):
@@ -53,6 +53,8 @@ class FirebirdEngineSpec(BaseEngineSpec):
# Firebird uses FIRST to limit: `SELECT FIRST 10 * FROM table`
limit_method = LimitMethod.FETCH_MANY
rls_method = RLSMethod.AS_PREDICATE_SPLICE
_time_grain_expressions = {
None: "{col}",
TimeGrain.SECOND: (

View File

@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.sql.parse import RLSMethod
from .db2 import Db2EngineSpec
@@ -28,6 +30,8 @@ class IBMiEngineSpec(Db2EngineSpec):
engine_name = "IBM Db2 for i"
max_column_name_length = 128
rls_method = RLSMethod.AS_PREDICATE_SPLICE
@classmethod
def epoch_to_dttm(cls) -> str:
return "(DAYS({col}) - DAYS('1970-01-01')) * 86400 + MIDNIGHT_SECONDS({col})"

View File

@@ -28,7 +28,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError,
)
from superset.sql.parse import LimitMethod
from superset.sql.parse import LimitMethod, RLSMethod
from superset.utils.core import GenericDataType
@@ -40,6 +40,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
allows_joins = True
allows_subqueries = True
allows_sql_comments = False
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": (

View File

@@ -21,6 +21,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@@ -29,6 +30,8 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
engine = "kylin"
engine_name = "Apache Kylin"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = {
"description": "Apache Kylin is an open-source OLAP engine for big data.",
"logo": "apache-kylin.png",

View File

@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.errors import SupersetErrorType
from superset.models.core import Database
from superset.models.sql_lab import Query
from superset.sql.parse import RLSMethod
# Regular expressions to catch custom errors
@@ -227,6 +228,8 @@ class OcientEngineSpec(BaseEngineSpec):
force_column_alias_quotes = True
max_column_name_length = 30
rls_method = RLSMethod.AS_PREDICATE_SPLICE
allows_cte_in_subquery = False
# Ocient does not support cte names starting with underscores
cte_alias = "cte__"

View File

@@ -20,6 +20,7 @@ from sqlalchemy.types import TypeEngine
from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class PinotEngineSpec(BaseEngineSpec):
@@ -30,6 +31,7 @@ class PinotEngineSpec(BaseEngineSpec):
allows_joins = False
allows_alias_in_select = False
allows_alias_in_orderby = False
rls_method = RLSMethod.AS_PREDICATE
# pinotdb only sets cursor.description when the response contains
# columnDataTypes, which Pinot omits for zero-row results.

View File

@@ -16,6 +16,7 @@
# under the License.
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@@ -27,6 +28,7 @@ class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
time_groupby_inline = False
allows_joins = False
allows_subqueries = False
rls_method = RLSMethod.AS_PREDICATE
metadata = {
"description": "Apache Solr is an open-source enterprise search platform.",

View File

@@ -22,6 +22,7 @@ from urllib import parse
from sqlalchemy.engine.url import make_url, URL # noqa: F401
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class TDengineEngineSpec(BaseEngineSpec):
@@ -29,6 +30,8 @@ class TDengineEngineSpec(BaseEngineSpec):
engine_name = "TDengine"
max_column_name_length = 64
default_driver = "taosws"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
sqlalchemy_uri_placeholder = (
"taosws://user:******@host:port/dbname[?key=value&key=value...]"
)

View File

@@ -76,7 +76,7 @@ SQLGLOT_DIALECTS = {
"duckdb": Dialects.DUCKDB,
# "dynamodb": ???
# "elasticsearch": ???
# "exa": ???
"exa": Dialects.EXASOL,
# "firebird": ???
"firebolt": Firebolt,
"gsheets": Dialects.SQLITE,
@@ -105,7 +105,7 @@ SQLGLOT_DIALECTS = {
"shillelagh": Dialects.SQLITE,
"singlestoredb": SingleStore,
"snowflake": Dialects.SNOWFLAKE,
# "solr": ???
"solr": Dialects.SOLR,
"spark": Dialects.SPARK,
"sqlite": Dialects.SQLITE,
"starrocks": Dialects.STARROCKS,
@@ -142,6 +142,7 @@ class RLSMethod(enum.Enum):
AS_PREDICATE = enum.auto()
AS_SUBQUERY = enum.auto()
AS_PREDICATE_SPLICE = enum.auto()
class RLSTransformer:
@@ -355,6 +356,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
statement: str | None = None,
engine: str = "base",
ast: InternalRepresentation | None = None,
source: str | None = None,
):
if ast:
self._parsed = ast
@@ -365,6 +367,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
# Original SQL substring for this statement, when known. Used by the
# splice-mode RLS path which rewrites this string instead of regenerating
# SQL from the AST. ``None`` means the statement was constructed from an
# AST without an associated source string (splice mode falls back).
self._source_sql: str | None = source if source is not None else statement
# Verbatim SQL to return from ``format()``. Set by string-rewriting
# operations (e.g. splice-mode RLS) that produce a final SQL string and
# need to bypass the dialect generator. Cleared by AST-mutating methods
# since those invalidate this cached text.
self._raw_sql: str | None = None
@classmethod
def split_script(
@@ -531,7 +543,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[InternalRepresentation]],
predicates: dict[Table, list[str]],
method: RLSMethod,
) -> None:
"""
@@ -559,9 +571,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
statement: str | None = None,
engine: str = "base",
ast: exp.Expression | None = None,
source: str | None = None,
):
self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine, ast)
super().__init__(statement, engine, ast, source)
@classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -626,10 +639,55 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
script: str,
engine: str,
) -> list[SQLStatement]:
asts = [ast for ast in cls._parse(script, engine) if ast]
sources = cls._split_source(script, engine, len(asts))
return [
cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast
cls(ast=ast, engine=engine, source=source)
for ast, source in zip(asts, sources, strict=False)
]
@classmethod
def _split_source(
cls,
script: str,
engine: str,
expected_count: int,
) -> list[str | None]:
"""
Slice ``script`` into per-statement substrings using top-level semicolon
positions from the tokenizer. Returns a list of length ``expected_count``;
any entry is ``None`` if the slicing didn't yield a usable substring.
The returned substrings preserve the original byte content of the script
for each statement — necessary for splice-mode RLS, which rewrites the
original SQL rather than regenerating from the AST.
"""
none_result: list[str | None] = [None] * expected_count
dialect = SQLGLOT_DIALECTS.get(engine)
try:
tokens = list(Dialect.get_or_raise(dialect).tokenize(script))
except sqlglot.errors.SqlglotError:
return none_result
# Top-level semicolon offsets (depth 0).
boundaries: list[int] = []
depth = 0
for tok in tokens:
if tok.token_type == sqlglot.tokens.TokenType.L_PAREN:
depth += 1
elif tok.token_type == sqlglot.tokens.TokenType.R_PAREN:
depth -= 1
elif tok.token_type == sqlglot.tokens.TokenType.SEMICOLON and depth == 0:
boundaries.append(tok.start)
starts = [0, *(b + 1 for b in boundaries)]
ends = [*boundaries, len(script)]
sources = [script[s:e].strip() for s, e in zip(starts, ends, strict=False)]
sources = [s for s in sources if s]
if len(sources) != expected_count:
return none_result
return list(sources)
@classmethod
def _parse_statement(
cls,
@@ -722,7 +780,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
def format(self, comments: bool = True) -> str:
"""
Pretty-format the SQL statement.
When a string-rewriting operation (e.g. splice-mode RLS) has cached a
verbatim result in ``_raw_sql``, return it as-is — the whole point of
those operations is to avoid the dialect generator round-trip.
"""
if self._raw_sql is not None:
return self._raw_sql
return Dialect.get_or_raise(self._dialect).generate(
self._parsed,
copy=True,
@@ -808,6 +872,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
"""
# AST mutation invalidates any cached verbatim SQL (e.g. from splice).
# If we already have a rewritten SQL string, re-parse it first so further
# AST mutations (like LIMIT injection) preserve prior text-based rewrites.
if self._raw_sql is not None:
self._parsed = self._parse_statement(self._raw_sql, self.engine)
self._source_sql = self._raw_sql
self._raw_sql = None
if method == LimitMethod.FORCE_LIMIT:
self._parsed.args["limit"] = exp.Limit(
expression=exp.Literal(this=str(limit), is_string=False)
@@ -902,7 +973,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[exp.Expression]],
predicates: dict[Table, list[str]],
method: RLSMethod,
) -> None:
"""
@@ -910,11 +981,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
:param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names
:param predicates: Mapping of fully qualified ``Table`` to raw predicate
SQL strings.
:param method: The method to use for applying the rules.
"""
if not predicates:
return
if method == RLSMethod.AS_PREDICATE_SPLICE:
self._apply_rls_splice(catalog, schema, predicates)
return
parsed_predicates: dict[Table, list[exp.Expression]] = {
table: [self.parse_predicate(predicate) for predicate in table_predicates]
for table, table_predicates in predicates.items()
}
transformers = {
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
@@ -922,9 +1004,39 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
if method not in transformers:
raise ValueError(f"Invalid RLS method: {method}")
transformer = transformers[method](catalog, schema, predicates)
transformer = transformers[method](catalog, schema, parsed_predicates)
self._parsed = self._parsed.transform(transformer)
def _apply_rls_splice(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[str]],
) -> None:
"""
Apply RLS via text splicing on the original SQL.
Requires the source SQL substring to be available. Raises ``ValueError``
if it isn't — the caller must ensure the statement was constructed from
a source string (the standard ``SQLScript`` path does this).
"""
from superset.sql.rls_splice import apply_rls_splice
if self._source_sql is None:
raise ValueError(
"Splice-mode RLS requires the source SQL string; "
"this SQLStatement was constructed without one."
)
spliced = apply_rls_splice(
self._source_sql,
catalog,
schema,
predicates,
dialect=self._dialect,
)
self._raw_sql = spliced
class KQLSplitState(enum.Enum):
"""

462
superset/sql/rls_splice.py Normal file
View File

@@ -0,0 +1,462 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
RLS predicate injection via text splicing.
Instead of round-tripping through sqlglot's generator (which transpiles
dialect-specific functions like ``LAST_DAY`` into something else), this approach:
1. Parses the SQL with sqlglot — only to understand structure (scope tree).
2. Uses sqlglot's tokenizer to get byte-accurate positions for every token
in the original SQL string.
3. For each ``SELECT`` scope that references a table with an RLS predicate,
finds the exact byte offset to inject at — either the end of an existing
``WHERE`` clause, or just before ``GROUP BY`` / ``ORDER BY`` / ``HAVING``
/ ``LIMIT`` / the closing paren of a subquery.
4. Splices the predicate text directly into the original string at that
offset — never calling ``.sql()``, so the generator never runs.
Result: everything outside the splice points is the original SQL, byte for
byte. Dialect-specific functions, comments, and formatting are all preserved
exactly.
Known limitations:
- SQL that fails to parse under the chosen dialect raises a ``ParseError``.
A thin dialect subclass is still required for parsing — but only for
parsing, not generation.
- Predicate strings are spliced in as raw SQL. They must come from a trusted
source (the RLS config), not user input.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlglot
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.tokens import Token, TokenType
if TYPE_CHECKING:
from superset.sql.parse import Table
# Token types that end a WHERE clause / FROM section at the current paren depth,
# indicating where a new predicate must be inserted just before.
_CLAUSE_ENDS = {
TokenType.GROUP_BY,
TokenType.HAVING,
TokenType.ORDER_BY,
TokenType.WINDOW,
TokenType.QUALIFY,
TokenType.LIMIT,
TokenType.FETCH,
TokenType.CLUSTER_BY,
TokenType.DISTRIBUTE_BY,
TokenType.SORT_BY,
TokenType.CONNECT_BY,
TokenType.START_WITH,
TokenType.UNION,
TokenType.INTERSECT,
TokenType.EXCEPT,
}
_JOIN_STARTS = {
TokenType.JOIN,
TokenType.STRAIGHT_JOIN,
TokenType.JOIN_MARKER,
}
def _splice_priority(text: str) -> int:
"""
Priority for applying splices at the same offset.
Insert full SQL fragments (WHERE/ON/predicates) before closing parens so
wrapping splices like ``pred AND (existing)`` compose correctly.
"""
return 1 if text != ")" else 0
def _after_previous_token(tokens: list[Token], index: int) -> int:
"""
Return the offset immediately after the token preceding *index*.
The sqlglot tokenizer strips comments and whitespace from the token stream,
so the previous token's ``end + 1`` is the splice point that lands right
after the last real SQL content — naturally skipping any intervening
comments or whitespace, and never confusing ``--`` or ``/*`` inside string
literals for real comment delimiters.
"""
if index <= 0:
return 0
return tokens[index - 1].end + 1
def _table_from_node(
node: exp.Table,
catalog: str | None,
schema: str | None,
) -> Table:
"""
Build a fully qualified ``Table`` from a sqlglot ``exp.Table`` node, defaulting
unqualified parts to the supplied catalog/schema.
"""
# Imported lazily to avoid a circular import with ``superset.sql.parse``.
from superset.sql.parse import Table
return Table(
table=node.name,
schema=node.db if node.db else schema,
catalog=node.catalog if node.catalog else catalog,
)
def apply_rls_splice(
sql: str,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[str]],
dialect: str | None = None,
) -> str:
"""
Inject RLS predicates into ``sql`` by splicing text at the right positions.
:param sql: The original SQL query. Returned unchanged except at splice points.
:param catalog: The default catalog for non-qualified table names.
:param schema: The default schema for non-qualified table names.
:param predicates: Mapping of ``Table`` to predicate SQL strings. Each entry
maps a fully qualified table to one or more raw predicate strings to
``AND`` together when that table is referenced in a SELECT scope.
:param dialect: The sqlglot dialect used for *parsing only* — to understand
scope structure and locate token positions. The generator is never
called, so this does not affect output formatting.
:return: The query with RLS predicates injected into every relevant SELECT
scope.
"""
if not predicates or not any(predicates.values()):
return sql
resolved_dialect = sqlglot.Dialect.get_or_raise(dialect)
tokens = list(resolved_dialect.tokenize(sql))
tree = sqlglot.parse_one(sql, dialect=dialect)
splices: list[tuple[int, str]] = []
for scope in traverse_scope(tree):
splices.extend(
_splices_for_scope(
sql,
tokens,
scope,
predicates,
catalog,
schema,
dialect,
)
)
# Apply splices in reverse offset order so earlier positions stay valid.
# For equal offsets, apply predicate/WHERE/ON inserts before ")" inserts.
splices.sort(key=lambda item: (item[0], _splice_priority(item[1])), reverse=True)
result = sql
for offset, text in splices:
result = result[:offset] + text + result[offset:]
return result
def _splices_for_scope(
sql: str,
tokens: list[Token],
scope: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> list[tuple[int, str]]:
"""
Compute all splices for a single SELECT scope.
This mirrors ``RLSAsPredicateTransformer`` semantics:
- predicates for FROM tables are applied to the SELECT WHERE clause as
``pred AND (existing_where)``
- predicates for JOIN tables are applied to each JOIN ON clause as
``pred AND (existing_on)`` (or ``ON pred`` when ON is absent)
"""
from_predicates: list[str] = []
from_table_ends: list[int] = []
join_splices: list[tuple[int, str]] = []
for source in scope.sources.values(): # type: ignore[attr-defined]
source_type, table_end, pred_sql = _classify_source_predicate(
source,
predicates,
catalog,
schema,
dialect,
)
if source_type == "none" or table_end is None or pred_sql is None:
continue
if source_type == "from":
from_predicates.append(pred_sql)
from_table_ends.append(table_end)
continue
join_splice = _find_join_splice(sql, tokens, table_end, pred_sql)
if join_splice:
join_splices.extend(join_splice)
continue
if not from_predicates:
return join_splices
combined_predicates = " AND ".join(dict.fromkeys(from_predicates))
from_splice = _find_where_splice(
sql,
tokens,
max(from_table_ends),
combined_predicates,
)
return [*join_splices, *from_splice]
def _table_end(source: exp.Table) -> int | None:
ident = source.find(exp.Identifier)
if ident and getattr(ident, "_meta", None):
return ident._meta["end"]
return None
def _classify_source_predicate(
source: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> tuple[str, int | None, str | None]:
"""
Return source kind (from/join/none), table end offset, and predicate SQL.
"""
if not isinstance(source, exp.Table):
return ("none", None, None)
table = _table_from_node(source, catalog, schema)
table_predicates = [
_qualify_predicate(predicate, source, dialect)
for predicate in predicates.get(table, [])
if predicate
]
if not table_predicates:
return ("none", None, None)
table_end = _table_end(source)
if table_end is None:
return ("none", None, None)
pred_sql = " AND ".join(dict.fromkeys(table_predicates))
if isinstance(source.parent, exp.From):
return ("from", table_end, pred_sql)
if isinstance(source.parent, exp.Join):
return ("join", table_end, pred_sql)
return ("none", None, None)
def _qualify_predicate(
predicate: str,
table_node: exp.Table,
dialect: str | None,
) -> str:
"""
Qualify predicate columns with the table alias/name, mirroring
``RLSAsPredicateTransformer``.
"""
parsed = sqlglot.parse_one(predicate, dialect=dialect)
table = table_node.alias_or_name
table_expr = exp.to_identifier(table)
for column in parsed.find_all(exp.Column):
column.set("table", table_expr.copy())
return parsed.sql(dialect=dialect)
def _scan_until_scope_boundary(
tokens: list[Token],
anchor: int,
*,
stop_at_join: bool,
) -> tuple[str, int | None]:
"""
Scan tokens forward from ``anchor`` until a clause/scope boundary.
Returns ``("where", index)`` when a WHERE token is found at depth 0,
``("boundary", index)`` for a non-WHERE boundary token, and
``("eof", None)`` when no boundary token is found.
"""
depth = 0
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
return ("boundary", i)
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.WHERE:
return ("where", i)
if tok.token_type in _CLAUSE_ENDS or (
stop_at_join and tok.token_type in _JOIN_STARTS
):
return ("boundary", i)
return ("eof", None)
def _find_condition_end(
tokens: list[Token],
start_index: int,
*,
stop_at_join: bool,
) -> int:
"""
Find the end offset for a WHERE/ON condition body.
"""
depth = 0
prev_end = tokens[start_index].end
for tok in tokens[start_index + 1 :]:
if tok.token_type == TokenType.L_PAREN:
depth += 1
elif tok.token_type == TokenType.R_PAREN:
if depth == 0:
return prev_end + 1
depth -= 1
elif depth == 0 and (
(stop_at_join and tok.token_type == TokenType.WHERE)
or tok.token_type in _CLAUSE_ENDS
or (stop_at_join and tok.token_type in _JOIN_STARTS)
):
return prev_end + 1
prev_end = tok.end
return prev_end + 1
def _find_where_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to the SELECT WHERE clause:
``pred`` when absent, ``pred AND (existing)`` when present.
"""
kind, idx = _scan_until_scope_boundary(tokens, anchor, stop_at_join=False)
if kind == "where" and idx is not None:
if idx + 1 >= len(tokens):
return [(tokens[idx].end + 1, f" {pred_sql}")]
body_start = tokens[idx + 1].start
body_end = _find_condition_end(tokens, idx, stop_at_join=False)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if kind == "boundary" and idx is not None:
return [(_after_previous_token(tokens, idx), f" WHERE {pred_sql}")]
return [(len(sql), f" WHERE {pred_sql}")]
def _find_join_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to a JOIN clause:
``ON pred`` when ON absent, ``ON pred AND (existing_on)`` when present.
"""
on_index, boundary_index = _scan_join_clause(tokens, anchor)
if on_index is not None:
if on_index + 1 >= len(tokens):
return [(tokens[on_index].end + 1, f" {pred_sql}")]
body_start = tokens[on_index + 1].start
body_end = _find_condition_end(tokens, on_index, stop_at_join=True)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if boundary_index is not None:
return [(_after_previous_token(tokens, boundary_index), f" ON {pred_sql}")]
return [(len(sql), f" ON {pred_sql}")]
def _scan_join_clause(
tokens: list[Token],
anchor: int,
) -> tuple[int | None, int | None]:
"""
Find ON and boundary token indexes for a JOIN segment.
"""
depth = 0
on_index: int | None = None
boundary_index: int | None = None
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
boundary_index = i
break
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.ON and on_index is None:
on_index = i
continue
if tok.token_type == TokenType.WHERE:
boundary_index = i
break
if tok.token_type in _JOIN_STARTS or tok.token_type in _CLAUSE_ENDS:
boundary_index = i
break
return on_index, boundary_index

View File

@@ -40,17 +40,19 @@ def apply_rls(
:returns: True if any RLS predicates were actually applied, False otherwise.
"""
# There are two ways to insert RLS: either replacing the table with a subquery
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
# safer, but not supported in all databases.
method = database.db_engine_spec.get_rls_method()
# There are three ways to insert RLS:
# - replace the table with a subquery containing the RLS (safest, but not
# supported in all databases)
# - append the RLS to the ``WHERE`` clause via AST transformation
# - splice the RLS into the original SQL string (preserves dialect-specific
# syntax that the sqlglot generator would otherwise transpile)
method = database.db_engine_spec.rls_method
# collect all RLS predicates for all tables in the query
predicates: dict[Table, list[Any]] = {}
predicates: dict[Table, list[str]] = {}
for table in parsed_statement.tables:
table = table.qualify(catalog=catalog, schema=schema)
predicates[table] = [
parsed_statement.parse_predicate(predicate)
raw_predicates = [
predicate
for predicate in get_predicates_for_table(
table,
database,
@@ -58,6 +60,7 @@ def apply_rls(
)
if predicate
]
predicates[table] = raw_predicates
has_predicates = any(predicates.values())
parsed_statement.apply_rls(catalog, schema, predicates, method)

View File

@@ -36,7 +36,7 @@ from sqlalchemy.sql import sqltypes
from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2RedirectError
from superset.sql.parse import Table
from superset.sql.parse import RLSMethod, Table
from superset.superset_typing import (
OAuth2ClientConfig,
OAuth2State,
@@ -1283,3 +1283,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
error = exc_info.value.error
assert error.extra["redirect_uri"] == fallback_uri
def test_default_rls_method_is_subquery() -> None:
"""Base engine spec defaults to subquery-based RLS."""
assert BaseEngineSpec.rls_method == RLSMethod.AS_SUBQUERY

View File

@@ -209,7 +209,7 @@ class TestApplyRlsReturnValue:
from superset.utils.rls import apply_rls
database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock()
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None
statement = MagicMock()
@@ -237,7 +237,7 @@ class TestApplyRlsReturnValue:
mock_get_predicates.return_value = []
database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock()
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None
mock_table = MagicMock()
@@ -268,7 +268,7 @@ class TestApplyRlsReturnValue:
mock_get_predicates.return_value = ["user_id = 42"]
database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock()
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None
mock_table = MagicMock()
@@ -276,8 +276,6 @@ class TestApplyRlsReturnValue:
statement = MagicMock()
statement.tables = [mock_table]
statement.parse_predicate.return_value = MagicMock()
result = apply_rls(
database=database,
catalog=None,
@@ -312,11 +310,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT pens.pen_id, pens.is_green FROM public.pens"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", None): [predicate]},
{Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()
@@ -333,11 +330,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT pens.pen_id, pens.is_green FROM mycat.public.pens"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", "mycat"): [predicate]},
{Table("pens", "public", "mycat"): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()
@@ -351,11 +347,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT p.pen_id, p.is_green FROM public.pens p"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", None): [predicate]},
{Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()
@@ -369,11 +364,10 @@ class TestRLSSubqueryAlias:
"""
sql = "SELECT pen_id, is_green FROM public.pens"
statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls(
None,
"public",
{Table("pens", "public", None): [predicate]},
{Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY,
)
result = statement.format()

View File

@@ -1704,6 +1704,21 @@ def test_set_limit_value(
assert statement.format() == expected
def test_set_limit_value_after_splice_reparses_from_raw_sql() -> None:
"""
When a statement has cached verbatim SQL from splice-mode rewrites, setting
limit should reparse that SQL before mutating the AST.
"""
statement = SQLStatement("SELECT * FROM some_table", "postgresql")
statement._raw_sql = "SELECT * FROM some_table WHERE tenant_id = 42"
statement.set_limit_value(10, LimitMethod.FORCE_LIMIT)
formatted = statement.format()
assert "tenant_id = 42" in formatted
assert "LIMIT 10" in formatted
@pytest.mark.parametrize(
"kql, limit, expected",
[
@@ -2198,7 +2213,7 @@ def test_rls_subquery_transformer(
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
{k: [v] for k, v in rules.items()},
RLSMethod.AS_SUBQUERY,
)
assert statement.format() == expected
@@ -2542,12 +2557,312 @@ def test_rls_predicate_transformer(
statement.apply_rls(
"catalog1",
"schema1",
{k: [parse_one(v)] for k, v in rules.items()},
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42 "
"AND (bar = 'baz' OR foo = 'qux')",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id)",
),
(
"SELECT * FROM table JOIN other_table",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
"WHERE 1=1",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id) WHERE 1=1",
),
],
)
def test_rls_predicate_splice_semantics_match_predicate(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Splice mode should preserve predicate-mode semantics for boolean grouping
and JOIN-vs-WHERE placement.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
# Simple — no WHERE clause to extend.
(
"SELECT LAST_DAY(d) FROM some_table",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42",
),
# Append to an existing WHERE clause.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open')",
),
# WHERE precedes GROUP BY: predicate goes before GROUP BY.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open') GROUP BY d",
),
# No WHERE, but GROUP BY and ORDER BY are present.
(
"SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42 "
"GROUP BY d ORDER BY d",
),
# JOIN — predicate scoped to one of the tables.
(
"SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id",
{Table("some_table", "schema1", "catalog1"): "o.tenant_id = 42"},
"SELECT o.id FROM some_table o JOIN locations l "
"ON o.loc_id = l.id WHERE o.tenant_id = 42",
),
# JOIN — different predicate per table, both spliced into one WHERE.
(
"SELECT * FROM some_table JOIN events ON some_table.id = events.order_id",
{
Table("events", "schema1", "catalog1"): "events.user_id = 99",
Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42",
},
"SELECT * FROM some_table JOIN events "
"ON events.user_id = 99 AND (some_table.id = events.order_id) "
"WHERE some_table.tenant_id = 42",
),
# Subquery in FROM — splice into the inner SELECT.
(
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
"WHERE some_table.tenant_id = 42) sub",
),
# CTE — splice into the CTE body.
(
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42) SELECT * FROM cte",
),
# Dialect-specific function (LAST_DAY) preserved verbatim.
(
"SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT id, LAST_DAY(created_at) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Multiline + inline comment preserved exactly.
(
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Schema-qualified table name (no default schema match) — no predicate.
(
"SELECT t.foo FROM schema2.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM schema2.some_table AS t",
),
],
)
def test_rls_predicate_splice(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Test the splice-mode RLS via ``RLSMethod.AS_PREDICATE_SPLICE``.
Splice mode rewrites the original SQL string instead of re-rendering the
AST through the dialect generator, so byte-level fidelity (including
dialect-specific functions, comments, and whitespace) is preserved.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
def test_rls_predicate_splice_requires_source() -> None:
"""
Splice mode requires the original SQL substring; constructing a statement
purely from an AST should make splice mode raise.
"""
ast = parse_one("SELECT * FROM some_table")
statement = SQLStatement(ast=ast, engine="postgresql")
with pytest.raises(ValueError, match="Splice-mode RLS requires the source SQL"):
statement.apply_rls(
"catalog1",
"schema1",
{Table("some_table", "schema1", "catalog1"): ["id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
def test_rls_predicate_splice_preserves_dialect_function() -> None:
"""
Splice mode must NOT round-trip through the sqlglot generator. ``LAST_DAY``
on the postgres dialect would otherwise be transpiled by the generator.
"""
sql = "SELECT LAST_DAY(d) FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42"
)
def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
"""
Splice mode accepts predicate strings directly — no ``parse_predicate`` is
needed at the call site.
"""
sql = "SELECT * FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42 AND some_table.active"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 AND some_table.active"
)
@pytest.mark.parametrize(
"sql, expected",
[
(
"SELECT * FROM some_table -- hi\nGROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"-- hi\nGROUP BY id",
),
(
"SELECT * FROM some_table /* inline */ GROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"/* inline */ GROUP BY id",
),
],
)
def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -> None:
"""
Splice mode should insert predicates before comments that precede the next
clause boundary, so comments do not swallow the injected SQL.
"""
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, engine, expected",
[
(
"SELECT * FROM some_table QUALIFY row_number() OVER "
"(PARTITION BY id ORDER BY ts DESC) = 1",
"snowflake",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
),
(
"SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY id)",
"postgresql",
"SELECT sum(v) OVER () FROM some_table "
"WHERE some_table.tenant_id = 42 "
"WINDOW w AS (PARTITION BY id)",
),
],
)
def test_rls_predicate_splice_handles_additional_clause_boundaries(
sql: str,
engine: str,
expected: str,
) -> None:
"""
Splice mode should insert WHERE before clause types that can legally follow
FROM/WHERE (for example QUALIFY and WINDOW).
"""
statement = SQLStatement(sql, engine=engine)
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
def test_rls_predicate_splice_then_limit_keeps_rls() -> None:
"""
LIMIT rewrites after splice-mode RLS should retain injected predicates.
"""
statement = SQLStatement("SELECT * FROM some_table", engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
statement.set_limit_value(101, LimitMethod.FORCE_LIMIT)
formatted = statement.format()
assert "some_table.tenant_id = 42" in formatted
assert "LIMIT 101" in formatted
@pytest.mark.parametrize(
"sql, table, expected",
[

View File

@@ -0,0 +1,292 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import sqlglot
from sqlglot import Dialect, exp
from superset.sql.parse import SQLStatement, Table
from superset.sql.rls_splice import (
_classify_source_predicate,
_find_condition_end,
_find_join_splice,
_find_where_splice,
_scan_join_clause,
_scan_until_scope_boundary,
_splices_for_scope,
_table_end,
apply_rls_splice,
)
def _tokenize(sql: str) -> list[sqlglot.tokens.Token]:
return list(Dialect.get_or_raise(None).tokenize(sql))
def _token_index(tokens: list[sqlglot.tokens.Token], token_type: object) -> int:
return next(i for i, token in enumerate(tokens) if token.token_type == token_type)
def _token_by_text(
tokens: list[sqlglot.tokens.Token], text: str
) -> sqlglot.tokens.Token:
return next(token for token in tokens if token.text == text)
def test_split_source_returns_none_result_when_tokenize_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _BrokenDialect:
@staticmethod
def tokenize(_: str) -> list[sqlglot.tokens.Token]:
raise sqlglot.errors.SqlglotError("boom")
monkeypatch.setattr(
"superset.sql.parse.Dialect.get_or_raise",
lambda _: _BrokenDialect(),
)
assert SQLStatement._split_source("SELECT 1", "postgresql", 2) == [None, None]
def test_apply_rls_splice_ignores_empty_predicates() -> None:
sql = "SELECT 1"
assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql
def test_apply_rls_splice_ignores_dash_dash_inside_string_literal() -> None:
"""
Regression: the splice point must not be confused by ``--`` appearing
inside a string literal. Earlier ``rfind("--", ...)`` logic mistook this
for an inline comment and inserted the predicate inside the quoted text.
"""
sql = "SELECT * FROM some_table WHERE note = '--x' GROUP BY id"
spliced = apply_rls_splice(
sql,
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
dialect="postgres",
)
assert spliced == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"AND (note = '--x') GROUP BY id"
)
def test_table_end_returns_none_without_metadata() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
assert _table_end(source) is None
def test_classify_source_predicate_returns_none_without_table_metadata() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
exp.From(this=source)
result = _classify_source_predicate(
source,
{Table("foo"): ["id = 1"]},
None,
None,
None,
)
assert result == ("none", None, None)
def test_classify_source_predicate_returns_none_for_unsupported_parent() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
source.this.meta["end"] = 3
exp.Alias(this=source, alias=exp.Identifier(this="alias"))
result = _classify_source_predicate(
source,
{Table("foo"): ["id = 1"]},
None,
None,
None,
)
assert result == ("none", None, None)
def test_scan_until_scope_boundary_tracks_parenthesis_depth() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_token = _token_by_text(tokens, "WHERE")
assert _scan_until_scope_boundary(
tokens, where_token.start, stop_at_join=False
) == (
"eof",
None,
)
def test_find_condition_end_handles_subquery_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert sql[end] == ")"
def test_find_condition_end_handles_parenthesized_expression() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert end == len(sql)
def test_find_where_splice_handles_trailing_where_keyword() -> None:
sql = "SELECT * FROM t WHERE"
tokens = _tokenize(sql)
splices = _find_where_splice(sql, tokens, anchor=0, pred_sql="t.id = 1")
assert splices == [(len(sql), " t.id = 1")]
def test_find_join_splice_handles_trailing_on_keyword() -> None:
sql = "SELECT * FROM a JOIN b ON"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
assert splices == [(len(sql), " b.id = 1")]
def test_find_join_splice_inserts_on_before_where_boundary() -> None:
sql = "SELECT * FROM a JOIN b WHERE x = 1"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
assert splices == [(sql.index("WHERE") - 1, " ON b.id = 1")]
def test_scan_join_clause_covers_nested_parentheses_and_join_boundary() -> None:
sql = "SELECT * FROM a JOIN b ON (a.id = b.id) JOIN c ON 1 = 1"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
on_index, boundary_index = _scan_join_clause(tokens, b_token.end)
assert on_index is not None
assert boundary_index is not None
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.JOIN
def test_scan_join_clause_stops_at_outer_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM a JOIN b) sub"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
_, boundary_index = _scan_join_clause(tokens, b_token.end)
assert boundary_index is not None
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.R_PAREN
def test_splices_for_scope_handles_empty_join_splice_result(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"x": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate",
lambda *args, **kwargs: ("join", 0, "x.id = 1"),
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [],
)
assert (
_splices_for_scope(
sql,
tokens,
_Scope(),
{Table("x"): ["x.id = 1"]},
None,
None,
None,
)
== []
)
def test_splices_for_scope_combines_join_and_from_splices(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"f": object(), "j": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
calls = [("from", 3, "f.id = 1"), ("join", 6, "j.id = 2")]
def _fake_classify(*args: object, **kwargs: object) -> tuple[str, int, str]:
return calls.pop(0)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [(50, " ON j.id = 2")],
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_where_splice",
lambda *args, **kwargs: [(20, " WHERE f.id = 1")],
)
assert _splices_for_scope(
sql,
tokens,
_Scope(),
{Table("f"): ["id = 1"], Table("j"): ["id = 2"]},
None,
None,
None,
) == [(50, " ON j.id = 2"), (20, " WHERE f.id = 1")]
def test_splices_for_scope_join_then_next_source(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"j": object(), "f": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
calls = [("join", 3, "j.id = 2"), ("none", None, None)]
def _fake_classify(
*args: object, **kwargs: object
) -> tuple[str, int | None, str | None]:
return calls.pop(0)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [],
)
assert (
_splices_for_scope(
sql,
tokens,
_Scope(),
{Table("j"): ["id = 2"]},
None,
None,
None,
)
== []
)