Compare commits

...

13 Commits

Author SHA1 Message Date
Beto Dealmeida
20b4f33710 Small fixes 2026-05-21 10:41:55 -04:00
Beto Dealmeida
c1b7a2d2ee Bump code coverage to 100% 2026-05-20 12:11:19 -04:00
Beto Dealmeida
0ce43abe5b Address comments 2026-05-12 10:17:50 -04:00
Beto Dealmeida
439130db54 Add dialects for Exa/Solr 2026-05-12 09:59:55 -04:00
Beto Dealmeida
ec9f2da81e Add to DB without sqlglot dialect 2026-05-12 09:41:30 -04:00
Beto Dealmeida
549e51bdc4 Fix lint 2026-05-12 09:25:57 -04:00
Beto Dealmeida
72ecb20e5c Run ruff-format 2026-05-08 17:27:11 -04:00
Beto Dealmeida
d1cd84931e Increase coverage 2026-05-08 16:40:29 -04:00
Beto Dealmeida
884649e3ed Increase parity 2026-05-08 15:32:10 -04:00
Beto Dealmeida
e9dd9a6107 Simplify function signature 2026-05-08 14:28:21 -04:00
Beto Dealmeida
0fe8293c7f Simplify 2026-05-08 14:12:46 -04:00
Beto Dealmeida
af2d3babec Improvements 2026-05-08 13:49:37 -04:00
Beto Dealmeida
3fe2b2505f feat: new splice RLSMethod 2026-05-08 13:33:42 -04:00
22 changed files with 1303 additions and 48 deletions

View File

@@ -557,6 +557,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
# if True, database will be listed as option in the upload file form # if True, database will be listed as option in the upload file form
supports_file_upload = True supports_file_upload = True
# RLS strategy for this engine spec. Override in engine-specific classes as
# needed (for example ``RLSMethod.AS_PREDICATE`` for engines that don't
# support subquery-based RLS, or ``RLSMethod.AS_PREDICATE_SPLICE`` for
# engines where sqlglot generation is not faithful).
rls_method = RLSMethod.AS_SUBQUERY
# Is the DB engine spec able to change the default schema? This requires implementing # noqa: E501 # Is the DB engine spec able to change the default schema? This requires implementing # noqa: E501
# a custom `adjust_engine_params` method. # a custom `adjust_engine_params` method.
supports_dynamic_schema = False supports_dynamic_schema = False
@@ -618,21 +624,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
else cls.encrypted_extra_sensitive_fields else cls.encrypted_extra_sensitive_fields
) )
@classmethod
def get_rls_method(cls) -> RLSMethod:
"""
Returns the RLS method to be used for this engine.
There are two ways to insert RLS: either replacing the table with a subquery
that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is
safer, but not supported in all databases.
"""
return (
RLSMethod.AS_SUBQUERY
if cls.allows_subqueries and cls.allows_alias_in_select
else RLSMethod.AS_PREDICATE
)
@classmethod @classmethod
def is_oauth2_enabled(cls) -> bool: def is_oauth2_enabled(cls) -> bool:
return ( return (

View File

@@ -35,6 +35,7 @@ from superset.db_engine_specs.base import (
DatabaseCategory, DatabaseCategory,
) )
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql.parse import RLSMethod
from superset.utils.network import is_hostname_valid, is_port_open from superset.utils.network import is_hostname_valid, is_port_open
@@ -82,6 +83,7 @@ class CouchbaseEngineSpec(BasicParametersMixin, BaseEngineSpec):
default_driver = "couchbase" default_driver = "couchbase"
allows_joins = False allows_joins = False
allows_subqueries = False allows_subqueries = False
rls_method = RLSMethod.AS_PREDICATE
sqlalchemy_uri_placeholder = ( sqlalchemy_uri_placeholder = (
"couchbase://user:password@host[:port]?truststorepath=value?ssl=value" "couchbase://user:password@host[:port]?truststorepath=value?ssl=value"
) )

View File

@@ -23,6 +23,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
if TYPE_CHECKING: if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn from superset.connectors.sqla.models import TableColumn
@@ -68,6 +69,8 @@ class CrateEngineSpec(BaseEngineSpec):
TimeGrain.YEAR: "DATE_TRUNC('year', {col})", TimeGrain.YEAR: "DATE_TRUNC('year', {col})",
} }
rls_method = RLSMethod.AS_PREDICATE_SPLICE
@classmethod @classmethod
def epoch_to_dttm(cls) -> str: def epoch_to_dttm(cls) -> str:
return "{col} * 1000" return "{col} * 1000"

View File

@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import (
) )
from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError from superset.db_engine_specs.exceptions import SupersetDBAPIDatabaseError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.sql.parse import RLSMethod
from superset.utils.core import GenericDataType from superset.utils.core import GenericDataType
from superset.utils.hashing import hash_from_str from superset.utils.hashing import hash_from_str
from superset.utils.network import is_hostname_valid, is_port_open from superset.utils.network import is_hostname_valid, is_port_open
@@ -55,6 +56,8 @@ class DatabendBaseEngineSpec(BaseEngineSpec):
time_secondary_columns = True time_secondary_columns = True
time_groupby_inline = True time_groupby_inline = True
rls_method = RLSMethod.AS_PREDICATE_SPLICE
_time_grain_expressions = { _time_grain_expressions = {
None: "{col}", None: "{col}",
TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})", TimeGrain.SECOND: "DATE_TRUNC('SECOND', {col})",

View File

