diff --git a/superset/sql/parse.py b/superset/sql/parse.py index f8b50f2d3b9..3d4f55010b9 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -142,8 +142,97 @@ class CLSAction(enum.Enum): MASK = enum.auto() # Replace with '****' -# Type alias for CLS rules: {table_name: {column_name: action}} -CLSRules = dict[str, dict[str, CLSAction]] +@dataclass(eq=True, frozen=True) +class Table: + """ + A fully qualified SQL table conforming to [[catalog.]schema.]table. + """ + + table: str + schema: str | None = None + catalog: str | None = None + + def __str__(self) -> str: + """ + Return the fully qualified SQL table name. + + Should not be used for SQL generation, only for logging and debugging, since the + quoting is not engine-specific. + """ + return ".".join( + urllib.parse.quote(part, safe="").replace(".", "%2E") + for part in [self.catalog, self.schema, self.table] + if part + ) + + def __eq__(self, other: Any) -> bool: + return str(self) == str(other) + + def qualify( + self, + *, + catalog: str | None = None, + schema: str | None = None, + ) -> Table: + """ + Return a new Table with the given schema and/or catalog, if not already set. + """ + return Table( + table=self.table, + schema=self.schema or schema, + catalog=self.catalog or catalog, + ) + + +# Type alias for CLS rules: {Table: {column_name: action}} +CLSRules = dict[Table, dict[str, CLSAction]] + +# CLS action precedence: higher value = stricter (less information revealed) +# HIDE > NULLIFY > MASK > HASH +CLS_ACTION_PRECEDENCE: dict[CLSAction, int] = { + CLSAction.HASH: 1, + CLSAction.MASK: 2, + CLSAction.NULLIFY: 3, + CLSAction.HIDE: 4, +} + + +def merge_cls_rules(*rules_list: CLSRules) -> CLSRules: + """ + Merge multiple CLS rule sets into one, using the stricter action when conflicts occur. + + When multiple rules specify actions for the same table/column, the stricter action + is kept. Precedence (strictest to least strict): HIDE > NULLIFY > MASK > HASH + + Args: + *rules_list: Variable number of CLSRules dicts to merge + + Returns: + A merged CLSRules dict with the strictest action for each table/column + + Example: + >>> rules1 = {Table("foo"): {"col1": CLSAction.HASH}} + >>> rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}} + >>> merge_cls_rules(rules1, rules2) + {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}} + """ + merged: CLSRules = {} + + for rules in rules_list: + for table, columns in rules.items(): + if table not in merged: + merged[table] = {} + + for column, action in columns.items(): + existing_action = merged[table].get(column) + if existing_action is None: + merged[table][column] = action + else: + # Keep the stricter action (higher precedence value) + if CLS_ACTION_PRECEDENCE[action] > CLS_ACTION_PRECEDENCE[existing_action]: + merged[table][column] = action + + return merged # Hash function patterns by dialect. The placeholder {} will be replaced with the @@ -333,16 +422,21 @@ class CLSTransformer: """ AST transformer to apply Column-Level Security rules. - This transformer modifies SELECT expressions to apply CLS actions: + This transformer modifies SELECT expressions and predicates to apply CLS actions: - HASH: Replace column with hash function (database-specific) - - NULLIFY: Replace column with NULL AS column_name - - HIDE: Remove column from SELECT entirely - - MASK: Replace column with '****' AS column_name + - NULLIFY: Replace with NULL AS column_name (SELECT) or FALSE (predicates) + - HIDE: Remove column from SELECT entirely, FALSE in predicates + - MASK: Replace column with '****' AS column_name (SELECT) or FALSE (predicates) Example: - Given rules: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}} - Query: SELECT id, salary, name FROM my_table + Given rules: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}} + Query: SELECT id, salary, name FROM my_table WHERE id = 1 Result: SELECT MD5(CAST(id AS TEXT)), NULL AS salary, name FROM my_table + WHERE MD5(CAST(id AS TEXT)) = 1 + + For predicates, HASH transforms the column to ensure filtered results also respect + the security policy. NULLIFY/MASK/HIDE transform to FALSE to prevent information + leakage through filtering. """ def __init__( @@ -354,12 +448,16 @@ class CLSTransformer: self.dialect = dialect self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None]) - def _normalize_rules(self, rules: CLSRules) -> CLSRules: + def _normalize_rules(self, rules: CLSRules) -> dict[Table, dict[str, CLSAction]]: """ Normalize table and column names to lowercase for case-insensitive matching. """ return { - table.lower(): {col.lower(): action for col, action in cols.items()} + Table( + table=table.table.lower(), + schema=table.schema.lower() if table.schema else None, + catalog=table.catalog.lower() if table.catalog else None, + ): {col.lower(): action for col, action in cols.items()} for table, cols in rules.items() } @@ -367,18 +465,41 @@ class CLSTransformer: self, table_name: str | None, column_name: str, + schema: str | None = None, + catalog: str | None = None, ) -> CLSAction | None: """ Get the CLS action for a column, if any. + + Matching logic: + 1. First try exact match with schema/catalog if provided + 2. Fallback to table name match - if table names match, apply the rule + regardless of schema/catalog (since query may not have schema info) """ if not table_name: return None - table_rules = self.rules.get(table_name.lower()) - if not table_rules: - return None + # Create a normalized Table for lookup + lookup_table = Table( + table=table_name.lower(), + schema=schema.lower() if schema else None, + catalog=catalog.lower() if catalog else None, + ) - return table_rules.get(column_name.lower()) + # First try exact match with schema/catalog + table_rules = self.rules.get(lookup_table) + if table_rules: + return table_rules.get(column_name.lower()) + + # Fallback: match by table name only + # This handles cases where the rule has schema/catalog but the query doesn't + for rule_table, cols in self.rules.items(): + if rule_table.table == lookup_table.table: + action = cols.get(column_name.lower()) + if action: + return action + + return None def _create_hash_expression( self, @@ -406,15 +527,44 @@ class CLSTransformer: alias=exp.Identifier(this=alias), ) - def _create_mask_expression(self, alias: str) -> exp.Expression: + def _create_mask_expression( + self, + column: exp.Column, + alias: str, + ) -> exp.Expression: """ - Create a '****' AS alias expression. + Create a CASE expression that masks non-NULL values while preserving NULLs. + + Generates: CASE WHEN column IS NULL THEN NULL ELSE '****' END AS alias + + This preserves the semantic meaning of NULL (no value) vs masked (hidden value). """ return exp.Alias( - this=exp.Literal(this="****", is_string=True), + this=exp.Case( + ifs=[ + exp.If( + this=exp.Is(this=column.copy(), expression=exp.Null()), + true=exp.Null(), + ) + ], + default=exp.Literal(this="****", is_string=True), + ), alias=exp.Identifier(this=alias), ) + def _create_hash_expression_no_alias( + self, + column: exp.Column, + ) -> exp.Expression: + """ + Create a hash expression for a column without an alias. + + Used for transforming columns in predicates (WHERE, ON, etc.). + """ + col_sql = column.sql(dialect=self.dialect) + hash_sql = self.hash_pattern.format(col_sql) + return sqlglot.parse_one(hash_sql, dialect=self.dialect) + def _get_column_alias(self, expr: exp.Expression) -> str: """ Get the alias for a column expression. @@ -453,9 +603,10 @@ class CLSTransformer: # This is a best-effort match for unqualified columns col_lower = column.name.lower() for table_name in scope_tables.values(): - table_rules = self.rules.get(table_name.lower()) - if table_rules and col_lower in table_rules: - return table_name + # Look for a rule matching this table + for rule_table, cols in self.rules.items(): + if rule_table.table == table_name.lower() and col_lower in cols: + return table_name return None @@ -481,6 +632,34 @@ class CLSTransformer: return tables + def _transform_nested_column( + self, + column: exp.Column, + scope_tables: dict[str, str], + ) -> exp.Expression: + """ + Transform a nested column reference within a SELECT expression. + + This handles columns inside CASE expressions, function arguments, etc. + Unlike top-level columns, nested columns use NULL for blocking instead + of FALSE (which works better in non-predicate contexts). + + - HASH: Replace with hash function + - NULLIFY/MASK/HIDE: Replace with NULL (blocks computation safely) + """ + table_name = self._get_table_for_column(column, scope_tables) + action = self._get_action(table_name, column.name) + + if action is None: + return column + + if action == CLSAction.HASH: + return self._create_hash_expression_no_alias(column) + + # NULLIFY/MASK/HIDE: Return NULL to safely block any computation + # NULL propagates through expressions: UPPER(NULL)→NULL, 1+NULL→NULL, etc. + return exp.Null() + def _transform_expression( self, expr: exp.Expression, @@ -489,32 +668,42 @@ class CLSTransformer: """ Transform a single SELECT expression based on CLS rules. + For simple column references: apply full transformation with alias. + For complex expressions: transform all nested column references. + Returns: - Transformed expression - - None if the column should be hidden + - None if a top-level column should be hidden """ # Get the underlying column (handle aliases) column = expr.this if isinstance(expr, exp.Alias) else expr alias = self._get_column_alias(expr) - if not isinstance(column, exp.Column): - # Not a simple column reference (could be a function, literal, etc.) - return expr + if isinstance(column, exp.Column): + # Simple column reference - apply full transformation with alias + table_name = self._get_table_for_column(column, scope_tables) + action = self._get_action(table_name, column.name) - table_name = self._get_table_for_column(column, scope_tables) - action = self._get_action(table_name, column.name) + if action is None: + return expr - if action is None: - return expr + if action == CLSAction.HIDE: + return None + if action == CLSAction.HASH: + return self._create_hash_expression(column, alias) + if action == CLSAction.NULLIFY: + return self._create_null_expression(alias) + # action == CLSAction.MASK + return self._create_mask_expression(column, alias) - if action == CLSAction.HIDE: - return None - if action == CLSAction.HASH: - return self._create_hash_expression(column, alias) - if action == CLSAction.NULLIFY: - return self._create_null_expression(alias) - # action == CLSAction.MASK - return self._create_mask_expression(alias) + # Complex expression (CASE, function, arithmetic, etc.) + # Transform ALL nested column references within it + def transform_nested(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Column): + return self._transform_nested_column(node, scope_tables) + return node + + return expr.transform(transform_nested) def _transform_star( self, @@ -535,13 +724,162 @@ class CLSTransformer: ) return [star] + def _transform_non_select_column( + self, + column: exp.Column, + scope_tables: dict[str, str], + ) -> exp.Expression: + """ + Transform a column reference outside of SELECT list. + + This is the SINGLE transformation function for ALL column references + outside the SELECT list (WHERE, HAVING, ON, GROUP BY, ORDER BY, + window functions, CASE expressions, function arguments, etc.) + + - HASH: Replace with hash function + - NULLIFY/MASK/HIDE: Replace with FALSE (blocks predicates, marked for + removal in GROUP BY/ORDER BY) + """ + table_name = self._get_table_for_column(column, scope_tables) + action = self._get_action(table_name, column.name) + + if action is None: + return column + + if action == CLSAction.HASH: + return self._create_hash_expression_no_alias(column) + + # NULLIFY/MASK/HIDE: Return FALSE to block usage + # For predicates: FALSE blocks the filter + # For GROUP BY/ORDER BY: Will be cleaned up in post-processing + return exp.false() + + @staticmethod + def _is_blocked(node: exp.Expression) -> bool: + """Check if an expression is a blocked column (FALSE or NULL sentinel).""" + # FALSE is used for blocked columns in predicates (Phase 2) + # NULL is used for blocked columns in nested expressions (Phase 1) + if isinstance(node, exp.Boolean) and not node.this: + return True + if isinstance(node, exp.Null): + return True + return False + + def _transform_all_non_select_columns( + self, + select: exp.Select, + scope_tables: dict[str, str], + ) -> None: + """ + Transform ALL column references outside the SELECT list. + + This uses sqlglot's transform() to recursively walk through the entire + expression tree, ensuring we catch columns in: + - WHERE clauses + - HAVING clauses + - JOIN ON conditions + - GROUP BY clauses + - ORDER BY clauses + - Window function PARTITION BY / ORDER BY + - CASE expressions + - Function arguments + - Any other nested expression + + This is the security-critical function that ensures NO column reference + is missed. + """ + + def transform_column(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Column): + return self._transform_non_select_column(node, scope_tables) + return node + + # Transform WHERE + if where := select.args.get("where"): + transformed = where.this.transform(transform_column) + select.set("where", exp.Where(this=transformed)) + + # Transform HAVING + if having := select.args.get("having"): + transformed = having.this.transform(transform_column) + select.set("having", exp.Having(this=transformed)) + + # Transform all JOINs (ON conditions) + for join in select.args.get("joins") or []: + if on := join.args.get("on"): + transformed = on.transform(transform_column) + join.set("on", transformed) + + # Transform GROUP BY and remove blocked (FALSE) expressions + if group := select.args.get("group"): + new_exprs = [] + for expr in group.expressions: + transformed = expr.transform(transform_column) + if not self._is_blocked(transformed): + new_exprs.append(transformed) + if new_exprs: + group.set("expressions", new_exprs) + else: + select.set("group", None) + + # Transform ORDER BY and remove blocked (FALSE) expressions + if order := select.args.get("order"): + new_exprs = [] + for ordered in order.expressions: + transformed = ordered.transform(transform_column) + # Check the inner expression (Ordered wraps the actual expr) + inner = transformed.this if isinstance(transformed, exp.Ordered) else transformed + if not self._is_blocked(inner): + new_exprs.append(transformed) + if new_exprs: + order.set("expressions", new_exprs) + else: + select.set("order", None) + + # Transform Window functions within SELECT expressions + # Window functions have their own PARTITION BY and ORDER BY clauses + for expr in select.args.get("expressions", []): + for window in expr.find_all(exp.Window): + # Transform PARTITION BY + if partition_by := window.args.get("partition_by"): + new_partition = [] + for part_expr in partition_by: + transformed = part_expr.transform(transform_column) + if not self._is_blocked(transformed): + new_partition.append(transformed) + window.set("partition_by", new_partition if new_partition else None) + + # Transform ORDER BY within window + if window_order := window.args.get("order"): + new_order_exprs = [] + for ordered in window_order.expressions: + transformed = ordered.transform(transform_column) + inner = ( + transformed.this + if isinstance(transformed, exp.Ordered) + else transformed + ) + if not self._is_blocked(inner): + new_order_exprs.append(transformed) + if new_order_exprs: + window_order.set("expressions", new_order_exprs) + else: + window.set("order", None) + def transform_select(self, select: exp.Select) -> exp.Select: """ - Transform a SELECT statement by applying CLS rules to its expressions. + Transform a SELECT statement by applying CLS rules. + + This is the main entry point for CLS transformation. It: + 1. Extracts table scope for column resolution + 2. Transforms SELECT list expressions (with HIDE removal and aliases) + 3. Transforms ALL other column references in the query """ scope_tables = self._extract_scope_tables(select) - expressions = select.args.get("expressions", []) + # Phase 1: Transform SELECT list expressions + # This handles HASH/NULLIFY/MASK with aliases, and removes HIDE columns + expressions = select.args.get("expressions", []) new_expressions: list[exp.Expression] = [] for expr in expressions: if isinstance(expr, exp.Star): @@ -550,9 +888,12 @@ class CLSTransformer: transformed = self._transform_expression(expr, scope_tables) if transformed is not None: new_expressions.append(transformed) - - # Create a new SELECT with transformed expressions select.set("expressions", new_expressions) + + # Phase 2: Transform ALL other column references + # This is the security-critical phase that catches every column reference + self._transform_all_non_select_columns(select, scope_tables) + return select def __call__(self, node: exp.Expression) -> exp.Expression: @@ -573,16 +914,19 @@ def apply_cls( """ Apply Column-Level Security rules to a SQL query. - This function transforms a SQL query by applying CLS actions to sensitive columns: - - HASH: Pseudonymize using database-specific hash function - - NULLIFY: Replace with NULL - - HIDE: Remove from SELECT results - - MASK: Replace with '****' + This function transforms a SQL query by applying CLS actions to sensitive columns + in both SELECT expressions and predicates (WHERE, ON, HAVING): + + - HASH: Pseudonymize using database-specific hash function (both SELECT and predicates) + - NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering + - HIDE: Remove from SELECT results, FALSE in predicates to block filtering + - MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering Args: sql: The SQL query to transform - rules: CLS rules mapping table names to column actions - Example: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}} + rules: CLS rules mapping Table objects to column actions + Example: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}} + Tables can include schema/catalog for fully qualified matching. engine: The database engine (used for dialect-specific hash functions) schema: Optional schema for column qualification. Required for JOINs with ambiguous column names. Format: {"table": {"column": "TYPE", ...}, ...} @@ -599,48 +943,6 @@ def apply_cls( return statement.format(comments=True) -@dataclass(eq=True, frozen=True) -class Table: - """ - A fully qualified SQL table conforming to [[catalog.]schema.]table. - """ - - table: str - schema: str | None = None - catalog: str | None = None - - def __str__(self) -> str: - """ - Return the fully qualified SQL table name. - - Should not be used for SQL generation, only for logging and debugging, since the - quoting is not engine-specific. - """ - return ".".join( - urllib.parse.quote(part, safe="").replace(".", "%2E") - for part in [self.catalog, self.schema, self.table] - if part - ) - - def __eq__(self, other: Any) -> bool: - return str(self) == str(other) - - def qualify( - self, - *, - catalog: str | None = None, - schema: str | None = None, - ) -> Table: - """ - Return a new Table with the given schema and/or catalog, if not already set. - """ - return Table( - table=self.table, - schema=self.schema or schema, - catalog=self.catalog or catalog, - ) - - # To avoid unnecessary parsing/formatting of queries, the statement has the concept of # an "internal representation", which is the AST of the SQL statement. For most of the # engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special @@ -1215,14 +1517,14 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): """ Apply Column-Level Security rules to the statement inplace. - CLS rules transform sensitive columns in SELECT statements: - - HASH: Pseudonymize using database-specific hash function - - NULLIFY: Replace with NULL - - HIDE: Remove from SELECT results - - MASK: Replace with '****' + CLS rules transform sensitive columns in SELECT statements and predicates: + - HASH: Pseudonymize using database-specific hash function (both SELECT and predicates) + - NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering + - HIDE: Remove from SELECT results, FALSE in predicates to block filtering + - MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering - :param rules: CLS rules mapping table names to column actions - Example: {"my_table": {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}} + :param rules: CLS rules mapping Table objects to column actions + Example: {Table("my_table"): {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}} :param schema: Optional schema for column qualification. Required for JOINs with ambiguous column names. Format: {"table": {"column": "TYPE", ...}} """ diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index f5a10a3c702..61063f8f882 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -27,6 +27,7 @@ from superset.exceptions import QueryClauseValidationException, SupersetParseErr from superset.jinja_context import JinjaTemplateProcessor from superset.sql.parse import ( apply_cls, + CLS_ACTION_PRECEDENCE, CLS_HASH_FUNCTIONS, CLSAction, CLSTransformer, @@ -36,6 +37,7 @@ from superset.sql.parse import ( KQLTokenType, KustoKQLStatement, LimitMethod, + merge_cls_rules, process_jinja_sql, remove_quotes, RLSMethod, @@ -3033,7 +3035,7 @@ def test_apply_cls_hash_action() -> None: """ Test CLSAction.HASH transforms column with hash function. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "SELECT ssn, name FROM users" result = apply_cls(sql, rules, engine="postgresql") @@ -3049,7 +3051,7 @@ def test_apply_cls_nullify_action() -> None: """ Test CLSAction.NULLIFY transforms column to NULL. """ - rules = {"users": {"salary": CLSAction.NULLIFY}} + rules = {Table("users"): {"salary": CLSAction.NULLIFY}} sql = "SELECT salary, name FROM users" result = apply_cls(sql, rules, engine="postgresql") @@ -3062,7 +3064,7 @@ def test_apply_cls_hide_action() -> None: """ Test CLSAction.HIDE removes column from SELECT. """ - rules = {"users": {"password": CLSAction.HIDE}} + rules = {Table("users"): {"password": CLSAction.HIDE}} sql = "SELECT password, name FROM users" result = apply_cls(sql, rules, engine="postgresql") @@ -3071,26 +3073,43 @@ def test_apply_cls_hide_action() -> None: def test_apply_cls_mask_action() -> None: """ - Test CLSAction.MASK transforms column to '****'. + Test CLSAction.MASK transforms column to CASE expression preserving NULLs. """ - rules = {"users": {"phone": CLSAction.MASK}} + rules = {Table("users"): {"phone": CLSAction.MASK}} sql = "SELECT phone, name FROM users" result = apply_cls(sql, rules, engine="postgresql") assert result == ( "SELECT\n" - " '****' AS phone,\n" + " CASE WHEN \"users\".\"phone\" IS NULL THEN NULL ELSE '****' END AS phone,\n" ' "users"."name" AS "name"\n' 'FROM "users" AS "users"' ) +def test_apply_cls_mask_preserves_null() -> None: + """ + Test CLSAction.MASK preserves NULL values using CASE expression. + + MASK generates: CASE WHEN column IS NULL THEN NULL ELSE '****' END + This preserves the semantic meaning of NULL (no value) vs masked (hidden value). + """ + rules = {Table("users"): {"email": CLSAction.MASK}} + sql = "SELECT email FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # The CASE expression should check for NULL and preserve it + assert "CASE WHEN" in result + assert "IS NULL THEN NULL" in result + assert "ELSE '****'" in result + + def test_apply_cls_all_actions() -> None: """ Test all CLS actions in a single query. """ rules = { - "employees": { + Table("employees"): { "ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY, "password": CLSAction.HIDE, @@ -3104,7 +3123,7 @@ def test_apply_cls_all_actions() -> None: "SELECT\n" ' MD5(CAST("employees"."ssn" AS TEXT)) AS ssn,\n' " NULL AS salary,\n" - " '****' AS phone,\n" + " CASE WHEN \"employees\".\"phone\" IS NULL THEN NULL ELSE '****' END AS phone,\n" ' "employees"."name" AS "name"\n' 'FROM "employees" AS "employees"' ) @@ -3114,7 +3133,7 @@ def test_apply_cls_qualified_columns() -> None: """ Test CLS with fully qualified column names (table.column). """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "SELECT users.ssn, users.name FROM users" result = apply_cls(sql, rules, engine="postgresql") @@ -3130,7 +3149,7 @@ def test_apply_cls_table_alias() -> None: """ Test CLS with table aliases. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "SELECT u.ssn, u.name FROM users u" result = apply_cls(sql, rules, engine="postgresql") @@ -3147,8 +3166,8 @@ def test_apply_cls_join() -> None: Test CLS with JOIN queries. """ rules = { - "employees": {"ssn": CLSAction.HASH}, - "salaries": {"amount": CLSAction.NULLIFY}, + Table("employees"): {"ssn": CLSAction.HASH}, + Table("salaries"): {"amount": CLSAction.NULLIFY}, } sql = """ SELECT e.ssn, e.name, s.amount @@ -3173,7 +3192,7 @@ def test_apply_cls_case_insensitive() -> None: """ Test CLS rules are case-insensitive for table and column names. """ - rules = {"USERS": {"SSN": CLSAction.HASH}} + rules = {Table("USERS"): {"SSN": CLSAction.HASH}} sql = "SELECT ssn, name FROM users" result = apply_cls(sql, rules, engine="postgresql") @@ -3189,7 +3208,7 @@ def test_apply_cls_with_column_alias() -> None: """ Test CLS preserves existing column aliases. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "SELECT ssn AS social_security, name FROM users" result = apply_cls(sql, rules, engine="postgresql") @@ -3205,7 +3224,7 @@ def test_apply_cls_no_matching_table() -> None: """ Test CLS leaves columns unchanged when table doesn't match rules. """ - rules = {"other_table": {"ssn": CLSAction.HASH}} + rules = {Table("other_table"): {"ssn": CLSAction.HASH}} sql = "SELECT ssn, name FROM users" result = apply_cls(sql, rules, engine="postgresql") @@ -3222,7 +3241,7 @@ def test_apply_cls_non_column_expressions() -> None: """ Test CLS leaves non-column expressions unchanged. """ - rules = {"users": {"name": CLSAction.HASH}} + rules = {Table("users"): {"name": CLSAction.HASH}} sql = "SELECT 1 AS one, 'test' AS str, COUNT(*) AS cnt FROM users" result = apply_cls(sql, rules, engine="postgresql") @@ -3240,8 +3259,8 @@ def test_apply_cls_with_schema() -> None: Test CLS with schema for column qualification. """ rules = { - "employees": {"ssn": CLSAction.HASH}, - "departments": {"budget": CLSAction.NULLIFY}, + Table("employees"): {"ssn": CLSAction.HASH}, + Table("departments"): {"budget": CLSAction.NULLIFY}, } schema = { "employees": { @@ -3269,7 +3288,7 @@ def test_apply_cls_different_dialects() -> None: """ Test CLS uses correct hash function for different database dialects. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "SELECT ssn FROM users" # PostgreSQL @@ -3297,7 +3316,7 @@ def test_apply_cls_unknown_dialect_fallback() -> None: """ Test CLS uses fallback for unknown database dialects. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "SELECT users.ssn FROM users" result = apply_cls(sql, rules, engine="unknown_database") @@ -3308,7 +3327,7 @@ def test_apply_cls_select_star_warning(caplog: pytest.LogCaptureFixture) -> None """ Test CLS logs warning for SELECT * queries. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "SELECT * FROM users" import logging @@ -3326,7 +3345,7 @@ def test_sql_statement_apply_cls_method() -> None: """ Test SQLStatement.apply_cls method. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} statement = SQLStatement("SELECT ssn, name FROM users", engine="postgresql") statement.apply_cls(rules) result = statement.format() @@ -3356,7 +3375,7 @@ def test_sql_statement_apply_cls_with_schema() -> None: """ Test SQLStatement.apply_cls with schema parameter. """ - rules = {"employees": {"ssn": CLSAction.HASH}} + rules = {Table("employees"): {"ssn": CLSAction.HASH}} schema = {"employees": {"id": "INT", "ssn": "VARCHAR", "name": "VARCHAR"}} statement = SQLStatement("SELECT ssn, name FROM employees", engine="postgresql") statement.apply_cls(rules, schema=schema) @@ -3374,19 +3393,21 @@ def test_cls_transformer_normalize_rules() -> None: """ Test CLSTransformer normalizes table and column names to lowercase. """ - rules = {"USERS": {"SSN": CLSAction.HASH, "Name": CLSAction.MASK}} + rules = {Table("USERS"): {"SSN": CLSAction.HASH, "Name": CLSAction.MASK}} transformer = CLSTransformer(rules, Dialects.POSTGRES) - assert "users" in transformer.rules - assert "ssn" in transformer.rules["users"] - assert "name" in transformer.rules["users"] + # Check that a normalized Table key exists + normalized_key = Table("users") + assert normalized_key in transformer.rules + assert "ssn" in transformer.rules[normalized_key] + assert "name" in transformer.rules[normalized_key] def test_cls_transformer_get_action() -> None: """ Test CLSTransformer._get_action method. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} transformer = CLSTransformer(rules, Dialects.POSTGRES) # Valid table and column @@ -3409,7 +3430,7 @@ def test_cls_transformer_extract_scope_tables() -> None: """ Test CLSTransformer._extract_scope_tables method. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} transformer = CLSTransformer(rules, Dialects.POSTGRES) # Single table @@ -3437,7 +3458,7 @@ def test_cls_transformer_get_table_for_column_qualified() -> None: """ Test CLSTransformer._get_table_for_column with qualified columns. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} transformer = CLSTransformer(rules, Dialects.POSTGRES) scope_tables = {"u": "users", "o": "orders"} @@ -3456,7 +3477,7 @@ def test_cls_transformer_get_table_for_column_single_table() -> None: """ Test CLSTransformer._get_table_for_column infers single table. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} transformer = CLSTransformer(rules, Dialects.POSTGRES) scope_tables = {"users": "users"} @@ -3470,7 +3491,7 @@ def test_cls_transformer_get_table_for_column_multi_table_rules_match() -> None: """ Test CLSTransformer._get_table_for_column matches against rules. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} transformer = CLSTransformer(rules, Dialects.POSTGRES) scope_tables = {"users": "users", "orders": "orders"} @@ -3484,7 +3505,7 @@ def test_cls_transformer_get_table_for_column_no_match() -> None: """ Test CLSTransformer._get_table_for_column returns None when no match. """ - rules = {"other": {"col": CLSAction.HASH}} + rules = {Table("other"): {"col": CLSAction.HASH}} transformer = CLSTransformer(rules, Dialects.POSTGRES) scope_tables = {"users": "users", "orders": "orders"} @@ -3531,19 +3552,23 @@ def test_cls_transformer_create_expressions() -> None: assert null_expr.alias == "salary" assert isinstance(null_expr.this, exp.Null) - # Mask expression - mask_expr = transformer._create_mask_expression("phone") + # Mask expression (CASE expression that preserves NULLs) + phone_column = parse_one("phone").find(exp.Column) + mask_expr = transformer._create_mask_expression(phone_column, "phone") assert isinstance(mask_expr, exp.Alias) assert mask_expr.alias == "phone" - assert isinstance(mask_expr.this, exp.Literal) - assert mask_expr.this.this == "****" + assert isinstance(mask_expr.this, exp.Case) + # The CASE should have a default of '****' + case_default = mask_expr.this.args.get("default") + assert isinstance(case_default, exp.Literal) + assert case_default.this == "****" def test_cls_transformer_call_non_select() -> None: """ Test CLSTransformer.__call__ returns non-SELECT nodes unchanged. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} transformer = CLSTransformer(rules, Dialects.POSTGRES) # Non-SELECT node should be returned unchanged @@ -3556,7 +3581,7 @@ def test_cls_transformer_transform_expression_non_column() -> None: """ Test CLSTransformer._transform_expression returns non-column expressions unchanged. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} transformer = CLSTransformer(rules, Dialects.POSTGRES) scope_tables = {"users": "users"} @@ -3577,37 +3602,37 @@ def test_cls_transformer_transform_expression_non_column() -> None: # Basic HASH ( "SELECT t.id FROM t", - {"t": {"id": CLSAction.HASH}}, + {Table("t"): {"id": CLSAction.HASH}}, "postgresql", 'SELECT\n MD5(CAST("t"."id" AS TEXT)) AS id\nFROM "t" AS "t"', ), # Basic NULLIFY ( "SELECT t.salary FROM t", - {"t": {"salary": CLSAction.NULLIFY}}, + {Table("t"): {"salary": CLSAction.NULLIFY}}, "postgresql", 'SELECT\n NULL AS salary\nFROM "t" AS "t"', ), # Basic HIDE ( "SELECT t.secret, t.public FROM t", - {"t": {"secret": CLSAction.HIDE}}, + {Table("t"): {"secret": CLSAction.HIDE}}, "postgresql", 'SELECT\n "t"."public" AS "public"\nFROM "t" AS "t"', ), - # Basic MASK + # Basic MASK (preserves NULLs) ( "SELECT t.phone FROM t", - {"t": {"phone": CLSAction.MASK}}, + {Table("t"): {"phone": CLSAction.MASK}}, "postgresql", - 'SELECT\n \'****\' AS phone\nFROM "t" AS "t"', + "SELECT\n CASE WHEN \"t\".\"phone\" IS NULL THEN NULL ELSE '****' END AS phone\nFROM \"t\" AS \"t\"", ), # Multiple tables with different rules ( "SELECT a.ssn, b.amount FROM users a JOIN payments b ON a.id = b.user_id", { - "users": {"ssn": CLSAction.HASH}, - "payments": {"amount": CLSAction.NULLIFY}, + Table("users"): {"ssn": CLSAction.HASH}, + Table("payments"): {"amount": CLSAction.NULLIFY}, }, "postgresql", ( @@ -3622,14 +3647,14 @@ def test_cls_transformer_transform_expression_non_column() -> None: # Snowflake dialect ( "SELECT t.col FROM t", - {"t": {"col": CLSAction.HASH}}, + {Table("t"): {"col": CLSAction.HASH}}, "snowflake", 'SELECT\n MD5(TO_CHAR("T"."COL")) AS COL\nFROM "T" AS "T"', ), # ClickHouse dialect ( "SELECT t.col FROM t", - {"t": {"col": CLSAction.HASH}}, + {Table("t"): {"col": CLSAction.HASH}}, "clickhouse", 'SELECT\n MD5(toString("t"."col")) AS col\nFROM "t" AS "t"', ), @@ -3637,7 +3662,7 @@ def test_cls_transformer_transform_expression_non_column() -> None: ) def test_apply_cls_parametrized( sql: str, - rules: dict[str, Any], + rules: dict[Table, Any], engine: str, expected: str, ) -> None: @@ -3652,7 +3677,7 @@ def test_apply_cls_subquery() -> None: """ Test CLS applies to subqueries. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "SELECT * FROM (SELECT ssn, name FROM users) AS subq" result = apply_cls(sql, rules, engine="postgresql") @@ -3673,7 +3698,7 @@ def test_apply_cls_cte() -> None: """ Test CLS applies to CTEs. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "WITH cte AS (SELECT ssn, name FROM users) SELECT * FROM cte" result = apply_cls(sql, rules, engine="postgresql") @@ -3695,7 +3720,7 @@ def test_apply_cls_union() -> None: """ Test CLS applies to UNION queries. """ - rules = {"users": {"ssn": CLSAction.HASH}} + rules = {Table("users"): {"ssn": CLSAction.HASH}} sql = "SELECT ssn FROM users UNION SELECT ssn FROM archived_users" result = apply_cls(sql, rules, engine="postgresql") @@ -3714,7 +3739,7 @@ def test_cls_hide_all_columns() -> None: """ Test CLS HIDE action when all columns are hidden. """ - rules = {"users": {"id": CLSAction.HIDE, "name": CLSAction.HIDE}} + rules = {Table("users"): {"id": CLSAction.HIDE, "name": CLSAction.HIDE}} sql = "SELECT id, name FROM users" result = apply_cls(sql, rules, engine="postgresql") @@ -3747,7 +3772,7 @@ def test_apply_cls_aliased_column_preserves_alias() -> None: """ Test that CLS preserves the alias when column has AS clause. """ - rules = {"t": {"col": CLSAction.HASH}} + rules = {Table("t"): {"col": CLSAction.HASH}} sql = "SELECT t.col AS my_alias FROM t" result = apply_cls(sql, rules, engine="postgresql") @@ -3761,5 +3786,907 @@ def test_cls_transformer_hash_pattern_fallback() -> None: Test CLSTransformer uses fallback hash pattern for unknown dialect. """ # Use None as dialect to trigger fallback - transformer = CLSTransformer({"t": {"col": CLSAction.HASH}}, None) + transformer = CLSTransformer({Table("t"): {"col": CLSAction.HASH}}, None) assert transformer.hash_pattern == "'[HASHED]'" + + +# Tests for CLS predicate transformation +def test_apply_cls_where_clause_hash() -> None: + """ + Test CLS HASH transforms columns in WHERE clause predicates. + + This prevents information leakage by ensuring filter conditions also + get hashed, so queries like "WHERE role='CEO'" won't match unless + the user knows the hash value. + """ + rules = {Table("payroll"): {"role": CLSAction.HASH}} + sql = "SELECT MAX(salary) FROM payroll WHERE role='CEO'" + result = apply_cls(sql, rules, engine="postgresql") + + # The WHERE clause should have the column hashed + assert "MD5" in result + assert 'WHERE' in result + assert "MD5(CAST" in result + + +def test_apply_cls_where_clause_nullify() -> None: + """ + Test CLS NULLIFY in WHERE clause becomes FALSE to prevent filtering. + """ + rules = {Table("payroll"): {"salary": CLSAction.NULLIFY}} + sql = "SELECT name FROM payroll WHERE salary > 100000" + result = apply_cls(sql, rules, engine="postgresql") + + # The WHERE clause column becomes FALSE to block filtering + assert "FALSE" in result + + +def test_apply_cls_where_clause_mask() -> None: + """ + Test CLS MASK in WHERE clause becomes FALSE to prevent filtering. + """ + rules = {Table("users"): {"phone": CLSAction.MASK}} + sql = "SELECT name FROM users WHERE phone = '555-1234'" + result = apply_cls(sql, rules, engine="postgresql") + + # The WHERE clause column becomes FALSE to block filtering + assert "FALSE" in result + + +def test_apply_cls_where_clause_hide() -> None: + """ + Test CLS HIDE in WHERE clause becomes FALSE to prevent filtering. + """ + rules = {Table("users"): {"secret_code": CLSAction.HIDE}} + sql = "SELECT name FROM users WHERE secret_code = 'ADMIN'" + result = apply_cls(sql, rules, engine="postgresql") + + # The WHERE clause column becomes FALSE to block filtering + assert "FALSE" in result + + +def test_apply_cls_where_clause_multiple_conditions() -> None: + """ + Test CLS transforms multiple conditions in WHERE clause. + """ + rules = {Table("users"): {"role": CLSAction.HASH, "salary": CLSAction.NULLIFY}} + sql = "SELECT name FROM users WHERE role = 'admin' AND salary > 50000" + result = apply_cls(sql, rules, engine="postgresql") + + # One condition should be hashed, other becomes FALSE + assert "MD5" in result + assert "FALSE" in result + + +def test_apply_cls_join_on_clause_hash() -> None: + """ + Test CLS HASH transforms columns in JOIN ON clause. + """ + rules = {Table("users"): {"user_id": CLSAction.HASH}} + sql = """ +SELECT o.order_id +FROM orders o +JOIN users u ON o.customer_id = u.user_id + """ + result = apply_cls(sql, rules, engine="postgresql") + + # The ON clause should have the column hashed + assert "MD5" in result + + +def test_apply_cls_cross_join_no_on_clause() -> None: + """ + Test CLS handles CROSS JOIN (no ON clause) without error. + """ + rules = {Table("users"): {"ssn": CLSAction.HASH}} + sql = "SELECT u.ssn, p.name FROM users u CROSS JOIN products p" + result = apply_cls(sql, rules, engine="postgresql") + + # SSN in SELECT should be hashed, CROSS JOIN has no ON to transform + assert "MD5" in result + assert "CROSS JOIN" in result + + +def test_apply_cls_having_clause_hash() -> None: + """ + Test CLS HASH transforms columns in HAVING clause. + """ + rules = {Table("sales"): {"region": CLSAction.HASH}} + sql = "SELECT COUNT(*) FROM sales GROUP BY region HAVING region = 'North'" + result = apply_cls(sql, rules, engine="postgresql") + + # The HAVING clause should have the column hashed + assert "HAVING" in result + assert "MD5" in result + + +def test_apply_cls_group_by_hash() -> None: + """ + Test CLS HASH transforms columns in GROUP BY clause. + """ + rules = {Table("users"): {"department": CLSAction.HASH}} + sql = "SELECT COUNT(*) FROM users GROUP BY department" + result = apply_cls(sql, rules, engine="postgresql") + + # GROUP BY should have the column hashed + assert "GROUP BY" in result + assert "MD5" in result + + +def test_apply_cls_group_by_hide() -> None: + """ + Test CLS HIDE removes column from GROUP BY clause. + """ + rules = {Table("users"): {"ssn": CLSAction.HIDE}} + sql = "SELECT COUNT(*) FROM users GROUP BY ssn" + result = apply_cls(sql, rules, engine="postgresql") + + # GROUP BY should be removed (no columns left) + assert "GROUP BY" not in result + + +def test_apply_cls_group_by_nullify() -> None: + """ + Test CLS NULLIFY removes column from GROUP BY clause. + """ + rules = {Table("users"): {"salary": CLSAction.NULLIFY}} + sql = "SELECT COUNT(*) FROM users GROUP BY salary" + result = apply_cls(sql, rules, engine="postgresql") + + # GROUP BY should be removed (no columns left) + assert "GROUP BY" not in result + + +def test_apply_cls_group_by_mask() -> None: + """ + Test CLS MASK removes column from GROUP BY clause. + """ + rules = {Table("users"): {"phone": CLSAction.MASK}} + sql = "SELECT COUNT(*) FROM users GROUP BY phone" + result = apply_cls(sql, rules, engine="postgresql") + + # GROUP BY should be removed (no columns left) + assert "GROUP BY" not in result + + +def test_apply_cls_group_by_multiple_columns() -> None: + """ + Test CLS with multiple columns in GROUP BY - partial removal. + """ + rules = {Table("users"): {"ssn": CLSAction.HIDE}} + sql = "SELECT COUNT(*) FROM users GROUP BY department, ssn" + result = apply_cls(sql, rules, engine="postgresql") + + # GROUP BY should keep department but remove ssn + assert "GROUP BY" in result + assert "department" in result.lower() + assert "ssn" not in result.lower() + + +def test_apply_cls_order_by_hash() -> None: + """ + Test CLS HASH transforms columns in ORDER BY clause. + """ + rules = {Table("users"): {"salary": CLSAction.HASH}} + sql = "SELECT name FROM users ORDER BY salary DESC" + result = apply_cls(sql, rules, engine="postgresql") + + # ORDER BY should have the column hashed + assert "ORDER BY" in result + assert "MD5" in result + assert "DESC" in result + + +def test_apply_cls_order_by_hide() -> None: + """ + Test CLS HIDE removes column from ORDER BY clause. + """ + rules = {Table("users"): {"salary": CLSAction.HIDE}} + sql = "SELECT name FROM users ORDER BY salary" + result = apply_cls(sql, rules, engine="postgresql") + + # ORDER BY should be removed (no columns left) + assert "ORDER BY" not in result + + +def test_apply_cls_order_by_nullify() -> None: + """ + Test CLS NULLIFY removes column from ORDER BY clause. + """ + rules = {Table("users"): {"salary": CLSAction.NULLIFY}} + sql = "SELECT name FROM users ORDER BY salary" + result = apply_cls(sql, rules, engine="postgresql") + + # ORDER BY should be removed + assert "ORDER BY" not in result + + +def test_apply_cls_order_by_mask() -> None: + """ + Test CLS MASK removes column from ORDER BY clause. + """ + rules = {Table("users"): {"salary": CLSAction.MASK}} + sql = "SELECT name FROM users ORDER BY salary" + result = apply_cls(sql, rules, engine="postgresql") + + # ORDER BY should be removed + assert "ORDER BY" not in result + + +def test_apply_cls_order_by_multiple_columns() -> None: + """ + Test CLS with multiple columns in ORDER BY - partial removal. + """ + rules = {Table("users"): {"salary": CLSAction.HIDE}} + sql = "SELECT name FROM users ORDER BY name, salary DESC" + result = apply_cls(sql, rules, engine="postgresql") + + # ORDER BY should keep name but remove salary + assert "ORDER BY" in result + assert "name" in result.lower() + # salary should be removed + assert "salary" not in result.lower() + + +def test_apply_cls_case_expression_hide() -> None: + """ + Test CLS HIDE in CASE expression replaces column with NULL. + """ + rules = {Table("users"): {"status": CLSAction.HIDE}} + sql = "SELECT name, CASE WHEN status = 'active' THEN 'yes' ELSE 'no' END as is_active FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # status should be replaced with NULL in the CASE + assert "NULL = 'active'" in result + # Original column name should not appear + assert "status" not in result.lower() or "null" in result.lower() + + +def test_apply_cls_case_expression_hash() -> None: + """ + Test CLS HASH in CASE expression transforms the column. + """ + rules = {Table("users"): {"status": CLSAction.HASH}} + sql = "SELECT name, CASE WHEN status = 'active' THEN 'yes' ELSE 'no' END as is_active FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # status should be hashed in the CASE + assert "MD5" in result + assert "CASE WHEN" in result + + +def test_apply_cls_function_argument_hide() -> None: + """ + Test CLS HIDE in function argument replaces column with NULL. + """ + rules = {Table("users"): {"email": CLSAction.HIDE}} + sql = "SELECT UPPER(email) as upper_email FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # email should be replaced with NULL + assert "UPPER(NULL)" in result + + +def test_apply_cls_function_argument_hash() -> None: + """ + Test CLS HASH in function argument transforms the column. + """ + rules = {Table("users"): {"email": CLSAction.HASH}} + sql = "SELECT UPPER(email) as upper_email FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # email should be hashed inside UPPER + assert "UPPER(MD5" in result + + +def test_apply_cls_window_partition_by_hide() -> None: + """ + Test CLS HIDE removes column from window PARTITION BY. + """ + rules = {Table("employees"): {"department": CLSAction.HIDE}} + sql = "SELECT name, ROW_NUMBER() OVER (PARTITION BY department ORDER BY name) as rn FROM employees" + result = apply_cls(sql, rules, engine="postgresql") + + # department should be removed from PARTITION BY + assert "PARTITION BY" not in result or "department" not in result.lower() + # ORDER BY name should remain + assert "ORDER BY" in result + + +def test_apply_cls_window_partition_by_hash() -> None: + """ + Test CLS HASH transforms column in window PARTITION BY. + """ + rules = {Table("employees"): {"department": CLSAction.HASH}} + sql = "SELECT name, ROW_NUMBER() OVER (PARTITION BY department ORDER BY name) as rn FROM employees" + result = apply_cls(sql, rules, engine="postgresql") + + # department should be hashed in PARTITION BY + assert "PARTITION BY" in result + assert "MD5" in result + + +def test_apply_cls_window_order_by_hide() -> None: + """ + Test CLS HIDE removes column from window ORDER BY. + """ + rules = {Table("employees"): {"salary": CLSAction.HIDE}} + sql = "SELECT name, RANK() OVER (ORDER BY salary DESC) as rank FROM employees" + result = apply_cls(sql, rules, engine="postgresql") + + # salary should be removed from ORDER BY, leaving empty or no ORDER + assert "salary" not in result.lower() + + +def test_apply_cls_window_order_by_hash() -> None: + """ + Test CLS HASH transforms column in window ORDER BY. + """ + rules = {Table("employees"): {"salary": CLSAction.HASH}} + sql = "SELECT name, RANK() OVER (ORDER BY salary DESC) as rank FROM employees" + result = apply_cls(sql, rules, engine="postgresql") + + # salary should be hashed in window ORDER BY + assert "MD5" in result + assert "DESC" in result + + +def test_apply_cls_window_partition_only_no_order() -> None: + """ + Test CLS with window that has PARTITION BY but no ORDER BY. + Covers branch 847->836 (no window_order). + """ + rules = {Table("employees"): {"department": CLSAction.HASH}} + sql = "SELECT name, COUNT(*) OVER (PARTITION BY department) as cnt FROM employees" + result = apply_cls(sql, rules, engine="postgresql") + + # department should be hashed in PARTITION BY, no ORDER BY + assert "PARTITION BY" in result + assert "MD5" in result + assert "ORDER BY" not in result.split("OVER")[1].split(")")[0] # No ORDER BY in window + + +def test_apply_cls_window_partition_all_blocked() -> None: + """ + Test CLS removes PARTITION BY when all columns are blocked. + Covers branch 842->840 (_is_blocked returns True in partition loop). + """ + rules = {Table("employees"): {"department": CLSAction.HIDE}} + sql = "SELECT name, COUNT(*) OVER (PARTITION BY department) as cnt FROM employees" + result = apply_cls(sql, rules, engine="postgresql") + + # PARTITION BY should be removed entirely + assert "PARTITION BY" not in result + assert "OVER ()" in result + + +def test_apply_cls_window_order_all_blocked() -> None: + """ + Test CLS removes window ORDER BY when all columns are blocked. + Covers line 861 (window.set("order", None)) and branch 856->849. + """ + rules = {Table("employees"): {"salary": CLSAction.HIDE}} + sql = "SELECT name, RANK() OVER (ORDER BY salary) as rank FROM employees" + result = apply_cls(sql, rules, engine="postgresql") + + # ORDER BY should be removed from window + assert "ORDER BY" not in result + assert "OVER ()" in result + + +def test_apply_cls_arithmetic_expression_hide() -> None: + """ + Test CLS HIDE in arithmetic expression replaces column with NULL. + """ + rules = {Table("products"): {"price": CLSAction.HIDE}} + sql = "SELECT name, price * 1.1 as price_with_tax FROM products" + result = apply_cls(sql, rules, engine="postgresql") + + # price should be replaced with NULL + assert "NULL * 1.1" in result or "NULL" in result + + +def test_apply_cls_concat_function_hide() -> None: + """ + Test CLS HIDE in CONCAT function replaces column with NULL. + """ + rules = {Table("users"): {"ssn": CLSAction.HIDE}} + sql = "SELECT CONCAT('SSN: ', ssn) as display FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # ssn should be replaced with NULL + assert "NULL" in result + + +def test_apply_cls_where_no_rules_unchanged() -> None: + """ + Test that WHERE clause without matching rules remains unchanged. + """ + rules = {Table("other_table"): {"col": CLSAction.HASH}} + sql = "SELECT name FROM users WHERE active = true" + result = apply_cls(sql, rules, engine="postgresql") + + # No transformation should occur - just column qualification + assert "WHERE" in result + assert 'MD5' not in result + assert 'FALSE' not in result + + +def test_apply_cls_table_with_schema() -> None: + """ + Test CLS rules with Table containing schema. + + Rules with schema specified require the table in scope to match. + Since queries often resolve to just the table name without schema, + we match by table name when the rule doesn't specify schema. + """ + # Rule without schema matches any table with that name + rules = {Table("users"): {"ssn": CLSAction.HASH}} + sql = "SELECT ssn FROM public.users" + result = apply_cls(sql, rules, engine="postgresql") + + # Should apply the hash since table name matches + assert "MD5" in result + + +def test_apply_cls_table_key_matching() -> None: + """ + Test CLS rules match by table name when schema is not in query scope. + + The scope_tables dict maps aliases to table names. When the query + has a schema-qualified table, the transformer still uses the table name. + """ + rules = {Table("users"): {"ssn": CLSAction.HASH}} + sql = "SELECT ssn FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # Should apply the hash since table name matches + assert "MD5" in result + + +def test_apply_cls_table_with_schema_rule() -> None: + """ + Test CLS rules with Table(table, schema) pattern. + """ + rules = {Table("users", schema="public"): {"ssn": CLSAction.HASH}} + sql = "SELECT ssn FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # Rule has schema but query doesn't - should match by table name fallback + assert "MD5" in result + + +def test_apply_cls_table_with_schema_case_insensitive() -> None: + """ + Test CLS rules with Table(table, schema) are case-insensitive. + """ + rules = {Table("USERS", schema="PUBLIC"): {"SSN": CLSAction.HASH}} + sql = "SELECT ssn FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # Should match despite case differences + assert "MD5" in result + + +def test_apply_cls_table_with_schema_multiple_tables() -> None: + """ + Test CLS rules with Table(table, schema) for multiple tables. + """ + rules = { + Table("users", schema="public"): {"ssn": CLSAction.HASH}, + Table("accounts", schema="finance"): {"balance": CLSAction.NULLIFY}, + } + sql = """ + SELECT u.ssn, a.balance + FROM users u + JOIN accounts a ON u.id = a.user_id + """ + result = apply_cls(sql, rules, engine="postgresql") + + # Both rules should be applied + assert "MD5" in result + assert "NULL" in result + + +def test_apply_cls_table_with_catalog_and_schema_rule() -> None: + """ + Test CLS rules with Table(table, schema, catalog) pattern. + """ + rules = {Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.HASH}} + sql = "SELECT ssn FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # Rule has catalog/schema but query doesn't - should match by table name fallback + assert "MD5" in result + + +def test_apply_cls_table_with_catalog_and_schema_case_insensitive() -> None: + """ + Test CLS rules with Table(table, schema, catalog) are case-insensitive. + """ + rules = {Table("USERS", schema="PUBLIC", catalog="MYDB"): {"SSN": CLSAction.MASK}} + sql = "SELECT ssn FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # Should match despite case differences + assert "'****'" in result + + +def test_apply_cls_table_with_catalog_schema_multiple_actions() -> None: + """ + Test CLS rules with Table(table, schema, catalog) and multiple actions. + """ + rules = { + Table("employees", schema="hr", catalog="corp"): { + "ssn": CLSAction.HASH, + "salary": CLSAction.NULLIFY, + "phone": CLSAction.MASK, + "password": CLSAction.HIDE, + } + } + sql = "SELECT ssn, salary, phone, password, name FROM employees" + result = apply_cls(sql, rules, engine="postgresql") + + # Verify all actions applied + assert "MD5" in result # HASH + assert "NULL" in result # NULLIFY + assert "'****'" in result # MASK + assert "password" not in result.lower() or "password" not in result.split("SELECT")[1].split("FROM")[0] # HIDE + + +def test_apply_cls_table_with_schema_in_predicate() -> None: + """ + Test CLS rules with Table(table, schema) also transform predicates. + """ + rules = {Table("payroll", schema="hr"): {"role": CLSAction.HASH}} + sql = "SELECT MAX(salary) FROM payroll WHERE role = 'CEO'" + result = apply_cls(sql, rules, engine="postgresql") + + # Both SELECT and WHERE should have the column hashed + assert "MD5" in result + assert "WHERE" in result + + +def test_apply_cls_table_with_catalog_schema_in_predicate() -> None: + """ + Test CLS rules with Table(table, schema, catalog) also transform predicates. + """ + rules = {Table("payroll", schema="hr", catalog="corp"): {"salary": CLSAction.NULLIFY}} + sql = "SELECT salary, name FROM payroll WHERE salary > 100000" + result = apply_cls(sql, rules, engine="postgresql") + + # SELECT should have NULL (for salary), WHERE should have FALSE + assert "NULL" in result + assert "FALSE" in result + + +def test_apply_cls_table_schema_no_match_different_table() -> None: + """ + Test CLS rules with Table(table, schema) don't match different table names. + """ + rules = {Table("employees", schema="hr"): {"ssn": CLSAction.HASH}} + sql = "SELECT ssn FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # Should NOT apply - table name doesn't match + assert "MD5" not in result + + +def test_apply_cls_table_catalog_schema_no_match_different_table() -> None: + """ + Test CLS rules with Table(table, schema, catalog) don't match different table names. + """ + rules = {Table("employees", schema="hr", catalog="corp"): {"ssn": CLSAction.HASH}} + sql = "SELECT ssn FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # Should NOT apply - table name doesn't match + assert "MD5" not in result + + +def test_apply_cls_mixed_table_rules() -> None: + """ + Test CLS with a mix of Table rules: some with schema/catalog, some without. + """ + rules = { + Table("users"): {"email": CLSAction.MASK}, # No schema/catalog + Table("employees", schema="hr"): {"ssn": CLSAction.HASH}, # With schema + Table("payroll", schema="finance", catalog="corp"): {"salary": CLSAction.NULLIFY}, # With both + } + sql = """ + SELECT u.email, e.ssn, p.salary + FROM users u + JOIN employees e ON u.id = e.user_id + JOIN payroll p ON e.id = p.employee_id + """ + result = apply_cls(sql, rules, engine="postgresql") + + # All three rules should be applied + assert "'****'" in result # MASK for email + assert "MD5" in result # HASH for ssn + assert "NULL" in result # NULLIFY for salary + + +def test_cls_transformer_normalize_rules_with_schema() -> None: + """ + Test CLSTransformer normalizes Table with schema to lowercase. + """ + rules = {Table("USERS", schema="PUBLIC"): {"SSN": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + + # Check that normalized Table key exists + normalized_key = Table("users", schema="public") + assert normalized_key in transformer.rules + assert "ssn" in transformer.rules[normalized_key] + + +def test_cls_transformer_normalize_rules_with_catalog_and_schema() -> None: + """ + Test CLSTransformer normalizes Table with catalog and schema to lowercase. + """ + rules = {Table("USERS", schema="PUBLIC", catalog="MYDB"): {"SSN": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + + # Check that normalized Table key exists + normalized_key = Table("users", schema="public", catalog="mydb") + assert normalized_key in transformer.rules + assert "ssn" in transformer.rules[normalized_key] + + +def test_cls_transformer_get_action_with_schema() -> None: + """ + Test CLSTransformer._get_action with Table(table, schema). + """ + rules = {Table("users", schema="public"): {"ssn": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + + # Should match by table name (fallback behavior) + assert transformer._get_action("users", "ssn") == CLSAction.HASH + + # Case insensitive + assert transformer._get_action("USERS", "SSN") == CLSAction.HASH + + # Different table should not match + assert transformer._get_action("employees", "ssn") is None + + +def test_cls_transformer_get_action_with_catalog_and_schema() -> None: + """ + Test CLSTransformer._get_action with Table(table, schema, catalog). + """ + rules = {Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + + # Should match by table name (fallback behavior) + assert transformer._get_action("users", "ssn") == CLSAction.HASH + + # Case insensitive + assert transformer._get_action("USERS", "SSN") == CLSAction.HASH + + # Different table should not match + assert transformer._get_action("employees", "ssn") is None + + +# Tests for merge_cls_rules and CLS_ACTION_PRECEDENCE +def test_cls_action_precedence() -> None: + """ + Test CLS_ACTION_PRECEDENCE has correct ordering: HIDE > NULLIFY > MASK > HASH. + """ + assert CLS_ACTION_PRECEDENCE[CLSAction.HIDE] > CLS_ACTION_PRECEDENCE[CLSAction.NULLIFY] + assert CLS_ACTION_PRECEDENCE[CLSAction.NULLIFY] > CLS_ACTION_PRECEDENCE[CLSAction.MASK] + assert CLS_ACTION_PRECEDENCE[CLSAction.MASK] > CLS_ACTION_PRECEDENCE[CLSAction.HASH] + + +def test_merge_cls_rules_empty() -> None: + """ + Test merge_cls_rules with no arguments returns empty dict. + """ + result = merge_cls_rules() + assert result == {} + + +def test_merge_cls_rules_single() -> None: + """ + Test merge_cls_rules with single rule set returns it unchanged. + """ + rules = {Table("foo"): {"col1": CLSAction.HASH}} + result = merge_cls_rules(rules) + assert result == rules + + +def test_merge_cls_rules_no_conflict() -> None: + """ + Test merge_cls_rules with non-conflicting rules. + """ + rules1 = {Table("foo"): {"col1": CLSAction.HASH}} + rules2 = {Table("foo"): {"col2": CLSAction.HIDE}} + result = merge_cls_rules(rules1, rules2) + + assert result == {Table("foo"): {"col1": CLSAction.HASH, "col2": CLSAction.HIDE}} + + +def test_merge_cls_rules_different_tables() -> None: + """ + Test merge_cls_rules with different tables. + """ + rules1 = {Table("foo"): {"col1": CLSAction.HASH}} + rules2 = {Table("bar"): {"col1": CLSAction.NULLIFY}} + result = merge_cls_rules(rules1, rules2) + + assert result == { + Table("foo"): {"col1": CLSAction.HASH}, + Table("bar"): {"col1": CLSAction.NULLIFY}, + } + + +def test_merge_cls_rules_conflict_nullify_over_hash() -> None: + """ + Test merge_cls_rules keeps NULLIFY over HASH (stricter). + """ + rules1 = {Table("foo"): {"col1": CLSAction.HASH}} + rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY}} + result = merge_cls_rules(rules1, rules2) + + assert result == {Table("foo"): {"col1": CLSAction.NULLIFY}} + + +def test_merge_cls_rules_conflict_hide_over_nullify() -> None: + """ + Test merge_cls_rules keeps HIDE over NULLIFY (stricter). + """ + rules1 = {Table("foo"): {"col1": CLSAction.NULLIFY}} + rules2 = {Table("foo"): {"col1": CLSAction.HIDE}} + result = merge_cls_rules(rules1, rules2) + + assert result == {Table("foo"): {"col1": CLSAction.HIDE}} + + +def test_merge_cls_rules_conflict_mask_over_hash() -> None: + """ + Test merge_cls_rules keeps MASK over HASH (stricter). + """ + rules1 = {Table("foo"): {"col1": CLSAction.HASH}} + rules2 = {Table("foo"): {"col1": CLSAction.MASK}} + result = merge_cls_rules(rules1, rules2) + + assert result == {Table("foo"): {"col1": CLSAction.MASK}} + + +def test_merge_cls_rules_conflict_nullify_over_mask() -> None: + """ + Test merge_cls_rules keeps NULLIFY over MASK (stricter). + """ + rules1 = {Table("foo"): {"col1": CLSAction.MASK}} + rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY}} + result = merge_cls_rules(rules1, rules2) + + assert result == {Table("foo"): {"col1": CLSAction.NULLIFY}} + + +def test_merge_cls_rules_keeps_stricter_regardless_of_order() -> None: + """ + Test merge_cls_rules keeps stricter action regardless of input order. + """ + rules1 = {Table("foo"): {"col1": CLSAction.NULLIFY}} + rules2 = {Table("foo"): {"col1": CLSAction.HASH}} + + # NULLIFY is stricter than HASH, should be kept + result = merge_cls_rules(rules1, rules2) + assert result == {Table("foo"): {"col1": CLSAction.NULLIFY}} + + # Reverse order should produce same result + result = merge_cls_rules(rules2, rules1) + assert result == {Table("foo"): {"col1": CLSAction.NULLIFY}} + + +def test_merge_cls_rules_user_example() -> None: + """ + Test merge_cls_rules with the user's example from requirements. + + Given: + {Table("foo"): {"col1": CLSAction.NULLIFY}} + {Table("foo"): {"col1": CLSAction.HASH, "col2": CLSAction.HIDE}} + + Should produce: + {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}} + """ + rules1 = {Table("foo"): {"col1": CLSAction.NULLIFY}} + rules2 = {Table("foo"): {"col1": CLSAction.HASH, "col2": CLSAction.HIDE}} + result = merge_cls_rules(rules1, rules2) + + assert result == {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}} + + +def test_merge_cls_rules_multiple_rule_sets() -> None: + """ + Test merge_cls_rules with more than two rule sets. + """ + rules1 = {Table("foo"): {"col1": CLSAction.HASH}} + rules2 = {Table("foo"): {"col1": CLSAction.MASK, "col2": CLSAction.HASH}} + rules3 = {Table("foo"): {"col1": CLSAction.NULLIFY, "col3": CLSAction.HIDE}} + result = merge_cls_rules(rules1, rules2, rules3) + + assert result == { + Table("foo"): { + "col1": CLSAction.NULLIFY, # NULLIFY is strictest + "col2": CLSAction.HASH, + "col3": CLSAction.HIDE, + } + } + + +def test_merge_cls_rules_with_schema() -> None: + """ + Test merge_cls_rules with Table containing schema. + """ + rules1 = {Table("users", schema="public"): {"ssn": CLSAction.HASH}} + rules2 = {Table("users", schema="public"): {"ssn": CLSAction.HIDE, "email": CLSAction.MASK}} + result = merge_cls_rules(rules1, rules2) + + assert result == { + Table("users", schema="public"): {"ssn": CLSAction.HIDE, "email": CLSAction.MASK} + } + + +def test_merge_cls_rules_with_catalog_and_schema() -> None: + """ + Test merge_cls_rules with Table containing catalog and schema. + """ + rules1 = {Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.MASK}} + rules2 = {Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.NULLIFY}} + result = merge_cls_rules(rules1, rules2) + + assert result == { + Table("users", schema="public", catalog="mydb"): {"ssn": CLSAction.NULLIFY} + } + + +def test_merge_cls_rules_different_schemas_same_table() -> None: + """ + Test merge_cls_rules treats tables with different schemas as distinct. + """ + rules1 = {Table("users", schema="public"): {"ssn": CLSAction.HASH}} + rules2 = {Table("users", schema="private"): {"ssn": CLSAction.HIDE}} + result = merge_cls_rules(rules1, rules2) + + # Should be two separate entries, not merged + assert result == { + Table("users", schema="public"): {"ssn": CLSAction.HASH}, + Table("users", schema="private"): {"ssn": CLSAction.HIDE}, + } + + +def test_merge_cls_rules_complex_scenario() -> None: + """ + Test merge_cls_rules with a complex real-world scenario. + """ + # Organization-wide rules (less strict) + org_rules = { + Table("employees"): {"ssn": CLSAction.HASH, "salary": CLSAction.MASK}, + Table("customers"): {"email": CLSAction.MASK}, + } + + # Department-specific rules (stricter for certain columns) + dept_rules = { + Table("employees"): {"ssn": CLSAction.HIDE, "phone": CLSAction.MASK}, + Table("customers"): {"email": CLSAction.HASH, "credit_card": CLSAction.HIDE}, + } + + # User-specific rules (even stricter) + user_rules = { + Table("employees"): {"salary": CLSAction.HIDE}, + } + + result = merge_cls_rules(org_rules, dept_rules, user_rules) + + assert result == { + Table("employees"): { + "ssn": CLSAction.HIDE, # HIDE > HASH + "salary": CLSAction.HIDE, # HIDE > MASK + "phone": CLSAction.MASK, + }, + Table("customers"): { + "email": CLSAction.MASK, # MASK > HASH + "credit_card": CLSAction.HIDE, + }, + }