Files
superset2/tests/unit_tests/data_access_rules/utils_test.py
2025-12-17 11:41:09 -05:00

769 lines
25 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for Data Access Rules utility functions.
"""
from unittest.mock import MagicMock, patch
from superset.data_access_rules.models import DataAccessRule
from superset.data_access_rules.utils import (
_is_more_specific,
_matches_rule_entry,
AccessCheckResult,
check_table_access,
get_all_group_keys,
get_cls_rules_for_table,
get_rls_predicates_for_table,
)
from superset.sql.parse import CLSAction, Table
# Tests for _matches_rule_entry
def test_matches_rule_entry_database_only():
"""Test matching when rule specifies only database."""
entry = {"database": "mydb"}
assert _matches_rule_entry(entry, "mydb", None, None, None) is True
assert _matches_rule_entry(entry, "mydb", "catalog1", "schema1", "table1") is True
assert _matches_rule_entry(entry, "otherdb", None, None, None) is False
def test_matches_rule_entry_database_and_catalog():
"""Test matching when rule specifies database and catalog."""
entry = {"database": "mydb", "catalog": "cat1"}
assert _matches_rule_entry(entry, "mydb", "cat1", None, None) is True
assert _matches_rule_entry(entry, "mydb", "cat1", "schema1", "table1") is True
assert _matches_rule_entry(entry, "mydb", "cat2", None, None) is False
assert _matches_rule_entry(entry, "otherdb", "cat1", None, None) is False
def test_matches_rule_entry_database_and_schema():
"""Test matching when rule specifies database and schema (no catalog)."""
entry = {"database": "mydb", "schema": "public"}
assert _matches_rule_entry(entry, "mydb", None, "public", None) is True
assert _matches_rule_entry(entry, "mydb", None, "public", "table1") is True
assert _matches_rule_entry(entry, "mydb", None, "other", None) is False
def test_matches_rule_entry_full_table():
"""Test matching when rule specifies full table path."""
entry = {"database": "mydb", "schema": "public", "table": "users"}
assert _matches_rule_entry(entry, "mydb", None, "public", "users") is True
assert _matches_rule_entry(entry, "mydb", None, "public", "orders") is False
assert _matches_rule_entry(entry, "mydb", None, "other", "users") is False
def test_matches_rule_entry_with_catalog():
"""Test matching with catalog in the path."""
entry = {
"database": "mydb",
"catalog": "main",
"schema": "public",
"table": "users",
}
assert _matches_rule_entry(entry, "mydb", "main", "public", "users") is True
assert _matches_rule_entry(entry, "mydb", "other", "public", "users") is False
# Tests for _is_more_specific
def test_is_more_specific():
"""Test specificity comparison between entries."""
db_only = {"database": "mydb"}
db_schema = {"database": "mydb", "schema": "public"}
db_table = {"database": "mydb", "schema": "public", "table": "users"}
db_catalog = {"database": "mydb", "catalog": "main"}
db_catalog_schema = {"database": "mydb", "catalog": "main", "schema": "public"}
# More specific should win
assert _is_more_specific(db_schema, db_only) is True
assert _is_more_specific(db_table, db_schema) is True
assert _is_more_specific(db_table, db_only) is True
# Less specific should lose
assert _is_more_specific(db_only, db_schema) is False
assert _is_more_specific(db_schema, db_table) is False
# Same specificity
assert _is_more_specific(db_schema, db_catalog) is False
assert _is_more_specific(db_catalog, db_schema) is False
# Tests for check_table_access
def test_check_table_access_no_rules():
"""Test access check when no rules are provided."""
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[])
assert result.access == AccessCheckResult.NO_RULE
assert result.rls_predicates == []
assert result.cls_rules == {}
def test_check_table_access_allowed():
"""Test access check when table is allowed."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public"}],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
assert result.rls_predicates == []
assert result.cls_rules == {}
def test_check_table_access_denied():
"""Test access check when table is denied."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [],
"denied": [{"database": "mydb", "schema": "public"}],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.DENIED
def test_check_table_access_denied_more_specific():
"""Test that more specific deny wins over less specific allow."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb"}], # Less specific
"denied": [{"database": "mydb", "schema": "secret"}], # More specific
}
# Table in non-denied schema should be allowed
table_public = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table_public, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
# Table in denied schema should be denied
table_secret = Table(table="data", schema="secret", catalog=None)
result = check_table_access("mydb", table_secret, rules=[rule])
assert result.access == AccessCheckResult.DENIED
def test_check_table_access_allowed_more_specific():
"""Test that more specific allow wins over less specific deny."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public", "table": "users"}],
"denied": [{"database": "mydb", "schema": "public"}],
}
# The specific table is allowed despite schema being denied
table_users = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table_users, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
# Other tables in the schema are still denied
table_orders = Table(table="orders", schema="public", catalog=None)
result = check_table_access("mydb", table_orders, rules=[rule])
assert result.access == AccessCheckResult.DENIED
def test_check_table_access_same_specificity_deny_wins():
"""Test that deny wins when rules have same specificity."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public"}],
"denied": [{"database": "mydb", "schema": "public"}],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.DENIED
def test_check_table_access_with_rls():
"""Test access check collects RLS predicates."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"rls": {"predicate": "org_id = 123", "group_key": "org"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
assert len(result.rls_predicates) == 1
assert result.rls_predicates[0].predicate == "org_id = 123"
assert result.rls_predicates[0].group_key == "org"
def test_check_table_access_multiple_rls():
"""Test access check collects multiple RLS predicates from different rules."""
rule1 = MagicMock(spec=DataAccessRule)
rule1.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 123", "group_key": "org"},
}
],
"denied": [],
}
rule2 = MagicMock(spec=DataAccessRule)
rule2.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "region = 'US'"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule1, rule2])
assert result.access == AccessCheckResult.ALLOWED
assert len(result.rls_predicates) == 2
predicates = [p.predicate for p in result.rls_predicates]
assert "org_id = 123" in predicates
assert "region = 'US'" in predicates
def test_check_table_access_with_cls():
"""Test access check collects CLS rules."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"email": "mask", "ssn": "hide", "name": "hash"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule])
assert result.access == AccessCheckResult.ALLOWED
assert result.cls_rules == {
"email": CLSAction.MASK,
"ssn": CLSAction.HIDE,
"name": CLSAction.HASH,
}
def test_check_table_access_cls_strictest_wins():
"""Test that strictest CLS action wins when multiple rules apply."""
rule1 = MagicMock(spec=DataAccessRule)
rule1.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"cls": {"email": "mask"}, # Less strict
}
],
"denied": [],
}
rule2 = MagicMock(spec=DataAccessRule)
rule2.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"cls": {"email": "hide"}, # More strict - should win
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
result = check_table_access("mydb", table, rules=[rule1, rule2])
assert result.access == AccessCheckResult.ALLOWED
assert result.cls_rules["email"] == CLSAction.HIDE
# Tests for get_rls_predicates_for_table
def test_get_rls_predicates_for_table_no_predicates():
"""Test getting RLS predicates when there are none."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public"}],
"denied": [],
}
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
assert predicates == []
def test_get_rls_predicates_for_table_with_predicates():
"""Test getting RLS predicates."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 123"},
}
],
"denied": [],
}
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
assert predicates == ["(org_id = 123)"]
def test_get_rls_predicates_for_table_with_group_key():
"""Test getting RLS predicates with group_key combines with OR."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule1 = MagicMock(spec=DataAccessRule)
rule1.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 1", "group_key": "org"},
}
],
"denied": [],
}
rule2 = MagicMock(spec=DataAccessRule)
rule2.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 2", "group_key": "org"},
}
],
"denied": [],
}
predicates = get_rls_predicates_for_table(table, database, rules=[rule1, rule2])
# Same group_key predicates should be ORed
assert len(predicates) == 1
assert "(org_id = 1)" in predicates[0]
assert "(org_id = 2)" in predicates[0]
assert " OR " in predicates[0]
def test_get_rls_predicates_for_table_mixed_group_keys():
"""Test getting RLS predicates with mixed group_keys."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 1", "group_key": "org"},
},
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "org_id = 2", "group_key": "org"},
},
{
"database": "mydb",
"schema": "public",
"rls": {"predicate": "region = 'US'"}, # No group_key
},
],
"denied": [],
}
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
# Should have: ungrouped predicate + ORed group predicate = 2 items
assert len(predicates) == 2
has_region = any("region = 'US'" in p for p in predicates)
has_org_group = any("org_id = 1" in p and "org_id = 2" in p for p in predicates)
assert has_region
assert has_org_group
# Tests for get_cls_rules_for_table
def test_get_cls_rules_for_table_no_rules():
"""Test getting CLS rules when there are none."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [{"database": "mydb", "schema": "public"}],
"denied": [],
}
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
assert cls_rules == {}
def test_get_cls_rules_for_table_with_rules():
"""Test getting CLS rules."""
database = MagicMock()
database.database_name = "mydb"
table = Table(table="users", schema="public", catalog=None)
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"email": "mask", "ssn": "hide"},
}
],
"denied": [],
}
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
assert cls_rules == {"email": CLSAction.MASK, "ssn": CLSAction.HIDE}
# Tests for get_all_group_keys
def test_get_all_group_keys_empty(app_context: None):
"""Test getting group keys when none exist."""
with patch("superset.data_access_rules.utils.db") as mock_db:
mock_db.session.query.return_value.all.return_value = []
keys = get_all_group_keys()
assert keys == set()
def test_get_all_group_keys_with_keys(app_context: None):
"""Test getting group keys from rules."""
rule1 = MagicMock(spec=DataAccessRule)
rule1.rule_dict = {
"allowed": [
{"database": "mydb", "rls": {"predicate": "x=1", "group_key": "key1"}},
{"database": "mydb", "rls": {"predicate": "x=2", "group_key": "key2"}},
],
"denied": [],
}
rule2 = MagicMock(spec=DataAccessRule)
rule2.rule_dict = {
"allowed": [
{"database": "mydb", "rls": {"predicate": "x=3", "group_key": "key1"}},
{"database": "mydb", "rls": {"predicate": "x=4"}}, # No group_key
],
"denied": [],
}
with patch("superset.data_access_rules.utils.db") as mock_db:
mock_db.session.query.return_value.all.return_value = [rule1, rule2]
keys = get_all_group_keys()
assert keys == {"key1", "key2"}
def test_get_all_group_keys_filtered_by_database(app_context: None):
"""Test getting group keys filtered by database."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{"database": "db1", "rls": {"predicate": "x=1", "group_key": "key1"}},
{"database": "db2", "rls": {"predicate": "x=2", "group_key": "key2"}},
],
"denied": [],
}
with patch("superset.data_access_rules.utils.db") as mock_db:
mock_db.session.query.return_value.all.return_value = [rule]
keys = get_all_group_keys(database_name="db1")
assert keys == {"key1"}
def test_get_all_group_keys_filtered_by_table(app_context: None):
"""Test getting group keys filtered by table."""
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "db1",
"schema": "public",
"table": "users",
"rls": {"predicate": "x=1", "group_key": "key1"},
},
{
"database": "db1",
"schema": "public",
"table": "orders",
"rls": {"predicate": "x=2", "group_key": "key2"},
},
],
"denied": [],
}
with patch("superset.data_access_rules.utils.db") as mock_db:
mock_db.session.query.return_value.all.return_value = [rule]
table = Table(table="users", schema="public", catalog=None)
keys = get_all_group_keys(database_name="db1", table=table)
assert keys == {"key1"}
# Tests for get_hidden_columns_for_table
def test_get_hidden_columns_for_table_no_hidden(app_context: None):
"""Test getting hidden columns when no columns are hidden."""
from superset.data_access_rules.utils import get_hidden_columns_for_table
database = MagicMock()
database.database_name = "mydb"
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"email": "mask", "phone": "hash"}, # No "hide" actions
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
assert hidden == set()
def test_get_hidden_columns_for_table_with_hidden(app_context: None):
"""Test getting hidden columns when some columns are hidden."""
from superset.data_access_rules.utils import get_hidden_columns_for_table
database = MagicMock()
database.database_name = "mydb"
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"email": "mask", "ssn": "hide", "password": "hide"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
assert hidden == {"ssn", "password"}
def test_get_hidden_columns_for_table_denied_access(app_context: None):
"""Test that denied access returns no hidden columns."""
from superset.data_access_rules.utils import get_hidden_columns_for_table
database = MagicMock()
database.database_name = "mydb"
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [],
"denied": [
{
"database": "mydb",
"schema": "public",
"table": "users",
}
],
}
table = Table(table="users", schema="public", catalog=None)
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
# Denied access means no CLS rules are returned
assert hidden == set()
# Tests for filter_columns_by_cls
def test_filter_columns_by_cls_no_hidden(app_context: None):
"""Test filtering columns when no columns are hidden."""
from superset.data_access_rules.utils import filter_columns_by_cls
database = MagicMock()
database.database_name = "mydb"
columns = [
{"column_name": "id", "type": "INTEGER"},
{"column_name": "name", "type": "VARCHAR"},
{"column_name": "email", "type": "VARCHAR"},
]
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{"database": "mydb", "schema": "public", "table": "users"}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
with patch(
"superset.data_access_rules.utils.is_feature_enabled",
return_value=True,
):
with patch(
"superset.data_access_rules.utils.get_user_rules",
return_value=[rule],
):
filtered = filter_columns_by_cls(columns, table, database)
assert len(filtered) == 3
assert filtered == columns
def test_filter_columns_by_cls_with_hidden(app_context: None):
"""Test filtering columns when some columns are hidden."""
from superset.data_access_rules.utils import filter_columns_by_cls
database = MagicMock()
database.database_name = "mydb"
columns = [
{"column_name": "id", "type": "INTEGER"},
{"column_name": "name", "type": "VARCHAR"},
{"column_name": "email", "type": "VARCHAR"},
{"column_name": "ssn", "type": "VARCHAR"},
]
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"ssn": "hide"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
with patch(
"superset.data_access_rules.utils.is_feature_enabled",
return_value=True,
):
with patch(
"superset.data_access_rules.utils.get_user_rules",
return_value=[rule],
):
filtered = filter_columns_by_cls(columns, table, database)
assert len(filtered) == 3
column_names = [c["column_name"] for c in filtered]
assert "ssn" not in column_names
assert "id" in column_names
assert "name" in column_names
assert "email" in column_names
def test_filter_columns_by_cls_feature_disabled(app_context: None):
"""Test that filtering is skipped when feature flag is disabled."""
from superset.data_access_rules.utils import filter_columns_by_cls
database = MagicMock()
database.database_name = "mydb"
columns = [
{"column_name": "id", "type": "INTEGER"},
{"column_name": "ssn", "type": "VARCHAR"},
]
table = Table(table="users", schema="public", catalog=None)
with patch(
"superset.data_access_rules.utils.is_feature_enabled",
return_value=False,
):
# Even if there would be hidden columns, they are not filtered
filtered = filter_columns_by_cls(columns, table, database)
assert len(filtered) == 2
assert filtered == columns
def test_filter_columns_by_cls_custom_key(app_context: None):
"""Test filtering columns with custom column name key."""
from superset.data_access_rules.utils import filter_columns_by_cls
database = MagicMock()
database.database_name = "mydb"
# Columns with different key structure (like from SQL Lab table metadata)
columns = [
{"name": "id", "type": "INTEGER"},
{"name": "ssn", "type": "VARCHAR"},
]
rule = MagicMock(spec=DataAccessRule)
rule.rule_dict = {
"allowed": [
{
"database": "mydb",
"schema": "public",
"table": "users",
"cls": {"ssn": "hide"},
}
],
"denied": [],
}
table = Table(table="users", schema="public", catalog=None)
with patch(
"superset.data_access_rules.utils.is_feature_enabled",
return_value=True,
):
with patch(
"superset.data_access_rules.utils.get_user_rules",
return_value=[rule],
):
filtered = filter_columns_by_cls(
columns, table, database, column_name_key="name"
)
assert len(filtered) == 1
assert filtered[0]["name"] == "id"