mirror of
https://github.com/apache/superset.git
synced 2026-04-30 05:24:31 +00:00
769 lines
25 KiB
Python
769 lines
25 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
"""
|
|
Unit tests for Data Access Rules utility functions.
|
|
"""
|
|
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
from superset.data_access_rules.models import DataAccessRule
|
|
from superset.data_access_rules.utils import (
|
|
_is_more_specific,
|
|
_matches_rule_entry,
|
|
AccessCheckResult,
|
|
check_table_access,
|
|
get_all_group_keys,
|
|
get_cls_rules_for_table,
|
|
get_rls_predicates_for_table,
|
|
)
|
|
from superset.sql.parse import CLSAction, Table
|
|
|
|
|
|
# Tests for _matches_rule_entry
|
|
def test_matches_rule_entry_database_only():
|
|
"""Test matching when rule specifies only database."""
|
|
entry = {"database": "mydb"}
|
|
assert _matches_rule_entry(entry, "mydb", None, None, None) is True
|
|
assert _matches_rule_entry(entry, "mydb", "catalog1", "schema1", "table1") is True
|
|
assert _matches_rule_entry(entry, "otherdb", None, None, None) is False
|
|
|
|
|
|
def test_matches_rule_entry_database_and_catalog():
|
|
"""Test matching when rule specifies database and catalog."""
|
|
entry = {"database": "mydb", "catalog": "cat1"}
|
|
assert _matches_rule_entry(entry, "mydb", "cat1", None, None) is True
|
|
assert _matches_rule_entry(entry, "mydb", "cat1", "schema1", "table1") is True
|
|
assert _matches_rule_entry(entry, "mydb", "cat2", None, None) is False
|
|
assert _matches_rule_entry(entry, "otherdb", "cat1", None, None) is False
|
|
|
|
|
|
def test_matches_rule_entry_database_and_schema():
|
|
"""Test matching when rule specifies database and schema (no catalog)."""
|
|
entry = {"database": "mydb", "schema": "public"}
|
|
assert _matches_rule_entry(entry, "mydb", None, "public", None) is True
|
|
assert _matches_rule_entry(entry, "mydb", None, "public", "table1") is True
|
|
assert _matches_rule_entry(entry, "mydb", None, "other", None) is False
|
|
|
|
|
|
def test_matches_rule_entry_full_table():
|
|
"""Test matching when rule specifies full table path."""
|
|
entry = {"database": "mydb", "schema": "public", "table": "users"}
|
|
assert _matches_rule_entry(entry, "mydb", None, "public", "users") is True
|
|
assert _matches_rule_entry(entry, "mydb", None, "public", "orders") is False
|
|
assert _matches_rule_entry(entry, "mydb", None, "other", "users") is False
|
|
|
|
|
|
def test_matches_rule_entry_with_catalog():
|
|
"""Test matching with catalog in the path."""
|
|
entry = {
|
|
"database": "mydb",
|
|
"catalog": "main",
|
|
"schema": "public",
|
|
"table": "users",
|
|
}
|
|
assert _matches_rule_entry(entry, "mydb", "main", "public", "users") is True
|
|
assert _matches_rule_entry(entry, "mydb", "other", "public", "users") is False
|
|
|
|
|
|
# Tests for _is_more_specific
|
|
def test_is_more_specific():
|
|
"""Test specificity comparison between entries."""
|
|
db_only = {"database": "mydb"}
|
|
db_schema = {"database": "mydb", "schema": "public"}
|
|
db_table = {"database": "mydb", "schema": "public", "table": "users"}
|
|
db_catalog = {"database": "mydb", "catalog": "main"}
|
|
db_catalog_schema = {"database": "mydb", "catalog": "main", "schema": "public"}
|
|
|
|
# More specific should win
|
|
assert _is_more_specific(db_schema, db_only) is True
|
|
assert _is_more_specific(db_table, db_schema) is True
|
|
assert _is_more_specific(db_table, db_only) is True
|
|
|
|
# Less specific should lose
|
|
assert _is_more_specific(db_only, db_schema) is False
|
|
assert _is_more_specific(db_schema, db_table) is False
|
|
|
|
# Same specificity
|
|
assert _is_more_specific(db_schema, db_catalog) is False
|
|
assert _is_more_specific(db_catalog, db_schema) is False
|
|
|
|
|
|
# Tests for check_table_access
|
|
def test_check_table_access_no_rules():
|
|
"""Test access check when no rules are provided."""
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table, rules=[])
|
|
assert result.access == AccessCheckResult.NO_RULE
|
|
assert result.rls_predicates == []
|
|
assert result.cls_rules == {}
|
|
|
|
|
|
def test_check_table_access_allowed():
|
|
"""Test access check when table is allowed."""
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [{"database": "mydb", "schema": "public"}],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table, rules=[rule])
|
|
assert result.access == AccessCheckResult.ALLOWED
|
|
assert result.rls_predicates == []
|
|
assert result.cls_rules == {}
|
|
|
|
|
|
def test_check_table_access_denied():
|
|
"""Test access check when table is denied."""
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [],
|
|
"denied": [{"database": "mydb", "schema": "public"}],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table, rules=[rule])
|
|
assert result.access == AccessCheckResult.DENIED
|
|
|
|
|
|
def test_check_table_access_denied_more_specific():
|
|
"""Test that more specific deny wins over less specific allow."""
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [{"database": "mydb"}], # Less specific
|
|
"denied": [{"database": "mydb", "schema": "secret"}], # More specific
|
|
}
|
|
|
|
# Table in non-denied schema should be allowed
|
|
table_public = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table_public, rules=[rule])
|
|
assert result.access == AccessCheckResult.ALLOWED
|
|
|
|
# Table in denied schema should be denied
|
|
table_secret = Table(table="data", schema="secret", catalog=None)
|
|
result = check_table_access("mydb", table_secret, rules=[rule])
|
|
assert result.access == AccessCheckResult.DENIED
|
|
|
|
|
|
def test_check_table_access_allowed_more_specific():
|
|
"""Test that more specific allow wins over less specific deny."""
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [{"database": "mydb", "schema": "public", "table": "users"}],
|
|
"denied": [{"database": "mydb", "schema": "public"}],
|
|
}
|
|
|
|
# The specific table is allowed despite schema being denied
|
|
table_users = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table_users, rules=[rule])
|
|
assert result.access == AccessCheckResult.ALLOWED
|
|
|
|
# Other tables in the schema are still denied
|
|
table_orders = Table(table="orders", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table_orders, rules=[rule])
|
|
assert result.access == AccessCheckResult.DENIED
|
|
|
|
|
|
def test_check_table_access_same_specificity_deny_wins():
|
|
"""Test that deny wins when rules have same specificity."""
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [{"database": "mydb", "schema": "public"}],
|
|
"denied": [{"database": "mydb", "schema": "public"}],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table, rules=[rule])
|
|
assert result.access == AccessCheckResult.DENIED
|
|
|
|
|
|
def test_check_table_access_with_rls():
|
|
"""Test access check collects RLS predicates."""
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"table": "users",
|
|
"rls": {"predicate": "org_id = 123", "group_key": "org"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table, rules=[rule])
|
|
assert result.access == AccessCheckResult.ALLOWED
|
|
assert len(result.rls_predicates) == 1
|
|
assert result.rls_predicates[0].predicate == "org_id = 123"
|
|
assert result.rls_predicates[0].group_key == "org"
|
|
|
|
|
|
def test_check_table_access_multiple_rls():
|
|
"""Test access check collects multiple RLS predicates from different rules."""
|
|
rule1 = MagicMock(spec=DataAccessRule)
|
|
rule1.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"rls": {"predicate": "org_id = 123", "group_key": "org"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
rule2 = MagicMock(spec=DataAccessRule)
|
|
rule2.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"rls": {"predicate": "region = 'US'"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table, rules=[rule1, rule2])
|
|
assert result.access == AccessCheckResult.ALLOWED
|
|
assert len(result.rls_predicates) == 2
|
|
|
|
predicates = [p.predicate for p in result.rls_predicates]
|
|
assert "org_id = 123" in predicates
|
|
assert "region = 'US'" in predicates
|
|
|
|
|
|
def test_check_table_access_with_cls():
|
|
"""Test access check collects CLS rules."""
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"table": "users",
|
|
"cls": {"email": "mask", "ssn": "hide", "name": "hash"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table, rules=[rule])
|
|
assert result.access == AccessCheckResult.ALLOWED
|
|
assert result.cls_rules == {
|
|
"email": CLSAction.MASK,
|
|
"ssn": CLSAction.HIDE,
|
|
"name": CLSAction.HASH,
|
|
}
|
|
|
|
|
|
def test_check_table_access_cls_strictest_wins():
|
|
"""Test that strictest CLS action wins when multiple rules apply."""
|
|
rule1 = MagicMock(spec=DataAccessRule)
|
|
rule1.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"cls": {"email": "mask"}, # Less strict
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
rule2 = MagicMock(spec=DataAccessRule)
|
|
rule2.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"cls": {"email": "hide"}, # More strict - should win
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
result = check_table_access("mydb", table, rules=[rule1, rule2])
|
|
assert result.access == AccessCheckResult.ALLOWED
|
|
assert result.cls_rules["email"] == CLSAction.HIDE
|
|
|
|
|
|
# Tests for get_rls_predicates_for_table
|
|
def test_get_rls_predicates_for_table_no_predicates():
|
|
"""Test getting RLS predicates when there are none."""
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [{"database": "mydb", "schema": "public"}],
|
|
"denied": [],
|
|
}
|
|
|
|
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
|
|
assert predicates == []
|
|
|
|
|
|
def test_get_rls_predicates_for_table_with_predicates():
|
|
"""Test getting RLS predicates."""
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"rls": {"predicate": "org_id = 123"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
|
|
assert predicates == ["(org_id = 123)"]
|
|
|
|
|
|
def test_get_rls_predicates_for_table_with_group_key():
|
|
"""Test getting RLS predicates with group_key combines with OR."""
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
rule1 = MagicMock(spec=DataAccessRule)
|
|
rule1.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"rls": {"predicate": "org_id = 1", "group_key": "org"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
rule2 = MagicMock(spec=DataAccessRule)
|
|
rule2.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"rls": {"predicate": "org_id = 2", "group_key": "org"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
predicates = get_rls_predicates_for_table(table, database, rules=[rule1, rule2])
|
|
# Same group_key predicates should be ORed
|
|
assert len(predicates) == 1
|
|
assert "(org_id = 1)" in predicates[0]
|
|
assert "(org_id = 2)" in predicates[0]
|
|
assert " OR " in predicates[0]
|
|
|
|
|
|
def test_get_rls_predicates_for_table_mixed_group_keys():
|
|
"""Test getting RLS predicates with mixed group_keys."""
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"rls": {"predicate": "org_id = 1", "group_key": "org"},
|
|
},
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"rls": {"predicate": "org_id = 2", "group_key": "org"},
|
|
},
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"rls": {"predicate": "region = 'US'"}, # No group_key
|
|
},
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
predicates = get_rls_predicates_for_table(table, database, rules=[rule])
|
|
# Should have: ungrouped predicate + ORed group predicate = 2 items
|
|
assert len(predicates) == 2
|
|
|
|
has_region = any("region = 'US'" in p for p in predicates)
|
|
has_org_group = any("org_id = 1" in p and "org_id = 2" in p for p in predicates)
|
|
assert has_region
|
|
assert has_org_group
|
|
|
|
|
|
# Tests for get_cls_rules_for_table
|
|
def test_get_cls_rules_for_table_no_rules():
|
|
"""Test getting CLS rules when there are none."""
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [{"database": "mydb", "schema": "public"}],
|
|
"denied": [],
|
|
}
|
|
|
|
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
|
|
assert cls_rules == {}
|
|
|
|
|
|
def test_get_cls_rules_for_table_with_rules():
|
|
"""Test getting CLS rules."""
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"table": "users",
|
|
"cls": {"email": "mask", "ssn": "hide"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
cls_rules = get_cls_rules_for_table(table, database, rules=[rule])
|
|
assert cls_rules == {"email": CLSAction.MASK, "ssn": CLSAction.HIDE}
|
|
|
|
|
|
# Tests for get_all_group_keys
|
|
def test_get_all_group_keys_empty(app_context: None):
|
|
"""Test getting group keys when none exist."""
|
|
with patch("superset.data_access_rules.utils.db") as mock_db:
|
|
mock_db.session.query.return_value.all.return_value = []
|
|
keys = get_all_group_keys()
|
|
assert keys == set()
|
|
|
|
|
|
def test_get_all_group_keys_with_keys(app_context: None):
|
|
"""Test getting group keys from rules."""
|
|
rule1 = MagicMock(spec=DataAccessRule)
|
|
rule1.rule_dict = {
|
|
"allowed": [
|
|
{"database": "mydb", "rls": {"predicate": "x=1", "group_key": "key1"}},
|
|
{"database": "mydb", "rls": {"predicate": "x=2", "group_key": "key2"}},
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
rule2 = MagicMock(spec=DataAccessRule)
|
|
rule2.rule_dict = {
|
|
"allowed": [
|
|
{"database": "mydb", "rls": {"predicate": "x=3", "group_key": "key1"}},
|
|
{"database": "mydb", "rls": {"predicate": "x=4"}}, # No group_key
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
with patch("superset.data_access_rules.utils.db") as mock_db:
|
|
mock_db.session.query.return_value.all.return_value = [rule1, rule2]
|
|
keys = get_all_group_keys()
|
|
assert keys == {"key1", "key2"}
|
|
|
|
|
|
def test_get_all_group_keys_filtered_by_database(app_context: None):
|
|
"""Test getting group keys filtered by database."""
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{"database": "db1", "rls": {"predicate": "x=1", "group_key": "key1"}},
|
|
{"database": "db2", "rls": {"predicate": "x=2", "group_key": "key2"}},
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
with patch("superset.data_access_rules.utils.db") as mock_db:
|
|
mock_db.session.query.return_value.all.return_value = [rule]
|
|
keys = get_all_group_keys(database_name="db1")
|
|
assert keys == {"key1"}
|
|
|
|
|
|
def test_get_all_group_keys_filtered_by_table(app_context: None):
|
|
"""Test getting group keys filtered by table."""
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "db1",
|
|
"schema": "public",
|
|
"table": "users",
|
|
"rls": {"predicate": "x=1", "group_key": "key1"},
|
|
},
|
|
{
|
|
"database": "db1",
|
|
"schema": "public",
|
|
"table": "orders",
|
|
"rls": {"predicate": "x=2", "group_key": "key2"},
|
|
},
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
with patch("superset.data_access_rules.utils.db") as mock_db:
|
|
mock_db.session.query.return_value.all.return_value = [rule]
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
keys = get_all_group_keys(database_name="db1", table=table)
|
|
assert keys == {"key1"}
|
|
|
|
|
|
# Tests for get_hidden_columns_for_table
|
|
def test_get_hidden_columns_for_table_no_hidden(app_context: None):
|
|
"""Test getting hidden columns when no columns are hidden."""
|
|
from superset.data_access_rules.utils import get_hidden_columns_for_table
|
|
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"table": "users",
|
|
"cls": {"email": "mask", "phone": "hash"}, # No "hide" actions
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
|
|
assert hidden == set()
|
|
|
|
|
|
def test_get_hidden_columns_for_table_with_hidden(app_context: None):
|
|
"""Test getting hidden columns when some columns are hidden."""
|
|
from superset.data_access_rules.utils import get_hidden_columns_for_table
|
|
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"table": "users",
|
|
"cls": {"email": "mask", "ssn": "hide", "password": "hide"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
|
|
assert hidden == {"ssn", "password"}
|
|
|
|
|
|
def test_get_hidden_columns_for_table_denied_access(app_context: None):
|
|
"""Test that denied access returns no hidden columns."""
|
|
from superset.data_access_rules.utils import get_hidden_columns_for_table
|
|
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [],
|
|
"denied": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"table": "users",
|
|
}
|
|
],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
hidden = get_hidden_columns_for_table(table, database, rules=[rule])
|
|
# Denied access means no CLS rules are returned
|
|
assert hidden == set()
|
|
|
|
|
|
# Tests for filter_columns_by_cls
|
|
def test_filter_columns_by_cls_no_hidden(app_context: None):
|
|
"""Test filtering columns when no columns are hidden."""
|
|
from superset.data_access_rules.utils import filter_columns_by_cls
|
|
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
|
|
columns = [
|
|
{"column_name": "id", "type": "INTEGER"},
|
|
{"column_name": "name", "type": "VARCHAR"},
|
|
{"column_name": "email", "type": "VARCHAR"},
|
|
]
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{"database": "mydb", "schema": "public", "table": "users"}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
with patch(
|
|
"superset.data_access_rules.utils.is_feature_enabled",
|
|
return_value=True,
|
|
):
|
|
with patch(
|
|
"superset.data_access_rules.utils.get_user_rules",
|
|
return_value=[rule],
|
|
):
|
|
filtered = filter_columns_by_cls(columns, table, database)
|
|
assert len(filtered) == 3
|
|
assert filtered == columns
|
|
|
|
|
|
def test_filter_columns_by_cls_with_hidden(app_context: None):
|
|
"""Test filtering columns when some columns are hidden."""
|
|
from superset.data_access_rules.utils import filter_columns_by_cls
|
|
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
|
|
columns = [
|
|
{"column_name": "id", "type": "INTEGER"},
|
|
{"column_name": "name", "type": "VARCHAR"},
|
|
{"column_name": "email", "type": "VARCHAR"},
|
|
{"column_name": "ssn", "type": "VARCHAR"},
|
|
]
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"table": "users",
|
|
"cls": {"ssn": "hide"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
with patch(
|
|
"superset.data_access_rules.utils.is_feature_enabled",
|
|
return_value=True,
|
|
):
|
|
with patch(
|
|
"superset.data_access_rules.utils.get_user_rules",
|
|
return_value=[rule],
|
|
):
|
|
filtered = filter_columns_by_cls(columns, table, database)
|
|
assert len(filtered) == 3
|
|
column_names = [c["column_name"] for c in filtered]
|
|
assert "ssn" not in column_names
|
|
assert "id" in column_names
|
|
assert "name" in column_names
|
|
assert "email" in column_names
|
|
|
|
|
|
def test_filter_columns_by_cls_feature_disabled(app_context: None):
|
|
"""Test that filtering is skipped when feature flag is disabled."""
|
|
from superset.data_access_rules.utils import filter_columns_by_cls
|
|
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
|
|
columns = [
|
|
{"column_name": "id", "type": "INTEGER"},
|
|
{"column_name": "ssn", "type": "VARCHAR"},
|
|
]
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
with patch(
|
|
"superset.data_access_rules.utils.is_feature_enabled",
|
|
return_value=False,
|
|
):
|
|
# Even if there would be hidden columns, they are not filtered
|
|
filtered = filter_columns_by_cls(columns, table, database)
|
|
assert len(filtered) == 2
|
|
assert filtered == columns
|
|
|
|
|
|
def test_filter_columns_by_cls_custom_key(app_context: None):
|
|
"""Test filtering columns with custom column name key."""
|
|
from superset.data_access_rules.utils import filter_columns_by_cls
|
|
|
|
database = MagicMock()
|
|
database.database_name = "mydb"
|
|
|
|
# Columns with different key structure (like from SQL Lab table metadata)
|
|
columns = [
|
|
{"name": "id", "type": "INTEGER"},
|
|
{"name": "ssn", "type": "VARCHAR"},
|
|
]
|
|
|
|
rule = MagicMock(spec=DataAccessRule)
|
|
rule.rule_dict = {
|
|
"allowed": [
|
|
{
|
|
"database": "mydb",
|
|
"schema": "public",
|
|
"table": "users",
|
|
"cls": {"ssn": "hide"},
|
|
}
|
|
],
|
|
"denied": [],
|
|
}
|
|
|
|
table = Table(table="users", schema="public", catalog=None)
|
|
|
|
with patch(
|
|
"superset.data_access_rules.utils.is_feature_enabled",
|
|
return_value=True,
|
|
):
|
|
with patch(
|
|
"superset.data_access_rules.utils.get_user_rules",
|
|
return_value=[rule],
|
|
):
|
|
filtered = filter_columns_by_cls(
|
|
columns, table, database, column_name_key="name"
|
|
)
|
|
assert len(filtered) == 1
|
|
assert filtered[0]["name"] == "id"
|