Files
superset2/superset/data_access_rules/utils.py
Beto Dealmeida ca7635dfc2 UI progress
2025-12-17 19:34:52 -05:00

897 lines
28 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Data Access Rules utility functions.
This module provides functions for:
- Checking if a user has access to a table
- Collecting RLS predicates for a table
- Collecting CLS rules for a table
- Applying RLS and CLS to SQL queries
"""
from __future__ import annotations
import logging
from collections import defaultdict
from dataclasses import dataclass
from enum import Enum
from typing import Any, TYPE_CHECKING
from flask import g
from superset import db, is_feature_enabled, security_manager
from superset.sql.parse import CLSAction, Table
if TYPE_CHECKING:
from superset.data_access_rules.models import DataAccessRule
from superset.models.core import Database
from superset.sql.parse import BaseSQLStatement
logger = logging.getLogger(__name__)
class AccessCheckResult(Enum):
"""Result of an access check."""
ALLOWED = "allowed"
DENIED = "denied"
NO_RULE = "no_rule"
@dataclass
class RLSPredicate:
"""An RLS predicate with optional group_key."""
predicate: str
group_key: str | None = None
@dataclass
class TableAccessInfo:
"""Information about access to a specific table."""
access: AccessCheckResult
rls_predicates: list[RLSPredicate]
cls_rules: dict[str, CLSAction]
def get_user_rules() -> list[DataAccessRule]:
"""
Get all data access rules for the current user's roles.
Returns:
List of DataAccessRule objects for the current user's roles.
"""
from superset.data_access_rules.models import DataAccessRule
if not hasattr(g, "user") or not g.user:
return []
user_roles = security_manager.get_user_roles()
role_ids = [role.id for role in user_roles]
if not role_ids:
return []
return (
db.session.query(DataAccessRule)
.filter(DataAccessRule.role_id.in_(role_ids))
.all()
)
def _matches_rule_entry(
entry: dict[str, Any],
database_name: str,
catalog: str | None,
schema: str | None,
table_name: str | None,
) -> bool:
"""
Check if a rule entry matches the given database/catalog/schema/table.
The rule entry can specify any level of the hierarchy:
- database only: matches all catalogs/schemas/tables in that database
- database + catalog: matches all schemas/tables in that catalog
- database + schema: matches all tables in that schema (for DBs without catalogs)
- database + catalog + schema: matches all tables in that schema
- database + schema + table: matches the specific table (for DBs without catalogs)
- database + catalog + schema + table: matches the specific table
Args:
entry: The rule entry dict
database_name: The database name to check
catalog: The catalog to check (None if DB doesn't support catalogs)
schema: The schema to check
table_name: The table name to check
Returns:
True if the entry matches, False otherwise.
"""
# Database must always match
if entry.get("database") != database_name:
return False
entry_catalog = entry.get("catalog")
entry_schema = entry.get("schema")
entry_table = entry.get("table")
# If the entry specifies a catalog, it must match (or catalog must be None/default)
if entry_catalog is not None:
if catalog is not None and entry_catalog != catalog:
return False
# If the entry specifies a schema, it must match
if entry_schema is not None:
if schema is not None and entry_schema != schema:
return False
# If the entry specifies a table, it must match
if entry_table is not None:
if table_name is not None and entry_table != table_name:
return False
# Check specificity: entry must be at least as specific as the query
# If querying a specific table, entry must specify that table or be broader
if table_name is not None and entry_table is not None and entry_table != table_name:
return False
return True
def _is_more_specific(entry: dict[str, Any], other: dict[str, Any]) -> bool:
"""
Check if 'entry' is more specific than 'other'.
More specific means it specifies more levels of the hierarchy.
"""
entry_specificity = sum(
[
entry.get("catalog") is not None,
entry.get("schema") is not None,
entry.get("table") is not None,
]
)
other_specificity = sum(
[
other.get("catalog") is not None,
other.get("schema") is not None,
other.get("table") is not None,
]
)
return entry_specificity > other_specificity
def check_table_access(
database_name: str,
table: Table,
rules: list[DataAccessRule] | None = None,
) -> TableAccessInfo:
"""
Check if the current user has access to a specific table.
The function evaluates all rules for the user's roles and determines:
1. Whether access is allowed, denied, or no rule applies
2. Any RLS predicates that should be applied
3. Any CLS rules for column masking/hiding
Denied rules take precedence over allowed rules when at the same specificity level.
More specific rules take precedence over less specific rules.
Args:
database_name: The database name
table: The Table object with catalog, schema, and table name
rules: Optional list of rules to check (defaults to current user's rules)
Returns:
TableAccessInfo with access result, RLS predicates, and CLS rules.
"""
if rules is None:
rules = get_user_rules()
if not rules:
return TableAccessInfo(
access=AccessCheckResult.NO_RULE,
rls_predicates=[],
cls_rules={},
)
# Collect all matching rules
allowed_entries: list[dict[str, Any]] = []
denied_entries: list[dict[str, Any]] = []
for rule in rules:
rule_dict = rule.rule_dict
# Check allowed entries
for entry in rule_dict.get("allowed", []):
if _matches_rule_entry(
entry, database_name, table.catalog, table.schema, table.table
):
allowed_entries.append(entry)
# Check denied entries
for entry in rule_dict.get("denied", []):
if _matches_rule_entry(
entry, database_name, table.catalog, table.schema, table.table
):
denied_entries.append(entry)
# If no rules match, return NO_RULE
if not allowed_entries and not denied_entries:
return TableAccessInfo(
access=AccessCheckResult.NO_RULE,
rls_predicates=[],
cls_rules={},
)
# Find the most specific denied entry
most_specific_denied = None
for entry in denied_entries:
if most_specific_denied is None or _is_more_specific(
entry, most_specific_denied
):
most_specific_denied = entry
# Find the most specific allowed entry
most_specific_allowed = None
for entry in allowed_entries:
if most_specific_allowed is None or _is_more_specific(
entry, most_specific_allowed
):
most_specific_allowed = entry
# Determine access: deny wins at same specificity, more specific wins otherwise
if most_specific_denied is not None and most_specific_allowed is not None:
if _is_more_specific(most_specific_denied, most_specific_allowed):
return TableAccessInfo(
access=AccessCheckResult.DENIED,
rls_predicates=[],
cls_rules={},
)
elif _is_more_specific(most_specific_allowed, most_specific_denied):
# Access allowed, collect RLS and CLS from matching entries
pass
else:
# Same specificity: denied wins
return TableAccessInfo(
access=AccessCheckResult.DENIED,
rls_predicates=[],
cls_rules={},
)
elif most_specific_denied is not None:
return TableAccessInfo(
access=AccessCheckResult.DENIED,
rls_predicates=[],
cls_rules={},
)
elif most_specific_allowed is None:
return TableAccessInfo(
access=AccessCheckResult.NO_RULE,
rls_predicates=[],
cls_rules={},
)
# Collect RLS predicates from all matching allowed entries
# (RLS is cumulative - all predicates are applied)
rls_predicates: list[RLSPredicate] = []
for entry in allowed_entries:
rls_config = entry.get("rls")
if rls_config and "predicate" in rls_config:
rls_predicates.append(
RLSPredicate(
predicate=rls_config["predicate"],
group_key=rls_config.get("group_key"),
)
)
# Collect CLS rules from all matching allowed entries
# (CLS is cumulative - strictest action wins per column)
cls_rules: dict[str, CLSAction] = {}
cls_precedence = {
CLSAction.HIDE: 4,
CLSAction.NULLIFY: 3,
CLSAction.MASK: 2,
CLSAction.HASH: 1,
}
action_map = {
"hide": CLSAction.HIDE,
"nullify": CLSAction.NULLIFY,
"mask": CLSAction.MASK,
"hash": CLSAction.HASH,
}
for entry in allowed_entries:
cls_config = entry.get("cls", {})
for column, action_str in cls_config.items():
action = action_map.get(action_str.lower())
if action is None:
logger.warning("Unknown CLS action: %s", action_str)
continue
existing = cls_rules.get(column)
if existing is None or cls_precedence[action] > cls_precedence[existing]:
cls_rules[column] = action
return TableAccessInfo(
access=AccessCheckResult.ALLOWED,
rls_predicates=rls_predicates,
cls_rules=cls_rules,
)
def get_rls_predicates_for_table(
table: Table,
database: Database,
rules: list[DataAccessRule] | None = None,
) -> list[str]:
"""
Get the RLS predicates for a table using the new Data Access Rules system.
This function collects all RLS predicates from matching rules and combines them
using the group_key logic:
- Predicates without group_key are ANDed together
- Predicates with the same group_key are ORed together
- Groups are ANDed together
Args:
table: The fully qualified Table object
database: The Database object
rules: Optional list of rules to check (defaults to current user's rules)
Returns:
List of SQL predicate strings to be ANDed together.
"""
access_info = check_table_access(
database_name=database.database_name,
table=table,
rules=rules,
)
if access_info.access != AccessCheckResult.ALLOWED:
return []
if not access_info.rls_predicates:
return []
# Group predicates by group_key
ungrouped: list[str] = []
groups: dict[str, list[str]] = defaultdict(list)
for pred in access_info.rls_predicates:
if pred.group_key:
groups[pred.group_key].append(f"({pred.predicate})")
else:
ungrouped.append(f"({pred.predicate})")
# Build result: ungrouped predicates + OR'd groups
result = ungrouped.copy()
for group_predicates in groups.values():
if len(group_predicates) == 1:
result.append(group_predicates[0])
else:
result.append(f"({' OR '.join(group_predicates)})")
return result
def get_cls_rules_for_table(
table: Table,
database: Database,
rules: list[DataAccessRule] | None = None,
) -> dict[str, CLSAction]:
"""
Get the CLS rules for a table using the new Data Access Rules system.
Args:
table: The fully qualified Table object
database: The Database object
rules: Optional list of rules to check (defaults to current user's rules)
Returns:
Dict mapping column names to CLSAction values.
"""
access_info = check_table_access(
database_name=database.database_name,
table=table,
rules=rules,
)
if access_info.access != AccessCheckResult.ALLOWED:
return {}
return access_info.cls_rules
def get_hidden_columns_for_table(
table: Table,
database: Database,
rules: list[DataAccessRule] | None = None,
) -> set[str]:
"""
Get the set of column names that should be hidden for a table.
This function checks the CLS rules for the current user and returns
the names of columns that have the "hide" action applied.
Args:
table: The fully qualified Table object
database: The Database object
rules: Optional list of rules to check (defaults to current user's rules)
Returns:
Set of column names that should be hidden.
"""
cls_rules = get_cls_rules_for_table(table, database, rules)
hidden_columns: set[str] = set()
for column_name, action in cls_rules.items():
if action == CLSAction.HIDE:
hidden_columns.add(column_name)
return hidden_columns
def filter_columns_by_cls(
columns: list[dict[str, Any]],
table: Table,
database: Database,
column_name_key: str = "column_name",
) -> list[dict[str, Any]]:
"""
Filter a list of column dictionaries to exclude hidden columns.
This function is useful for filtering column metadata returned by
database reflection or dataset APIs.
Args:
columns: List of column dictionaries
table: The fully qualified Table object
database: The Database object
column_name_key: The key in the column dict that contains the column name
Returns:
Filtered list of columns with hidden columns removed.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return columns
hidden_columns = get_hidden_columns_for_table(table, database)
if not hidden_columns:
return columns
return [
col for col in columns
if col.get(column_name_key) not in hidden_columns
]
def apply_data_access_rules(
database: Database,
catalog: str | None,
schema: str,
parsed_statement: BaseSQLStatement[Any],
) -> None:
"""
Apply Data Access Rules (RLS and CLS) to a parsed SQL statement.
This function:
1. Checks if the DATA_ACCESS_RULES feature is enabled
2. For each table in the query, checks access and collects RLS/CLS rules
3. Applies RLS predicates using the existing infrastructure
4. Applies CLS rules using the existing infrastructure
Args:
database: The Database object
catalog: The default catalog for the query
schema: The default schema for the query
parsed_statement: The parsed SQL statement to modify in place
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return
from superset.sql.parse import CLSRules
rules = get_user_rules()
if not rules:
return
# Get the RLS method for this database
method = database.db_engine_spec.get_rls_method()
# Collect RLS predicates and CLS rules for all tables
rls_predicates: dict[Table, list[Any]] = {}
cls_rules: CLSRules = {}
for table in parsed_statement.tables:
qualified_table = table.qualify(catalog=catalog, schema=schema)
# Check access first
access_info = check_table_access(
database_name=database.database_name,
table=qualified_table,
rules=rules,
)
if access_info.access == AccessCheckResult.DENIED:
# TODO: How should we handle denied access mid-query?
# For now, log a warning. In the future, we might raise an exception.
logger.warning(
"Access denied to table %s for user %s",
qualified_table,
getattr(g, "user", "unknown"),
)
continue
# Collect RLS predicates
predicates = get_rls_predicates_for_table(qualified_table, database, rules)
if predicates:
rls_predicates[qualified_table] = [
parsed_statement.parse_predicate(pred) for pred in predicates if pred
]
# Collect CLS rules
table_cls = get_cls_rules_for_table(qualified_table, database, rules)
if table_cls:
cls_rules[qualified_table] = table_cls
# Apply CLS first (before RLS) so that hidden columns are removed
# before RLS wraps the query in a subquery
if cls_rules:
# Build schema dict for sqlglot's qualify() to expand SELECT *
# sqlglot expects nested format: {catalog: {schema: {table: {col: type}}}}
# or {schema: {table: {col: type}}} without catalog
table_schemas: dict[str, Any] = {}
for table in cls_rules.keys():
try:
columns = database.get_columns(table)
col_types = {
col["column_name"]: str(col.get("type", "VARCHAR"))
for col in columns
}
# Build nested structure for sqlglot
if table.catalog:
if table.catalog not in table_schemas:
table_schemas[table.catalog] = {}
if table.schema:
if table.schema not in table_schemas[table.catalog]:
table_schemas[table.catalog][table.schema] = {}
table_schemas[table.catalog][table.schema][table.table] = col_types
else:
table_schemas[table.catalog][table.table] = col_types
elif table.schema:
if table.schema not in table_schemas:
table_schemas[table.schema] = {}
table_schemas[table.schema][table.table] = col_types
else:
table_schemas[table.table] = col_types
except Exception as ex:
logger.warning(
"Could not fetch schema for table %s: %s",
table,
ex,
)
parsed_statement.apply_cls(cls_rules, schema=table_schemas if table_schemas else None)
# Apply RLS after CLS - RLS wraps the query in a subquery with SELECT *
# which will pick up the already-transformed columns from CLS
if rls_predicates:
parsed_statement.apply_rls(catalog, schema, rls_predicates, method)
def get_allowed_tables(
database_name: str,
schema: str | None = None,
catalog: str | None = None,
) -> tuple[set[str], bool]:
"""
Get all table names that the current user has access to via Data Access Rules
for a specific database and schema.
Args:
database_name: The database name to check
schema: Optional schema name to filter by
catalog: Optional catalog name to filter by
Returns:
Tuple of (set of table names, bool indicating if schema-level access is granted).
If schema-level access is granted, the set may be empty but all tables are allowed.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return set(), False
rules = get_user_rules()
if not rules:
return set(), False
table_names: set[str] = set()
schema_level_access = False
for rule in rules:
rule_dict = rule.rule_dict
# Collect tables from allowed entries
for entry in rule_dict.get("allowed", []):
if entry.get("database") != database_name:
continue
# If catalog is specified in the entry, it must match
entry_catalog = entry.get("catalog")
if catalog is not None and entry_catalog is not None:
if entry_catalog != catalog:
continue
# If schema is specified, check if it matches
entry_schema = entry.get("schema")
if schema is not None and entry_schema is not None:
if entry_schema != schema:
continue
# If entry has a table, add it to the set
if table := entry.get("table"):
table_names.add(table)
elif entry_schema == schema or (entry_schema is None and schema is None):
# Schema-level or database-level access without table means all tables
schema_level_access = True
return table_names, schema_level_access
def get_allowed_schemas(database_name: str, catalog: str | None = None) -> set[str]:
"""
Get all schema names that the current user has access to via Data Access Rules
for a specific database.
Args:
database_name: The database name to check
catalog: Optional catalog name to filter by
Returns:
Set of schema names the user has access to.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return set()
rules = get_user_rules()
if not rules:
return set()
schema_names: set[str] = set()
for rule in rules:
rule_dict = rule.rule_dict
# Collect schemas from allowed entries
for entry in rule_dict.get("allowed", []):
if entry.get("database") != database_name:
continue
# If catalog is specified in the entry, it must match
entry_catalog = entry.get("catalog")
if catalog is not None and entry_catalog is not None:
if entry_catalog != catalog:
continue
# If the entry grants database-level access (no schema specified),
# we return an empty set to indicate "all schemas" should be allowed
# This will be handled by the caller
if schema := entry.get("schema"):
schema_names.add(schema)
elif entry.get("database") == database_name:
# Database-level access without schema means all schemas
# Return a special marker that caller can check
schema_names.add("*")
return schema_names
def get_allowed_databases() -> set[str]:
"""
Get all database names that the current user has access to via Data Access Rules.
This function is used to populate database selectors in SQL Lab and elsewhere.
Returns:
Set of database names the user has access to.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return set()
rules = get_user_rules()
if not rules:
return set()
database_names: set[str] = set()
for rule in rules:
rule_dict = rule.rule_dict
# Collect databases from allowed entries
for entry in rule_dict.get("allowed", []):
if database := entry.get("database"):
database_names.add(database)
return database_names
@dataclass
class AllowedTable:
"""A table allowed by DAR with database context."""
database: str
table: str
schema: str | None = None
catalog: str | None = None
@dataclass
class AllowedEntry:
"""
An allowed entry from DAR at any level of the hierarchy.
Fields may be None to indicate "all" at that level:
- database only: all catalogs/schemas/tables in that database
- database + catalog: all schemas/tables in that catalog
- database + schema: all tables in that schema (for DBs without catalogs)
- database + catalog + schema: all tables in that schema
- database + schema + table: specific table (for DBs without catalogs)
- database + catalog + schema + table: specific table
"""
database: str
catalog: str | None = None
schema: str | None = None
table: str | None = None
def get_all_allowed_entries() -> list[AllowedEntry]:
"""
Get all access entries that the current user has via Data Access Rules
across all databases.
This function returns entries at all hierarchy levels (database, schema, table),
allowing callers to build appropriate filters for their use case.
Returns:
List of AllowedEntry objects representing allowed access.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return []
rules = get_user_rules()
if not rules:
return []
allowed_entries: list[AllowedEntry] = []
for rule in rules:
rule_dict = rule.rule_dict
# Collect all allowed entries
for entry in rule_dict.get("allowed", []):
database = entry.get("database")
if not database:
continue
allowed_entries.append(
AllowedEntry(
database=database,
catalog=entry.get("catalog"),
schema=entry.get("schema"),
table=entry.get("table"),
)
)
return allowed_entries
def get_all_allowed_tables() -> list[AllowedTable]:
"""
Get all tables that the current user has access to via Data Access Rules
across all databases.
This function is used for dataset filtering where we need to know all
specific tables the user can access.
Returns:
List of AllowedTable objects representing allowed tables.
"""
if not is_feature_enabled("DATA_ACCESS_RULES"):
return []
rules = get_user_rules()
if not rules:
return []
allowed_tables: list[AllowedTable] = []
for rule in rules:
rule_dict = rule.rule_dict
# Collect tables from allowed entries
for entry in rule_dict.get("allowed", []):
database = entry.get("database")
if not database:
continue
table_name = entry.get("table")
if not table_name:
# Skip database-level or schema-level access for now
# as we can't enumerate all tables without querying the DB
continue
allowed_tables.append(
AllowedTable(
database=database,
table=table_name,
schema=entry.get("schema"),
catalog=entry.get("catalog"),
)
)
return allowed_tables
def get_all_group_keys(
database_name: str | None = None,
table: Table | None = None,
) -> set[str]:
"""
Get all distinct group_keys used in RLS rules.
This is useful for UI discoverability - showing users what group_keys
already exist so they can reuse them for consistent rule grouping.
Args:
database_name: Optional filter by database
table: Optional Table object to filter by catalog/schema/table
Returns:
Set of unique group_key values.
"""
from superset.data_access_rules.models import DataAccessRule
query = db.session.query(DataAccessRule)
rules = query.all()
group_keys: set[str] = set()
for rule in rules:
rule_dict = rule.rule_dict
for entry in rule_dict.get("allowed", []):
# Apply filters if specified
if database_name and entry.get("database") != database_name:
continue
if table is not None:
if table.catalog and entry.get("catalog") != table.catalog:
continue
if table.schema and entry.get("schema") != table.schema:
continue
if table.table and entry.get("table") != table.table:
continue
rls_config = entry.get("rls", {})
if group_key := rls_config.get("group_key"):
group_keys.add(group_key)
return group_keys