mirror of
https://github.com/apache/superset.git
synced 2026-05-07 08:54:23 +00:00
Initial rules
This commit is contained in:
542
tests/unit_tests/data_access_rules/utils_test.py
Normal file
542
tests/unit_tests/data_access_rules/utils_test.py
Normal file
@@ -0,0 +1,542 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Unit tests for Data Access Rules utility functions.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.data_access_rules.models import DataAccessRule
|
||||
from superset.data_access_rules.utils import (
|
||||
_is_more_specific,
|
||||
_matches_rule_entry,
|
||||
AccessCheckResult,
|
||||
check_table_access,
|
||||
get_all_group_keys,
|
||||
get_cls_rules_for_table,
|
||||
get_rls_predicates_for_table,
|
||||
)
|
||||
from superset.sql.parse import CLSAction, Table
|
||||
|
||||
|
||||
# Tests for _matches_rule_entry
|
||||
def test_matches_rule_entry_database_only():
|
||||
"""Test matching when rule specifies only database."""
|
||||
entry = {"database": "mydb"}
|
||||
assert _matches_rule_entry(entry, "mydb", None, None, None) is True
|
||||
assert _matches_rule_entry(entry, "mydb", "catalog1", "schema1", "table1") is True
|
||||
assert _matches_rule_entry(entry, "otherdb", None, None, None) is False
|
||||
|
||||
|
||||
def test_matches_rule_entry_database_and_catalog():
|
||||
"""Test matching when rule specifies database and catalog."""
|
||||
entry = {"database": "mydb", "catalog": "cat1"}
|
||||
assert _matches_rule_entry(entry, "mydb", "cat1", None, None) is True
|
||||
assert _matches_rule_entry(entry, "mydb", "cat1", "schema1", "table1") is True
|
||||
assert _matches_rule_entry(entry, "mydb", "cat2", None, None) is False
|
||||
assert _matches_rule_entry(entry, "otherdb", "cat1", None, None) is False
|
||||
|
||||
|
||||
def test_matches_rule_entry_database_and_schema():
|
||||
"""Test matching when rule specifies database and schema (no catalog)."""
|
||||
entry = {"database": "mydb", "schema": "public"}
|
||||
assert _matches_rule_entry(entry, "mydb", None, "public", None) is True
|
||||
assert _matches_rule_entry(entry, "mydb", None, "public", "table1") is True
|
||||
assert _matches_rule_entry(entry, "mydb", None, "other", None) is False
|
||||
|
||||
|
||||
def test_matches_rule_entry_full_table():
|
||||
"""Test matching when rule specifies full table path."""
|
||||
entry = {"database": "mydb", "schema": "public", "table": "users"}
|
||||
assert _matches_rule_entry(entry, "mydb", None, "public", "users") is True
|
||||
assert _matches_rule_entry(entry, "mydb", None, "public", "orders") is False
|
||||
assert _matches_rule_entry(entry, "mydb", None, "other", "users") is False
|
||||
|
||||
|
||||
def test_matches_rule_entry_with_catalog():
|
||||
"""Test matching with catalog in the path."""
|
||||
entry = {
|
||||
"database": "mydb",
|
||||
"catalog": "main",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
}
|
||||
assert _matches_rule_entry(entry, "mydb", "main", "public", "users") is True
|
||||
assert _matches_rule_entry(entry, "mydb", "other", "public", "users") is False
|
||||
|
||||
|
||||
# Tests for _is_more_specific
|
||||
def test_is_more_specific():
|
||||
"""Test specificity comparison between entries."""
|
||||
db_only = {"database": "mydb"}
|
||||
db_schema = {"database": "mydb", "schema": "public"}
|
||||
db_table = {"database": "mydb", "schema": "public", "table": "users"}
|
||||
db_catalog = {"database": "mydb", "catalog": "main"}
|
||||
db_catalog_schema = {"database": "mydb", "catalog": "main", "schema": "public"}
|
||||
|
||||
# More specific should win
|
||||
assert _is_more_specific(db_schema, db_only) is True
|
||||
assert _is_more_specific(db_table, db_schema) is True
|
||||
assert _is_more_specific(db_table, db_only) is True
|
||||
|
||||
# Less specific should lose
|
||||
assert _is_more_specific(db_only, db_schema) is False
|
||||
assert _is_more_specific(db_schema, db_table) is False
|
||||
|
||||
# Same specificity
|
||||
assert _is_more_specific(db_schema, db_catalog) is False
|
||||
assert _is_more_specific(db_catalog, db_schema) is False
|
||||
|
||||
|
||||
# Tests for check_table_access
|
||||
def test_check_table_access_no_rules():
|
||||
"""Test access check when no rules are provided."""
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[])
|
||||
assert result.access == AccessCheckResult.NO_RULE
|
||||
assert result.rls_predicates == []
|
||||
assert result.cls_rules == {}
|
||||
|
||||
|
||||
def test_check_table_access_allowed():
|
||||
"""Test access check when table is allowed."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public"}],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert result.rls_predicates == []
|
||||
assert result.cls_rules == {}
|
||||
|
||||
|
||||
def test_check_table_access_denied():
|
||||
"""Test access check when table is denied."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [],
|
||||
"denied": [{"database": "mydb", "schema": "public"}],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.DENIED
|
||||
|
||||
|
||||
def test_check_table_access_denied_more_specific():
|
||||
"""Test that more specific deny wins over less specific allow."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb"}], # Less specific
|
||||
"denied": [{"database": "mydb", "schema": "secret"}], # More specific
|
||||
}
|
||||
|
||||
# Table in non-denied schema should be allowed
|
||||
table_public = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table_public, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
|
||||
# Table in denied schema should be denied
|
||||
table_secret = Table(table="data", schema="secret", catalog=None)
|
||||
result = check_table_access("mydb", table_secret, rules=[rule])
|
||||
assert result.access == AccessCheckResult.DENIED
|
||||
|
||||
|
||||
def test_check_table_access_allowed_more_specific():
|
||||
"""Test that more specific allow wins over less specific deny."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public", "table": "users"}],
|
||||
"denied": [{"database": "mydb", "schema": "public"}],
|
||||
}
|
||||
|
||||
# The specific table is allowed despite schema being denied
|
||||
table_users = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table_users, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
|
||||
# Other tables in the schema are still denied
|
||||
table_orders = Table(table="orders", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table_orders, rules=[rule])
|
||||
assert result.access == AccessCheckResult.DENIED
|
||||
|
||||
|
||||
def test_check_table_access_same_specificity_deny_wins():
|
||||
"""Test that deny wins when rules have same specificity."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public"}],
|
||||
"denied": [{"database": "mydb", "schema": "public"}],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.DENIED
|
||||
|
||||
|
||||
def test_check_table_access_with_rls():
|
||||
"""Test access check collects RLS predicates."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"rls": {"predicate": "org_id = 123", "group_key": "org"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert len(result.rls_predicates) == 1
|
||||
assert result.rls_predicates[0].predicate == "org_id = 123"
|
||||
assert result.rls_predicates[0].group_key == "org"
|
||||
|
||||
|
||||
def test_check_table_access_multiple_rls():
|
||||
"""Test access check collects multiple RLS predicates from different rules."""
|
||||
rule1 = MagicMock(spec=DataAccessRule)
|
||||
rule1.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 123", "group_key": "org"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
rule2 = MagicMock(spec=DataAccessRule)
|
||||
rule2.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "region = 'US'"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule1, rule2])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert len(result.rls_predicates) == 2
|
||||
|
||||
predicates = [p.predicate for p in result.rls_predicates]
|
||||
assert "org_id = 123" in predicates
|
||||
assert "region = 'US'" in predicates
|
||||
|
||||
|
||||
def test_check_table_access_with_cls():
|
||||
"""Test access check collects CLS rules."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"cls": {"email": "mask", "ssn": "hide", "name": "hash"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert result.cls_rules == {
|
||||
"email": CLSAction.MASK,
|
||||
"ssn": CLSAction.HIDE,
|
||||
"name": CLSAction.HASH,
|
||||
}
|
||||
|
||||
|
||||
def test_check_table_access_cls_strictest_wins():
|
||||
"""Test that strictest CLS action wins when multiple rules apply."""
|
||||
rule1 = MagicMock(spec=DataAccessRule)
|
||||
rule1.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"cls": {"email": "mask"}, # Less strict
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
rule2 = MagicMock(spec=DataAccessRule)
|
||||
rule2.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"cls": {"email": "hide"}, # More strict - should win
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
result = check_table_access("mydb", table, rules=[rule1, rule2])
|
||||
assert result.access == AccessCheckResult.ALLOWED
|
||||
assert result.cls_rules["email"] == CLSAction.HIDE
|
||||
|
||||
|
||||
# Tests for get_rls_predicates_for_table
|
||||
def test_get_rls_predicates_for_table_no_predicates():
|
||||
"""Test getting RLS predicates when there are none."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public"}],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
|
||||
assert predicates == []
|
||||
|
||||
|
||||
def test_get_rls_predicates_for_table_with_predicates():
|
||||
"""Test getting RLS predicates."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 123"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
|
||||
assert predicates == ["(org_id = 123)"]
|
||||
|
||||
|
||||
def test_get_rls_predicates_for_table_with_group_key():
|
||||
"""Test getting RLS predicates with group_key combines with OR."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule1 = MagicMock(spec=DataAccessRule)
|
||||
rule1.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 1", "group_key": "org"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
rule2 = MagicMock(spec=DataAccessRule)
|
||||
rule2.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 2", "group_key": "org"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
predicates = get_rls_predicates_for_table(table, database, rules=[rule1, rule2])
|
||||
# Same group_key predicates should be ORed
|
||||
assert len(predicates) == 1
|
||||
assert "(org_id = 1)" in predicates[0]
|
||||
assert "(org_id = 2)" in predicates[0]
|
||||
assert " OR " in predicates[0]
|
||||
|
||||
|
||||
def test_get_rls_predicates_for_table_mixed_group_keys():
|
||||
"""Test getting RLS predicates with mixed group_keys."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 1", "group_key": "org"},
|
||||
},
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "org_id = 2", "group_key": "org"},
|
||||
},
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"rls": {"predicate": "region = 'US'"}, # No group_key
|
||||
},
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
|
||||
# Should have: ungrouped predicate + ORed group predicate = 2 items
|
||||
assert len(predicates) == 2
|
||||
|
||||
has_region = any("region = 'US'" in p for p in predicates)
|
||||
has_org_group = any("org_id = 1" in p and "org_id = 2" in p for p in predicates)
|
||||
assert has_region
|
||||
assert has_org_group
|
||||
|
||||
|
||||
# Tests for get_cls_rules_for_table
|
||||
def test_get_cls_rules_for_table_no_rules():
|
||||
"""Test getting CLS rules when there are none."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [{"database": "mydb", "schema": "public"}],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
|
||||
assert cls_rules == {}
|
||||
|
||||
|
||||
def test_get_cls_rules_for_table_with_rules():
|
||||
"""Test getting CLS rules."""
|
||||
database = MagicMock()
|
||||
database.database_name = "mydb"
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "mydb",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"cls": {"email": "mask", "ssn": "hide"},
|
||||
}
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
|
||||
assert cls_rules == {"email": CLSAction.MASK, "ssn": CLSAction.HIDE}
|
||||
|
||||
|
||||
# Tests for get_all_group_keys
|
||||
def test_get_all_group_keys_empty(app_context: None):
|
||||
"""Test getting group keys when none exist."""
|
||||
with patch("superset.data_access_rules.utils.db") as mock_db:
|
||||
mock_db.session.query.return_value.all.return_value = []
|
||||
keys = get_all_group_keys()
|
||||
assert keys == set()
|
||||
|
||||
|
||||
def test_get_all_group_keys_with_keys(app_context: None):
|
||||
"""Test getting group keys from rules."""
|
||||
rule1 = MagicMock(spec=DataAccessRule)
|
||||
rule1.rule_dict = {
|
||||
"allowed": [
|
||||
{"database": "mydb", "rls": {"predicate": "x=1", "group_key": "key1"}},
|
||||
{"database": "mydb", "rls": {"predicate": "x=2", "group_key": "key2"}},
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
rule2 = MagicMock(spec=DataAccessRule)
|
||||
rule2.rule_dict = {
|
||||
"allowed": [
|
||||
{"database": "mydb", "rls": {"predicate": "x=3", "group_key": "key1"}},
|
||||
{"database": "mydb", "rls": {"predicate": "x=4"}}, # No group_key
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
with patch("superset.data_access_rules.utils.db") as mock_db:
|
||||
mock_db.session.query.return_value.all.return_value = [rule1, rule2]
|
||||
keys = get_all_group_keys()
|
||||
assert keys == {"key1", "key2"}
|
||||
|
||||
|
||||
def test_get_all_group_keys_filtered_by_database(app_context: None):
|
||||
"""Test getting group keys filtered by database."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{"database": "db1", "rls": {"predicate": "x=1", "group_key": "key1"}},
|
||||
{"database": "db2", "rls": {"predicate": "x=2", "group_key": "key2"}},
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
with patch("superset.data_access_rules.utils.db") as mock_db:
|
||||
mock_db.session.query.return_value.all.return_value = [rule]
|
||||
keys = get_all_group_keys(database_name="db1")
|
||||
assert keys == {"key1"}
|
||||
|
||||
|
||||
def test_get_all_group_keys_filtered_by_table(app_context: None):
|
||||
"""Test getting group keys filtered by table."""
|
||||
rule = MagicMock(spec=DataAccessRule)
|
||||
rule.rule_dict = {
|
||||
"allowed": [
|
||||
{
|
||||
"database": "db1",
|
||||
"schema": "public",
|
||||
"table": "users",
|
||||
"rls": {"predicate": "x=1", "group_key": "key1"},
|
||||
},
|
||||
{
|
||||
"database": "db1",
|
||||
"schema": "public",
|
||||
"table": "orders",
|
||||
"rls": {"predicate": "x=2", "group_key": "key2"},
|
||||
},
|
||||
],
|
||||
"denied": [],
|
||||
}
|
||||
|
||||
with patch("superset.data_access_rules.utils.db") as mock_db:
|
||||
mock_db.session.query.return_value.all.return_value = [rule]
|
||||
table = Table(table="users", schema="public", catalog=None)
|
||||
keys = get_all_group_keys(database_name="db1", table=table)
|
||||
assert keys == {"key1"}
|
||||
Reference in New Issue
Block a user