Support filters

This commit is contained in:
Beto Dealmeida
2025-12-16 10:58:16 -05:00
parent 57a210f7d6
commit 5c61c40704
2 changed files with 1381 additions and 152 deletions

View File

@@ -142,8 +142,97 @@ class CLSAction(enum.Enum):
MASK = enum.auto() # Replace with '****'
# Type alias for CLS rules: {table_name: {column_name: action}}
CLSRules = dict[str, dict[str, CLSAction]]
@dataclass(eq=True, frozen=True)
class Table:
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""
table: str
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
Should not be used for SQL generation, only for logging and debugging, since the
quoting is not engine-specific.
"""
return ".".join(
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def qualify(
self,
*,
catalog: str | None = None,
schema: str | None = None,
) -> Table:
"""
Return a new Table with the given schema and/or catalog, if not already set.
"""
return Table(
table=self.table,
schema=self.schema or schema,
catalog=self.catalog or catalog,
)
# Type alias for CLS rules: {Table: {column_name: action}}
CLSRules = dict[Table, dict[str, CLSAction]]
# CLS action precedence: higher value = stricter (less information revealed)
# HIDE > NULLIFY > MASK > HASH
CLS_ACTION_PRECEDENCE: dict[CLSAction, int] = {
CLSAction.HASH: 1,
CLSAction.MASK: 2,
CLSAction.NULLIFY: 3,
CLSAction.HIDE: 4,
}
def merge_cls_rules(*rules_list: CLSRules) -> CLSRules:
"""
Merge multiple CLS rule sets into one, using the stricter action when conflicts occur.
When multiple rules specify actions for the same table/column, the stricter action
is kept. Precedence (strictest to least strict): HIDE > NULLIFY > MASK > HASH
Args:
*rules_list: Variable number of CLSRules dicts to merge
Returns:
A merged CLSRules dict with the strictest action for each table/column
Example:
>>> rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
>>> rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
>>> merge_cls_rules(rules1, rules2)
{Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
"""
merged: CLSRules = {}
for rules in rules_list:
for table, columns in rules.items():
if table not in merged:
merged[table] = {}
for column, action in columns.items():
existing_action = merged[table].get(column)
if existing_action is None:
merged[table][column] = action
else:
# Keep the stricter action (higher precedence value)
if CLS_ACTION_PRECEDENCE[action] > CLS_ACTION_PRECEDENCE[existing_action]:
merged[table][column] = action
return merged
# Hash function patterns by dialect. The placeholder {} will be replaced with the
@@ -333,16 +422,21 @@ class CLSTransformer:
"""
AST transformer to apply Column-Level Security rules.
This transformer modifies SELECT expressions to apply CLS actions:
This transformer modifies SELECT expressions and predicates to apply CLS actions:
- HASH: Replace column with hash function (database-specific)
- NULLIFY: Replace column with NULL AS column_name
- HIDE: Remove column from SELECT entirely
- MASK: Replace column with '****' AS column_name
- NULLIFY: Replace with NULL AS column_name (SELECT) or FALSE (predicates)
- HIDE: Remove column from SELECT entirely, FALSE in predicates
- MASK: Replace column with '****' AS column_name (SELECT) or FALSE (predicates)
Example:
Given rules: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
Query: SELECT id, salary, name FROM my_table
Given rules: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
Query: SELECT id, salary, name FROM my_table WHERE id = 1
Result: SELECT MD5(CAST(id AS TEXT)), NULL AS salary, name FROM my_table
WHERE MD5(CAST(id AS TEXT)) = 1
For predicates, HASH transforms the column to ensure filtered results also respect
the security policy. NULLIFY/MASK/HIDE transform to FALSE to prevent information
leakage through filtering.
"""
def __init__(
@@ -354,12 +448,16 @@ class CLSTransformer:
self.dialect = dialect
self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None])
def _normalize_rules(self, rules: CLSRules) -> CLSRules:
def _normalize_rules(self, rules: CLSRules) -> dict[Table, dict[str, CLSAction]]:
"""
Normalize table and column names to lowercase for case-insensitive matching.
"""
return {
table.lower(): {col.lower(): action for col, action in cols.items()}
Table(
table=table.table.lower(),
schema=table.schema.lower() if table.schema else None,
catalog=table.catalog.lower() if table.catalog else None,
): {col.lower(): action for col, action in cols.items()}
for table, cols in rules.items()
}
@@ -367,18 +465,41 @@ class CLSTransformer:
self,
table_name: str | None,
column_name: str,
schema: str | None = None,
catalog: str | None = None,
) -> CLSAction | None:
"""
Get the CLS action for a column, if any.
Matching logic:
1. First try exact match with schema/catalog if provided
2. Fallback to table name match - if table names match, apply the rule
regardless of schema/catalog (since query may not have schema info)
"""
if not table_name:
return None
table_rules = self.rules.get(table_name.lower())
if not table_rules:
return None
# Create a normalized Table for lookup
lookup_table = Table(
table=table_name.lower(),
schema=schema.lower() if schema else None,
catalog=catalog.lower() if catalog else None,
)
return table_rules.get(column_name.lower())
# First try exact match with schema/catalog
table_rules = self.rules.get(lookup_table)
if table_rules:
return table_rules.get(column_name.lower())
# Fallback: match by table name only
# This handles cases where the rule has schema/catalog but the query doesn't
for rule_table, cols in self.rules.items():
if rule_table.table == lookup_table.table:
action = cols.get(column_name.lower())
if action:
return action
return None
def _create_hash_expression(
self,
@@ -406,15 +527,44 @@ class CLSTransformer:
alias=exp.Identifier(this=alias),
)
def _create_mask_expression(self, alias: str) -> exp.Expression:
def _create_mask_expression(
self,
column: exp.Column,
alias: str,
) -> exp.Expression:
"""
Create a '****' AS alias expression.
Create a CASE expression that masks non-NULL values while preserving NULLs.
Generates: CASE WHEN column IS NULL THEN NULL ELSE '****' END AS alias
This preserves the semantic meaning of NULL (no value) vs masked (hidden value).
"""
return exp.Alias(
this=exp.Literal(this="****", is_string=True),
this=exp.Case(
ifs=[
exp.If(
this=exp.Is(this=column.copy(), expression=exp.Null()),
true=exp.Null(),
)
],
default=exp.Literal(this="****", is_string=True),
),
alias=exp.Identifier(this=alias),
)
def _create_hash_expression_no_alias(
self,
column: exp.Column,
) -> exp.Expression:
"""
Create a hash expression for a column without an alias.
Used for transforming columns in predicates (WHERE, ON, etc.).
"""
col_sql = column.sql(dialect=self.dialect)
hash_sql = self.hash_pattern.format(col_sql)
return sqlglot.parse_one(hash_sql, dialect=self.dialect)
def _get_column_alias(self, expr: exp.Expression) -> str:
"""
Get the alias for a column expression.
@@ -453,9 +603,10 @@ class CLSTransformer:
# This is a best-effort match for unqualified columns
col_lower = column.name.lower()
for table_name in scope_tables.values():
table_rules = self.rules.get(table_name.lower())
if table_rules and col_lower in table_rules:
return table_name
# Look for a rule matching this table
for rule_table, cols in self.rules.items():
if rule_table.table == table_name.lower() and col_lower in cols:
return table_name
return None
@@ -481,6 +632,34 @@ class CLSTransformer:
return tables
def _transform_nested_column(
self,
column: exp.Column,
scope_tables: dict[str, str],
) -> exp.Expression:
"""
Transform a nested column reference within a SELECT expression.
This handles columns inside CASE expressions, function arguments, etc.
Unlike top-level columns, nested columns use NULL for blocking instead
of FALSE (which works better in non-predicate contexts).
- HASH: Replace with hash function
- NULLIFY/MASK/HIDE: Replace with NULL (blocks computation safely)
"""
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return column
if action == CLSAction.HASH:
return self._create_hash_expression_no_alias(column)
# NULLIFY/MASK/HIDE: Return NULL to safely block any computation
# NULL propagates through expressions: UPPER(NULL)→NULL, 1+NULL→NULL, etc.
return exp.Null()
def _transform_expression(
self,
expr: exp.Expression,
@@ -489,32 +668,42 @@ class CLSTransformer:
"""
Transform a single SELECT expression based on CLS rules.
For simple column references: apply full transformation with alias.
For complex expressions: transform all nested column references.
Returns:
- Transformed expression
- None if the column should be hidden
- None if a top-level column should be hidden
"""
# Get the underlying column (handle aliases)
column = expr.this if isinstance(expr, exp.Alias) else expr
alias = self._get_column_alias(expr)
if not isinstance(column, exp.Column):
# Not a simple column reference (could be a function, literal, etc.)
return expr
if isinstance(column, exp.Column):
# Simple column reference - apply full transformation with alias
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return expr
if action is None:
return expr
if action == CLSAction.HIDE:
return None
if action == CLSAction.HASH:
return self._create_hash_expression(column, alias)
if action == CLSAction.NULLIFY:
return self._create_null_expression(alias)
# action == CLSAction.MASK
return self._create_mask_expression(column, alias)
if action == CLSAction.HIDE:
return None
if action == CLSAction.HASH:
return self._create_hash_expression(column, alias)
if action == CLSAction.NULLIFY:
return self._create_null_expression(alias)
# action == CLSAction.MASK
return self._create_mask_expression(alias)
# Complex expression (CASE, function, arithmetic, etc.)
# Transform ALL nested column references within it
def transform_nested(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Column):
return self._transform_nested_column(node, scope_tables)
return node
return expr.transform(transform_nested)
def _transform_star(
self,
@@ -535,13 +724,162 @@ class CLSTransformer:
)
return [star]
def _transform_non_select_column(
self,
column: exp.Column,
scope_tables: dict[str, str],
) -> exp.Expression:
"""
Transform a column reference outside of SELECT list.
This is the SINGLE transformation function for ALL column references
outside the SELECT list (WHERE, HAVING, ON, GROUP BY, ORDER BY,
window functions, CASE expressions, function arguments, etc.)
- HASH: Replace with hash function
- NULLIFY/MASK/HIDE: Replace with FALSE (blocks predicates, marked for
removal in GROUP BY/ORDER BY)
"""
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return column
if action == CLSAction.HASH:
return self._create_hash_expression_no_alias(column)
# NULLIFY/MASK/HIDE: Return FALSE to block usage
# For predicates: FALSE blocks the filter
# For GROUP BY/ORDER BY: Will be cleaned up in post-processing
return exp.false()
@staticmethod
def _is_blocked(node: exp.Expression) -> bool:
"""Check if an expression is a blocked column (FALSE or NULL sentinel)."""
# FALSE is used for blocked columns in predicates (Phase 2)
# NULL is used for blocked columns in nested expressions (Phase 1)
if isinstance(node, exp.Boolean) and not node.this:
return True
if isinstance(node, exp.Null):
return True
return False
def _transform_all_non_select_columns(
self,
select: exp.Select,
scope_tables: dict[str, str],
) -> None:
"""
Transform ALL column references outside the SELECT list.
This uses sqlglot's transform() to recursively walk through the entire
expression tree, ensuring we catch columns in:
- WHERE clauses
- HAVING clauses
- JOIN ON conditions
- GROUP BY clauses
- ORDER BY clauses
- Window function PARTITION BY / ORDER BY
- CASE expressions
- Function arguments
- Any other nested expression
This is the security-critical function that ensures NO column reference
is missed.
"""
def transform_column(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Column):
return self._transform_non_select_column(node, scope_tables)
return node
# Transform WHERE
if where := select.args.get("where"):
transformed = where.this.transform(transform_column)
select.set("where", exp.Where(this=transformed))
# Transform HAVING
if having := select.args.get("having"):
transformed = having.this.transform(transform_column)
select.set("having", exp.Having(this=transformed))
# Transform all JOINs (ON conditions)
for join in select.args.get("joins") or []:
if on := join.args.get("on"):
transformed = on.transform(transform_column)
join.set("on", transformed)
# Transform GROUP BY and remove blocked (FALSE) expressions
if group := select.args.get("group"):
new_exprs = []
for expr in group.expressions:
transformed = expr.transform(transform_column)
if not self._is_blocked(transformed):
new_exprs.append(transformed)
if new_exprs:
group.set("expressions", new_exprs)
else:
select.set("group", None)
# Transform ORDER BY and remove blocked (FALSE) expressions
if order := select.args.get("order"):
new_exprs = []
for ordered in order.expressions:
transformed = ordered.transform(transform_column)
# Check the inner expression (Ordered wraps the actual expr)
inner = transformed.this if isinstance(transformed, exp.Ordered) else transformed
if not self._is_blocked(inner):
new_exprs.append(transformed)
if new_exprs:
order.set("expressions", new_exprs)
else:
select.set("order", None)
# Transform Window functions within SELECT expressions
# Window functions have their own PARTITION BY and ORDER BY clauses
for expr in select.args.get("expressions", []):
for window in expr.find_all(exp.Window):
# Transform PARTITION BY
if partition_by := window.args.get("partition_by"):
new_partition = []
for part_expr in partition_by:
transformed = part_expr.transform(transform_column)
if not self._is_blocked(transformed):
new_partition.append(transformed)
window.set("partition_by", new_partition if new_partition else None)
# Transform ORDER BY within window
if window_order := window.args.get("order"):
new_order_exprs = []
for ordered in window_order.expressions:
transformed = ordered.transform(transform_column)
inner = (
transformed.this
if isinstance(transformed, exp.Ordered)
else transformed
)
if not self._is_blocked(inner):
new_order_exprs.append(transformed)
if new_order_exprs:
window_order.set("expressions", new_order_exprs)
else:
window.set("order", None)
def transform_select(self, select: exp.Select) -> exp.Select:
"""
Transform a SELECT statement by applying CLS rules to its expressions.
Transform a SELECT statement by applying CLS rules.
This is the main entry point for CLS transformation. It:
1. Extracts table scope for column resolution
2. Transforms SELECT list expressions (with HIDE removal and aliases)
3. Transforms ALL other column references in the query
"""
scope_tables = self._extract_scope_tables(select)
expressions = select.args.get("expressions", [])
# Phase 1: Transform SELECT list expressions
# This handles HASH/NULLIFY/MASK with aliases, and removes HIDE columns
expressions = select.args.get("expressions", [])
new_expressions: list[exp.Expression] = []
for expr in expressions:
if isinstance(expr, exp.Star):
@@ -550,9 +888,12 @@ class CLSTransformer:
transformed = self._transform_expression(expr, scope_tables)
if transformed is not None:
new_expressions.append(transformed)
# Create a new SELECT with transformed expressions
select.set("expressions", new_expressions)
# Phase 2: Transform ALL other column references
# This is the security-critical phase that catches every column reference
self._transform_all_non_select_columns(select, scope_tables)
return select
def __call__(self, node: exp.Expression) -> exp.Expression:
@@ -573,16 +914,19 @@ def apply_cls(
"""
Apply Column-Level Security rules to a SQL query.
This function transforms a SQL query by applying CLS actions to sensitive columns:
- HASH: Pseudonymize using database-specific hash function
- NULLIFY: Replace with NULL
- HIDE: Remove from SELECT results
- MASK: Replace with '****'
This function transforms a SQL query by applying CLS actions to sensitive columns
in both SELECT expressions and predicates (WHERE, ON, HAVING):
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
Args:
sql: The SQL query to transform
rules: CLS rules mapping table names to column actions
Example: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
rules: CLS rules mapping Table objects to column actions
Example: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
Tables can include schema/catalog for fully qualified matching.
engine: The database engine (used for dialect-specific hash functions)
schema: Optional schema for column qualification. Required for JOINs with
ambiguous column names. Format: {"table": {"column": "TYPE", ...}, ...}
@@ -599,48 +943,6 @@ def apply_cls(
return statement.format(comments=True)
@dataclass(eq=True, frozen=True)
class Table:
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""
table: str
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
Should not be used for SQL generation, only for logging and debugging, since the
quoting is not engine-specific.
"""
return ".".join(
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def qualify(
self,
*,
catalog: str | None = None,
schema: str | None = None,
) -> Table:
"""
Return a new Table with the given schema and/or catalog, if not already set.
"""
return Table(
table=self.table,
schema=self.schema or schema,
catalog=self.catalog or catalog,
)
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
# an "internal representation", which is the AST of the SQL statement. For most of the
# engines supported by Superset this is `sqlglot.exp.Expression`, but there is a special
@@ -1215,14 +1517,14 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
"""
Apply Column-Level Security rules to the statement inplace.
CLS rules transform sensitive columns in SELECT statements:
- HASH: Pseudonymize using database-specific hash function
- NULLIFY: Replace with NULL
- HIDE: Remove from SELECT results
- MASK: Replace with '****'
CLS rules transform sensitive columns in SELECT statements and predicates:
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
:param rules: CLS rules mapping table names to column actions
Example: {"my_table": {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
:param rules: CLS rules mapping Table objects to column actions
Example: {Table("my_table"): {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
:param schema: Optional schema for column qualification. Required for JOINs
with ambiguous column names. Format: {"table": {"column": "TYPE", ...}}
"""