mirror of
https://github.com/apache/superset.git
synced 2026-05-13 11:55:16 +00:00
Compare commits
11 Commits
embedded-e
...
rls-splice
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0ce43abe5b | ||
|
|
439130db54 | ||
|
|
ec9f2da81e | ||
|
|
549e51bdc4 | ||
|
|
72ecb20e5c | ||
|
|
d1cd84931e | ||
|
|
884649e3ed | ||
|
|
e9dd9a6107 | ||
|
|
0fe8293c7f | ||
|
|
af2d3babec | ||
|
|
3fe2b2505f |
@@ -557,6 +557,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
# if True, database will be listed as option in the upload file form
|
||||
supports_file_upload = True
|
||||
|
||||
# RLS strategy for this engine spec. Override in engine-specific classes as
|
||||
# needed (for example ``RLSMethod.AS_PREDICATE`` for engines that don't
|
||||
# support subquery-based RLS, or ``RLSMethod.AS_PREDICATE_SPLICE`` for
|
||||
# engines where sqlglot generation is not faithful).
|
||||
rls_method = RLSMethod.AS_SUBQUERY
|
||||
|
||||
# Is the DB engine spec able to change the default schema? This requires implementing # noqa: E501
|
||||
# a custom `adjust_engine_params` method.
|
||||
supports_dynamic_schema = False
|
||||
@@ -618,21 +624,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
else cls.encrypted_extra_sensitive_fields
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_rls_method(cls) -> RLSMethod:
|
||||
"""
|
||||
Returns the RLS method to be used for this engine.
|
||||
|
||||
There are two ways to insert RLS: either replacing the table with a subquery
|
||||
that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
|
||||
safer, but not supported in all databases.
|
||||
"""
|
||||
return (
|
||||
RLSMethod.AS_SUBQUERY
|
||||
if cls.allows_subqueries and cls.allows_alias_in_select
|
||||
else RLSMethod.AS_PREDICATE
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_oauth2_enabled(cls) -> bool:
|
||||
return (
|
||||
|
||||
@@ -35,6 +35,7 @@ from superset.db_engine_specs.base import (
|
||||
DatabaseCategory,
|
||||
)
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.sql.parse import RLSMethod
|
||||
from superset.utils.network import is_hostname_valid, is_port_open
|
||||
|
||||
|
||||
@@ -82,6 +83,7 @@ class CouchbaseEngineSpec(BasicParametersMixin, BaseEngineSpec):
|
||||
default_driver = "couchbase"
|
||||
allows_joins = False
|
||||
allows_subqueries = False
|
||||
rls_method = RLSMethod.AS_PREDICATE
|
||||
sqlalchemy_uri_placeholder = (
|
||||
"couchbase://user:password@host[:port]?truststorepath=value?ssl=value"
|
||||
)
|
||||
|
||||
@@ -23,6 +23,7 @@ from sqlalchemy import types
|
||||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import TableColumn
|
||||
@@ -68,6 +69,8 @@ class CrateEngineSpec(BaseEngineSpec):
|
||||
TimeGrain.YEAR: "DATE_TRUNC('year', {col})",
|
||||
}
|
||||
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls) -> str:
|
||||
return "{col} * 1000"
|
||||
|
||||
@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import (
|
||||
)
|
||||
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.sql.parse import RLSMethod
|
||||
from superset.utils.core import GenericDataType
|
||||
from superset.utils.hashing import hash_from_str
|
||||
from superset.utils.network import is_hostname_valid, is_port_open
|
||||
@@ -55,6 +56,8 @@ class DatabendBaseEngineSpec(BaseEngineSpec):
|
||||
time_secondary_columns = True
|
||||
time_groupby_inline = True
|
||||
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",
|
||||
|
||||
@@ -26,6 +26,7 @@ from superset.db_engine_specs.base import (
|
||||
DatabaseCategory,
|
||||
)
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
|
||||
# Internal class for defining error message patterns (for translation)
|
||||
@@ -58,6 +59,8 @@ class DenodoEngineSpec(BaseEngineSpec, BasicParametersMixin):
|
||||
engine = "denodo"
|
||||
engine_name = "Denodo"
|
||||
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
default_driver = "psycopg2"
|
||||
sqlalchemy_uri_placeholder = (
|
||||
"denodo://user:password@host:port/dbname[?key=value&key=value...]"
|
||||
|
||||
@@ -21,12 +21,15 @@ from sqlalchemy import types
|
||||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
|
||||
class DynamoDBEngineSpec(BaseEngineSpec):
|
||||
engine = "dynamodb"
|
||||
engine_name = "Amazon DynamoDB"
|
||||
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
metadata = {
|
||||
"description": (
|
||||
"Amazon DynamoDB is a serverless NoSQL database with SQL via PartiQL."
|
||||
|
||||
@@ -28,6 +28,7 @@ from superset.db_engine_specs.exceptions import (
|
||||
SupersetDBAPIOperationalError,
|
||||
SupersetDBAPIProgrammingError,
|
||||
)
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
@@ -39,6 +40,7 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
|
||||
allows_joins = False
|
||||
allows_subqueries = True
|
||||
allows_sql_comments = False
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
metadata = {
|
||||
"description": (
|
||||
|
||||
@@ -21,7 +21,7 @@ from sqlalchemy import types
|
||||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
|
||||
from superset.sql.parse import LimitMethod
|
||||
from superset.sql.parse import LimitMethod, RLSMethod
|
||||
|
||||
|
||||
class FirebirdEngineSpec(BaseEngineSpec):
|
||||
@@ -53,6 +53,8 @@ class FirebirdEngineSpec(BaseEngineSpec):
|
||||
# Firebird uses FIRST to limit: `SELECT FIRST 10 * FROM table`
|
||||
limit_method = LimitMethod.FETCH_MANY
|
||||
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
TimeGrain.SECOND: (
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
from .db2 import Db2EngineSpec
|
||||
|
||||
|
||||
@@ -28,6 +30,8 @@ class IBMiEngineSpec(Db2EngineSpec):
|
||||
engine_name = "IBM Db2 for i"
|
||||
max_column_name_length = 128
|
||||
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
@classmethod
|
||||
def epoch_to_dttm(cls) -> str:
|
||||
return "(DAYS({col}) - DAYS('1970-01-01')) * 86400 + MIDNIGHT_SECONDS({col})"
|
||||
|
||||
@@ -28,7 +28,7 @@ from superset.db_engine_specs.exceptions import (
|
||||
SupersetDBAPIOperationalError,
|
||||
SupersetDBAPIProgrammingError,
|
||||
)
|
||||
from superset.sql.parse import LimitMethod
|
||||
from superset.sql.parse import LimitMethod, RLSMethod
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
|
||||
@@ -40,6 +40,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
allows_joins = True
|
||||
allows_subqueries = True
|
||||
allows_sql_comments = False
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
metadata = {
|
||||
"description": (
|
||||
|
||||
@@ -21,6 +21,7 @@ from sqlalchemy import types
|
||||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
|
||||
class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
@@ -29,6 +30,8 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
engine = "kylin"
|
||||
engine_name = "Apache Kylin"
|
||||
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
metadata = {
|
||||
"description": "Apache Kylin is an open-source OLAP engine for big data.",
|
||||
"logo": "apache-kylin.png",
|
||||
|
||||
@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
|
||||
from superset.errors import SupersetErrorType
|
||||
from superset.models.core import Database
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
# Regular expressions to catch custom errors
|
||||
|
||||
@@ -227,6 +228,8 @@ class OcientEngineSpec(BaseEngineSpec):
|
||||
force_column_alias_quotes = True
|
||||
max_column_name_length = 30
|
||||
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
|
||||
allows_cte_in_subquery = False
|
||||
# Ocient does not support cte names starting with underscores
|
||||
cte_alias = "cte__"
|
||||
|
||||
@@ -20,6 +20,7 @@ from sqlalchemy.types import TypeEngine
|
||||
|
||||
from superset.constants import TimeGrain
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
|
||||
class PinotEngineSpec(BaseEngineSpec):
|
||||
@@ -30,6 +31,7 @@ class PinotEngineSpec(BaseEngineSpec):
|
||||
allows_joins = False
|
||||
allows_alias_in_select = False
|
||||
allows_alias_in_orderby = False
|
||||
rls_method = RLSMethod.AS_PREDICATE
|
||||
|
||||
# pinotdb only sets cursor.description when the response contains
|
||||
# columnDataTypes, which Pinot omits for zero-row results.
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# under the License.
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
|
||||
class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
@@ -27,6 +28,7 @@ class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
|
||||
time_groupby_inline = False
|
||||
allows_joins = False
|
||||
allows_subqueries = False
|
||||
rls_method = RLSMethod.AS_PREDICATE
|
||||
|
||||
metadata = {
|
||||
"description": "Apache Solr is an open-source enterprise search platform.",
|
||||
|
||||
@@ -22,6 +22,7 @@ from urllib import parse
|
||||
from sqlalchemy.engine.url import make_url, URL # noqa: F401
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
|
||||
from superset.sql.parse import RLSMethod
|
||||
|
||||
|
||||
class TDengineEngineSpec(BaseEngineSpec):
|
||||
@@ -29,6 +30,8 @@ class TDengineEngineSpec(BaseEngineSpec):
|
||||
engine_name = "TDengine"
|
||||
max_column_name_length = 64
|
||||
default_driver = "taosws"
|
||||
|
||||
rls_method = RLSMethod.AS_PREDICATE_SPLICE
|
||||
sqlalchemy_uri_placeholder = (
|
||||
"taosws://user:******@host:port/dbname[?key=value&key=value...]"
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ SQLGLOT_DIALECTS = {
|
||||
"duckdb": Dialects.DUCKDB,
|
||||
# "dynamodb": ???
|
||||
# "elasticsearch": ???
|
||||
# "exa": ???
|
||||
"exa": Dialects.EXASOL,
|
||||
# "firebird": ???
|
||||
"firebolt": Firebolt,
|
||||
"gsheets": Dialects.SQLITE,
|
||||
@@ -105,7 +105,7 @@ SQLGLOT_DIALECTS = {
|
||||
"shillelagh": Dialects.SQLITE,
|
||||
"singlestoredb": SingleStore,
|
||||
"snowflake": Dialects.SNOWFLAKE,
|
||||
# "solr": ???
|
||||
"solr": Dialects.SOLR,
|
||||
"spark": Dialects.SPARK,
|
||||
"sqlite": Dialects.SQLITE,
|
||||
"starrocks": Dialects.STARROCKS,
|
||||
@@ -142,6 +142,7 @@ class RLSMethod(enum.Enum):
|
||||
|
||||
AS_PREDICATE = enum.auto()
|
||||
AS_SUBQUERY = enum.auto()
|
||||
AS_PREDICATE_SPLICE = enum.auto()
|
||||
|
||||
|
||||
class RLSTransformer:
|
||||
@@ -355,6 +356,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
statement: str | None = None,
|
||||
engine: str = "base",
|
||||
ast: InternalRepresentation | None = None,
|
||||
source: str | None = None,
|
||||
):
|
||||
if ast:
|
||||
self._parsed = ast
|
||||
@@ -365,6 +367,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
|
||||
self.engine = engine
|
||||
self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
|
||||
# Original SQL substring for this statement, when known. Used by the
|
||||
# splice-mode RLS path which rewrites this string instead of regenerating
|
||||
# SQL from the AST. ``None`` means the statement was constructed from an
|
||||
# AST without an associated source string (splice mode falls back).
|
||||
self._source_sql: str | None = source if source is not None else statement
|
||||
# Verbatim SQL to return from ``format()``. Set by string-rewriting
|
||||
# operations (e.g. splice-mode RLS) that produce a final SQL string and
|
||||
# need to bypass the dialect generator. Cleared by AST-mutating methods
|
||||
# since those invalidate this cached text.
|
||||
self._raw_sql: str | None = None
|
||||
|
||||
@classmethod
|
||||
def split_script(
|
||||
@@ -531,7 +543,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[InternalRepresentation]],
|
||||
predicates: dict[Table, list[str]],
|
||||
method: RLSMethod,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -559,9 +571,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
statement: str | None = None,
|
||||
engine: str = "base",
|
||||
ast: exp.Expression | None = None,
|
||||
source: str | None = None,
|
||||
):
|
||||
self._dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
super().__init__(statement, engine, ast)
|
||||
super().__init__(statement, engine, ast, source)
|
||||
|
||||
@classmethod
|
||||
def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
|
||||
@@ -626,10 +639,55 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
script: str,
|
||||
engine: str,
|
||||
) -> list[SQLStatement]:
|
||||
asts = [ast for ast in cls._parse(script, engine) if ast]
|
||||
sources = cls._split_source(script, engine, len(asts))
|
||||
return [
|
||||
cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast
|
||||
cls(ast=ast, engine=engine, source=source)
|
||||
for ast, source in zip(asts, sources, strict=False)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def _split_source(
|
||||
cls,
|
||||
script: str,
|
||||
engine: str,
|
||||
expected_count: int,
|
||||
) -> list[str | None]:
|
||||
"""
|
||||
Slice ``script`` into per-statement substrings using top-level semicolon
|
||||
positions from the tokenizer. Returns a list of length ``expected_count``;
|
||||
any entry is ``None`` if the slicing didn't yield a usable substring.
|
||||
|
||||
The returned substrings preserve the original byte content of the script
|
||||
for each statement — necessary for splice-mode RLS, which rewrites the
|
||||
original SQL rather than regenerating from the AST.
|
||||
"""
|
||||
none_result: list[str | None] = [None] * expected_count
|
||||
dialect = SQLGLOT_DIALECTS.get(engine)
|
||||
try:
|
||||
tokens = list(Dialect.get_or_raise(dialect).tokenize(script))
|
||||
except sqlglot.errors.SqlglotError:
|
||||
return none_result
|
||||
|
||||
# Top-level semicolon offsets (depth 0).
|
||||
boundaries: list[int] = []
|
||||
depth = 0
|
||||
for tok in tokens:
|
||||
if tok.token_type == sqlglot.tokens.TokenType.L_PAREN:
|
||||
depth += 1
|
||||
elif tok.token_type == sqlglot.tokens.TokenType.R_PAREN:
|
||||
depth -= 1
|
||||
elif tok.token_type == sqlglot.tokens.TokenType.SEMICOLON and depth == 0:
|
||||
boundaries.append(tok.start)
|
||||
|
||||
starts = [0, *(b + 1 for b in boundaries)]
|
||||
ends = [*boundaries, len(script)]
|
||||
sources = [script[s:e].strip() for s, e in zip(starts, ends, strict=False)]
|
||||
sources = [s for s in sources if s]
|
||||
if len(sources) != expected_count:
|
||||
return none_result
|
||||
return list(sources)
|
||||
|
||||
@classmethod
|
||||
def _parse_statement(
|
||||
cls,
|
||||
@@ -722,7 +780,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
def format(self, comments: bool = True) -> str:
|
||||
"""
|
||||
Pretty-format the SQL statement.
|
||||
|
||||
When a string-rewriting operation (e.g. splice-mode RLS) has cached a
|
||||
verbatim result in ``_raw_sql``, return it as-is — the whole point of
|
||||
those operations is to avoid the dialect generator round-trip.
|
||||
"""
|
||||
if self._raw_sql is not None:
|
||||
return self._raw_sql
|
||||
return Dialect.get_or_raise(self._dialect).generate(
|
||||
self._parsed,
|
||||
copy=True,
|
||||
@@ -808,6 +872,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
"""
|
||||
Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
|
||||
"""
|
||||
# AST mutation invalidates any cached verbatim SQL (e.g. from splice).
|
||||
# If we already have a rewritten SQL string, re-parse it first so further
|
||||
# AST mutations (like LIMIT injection) preserve prior text-based rewrites.
|
||||
if self._raw_sql is not None:
|
||||
self._parsed = self._parse_statement(self._raw_sql, self.engine)
|
||||
self._source_sql = self._raw_sql
|
||||
self._raw_sql = None
|
||||
if method == LimitMethod.FORCE_LIMIT:
|
||||
self._parsed.args["limit"] = exp.Limit(
|
||||
expression=exp.Literal(this=str(limit), is_string=False)
|
||||
@@ -902,7 +973,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[exp.Expression]],
|
||||
predicates: dict[Table, list[str]],
|
||||
method: RLSMethod,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -910,11 +981,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
|
||||
:param catalog: The default catalog for non-qualified table names
|
||||
:param schema: The default schema for non-qualified table names
|
||||
:param predicates: Mapping of fully qualified ``Table`` to raw predicate
|
||||
SQL strings.
|
||||
:param method: The method to use for applying the rules.
|
||||
"""
|
||||
if not predicates:
|
||||
return
|
||||
|
||||
if method == RLSMethod.AS_PREDICATE_SPLICE:
|
||||
self._apply_rls_splice(catalog, schema, predicates)
|
||||
return
|
||||
|
||||
parsed_predicates: dict[Table, list[exp.Expression]] = {
|
||||
table: [self.parse_predicate(predicate) for predicate in table_predicates]
|
||||
for table, table_predicates in predicates.items()
|
||||
}
|
||||
|
||||
transformers = {
|
||||
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
|
||||
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
|
||||
@@ -922,9 +1004,39 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
if method not in transformers:
|
||||
raise ValueError(f"Invalid RLS method: {method}")
|
||||
|
||||
transformer = transformers[method](catalog, schema, predicates)
|
||||
transformer = transformers[method](catalog, schema, parsed_predicates)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
def _apply_rls_splice(
|
||||
self,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[str]],
|
||||
) -> None:
|
||||
"""
|
||||
Apply RLS via text splicing on the original SQL.
|
||||
|
||||
Requires the source SQL substring to be available. Raises ``ValueError``
|
||||
if it isn't — the caller must ensure the statement was constructed from
|
||||
a source string (the standard ``SQLScript`` path does this).
|
||||
"""
|
||||
from superset.sql.rls_splice import apply_rls_splice
|
||||
|
||||
if self._source_sql is None:
|
||||
raise ValueError(
|
||||
"Splice-mode RLS requires the source SQL string; "
|
||||
"this SQLStatement was constructed without one."
|
||||
)
|
||||
|
||||
spliced = apply_rls_splice(
|
||||
self._source_sql,
|
||||
catalog,
|
||||
schema,
|
||||
predicates,
|
||||
dialect=self._dialect,
|
||||
)
|
||||
self._raw_sql = spliced
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
|
||||
462
superset/sql/rls_splice.py
Normal file
462
superset/sql/rls_splice.py
Normal file
@@ -0,0 +1,462 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
RLS predicate injection via text splicing.
|
||||
|
||||
Instead of round-tripping through sqlglot's generator (which transpiles
|
||||
dialect-specific functions like ``LAST_DAY`` into something else), this approach:
|
||||
|
||||
1. Parses the SQL with sqlglot — only to understand structure (scope tree).
|
||||
2. Uses sqlglot's tokenizer to get byte-accurate positions for every token
|
||||
in the original SQL string.
|
||||
3. For each ``SELECT`` scope that references a table with an RLS predicate,
|
||||
finds the exact byte offset to inject at — either the end of an existing
|
||||
``WHERE`` clause, or just before ``GROUP BY`` / ``ORDER BY`` / ``HAVING``
|
||||
/ ``LIMIT`` / the closing paren of a subquery.
|
||||
4. Splices the predicate text directly into the original string at that
|
||||
offset — never calling ``.sql()``, so the generator never runs.
|
||||
|
||||
Result: everything outside the splice points is the original SQL, byte for
|
||||
byte. Dialect-specific functions, comments, and formatting are all preserved
|
||||
exactly.
|
||||
|
||||
Known limitations:
|
||||
- SQL that fails to parse under the chosen dialect raises a ``ParseError``.
|
||||
A thin dialect subclass is still required for parsing — but only for
|
||||
parsing, not generation.
|
||||
- Predicate strings are spliced in as raw SQL. They must come from a trusted
|
||||
source (the RLS config), not user input.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sqlglot
|
||||
from sqlglot import exp
|
||||
from sqlglot.optimizer.scope import traverse_scope
|
||||
from sqlglot.tokens import Token, TokenType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.sql.parse import Table
|
||||
|
||||
|
||||
# Token types that end a WHERE clause / FROM section at the current paren depth,
|
||||
# indicating where a new predicate must be inserted just before.
|
||||
_CLAUSE_ENDS = {
|
||||
TokenType.GROUP_BY,
|
||||
TokenType.HAVING,
|
||||
TokenType.ORDER_BY,
|
||||
TokenType.WINDOW,
|
||||
TokenType.QUALIFY,
|
||||
TokenType.LIMIT,
|
||||
TokenType.FETCH,
|
||||
TokenType.CLUSTER_BY,
|
||||
TokenType.DISTRIBUTE_BY,
|
||||
TokenType.SORT_BY,
|
||||
TokenType.CONNECT_BY,
|
||||
TokenType.START_WITH,
|
||||
TokenType.UNION,
|
||||
TokenType.INTERSECT,
|
||||
TokenType.EXCEPT,
|
||||
}
|
||||
|
||||
_JOIN_STARTS = {
|
||||
TokenType.JOIN,
|
||||
TokenType.STRAIGHT_JOIN,
|
||||
TokenType.JOIN_MARKER,
|
||||
}
|
||||
|
||||
|
||||
def _splice_priority(text: str) -> int:
|
||||
"""
|
||||
Priority for applying splices at the same offset.
|
||||
|
||||
Insert full SQL fragments (WHERE/ON/predicates) before closing parens so
|
||||
wrapping splices like ``pred AND (existing)`` compose correctly.
|
||||
"""
|
||||
return 1 if text != ")" else 0
|
||||
|
||||
|
||||
def _after_previous_token(tokens: list[Token], index: int) -> int:
|
||||
"""
|
||||
Return the offset immediately after the token preceding *index*.
|
||||
|
||||
The sqlglot tokenizer strips comments and whitespace from the token stream,
|
||||
so the previous token's ``end + 1`` is the splice point that lands right
|
||||
after the last real SQL content — naturally skipping any intervening
|
||||
comments or whitespace, and never confusing ``--`` or ``/*`` inside string
|
||||
literals for real comment delimiters.
|
||||
"""
|
||||
if index <= 0:
|
||||
return 0
|
||||
return tokens[index - 1].end + 1
|
||||
|
||||
|
||||
def _table_from_node(
|
||||
node: exp.Table,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
) -> Table:
|
||||
"""
|
||||
Build a fully qualified ``Table`` from a sqlglot ``exp.Table`` node, defaulting
|
||||
unqualified parts to the supplied catalog/schema.
|
||||
"""
|
||||
# Imported lazily to avoid a circular import with ``superset.sql.parse``.
|
||||
from superset.sql.parse import Table
|
||||
|
||||
return Table(
|
||||
table=node.name,
|
||||
schema=node.db if node.db else schema,
|
||||
catalog=node.catalog if node.catalog else catalog,
|
||||
)
|
||||
|
||||
|
||||
def apply_rls_splice(
|
||||
sql: str,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
predicates: dict[Table, list[str]],
|
||||
dialect: str | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Inject RLS predicates into ``sql`` by splicing text at the right positions.
|
||||
|
||||
:param sql: The original SQL query. Returned unchanged except at splice points.
|
||||
:param catalog: The default catalog for non-qualified table names.
|
||||
:param schema: The default schema for non-qualified table names.
|
||||
:param predicates: Mapping of ``Table`` to predicate SQL strings. Each entry
|
||||
maps a fully qualified table to one or more raw predicate strings to
|
||||
``AND`` together when that table is referenced in a SELECT scope.
|
||||
:param dialect: The sqlglot dialect used for *parsing only* — to understand
|
||||
scope structure and locate token positions. The generator is never
|
||||
called, so this does not affect output formatting.
|
||||
:return: The query with RLS predicates injected into every relevant SELECT
|
||||
scope.
|
||||
"""
|
||||
if not predicates or not any(predicates.values()):
|
||||
return sql
|
||||
|
||||
resolved_dialect = sqlglot.Dialect.get_or_raise(dialect)
|
||||
tokens = list(resolved_dialect.tokenize(sql))
|
||||
tree = sqlglot.parse_one(sql, dialect=dialect)
|
||||
|
||||
splices: list[tuple[int, str]] = []
|
||||
for scope in traverse_scope(tree):
|
||||
splices.extend(
|
||||
_splices_for_scope(
|
||||
sql,
|
||||
tokens,
|
||||
scope,
|
||||
predicates,
|
||||
catalog,
|
||||
schema,
|
||||
dialect,
|
||||
)
|
||||
)
|
||||
|
||||
# Apply splices in reverse offset order so earlier positions stay valid.
|
||||
# For equal offsets, apply predicate/WHERE/ON inserts before ")" inserts.
|
||||
splices.sort(key=lambda item: (item[0], _splice_priority(item[1])), reverse=True)
|
||||
result = sql
|
||||
for offset, text in splices:
|
||||
result = result[:offset] + text + result[offset:]
|
||||
return result
|
||||
|
||||
|
||||
def _splices_for_scope(
|
||||
sql: str,
|
||||
tokens: list[Token],
|
||||
scope: object,
|
||||
predicates: dict[Table, list[str]],
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
dialect: str | None,
|
||||
) -> list[tuple[int, str]]:
|
||||
"""
|
||||
Compute all splices for a single SELECT scope.
|
||||
|
||||
This mirrors ``RLSAsPredicateTransformer`` semantics:
|
||||
- predicates for FROM tables are applied to the SELECT WHERE clause as
|
||||
``pred AND (existing_where)``
|
||||
- predicates for JOIN tables are applied to each JOIN ON clause as
|
||||
``pred AND (existing_on)`` (or ``ON pred`` when ON is absent)
|
||||
"""
|
||||
from_predicates: list[str] = []
|
||||
from_table_ends: list[int] = []
|
||||
join_splices: list[tuple[int, str]] = []
|
||||
|
||||
for source in scope.sources.values(): # type: ignore[attr-defined]
|
||||
source_type, table_end, pred_sql = _classify_source_predicate(
|
||||
source,
|
||||
predicates,
|
||||
catalog,
|
||||
schema,
|
||||
dialect,
|
||||
)
|
||||
if source_type == "none" or table_end is None or pred_sql is None:
|
||||
continue
|
||||
|
||||
if source_type == "from":
|
||||
from_predicates.append(pred_sql)
|
||||
from_table_ends.append(table_end)
|
||||
continue
|
||||
|
||||
join_splice = _find_join_splice(sql, tokens, table_end, pred_sql)
|
||||
if join_splice:
|
||||
join_splices.extend(join_splice)
|
||||
continue
|
||||
|
||||
if not from_predicates:
|
||||
return join_splices
|
||||
|
||||
combined_predicates = " AND ".join(dict.fromkeys(from_predicates))
|
||||
from_splice = _find_where_splice(
|
||||
sql,
|
||||
tokens,
|
||||
max(from_table_ends),
|
||||
combined_predicates,
|
||||
)
|
||||
return [*join_splices, *from_splice]
|
||||
|
||||
|
||||
def _table_end(source: exp.Table) -> int | None:
|
||||
ident = source.find(exp.Identifier)
|
||||
if ident and getattr(ident, "_meta", None):
|
||||
return ident._meta["end"]
|
||||
return None
|
||||
|
||||
|
||||
def _classify_source_predicate(
|
||||
source: object,
|
||||
predicates: dict[Table, list[str]],
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
dialect: str | None,
|
||||
) -> tuple[str, int | None, str | None]:
|
||||
"""
|
||||
Return source kind (from/join/none), table end offset, and predicate SQL.
|
||||
"""
|
||||
if not isinstance(source, exp.Table):
|
||||
return ("none", None, None)
|
||||
|
||||
table = _table_from_node(source, catalog, schema)
|
||||
table_predicates = [
|
||||
_qualify_predicate(predicate, source, dialect)
|
||||
for predicate in predicates.get(table, [])
|
||||
if predicate
|
||||
]
|
||||
if not table_predicates:
|
||||
return ("none", None, None)
|
||||
|
||||
table_end = _table_end(source)
|
||||
if table_end is None:
|
||||
return ("none", None, None)
|
||||
|
||||
pred_sql = " AND ".join(dict.fromkeys(table_predicates))
|
||||
if isinstance(source.parent, exp.From):
|
||||
return ("from", table_end, pred_sql)
|
||||
if isinstance(source.parent, exp.Join):
|
||||
return ("join", table_end, pred_sql)
|
||||
return ("none", None, None)
|
||||
|
||||
|
||||
def _qualify_predicate(
|
||||
predicate: str,
|
||||
table_node: exp.Table,
|
||||
dialect: str | None,
|
||||
) -> str:
|
||||
"""
|
||||
Qualify predicate columns with the table alias/name, mirroring
|
||||
``RLSAsPredicateTransformer``.
|
||||
"""
|
||||
parsed = sqlglot.parse_one(predicate, dialect=dialect)
|
||||
table = table_node.alias_or_name
|
||||
table_expr = exp.to_identifier(table)
|
||||
for column in parsed.find_all(exp.Column):
|
||||
column.set("table", table_expr.copy())
|
||||
return parsed.sql(dialect=dialect)
|
||||
|
||||
|
||||
def _scan_until_scope_boundary(
|
||||
tokens: list[Token],
|
||||
anchor: int,
|
||||
*,
|
||||
stop_at_join: bool,
|
||||
) -> tuple[str, int | None]:
|
||||
"""
|
||||
Scan tokens forward from ``anchor`` until a clause/scope boundary.
|
||||
|
||||
Returns ``("where", index)`` when a WHERE token is found at depth 0,
|
||||
``("boundary", index)`` for a non-WHERE boundary token, and
|
||||
``("eof", None)`` when no boundary token is found.
|
||||
"""
|
||||
depth = 0
|
||||
for i, tok in enumerate(tokens):
|
||||
if tok.start <= anchor:
|
||||
continue
|
||||
|
||||
if tok.token_type == TokenType.L_PAREN:
|
||||
depth += 1
|
||||
continue
|
||||
|
||||
if tok.token_type == TokenType.R_PAREN:
|
||||
if depth == 0:
|
||||
return ("boundary", i)
|
||||
depth -= 1
|
||||
continue
|
||||
|
||||
if depth > 0:
|
||||
continue
|
||||
|
||||
if tok.token_type == TokenType.WHERE:
|
||||
return ("where", i)
|
||||
|
||||
if tok.token_type in _CLAUSE_ENDS or (
|
||||
stop_at_join and tok.token_type in _JOIN_STARTS
|
||||
):
|
||||
return ("boundary", i)
|
||||
|
||||
return ("eof", None)
|
||||
|
||||
|
||||
def _find_condition_end(
|
||||
tokens: list[Token],
|
||||
start_index: int,
|
||||
*,
|
||||
stop_at_join: bool,
|
||||
) -> int:
|
||||
"""
|
||||
Find the end offset for a WHERE/ON condition body.
|
||||
"""
|
||||
depth = 0
|
||||
prev_end = tokens[start_index].end
|
||||
for tok in tokens[start_index + 1 :]:
|
||||
if tok.token_type == TokenType.L_PAREN:
|
||||
depth += 1
|
||||
elif tok.token_type == TokenType.R_PAREN:
|
||||
if depth == 0:
|
||||
return prev_end + 1
|
||||
depth -= 1
|
||||
elif depth == 0 and (
|
||||
(stop_at_join and tok.token_type == TokenType.WHERE)
|
||||
or tok.token_type in _CLAUSE_ENDS
|
||||
or (stop_at_join and tok.token_type in _JOIN_STARTS)
|
||||
):
|
||||
return prev_end + 1
|
||||
prev_end = tok.end
|
||||
return prev_end + 1
|
||||
|
||||
|
||||
def _find_where_splice(
|
||||
sql: str,
|
||||
tokens: list[Token],
|
||||
anchor: int,
|
||||
pred_sql: str,
|
||||
) -> list[tuple[int, str]]:
|
||||
"""
|
||||
Build splices for adding predicate semantics to the SELECT WHERE clause:
|
||||
``pred`` when absent, ``pred AND (existing)`` when present.
|
||||
"""
|
||||
kind, idx = _scan_until_scope_boundary(tokens, anchor, stop_at_join=False)
|
||||
if kind == "where" and idx is not None:
|
||||
if idx + 1 >= len(tokens):
|
||||
return [(tokens[idx].end + 1, f" {pred_sql}")]
|
||||
body_start = tokens[idx + 1].start
|
||||
body_end = _find_condition_end(tokens, idx, stop_at_join=False)
|
||||
return [
|
||||
(body_start, f"{pred_sql} AND ("),
|
||||
(body_end, ")"),
|
||||
]
|
||||
|
||||
if kind == "boundary" and idx is not None:
|
||||
return [(_after_previous_token(tokens, idx), f" WHERE {pred_sql}")]
|
||||
|
||||
return [(len(sql), f" WHERE {pred_sql}")]
|
||||
|
||||
|
||||
def _find_join_splice(
|
||||
sql: str,
|
||||
tokens: list[Token],
|
||||
anchor: int,
|
||||
pred_sql: str,
|
||||
) -> list[tuple[int, str]]:
|
||||
"""
|
||||
Build splices for adding predicate semantics to a JOIN clause:
|
||||
``ON pred`` when ON absent, ``ON pred AND (existing_on)`` when present.
|
||||
"""
|
||||
on_index, boundary_index = _scan_join_clause(tokens, anchor)
|
||||
|
||||
if on_index is not None:
|
||||
if on_index + 1 >= len(tokens):
|
||||
return [(tokens[on_index].end + 1, f" {pred_sql}")]
|
||||
body_start = tokens[on_index + 1].start
|
||||
body_end = _find_condition_end(tokens, on_index, stop_at_join=True)
|
||||
return [
|
||||
(body_start, f"{pred_sql} AND ("),
|
||||
(body_end, ")"),
|
||||
]
|
||||
|
||||
if boundary_index is not None:
|
||||
return [(_after_previous_token(tokens, boundary_index), f" ON {pred_sql}")]
|
||||
|
||||
return [(len(sql), f" ON {pred_sql}")]
|
||||
|
||||
|
||||
def _scan_join_clause(
|
||||
tokens: list[Token],
|
||||
anchor: int,
|
||||
) -> tuple[int | None, int | None]:
|
||||
"""
|
||||
Find ON and boundary token indexes for a JOIN segment.
|
||||
"""
|
||||
depth = 0
|
||||
on_index: int | None = None
|
||||
boundary_index: int | None = None
|
||||
|
||||
for i, tok in enumerate(tokens):
|
||||
if tok.start <= anchor:
|
||||
continue
|
||||
|
||||
if tok.token_type == TokenType.L_PAREN:
|
||||
depth += 1
|
||||
continue
|
||||
|
||||
if tok.token_type == TokenType.R_PAREN:
|
||||
if depth == 0:
|
||||
boundary_index = i
|
||||
break
|
||||
depth -= 1
|
||||
continue
|
||||
|
||||
if depth > 0:
|
||||
continue
|
||||
|
||||
if tok.token_type == TokenType.ON and on_index is None:
|
||||
on_index = i
|
||||
continue
|
||||
|
||||
if tok.token_type == TokenType.WHERE:
|
||||
boundary_index = i
|
||||
break
|
||||
|
||||
if tok.token_type in _JOIN_STARTS or tok.token_type in _CLAUSE_ENDS:
|
||||
boundary_index = i
|
||||
break
|
||||
|
||||
return on_index, boundary_index
|
||||
@@ -40,17 +40,19 @@ def apply_rls(
|
||||
|
||||
:returns: True if any RLS predicates were actually applied, False otherwise.
|
||||
"""
|
||||
# There are two ways to insert RLS: either replacing the table with a subquery
|
||||
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
|
||||
# safer, but not supported in all databases.
|
||||
method = database.db_engine_spec.get_rls_method()
|
||||
# There are three ways to insert RLS:
|
||||
# - replace the table with a subquery containing the RLS (safest, but not
|
||||
# supported in all databases)
|
||||
# - append the RLS to the ``WHERE`` clause via AST transformation
|
||||
# - splice the RLS into the original SQL string (preserves dialect-specific
|
||||
# syntax that the sqlglot generator would otherwise transpile)
|
||||
method = database.db_engine_spec.rls_method
|
||||
|
||||
# collect all RLS predicates for all tables in the query
|
||||
predicates: dict[Table, list[Any]] = {}
|
||||
predicates: dict[Table, list[str]] = {}
|
||||
for table in parsed_statement.tables:
|
||||
table = table.qualify(catalog=catalog, schema=schema)
|
||||
predicates[table] = [
|
||||
parsed_statement.parse_predicate(predicate)
|
||||
raw_predicates = [
|
||||
predicate
|
||||
for predicate in get_predicates_for_table(
|
||||
table,
|
||||
database,
|
||||
@@ -58,6 +60,7 @@ def apply_rls(
|
||||
)
|
||||
if predicate
|
||||
]
|
||||
predicates[table] = raw_predicates
|
||||
|
||||
has_predicates = any(predicates.values())
|
||||
parsed_statement.apply_rls(catalog, schema, predicates, method)
|
||||
|
||||
@@ -36,7 +36,7 @@ from sqlalchemy.sql import sqltypes
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import OAuth2RedirectError
|
||||
from superset.sql.parse import Table
|
||||
from superset.sql.parse import RLSMethod, Table
|
||||
from superset.superset_typing import (
|
||||
OAuth2ClientConfig,
|
||||
OAuth2State,
|
||||
@@ -1283,3 +1283,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
|
||||
error = exc_info.value.error
|
||||
|
||||
assert error.extra["redirect_uri"] == fallback_uri
|
||||
|
||||
|
||||
def test_default_rls_method_is_subquery() -> None:
|
||||
"""Base engine spec defaults to subquery-based RLS."""
|
||||
assert BaseEngineSpec.rls_method == RLSMethod.AS_SUBQUERY
|
||||
|
||||
@@ -209,7 +209,7 @@ class TestApplyRlsReturnValue:
|
||||
from superset.utils.rls import apply_rls
|
||||
|
||||
database = MagicMock()
|
||||
database.db_engine_spec.get_rls_method.return_value = MagicMock()
|
||||
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
|
||||
database.get_default_catalog.return_value = None
|
||||
|
||||
statement = MagicMock()
|
||||
@@ -237,7 +237,7 @@ class TestApplyRlsReturnValue:
|
||||
mock_get_predicates.return_value = []
|
||||
|
||||
database = MagicMock()
|
||||
database.db_engine_spec.get_rls_method.return_value = MagicMock()
|
||||
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
|
||||
database.get_default_catalog.return_value = None
|
||||
|
||||
mock_table = MagicMock()
|
||||
@@ -268,7 +268,7 @@ class TestApplyRlsReturnValue:
|
||||
mock_get_predicates.return_value = ["user_id = 42"]
|
||||
|
||||
database = MagicMock()
|
||||
database.db_engine_spec.get_rls_method.return_value = MagicMock()
|
||||
database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
|
||||
database.get_default_catalog.return_value = None
|
||||
|
||||
mock_table = MagicMock()
|
||||
@@ -276,8 +276,6 @@ class TestApplyRlsReturnValue:
|
||||
|
||||
statement = MagicMock()
|
||||
statement.tables = [mock_table]
|
||||
statement.parse_predicate.return_value = MagicMock()
|
||||
|
||||
result = apply_rls(
|
||||
database=database,
|
||||
catalog=None,
|
||||
@@ -312,11 +310,10 @@ class TestRLSSubqueryAlias:
|
||||
"""
|
||||
sql = "SELECT pens.pen_id, pens.is_green FROM public.pens"
|
||||
statement = SQLStatement(sql, engine="redshift")
|
||||
predicate = statement.parse_predicate("user_id = 1")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
"public",
|
||||
{Table("pens", "public", None): [predicate]},
|
||||
{Table("pens", "public", None): ["user_id = 1"]},
|
||||
RLSMethod.AS_SUBQUERY,
|
||||
)
|
||||
result = statement.format()
|
||||
@@ -333,11 +330,10 @@ class TestRLSSubqueryAlias:
|
||||
"""
|
||||
sql = "SELECT pens.pen_id, pens.is_green FROM mycat.public.pens"
|
||||
statement = SQLStatement(sql, engine="redshift")
|
||||
predicate = statement.parse_predicate("user_id = 1")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
"public",
|
||||
{Table("pens", "public", "mycat"): [predicate]},
|
||||
{Table("pens", "public", "mycat"): ["user_id = 1"]},
|
||||
RLSMethod.AS_SUBQUERY,
|
||||
)
|
||||
result = statement.format()
|
||||
@@ -351,11 +347,10 @@ class TestRLSSubqueryAlias:
|
||||
"""
|
||||
sql = "SELECT p.pen_id, p.is_green FROM public.pens p"
|
||||
statement = SQLStatement(sql, engine="redshift")
|
||||
predicate = statement.parse_predicate("user_id = 1")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
"public",
|
||||
{Table("pens", "public", None): [predicate]},
|
||||
{Table("pens", "public", None): ["user_id = 1"]},
|
||||
RLSMethod.AS_SUBQUERY,
|
||||
)
|
||||
result = statement.format()
|
||||
@@ -369,11 +364,10 @@ class TestRLSSubqueryAlias:
|
||||
"""
|
||||
sql = "SELECT pen_id, is_green FROM public.pens"
|
||||
statement = SQLStatement(sql, engine="redshift")
|
||||
predicate = statement.parse_predicate("user_id = 1")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
"public",
|
||||
{Table("pens", "public", None): [predicate]},
|
||||
{Table("pens", "public", None): ["user_id = 1"]},
|
||||
RLSMethod.AS_SUBQUERY,
|
||||
)
|
||||
result = statement.format()
|
||||
|
||||
@@ -1704,6 +1704,21 @@ def test_set_limit_value(
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
def test_set_limit_value_after_splice_reparses_from_raw_sql() -> None:
|
||||
"""
|
||||
When a statement has cached verbatim SQL from splice-mode rewrites, setting
|
||||
limit should reparse that SQL before mutating the AST.
|
||||
"""
|
||||
statement = SQLStatement("SELECT * FROM some_table", "postgresql")
|
||||
statement._raw_sql = "SELECT * FROM some_table WHERE tenant_id = 42"
|
||||
|
||||
statement.set_limit_value(10, LimitMethod.FORCE_LIMIT)
|
||||
formatted = statement.format()
|
||||
|
||||
assert "tenant_id = 42" in formatted
|
||||
assert "LIMIT 10" in formatted
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"kql, limit, expected",
|
||||
[
|
||||
@@ -2198,7 +2213,7 @@ def test_rls_subquery_transformer(
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{k: [parse_one(v)] for k, v in rules.items()},
|
||||
{k: [v] for k, v in rules.items()},
|
||||
RLSMethod.AS_SUBQUERY,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
@@ -2542,12 +2557,312 @@ def test_rls_predicate_transformer(
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{k: [parse_one(v)] for k, v in rules.items()},
|
||||
{k: [v] for k, v in rules.items()},
|
||||
RLSMethod.AS_PREDICATE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, rules, expected",
|
||||
[
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
|
||||
"SELECT t.foo FROM some_table AS t WHERE t.id = 42",
|
||||
),
|
||||
(
|
||||
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
|
||||
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
|
||||
"SELECT t.foo FROM some_table AS t WHERE t.id = 42 "
|
||||
"AND (bar = 'baz' OR foo = 'qux')",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
|
||||
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
|
||||
"AND (table.id = other_table.id)",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table",
|
||||
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id = 42",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
|
||||
"WHERE 1=1",
|
||||
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
|
||||
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
|
||||
"AND (table.id = other_table.id) WHERE 1=1",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_splice_semantics_match_predicate(
|
||||
sql: str,
|
||||
rules: dict[Table, str],
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Splice mode should preserve predicate-mode semantics for boolean grouping
|
||||
and JOIN-vs-WHERE placement.
|
||||
"""
|
||||
statement = SQLStatement(sql)
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{k: [v] for k, v in rules.items()},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, rules, expected",
|
||||
[
|
||||
# Simple — no WHERE clause to extend.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42",
|
||||
),
|
||||
# Append to an existing WHERE clause.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table "
|
||||
"WHERE some_table.tenant_id = 42 AND (status = 'open')",
|
||||
),
|
||||
# WHERE precedes GROUP BY: predicate goes before GROUP BY.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table "
|
||||
"WHERE some_table.tenant_id = 42 AND (status = 'open') GROUP BY d",
|
||||
),
|
||||
# No WHERE, but GROUP BY and ORDER BY are present.
|
||||
(
|
||||
"SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"GROUP BY d ORDER BY d",
|
||||
),
|
||||
# JOIN — predicate scoped to one of the tables.
|
||||
(
|
||||
"SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id",
|
||||
{Table("some_table", "schema1", "catalog1"): "o.tenant_id = 42"},
|
||||
"SELECT o.id FROM some_table o JOIN locations l "
|
||||
"ON o.loc_id = l.id WHERE o.tenant_id = 42",
|
||||
),
|
||||
# JOIN — different predicate per table, both spliced into one WHERE.
|
||||
(
|
||||
"SELECT * FROM some_table JOIN events ON some_table.id = events.order_id",
|
||||
{
|
||||
Table("events", "schema1", "catalog1"): "events.user_id = 99",
|
||||
Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42",
|
||||
},
|
||||
"SELECT * FROM some_table JOIN events "
|
||||
"ON events.user_id = 99 AND (some_table.id = events.order_id) "
|
||||
"WHERE some_table.tenant_id = 42",
|
||||
),
|
||||
# Subquery in FROM — splice into the inner SELECT.
|
||||
(
|
||||
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
|
||||
"WHERE some_table.tenant_id = 42) sub",
|
||||
),
|
||||
# CTE — splice into the CTE body.
|
||||
(
|
||||
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
|
||||
"WHERE some_table.tenant_id = 42) SELECT * FROM cte",
|
||||
),
|
||||
# Dialect-specific function (LAST_DAY) preserved verbatim.
|
||||
(
|
||||
"SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT id, LAST_DAY(created_at) FROM some_table "
|
||||
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
|
||||
),
|
||||
# Multiline + inline comment preserved exactly.
|
||||
(
|
||||
"SELECT LAST_DAY(created_at) -- last day of month\n"
|
||||
"FROM some_table\n"
|
||||
"WHERE region = 'US'",
|
||||
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
|
||||
"SELECT LAST_DAY(created_at) -- last day of month\n"
|
||||
"FROM some_table\n"
|
||||
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
|
||||
),
|
||||
# Schema-qualified table name (no default schema match) — no predicate.
|
||||
(
|
||||
"SELECT t.foo FROM schema2.some_table AS t",
|
||||
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
|
||||
"SELECT t.foo FROM schema2.some_table AS t",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_splice(
|
||||
sql: str,
|
||||
rules: dict[Table, str],
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Test the splice-mode RLS via ``RLSMethod.AS_PREDICATE_SPLICE``.
|
||||
|
||||
Splice mode rewrites the original SQL string instead of re-rendering the
|
||||
AST through the dialect generator, so byte-level fidelity (including
|
||||
dialect-specific functions, comments, and whitespace) is preserved.
|
||||
"""
|
||||
statement = SQLStatement(sql)
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{k: [v] for k, v in rules.items()},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
def test_rls_predicate_splice_requires_source() -> None:
|
||||
"""
|
||||
Splice mode requires the original SQL substring; constructing a statement
|
||||
purely from an AST should make splice mode raise.
|
||||
"""
|
||||
ast = parse_one("SELECT * FROM some_table")
|
||||
statement = SQLStatement(ast=ast, engine="postgresql")
|
||||
with pytest.raises(ValueError, match="Splice-mode RLS requires the source SQL"):
|
||||
statement.apply_rls(
|
||||
"catalog1",
|
||||
"schema1",
|
||||
{Table("some_table", "schema1", "catalog1"): ["id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
|
||||
|
||||
def test_rls_predicate_splice_preserves_dialect_function() -> None:
|
||||
"""
|
||||
Splice mode must NOT round-trip through the sqlglot generator. ``LAST_DAY``
|
||||
on the postgres dialect would otherwise be transpiled by the generator.
|
||||
"""
|
||||
sql = "SELECT LAST_DAY(d) FROM some_table"
|
||||
statement = SQLStatement(sql, engine="postgresql")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["some_table.tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == (
|
||||
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42"
|
||||
)
|
||||
|
||||
|
||||
def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
|
||||
"""
|
||||
Splice mode accepts predicate strings directly — no ``parse_predicate`` is
|
||||
needed at the call site.
|
||||
"""
|
||||
sql = "SELECT * FROM some_table"
|
||||
statement = SQLStatement(sql, engine="postgresql")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["some_table.tenant_id = 42 AND some_table.active"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == (
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 AND some_table.active"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, expected",
|
||||
[
|
||||
(
|
||||
"SELECT * FROM some_table -- hi\nGROUP BY id",
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"-- hi\nGROUP BY id",
|
||||
),
|
||||
(
|
||||
"SELECT * FROM some_table /* inline */ GROUP BY id",
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"/* inline */ GROUP BY id",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -> None:
|
||||
"""
|
||||
Splice mode should insert predicates before comments that precede the next
|
||||
clause boundary, so comments do not swallow the injected SQL.
|
||||
"""
|
||||
statement = SQLStatement(sql, engine="postgresql")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["some_table.tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, engine, expected",
|
||||
[
|
||||
(
|
||||
"SELECT * FROM some_table QUALIFY row_number() OVER "
|
||||
"(PARTITION BY id ORDER BY ts DESC) = 1",
|
||||
"snowflake",
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
|
||||
),
|
||||
(
|
||||
"SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY id)",
|
||||
"postgresql",
|
||||
"SELECT sum(v) OVER () FROM some_table "
|
||||
"WHERE some_table.tenant_id = 42 "
|
||||
"WINDOW w AS (PARTITION BY id)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_rls_predicate_splice_handles_additional_clause_boundaries(
|
||||
sql: str,
|
||||
engine: str,
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Splice mode should insert WHERE before clause types that can legally follow
|
||||
FROM/WHERE (for example QUALIFY and WINDOW).
|
||||
"""
|
||||
statement = SQLStatement(sql, engine=engine)
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["some_table.tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
assert statement.format() == expected
|
||||
|
||||
|
||||
def test_rls_predicate_splice_then_limit_keeps_rls() -> None:
|
||||
"""
|
||||
LIMIT rewrites after splice-mode RLS should retain injected predicates.
|
||||
"""
|
||||
statement = SQLStatement("SELECT * FROM some_table", engine="postgresql")
|
||||
statement.apply_rls(
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["tenant_id = 42"]},
|
||||
RLSMethod.AS_PREDICATE_SPLICE,
|
||||
)
|
||||
statement.set_limit_value(101, LimitMethod.FORCE_LIMIT)
|
||||
|
||||
formatted = statement.format()
|
||||
assert "some_table.tenant_id = 42" in formatted
|
||||
assert "LIMIT 101" in formatted
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql, table, expected",
|
||||
[
|
||||
|
||||
292
tests/unit_tests/sql/rls_splice_unit_tests.py
Normal file
292
tests/unit_tests/sql/rls_splice_unit_tests.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import pytest
|
||||
import sqlglot
|
||||
from sqlglot import Dialect, exp
|
||||
|
||||
from superset.sql.parse import SQLStatement, Table
|
||||
from superset.sql.rls_splice import (
|
||||
_classify_source_predicate,
|
||||
_find_condition_end,
|
||||
_find_join_splice,
|
||||
_find_where_splice,
|
||||
_scan_join_clause,
|
||||
_scan_until_scope_boundary,
|
||||
_splices_for_scope,
|
||||
_table_end,
|
||||
apply_rls_splice,
|
||||
)
|
||||
|
||||
|
||||
def _tokenize(sql: str) -> list[sqlglot.tokens.Token]:
|
||||
return list(Dialect.get_or_raise(None).tokenize(sql))
|
||||
|
||||
|
||||
def _token_index(tokens: list[sqlglot.tokens.Token], token_type: object) -> int:
|
||||
return next(i for i, token in enumerate(tokens) if token.token_type == token_type)
|
||||
|
||||
|
||||
def _token_by_text(
|
||||
tokens: list[sqlglot.tokens.Token], text: str
|
||||
) -> sqlglot.tokens.Token:
|
||||
return next(token for token in tokens if token.text == text)
|
||||
|
||||
|
||||
def test_split_source_returns_none_result_when_tokenize_fails(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _BrokenDialect:
|
||||
@staticmethod
|
||||
def tokenize(_: str) -> list[sqlglot.tokens.Token]:
|
||||
raise sqlglot.errors.SqlglotError("boom")
|
||||
|
||||
monkeypatch.setattr(
|
||||
"superset.sql.parse.Dialect.get_or_raise",
|
||||
lambda _: _BrokenDialect(),
|
||||
)
|
||||
assert SQLStatement._split_source("SELECT 1", "postgresql", 2) == [None, None]
|
||||
|
||||
|
||||
def test_apply_rls_splice_ignores_empty_predicates() -> None:
|
||||
sql = "SELECT 1"
|
||||
assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql
|
||||
|
||||
|
||||
def test_apply_rls_splice_ignores_dash_dash_inside_string_literal() -> None:
|
||||
"""
|
||||
Regression: the splice point must not be confused by ``--`` appearing
|
||||
inside a string literal. Earlier ``rfind("--", ...)`` logic mistook this
|
||||
for an inline comment and inserted the predicate inside the quoted text.
|
||||
"""
|
||||
sql = "SELECT * FROM some_table WHERE note = '--x' GROUP BY id"
|
||||
spliced = apply_rls_splice(
|
||||
sql,
|
||||
None,
|
||||
None,
|
||||
{Table("some_table"): ["some_table.tenant_id = 42"]},
|
||||
dialect="postgres",
|
||||
)
|
||||
assert spliced == (
|
||||
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
|
||||
"AND (note = '--x') GROUP BY id"
|
||||
)
|
||||
|
||||
|
||||
def test_table_end_returns_none_without_metadata() -> None:
|
||||
source = exp.Table(this=exp.Identifier(this="foo"))
|
||||
assert _table_end(source) is None
|
||||
|
||||
|
||||
def test_classify_source_predicate_returns_none_without_table_metadata() -> None:
|
||||
source = exp.Table(this=exp.Identifier(this="foo"))
|
||||
exp.From(this=source)
|
||||
result = _classify_source_predicate(
|
||||
source,
|
||||
{Table("foo"): ["id = 1"]},
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
assert result == ("none", None, None)
|
||||
|
||||
|
||||
def test_classify_source_predicate_returns_none_for_unsupported_parent() -> None:
|
||||
source = exp.Table(this=exp.Identifier(this="foo"))
|
||||
source.this.meta["end"] = 3
|
||||
exp.Alias(this=source, alias=exp.Identifier(this="alias"))
|
||||
result = _classify_source_predicate(
|
||||
source,
|
||||
{Table("foo"): ["id = 1"]},
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
assert result == ("none", None, None)
|
||||
|
||||
|
||||
def test_scan_until_scope_boundary_tracks_parenthesis_depth() -> None:
|
||||
sql = "SELECT * FROM t WHERE (a = 1)"
|
||||
tokens = _tokenize(sql)
|
||||
where_token = _token_by_text(tokens, "WHERE")
|
||||
assert _scan_until_scope_boundary(
|
||||
tokens, where_token.start, stop_at_join=False
|
||||
) == (
|
||||
"eof",
|
||||
None,
|
||||
)
|
||||
|
||||
|
||||
def test_find_condition_end_handles_subquery_closing_paren() -> None:
|
||||
sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)"
|
||||
tokens = _tokenize(sql)
|
||||
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
|
||||
end = _find_condition_end(tokens, where_index, stop_at_join=False)
|
||||
assert sql[end] == ")"
|
||||
|
||||
|
||||
def test_find_condition_end_handles_parenthesized_expression() -> None:
|
||||
sql = "SELECT * FROM t WHERE (a = 1)"
|
||||
tokens = _tokenize(sql)
|
||||
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
|
||||
end = _find_condition_end(tokens, where_index, stop_at_join=False)
|
||||
assert end == len(sql)
|
||||
|
||||
|
||||
def test_find_where_splice_handles_trailing_where_keyword() -> None:
|
||||
sql = "SELECT * FROM t WHERE"
|
||||
tokens = _tokenize(sql)
|
||||
splices = _find_where_splice(sql, tokens, anchor=0, pred_sql="t.id = 1")
|
||||
assert splices == [(len(sql), " t.id = 1")]
|
||||
|
||||
|
||||
def test_find_join_splice_handles_trailing_on_keyword() -> None:
|
||||
sql = "SELECT * FROM a JOIN b ON"
|
||||
tokens = _tokenize(sql)
|
||||
b_token = _token_by_text(tokens, "b")
|
||||
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
|
||||
assert splices == [(len(sql), " b.id = 1")]
|
||||
|
||||
|
||||
def test_find_join_splice_inserts_on_before_where_boundary() -> None:
|
||||
sql = "SELECT * FROM a JOIN b WHERE x = 1"
|
||||
tokens = _tokenize(sql)
|
||||
b_token = _token_by_text(tokens, "b")
|
||||
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
|
||||
assert splices == [(sql.index("WHERE") - 1, " ON b.id = 1")]
|
||||
|
||||
|
||||
def test_scan_join_clause_covers_nested_parentheses_and_join_boundary() -> None:
|
||||
sql = "SELECT * FROM a JOIN b ON (a.id = b.id) JOIN c ON 1 = 1"
|
||||
tokens = _tokenize(sql)
|
||||
b_token = _token_by_text(tokens, "b")
|
||||
on_index, boundary_index = _scan_join_clause(tokens, b_token.end)
|
||||
assert on_index is not None
|
||||
assert boundary_index is not None
|
||||
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.JOIN
|
||||
|
||||
|
||||
def test_scan_join_clause_stops_at_outer_closing_paren() -> None:
|
||||
sql = "SELECT * FROM (SELECT * FROM a JOIN b) sub"
|
||||
tokens = _tokenize(sql)
|
||||
b_token = _token_by_text(tokens, "b")
|
||||
_, boundary_index = _scan_join_clause(tokens, b_token.end)
|
||||
assert boundary_index is not None
|
||||
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.R_PAREN
|
||||
|
||||
|
||||
def test_splices_for_scope_handles_empty_join_splice_result(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _Scope:
|
||||
sources = {"x": object()}
|
||||
|
||||
sql = "SELECT 1"
|
||||
tokens = _tokenize(sql)
|
||||
monkeypatch.setattr(
|
||||
"superset.sql.rls_splice._classify_source_predicate",
|
||||
lambda *args, **kwargs: ("join", 0, "x.id = 1"),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"superset.sql.rls_splice._find_join_splice",
|
||||
lambda *args, **kwargs: [],
|
||||
)
|
||||
assert (
|
||||
_splices_for_scope(
|
||||
sql,
|
||||
tokens,
|
||||
_Scope(),
|
||||
{Table("x"): ["x.id = 1"]},
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
== []
|
||||
)
|
||||
|
||||
|
||||
def test_splices_for_scope_combines_join_and_from_splices(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _Scope:
|
||||
sources = {"f": object(), "j": object()}
|
||||
|
||||
sql = "SELECT 1"
|
||||
tokens = _tokenize(sql)
|
||||
calls = [("from", 3, "f.id = 1"), ("join", 6, "j.id = 2")]
|
||||
|
||||
def _fake_classify(*args: object, **kwargs: object) -> tuple[str, int, str]:
|
||||
return calls.pop(0)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"superset.sql.rls_splice._find_join_splice",
|
||||
lambda *args, **kwargs: [(50, " ON j.id = 2")],
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"superset.sql.rls_splice._find_where_splice",
|
||||
lambda *args, **kwargs: [(20, " WHERE f.id = 1")],
|
||||
)
|
||||
|
||||
assert _splices_for_scope(
|
||||
sql,
|
||||
tokens,
|
||||
_Scope(),
|
||||
{Table("f"): ["id = 1"], Table("j"): ["id = 2"]},
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
) == [(50, " ON j.id = 2"), (20, " WHERE f.id = 1")]
|
||||
|
||||
|
||||
def test_splices_for_scope_join_then_next_source(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
class _Scope:
|
||||
sources = {"j": object(), "f": object()}
|
||||
|
||||
sql = "SELECT 1"
|
||||
tokens = _tokenize(sql)
|
||||
calls = [("join", 3, "j.id = 2"), ("none", None, None)]
|
||||
|
||||
def _fake_classify(
|
||||
*args: object, **kwargs: object
|
||||
) -> tuple[str, int | None, str | None]:
|
||||
return calls.pop(0)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"superset.sql.rls_splice._find_join_splice",
|
||||
lambda *args, **kwargs: [],
|
||||
)
|
||||
|
||||
assert (
|
||||
_splices_for_scope(
|
||||
sql,
|
||||
tokens,
|
||||
_Scope(),
|
||||
{Table("j"): ["id = 2"]},
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
== []
|
||||
)
|
||||
Reference in New Issue
Block a user