@@ -26,6 +26,7 @@ from superset.db_engine_specs.base import (
DatabaseCategory, DatabaseCategory,
) )
from superset.errors import SupersetErrorType from superset.errors import SupersetErrorType
from superset.sql.parse import RLSMethod
# Internal class for defining error message patterns (for translation) # Internal class for defining error message patterns (for translation)
@@ -58,6 +59,8 @@ class DenodoEngineSpec(BaseEngineSpec, BasicParametersMixin):
engine = "denodo" engine = "denodo"
engine_name = "Denodo" engine_name = "Denodo"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
default_driver = "psycopg2" default_driver = "psycopg2"
sqlalchemy_uri_placeholder = ( sqlalchemy_uri_placeholder = (
"denodo://user:password@host:port/dbname[?key=value&key=value...]" "denodo://user:password@host:port/dbname[?key=value&key=value...]"

View File

@@ -21,12 +21,15 @@ from sqlalchemy import types
from superset.constants import TimeGrain from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class DynamoDBEngineSpec(BaseEngineSpec): class DynamoDBEngineSpec(BaseEngineSpec):
engine = "dynamodb" engine = "dynamodb"
engine_name = "Amazon DynamoDB" engine_name = "Amazon DynamoDB"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = { metadata = {
"description": ( "description": (
"Amazon DynamoDB is a serverless NoSQL database with SQL via PartiQL." "Amazon DynamoDB is a serverless NoSQL database with SQL via PartiQL."

View File

@@ -28,6 +28,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIOperationalError, SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError, SupersetDBAPIProgrammingError,
) )
from superset.sql.parse import RLSMethod
logger = logging.getLogger() logger = logging.getLogger()
@@ -39,6 +40,7 @@ class ElasticSearchEngineSpec(BaseEngineSpec): # pylint: disable=abstract-metho
allows_joins = False allows_joins = False
allows_subqueries = True allows_subqueries = True
allows_sql_comments = False allows_sql_comments = False
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = { metadata = {
"description": ( "description": (

View File

@@ -21,7 +21,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import LimitMethod from superset.sql.parse import LimitMethod, RLSMethod
class FirebirdEngineSpec(BaseEngineSpec): class FirebirdEngineSpec(BaseEngineSpec):
@@ -53,6 +53,8 @@ class FirebirdEngineSpec(BaseEngineSpec):
# Firebird uses FIRST to limit: `SELECT FIRST 10 * FROM table` # Firebird uses FIRST to limit: `SELECT FIRST 10 * FROM table`
limit_method = LimitMethod.FETCH_MANY limit_method = LimitMethod.FETCH_MANY
rls_method = RLSMethod.AS_PREDICATE_SPLICE
_time_grain_expressions = { _time_grain_expressions = {
None: "{col}", None: "{col}",
TimeGrain.SECOND: ( TimeGrain.SECOND: (

View File

@@ -14,6 +14,8 @@
# KIND, either express or implied. See the License for the # KIND, either express or implied. See the License for the
# specific language governing permissions and limitations # specific language governing permissions and limitations
# under the License. # under the License.
from superset.sql.parse import RLSMethod
from .db2 import Db2EngineSpec from .db2 import Db2EngineSpec
@@ -28,6 +30,8 @@ class IBMiEngineSpec(Db2EngineSpec):
engine_name = "IBM Db2 for i" engine_name = "IBM Db2 for i"
max_column_name_length = 128 max_column_name_length = 128
rls_method = RLSMethod.AS_PREDICATE_SPLICE
@classmethod @classmethod
def epoch_to_dttm(cls) -> str: def epoch_to_dttm(cls) -> str:
return "(DAYS({col}) - DAYS('1970-01-01')) * 86400 + MIDNIGHT_SECONDS({col})" return "(DAYS({col}) - DAYS('1970-01-01')) * 86400 + MIDNIGHT_SECONDS({col})"

View File

@@ -28,7 +28,7 @@ from superset.db_engine_specs.exceptions import (
SupersetDBAPIOperationalError, SupersetDBAPIOperationalError,
SupersetDBAPIProgrammingError, SupersetDBAPIProgrammingError,
) )
from superset.sql.parse import LimitMethod from superset.sql.parse import LimitMethod, RLSMethod
from superset.utils.core import GenericDataType from superset.utils.core import GenericDataType
@@ -40,6 +40,7 @@ class KustoSqlEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
allows_joins = True allows_joins = True
allows_subqueries = True allows_subqueries = True
allows_sql_comments = False allows_sql_comments = False
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = { metadata = {
"description": ( "description": (

View File

@@ -21,6 +21,7 @@ from sqlalchemy import types
from superset.constants import TimeGrain from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@@ -29,6 +30,8 @@ class KylinEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
engine = "kylin" engine = "kylin"
engine_name = "Apache Kylin" engine_name = "Apache Kylin"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
metadata = { metadata = {
"description": "Apache Kylin is an open-source OLAP engine for big data.", "description": "Apache Kylin is an open-source OLAP engine for big data.",
"logo": "apache-kylin.png", "logo": "apache-kylin.png",

View File

@@ -39,6 +39,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.errors import SupersetErrorType from superset.errors import SupersetErrorType
from superset.models.core import Database from superset.models.core import Database
from superset.models.sql_lab import Query from superset.models.sql_lab import Query
from superset.sql.parse import RLSMethod
# Regular expressions to catch custom errors # Regular expressions to catch custom errors
@@ -227,6 +228,8 @@ class OcientEngineSpec(BaseEngineSpec):
force_column_alias_quotes = True force_column_alias_quotes = True
max_column_name_length = 30 max_column_name_length = 30
rls_method = RLSMethod.AS_PREDICATE_SPLICE
allows_cte_in_subquery = False allows_cte_in_subquery = False
# Ocient does not support cte names starting with underscores # Ocient does not support cte names starting with underscores
cte_alias = "cte__" cte_alias = "cte__"

View File

@@ -20,6 +20,7 @@ from sqlalchemy.types import TypeEngine
from superset.constants import TimeGrain from superset.constants import TimeGrain
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class PinotEngineSpec(BaseEngineSpec): class PinotEngineSpec(BaseEngineSpec):
@@ -30,6 +31,7 @@ class PinotEngineSpec(BaseEngineSpec):
allows_joins = False allows_joins = False
allows_alias_in_select = False allows_alias_in_select = False
allows_alias_in_orderby = False allows_alias_in_orderby = False
rls_method = RLSMethod.AS_PREDICATE
# pinotdb only sets cursor.description when the response contains # pinotdb only sets cursor.description when the response contains
# columnDataTypes, which Pinot omits for zero-row results. # columnDataTypes, which Pinot omits for zero-row results.

View File

@@ -16,6 +16,7 @@
# under the License. # under the License.
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
@@ -27,6 +28,7 @@ class SolrEngineSpec(BaseEngineSpec): # pylint: disable=abstract-method
time_groupby_inline = False time_groupby_inline = False
allows_joins = False allows_joins = False
allows_subqueries = False allows_subqueries = False
rls_method = RLSMethod.AS_PREDICATE
metadata = { metadata = {
"description": "Apache Solr is an open-source enterprise search platform.", "description": "Apache Solr is an open-source enterprise search platform.",

View File

@@ -22,6 +22,7 @@ from urllib import parse
from sqlalchemy.engine.url import make_url, URL # noqa: F401 from sqlalchemy.engine.url import make_url, URL # noqa: F401
from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory from superset.db_engine_specs.base import BaseEngineSpec, DatabaseCategory
from superset.sql.parse import RLSMethod
class TDengineEngineSpec(BaseEngineSpec): class TDengineEngineSpec(BaseEngineSpec):
@@ -29,6 +30,8 @@ class TDengineEngineSpec(BaseEngineSpec):
engine_name = "TDengine" engine_name = "TDengine"
max_column_name_length = 64 max_column_name_length = 64
default_driver = "taosws" default_driver = "taosws"
rls_method = RLSMethod.AS_PREDICATE_SPLICE
sqlalchemy_uri_placeholder = ( sqlalchemy_uri_placeholder = (
"taosws://user:******@host:port/dbname[?key=value&key=value...]" "taosws://user:******@host:port/dbname[?key=value&key=value...]"
) )

View File

@@ -76,7 +76,7 @@ SQLGLOT_DIALECTS = {
"duckdb": Dialects.DUCKDB, "duckdb": Dialects.DUCKDB,
# "dynamodb": ??? # "dynamodb": ???
# "elasticsearch": ??? # "elasticsearch": ???
# "exa": ??? "exa": Dialects.EXASOL,
# "firebird": ??? # "firebird": ???
"firebolt": Firebolt, "firebolt": Firebolt,
"gsheets": Dialects.SQLITE, "gsheets": Dialects.SQLITE,
@@ -105,7 +105,7 @@ SQLGLOT_DIALECTS = {
"shillelagh": Dialects.SQLITE, "shillelagh": Dialects.SQLITE,
"singlestoredb": SingleStore, "singlestoredb": SingleStore,
"snowflake": Dialects.SNOWFLAKE, "snowflake": Dialects.SNOWFLAKE,
# "solr": ??? "solr": Dialects.SOLR,
"spark": Dialects.SPARK, "spark": Dialects.SPARK,
"sqlite": Dialects.SQLITE, "sqlite": Dialects.SQLITE,
"starrocks": Dialects.STARROCKS, "starrocks": Dialects.STARROCKS,
@@ -142,6 +142,7 @@ class RLSMethod(enum.Enum):
AS_PREDICATE = enum.auto() AS_PREDICATE = enum.auto()
AS_SUBQUERY = enum.auto() AS_SUBQUERY = enum.auto()
AS_PREDICATE_SPLICE = enum.auto()
class RLSTransformer: class RLSTransformer:
@@ -355,6 +356,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
statement: str | None = None, statement: str | None = None,
engine: str = "base", engine: str = "base",
ast: InternalRepresentation | None = None, ast: InternalRepresentation | None = None,
source: str | None = None,
): ):
if ast: if ast:
self._parsed = ast self._parsed = ast
@@ -365,6 +367,16 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
self.engine = engine self.engine = engine
self.tables = self._extract_tables_from_statement(self._parsed, self.engine) self.tables = self._extract_tables_from_statement(self._parsed, self.engine)
# Original SQL substring for this statement, when known. Used by the
# splice-mode RLS path which rewrites this string instead of regenerating
# SQL from the AST. ``None`` means the statement was constructed from an
# AST without an associated source string (splice mode falls back).
self._source_sql: str | None = source if source is not None else statement
# Verbatim SQL to return from ``format()``. Set by string-rewriting
# operations (e.g. splice-mode RLS) that produce a final SQL string and
# need to bypass the dialect generator. Cleared by AST-mutating methods
# since those invalidate this cached text.
self._raw_sql: str | None = None
@classmethod @classmethod
def split_script( def split_script(
@@ -531,7 +543,7 @@ class BaseSQLStatement(Generic[InternalRepresentation]):
self, self,
catalog: str | None, catalog: str | None,
schema: str | None, schema: str | None,
predicates: dict[Table, list[InternalRepresentation]], predicates: dict[Table, list[str]],
method: RLSMethod, method: RLSMethod,
) -> None: ) -> None:
""" """
@@ -559,9 +571,10 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
statement: str | None = None, statement: str | None = None,
engine: str = "base", engine: str = "base",
ast: exp.Expression | None = None, ast: exp.Expression | None = None,
source: str | None = None,
): ):
self._dialect = SQLGLOT_DIALECTS.get(engine) self._dialect = SQLGLOT_DIALECTS.get(engine)
super().__init__(statement, engine, ast) super().__init__(statement, engine, ast, source)
@classmethod @classmethod
def _parse(cls, script: str, engine: str) -> list[exp.Expression]: def _parse(cls, script: str, engine: str) -> list[exp.Expression]:
@@ -626,10 +639,57 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
script: str, script: str,
engine: str, engine: str,
) -> list[SQLStatement]: ) -> list[SQLStatement]:
asts = [ast for ast in cls._parse(script, engine) if ast]
sources = cls._split_source(script, engine, len(asts))
return [ return [
cls(ast=ast, engine=engine) for ast in cls._parse(script, engine) if ast cls(ast=ast, engine=engine, source=source)
for ast, source in zip(asts, sources, strict=True)
] ]
@classmethod
def _split_source(
cls,
script: str,
engine: str,
expected_count: int,
) -> list[str | None]:
"""
Slice ``script`` into per-statement substrings using top-level semicolon
positions from the tokenizer. Returns a list of length ``expected_count``;
any entry is ``None`` if the slicing didn't yield a usable substring.
The returned substrings preserve the original byte content of the script
for each statement — necessary for splice-mode RLS, which rewrites the
original SQL rather than regenerating from the AST.
"""
none_result: list[str | None] = [None] * expected_count
dialect = SQLGLOT_DIALECTS.get(engine)
try:
tokens = list(Dialect.get_or_raise(dialect).tokenize(script))
except sqlglot.errors.SqlglotError:
return none_result
# Top-level semicolon offsets (depth 0).
boundaries: list[int] = []
depth = 0
for tok in tokens:
if tok.token_type == sqlglot.tokens.TokenType.L_PAREN:
depth += 1
elif tok.token_type == sqlglot.tokens.TokenType.R_PAREN:
# Clamp at 0 so malformed SQL with unbalanced ')' can't drive
# depth negative and misclassify later semicolons as nested.
depth = max(0, depth - 1)
elif tok.token_type == sqlglot.tokens.TokenType.SEMICOLON and depth == 0:
boundaries.append(tok.start)
starts = [0, *(b + 1 for b in boundaries)]
ends = [*boundaries, len(script)]
sources = [script[s:e].strip() for s, e in zip(starts, ends, strict=True)]
sources = [s for s in sources if s]
if len(sources) != expected_count:
return none_result
return list(sources)
@classmethod @classmethod
def _parse_statement( def _parse_statement(
cls, cls,
@@ -722,7 +782,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
def format(self, comments: bool = True) -> str: def format(self, comments: bool = True) -> str:
""" """
Pretty-format the SQL statement. Pretty-format the SQL statement.
When a string-rewriting operation (e.g. splice-mode RLS) has cached a
verbatim result in ``_raw_sql``, return it as-is — the whole point of
those operations is to avoid the dialect generator round-trip.
""" """
if self._raw_sql is not None:
return self._raw_sql
return Dialect.get_or_raise(self._dialect).generate( return Dialect.get_or_raise(self._dialect).generate(
self._parsed, self._parsed,
copy=True, copy=True,
@@ -808,6 +874,13 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
""" """
Modify the `LIMIT` or `TOP` value of the SQL statement inplace. Modify the `LIMIT` or `TOP` value of the SQL statement inplace.
""" """
# AST mutation invalidates any cached verbatim SQL (e.g. from splice).
# If we already have a rewritten SQL string, re-parse it first so further
# AST mutations (like LIMIT injection) preserve prior text-based rewrites.
if self._raw_sql is not None:
self._parsed = self._parse_statement(self._raw_sql, self.engine)
self._source_sql = self._raw_sql
self._raw_sql = None
if method == LimitMethod.FORCE_LIMIT: if method == LimitMethod.FORCE_LIMIT:
self._parsed.args["limit"] = exp.Limit( self._parsed.args["limit"] = exp.Limit(
expression=exp.Literal(this=str(limit), is_string=False) expression=exp.Literal(this=str(limit), is_string=False)
@@ -902,7 +975,7 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
self, self,
catalog: str | None, catalog: str | None,
schema: str | None, schema: str | None,
predicates: dict[Table, list[exp.Expression]], predicates: dict[Table, list[str]],
method: RLSMethod, method: RLSMethod,
) -> None: ) -> None:
""" """
@@ -910,11 +983,22 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
:param catalog: The default catalog for non-qualified table names :param catalog: The default catalog for non-qualified table names
:param schema: The default schema for non-qualified table names :param schema: The default schema for non-qualified table names
:param predicates: Mapping of fully qualified ``Table`` to raw predicate
SQL strings.
:param method: The method to use for applying the rules. :param method: The method to use for applying the rules.
""" """
if not predicates: if not predicates:
return return
if method == RLSMethod.AS_PREDICATE_SPLICE:
self._apply_rls_splice(catalog, schema, predicates)
return
parsed_predicates: dict[Table, list[exp.Expression]] = {
table: [self.parse_predicate(predicate) for predicate in table_predicates]
for table, table_predicates in predicates.items()
}
transformers = { transformers = {
RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer, RLSMethod.AS_PREDICATE: RLSAsPredicateTransformer,
RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer, RLSMethod.AS_SUBQUERY: RLSAsSubqueryTransformer,
@@ -922,9 +1006,39 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
if method not in transformers: if method not in transformers:
raise ValueError(f"Invalid RLS method: {method}") raise ValueError(f"Invalid RLS method: {method}")
transformer = transformers[method](catalog, schema, predicates) transformer = transformers[method](catalog, schema, parsed_predicates)
self._parsed = self._parsed.transform(transformer) self._parsed = self._parsed.transform(transformer)
def _apply_rls_splice(
self,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[str]],
) -> None:
"""
Apply RLS via text splicing on the original SQL.
Requires the source SQL substring to be available. Raises ``ValueError``
if it isn't — the caller must ensure the statement was constructed from
a source string (the standard ``SQLScript`` path does this).
"""
from superset.sql.rls_splice import apply_rls_splice
if self._source_sql is None:
raise ValueError(
"Splice-mode RLS requires the source SQL string; "
"this SQLStatement was constructed without one."
)
spliced = apply_rls_splice(
self._source_sql,
catalog,
schema,
predicates,
dialect=self._dialect,
)
self._raw_sql = spliced
class KQLSplitState(enum.Enum): class KQLSplitState(enum.Enum):
""" """

474
superset/sql/rls_splice.py Normal file
View File

@@ -0,0 +1,474 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
RLS predicate injection via text splicing.
Instead of round-tripping through sqlglot's generator (which transpiles
dialect-specific functions like ``LAST_DAY`` into something else), this approach:
1. Parses the SQL with sqlglot — only to understand structure (scope tree).
2. Uses sqlglot's tokenizer to get byte-accurate positions for every token
in the original SQL string.
3. For each ``SELECT`` scope that references a table with an RLS predicate,
finds the exact byte offset to inject at — either the end of an existing
``WHERE`` clause, or just before ``GROUP BY`` / ``ORDER BY`` / ``HAVING``
/ ``LIMIT`` / the closing paren of a subquery.
4. Splices the predicate text directly into the original string at that
offset — never calling ``.sql()``, so the generator never runs.
Result: everything outside the splice points is the original SQL, byte for
byte. Dialect-specific functions, comments, and formatting are all preserved
exactly.
Known limitations:
- SQL that fails to parse under the chosen dialect raises a ``ParseError``.
A thin dialect subclass is still required for parsing — but only for
parsing, not generation.
- Predicate strings are spliced in as raw SQL. They must come from a trusted
source (the RLS config), not user input.
- Predicate **column qualification** (prefixing bare columns with the table
alias) currently round-trips the predicate through the sqlglot generator
via ``_qualify_predicate``. Predicates that contain dialect-specific
functions can therefore still be transpiled by the generator at that step,
even though the surrounding query is preserved byte-for-byte. The
surrounding-query fidelity guarantee does not extend to the predicate
string itself.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import sqlglot
from sqlglot import exp
from sqlglot.optimizer.scope import traverse_scope
from sqlglot.tokens import Token, TokenType
if TYPE_CHECKING:
from superset.sql.parse import Table
# Token types that end a WHERE clause / FROM section at the current paren depth,
# indicating where a new predicate must be inserted just before.
_CLAUSE_ENDS = {
TokenType.GROUP_BY,
TokenType.HAVING,
TokenType.ORDER_BY,
TokenType.WINDOW,
TokenType.QUALIFY,
TokenType.LIMIT,
TokenType.FETCH,
TokenType.CLUSTER_BY,
TokenType.DISTRIBUTE_BY,
TokenType.SORT_BY,
TokenType.CONNECT_BY,
TokenType.START_WITH,
TokenType.UNION,
TokenType.INTERSECT,
TokenType.EXCEPT,
}
_JOIN_STARTS = {
TokenType.JOIN,
TokenType.STRAIGHT_JOIN,
TokenType.JOIN_MARKER,
}
def _splice_priority(text: str) -> int:
"""
Priority for applying splices at the same offset.
Insert full SQL fragments (WHERE/ON/predicates) before closing parens so
wrapping splices like ``pred AND (existing)`` compose correctly.
"""
return 1 if text != ")" else 0
def _after_previous_token(tokens: list[Token], index: int) -> int:
"""
Return the offset immediately after the token preceding *index*.
The sqlglot tokenizer strips comments and whitespace from the token stream,
so the previous token's ``end + 1`` is the splice point that lands right
after the last real SQL content — naturally skipping any intervening
comments or whitespace, and never confusing ``--`` or ``/*`` inside string
literals for real comment delimiters.
"""
if index <= 0:
return 0
return tokens[index - 1].end + 1
def _table_from_node(
node: exp.Table,
catalog: str | None,
schema: str | None,
) -> Table:
"""
Build a fully qualified ``Table`` from a sqlglot ``exp.Table`` node, defaulting
unqualified parts to the supplied catalog/schema.
"""
# Imported lazily to avoid a circular import with ``superset.sql.parse``.
from superset.sql.parse import Table
return Table(
table=node.name,
schema=node.db if node.db else schema,
catalog=node.catalog if node.catalog else catalog,
)
def apply_rls_splice(
sql: str,
catalog: str | None,
schema: str | None,
predicates: dict[Table, list[str]],
dialect: str | None = None,
) -> str:
"""
Inject RLS predicates into ``sql`` by splicing text at the right positions.
:param sql: The original SQL query. Returned unchanged except at splice points.
:param catalog: The default catalog for non-qualified table names.
:param schema: The default schema for non-qualified table names.
:param predicates: Mapping of ``Table`` to predicate SQL strings. Each entry
maps a fully qualified table to one or more raw predicate strings to
``AND`` together when that table is referenced in a SELECT scope.
:param dialect: The sqlglot dialect used for *parsing only* — to understand
scope structure and locate token positions. The generator is never
called, so this does not affect output formatting.
:return: The query with RLS predicates injected into every relevant SELECT
scope.
"""
if not predicates or not any(predicates.values()):
return sql
resolved_dialect = sqlglot.Dialect.get_or_raise(dialect)
tokens = list(resolved_dialect.tokenize(sql))
tree = sqlglot.parse_one(sql, dialect=dialect)
splices: list[tuple[int, str]] = []
for scope in traverse_scope(tree):
splices.extend(
_splices_for_scope(
sql,
tokens,
scope,
predicates,
catalog,
schema,
dialect,
)
)
# Apply splices in reverse offset order so earlier positions stay valid.
# For equal offsets, apply predicate/WHERE/ON inserts before ")" inserts.
splices.sort(key=lambda item: (item[0], _splice_priority(item[1])), reverse=True)
result = sql
for offset, text in splices:
result = result[:offset] + text + result[offset:]
return result
def _splices_for_scope(
sql: str,
tokens: list[Token],
scope: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> list[tuple[int, str]]:
"""
Compute all splices for a single SELECT scope.
This mirrors ``RLSAsPredicateTransformer`` semantics:
- predicates for FROM tables are applied to the SELECT WHERE clause as
``pred AND (existing_where)``
- predicates for JOIN tables are applied to each JOIN ON clause as
``pred AND (existing_on)`` (or ``ON pred`` when ON is absent)
"""
from_predicates: list[str] = []
from_table_ends: list[int] = []
join_splices: list[tuple[int, str]] = []
for source in scope.sources.values(): # type: ignore[attr-defined]
source_type, table_end, pred_sql = _classify_source_predicate(
source,
predicates,
catalog,
schema,
dialect,
)
if source_type == "none" or table_end is None or pred_sql is None:
continue
if source_type == "from":
from_predicates.append(pred_sql)
from_table_ends.append(table_end)
continue
join_splice = _find_join_splice(sql, tokens, table_end, pred_sql)
if join_splice:
join_splices.extend(join_splice)
if not from_predicates:
return join_splices
combined_predicates = " AND ".join(dict.fromkeys(from_predicates))
from_splice = _find_where_splice(
sql,
tokens,
max(from_table_ends),
combined_predicates,
)
return [*join_splices, *from_splice]
def _table_end(source: exp.Table) -> int | None:
ident = source.find(exp.Identifier)
meta = getattr(ident, "_meta", None) if ident else None
if meta is None:
return None
return meta.get("end")
def _classify_source_predicate(
source: object,
predicates: dict[Table, list[str]],
catalog: str | None,
schema: str | None,
dialect: str | None,
) -> tuple[str, int | None, str | None]:
"""
Return source kind (from/join/none), table end offset, and predicate SQL.
"""
if not isinstance(source, exp.Table):
return ("none", None, None)
table = _table_from_node(source, catalog, schema)
table_predicates = [
_qualify_predicate(predicate, source, dialect)
for predicate in predicates.get(table, [])
if predicate
]
if not table_predicates:
return ("none", None, None)
table_end = _table_end(source)
if table_end is None:
return ("none", None, None)
pred_sql = " AND ".join(dict.fromkeys(table_predicates))
if isinstance(source.parent, exp.From):
return ("from", table_end, pred_sql)
if isinstance(source.parent, exp.Join):
return ("join", table_end, pred_sql)
return ("none", None, None)
def _qualify_predicate(
predicate: str,
table_node: exp.Table,
dialect: str | None,
) -> str:
"""
Qualify predicate columns with the table alias/name, mirroring
``RLSAsPredicateTransformer``.
Note: this re-renders the predicate via the sqlglot generator, so the
splice-mode fidelity guarantee does not extend to the predicate text
itself. Predicates containing dialect-specific functions may be transpiled
here even though the surrounding query is preserved byte-for-byte.
"""
parsed = sqlglot.parse_one(predicate, dialect=dialect)
table = table_node.alias_or_name
table_expr = exp.to_identifier(table)
for column in parsed.find_all(exp.Column):
column.set("table", table_expr.copy())
return parsed.sql(dialect=dialect)
def _scan_until_scope_boundary(
tokens: list[Token],
anchor: int,
*,
stop_at_join: bool,
) -> tuple[str, int | None]:
"""
Scan tokens forward from ``anchor`` until a clause/scope boundary.
Returns ``("where", index)`` when a WHERE token is found at depth 0,
``("boundary", index)`` for a non-WHERE boundary token, and
``("eof", None)`` when no boundary token is found.
"""
depth = 0
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
return ("boundary", i)
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.WHERE:
return ("where", i)
if tok.token_type in _CLAUSE_ENDS or (
stop_at_join and tok.token_type in _JOIN_STARTS
):
return ("boundary", i)
return ("eof", None)
def _find_condition_end(
tokens: list[Token],
start_index: int,
*,
stop_at_join: bool,
) -> int:
"""
Find the end offset for a WHERE/ON condition body.
"""
depth = 0
prev_end = tokens[start_index].end
for tok in tokens[start_index + 1 :]:
if tok.token_type == TokenType.L_PAREN:
depth += 1
elif tok.token_type == TokenType.R_PAREN:
if depth == 0:
return prev_end + 1
depth -= 1
elif depth == 0 and (
(stop_at_join and tok.token_type == TokenType.WHERE)
or tok.token_type in _CLAUSE_ENDS
or (stop_at_join and tok.token_type in _JOIN_STARTS)
):
return prev_end + 1
prev_end = tok.end
return prev_end + 1
def _find_where_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to the SELECT WHERE clause:
``pred`` when absent, ``pred AND (existing)`` when present.
"""
kind, idx = _scan_until_scope_boundary(tokens, anchor, stop_at_join=False)
if kind == "where" and idx is not None:
if idx + 1 >= len(tokens):
return [(tokens[idx].end + 1, f" {pred_sql}")]
body_start = tokens[idx + 1].start
body_end = _find_condition_end(tokens, idx, stop_at_join=False)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if kind == "boundary" and idx is not None:
return [(_after_previous_token(tokens, idx), f" WHERE {pred_sql}")]
return [(len(sql), f" WHERE {pred_sql}")]
def _find_join_splice(
sql: str,
tokens: list[Token],
anchor: int,
pred_sql: str,
) -> list[tuple[int, str]]:
"""
Build splices for adding predicate semantics to a JOIN clause:
``ON pred`` when ON absent, ``ON pred AND (existing_on)`` when present.
"""
on_index, boundary_index = _scan_join_clause(tokens, anchor)
if on_index is not None:
if on_index + 1 >= len(tokens):
return [(tokens[on_index].end + 1, f" {pred_sql}")]
body_start = tokens[on_index + 1].start
body_end = _find_condition_end(tokens, on_index, stop_at_join=True)
return [
(body_start, f"{pred_sql} AND ("),
(body_end, ")"),
]
if boundary_index is not None:
return [(_after_previous_token(tokens, boundary_index), f" ON {pred_sql}")]
return [(len(sql), f" ON {pred_sql}")]
def _scan_join_clause(
tokens: list[Token],
anchor: int,
) -> tuple[int | None, int | None]:
"""
Find ON and boundary token indexes for a JOIN segment.
"""
depth = 0
on_index: int | None = None
boundary_index: int | None = None
for i, tok in enumerate(tokens):
if tok.start <= anchor:
continue
if tok.token_type == TokenType.L_PAREN:
depth += 1
continue
if tok.token_type == TokenType.R_PAREN:
if depth == 0:
boundary_index = i
break
depth -= 1
continue
if depth > 0:
continue
if tok.token_type == TokenType.ON and on_index is None:
on_index = i
continue
if tok.token_type == TokenType.WHERE:
boundary_index = i
break
if tok.token_type in _JOIN_STARTS or tok.token_type in _CLAUSE_ENDS:
boundary_index = i
break
return on_index, boundary_index

View File

@@ -40,17 +40,19 @@ def apply_rls(
:returns: True if any RLS predicates were actually applied, False otherwise. :returns: True if any RLS predicates were actually applied, False otherwise.
""" """
# There are two ways to insert RLS: either replacing the table with a subquery # There are three ways to insert RLS:
# that has the RLS, or appending the RLS to the ``WHERE`` clause. The former is # - replace the table with a subquery containing the RLS (safest, but not
# safer, but not supported in all databases. # supported in all databases)
method = database.db_engine_spec.get_rls_method() # - append the RLS to the ``WHERE`` clause via AST transformation
# - splice the RLS into the original SQL string (preserves dialect-specific
# syntax that the sqlglot generator would otherwise transpile)
method = database.db_engine_spec.rls_method
# collect all RLS predicates for all tables in the query predicates: dict[Table, list[str]] = {}
predicates: dict[Table, list[Any]] = {}
for table in parsed_statement.tables: for table in parsed_statement.tables:
table = table.qualify(catalog=catalog, schema=schema) table = table.qualify(catalog=catalog, schema=schema)
predicates[table] = [ raw_predicates = [
parsed_statement.parse_predicate(predicate) predicate
for predicate in get_predicates_for_table( for predicate in get_predicates_for_table(
table, table,
database, database,
@@ -58,6 +60,7 @@ def apply_rls(
) )
if predicate if predicate
] ]
predicates[table] = raw_predicates
has_predicates = any(predicates.values()) has_predicates = any(predicates.values())
parsed_statement.apply_rls(catalog, schema, predicates, method) parsed_statement.apply_rls(catalog, schema, predicates, method)

View File

@@ -36,7 +36,7 @@ from sqlalchemy.sql import sqltypes
from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2RedirectError from superset.exceptions import OAuth2RedirectError
from superset.sql.parse import Table from superset.sql.parse import RLSMethod, Table
from superset.superset_typing import ( from superset.superset_typing import (
OAuth2ClientConfig, OAuth2ClientConfig,
OAuth2State, OAuth2State,
@@ -1283,3 +1283,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
error = exc_info.value.error error = exc_info.value.error
assert error.extra["redirect_uri"] == fallback_uri assert error.extra["redirect_uri"] == fallback_uri
def test_default_rls_method_is_subquery() -> None:
"""Base engine spec defaults to subquery-based RLS."""
assert BaseEngineSpec.rls_method == RLSMethod.AS_SUBQUERY

View File

@@ -209,7 +209,7 @@ class TestApplyRlsReturnValue:
from superset.utils.rls import apply_rls from superset.utils.rls import apply_rls
database = MagicMock() database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock() database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None database.get_default_catalog.return_value = None
statement = MagicMock() statement = MagicMock()
@@ -237,7 +237,7 @@ class TestApplyRlsReturnValue:
mock_get_predicates.return_value = [] mock_get_predicates.return_value = []
database = MagicMock() database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock() database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None database.get_default_catalog.return_value = None
mock_table = MagicMock() mock_table = MagicMock()
@@ -268,7 +268,7 @@ class TestApplyRlsReturnValue:
mock_get_predicates.return_value = ["user_id = 42"] mock_get_predicates.return_value = ["user_id = 42"]
database = MagicMock() database = MagicMock()
database.db_engine_spec.get_rls_method.return_value = MagicMock() database.db_engine_spec.rls_method = RLSMethod.AS_SUBQUERY
database.get_default_catalog.return_value = None database.get_default_catalog.return_value = None
mock_table = MagicMock() mock_table = MagicMock()
@@ -276,8 +276,6 @@ class TestApplyRlsReturnValue:
statement = MagicMock() statement = MagicMock()
statement.tables = [mock_table] statement.tables = [mock_table]
statement.parse_predicate.return_value = MagicMock()
result = apply_rls( result = apply_rls(
database=database, database=database,
catalog=None, catalog=None,
@@ -312,11 +310,10 @@ class TestRLSSubqueryAlias:
""" """
sql = "SELECT pens.pen_id, pens.is_green FROM public.pens" sql = "SELECT pens.pen_id, pens.is_green FROM public.pens"
statement = SQLStatement(sql, engine="redshift") statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls( statement.apply_rls(
None, None,
"public", "public",
{Table("pens", "public", None): [predicate]}, {Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY, RLSMethod.AS_SUBQUERY,
) )
result = statement.format() result = statement.format()
@@ -333,11 +330,10 @@ class TestRLSSubqueryAlias:
""" """
sql = "SELECT pens.pen_id, pens.is_green FROM mycat.public.pens" sql = "SELECT pens.pen_id, pens.is_green FROM mycat.public.pens"
statement = SQLStatement(sql, engine="redshift") statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls( statement.apply_rls(
None, None,
"public", "public",
{Table("pens", "public", "mycat"): [predicate]}, {Table("pens", "public", "mycat"): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY, RLSMethod.AS_SUBQUERY,
) )
result = statement.format() result = statement.format()
@@ -351,11 +347,10 @@ class TestRLSSubqueryAlias:
""" """
sql = "SELECT p.pen_id, p.is_green FROM public.pens p" sql = "SELECT p.pen_id, p.is_green FROM public.pens p"
statement = SQLStatement(sql, engine="redshift") statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls( statement.apply_rls(
None, None,
"public", "public",
{Table("pens", "public", None): [predicate]}, {Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY, RLSMethod.AS_SUBQUERY,
) )
result = statement.format() result = statement.format()
@@ -369,11 +364,10 @@ class TestRLSSubqueryAlias:
""" """
sql = "SELECT pen_id, is_green FROM public.pens" sql = "SELECT pen_id, is_green FROM public.pens"
statement = SQLStatement(sql, engine="redshift") statement = SQLStatement(sql, engine="redshift")
predicate = statement.parse_predicate("user_id = 1")
statement.apply_rls( statement.apply_rls(
None, None,
"public", "public",
{Table("pens", "public", None): [predicate]}, {Table("pens", "public", None): ["user_id = 1"]},
RLSMethod.AS_SUBQUERY, RLSMethod.AS_SUBQUERY,
) )
result = statement.format() result = statement.format()

View File

@@ -1704,6 +1704,21 @@ def test_set_limit_value(
assert statement.format() == expected assert statement.format() == expected
def test_set_limit_value_after_splice_reparses_from_raw_sql() -> None:
"""
When a statement has cached verbatim SQL from splice-mode rewrites, setting
limit should reparse that SQL before mutating the AST.
"""
statement = SQLStatement("SELECT * FROM some_table", "postgresql")
statement._raw_sql = "SELECT * FROM some_table WHERE tenant_id = 42"
statement.set_limit_value(10, LimitMethod.FORCE_LIMIT)
formatted = statement.format()
assert "tenant_id = 42" in formatted
assert "LIMIT 10" in formatted
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kql, limit, expected", "kql, limit, expected",
[ [
@@ -2198,7 +2213,7 @@ def test_rls_subquery_transformer(
statement.apply_rls( statement.apply_rls(
"catalog1", "catalog1",
"schema1", "schema1",
{k: [parse_one(v)] for k, v in rules.items()}, {k: [v] for k, v in rules.items()},
RLSMethod.AS_SUBQUERY, RLSMethod.AS_SUBQUERY,
) )
assert statement.format() == expected assert statement.format() == expected
@@ -2542,12 +2557,337 @@ def test_rls_predicate_transformer(
statement.apply_rls( statement.apply_rls(
"catalog1", "catalog1",
"schema1", "schema1",
{k: [parse_one(v)] for k, v in rules.items()}, {k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE, RLSMethod.AS_PREDICATE,
) )
assert statement.format() == expected assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
(
"SELECT t.foo FROM some_table AS t",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42",
),
(
"SELECT t.foo FROM some_table AS t WHERE bar = 'baz' OR foo = 'qux'",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM some_table AS t WHERE t.id = 42 "
"AND (bar = 'baz' OR foo = 'qux')",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id)",
),
(
"SELECT * FROM table JOIN other_table",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42",
),
(
"SELECT * FROM table JOIN other_table ON table.id = other_table.id "
"WHERE 1=1",
{Table("other_table", "schema1", "catalog1"): "other_table.id = 42"},
"SELECT * FROM table JOIN other_table ON other_table.id = 42 "
"AND (table.id = other_table.id) WHERE 1=1",
),
],
)
def test_rls_predicate_splice_semantics_match_predicate(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Splice mode should preserve predicate-mode semantics for boolean grouping
and JOIN-vs-WHERE placement.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, rules, expected",
[
# Simple — no WHERE clause to extend.
(
"SELECT LAST_DAY(d) FROM some_table",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42",
),
# Append to an existing WHERE clause.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open')",
),
# WHERE precedes GROUP BY: predicate goes before GROUP BY.
(
"SELECT LAST_DAY(d) FROM some_table WHERE status = 'open' GROUP BY d",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (status = 'open') GROUP BY d",
),
# No WHERE, but GROUP BY and ORDER BY are present.
(
"SELECT LAST_DAY(d) FROM some_table GROUP BY d ORDER BY d",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42 "
"GROUP BY d ORDER BY d",
),
# JOIN — predicate scoped to one of the tables.
(
"SELECT o.id FROM some_table o JOIN locations l ON o.loc_id = l.id",
{Table("some_table", "schema1", "catalog1"): "o.tenant_id = 42"},
"SELECT o.id FROM some_table o JOIN locations l "
"ON o.loc_id = l.id WHERE o.tenant_id = 42",
),
# JOIN — different predicate per table, both spliced into one WHERE.
(
"SELECT * FROM some_table JOIN events ON some_table.id = events.order_id",
{
Table("events", "schema1", "catalog1"): "events.user_id = 99",
Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42",
},
"SELECT * FROM some_table JOIN events "
"ON events.user_id = 99 AND (some_table.id = events.order_id) "
"WHERE some_table.tenant_id = 42",
),
# Subquery in FROM — splice into the inner SELECT.
(
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table) sub",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT x FROM (SELECT LAST_DAY(d) AS x FROM some_table "
"WHERE some_table.tenant_id = 42) sub",
),
# CTE — splice into the CTE body.
(
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table) SELECT * FROM cte",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"WITH cte AS (SELECT LAST_DAY(d) FROM some_table "
"WHERE some_table.tenant_id = 42) SELECT * FROM cte",
),
# Dialect-specific function (LAST_DAY) preserved verbatim.
(
"SELECT id, LAST_DAY(created_at) FROM some_table WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT id, LAST_DAY(created_at) FROM some_table "
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Multiline + inline comment preserved exactly.
(
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE region = 'US'",
{Table("some_table", "schema1", "catalog1"): "some_table.tenant_id = 42"},
"SELECT LAST_DAY(created_at) -- last day of month\n"
"FROM some_table\n"
"WHERE some_table.tenant_id = 42 AND (region = 'US')",
),
# Schema-qualified table name (no default schema match) — no predicate.
(
"SELECT t.foo FROM schema2.some_table AS t",
{Table("some_table", "schema1", "catalog1"): "t.id = 42"},
"SELECT t.foo FROM schema2.some_table AS t",
),
],
)
def test_rls_predicate_splice(
sql: str,
rules: dict[Table, str],
expected: str,
) -> None:
"""
Test the splice-mode RLS via ``RLSMethod.AS_PREDICATE_SPLICE``.
Splice mode rewrites the original SQL string instead of re-rendering the
AST through the dialect generator, so byte-level fidelity (including
dialect-specific functions, comments, and whitespace) is preserved.
"""
statement = SQLStatement(sql)
statement.apply_rls(
"catalog1",
"schema1",
{k: [v] for k, v in rules.items()},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
def test_rls_predicate_splice_requires_source() -> None:
"""
Splice mode requires the original SQL substring; constructing a statement
purely from an AST should make splice mode raise.
"""
ast = parse_one("SELECT * FROM some_table")
statement = SQLStatement(ast=ast, engine="postgresql")
with pytest.raises(ValueError, match="Splice-mode RLS requires the source SQL"):
statement.apply_rls(
"catalog1",
"schema1",
{Table("some_table", "schema1", "catalog1"): ["id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
def test_rls_predicate_splice_preserves_dialect_function() -> None:
"""
Splice mode must NOT round-trip through the sqlglot generator. ``LAST_DAY``
on the postgres dialect would otherwise be transpiled by the generator.
"""
sql = "SELECT LAST_DAY(d) FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT LAST_DAY(d) FROM some_table WHERE some_table.tenant_id = 42"
)
def test_rls_predicate_splice_combines_multiple_predicates() -> None:
"""
Splice mode should AND together multiple predicates configured for the same
table into a single injected condition.
"""
sql = "SELECT * FROM some_table WHERE status = 'open'"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{
Table("some_table"): [
"some_table.tenant_id = 42",
"some_table.region = 'US'",
],
},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT * FROM some_table "
"WHERE some_table.tenant_id = 42 AND some_table.region = 'US' "
"AND (status = 'open')"
)
def test_rls_predicate_splice_string_predicates_skip_parse() -> None:
"""
Splice mode accepts predicate strings directly — no ``parse_predicate`` is
needed at the call site.
"""
sql = "SELECT * FROM some_table"
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42 AND some_table.active"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 AND some_table.active"
)
@pytest.mark.parametrize(
"sql, expected",
[
(
"SELECT * FROM some_table -- hi\nGROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"-- hi\nGROUP BY id",
),
(
"SELECT * FROM some_table /* inline */ GROUP BY id",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"/* inline */ GROUP BY id",
),
],
)
def test_rls_predicate_splice_inserts_before_comments(sql: str, expected: str) -> None:
"""
Splice mode should insert predicates before comments that precede the next
clause boundary, so comments do not swallow the injected SQL.
"""
statement = SQLStatement(sql, engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
@pytest.mark.parametrize(
"sql, engine, expected",
[
(
"SELECT * FROM some_table QUALIFY row_number() OVER "
"(PARTITION BY id ORDER BY ts DESC) = 1",
"snowflake",
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"QUALIFY row_number() OVER (PARTITION BY id ORDER BY ts DESC) = 1",
),
(
"SELECT sum(v) OVER () FROM some_table WINDOW w AS (PARTITION BY id)",
"postgresql",
"SELECT sum(v) OVER () FROM some_table "
"WHERE some_table.tenant_id = 42 "
"WINDOW w AS (PARTITION BY id)",
),
],
)
def test_rls_predicate_splice_handles_additional_clause_boundaries(
sql: str,
engine: str,
expected: str,
) -> None:
"""
Splice mode should insert WHERE before clause types that can legally follow
FROM/WHERE (for example QUALIFY and WINDOW).
"""
statement = SQLStatement(sql, engine=engine)
statement.apply_rls(
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
assert statement.format() == expected
def test_rls_predicate_splice_then_limit_keeps_rls() -> None:
"""
LIMIT rewrites after splice-mode RLS should retain injected predicates.
"""
statement = SQLStatement("SELECT * FROM some_table", engine="postgresql")
statement.apply_rls(
None,
None,
{Table("some_table"): ["tenant_id = 42"]},
RLSMethod.AS_PREDICATE_SPLICE,
)
statement.set_limit_value(101, LimitMethod.FORCE_LIMIT)
formatted = statement.format()
assert "some_table.tenant_id = 42" in formatted
assert "LIMIT 101" in formatted
@pytest.mark.parametrize( @pytest.mark.parametrize(
"sql, table, expected", "sql, table, expected",
[ [

View File

@@ -0,0 +1,298 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import sqlglot
from sqlglot import Dialect, exp
from superset.sql.parse import SQLStatement, Table
from superset.sql.rls_splice import (
_after_previous_token,
_classify_source_predicate,
_find_condition_end,
_find_join_splice,
_find_where_splice,
_scan_join_clause,
_scan_until_scope_boundary,
_splices_for_scope,
_table_end,
apply_rls_splice,
)
def _tokenize(sql: str) -> list[sqlglot.tokens.Token]:
return list(Dialect.get_or_raise(None).tokenize(sql))
def _token_index(tokens: list[sqlglot.tokens.Token], token_type: object) -> int:
return next(i for i, token in enumerate(tokens) if token.token_type == token_type)
def _token_by_text(
tokens: list[sqlglot.tokens.Token], text: str
) -> sqlglot.tokens.Token:
return next(token for token in tokens if token.text == text)
def test_split_source_returns_none_result_when_tokenize_fails(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _BrokenDialect:
@staticmethod
def tokenize(_: str) -> list[sqlglot.tokens.Token]:
raise sqlglot.errors.SqlglotError("boom")
monkeypatch.setattr(
"superset.sql.parse.Dialect.get_or_raise",
lambda _: _BrokenDialect(),
)
assert SQLStatement._split_source("SELECT 1", "postgresql", 2) == [None, None]
def test_apply_rls_splice_ignores_empty_predicates() -> None:
sql = "SELECT 1"
assert apply_rls_splice(sql, None, None, {Table("foo"): []}) == sql
def test_apply_rls_splice_ignores_dash_dash_inside_string_literal() -> None:
"""
Regression: the splice point must not be confused by ``--`` appearing
inside a string literal. Earlier ``rfind("--", ...)`` logic mistook this
for an inline comment and inserted the predicate inside the quoted text.
"""
sql = "SELECT * FROM some_table WHERE note = '--x' GROUP BY id"
spliced = apply_rls_splice(
sql,
None,
None,
{Table("some_table"): ["some_table.tenant_id = 42"]},
dialect="postgres",
)
assert spliced == (
"SELECT * FROM some_table WHERE some_table.tenant_id = 42 "
"AND (note = '--x') GROUP BY id"
)
def test_table_end_returns_none_without_metadata() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
assert _table_end(source) is None
def test_classify_source_predicate_returns_none_without_table_metadata() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
exp.From(this=source)
result = _classify_source_predicate(
source,
{Table("foo"): ["id = 1"]},
None,
None,
None,
)
assert result == ("none", None, None)
def test_classify_source_predicate_returns_none_for_unsupported_parent() -> None:
source = exp.Table(this=exp.Identifier(this="foo"))
source.this.meta["end"] = 3
exp.Alias(this=source, alias=exp.Identifier(this="alias"))
result = _classify_source_predicate(
source,
{Table("foo"): ["id = 1"]},
None,
None,
None,
)
assert result == ("none", None, None)
def test_after_previous_token_returns_zero_at_stream_start() -> None:
tokens = _tokenize("SELECT 1")
assert _after_previous_token(tokens, 0) == 0
def test_scan_until_scope_boundary_tracks_parenthesis_depth() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_token = _token_by_text(tokens, "WHERE")
assert _scan_until_scope_boundary(
tokens, where_token.start, stop_at_join=False
) == (
"eof",
None,
)
def test_find_condition_end_handles_subquery_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM t WHERE a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert sql[end] == ")"
def test_find_condition_end_handles_parenthesized_expression() -> None:
sql = "SELECT * FROM t WHERE (a = 1)"
tokens = _tokenize(sql)
where_index = _token_index(tokens, sqlglot.tokens.TokenType.WHERE)
end = _find_condition_end(tokens, where_index, stop_at_join=False)
assert end == len(sql)
def test_find_where_splice_handles_trailing_where_keyword() -> None:
sql = "SELECT * FROM t WHERE"
tokens = _tokenize(sql)
splices = _find_where_splice(sql, tokens, anchor=0, pred_sql="t.id = 1")
assert splices == [(len(sql), " t.id = 1")]
def test_find_join_splice_handles_trailing_on_keyword() -> None:
sql = "SELECT * FROM a JOIN b ON"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
assert splices == [(len(sql), " b.id = 1")]
def test_find_join_splice_inserts_on_before_where_boundary() -> None:
sql = "SELECT * FROM a JOIN b WHERE x = 1"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
splices = _find_join_splice(sql, tokens, b_token.end, "b.id = 1")
assert splices == [(sql.index("WHERE") - 1, " ON b.id = 1")]
def test_scan_join_clause_covers_nested_parentheses_and_join_boundary() -> None:
sql = "SELECT * FROM a JOIN b ON (a.id = b.id) JOIN c ON 1 = 1"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
on_index, boundary_index = _scan_join_clause(tokens, b_token.end)
assert on_index is not None
assert boundary_index is not None
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.JOIN
def test_scan_join_clause_stops_at_outer_closing_paren() -> None:
sql = "SELECT * FROM (SELECT * FROM a JOIN b) sub"
tokens = _tokenize(sql)
b_token = _token_by_text(tokens, "b")
_, boundary_index = _scan_join_clause(tokens, b_token.end)
assert boundary_index is not None
assert tokens[boundary_index].token_type == sqlglot.tokens.TokenType.R_PAREN
def test_splices_for_scope_handles_empty_join_splice_result(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"x": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate",
lambda *args, **kwargs: ("join", 0, "x.id = 1"),
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [],
)
assert (
_splices_for_scope(
sql,
tokens,
_Scope(),
{Table("x"): ["x.id = 1"]},
None,
None,
None,
)
== []
)
def test_splices_for_scope_combines_join_and_from_splices(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"f": object(), "j": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
calls = [("from", 3, "f.id = 1"), ("join", 6, "j.id = 2")]
def _fake_classify(*args: object, **kwargs: object) -> tuple[str, int, str]:
return calls.pop(0)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [(50, " ON j.id = 2")],
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_where_splice",
lambda *args, **kwargs: [(20, " WHERE f.id = 1")],
)
assert _splices_for_scope(
sql,
tokens,
_Scope(),
{Table("f"): ["id = 1"], Table("j"): ["id = 2"]},
None,
None,
None,
) == [(50, " ON j.id = 2"), (20, " WHERE f.id = 1")]
def test_splices_for_scope_join_then_next_source(
monkeypatch: pytest.MonkeyPatch,
) -> None:
class _Scope:
sources = {"j": object(), "f": object()}
sql = "SELECT 1"
tokens = _tokenize(sql)
calls = [("join", 3, "j.id = 2"), ("none", None, None)]
def _fake_classify(
*args: object, **kwargs: object
) -> tuple[str, int | None, str | None]:
return calls.pop(0)
monkeypatch.setattr(
"superset.sql.rls_splice._classify_source_predicate", _fake_classify
)
monkeypatch.setattr(
"superset.sql.rls_splice._find_join_splice",
lambda *args, **kwargs: [],
)
assert (
_splices_for_scope(
sql,
tokens,
_Scope(),
{Table("j"): ["id = 2"]},
None,
None,
None,
)
== []
)