mirror of
https://github.com/apache/superset.git
synced 2026-05-11 19:05:24 +00:00
feat: column-level security
This commit is contained in:
@@ -17,6 +17,8 @@
|
||||
# pylint: disable=invalid-name, redefined-outer-name, too-many-lines
|
||||
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlglot import Dialects, exp, parse_one
|
||||
@@ -24,6 +26,10 @@ from sqlglot import Dialects, exp, parse_one
|
||||
from superset.exceptions import QueryClauseValidationException, SupersetParseError
|
||||
from superset.jinja_context import JinjaTemplateProcessor
|
||||
from superset.sql.parse import (
|
||||
apply_cls,
|
||||
CLS_HASH_FUNCTIONS,
|
||||
CLSAction,
|
||||
CLSTransformer,
|
||||
CTASMethod,
|
||||
extract_tables_from_statement,
|
||||
JinjaSQLResult,
|
||||
@@ -2977,3 +2983,783 @@ def test_has_subquery(sql: str, engine: str, expected: bool) -> None:
|
||||
Test the `has_subquery` method.
|
||||
"""
|
||||
assert SQLStatement(sql, engine).has_subquery() == expected
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Column-Level Security (CLS) Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_cls_action_enum() -> None:
|
||||
"""
|
||||
Test CLSAction enum values exist.
|
||||
"""
|
||||
assert CLSAction.HASH is not None
|
||||
assert CLSAction.NULLIFY is not None
|
||||
assert CLSAction.HIDE is not None
|
||||
assert CLSAction.MASK is not None
|
||||
|
||||
|
||||
def test_cls_hash_functions_mapping() -> None:
|
||||
"""
|
||||
Test that CLS_HASH_FUNCTIONS has entries for common dialects.
|
||||
"""
|
||||
# Check fallback exists
|
||||
assert None in CLS_HASH_FUNCTIONS
|
||||
assert CLS_HASH_FUNCTIONS[None] == "'[HASHED]'"
|
||||
|
||||
# Check common dialects
|
||||
assert Dialects.POSTGRES in CLS_HASH_FUNCTIONS
|
||||
assert Dialects.MYSQL in CLS_HASH_FUNCTIONS
|
||||
assert Dialects.BIGQUERY in CLS_HASH_FUNCTIONS
|
||||
assert Dialects.SNOWFLAKE in CLS_HASH_FUNCTIONS
|
||||
|
||||
# Verify hash patterns contain placeholder
|
||||
for dialect, pattern in CLS_HASH_FUNCTIONS.items():
|
||||
if dialect is not None and pattern != "'[HASHED]'":
|
||||
assert "{}" in pattern, f"Missing placeholder in {dialect} hash pattern"
|
||||
|
||||
|
||||
def test_apply_cls_empty_rules() -> None:
|
||||
"""
|
||||
Test that apply_cls returns original SQL when rules are empty.
|
||||
"""
|
||||
sql = "SELECT id, name FROM users"
|
||||
result = apply_cls(sql, {}, engine="postgresql")
|
||||
assert result == sql
|
||||
|
||||
|
||||
def test_apply_cls_hash_action() -> None:
|
||||
"""
|
||||
Test CLSAction.HASH transforms column with hash function.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT ssn, name FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
||||
' "users"."name" AS "name"\n'
|
||||
'FROM "users" AS "users"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_nullify_action() -> None:
|
||||
"""
|
||||
Test CLSAction.NULLIFY transforms column to NULL.
|
||||
"""
|
||||
rules = {"users": {"salary": CLSAction.NULLIFY}}
|
||||
sql = "SELECT salary, name FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
'SELECT\n NULL AS salary,\n "users"."name" AS "name"\nFROM "users" AS "users"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_hide_action() -> None:
|
||||
"""
|
||||
Test CLSAction.HIDE removes column from SELECT.
|
||||
"""
|
||||
rules = {"users": {"password": CLSAction.HIDE}}
|
||||
sql = "SELECT password, name FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == ('SELECT\n "users"."name" AS "name"\nFROM "users" AS "users"')
|
||||
|
||||
|
||||
def test_apply_cls_mask_action() -> None:
|
||||
"""
|
||||
Test CLSAction.MASK transforms column to '****'.
|
||||
"""
|
||||
rules = {"users": {"phone": CLSAction.MASK}}
|
||||
sql = "SELECT phone, name FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
" '****' AS phone,\n"
|
||||
' "users"."name" AS "name"\n'
|
||||
'FROM "users" AS "users"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_all_actions() -> None:
|
||||
"""
|
||||
Test all CLS actions in a single query.
|
||||
"""
|
||||
rules = {
|
||||
"employees": {
|
||||
"ssn": CLSAction.HASH,
|
||||
"salary": CLSAction.NULLIFY,
|
||||
"password": CLSAction.HIDE,
|
||||
"phone": CLSAction.MASK,
|
||||
}
|
||||
}
|
||||
sql = "SELECT ssn, salary, password, phone, name FROM employees"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("employees"."ssn" AS TEXT)) AS ssn,\n'
|
||||
" NULL AS salary,\n"
|
||||
" '****' AS phone,\n"
|
||||
' "employees"."name" AS "name"\n'
|
||||
'FROM "employees" AS "employees"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_qualified_columns() -> None:
|
||||
"""
|
||||
Test CLS with fully qualified column names (table.column).
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT users.ssn, users.name FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
||||
' "users"."name" AS "name"\n'
|
||||
'FROM "users" AS "users"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_table_alias() -> None:
|
||||
"""
|
||||
Test CLS with table aliases.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT u.ssn, u.name FROM users u"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("u"."ssn" AS TEXT)) AS ssn,\n'
|
||||
' "u"."name" AS "name"\n'
|
||||
'FROM "users" AS "u"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_join() -> None:
|
||||
"""
|
||||
Test CLS with JOIN queries.
|
||||
"""
|
||||
rules = {
|
||||
"employees": {"ssn": CLSAction.HASH},
|
||||
"salaries": {"amount": CLSAction.NULLIFY},
|
||||
}
|
||||
sql = """
|
||||
SELECT e.ssn, e.name, s.amount
|
||||
FROM employees e
|
||||
JOIN salaries s
|
||||
ON e.id = s.employee_id
|
||||
"""
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("e"."ssn" AS TEXT)) AS ssn,\n'
|
||||
' "e"."name" AS "name",\n'
|
||||
" NULL AS amount\n"
|
||||
'FROM "employees" AS "e"\n'
|
||||
'JOIN "salaries" AS "s"\n'
|
||||
' ON "e"."id" = "s"."employee_id"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_case_insensitive() -> None:
|
||||
"""
|
||||
Test CLS rules are case-insensitive for table and column names.
|
||||
"""
|
||||
rules = {"USERS": {"SSN": CLSAction.HASH}}
|
||||
sql = "SELECT ssn, name FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
||||
' "users"."name" AS "name"\n'
|
||||
'FROM "users" AS "users"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_with_column_alias() -> None:
|
||||
"""
|
||||
Test CLS preserves existing column aliases.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT ssn AS social_security, name FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("users"."ssn" AS TEXT)) AS social_security,\n'
|
||||
' "users"."name" AS "name"\n'
|
||||
'FROM "users" AS "users"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_no_matching_table() -> None:
|
||||
"""
|
||||
Test CLS leaves columns unchanged when table doesn't match rules.
|
||||
"""
|
||||
rules = {"other_table": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT ssn, name FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
# Table doesn't match rules, so columns are unchanged (just qualified)
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' "users"."ssn" AS "ssn",\n'
|
||||
' "users"."name" AS "name"\n'
|
||||
'FROM "users" AS "users"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_non_column_expressions() -> None:
|
||||
"""
|
||||
Test CLS leaves non-column expressions unchanged.
|
||||
"""
|
||||
rules = {"users": {"name": CLSAction.HASH}}
|
||||
sql = "SELECT 1 AS one, 'test' AS str, COUNT(*) AS cnt FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' 1 AS "one",\n'
|
||||
" 'test' AS \"str\",\n"
|
||||
' COUNT(*) AS "cnt"\n'
|
||||
'FROM "users" AS "users"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_with_schema() -> None:
|
||||
"""
|
||||
Test CLS with schema for column qualification.
|
||||
"""
|
||||
rules = {
|
||||
"employees": {"ssn": CLSAction.HASH},
|
||||
"departments": {"budget": CLSAction.NULLIFY},
|
||||
}
|
||||
schema = {
|
||||
"employees": {
|
||||
"id": "INT",
|
||||
"ssn": "VARCHAR",
|
||||
"name": "VARCHAR",
|
||||
"dept_id": "INT",
|
||||
},
|
||||
"departments": {"id": "INT", "name": "VARCHAR", "budget": "DECIMAL"},
|
||||
}
|
||||
sql = """
|
||||
SELECT
|
||||
ssn, name, budget
|
||||
FROM employees e
|
||||
JOIN departments d
|
||||
ON e.dept_id = d.id
|
||||
"""
|
||||
result = apply_cls(sql, rules, engine="postgresql", schema=schema)
|
||||
|
||||
assert "MD5" in result
|
||||
assert "NULL" in result
|
||||
|
||||
|
||||
def test_apply_cls_different_dialects() -> None:
|
||||
"""
|
||||
Test CLS uses correct hash function for different database dialects.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT ssn FROM users"
|
||||
|
||||
# PostgreSQL
|
||||
result_pg = apply_cls(sql, rules, engine="postgresql")
|
||||
assert result_pg == (
|
||||
'SELECT\n MD5(CAST("users"."ssn" AS TEXT)) AS ssn\nFROM "users" AS "users"'
|
||||
)
|
||||
|
||||
# MySQL
|
||||
result_mysql = apply_cls(sql, rules, engine="mysql")
|
||||
assert result_mysql == (
|
||||
"SELECT\n MD5(CAST(`users`.`ssn` AS CHAR)) AS ssn\nFROM `users` AS `users`"
|
||||
)
|
||||
|
||||
# BigQuery
|
||||
result_bq = apply_cls(sql, rules, engine="bigquery")
|
||||
assert result_bq == (
|
||||
"SELECT\n"
|
||||
" TO_HEX(MD5(CAST(`users`.`ssn` AS STRING))) AS ssn\n"
|
||||
"FROM `users` AS `users`"
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_unknown_dialect_fallback() -> None:
|
||||
"""
|
||||
Test CLS uses fallback for unknown database dialects.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT users.ssn FROM users"
|
||||
result = apply_cls(sql, rules, engine="unknown_database")
|
||||
|
||||
assert result == ('SELECT\n \'[HASHED]\' AS ssn\nFROM "users" AS "users"')
|
||||
|
||||
|
||||
def test_apply_cls_select_star_warning(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""
|
||||
Test CLS logs warning for SELECT * queries.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT * FROM users"
|
||||
|
||||
import logging
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert (
|
||||
"SELECT *" in caplog.text or "CLS cannot fully process SELECT *" in caplog.text
|
||||
)
|
||||
assert "*" in result # Star should be preserved
|
||||
|
||||
|
||||
def test_sql_statement_apply_cls_method() -> None:
|
||||
"""
|
||||
Test SQLStatement.apply_cls method.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
statement = SQLStatement("SELECT ssn, name FROM users", engine="postgresql")
|
||||
statement.apply_cls(rules)
|
||||
result = statement.format()
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
||||
' "users"."name" AS "name"\n'
|
||||
'FROM "users" AS "users"'
|
||||
)
|
||||
|
||||
|
||||
def test_sql_statement_apply_cls_empty_rules() -> None:
|
||||
"""
|
||||
Test SQLStatement.apply_cls with empty rules returns unchanged statement.
|
||||
"""
|
||||
original_sql = "SELECT ssn, name FROM users"
|
||||
statement = SQLStatement(original_sql, engine="postgresql")
|
||||
statement.apply_cls({})
|
||||
result = statement.format()
|
||||
|
||||
# Empty rules, so original query is preserved (just formatted)
|
||||
assert result == ("SELECT\n ssn,\n name\nFROM users")
|
||||
|
||||
|
||||
def test_sql_statement_apply_cls_with_schema() -> None:
|
||||
"""
|
||||
Test SQLStatement.apply_cls with schema parameter.
|
||||
"""
|
||||
rules = {"employees": {"ssn": CLSAction.HASH}}
|
||||
schema = {"employees": {"id": "INT", "ssn": "VARCHAR", "name": "VARCHAR"}}
|
||||
statement = SQLStatement("SELECT ssn, name FROM employees", engine="postgresql")
|
||||
statement.apply_cls(rules, schema=schema)
|
||||
result = statement.format()
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("employees"."ssn" AS TEXT)) AS ssn,\n'
|
||||
' "employees"."name" AS "name"\n'
|
||||
'FROM "employees" AS "employees"'
|
||||
)
|
||||
|
||||
|
||||
def test_cls_transformer_normalize_rules() -> None:
|
||||
"""
|
||||
Test CLSTransformer normalizes table and column names to lowercase.
|
||||
"""
|
||||
rules = {"USERS": {"SSN": CLSAction.HASH, "Name": CLSAction.MASK}}
|
||||
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
||||
|
||||
assert "users" in transformer.rules
|
||||
assert "ssn" in transformer.rules["users"]
|
||||
assert "name" in transformer.rules["users"]
|
||||
|
||||
|
||||
def test_cls_transformer_get_action() -> None:
|
||||
"""
|
||||
Test CLSTransformer._get_action method.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
||||
|
||||
# Valid table and column
|
||||
assert transformer._get_action("users", "ssn") == CLSAction.HASH
|
||||
|
||||
# Case insensitive
|
||||
assert transformer._get_action("USERS", "SSN") == CLSAction.HASH
|
||||
|
||||
# No matching column
|
||||
assert transformer._get_action("users", "name") is None
|
||||
|
||||
# No matching table
|
||||
assert transformer._get_action("other", "ssn") is None
|
||||
|
||||
# None table
|
||||
assert transformer._get_action(None, "ssn") is None
|
||||
|
||||
|
||||
def test_cls_transformer_extract_scope_tables() -> None:
|
||||
"""
|
||||
Test CLSTransformer._extract_scope_tables method.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
||||
|
||||
# Single table
|
||||
select = parse_one("SELECT * FROM users")
|
||||
tables = transformer._extract_scope_tables(select)
|
||||
assert "users" in tables
|
||||
assert tables["users"] == "users"
|
||||
|
||||
# Table with alias
|
||||
select = parse_one("SELECT * FROM users u")
|
||||
tables = transformer._extract_scope_tables(select)
|
||||
assert "u" in tables
|
||||
assert tables["u"] == "users"
|
||||
|
||||
# JOIN
|
||||
select = parse_one("SELECT * FROM users u JOIN orders o ON u.id = o.user_id")
|
||||
tables = transformer._extract_scope_tables(select)
|
||||
assert "u" in tables
|
||||
assert "o" in tables
|
||||
assert tables["u"] == "users"
|
||||
assert tables["o"] == "orders"
|
||||
|
||||
|
||||
def test_cls_transformer_get_table_for_column_qualified() -> None:
|
||||
"""
|
||||
Test CLSTransformer._get_table_for_column with qualified columns.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
||||
scope_tables = {"u": "users", "o": "orders"}
|
||||
|
||||
# Qualified with alias
|
||||
column = parse_one("u.ssn").find(exp.Column)
|
||||
result = transformer._get_table_for_column(column, scope_tables)
|
||||
assert result == "users"
|
||||
|
||||
# Qualified with unknown alias (returns as-is)
|
||||
column = parse_one("x.ssn").find(exp.Column)
|
||||
result = transformer._get_table_for_column(column, scope_tables)
|
||||
assert result == "x"
|
||||
|
||||
|
||||
def test_cls_transformer_get_table_for_column_single_table() -> None:
|
||||
"""
|
||||
Test CLSTransformer._get_table_for_column infers single table.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
||||
scope_tables = {"users": "users"}
|
||||
|
||||
# Unqualified column with single table in scope
|
||||
column = parse_one("ssn").find(exp.Column)
|
||||
result = transformer._get_table_for_column(column, scope_tables)
|
||||
assert result == "users"
|
||||
|
||||
|
||||
def test_cls_transformer_get_table_for_column_multi_table_rules_match() -> None:
|
||||
"""
|
||||
Test CLSTransformer._get_table_for_column matches against rules.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
||||
scope_tables = {"users": "users", "orders": "orders"}
|
||||
|
||||
# Unqualified column that only exists in rules for one table
|
||||
column = parse_one("ssn").find(exp.Column)
|
||||
result = transformer._get_table_for_column(column, scope_tables)
|
||||
assert result == "users"
|
||||
|
||||
|
||||
def test_cls_transformer_get_table_for_column_no_match() -> None:
|
||||
"""
|
||||
Test CLSTransformer._get_table_for_column returns None when no match.
|
||||
"""
|
||||
rules = {"other": {"col": CLSAction.HASH}}
|
||||
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
||||
scope_tables = {"users": "users", "orders": "orders"}
|
||||
|
||||
# Unqualified column with no matching rule
|
||||
column = parse_one("ssn").find(exp.Column)
|
||||
result = transformer._get_table_for_column(column, scope_tables)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_cls_transformer_get_column_alias() -> None:
|
||||
"""
|
||||
Test CLSTransformer._get_column_alias method.
|
||||
"""
|
||||
transformer = CLSTransformer({}, Dialects.POSTGRES)
|
||||
|
||||
# Column expression
|
||||
column = parse_one("ssn").find(exp.Column)
|
||||
assert transformer._get_column_alias(column) == "ssn"
|
||||
|
||||
# Alias expression
|
||||
alias = parse_one("ssn AS social").find(exp.Alias)
|
||||
assert transformer._get_column_alias(alias) == "social"
|
||||
|
||||
# Other expression (literal)
|
||||
literal = parse_one("'test'").find(exp.Literal)
|
||||
assert transformer._get_column_alias(literal) == "'test'"
|
||||
|
||||
|
||||
def test_cls_transformer_create_expressions() -> None:
|
||||
"""
|
||||
Test CLSTransformer expression creation methods.
|
||||
"""
|
||||
transformer = CLSTransformer({}, Dialects.POSTGRES)
|
||||
|
||||
# Hash expression
|
||||
column = parse_one("ssn").find(exp.Column)
|
||||
hash_expr = transformer._create_hash_expression(column, "ssn")
|
||||
assert isinstance(hash_expr, exp.Alias)
|
||||
assert hash_expr.alias == "ssn"
|
||||
|
||||
# Null expression
|
||||
null_expr = transformer._create_null_expression("salary")
|
||||
assert isinstance(null_expr, exp.Alias)
|
||||
assert null_expr.alias == "salary"
|
||||
assert isinstance(null_expr.this, exp.Null)
|
||||
|
||||
# Mask expression
|
||||
mask_expr = transformer._create_mask_expression("phone")
|
||||
assert isinstance(mask_expr, exp.Alias)
|
||||
assert mask_expr.alias == "phone"
|
||||
assert isinstance(mask_expr.this, exp.Literal)
|
||||
assert mask_expr.this.this == "****"
|
||||
|
||||
|
||||
def test_cls_transformer_call_non_select() -> None:
|
||||
"""
|
||||
Test CLSTransformer.__call__ returns non-SELECT nodes unchanged.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
||||
|
||||
# Non-SELECT node should be returned unchanged
|
||||
table = parse_one("users").find(exp.Column)
|
||||
result = transformer(table)
|
||||
assert result == table
|
||||
|
||||
|
||||
def test_cls_transformer_transform_expression_non_column() -> None:
|
||||
"""
|
||||
Test CLSTransformer._transform_expression returns non-column expressions unchanged.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
transformer = CLSTransformer(rules, Dialects.POSTGRES)
|
||||
scope_tables = {"users": "users"}
|
||||
|
||||
# Literal expression should be unchanged
|
||||
literal = parse_one("'test'")
|
||||
result = transformer._transform_expression(literal, scope_tables)
|
||||
assert result == literal
|
||||
|
||||
# Function expression should be unchanged
|
||||
func = parse_one("COUNT(*)")
|
||||
result = transformer._transform_expression(func, scope_tables)
|
||||
assert result == func
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sql,rules,engine,expected",
|
||||
[
|
||||
# Basic HASH
|
||||
(
|
||||
"SELECT t.id FROM t",
|
||||
{"t": {"id": CLSAction.HASH}},
|
||||
"postgresql",
|
||||
'SELECT\n MD5(CAST("t"."id" AS TEXT)) AS id\nFROM "t" AS "t"',
|
||||
),
|
||||
# Basic NULLIFY
|
||||
(
|
||||
"SELECT t.salary FROM t",
|
||||
{"t": {"salary": CLSAction.NULLIFY}},
|
||||
"postgresql",
|
||||
'SELECT\n NULL AS salary\nFROM "t" AS "t"',
|
||||
),
|
||||
# Basic HIDE
|
||||
(
|
||||
"SELECT t.secret, t.public FROM t",
|
||||
{"t": {"secret": CLSAction.HIDE}},
|
||||
"postgresql",
|
||||
'SELECT\n "t"."public" AS "public"\nFROM "t" AS "t"',
|
||||
),
|
||||
# Basic MASK
|
||||
(
|
||||
"SELECT t.phone FROM t",
|
||||
{"t": {"phone": CLSAction.MASK}},
|
||||
"postgresql",
|
||||
'SELECT\n \'****\' AS phone\nFROM "t" AS "t"',
|
||||
),
|
||||
# Multiple tables with different rules
|
||||
(
|
||||
"SELECT a.ssn, b.amount FROM users a JOIN payments b ON a.id = b.user_id",
|
||||
{
|
||||
"users": {"ssn": CLSAction.HASH},
|
||||
"payments": {"amount": CLSAction.NULLIFY},
|
||||
},
|
||||
"postgresql",
|
||||
(
|
||||
"SELECT\n"
|
||||
' MD5(CAST("a"."ssn" AS TEXT)) AS ssn,\n'
|
||||
" NULL AS amount\n"
|
||||
'FROM "users" AS "a"\n'
|
||||
'JOIN "payments" AS "b"\n'
|
||||
' ON "a"."id" = "b"."user_id"'
|
||||
),
|
||||
),
|
||||
# Snowflake dialect
|
||||
(
|
||||
"SELECT t.col FROM t",
|
||||
{"t": {"col": CLSAction.HASH}},
|
||||
"snowflake",
|
||||
'SELECT\n MD5(TO_CHAR("T"."COL")) AS COL\nFROM "T" AS "T"',
|
||||
),
|
||||
# ClickHouse dialect
|
||||
(
|
||||
"SELECT t.col FROM t",
|
||||
{"t": {"col": CLSAction.HASH}},
|
||||
"clickhouse",
|
||||
'SELECT\n MD5(toString("t"."col")) AS col\nFROM "t" AS "t"',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_apply_cls_parametrized(
|
||||
sql: str,
|
||||
rules: dict[str, Any],
|
||||
engine: str,
|
||||
expected: str,
|
||||
) -> None:
|
||||
"""
|
||||
Parametrized tests for apply_cls function.
|
||||
"""
|
||||
result = apply_cls(sql, rules, engine=engine)
|
||||
assert result == expected
|
||||
|
||||
|
||||
def test_apply_cls_subquery() -> None:
|
||||
"""
|
||||
Test CLS applies to subqueries.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT * FROM (SELECT ssn, name FROM users) AS subq"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' "subq"."ssn" AS "ssn",\n'
|
||||
' "subq"."name" AS "name"\n'
|
||||
"FROM (\n"
|
||||
" SELECT\n"
|
||||
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
||||
' "users"."name" AS "name"\n'
|
||||
' FROM "users" AS "users"\n'
|
||||
') AS "subq"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_cte() -> None:
|
||||
"""
|
||||
Test CLS applies to CTEs.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "WITH cte AS (SELECT ssn, name FROM users) SELECT * FROM cte"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
'WITH "cte" AS (\n'
|
||||
" SELECT\n"
|
||||
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n'
|
||||
' "users"."name" AS "name"\n'
|
||||
' FROM "users" AS "users"\n'
|
||||
")\n"
|
||||
"SELECT\n"
|
||||
' "cte"."ssn" AS "ssn",\n'
|
||||
' "cte"."name" AS "name"\n'
|
||||
'FROM "cte" AS "cte"'
|
||||
)
|
||||
|
||||
|
||||
def test_apply_cls_union() -> None:
|
||||
"""
|
||||
Test CLS applies to UNION queries.
|
||||
"""
|
||||
rules = {"users": {"ssn": CLSAction.HASH}}
|
||||
sql = "SELECT ssn FROM users UNION SELECT ssn FROM archived_users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
"SELECT\n"
|
||||
' MD5(CAST("users"."ssn" AS TEXT)) AS ssn\n'
|
||||
'FROM "users" AS "users"\n'
|
||||
"UNION\n"
|
||||
"SELECT\n"
|
||||
' "archived_users"."ssn" AS "ssn"\n'
|
||||
'FROM "archived_users" AS "archived_users"'
|
||||
)
|
||||
|
||||
|
||||
def test_cls_hide_all_columns() -> None:
|
||||
"""
|
||||
Test CLS HIDE action when all columns are hidden.
|
||||
"""
|
||||
rules = {"users": {"id": CLSAction.HIDE, "name": CLSAction.HIDE}}
|
||||
sql = "SELECT id, name FROM users"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
# Both columns should be hidden, resulting in empty SELECT
|
||||
assert result == 'SELECT\nFROM "users" AS "users"'
|
||||
|
||||
|
||||
def test_cls_transformer_extract_scope_tables_no_from() -> None:
|
||||
"""
|
||||
Test CLSTransformer._extract_scope_tables with no FROM clause.
|
||||
"""
|
||||
transformer = CLSTransformer({}, Dialects.POSTGRES)
|
||||
select = parse_one("SELECT 1")
|
||||
tables = transformer._extract_scope_tables(select)
|
||||
assert tables == {}
|
||||
|
||||
|
||||
def test_cls_transformer_extract_scope_tables_no_joins() -> None:
|
||||
"""
|
||||
Test CLSTransformer._extract_scope_tables with FROM but no JOINs.
|
||||
"""
|
||||
transformer = CLSTransformer({}, Dialects.POSTGRES)
|
||||
select = parse_one("SELECT * FROM users")
|
||||
tables = transformer._extract_scope_tables(select)
|
||||
assert "users" in tables
|
||||
assert len(tables) == 1
|
||||
|
||||
|
||||
def test_apply_cls_aliased_column_preserves_alias() -> None:
|
||||
"""
|
||||
Test that CLS preserves the alias when column has AS clause.
|
||||
"""
|
||||
rules = {"t": {"col": CLSAction.HASH}}
|
||||
sql = "SELECT t.col AS my_alias FROM t"
|
||||
result = apply_cls(sql, rules, engine="postgresql")
|
||||
|
||||
assert result == (
|
||||
'SELECT\n MD5(CAST("t"."col" AS TEXT)) AS my_alias\nFROM "t" AS "t"'
|
||||
)
|
||||
|
||||
|
||||
def test_cls_transformer_hash_pattern_fallback() -> None:
|
||||
"""
|
||||
Test CLSTransformer uses fallback hash pattern for unknown dialect.
|
||||
"""
|
||||
# Use None as dialect to trigger fallback
|
||||
transformer = CLSTransformer({"t": {"col": CLSAction.HASH}}, None)
|
||||
assert transformer.hash_pattern == "'[HASHED]'"
|
||||
|
||||
Reference in New Issue
Block a user