Compare commits

...

2 Commits

Author SHA1 Message Date
Beto Dealmeida
5c61c40704 Support filters 2025-12-16 11:31:35 -05:00
Beto Dealmeida
57a210f7d6 feat: column-level security 2025-12-15 17:01:11 -05:00
2 changed files with 2404 additions and 32 deletions

View File

@@ -129,6 +129,145 @@ class LimitMethod(enum.Enum):
FETCH_MANY = enum.auto()
class CLSAction(enum.Enum):
"""
Column-Level Security actions.
These actions determine how sensitive columns are transformed in queries.
"""
HASH = enum.auto() # Pseudonymization via hashing
NULLIFY = enum.auto() # Replace with NULL
HIDE = enum.auto() # Remove from results entirely
MASK = enum.auto() # Replace with '****'
@dataclass(eq=True, frozen=True)
class Table:
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
"""
table: str
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
Should not be used for SQL generation, only for logging and debugging, since the
quoting is not engine-specific.
"""
return ".".join(
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def qualify(
self,
*,
catalog: str | None = None,
schema: str | None = None,
) -> Table:
"""
Return a new Table with the given schema and/or catalog, if not already set.
"""
return Table(
table=self.table,
schema=self.schema or schema,
catalog=self.catalog or catalog,
)
# Type alias for CLS rules: {Table: {column_name: action}}
CLSRules = dict[Table, dict[str, CLSAction]]
# CLS action precedence: higher value = stricter (less information revealed)
# HIDE > NULLIFY > MASK > HASH
CLS_ACTION_PRECEDENCE: dict[CLSAction, int] = {
CLSAction.HASH: 1,
CLSAction.MASK: 2,
CLSAction.NULLIFY: 3,
CLSAction.HIDE: 4,
}
def merge_cls_rules(*rules_list: CLSRules) -> CLSRules:
"""
Merge multiple CLS rule sets into one, using the stricter action when conflicts occur.
When multiple rules specify actions for the same table/column, the stricter action
is kept. Precedence (strictest to least strict): HIDE > NULLIFY > MASK > HASH
Args:
*rules_list: Variable number of CLSRules dicts to merge
Returns:
A merged CLSRules dict with the strictest action for each table/column
Example:
>>> rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
>>> rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
>>> merge_cls_rules(rules1, rules2)
{Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
"""
merged: CLSRules = {}
for rules in rules_list:
for table, columns in rules.items():
if table not in merged:
merged[table] = {}
for column, action in columns.items():
existing_action = merged[table].get(column)
if existing_action is None:
merged[table][column] = action
else:
# Keep the stricter action (higher precedence value)
if CLS_ACTION_PRECEDENCE[action] > CLS_ACTION_PRECEDENCE[existing_action]:
merged[table][column] = action
return merged
# Hash function patterns by dialect. The placeholder {} will be replaced with the
# column. Some databases need casting for non-string types, so we cast to string/text.
# The fallback uses a literal since there's no universal hash function across all
# databases.
CLS_HASH_FUNCTIONS: dict[Dialects | type[Dialect] | None, str] = {
None: "'[HASHED]'", # Universal fallback - no hash function available
Dialects.DIALECT: "MD5(CAST({} AS VARCHAR))", # Generic SQL with MD5
Dialects.POSTGRES: "MD5(CAST({} AS TEXT))",
Dialects.MYSQL: "MD5(CAST({} AS CHAR))",
Dialects.BIGQUERY: "TO_HEX(MD5(CAST({} AS STRING)))",
Dialects.SNOWFLAKE: "MD5(TO_VARCHAR({}))",
Dialects.REDSHIFT: "MD5(CAST({} AS VARCHAR))",
Dialects.PRESTO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
Dialects.TRINO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
Dialects.SQLITE: "HEX({})", # SQLite doesn't have MD5, use HEX as placeholder
Dialects.DUCKDB: "MD5(CAST({} AS VARCHAR))",
Dialects.ORACLE: "STANDARD_HASH(TO_CHAR({}), 'MD5')",
Dialects.TSQL: (
"CONVERT(VARCHAR(32), HASHBYTES('MD5', CAST({} AS VARCHAR(MAX))), 2)"
),
Dialects.HIVE: "MD5(CAST({} AS STRING))",
Dialects.SPARK: "MD5(CAST({} AS STRING))",
Dialects.CLICKHOUSE: "MD5(toString({}))",
Dialects.DATABRICKS: "MD5(CAST({} AS STRING))",
Dialects.DORIS: "MD5(CAST({} AS VARCHAR))",
Dialects.STARROCKS: "MD5(CAST({} AS VARCHAR))",
Dialects.DRILL: "MD5(CAST({} AS VARCHAR))",
Dialects.DRUID: "MD5(CAST({} AS VARCHAR))",
Dialects.TERADATA: "HASH_MD5(CAST({} AS VARCHAR(10000)))",
Dialects.RISINGWAVE: "MD5(CAST({} AS VARCHAR))",
}
class CTASMethod(enum.Enum):
TABLE = enum.auto()
VIEW = enum.auto()
@@ -279,47 +418,530 @@ class RLSAsSubqueryTransformer(RLSTransformer):
return node
@dataclass(eq=True, frozen=True)
class Table:
class CLSTransformer:
"""
A fully qualified SQL table conforming to [[catalog.]schema.]table.
AST transformer to apply Column-Level Security rules.
This transformer modifies SELECT expressions and predicates to apply CLS actions:
- HASH: Replace column with hash function (database-specific)
- NULLIFY: Replace with NULL AS column_name (SELECT) or FALSE (predicates)
- HIDE: Remove column from SELECT entirely, FALSE in predicates
- MASK: Replace column with '****' AS column_name (SELECT) or FALSE (predicates)
Example:
Given rules: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
Query: SELECT id, salary, name FROM my_table WHERE id = 1
Result: SELECT MD5(CAST(id AS TEXT)), NULL AS salary, name FROM my_table
WHERE MD5(CAST(id AS TEXT)) = 1
For predicates, HASH transforms the column to ensure filtered results also respect
the security policy. NULLIFY/MASK/HIDE transform to FALSE to prevent information
leakage through filtering.
"""
table: str
schema: str | None = None
catalog: str | None = None
def __str__(self) -> str:
"""
Return the fully qualified SQL table name.
Should not be used for SQL generation, only for logging and debugging, since the
quoting is not engine-specific.
"""
return ".".join(
urllib.parse.quote(part, safe="").replace(".", "%2E")
for part in [self.catalog, self.schema, self.table]
if part
)
def __eq__(self, other: Any) -> bool:
return str(self) == str(other)
def qualify(
def __init__(
self,
*,
catalog: str | None = None,
rules: CLSRules,
dialect: Dialects | type[Dialect] | None,
) -> None:
self.rules = self._normalize_rules(rules)
self.dialect = dialect
self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None])
def _normalize_rules(self, rules: CLSRules) -> dict[Table, dict[str, CLSAction]]:
"""
Normalize table and column names to lowercase for case-insensitive matching.
"""
return {
Table(
table=table.table.lower(),
schema=table.schema.lower() if table.schema else None,
catalog=table.catalog.lower() if table.catalog else None,
): {col.lower(): action for col, action in cols.items()}
for table, cols in rules.items()
}
def _get_action(
self,
table_name: str | None,
column_name: str,
schema: str | None = None,
) -> Table:
catalog: str | None = None,
) -> CLSAction | None:
"""
Return a new Table with the given schema and/or catalog, if not already set.
Get the CLS action for a column, if any.
Matching logic:
1. First try exact match with schema/catalog if provided
2. Fallback to table name match - if table names match, apply the rule
regardless of schema/catalog (since query may not have schema info)
"""
return Table(
table=self.table,
schema=self.schema or schema,
catalog=self.catalog or catalog,
if not table_name:
return None
# Create a normalized Table for lookup
lookup_table = Table(
table=table_name.lower(),
schema=schema.lower() if schema else None,
catalog=catalog.lower() if catalog else None,
)
# First try exact match with schema/catalog
table_rules = self.rules.get(lookup_table)
if table_rules:
return table_rules.get(column_name.lower())
# Fallback: match by table name only
# This handles cases where the rule has schema/catalog but the query doesn't
for rule_table, cols in self.rules.items():
if rule_table.table == lookup_table.table:
action = cols.get(column_name.lower())
if action:
return action
return None
def _create_hash_expression(
self,
column: exp.Column,
alias: str,
) -> exp.Expression:
"""
Create a hash expression for a column.
"""
# Generate the column SQL without any alias
col_sql = column.sql(dialect=self.dialect)
hash_sql = self.hash_pattern.format(col_sql)
hash_expr = sqlglot.parse_one(hash_sql, dialect=self.dialect)
return exp.Alias(
this=hash_expr,
alias=exp.Identifier(this=alias),
)
def _create_null_expression(self, alias: str) -> exp.Expression:
"""
Create a NULL AS alias expression.
"""
return exp.Alias(
this=exp.Null(),
alias=exp.Identifier(this=alias),
)
def _create_mask_expression(
self,
column: exp.Column,
alias: str,
) -> exp.Expression:
"""
Create a CASE expression that masks non-NULL values while preserving NULLs.
Generates: CASE WHEN column IS NULL THEN NULL ELSE '****' END AS alias
This preserves the semantic meaning of NULL (no value) vs masked (hidden value).
"""
return exp.Alias(
this=exp.Case(
ifs=[
exp.If(
this=exp.Is(this=column.copy(), expression=exp.Null()),
true=exp.Null(),
)
],
default=exp.Literal(this="****", is_string=True),
),
alias=exp.Identifier(this=alias),
)
def _create_hash_expression_no_alias(
self,
column: exp.Column,
) -> exp.Expression:
"""
Create a hash expression for a column without an alias.
Used for transforming columns in predicates (WHERE, ON, etc.).
"""
col_sql = column.sql(dialect=self.dialect)
hash_sql = self.hash_pattern.format(col_sql)
return sqlglot.parse_one(hash_sql, dialect=self.dialect)
def _get_column_alias(self, expr: exp.Expression) -> str:
"""
Get the alias for a column expression.
"""
if isinstance(expr, exp.Alias):
return expr.alias
if isinstance(expr, exp.Column):
return expr.name
return expr.sql(dialect=self.dialect)
def _get_table_for_column(
self,
column: exp.Column,
scope_tables: dict[str, str],
) -> str | None:
"""
Resolve which table a column belongs to.
Args:
column: The column expression
scope_tables: Map of alias/name to actual table name
Returns:
The table name or None if cannot be resolved
"""
if column.table:
# Column is qualified with table name/alias
return scope_tables.get(column.table.lower(), column.table)
# For unqualified columns, if there's only one table in scope,
# we can infer the column belongs to that table
if len(scope_tables) == 1:
return next(iter(scope_tables.values()))
# With multiple tables, check if any table in rules has this column
# This is a best-effort match for unqualified columns
col_lower = column.name.lower()
for table_name in scope_tables.values():
# Look for a rule matching this table
for rule_table, cols in self.rules.items():
if rule_table.table == table_name.lower() and col_lower in cols:
return table_name
return None
def _extract_scope_tables(self, select: exp.Select) -> dict[str, str]:
"""
Extract table names and aliases from a SELECT statement's FROM clause.
Returns a dict mapping alias (or table name if no alias) to actual table name.
"""
tables: dict[str, str] = {}
if from_clause := select.args.get("from"):
for table in from_clause.find_all(exp.Table):
table_name = table.name
alias = table.alias if table.alias else table_name
tables[alias.lower()] = table_name
for join in select.args.get("joins") or []:
for table in join.find_all(exp.Table):
table_name = table.name
alias = table.alias if table.alias else table_name
tables[alias.lower()] = table_name
return tables
def _transform_nested_column(
self,
column: exp.Column,
scope_tables: dict[str, str],
) -> exp.Expression:
"""
Transform a nested column reference within a SELECT expression.
This handles columns inside CASE expressions, function arguments, etc.
Unlike top-level columns, nested columns use NULL for blocking instead
of FALSE (which works better in non-predicate contexts).
- HASH: Replace with hash function
- NULLIFY/MASK/HIDE: Replace with NULL (blocks computation safely)
"""
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return column
if action == CLSAction.HASH:
return self._create_hash_expression_no_alias(column)
# NULLIFY/MASK/HIDE: Return NULL to safely block any computation
# NULL propagates through expressions: UPPER(NULL)→NULL, 1+NULL→NULL, etc.
return exp.Null()
def _transform_expression(
self,
expr: exp.Expression,
scope_tables: dict[str, str],
) -> exp.Expression | None:
"""
Transform a single SELECT expression based on CLS rules.
For simple column references: apply full transformation with alias.
For complex expressions: transform all nested column references.
Returns:
- Transformed expression
- None if a top-level column should be hidden
"""
# Get the underlying column (handle aliases)
column = expr.this if isinstance(expr, exp.Alias) else expr
alias = self._get_column_alias(expr)
if isinstance(column, exp.Column):
# Simple column reference - apply full transformation with alias
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return expr
if action == CLSAction.HIDE:
return None
if action == CLSAction.HASH:
return self._create_hash_expression(column, alias)
if action == CLSAction.NULLIFY:
return self._create_null_expression(alias)
# action == CLSAction.MASK
return self._create_mask_expression(column, alias)
# Complex expression (CASE, function, arithmetic, etc.)
# Transform ALL nested column references within it
def transform_nested(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Column):
return self._transform_nested_column(node, scope_tables)
return node
return expr.transform(transform_nested)
def _transform_star(
self,
star: exp.Star,
scope_tables: dict[str, str],
) -> list[exp.Expression]:
"""
Transform SELECT * by expanding hidden columns conceptually.
Since we don't have schema information, we cannot truly expand *.
We return the star as-is but log a warning.
"""
# Without schema information, we cannot expand SELECT *
# In a real implementation, you would need to query the database schema
logger.warning(
"CLS cannot fully process SELECT * without schema information. "
"Consider using explicit column lists for queries with CLS rules."
)
return [star]
def _transform_non_select_column(
self,
column: exp.Column,
scope_tables: dict[str, str],
) -> exp.Expression:
"""
Transform a column reference outside of SELECT list.
This is the SINGLE transformation function for ALL column references
outside the SELECT list (WHERE, HAVING, ON, GROUP BY, ORDER BY,
window functions, CASE expressions, function arguments, etc.)
- HASH: Replace with hash function
- NULLIFY/MASK/HIDE: Replace with FALSE (blocks predicates, marked for
removal in GROUP BY/ORDER BY)
"""
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return column
if action == CLSAction.HASH:
return self._create_hash_expression_no_alias(column)
# NULLIFY/MASK/HIDE: Return FALSE to block usage
# For predicates: FALSE blocks the filter
# For GROUP BY/ORDER BY: Will be cleaned up in post-processing
return exp.false()
@staticmethod
def _is_blocked(node: exp.Expression) -> bool:
"""Check if an expression is a blocked column (FALSE or NULL sentinel)."""
# FALSE is used for blocked columns in predicates (Phase 2)
# NULL is used for blocked columns in nested expressions (Phase 1)
if isinstance(node, exp.Boolean) and not node.this:
return True
if isinstance(node, exp.Null):
return True
return False
def _transform_all_non_select_columns(
self,
select: exp.Select,
scope_tables: dict[str, str],
) -> None:
"""
Transform ALL column references outside the SELECT list.
This uses sqlglot's transform() to recursively walk through the entire
expression tree, ensuring we catch columns in:
- WHERE clauses
- HAVING clauses
- JOIN ON conditions
- GROUP BY clauses
- ORDER BY clauses
- Window function PARTITION BY / ORDER BY
- CASE expressions
- Function arguments
- Any other nested expression
This is the security-critical function that ensures NO column reference
is missed.
"""
def transform_column(node: exp.Expression) -> exp.Expression:
if isinstance(node, exp.Column):
return self._transform_non_select_column(node, scope_tables)
return node
# Transform WHERE
if where := select.args.get("where"):
transformed = where.this.transform(transform_column)
select.set("where", exp.Where(this=transformed))
# Transform HAVING
if having := select.args.get("having"):
transformed = having.this.transform(transform_column)
select.set("having", exp.Having(this=transformed))
# Transform all JOINs (ON conditions)
for join in select.args.get("joins") or []:
if on := join.args.get("on"):
transformed = on.transform(transform_column)
join.set("on", transformed)
# Transform GROUP BY and remove blocked (FALSE) expressions
if group := select.args.get("group"):
new_exprs = []
for expr in group.expressions:
transformed = expr.transform(transform_column)
if not self._is_blocked(transformed):
new_exprs.append(transformed)
if new_exprs:
group.set("expressions", new_exprs)
else:
select.set("group", None)
# Transform ORDER BY and remove blocked (FALSE) expressions
if order := select.args.get("order"):
new_exprs = []
for ordered in order.expressions:
transformed = ordered.transform(transform_column)
# Check the inner expression (Ordered wraps the actual expr)
inner = transformed.this if isinstance(transformed, exp.Ordered) else transformed
if not self._is_blocked(inner):
new_exprs.append(transformed)
if new_exprs:
order.set("expressions", new_exprs)
else:
select.set("order", None)
# Transform Window functions within SELECT expressions
# Window functions have their own PARTITION BY and ORDER BY clauses
for expr in select.args.get("expressions", []):
for window in expr.find_all(exp.Window):
# Transform PARTITION BY
if partition_by := window.args.get("partition_by"):
new_partition = []
for part_expr in partition_by:
transformed = part_expr.transform(transform_column)
if not self._is_blocked(transformed):
new_partition.append(transformed)
window.set("partition_by", new_partition if new_partition else None)
# Transform ORDER BY within window
if window_order := window.args.get("order"):
new_order_exprs = []
for ordered in window_order.expressions:
transformed = ordered.transform(transform_column)
inner = (
transformed.this
if isinstance(transformed, exp.Ordered)
else transformed
)
if not self._is_blocked(inner):
new_order_exprs.append(transformed)
if new_order_exprs:
window_order.set("expressions", new_order_exprs)
else:
window.set("order", None)
def transform_select(self, select: exp.Select) -> exp.Select:
"""
Transform a SELECT statement by applying CLS rules.
This is the main entry point for CLS transformation. It:
1. Extracts table scope for column resolution
2. Transforms SELECT list expressions (with HIDE removal and aliases)
3. Transforms ALL other column references in the query
"""
scope_tables = self._extract_scope_tables(select)
# Phase 1: Transform SELECT list expressions
# This handles HASH/NULLIFY/MASK with aliases, and removes HIDE columns
expressions = select.args.get("expressions", [])
new_expressions: list[exp.Expression] = []
for expr in expressions:
if isinstance(expr, exp.Star):
new_expressions.extend(self._transform_star(expr, scope_tables))
else:
transformed = self._transform_expression(expr, scope_tables)
if transformed is not None:
new_expressions.append(transformed)
select.set("expressions", new_expressions)
# Phase 2: Transform ALL other column references
# This is the security-critical phase that catches every column reference
self._transform_all_non_select_columns(select, scope_tables)
return select
def __call__(self, node: exp.Expression) -> exp.Expression:
"""
Transform callback for sqlglot's transform method.
"""
if isinstance(node, exp.Select):
return self.transform_select(node)
return node
def apply_cls(
sql: str,
rules: CLSRules,
engine: str = "base",
schema: dict[str, dict[str, str]] | None = None,
) -> str:
"""
Apply Column-Level Security rules to a SQL query.
This function transforms a SQL query by applying CLS actions to sensitive columns
in both SELECT expressions and predicates (WHERE, ON, HAVING):
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
Args:
sql: The SQL query to transform
rules: CLS rules mapping Table objects to column actions
Example: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
Tables can include schema/catalog for fully qualified matching.
engine: The database engine (used for dialect-specific hash functions)
schema: Optional schema for column qualification. Required for JOINs with
ambiguous column names. Format: {"table": {"column": "TYPE", ...}, ...}
Returns:
The transformed SQL query
"""
if not rules:
return sql
statement = SQLStatement(sql, engine)
statement.apply_cls(rules, schema=schema)
return statement.format(comments=True)
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
# an "internal representation", which is the AST of the SQL statement. For most of the
@@ -887,6 +1509,43 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
transformer = transformers[method](catalog, schema, predicates)
self._parsed = self._parsed.transform(transformer)
def apply_cls(
self,
rules: CLSRules,
schema: dict[str, dict[str, str]] | None = None,
) -> None:
"""
Apply Column-Level Security rules to the statement inplace.
CLS rules transform sensitive columns in SELECT statements and predicates:
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
:param rules: CLS rules mapping Table objects to column actions
Example: {Table("my_table"): {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
:param schema: Optional schema for column qualification. Required for JOINs
with ambiguous column names. Format: {"table": {"column": "TYPE", ...}}
"""
if not rules:
return
# Always attempt to qualify columns for better CLS resolution.
# With schema: full qualification of all columns.
# Without schema: qualifies single-table queries, partial for JOINs.
from sqlglot.optimizer.qualify import qualify
self._parsed = qualify(
self._parsed,
schema=schema,
dialect=self._dialect,
validate_qualify_columns=False,
)
transformer = CLSTransformer(rules, self._dialect)
self._parsed = self._parsed.transform(transformer)
class KQLSplitState(enum.Enum):
"""

File diff suppressed because it is too large Load Diff