mirror of
https://github.com/apache/superset.git
synced 2026-05-07 17:04:58 +00:00
897 lines
28 KiB
Python
897 lines
28 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
"""
|
|
Data Access Rules utility functions.
|
|
|
|
This module provides functions for:
|
|
- Checking if a user has access to a table
|
|
- Collecting RLS predicates for a table
|
|
- Collecting CLS rules for a table
|
|
- Applying RLS and CLS to SQL queries
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
from enum import Enum
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
from flask import g
|
|
|
|
from superset import db, is_feature_enabled, security_manager
|
|
from superset.sql.parse import CLSAction, Table
|
|
|
|
if TYPE_CHECKING:
|
|
from superset.data_access_rules.models import DataAccessRule
|
|
from superset.models.core import Database
|
|
from superset.sql.parse import BaseSQLStatement
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AccessCheckResult(Enum):
|
|
"""Result of an access check."""
|
|
|
|
ALLOWED = "allowed"
|
|
DENIED = "denied"
|
|
NO_RULE = "no_rule"
|
|
|
|
|
|
@dataclass
|
|
class RLSPredicate:
|
|
"""An RLS predicate with optional group_key."""
|
|
|
|
predicate: str
|
|
group_key: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class TableAccessInfo:
|
|
"""Information about access to a specific table."""
|
|
|
|
access: AccessCheckResult
|
|
rls_predicates: list[RLSPredicate]
|
|
cls_rules: dict[str, CLSAction]
|
|
|
|
|
|
def get_user_rules() -> list[DataAccessRule]:
|
|
"""
|
|
Get all data access rules for the current user's roles.
|
|
|
|
Returns:
|
|
List of DataAccessRule objects for the current user's roles.
|
|
"""
|
|
from superset.data_access_rules.models import DataAccessRule
|
|
|
|
if not hasattr(g, "user") or not g.user:
|
|
return []
|
|
|
|
user_roles = security_manager.get_user_roles()
|
|
role_ids = [role.id for role in user_roles]
|
|
|
|
if not role_ids:
|
|
return []
|
|
|
|
return (
|
|
db.session.query(DataAccessRule)
|
|
.filter(DataAccessRule.role_id.in_(role_ids))
|
|
.all()
|
|
)
|
|
|
|
|
|
def _matches_rule_entry(
|
|
entry: dict[str, Any],
|
|
database_name: str,
|
|
catalog: str | None,
|
|
schema: str | None,
|
|
table_name: str | None,
|
|
) -> bool:
|
|
"""
|
|
Check if a rule entry matches the given database/catalog/schema/table.
|
|
|
|
The rule entry can specify any level of the hierarchy:
|
|
- database only: matches all catalogs/schemas/tables in that database
|
|
- database + catalog: matches all schemas/tables in that catalog
|
|
- database + schema: matches all tables in that schema (for DBs without catalogs)
|
|
- database + catalog + schema: matches all tables in that schema
|
|
- database + schema + table: matches the specific table (for DBs without catalogs)
|
|
- database + catalog + schema + table: matches the specific table
|
|
|
|
Args:
|
|
entry: The rule entry dict
|
|
database_name: The database name to check
|
|
catalog: The catalog to check (None if DB doesn't support catalogs)
|
|
schema: The schema to check
|
|
table_name: The table name to check
|
|
|
|
Returns:
|
|
True if the entry matches, False otherwise.
|
|
"""
|
|
# Database must always match
|
|
if entry.get("database") != database_name:
|
|
return False
|
|
|
|
entry_catalog = entry.get("catalog")
|
|
entry_schema = entry.get("schema")
|
|
entry_table = entry.get("table")
|
|
|
|
# If the entry specifies a catalog, it must match (or catalog must be None/default)
|
|
if entry_catalog is not None:
|
|
if catalog is not None and entry_catalog != catalog:
|
|
return False
|
|
|
|
# If the entry specifies a schema, it must match
|
|
if entry_schema is not None:
|
|
if schema is not None and entry_schema != schema:
|
|
return False
|
|
|
|
# If the entry specifies a table, it must match
|
|
if entry_table is not None:
|
|
if table_name is not None and entry_table != table_name:
|
|
return False
|
|
|
|
# Check specificity: entry must be at least as specific as the query
|
|
# If querying a specific table, entry must specify that table or be broader
|
|
if table_name is not None and entry_table is not None and entry_table != table_name:
|
|
return False
|
|
|
|
return True
|
|
|
|
|
|
def _is_more_specific(entry: dict[str, Any], other: dict[str, Any]) -> bool:
|
|
"""
|
|
Check if 'entry' is more specific than 'other'.
|
|
|
|
More specific means it specifies more levels of the hierarchy.
|
|
"""
|
|
entry_specificity = sum(
|
|
[
|
|
entry.get("catalog") is not None,
|
|
entry.get("schema") is not None,
|
|
entry.get("table") is not None,
|
|
]
|
|
)
|
|
other_specificity = sum(
|
|
[
|
|
other.get("catalog") is not None,
|
|
other.get("schema") is not None,
|
|
other.get("table") is not None,
|
|
]
|
|
)
|
|
return entry_specificity > other_specificity
|
|
|
|
|
|
def check_table_access(
|
|
database_name: str,
|
|
table: Table,
|
|
rules: list[DataAccessRule] | None = None,
|
|
) -> TableAccessInfo:
|
|
"""
|
|
Check if the current user has access to a specific table.
|
|
|
|
The function evaluates all rules for the user's roles and determines:
|
|
1. Whether access is allowed, denied, or no rule applies
|
|
2. Any RLS predicates that should be applied
|
|
3. Any CLS rules for column masking/hiding
|
|
|
|
Denied rules take precedence over allowed rules when at the same specificity level.
|
|
More specific rules take precedence over less specific rules.
|
|
|
|
Args:
|
|
database_name: The database name
|
|
table: The Table object with catalog, schema, and table name
|
|
rules: Optional list of rules to check (defaults to current user's rules)
|
|
|
|
Returns:
|
|
TableAccessInfo with access result, RLS predicates, and CLS rules.
|
|
"""
|
|
if rules is None:
|
|
rules = get_user_rules()
|
|
|
|
if not rules:
|
|
return TableAccessInfo(
|
|
access=AccessCheckResult.NO_RULE,
|
|
rls_predicates=[],
|
|
cls_rules={},
|
|
)
|
|
|
|
# Collect all matching rules
|
|
allowed_entries: list[dict[str, Any]] = []
|
|
denied_entries: list[dict[str, Any]] = []
|
|
|
|
for rule in rules:
|
|
rule_dict = rule.rule_dict
|
|
|
|
# Check allowed entries
|
|
for entry in rule_dict.get("allowed", []):
|
|
if _matches_rule_entry(
|
|
entry, database_name, table.catalog, table.schema, table.table
|
|
):
|
|
allowed_entries.append(entry)
|
|
|
|
# Check denied entries
|
|
for entry in rule_dict.get("denied", []):
|
|
if _matches_rule_entry(
|
|
entry, database_name, table.catalog, table.schema, table.table
|
|
):
|
|
denied_entries.append(entry)
|
|
|
|
# If no rules match, return NO_RULE
|
|
if not allowed_entries and not denied_entries:
|
|
return TableAccessInfo(
|
|
access=AccessCheckResult.NO_RULE,
|
|
rls_predicates=[],
|
|
cls_rules={},
|
|
)
|
|
|
|
# Find the most specific denied entry
|
|
most_specific_denied = None
|
|
for entry in denied_entries:
|
|
if most_specific_denied is None or _is_more_specific(
|
|
entry, most_specific_denied
|
|
):
|
|
most_specific_denied = entry
|
|
|
|
# Find the most specific allowed entry
|
|
most_specific_allowed = None
|
|
for entry in allowed_entries:
|
|
if most_specific_allowed is None or _is_more_specific(
|
|
entry, most_specific_allowed
|
|
):
|
|
most_specific_allowed = entry
|
|
|
|
# Determine access: deny wins at same specificity, more specific wins otherwise
|
|
if most_specific_denied is not None and most_specific_allowed is not None:
|
|
if _is_more_specific(most_specific_denied, most_specific_allowed):
|
|
return TableAccessInfo(
|
|
access=AccessCheckResult.DENIED,
|
|
rls_predicates=[],
|
|
cls_rules={},
|
|
)
|
|
elif _is_more_specific(most_specific_allowed, most_specific_denied):
|
|
# Access allowed, collect RLS and CLS from matching entries
|
|
pass
|
|
else:
|
|
# Same specificity: denied wins
|
|
return TableAccessInfo(
|
|
access=AccessCheckResult.DENIED,
|
|
rls_predicates=[],
|
|
cls_rules={},
|
|
)
|
|
elif most_specific_denied is not None:
|
|
return TableAccessInfo(
|
|
access=AccessCheckResult.DENIED,
|
|
rls_predicates=[],
|
|
cls_rules={},
|
|
)
|
|
elif most_specific_allowed is None:
|
|
return TableAccessInfo(
|
|
access=AccessCheckResult.NO_RULE,
|
|
rls_predicates=[],
|
|
cls_rules={},
|
|
)
|
|
|
|
# Collect RLS predicates from all matching allowed entries
|
|
# (RLS is cumulative - all predicates are applied)
|
|
rls_predicates: list[RLSPredicate] = []
|
|
for entry in allowed_entries:
|
|
rls_config = entry.get("rls")
|
|
if rls_config and "predicate" in rls_config:
|
|
rls_predicates.append(
|
|
RLSPredicate(
|
|
predicate=rls_config["predicate"],
|
|
group_key=rls_config.get("group_key"),
|
|
)
|
|
)
|
|
|
|
# Collect CLS rules from all matching allowed entries
|
|
# (CLS is cumulative - strictest action wins per column)
|
|
cls_rules: dict[str, CLSAction] = {}
|
|
cls_precedence = {
|
|
CLSAction.HIDE: 4,
|
|
CLSAction.NULLIFY: 3,
|
|
CLSAction.MASK: 2,
|
|
CLSAction.HASH: 1,
|
|
}
|
|
action_map = {
|
|
"hide": CLSAction.HIDE,
|
|
"nullify": CLSAction.NULLIFY,
|
|
"mask": CLSAction.MASK,
|
|
"hash": CLSAction.HASH,
|
|
}
|
|
|
|
for entry in allowed_entries:
|
|
cls_config = entry.get("cls", {})
|
|
for column, action_str in cls_config.items():
|
|
action = action_map.get(action_str.lower())
|
|
if action is None:
|
|
logger.warning("Unknown CLS action: %s", action_str)
|
|
continue
|
|
|
|
existing = cls_rules.get(column)
|
|
if existing is None or cls_precedence[action] > cls_precedence[existing]:
|
|
cls_rules[column] = action
|
|
|
|
return TableAccessInfo(
|
|
access=AccessCheckResult.ALLOWED,
|
|
rls_predicates=rls_predicates,
|
|
cls_rules=cls_rules,
|
|
)
|
|
|
|
|
|
def get_rls_predicates_for_table(
|
|
table: Table,
|
|
database: Database,
|
|
rules: list[DataAccessRule] | None = None,
|
|
) -> list[str]:
|
|
"""
|
|
Get the RLS predicates for a table using the new Data Access Rules system.
|
|
|
|
This function collects all RLS predicates from matching rules and combines them
|
|
using the group_key logic:
|
|
- Predicates without group_key are ANDed together
|
|
- Predicates with the same group_key are ORed together
|
|
- Groups are ANDed together
|
|
|
|
Args:
|
|
table: The fully qualified Table object
|
|
database: The Database object
|
|
rules: Optional list of rules to check (defaults to current user's rules)
|
|
|
|
Returns:
|
|
List of SQL predicate strings to be ANDed together.
|
|
"""
|
|
access_info = check_table_access(
|
|
database_name=database.database_name,
|
|
table=table,
|
|
rules=rules,
|
|
)
|
|
|
|
if access_info.access != AccessCheckResult.ALLOWED:
|
|
return []
|
|
|
|
if not access_info.rls_predicates:
|
|
return []
|
|
|
|
# Group predicates by group_key
|
|
ungrouped: list[str] = []
|
|
groups: dict[str, list[str]] = defaultdict(list)
|
|
|
|
for pred in access_info.rls_predicates:
|
|
if pred.group_key:
|
|
groups[pred.group_key].append(f"({pred.predicate})")
|
|
else:
|
|
ungrouped.append(f"({pred.predicate})")
|
|
|
|
# Build result: ungrouped predicates + OR'd groups
|
|
result = ungrouped.copy()
|
|
for group_predicates in groups.values():
|
|
if len(group_predicates) == 1:
|
|
result.append(group_predicates[0])
|
|
else:
|
|
result.append(f"({' OR '.join(group_predicates)})")
|
|
|
|
return result
|
|
|
|
|
|
def get_cls_rules_for_table(
|
|
table: Table,
|
|
database: Database,
|
|
rules: list[DataAccessRule] | None = None,
|
|
) -> dict[str, CLSAction]:
|
|
"""
|
|
Get the CLS rules for a table using the new Data Access Rules system.
|
|
|
|
Args:
|
|
table: The fully qualified Table object
|
|
database: The Database object
|
|
rules: Optional list of rules to check (defaults to current user's rules)
|
|
|
|
Returns:
|
|
Dict mapping column names to CLSAction values.
|
|
"""
|
|
access_info = check_table_access(
|
|
database_name=database.database_name,
|
|
table=table,
|
|
rules=rules,
|
|
)
|
|
|
|
if access_info.access != AccessCheckResult.ALLOWED:
|
|
return {}
|
|
|
|
return access_info.cls_rules
|
|
|
|
|
|
def get_hidden_columns_for_table(
|
|
table: Table,
|
|
database: Database,
|
|
rules: list[DataAccessRule] | None = None,
|
|
) -> set[str]:
|
|
"""
|
|
Get the set of column names that should be hidden for a table.
|
|
|
|
This function checks the CLS rules for the current user and returns
|
|
the names of columns that have the "hide" action applied.
|
|
|
|
Args:
|
|
table: The fully qualified Table object
|
|
database: The Database object
|
|
rules: Optional list of rules to check (defaults to current user's rules)
|
|
|
|
Returns:
|
|
Set of column names that should be hidden.
|
|
"""
|
|
cls_rules = get_cls_rules_for_table(table, database, rules)
|
|
|
|
hidden_columns: set[str] = set()
|
|
for column_name, action in cls_rules.items():
|
|
if action == CLSAction.HIDE:
|
|
hidden_columns.add(column_name)
|
|
|
|
return hidden_columns
|
|
|
|
|
|
def filter_columns_by_cls(
|
|
columns: list[dict[str, Any]],
|
|
table: Table,
|
|
database: Database,
|
|
column_name_key: str = "column_name",
|
|
) -> list[dict[str, Any]]:
|
|
"""
|
|
Filter a list of column dictionaries to exclude hidden columns.
|
|
|
|
This function is useful for filtering column metadata returned by
|
|
database reflection or dataset APIs.
|
|
|
|
Args:
|
|
columns: List of column dictionaries
|
|
table: The fully qualified Table object
|
|
database: The Database object
|
|
column_name_key: The key in the column dict that contains the column name
|
|
|
|
Returns:
|
|
Filtered list of columns with hidden columns removed.
|
|
"""
|
|
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
|
return columns
|
|
|
|
hidden_columns = get_hidden_columns_for_table(table, database)
|
|
|
|
if not hidden_columns:
|
|
return columns
|
|
|
|
return [
|
|
col for col in columns
|
|
if col.get(column_name_key) not in hidden_columns
|
|
]
|
|
|
|
|
|
def apply_data_access_rules(
|
|
database: Database,
|
|
catalog: str | None,
|
|
schema: str,
|
|
parsed_statement: BaseSQLStatement[Any],
|
|
) -> None:
|
|
"""
|
|
Apply Data Access Rules (RLS and CLS) to a parsed SQL statement.
|
|
|
|
This function:
|
|
1. Checks if the DATA_ACCESS_RULES feature is enabled
|
|
2. For each table in the query, checks access and collects RLS/CLS rules
|
|
3. Applies RLS predicates using the existing infrastructure
|
|
4. Applies CLS rules using the existing infrastructure
|
|
|
|
Args:
|
|
database: The Database object
|
|
catalog: The default catalog for the query
|
|
schema: The default schema for the query
|
|
parsed_statement: The parsed SQL statement to modify in place
|
|
"""
|
|
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
|
return
|
|
|
|
from superset.sql.parse import CLSRules
|
|
|
|
rules = get_user_rules()
|
|
if not rules:
|
|
return
|
|
|
|
# Get the RLS method for this database
|
|
method = database.db_engine_spec.get_rls_method()
|
|
|
|
# Collect RLS predicates and CLS rules for all tables
|
|
rls_predicates: dict[Table, list[Any]] = {}
|
|
cls_rules: CLSRules = {}
|
|
|
|
for table in parsed_statement.tables:
|
|
qualified_table = table.qualify(catalog=catalog, schema=schema)
|
|
|
|
# Check access first
|
|
access_info = check_table_access(
|
|
database_name=database.database_name,
|
|
table=qualified_table,
|
|
rules=rules,
|
|
)
|
|
|
|
if access_info.access == AccessCheckResult.DENIED:
|
|
# TODO: How should we handle denied access mid-query?
|
|
# For now, log a warning. In the future, we might raise an exception.
|
|
logger.warning(
|
|
"Access denied to table %s for user %s",
|
|
qualified_table,
|
|
getattr(g, "user", "unknown"),
|
|
)
|
|
continue
|
|
|
|
# Collect RLS predicates
|
|
predicates = get_rls_predicates_for_table(qualified_table, database, rules)
|
|
if predicates:
|
|
rls_predicates[qualified_table] = [
|
|
parsed_statement.parse_predicate(pred) for pred in predicates if pred
|
|
]
|
|
|
|
# Collect CLS rules
|
|
table_cls = get_cls_rules_for_table(qualified_table, database, rules)
|
|
if table_cls:
|
|
cls_rules[qualified_table] = table_cls
|
|
|
|
# Apply CLS first (before RLS) so that hidden columns are removed
|
|
# before RLS wraps the query in a subquery
|
|
if cls_rules:
|
|
# Build schema dict for sqlglot's qualify() to expand SELECT *
|
|
# sqlglot expects nested format: {catalog: {schema: {table: {col: type}}}}
|
|
# or {schema: {table: {col: type}}} without catalog
|
|
table_schemas: dict[str, Any] = {}
|
|
for table in cls_rules.keys():
|
|
try:
|
|
columns = database.get_columns(table)
|
|
col_types = {
|
|
col["column_name"]: str(col.get("type", "VARCHAR"))
|
|
for col in columns
|
|
}
|
|
|
|
# Build nested structure for sqlglot
|
|
if table.catalog:
|
|
if table.catalog not in table_schemas:
|
|
table_schemas[table.catalog] = {}
|
|
if table.schema:
|
|
if table.schema not in table_schemas[table.catalog]:
|
|
table_schemas[table.catalog][table.schema] = {}
|
|
table_schemas[table.catalog][table.schema][table.table] = col_types
|
|
else:
|
|
table_schemas[table.catalog][table.table] = col_types
|
|
elif table.schema:
|
|
if table.schema not in table_schemas:
|
|
table_schemas[table.schema] = {}
|
|
table_schemas[table.schema][table.table] = col_types
|
|
else:
|
|
table_schemas[table.table] = col_types
|
|
except Exception as ex:
|
|
logger.warning(
|
|
"Could not fetch schema for table %s: %s",
|
|
table,
|
|
ex,
|
|
)
|
|
|
|
parsed_statement.apply_cls(cls_rules, schema=table_schemas if table_schemas else None)
|
|
|
|
# Apply RLS after CLS - RLS wraps the query in a subquery with SELECT *
|
|
# which will pick up the already-transformed columns from CLS
|
|
if rls_predicates:
|
|
parsed_statement.apply_rls(catalog, schema, rls_predicates, method)
|
|
|
|
|
|
def get_allowed_tables(
|
|
database_name: str,
|
|
schema: str | None = None,
|
|
catalog: str | None = None,
|
|
) -> tuple[set[str], bool]:
|
|
"""
|
|
Get all table names that the current user has access to via Data Access Rules
|
|
for a specific database and schema.
|
|
|
|
Args:
|
|
database_name: The database name to check
|
|
schema: Optional schema name to filter by
|
|
catalog: Optional catalog name to filter by
|
|
|
|
Returns:
|
|
Tuple of (set of table names, bool indicating if schema-level access is granted).
|
|
If schema-level access is granted, the set may be empty but all tables are allowed.
|
|
"""
|
|
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
|
return set(), False
|
|
|
|
rules = get_user_rules()
|
|
if not rules:
|
|
return set(), False
|
|
|
|
table_names: set[str] = set()
|
|
schema_level_access = False
|
|
|
|
for rule in rules:
|
|
rule_dict = rule.rule_dict
|
|
|
|
# Collect tables from allowed entries
|
|
for entry in rule_dict.get("allowed", []):
|
|
if entry.get("database") != database_name:
|
|
continue
|
|
|
|
# If catalog is specified in the entry, it must match
|
|
entry_catalog = entry.get("catalog")
|
|
if catalog is not None and entry_catalog is not None:
|
|
if entry_catalog != catalog:
|
|
continue
|
|
|
|
# If schema is specified, check if it matches
|
|
entry_schema = entry.get("schema")
|
|
if schema is not None and entry_schema is not None:
|
|
if entry_schema != schema:
|
|
continue
|
|
|
|
# If entry has a table, add it to the set
|
|
if table := entry.get("table"):
|
|
table_names.add(table)
|
|
elif entry_schema == schema or (entry_schema is None and schema is None):
|
|
# Schema-level or database-level access without table means all tables
|
|
schema_level_access = True
|
|
|
|
return table_names, schema_level_access
|
|
|
|
|
|
def get_allowed_schemas(database_name: str, catalog: str | None = None) -> set[str]:
|
|
"""
|
|
Get all schema names that the current user has access to via Data Access Rules
|
|
for a specific database.
|
|
|
|
Args:
|
|
database_name: The database name to check
|
|
catalog: Optional catalog name to filter by
|
|
|
|
Returns:
|
|
Set of schema names the user has access to.
|
|
"""
|
|
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
|
return set()
|
|
|
|
rules = get_user_rules()
|
|
if not rules:
|
|
return set()
|
|
|
|
schema_names: set[str] = set()
|
|
|
|
for rule in rules:
|
|
rule_dict = rule.rule_dict
|
|
|
|
# Collect schemas from allowed entries
|
|
for entry in rule_dict.get("allowed", []):
|
|
if entry.get("database") != database_name:
|
|
continue
|
|
|
|
# If catalog is specified in the entry, it must match
|
|
entry_catalog = entry.get("catalog")
|
|
if catalog is not None and entry_catalog is not None:
|
|
if entry_catalog != catalog:
|
|
continue
|
|
|
|
# If the entry grants database-level access (no schema specified),
|
|
# we return an empty set to indicate "all schemas" should be allowed
|
|
# This will be handled by the caller
|
|
if schema := entry.get("schema"):
|
|
schema_names.add(schema)
|
|
elif entry.get("database") == database_name:
|
|
# Database-level access without schema means all schemas
|
|
# Return a special marker that caller can check
|
|
schema_names.add("*")
|
|
|
|
return schema_names
|
|
|
|
|
|
def get_allowed_databases() -> set[str]:
|
|
"""
|
|
Get all database names that the current user has access to via Data Access Rules.
|
|
|
|
This function is used to populate database selectors in SQL Lab and elsewhere.
|
|
|
|
Returns:
|
|
Set of database names the user has access to.
|
|
"""
|
|
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
|
return set()
|
|
|
|
rules = get_user_rules()
|
|
if not rules:
|
|
return set()
|
|
|
|
database_names: set[str] = set()
|
|
|
|
for rule in rules:
|
|
rule_dict = rule.rule_dict
|
|
|
|
# Collect databases from allowed entries
|
|
for entry in rule_dict.get("allowed", []):
|
|
if database := entry.get("database"):
|
|
database_names.add(database)
|
|
|
|
return database_names
|
|
|
|
|
|
@dataclass
|
|
class AllowedTable:
|
|
"""A table allowed by DAR with database context."""
|
|
|
|
database: str
|
|
table: str
|
|
schema: str | None = None
|
|
catalog: str | None = None
|
|
|
|
|
|
@dataclass
|
|
class AllowedEntry:
|
|
"""
|
|
An allowed entry from DAR at any level of the hierarchy.
|
|
|
|
Fields may be None to indicate "all" at that level:
|
|
- database only: all catalogs/schemas/tables in that database
|
|
- database + catalog: all schemas/tables in that catalog
|
|
- database + schema: all tables in that schema (for DBs without catalogs)
|
|
- database + catalog + schema: all tables in that schema
|
|
- database + schema + table: specific table (for DBs without catalogs)
|
|
- database + catalog + schema + table: specific table
|
|
"""
|
|
|
|
database: str
|
|
catalog: str | None = None
|
|
schema: str | None = None
|
|
table: str | None = None
|
|
|
|
|
|
def get_all_allowed_entries() -> list[AllowedEntry]:
|
|
"""
|
|
Get all access entries that the current user has via Data Access Rules
|
|
across all databases.
|
|
|
|
This function returns entries at all hierarchy levels (database, schema, table),
|
|
allowing callers to build appropriate filters for their use case.
|
|
|
|
Returns:
|
|
List of AllowedEntry objects representing allowed access.
|
|
"""
|
|
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
|
return []
|
|
|
|
rules = get_user_rules()
|
|
if not rules:
|
|
return []
|
|
|
|
allowed_entries: list[AllowedEntry] = []
|
|
|
|
for rule in rules:
|
|
rule_dict = rule.rule_dict
|
|
|
|
# Collect all allowed entries
|
|
for entry in rule_dict.get("allowed", []):
|
|
database = entry.get("database")
|
|
if not database:
|
|
continue
|
|
|
|
allowed_entries.append(
|
|
AllowedEntry(
|
|
database=database,
|
|
catalog=entry.get("catalog"),
|
|
schema=entry.get("schema"),
|
|
table=entry.get("table"),
|
|
)
|
|
)
|
|
|
|
return allowed_entries
|
|
|
|
|
|
def get_all_allowed_tables() -> list[AllowedTable]:
|
|
"""
|
|
Get all tables that the current user has access to via Data Access Rules
|
|
across all databases.
|
|
|
|
This function is used for dataset filtering where we need to know all
|
|
specific tables the user can access.
|
|
|
|
Returns:
|
|
List of AllowedTable objects representing allowed tables.
|
|
"""
|
|
if not is_feature_enabled("DATA_ACCESS_RULES"):
|
|
return []
|
|
|
|
rules = get_user_rules()
|
|
if not rules:
|
|
return []
|
|
|
|
allowed_tables: list[AllowedTable] = []
|
|
|
|
for rule in rules:
|
|
rule_dict = rule.rule_dict
|
|
|
|
# Collect tables from allowed entries
|
|
for entry in rule_dict.get("allowed", []):
|
|
database = entry.get("database")
|
|
if not database:
|
|
continue
|
|
|
|
table_name = entry.get("table")
|
|
if not table_name:
|
|
# Skip database-level or schema-level access for now
|
|
# as we can't enumerate all tables without querying the DB
|
|
continue
|
|
|
|
allowed_tables.append(
|
|
AllowedTable(
|
|
database=database,
|
|
table=table_name,
|
|
schema=entry.get("schema"),
|
|
catalog=entry.get("catalog"),
|
|
)
|
|
)
|
|
|
|
return allowed_tables
|
|
|
|
|
|
def get_all_group_keys(
|
|
database_name: str | None = None,
|
|
table: Table | None = None,
|
|
) -> set[str]:
|
|
"""
|
|
Get all distinct group_keys used in RLS rules.
|
|
|
|
This is useful for UI discoverability - showing users what group_keys
|
|
already exist so they can reuse them for consistent rule grouping.
|
|
|
|
Args:
|
|
database_name: Optional filter by database
|
|
table: Optional Table object to filter by catalog/schema/table
|
|
|
|
Returns:
|
|
Set of unique group_key values.
|
|
"""
|
|
from superset.data_access_rules.models import DataAccessRule
|
|
|
|
query = db.session.query(DataAccessRule)
|
|
rules = query.all()
|
|
|
|
group_keys: set[str] = set()
|
|
|
|
for rule in rules:
|
|
rule_dict = rule.rule_dict
|
|
|
|
for entry in rule_dict.get("allowed", []):
|
|
# Apply filters if specified
|
|
if database_name and entry.get("database") != database_name:
|
|
continue
|
|
if table is not None:
|
|
if table.catalog and entry.get("catalog") != table.catalog:
|
|
continue
|
|
if table.schema and entry.get("schema") != table.schema:
|
|
continue
|
|
if table.table and entry.get("table") != table.table:
|
|
continue
|
|
|
|
rls_config = entry.get("rls", {})
|
|
if group_key := rls_config.get("group_key"):
|
|
group_keys.add(group_key)
|
|
|
|
return group_keys
|