diff --git a/superset/config.py b/superset/config.py index a33294ed655..7c1dde5e914 100644 --- a/superset/config.py +++ b/superset/config.py @@ -581,6 +581,10 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = { # Apply RLS rules to SQL Lab queries. This requires parsing and manipulating the # query, and might break queries and/or allow users to bypass RLS. Use with care! "RLS_IN_SQLLAB": False, + # Enable the new Data Access Rules system for table-level access control, + # row-level security (RLS), and column-level security (CLS). This replaces + # the FAB-based permission system with a more flexible JSON-based rule system. + "DATA_ACCESS_RULES": False, # Try to optimize SQL queries — for now only predicate pushdown is supported. "OPTIMIZE_SQL": False, # When impersonating a user, use the email prefix instead of the username diff --git a/superset/data_access_rules/__init__.py b/superset/data_access_rules/__init__.py new file mode 100644 index 00000000000..49f197356a5 --- /dev/null +++ b/superset/data_access_rules/__init__.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Data Access Rules module. + +This module provides a new approach to data access control in Superset, +supporting: +- Table-level access control (allow/deny patterns) +- Row-level security (RLS) with predicates +- Column-level security (CLS) with masking/hiding options + +Unlike the FAB-based permission system, rules are stored as JSON documents +and can reference tables directly without requiring a priori permission creation. +""" diff --git a/superset/data_access_rules/api.py b/superset/data_access_rules/api.py new file mode 100644 index 00000000000..602b57d34ef --- /dev/null +++ b/superset/data_access_rules/api.py @@ -0,0 +1,145 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Data Access Rules REST API. + +This module provides the REST API for managing Data Access Rules, +including CRUD operations and a group_keys discovery endpoint. +""" + +import logging + +from flask import Response +from flask_appbuilder.api import expose, protect, safe +from flask_appbuilder.models.sqla.interface import SQLAInterface + +from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod +from superset.data_access_rules.models import DataAccessRule +from superset.data_access_rules.schemas import ( + DataAccessRuleListSchema, + DataAccessRulePostSchema, + DataAccessRulePutSchema, + DataAccessRuleShowSchema, +) +from superset.data_access_rules.utils import get_all_group_keys +from superset.extensions import event_logger +from superset.views.base_api import ( + BaseSupersetModelRestApi, + statsd_metrics, +) + +logger = logging.getLogger(__name__) + + +class DataAccessRulesRestApi(BaseSupersetModelRestApi): + """REST API for Data Access Rules.""" + + datamodel = SQLAInterface(DataAccessRule) + include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { + RouteMethod.RELATED, + "group_keys", + } + resource_name = "data_access_rule" + class_permission_name = "DataAccessRule" + openapi_spec_tag = "Data Access Rules" + method_permission_name = MODEL_API_RW_METHOD_PERMISSION_MAP + allow_browser_login = True + + list_columns = [ + "id", + "role_id", + "role.name", + "rule", + "changed_on_delta_humanized", + "changed_by.first_name", + "changed_by.last_name", + "changed_by.id", + ] + order_columns = [ + "id", + "role_id", + "changed_on_delta_humanized", + ] + add_columns = [ + "role_id", + "rule", + ] + edit_columns = [ + "role_id", + "rule", + ] + show_columns = [ + "id", + "role_id", + "role.name", + "rule", + "created_on", + "changed_on", + "created_by.first_name", + "created_by.last_name", + "changed_by.first_name", + "changed_by.last_name", + ] + + add_model_schema = DataAccessRulePostSchema() + edit_model_schema = DataAccessRulePutSchema() + list_model_schema = DataAccessRuleListSchema() + show_model_schema = DataAccessRuleShowSchema() + + openapi_spec_methods = { + "get": {"get": {"summary": "Get a data access rule"}}, + "get_list": {"get": {"summary": "Get a list of data access rules"}}, + "post": {"post": {"summary": "Create a data access rule"}}, + "put": {"put": {"summary": "Update a data access rule"}}, + "delete": {"delete": {"summary": "Delete a data access rule"}}, + } + + @expose("/group_keys/", methods=("GET",)) + @protect() + @safe + @statsd_metrics + @event_logger.log_this_with_context( + action=lambda self, *args, **kwargs: f"{self.__class__.__name__}.group_keys", + log_to_statsd=False, + ) + def group_keys(self) -> Response: + """ + Get all distinct group_keys used in RLS rules. + + This endpoint is useful for UI discoverability - showing users + what group_keys already exist so they can reuse them. + --- + get: + summary: Get all distinct RLS group keys + description: >- + Returns a list of all unique group_key values used in RLS rules + across all Data Access Rules. This helps users discover existing + keys for consistent rule grouping. + responses: + 200: + description: List of group keys + content: + application/json: + schema: + $ref: '#/components/schemas/GroupKeysResponseSchema' + 401: + $ref: '#/components/responses/401' + 500: + $ref: '#/components/responses/500' + """ + group_keys = get_all_group_keys() + return self.response(200, result=sorted(group_keys)) diff --git a/superset/data_access_rules/models.py b/superset/data_access_rules/models.py new file mode 100644 index 00000000000..1aa166f7c2a --- /dev/null +++ b/superset/data_access_rules/models.py @@ -0,0 +1,117 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Data Access Rules models. + +This module defines the DataAccessRule model for storing access rules +as JSON documents associated with roles. + +Example rule document structure: +{ + "allowed": [ + { + "database": "sales", + "schema": "orders", + "table": "ord_main" + }, + { + "database": "logs", + "catalog": "public" + }, + { + "database": "sales", + "schema": "orders", + "table": "prices", + "rls": { + "predicate": "org = 495", + "group_key": "org_filter" + } + }, + { + "database": "sales", + "schema": "orders", + "table": "user_info", + "cls": { + "name": "mask", + "age": "nullify", + "email": "hash", + "lastname": "hide" + } + } + ], + "denied": [ + { + "database": "logs", + "catalog": "public", + "schema": "pii" + } + ] +} +""" + +from __future__ import annotations + +from typing import Any + +from flask_appbuilder import Model +from sqlalchemy import Column, ForeignKey, Integer, Text +from sqlalchemy.orm import relationship + +from superset import security_manager +from superset.models.helpers import AuditMixinNullable + + +class DataAccessRule(Model, AuditMixinNullable): + """ + Data access rule associated with a role. + + Each rule is a JSON document that describes what databases, catalogs, + schemas, and tables a role can access, along with optional RLS predicates + and CLS column restrictions. + """ + + __tablename__ = "data_access_rules" + + id = Column(Integer, primary_key=True) + role_id = Column(Integer, ForeignKey("ab_role.id"), nullable=False) + rule = Column(Text, nullable=False) + + role = relationship( + security_manager.role_model, + backref="data_access_rules", + foreign_keys=[role_id], + ) + + def __repr__(self) -> str: + return f"" + + @property + def rule_dict(self) -> dict[str, Any]: + """Parse the rule JSON string into a dictionary.""" + import json + + try: + return json.loads(self.rule) if self.rule else {} + except json.JSONDecodeError: + return {} + + @rule_dict.setter + def rule_dict(self, value: dict[str, Any]) -> None: + """Serialize a dictionary to JSON for storage.""" + import json + + self.rule = json.dumps(value) diff --git a/superset/data_access_rules/schemas.py b/superset/data_access_rules/schemas.py new file mode 100644 index 00000000000..8a87bde8c56 --- /dev/null +++ b/superset/data_access_rules/schemas.py @@ -0,0 +1,216 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Data Access Rules schemas for API serialization/deserialization. +""" + +from marshmallow import fields, Schema, validates_schema, ValidationError + +from superset.dashboards.schemas import UserSchema + +# Field descriptions for OpenAPI documentation +rule_description = """ +A JSON document describing the access rule. The document should have two optional keys: +- `allowed`: List of entries describing what is allowed +- `denied`: List of entries describing what is denied + +Each entry can specify: +- `database` (required): The database name +- `catalog` (optional): The catalog name +- `schema` (optional): The schema name +- `table` (optional): The table name +- `rls` (optional): Row-level security config with `predicate` and optional `group_key` +- `cls` (optional): Column-level security config mapping column names to actions + +Example: +{ + "allowed": [ + {"database": "sales", "schema": "orders"}, + {"database": "sales", "schema": "orders", "table": "prices", + "rls": {"predicate": "org_id = 123", "group_key": "org"}}, + {"database": "sales", "schema": "users", "table": "info", + "cls": {"email": "mask", "ssn": "hide"}} + ], + "denied": [ + {"database": "sales", "schema": "internal"} + ] +} + +CLS actions: "hash", "nullify", "mask", "hide" +""" + + +class RoleSchema(Schema): + """Schema for role information.""" + + name = fields.String() + id = fields.Integer() + + +class DataAccessRuleListSchema(Schema): + """Schema for listing data access rules.""" + + id = fields.Integer(metadata={"description": "Unique ID of the rule"}) + role_id = fields.Integer(metadata={"description": "ID of the associated role"}) + role = fields.Nested(RoleSchema) + rule = fields.String(metadata={"description": rule_description}) + changed_on_delta_humanized = fields.String() + changed_by = fields.Nested(UserSchema(exclude=["username"])) + + +class DataAccessRuleShowSchema(Schema): + """Schema for showing a single data access rule.""" + + id = fields.Integer(metadata={"description": "Unique ID of the rule"}) + role_id = fields.Integer(metadata={"description": "ID of the associated role"}) + role = fields.Nested(RoleSchema) + rule = fields.String(metadata={"description": rule_description}) + created_on = fields.DateTime() + changed_on = fields.DateTime() + created_by = fields.Nested(UserSchema(exclude=["username"])) + changed_by = fields.Nested(UserSchema(exclude=["username"])) + + +class DataAccessRulePostSchema(Schema): + """Schema for creating a data access rule.""" + + role_id = fields.Integer( + metadata={"description": "ID of the role this rule applies to"}, + required=True, + allow_none=False, + ) + rule = fields.String( + metadata={"description": rule_description}, + required=True, + allow_none=False, + ) + + @validates_schema + def validate_rule_json(self, data: dict, **kwargs: dict) -> None: + """Validate that the rule field contains valid JSON.""" + import json + + if rule := data.get("rule"): + try: + parsed = json.loads(rule) + if not isinstance(parsed, dict): + raise ValidationError( + "Rule must be a JSON object", field_name="rule" + ) + + # Validate structure + allowed = parsed.get("allowed", []) + denied = parsed.get("denied", []) + + if not isinstance(allowed, list): + raise ValidationError("'allowed' must be a list", field_name="rule") + if not isinstance(denied, list): + raise ValidationError("'denied' must be a list", field_name="rule") + + # Validate entries + for entry in allowed + denied: + if not isinstance(entry, dict): + raise ValidationError( + "Each entry must be an object", field_name="rule" + ) + if "database" not in entry: + raise ValidationError( + "Each entry must have a 'database' field", + field_name="rule", + ) + + # Validate CLS actions if present + if cls_config := entry.get("cls"): + valid_actions = {"hash", "nullify", "mask", "hide"} + for col, action in cls_config.items(): + if action.lower() not in valid_actions: + raise ValidationError( + f"Invalid CLS action '{action}' for column '{col}'. " + f"Valid actions: {valid_actions}", + field_name="rule", + ) + except json.JSONDecodeError as ex: + raise ValidationError(f"Invalid JSON: {ex}", field_name="rule") from ex + + +class DataAccessRulePutSchema(Schema): + """Schema for updating a data access rule.""" + + role_id = fields.Integer( + metadata={"description": "ID of the role this rule applies to"}, + required=False, + allow_none=False, + ) + rule = fields.String( + metadata={"description": rule_description}, + required=False, + allow_none=False, + ) + + @validates_schema + def validate_rule_json(self, data: dict, **kwargs: dict) -> None: + """Validate that the rule field contains valid JSON if provided.""" + import json + + if rule := data.get("rule"): + try: + parsed = json.loads(rule) + if not isinstance(parsed, dict): + raise ValidationError( + "Rule must be a JSON object", field_name="rule" + ) + + # Same validation as POST schema + allowed = parsed.get("allowed", []) + denied = parsed.get("denied", []) + + if not isinstance(allowed, list): + raise ValidationError("'allowed' must be a list", field_name="rule") + if not isinstance(denied, list): + raise ValidationError("'denied' must be a list", field_name="rule") + + for entry in allowed + denied: + if not isinstance(entry, dict): + raise ValidationError( + "Each entry must be an object", field_name="rule" + ) + if "database" not in entry: + raise ValidationError( + "Each entry must have a 'database' field", + field_name="rule", + ) + + if cls_config := entry.get("cls"): + valid_actions = {"hash", "nullify", "mask", "hide"} + for col, action in cls_config.items(): + if action.lower() not in valid_actions: + raise ValidationError( + f"Invalid CLS action '{action}' for column '{col}'. " + f"Valid actions: {valid_actions}", + field_name="rule", + ) + except json.JSONDecodeError as ex: + raise ValidationError(f"Invalid JSON: {ex}", field_name="rule") from ex + + +class GroupKeysResponseSchema(Schema): + """Schema for the group_keys endpoint response.""" + + result = fields.List( + fields.String(), + metadata={"description": "List of unique group_key values used in RLS rules"}, + ) diff --git a/superset/data_access_rules/utils.py b/superset/data_access_rules/utils.py new file mode 100644 index 00000000000..6de09bb3aa3 --- /dev/null +++ b/superset/data_access_rules/utils.py @@ -0,0 +1,542 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Data Access Rules utility functions. + +This module provides functions for: +- Checking if a user has access to a table +- Collecting RLS predicates for a table +- Collecting CLS rules for a table +- Applying RLS and CLS to SQL queries +""" + +from __future__ import annotations + +import logging +from collections import defaultdict +from dataclasses import dataclass +from enum import Enum +from typing import Any, TYPE_CHECKING + +from flask import g + +from superset import db, is_feature_enabled, security_manager +from superset.sql.parse import CLSAction, Table + +if TYPE_CHECKING: + from superset.data_access_rules.models import DataAccessRule + from superset.models.core import Database + from superset.sql.parse import BaseSQLStatement + +logger = logging.getLogger(__name__) + + +class AccessCheckResult(Enum): + """Result of an access check.""" + + ALLOWED = "allowed" + DENIED = "denied" + NO_RULE = "no_rule" + + +@dataclass +class RLSPredicate: + """An RLS predicate with optional group_key.""" + + predicate: str + group_key: str | None = None + + +@dataclass +class TableAccessInfo: + """Information about access to a specific table.""" + + access: AccessCheckResult + rls_predicates: list[RLSPredicate] + cls_rules: dict[str, CLSAction] + + +def get_user_rules() -> list[DataAccessRule]: + """ + Get all data access rules for the current user's roles. + + Returns: + List of DataAccessRule objects for the current user's roles. + """ + from superset.data_access_rules.models import DataAccessRule + + if not hasattr(g, "user") or not g.user: + return [] + + user_roles = security_manager.get_user_roles() + role_ids = [role.id for role in user_roles] + + if not role_ids: + return [] + + return ( + db.session.query(DataAccessRule) + .filter(DataAccessRule.role_id.in_(role_ids)) + .all() + ) + + +def _matches_rule_entry( + entry: dict[str, Any], + database_name: str, + catalog: str | None, + schema: str | None, + table_name: str | None, +) -> bool: + """ + Check if a rule entry matches the given database/catalog/schema/table. + + The rule entry can specify any level of the hierarchy: + - database only: matches all catalogs/schemas/tables in that database + - database + catalog: matches all schemas/tables in that catalog + - database + schema: matches all tables in that schema (for DBs without catalogs) + - database + catalog + schema: matches all tables in that schema + - database + schema + table: matches the specific table (for DBs without catalogs) + - database + catalog + schema + table: matches the specific table + + Args: + entry: The rule entry dict + database_name: The database name to check + catalog: The catalog to check (None if DB doesn't support catalogs) + schema: The schema to check + table_name: The table name to check + + Returns: + True if the entry matches, False otherwise. + """ + # Database must always match + if entry.get("database") != database_name: + return False + + entry_catalog = entry.get("catalog") + entry_schema = entry.get("schema") + entry_table = entry.get("table") + + # If the entry specifies a catalog, it must match (or catalog must be None/default) + if entry_catalog is not None: + if catalog is not None and entry_catalog != catalog: + return False + + # If the entry specifies a schema, it must match + if entry_schema is not None: + if schema is not None and entry_schema != schema: + return False + + # If the entry specifies a table, it must match + if entry_table is not None: + if table_name is not None and entry_table != table_name: + return False + + # Check specificity: entry must be at least as specific as the query + # If querying a specific table, entry must specify that table or be broader + if table_name is not None and entry_table is not None and entry_table != table_name: + return False + + return True + + +def _is_more_specific(entry: dict[str, Any], other: dict[str, Any]) -> bool: + """ + Check if 'entry' is more specific than 'other'. + + More specific means it specifies more levels of the hierarchy. + """ + entry_specificity = sum( + [ + entry.get("catalog") is not None, + entry.get("schema") is not None, + entry.get("table") is not None, + ] + ) + other_specificity = sum( + [ + other.get("catalog") is not None, + other.get("schema") is not None, + other.get("table") is not None, + ] + ) + return entry_specificity > other_specificity + + +def check_table_access( + database_name: str, + table: Table, + rules: list[DataAccessRule] | None = None, +) -> TableAccessInfo: + """ + Check if the current user has access to a specific table. + + The function evaluates all rules for the user's roles and determines: + 1. Whether access is allowed, denied, or no rule applies + 2. Any RLS predicates that should be applied + 3. Any CLS rules for column masking/hiding + + Denied rules take precedence over allowed rules when at the same specificity level. + More specific rules take precedence over less specific rules. + + Args: + database_name: The database name + table: The Table object with catalog, schema, and table name + rules: Optional list of rules to check (defaults to current user's rules) + + Returns: + TableAccessInfo with access result, RLS predicates, and CLS rules. + """ + if rules is None: + rules = get_user_rules() + + if not rules: + return TableAccessInfo( + access=AccessCheckResult.NO_RULE, + rls_predicates=[], + cls_rules={}, + ) + + # Collect all matching rules + allowed_entries: list[dict[str, Any]] = [] + denied_entries: list[dict[str, Any]] = [] + + for rule in rules: + rule_dict = rule.rule_dict + + # Check allowed entries + for entry in rule_dict.get("allowed", []): + if _matches_rule_entry( + entry, database_name, table.catalog, table.schema, table.table + ): + allowed_entries.append(entry) + + # Check denied entries + for entry in rule_dict.get("denied", []): + if _matches_rule_entry( + entry, database_name, table.catalog, table.schema, table.table + ): + denied_entries.append(entry) + + # If no rules match, return NO_RULE + if not allowed_entries and not denied_entries: + return TableAccessInfo( + access=AccessCheckResult.NO_RULE, + rls_predicates=[], + cls_rules={}, + ) + + # Find the most specific denied entry + most_specific_denied = None + for entry in denied_entries: + if most_specific_denied is None or _is_more_specific( + entry, most_specific_denied + ): + most_specific_denied = entry + + # Find the most specific allowed entry + most_specific_allowed = None + for entry in allowed_entries: + if most_specific_allowed is None or _is_more_specific( + entry, most_specific_allowed + ): + most_specific_allowed = entry + + # Determine access: deny wins at same specificity, more specific wins otherwise + if most_specific_denied is not None and most_specific_allowed is not None: + if _is_more_specific(most_specific_denied, most_specific_allowed): + return TableAccessInfo( + access=AccessCheckResult.DENIED, + rls_predicates=[], + cls_rules={}, + ) + elif _is_more_specific(most_specific_allowed, most_specific_denied): + # Access allowed, collect RLS and CLS from matching entries + pass + else: + # Same specificity: denied wins + return TableAccessInfo( + access=AccessCheckResult.DENIED, + rls_predicates=[], + cls_rules={}, + ) + elif most_specific_denied is not None: + return TableAccessInfo( + access=AccessCheckResult.DENIED, + rls_predicates=[], + cls_rules={}, + ) + elif most_specific_allowed is None: + return TableAccessInfo( + access=AccessCheckResult.NO_RULE, + rls_predicates=[], + cls_rules={}, + ) + + # Collect RLS predicates from all matching allowed entries + # (RLS is cumulative - all predicates are applied) + rls_predicates: list[RLSPredicate] = [] + for entry in allowed_entries: + rls_config = entry.get("rls") + if rls_config and "predicate" in rls_config: + rls_predicates.append( + RLSPredicate( + predicate=rls_config["predicate"], + group_key=rls_config.get("group_key"), + ) + ) + + # Collect CLS rules from all matching allowed entries + # (CLS is cumulative - strictest action wins per column) + cls_rules: dict[str, CLSAction] = {} + cls_precedence = { + CLSAction.HIDE: 4, + CLSAction.NULLIFY: 3, + CLSAction.MASK: 2, + CLSAction.HASH: 1, + } + action_map = { + "hide": CLSAction.HIDE, + "nullify": CLSAction.NULLIFY, + "mask": CLSAction.MASK, + "hash": CLSAction.HASH, + } + + for entry in allowed_entries: + cls_config = entry.get("cls", {}) + for column, action_str in cls_config.items(): + action = action_map.get(action_str.lower()) + if action is None: + logger.warning("Unknown CLS action: %s", action_str) + continue + + existing = cls_rules.get(column) + if existing is None or cls_precedence[action] > cls_precedence[existing]: + cls_rules[column] = action + + return TableAccessInfo( + access=AccessCheckResult.ALLOWED, + rls_predicates=rls_predicates, + cls_rules=cls_rules, + ) + + +def get_rls_predicates_for_table( + table: Table, + database: Database, + rules: list[DataAccessRule] | None = None, +) -> list[str]: + """ + Get the RLS predicates for a table using the new Data Access Rules system. + + This function collects all RLS predicates from matching rules and combines them + using the group_key logic: + - Predicates without group_key are ANDed together + - Predicates with the same group_key are ORed together + - Groups are ANDed together + + Args: + table: The fully qualified Table object + database: The Database object + rules: Optional list of rules to check (defaults to current user's rules) + + Returns: + List of SQL predicate strings to be ANDed together. + """ + access_info = check_table_access( + database_name=database.database_name, + table=table, + rules=rules, + ) + + if access_info.access != AccessCheckResult.ALLOWED: + return [] + + if not access_info.rls_predicates: + return [] + + # Group predicates by group_key + ungrouped: list[str] = [] + groups: dict[str, list[str]] = defaultdict(list) + + for pred in access_info.rls_predicates: + if pred.group_key: + groups[pred.group_key].append(f"({pred.predicate})") + else: + ungrouped.append(f"({pred.predicate})") + + # Build result: ungrouped predicates + OR'd groups + result = ungrouped.copy() + for group_predicates in groups.values(): + if len(group_predicates) == 1: + result.append(group_predicates[0]) + else: + result.append(f"({' OR '.join(group_predicates)})") + + return result + + +def get_cls_rules_for_table( + table: Table, + database: Database, + rules: list[DataAccessRule] | None = None, +) -> dict[str, CLSAction]: + """ + Get the CLS rules for a table using the new Data Access Rules system. + + Args: + table: The fully qualified Table object + database: The Database object + rules: Optional list of rules to check (defaults to current user's rules) + + Returns: + Dict mapping column names to CLSAction values. + """ + access_info = check_table_access( + database_name=database.database_name, + table=table, + rules=rules, + ) + + if access_info.access != AccessCheckResult.ALLOWED: + return {} + + return access_info.cls_rules + + +def apply_data_access_rules( + database: Database, + catalog: str | None, + schema: str, + parsed_statement: BaseSQLStatement[Any], +) -> None: + """ + Apply Data Access Rules (RLS and CLS) to a parsed SQL statement. + + This function: + 1. Checks if the DATA_ACCESS_RULES feature is enabled + 2. For each table in the query, checks access and collects RLS/CLS rules + 3. Applies RLS predicates using the existing infrastructure + 4. Applies CLS rules using the existing infrastructure + + Args: + database: The Database object + catalog: The default catalog for the query + schema: The default schema for the query + parsed_statement: The parsed SQL statement to modify in place + """ + if not is_feature_enabled("DATA_ACCESS_RULES"): + return + + from superset.sql.parse import CLSRules + + rules = get_user_rules() + if not rules: + return + + # Get the RLS method for this database + method = database.db_engine_spec.get_rls_method() + + # Collect RLS predicates and CLS rules for all tables + rls_predicates: dict[Table, list[Any]] = {} + cls_rules: CLSRules = {} + + for table in parsed_statement.tables: + qualified_table = table.qualify(catalog=catalog, schema=schema) + + # Check access first + access_info = check_table_access( + database_name=database.database_name, + table=qualified_table, + rules=rules, + ) + + if access_info.access == AccessCheckResult.DENIED: + # TODO: How should we handle denied access mid-query? + # For now, log a warning. In the future, we might raise an exception. + logger.warning( + "Access denied to table %s for user %s", + qualified_table, + getattr(g, "user", "unknown"), + ) + continue + + # Collect RLS predicates + predicates = get_rls_predicates_for_table(qualified_table, database, rules) + if predicates: + rls_predicates[qualified_table] = [ + parsed_statement.parse_predicate(pred) for pred in predicates if pred + ] + + # Collect CLS rules + table_cls = get_cls_rules_for_table(qualified_table, database, rules) + if table_cls: + cls_rules[qualified_table] = table_cls + + # Apply RLS if we have predicates + if rls_predicates: + parsed_statement.apply_rls(catalog, schema, rls_predicates, method) + + # Apply CLS if we have rules + if cls_rules: + parsed_statement.apply_cls(cls_rules) + + +def get_all_group_keys( + database_name: str | None = None, + table: Table | None = None, +) -> set[str]: + """ + Get all distinct group_keys used in RLS rules. + + This is useful for UI discoverability - showing users what group_keys + already exist so they can reuse them for consistent rule grouping. + + Args: + database_name: Optional filter by database + table: Optional Table object to filter by catalog/schema/table + + Returns: + Set of unique group_key values. + """ + from superset.data_access_rules.models import DataAccessRule + + query = db.session.query(DataAccessRule) + rules = query.all() + + group_keys: set[str] = set() + + for rule in rules: + rule_dict = rule.rule_dict + + for entry in rule_dict.get("allowed", []): + # Apply filters if specified + if database_name and entry.get("database") != database_name: + continue + if table is not None: + if table.catalog and entry.get("catalog") != table.catalog: + continue + if table.schema and entry.get("schema") != table.schema: + continue + if table.table and entry.get("table") != table.table: + continue + + rls_config = entry.get("rls", {}) + if group_key := rls_config.get("group_key"): + group_keys.add(group_key) + + return group_keys diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index 4f4fd361a49..b16e7f7713e 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -161,6 +161,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods from superset.dashboards.api import DashboardRestApi from superset.dashboards.filter_state.api import DashboardFilterStateRestApi from superset.dashboards.permalink.api import DashboardPermalinkRestApi + from superset.data_access_rules.api import DataAccessRulesRestApi from superset.databases.api import DatabaseRestApi from superset.datasets.api import DatasetRestApi from superset.datasets.columns.api import DatasetColumnsRestApi @@ -264,6 +265,7 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods appbuilder.add_api(ReportScheduleRestApi) appbuilder.add_api(ReportExecutionLogRestApi) appbuilder.add_api(RLSRestApi) + appbuilder.add_api(DataAccessRulesRestApi) appbuilder.add_api(SavedQueryRestApi) appbuilder.add_api(TagRestApi) appbuilder.add_api(SqlLabRestApi) diff --git a/superset/migrations/versions/2025-12-17_10-00_a352d7609189_add_data_access_rules_table.py b/superset/migrations/versions/2025-12-17_10-00_a352d7609189_add_data_access_rules_table.py new file mode 100644 index 00000000000..08790fb53cc --- /dev/null +++ b/superset/migrations/versions/2025-12-17_10-00_a352d7609189_add_data_access_rules_table.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add_data_access_rules_table + +Revision ID: a352d7609189 +Revises: a9c01ec10479 +Create Date: 2025-12-17 10:00:00.000000 + +""" + +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +from superset.migrations.shared.utils import ( + create_fks_for_table, + create_table, + drop_table, +) + +# revision identifiers, used by Alembic. +revision = "a352d7609189" +down_revision = "a9c01ec10479" + + +def upgrade(): + create_table( + "data_access_rules", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("role_id", sa.Integer(), nullable=False), + sa.Column( + "rule", + sa.Text().with_variant(mysql.MEDIUMTEXT(), "mysql"), + nullable=False, + ), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("id"), + ) + + # Create foreign key constraints + create_fks_for_table( + "fk_data_access_rules_role_id_ab_role", + "data_access_rules", + "ab_role", + ["role_id"], + ["id"], + ondelete="CASCADE", + ) + + create_fks_for_table( + "fk_data_access_rules_created_by_fk_ab_user", + "data_access_rules", + "ab_user", + ["created_by_fk"], + ["id"], + ) + + create_fks_for_table( + "fk_data_access_rules_changed_by_fk_ab_user", + "data_access_rules", + "ab_user", + ["changed_by_fk"], + ["id"], + ) + + +def downgrade(): + drop_table("data_access_rules") diff --git a/tests/unit_tests/data_access_rules/__init__.py b/tests/unit_tests/data_access_rules/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/data_access_rules/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/data_access_rules/schemas_test.py b/tests/unit_tests/data_access_rules/schemas_test.py new file mode 100644 index 00000000000..1119d85e828 --- /dev/null +++ b/tests/unit_tests/data_access_rules/schemas_test.py @@ -0,0 +1,261 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Unit tests for Data Access Rules schemas. +""" + +import json + +import pytest +from marshmallow import ValidationError + +from superset.data_access_rules.schemas import ( + DataAccessRulePostSchema, + DataAccessRulePutSchema, +) + + +def test_post_schema_valid_rule(): + """Test that valid rule JSON is accepted.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps( + { + "allowed": [{"database": "mydb", "schema": "public"}], + "denied": [], + } + ), + } + result = schema.load(data) + assert result["role_id"] == 1 + assert "allowed" in json.loads(result["rule"]) + + +def test_post_schema_complex_rule(): + """Test that complex rule with RLS and CLS is accepted.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps( + { + "allowed": [ + {"database": "mydb", "schema": "public"}, + { + "database": "mydb", + "schema": "orders", + "table": "items", + "rls": {"predicate": "org_id = 123", "group_key": "org"}, + }, + { + "database": "mydb", + "schema": "users", + "table": "info", + "cls": {"email": "mask", "ssn": "hide", "name": "hash"}, + }, + ], + "denied": [{"database": "mydb", "schema": "internal"}], + } + ), + } + result = schema.load(data) + assert result["role_id"] == 1 + + +def test_post_schema_invalid_json(): + """Test that invalid JSON is rejected.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": "not valid json", + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "Invalid JSON" in str(exc_info.value) + + +def test_post_schema_rule_not_object(): + """Test that non-object rule is rejected.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps(["not", "an", "object"]), + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "must be a JSON object" in str(exc_info.value) + + +def test_post_schema_allowed_not_list(): + """Test that non-list 'allowed' is rejected.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps({"allowed": "not a list", "denied": []}), + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "'allowed' must be a list" in str(exc_info.value) + + +def test_post_schema_denied_not_list(): + """Test that non-list 'denied' is rejected.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps({"allowed": [], "denied": "not a list"}), + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "'denied' must be a list" in str(exc_info.value) + + +def test_post_schema_entry_not_object(): + """Test that non-object entry is rejected.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps({"allowed": ["not an object"], "denied": []}), + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "must be an object" in str(exc_info.value) + + +def test_post_schema_entry_missing_database(): + """Test that entry without 'database' is rejected.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps({"allowed": [{"schema": "public"}], "denied": []}), + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "'database' field" in str(exc_info.value) + + +def test_post_schema_invalid_cls_action(): + """Test that invalid CLS action is rejected.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps( + { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "cls": {"email": "invalid_action"}, + } + ], + "denied": [], + } + ), + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "Invalid CLS action" in str(exc_info.value) + + +def test_post_schema_missing_role_id(): + """Test that missing role_id is rejected.""" + schema = DataAccessRulePostSchema() + data = { + "rule": json.dumps({"allowed": [], "denied": []}), + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "role_id" in str(exc_info.value) + + +def test_post_schema_missing_rule(): + """Test that missing rule is rejected.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "rule" in str(exc_info.value) + + +def test_put_schema_partial_update(): + """Test that PUT schema allows partial updates.""" + schema = DataAccessRulePutSchema() + + # Only updating role_id + data = {"role_id": 2} + result = schema.load(data) + assert result == {"role_id": 2} + + # Only updating rule + data = {"rule": json.dumps({"allowed": [{"database": "newdb"}], "denied": []})} + result = schema.load(data) + assert "rule" in result + + +def test_put_schema_validates_rule_if_provided(): + """Test that PUT schema validates rule if provided.""" + schema = DataAccessRulePutSchema() + data = { + "rule": "invalid json", + } + with pytest.raises(ValidationError) as exc_info: + schema.load(data) + assert "Invalid JSON" in str(exc_info.value) + + +def test_post_schema_empty_allowed_denied(): + """Test that empty allowed and denied lists are valid.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps({"allowed": [], "denied": []}), + } + result = schema.load(data) + assert result["role_id"] == 1 + + +def test_post_schema_cls_all_valid_actions(): + """Test all valid CLS actions are accepted.""" + schema = DataAccessRulePostSchema() + data = { + "role_id": 1, + "rule": json.dumps( + { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "cls": { + "col1": "hash", + "col2": "HASH", # Case insensitive + "col3": "nullify", + "col4": "NULLIFY", + "col5": "mask", + "col6": "MASK", + "col7": "hide", + "col8": "HIDE", + }, + } + ], + "denied": [], + } + ), + } + result = schema.load(data) + assert result["role_id"] == 1 diff --git a/tests/unit_tests/data_access_rules/utils_test.py b/tests/unit_tests/data_access_rules/utils_test.py new file mode 100644 index 00000000000..686cb5c2bc0 --- /dev/null +++ b/tests/unit_tests/data_access_rules/utils_test.py @@ -0,0 +1,542 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Unit tests for Data Access Rules utility functions. +""" + +from unittest.mock import MagicMock, patch + +from superset.data_access_rules.models import DataAccessRule +from superset.data_access_rules.utils import ( + _is_more_specific, + _matches_rule_entry, + AccessCheckResult, + check_table_access, + get_all_group_keys, + get_cls_rules_for_table, + get_rls_predicates_for_table, +) +from superset.sql.parse import CLSAction, Table + + +# Tests for _matches_rule_entry +def test_matches_rule_entry_database_only(): + """Test matching when rule specifies only database.""" + entry = {"database": "mydb"} + assert _matches_rule_entry(entry, "mydb", None, None, None) is True + assert _matches_rule_entry(entry, "mydb", "catalog1", "schema1", "table1") is True + assert _matches_rule_entry(entry, "otherdb", None, None, None) is False + + +def test_matches_rule_entry_database_and_catalog(): + """Test matching when rule specifies database and catalog.""" + entry = {"database": "mydb", "catalog": "cat1"} + assert _matches_rule_entry(entry, "mydb", "cat1", None, None) is True + assert _matches_rule_entry(entry, "mydb", "cat1", "schema1", "table1") is True + assert _matches_rule_entry(entry, "mydb", "cat2", None, None) is False + assert _matches_rule_entry(entry, "otherdb", "cat1", None, None) is False + + +def test_matches_rule_entry_database_and_schema(): + """Test matching when rule specifies database and schema (no catalog).""" + entry = {"database": "mydb", "schema": "public"} + assert _matches_rule_entry(entry, "mydb", None, "public", None) is True + assert _matches_rule_entry(entry, "mydb", None, "public", "table1") is True + assert _matches_rule_entry(entry, "mydb", None, "other", None) is False + + +def test_matches_rule_entry_full_table(): + """Test matching when rule specifies full table path.""" + entry = {"database": "mydb", "schema": "public", "table": "users"} + assert _matches_rule_entry(entry, "mydb", None, "public", "users") is True + assert _matches_rule_entry(entry, "mydb", None, "public", "orders") is False + assert _matches_rule_entry(entry, "mydb", None, "other", "users") is False + + +def test_matches_rule_entry_with_catalog(): + """Test matching with catalog in the path.""" + entry = { + "database": "mydb", + "catalog": "main", + "schema": "public", + "table": "users", + } + assert _matches_rule_entry(entry, "mydb", "main", "public", "users") is True + assert _matches_rule_entry(entry, "mydb", "other", "public", "users") is False + + +# Tests for _is_more_specific +def test_is_more_specific(): + """Test specificity comparison between entries.""" + db_only = {"database": "mydb"} + db_schema = {"database": "mydb", "schema": "public"} + db_table = {"database": "mydb", "schema": "public", "table": "users"} + db_catalog = {"database": "mydb", "catalog": "main"} + db_catalog_schema = {"database": "mydb", "catalog": "main", "schema": "public"} + + # More specific should win + assert _is_more_specific(db_schema, db_only) is True + assert _is_more_specific(db_table, db_schema) is True + assert _is_more_specific(db_table, db_only) is True + + # Less specific should lose + assert _is_more_specific(db_only, db_schema) is False + assert _is_more_specific(db_schema, db_table) is False + + # Same specificity + assert _is_more_specific(db_schema, db_catalog) is False + assert _is_more_specific(db_catalog, db_schema) is False + + +# Tests for check_table_access +def test_check_table_access_no_rules(): + """Test access check when no rules are provided.""" + table = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table, rules=[]) + assert result.access == AccessCheckResult.NO_RULE + assert result.rls_predicates == [] + assert result.cls_rules == {} + + +def test_check_table_access_allowed(): + """Test access check when table is allowed.""" + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [{"database": "mydb", "schema": "public"}], + "denied": [], + } + + table = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table, rules=[rule]) + assert result.access == AccessCheckResult.ALLOWED + assert result.rls_predicates == [] + assert result.cls_rules == {} + + +def test_check_table_access_denied(): + """Test access check when table is denied.""" + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [], + "denied": [{"database": "mydb", "schema": "public"}], + } + + table = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table, rules=[rule]) + assert result.access == AccessCheckResult.DENIED + + +def test_check_table_access_denied_more_specific(): + """Test that more specific deny wins over less specific allow.""" + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [{"database": "mydb"}], # Less specific + "denied": [{"database": "mydb", "schema": "secret"}], # More specific + } + + # Table in non-denied schema should be allowed + table_public = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table_public, rules=[rule]) + assert result.access == AccessCheckResult.ALLOWED + + # Table in denied schema should be denied + table_secret = Table(table="data", schema="secret", catalog=None) + result = check_table_access("mydb", table_secret, rules=[rule]) + assert result.access == AccessCheckResult.DENIED + + +def test_check_table_access_allowed_more_specific(): + """Test that more specific allow wins over less specific deny.""" + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [{"database": "mydb", "schema": "public", "table": "users"}], + "denied": [{"database": "mydb", "schema": "public"}], + } + + # The specific table is allowed despite schema being denied + table_users = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table_users, rules=[rule]) + assert result.access == AccessCheckResult.ALLOWED + + # Other tables in the schema are still denied + table_orders = Table(table="orders", schema="public", catalog=None) + result = check_table_access("mydb", table_orders, rules=[rule]) + assert result.access == AccessCheckResult.DENIED + + +def test_check_table_access_same_specificity_deny_wins(): + """Test that deny wins when rules have same specificity.""" + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [{"database": "mydb", "schema": "public"}], + "denied": [{"database": "mydb", "schema": "public"}], + } + + table = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table, rules=[rule]) + assert result.access == AccessCheckResult.DENIED + + +def test_check_table_access_with_rls(): + """Test access check collects RLS predicates.""" + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "table": "users", + "rls": {"predicate": "org_id = 123", "group_key": "org"}, + } + ], + "denied": [], + } + + table = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table, rules=[rule]) + assert result.access == AccessCheckResult.ALLOWED + assert len(result.rls_predicates) == 1 + assert result.rls_predicates[0].predicate == "org_id = 123" + assert result.rls_predicates[0].group_key == "org" + + +def test_check_table_access_multiple_rls(): + """Test access check collects multiple RLS predicates from different rules.""" + rule1 = MagicMock(spec=DataAccessRule) + rule1.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "rls": {"predicate": "org_id = 123", "group_key": "org"}, + } + ], + "denied": [], + } + + rule2 = MagicMock(spec=DataAccessRule) + rule2.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "rls": {"predicate": "region = 'US'"}, + } + ], + "denied": [], + } + + table = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table, rules=[rule1, rule2]) + assert result.access == AccessCheckResult.ALLOWED + assert len(result.rls_predicates) == 2 + + predicates = [p.predicate for p in result.rls_predicates] + assert "org_id = 123" in predicates + assert "region = 'US'" in predicates + + +def test_check_table_access_with_cls(): + """Test access check collects CLS rules.""" + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "table": "users", + "cls": {"email": "mask", "ssn": "hide", "name": "hash"}, + } + ], + "denied": [], + } + + table = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table, rules=[rule]) + assert result.access == AccessCheckResult.ALLOWED + assert result.cls_rules == { + "email": CLSAction.MASK, + "ssn": CLSAction.HIDE, + "name": CLSAction.HASH, + } + + +def test_check_table_access_cls_strictest_wins(): + """Test that strictest CLS action wins when multiple rules apply.""" + rule1 = MagicMock(spec=DataAccessRule) + rule1.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "cls": {"email": "mask"}, # Less strict + } + ], + "denied": [], + } + + rule2 = MagicMock(spec=DataAccessRule) + rule2.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "cls": {"email": "hide"}, # More strict - should win + } + ], + "denied": [], + } + + table = Table(table="users", schema="public", catalog=None) + result = check_table_access("mydb", table, rules=[rule1, rule2]) + assert result.access == AccessCheckResult.ALLOWED + assert result.cls_rules["email"] == CLSAction.HIDE + + +# Tests for get_rls_predicates_for_table +def test_get_rls_predicates_for_table_no_predicates(): + """Test getting RLS predicates when there are none.""" + database = MagicMock() + database.database_name = "mydb" + table = Table(table="users", schema="public", catalog=None) + + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [{"database": "mydb", "schema": "public"}], + "denied": [], + } + + predicates = get_rls_predicates_for_table(table, database, rules=[rule]) + assert predicates == [] + + +def test_get_rls_predicates_for_table_with_predicates(): + """Test getting RLS predicates.""" + database = MagicMock() + database.database_name = "mydb" + table = Table(table="users", schema="public", catalog=None) + + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "rls": {"predicate": "org_id = 123"}, + } + ], + "denied": [], + } + + predicates = get_rls_predicates_for_table(table, database, rules=[rule]) + assert predicates == ["(org_id = 123)"] + + +def test_get_rls_predicates_for_table_with_group_key(): + """Test getting RLS predicates with group_key combines with OR.""" + database = MagicMock() + database.database_name = "mydb" + table = Table(table="users", schema="public", catalog=None) + + rule1 = MagicMock(spec=DataAccessRule) + rule1.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "rls": {"predicate": "org_id = 1", "group_key": "org"}, + } + ], + "denied": [], + } + + rule2 = MagicMock(spec=DataAccessRule) + rule2.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "rls": {"predicate": "org_id = 2", "group_key": "org"}, + } + ], + "denied": [], + } + + predicates = get_rls_predicates_for_table(table, database, rules=[rule1, rule2]) + # Same group_key predicates should be ORed + assert len(predicates) == 1 + assert "(org_id = 1)" in predicates[0] + assert "(org_id = 2)" in predicates[0] + assert " OR " in predicates[0] + + +def test_get_rls_predicates_for_table_mixed_group_keys(): + """Test getting RLS predicates with mixed group_keys.""" + database = MagicMock() + database.database_name = "mydb" + table = Table(table="users", schema="public", catalog=None) + + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "rls": {"predicate": "org_id = 1", "group_key": "org"}, + }, + { + "database": "mydb", + "schema": "public", + "rls": {"predicate": "org_id = 2", "group_key": "org"}, + }, + { + "database": "mydb", + "schema": "public", + "rls": {"predicate": "region = 'US'"}, # No group_key + }, + ], + "denied": [], + } + + predicates = get_rls_predicates_for_table(table, database, rules=[rule]) + # Should have: ungrouped predicate + ORed group predicate = 2 items + assert len(predicates) == 2 + + has_region = any("region = 'US'" in p for p in predicates) + has_org_group = any("org_id = 1" in p and "org_id = 2" in p for p in predicates) + assert has_region + assert has_org_group + + +# Tests for get_cls_rules_for_table +def test_get_cls_rules_for_table_no_rules(): + """Test getting CLS rules when there are none.""" + database = MagicMock() + database.database_name = "mydb" + table = Table(table="users", schema="public", catalog=None) + + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [{"database": "mydb", "schema": "public"}], + "denied": [], + } + + cls_rules = get_cls_rules_for_table(table, database, rules=[rule]) + assert cls_rules == {} + + +def test_get_cls_rules_for_table_with_rules(): + """Test getting CLS rules.""" + database = MagicMock() + database.database_name = "mydb" + table = Table(table="users", schema="public", catalog=None) + + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [ + { + "database": "mydb", + "schema": "public", + "table": "users", + "cls": {"email": "mask", "ssn": "hide"}, + } + ], + "denied": [], + } + + cls_rules = get_cls_rules_for_table(table, database, rules=[rule]) + assert cls_rules == {"email": CLSAction.MASK, "ssn": CLSAction.HIDE} + + +# Tests for get_all_group_keys +def test_get_all_group_keys_empty(app_context: None): + """Test getting group keys when none exist.""" + with patch("superset.data_access_rules.utils.db") as mock_db: + mock_db.session.query.return_value.all.return_value = [] + keys = get_all_group_keys() + assert keys == set() + + +def test_get_all_group_keys_with_keys(app_context: None): + """Test getting group keys from rules.""" + rule1 = MagicMock(spec=DataAccessRule) + rule1.rule_dict = { + "allowed": [ + {"database": "mydb", "rls": {"predicate": "x=1", "group_key": "key1"}}, + {"database": "mydb", "rls": {"predicate": "x=2", "group_key": "key2"}}, + ], + "denied": [], + } + + rule2 = MagicMock(spec=DataAccessRule) + rule2.rule_dict = { + "allowed": [ + {"database": "mydb", "rls": {"predicate": "x=3", "group_key": "key1"}}, + {"database": "mydb", "rls": {"predicate": "x=4"}}, # No group_key + ], + "denied": [], + } + + with patch("superset.data_access_rules.utils.db") as mock_db: + mock_db.session.query.return_value.all.return_value = [rule1, rule2] + keys = get_all_group_keys() + assert keys == {"key1", "key2"} + + +def test_get_all_group_keys_filtered_by_database(app_context: None): + """Test getting group keys filtered by database.""" + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [ + {"database": "db1", "rls": {"predicate": "x=1", "group_key": "key1"}}, + {"database": "db2", "rls": {"predicate": "x=2", "group_key": "key2"}}, + ], + "denied": [], + } + + with patch("superset.data_access_rules.utils.db") as mock_db: + mock_db.session.query.return_value.all.return_value = [rule] + keys = get_all_group_keys(database_name="db1") + assert keys == {"key1"} + + +def test_get_all_group_keys_filtered_by_table(app_context: None): + """Test getting group keys filtered by table.""" + rule = MagicMock(spec=DataAccessRule) + rule.rule_dict = { + "allowed": [ + { + "database": "db1", + "schema": "public", + "table": "users", + "rls": {"predicate": "x=1", "group_key": "key1"}, + }, + { + "database": "db1", + "schema": "public", + "table": "orders", + "rls": {"predicate": "x=2", "group_key": "key2"}, + }, + ], + "denied": [], + } + + with patch("superset.data_access_rules.utils.db") as mock_db: + mock_db.session.query.return_value.all.return_value = [rule] + table = Table(table="users", schema="public", catalog=None) + keys = get_all_group_keys(database_name="db1", table=table) + assert keys == {"key1"}