feat: column-level security

This commit is contained in:
Beto Dealmeida
2025-12-15 17:01:11 -05:00
parent 57ec3b5a6d
commit 57a210f7d6
2 changed files with 1143 additions and 0 deletions

View File

@@ -129,6 +129,56 @@ class LimitMethod(enum.Enum):
FETCH_MANY = enum.auto()
class CLSAction(enum.Enum):
"""
Column-Level Security actions.
These actions determine how sensitive columns are transformed in queries.
"""
HASH = enum.auto() # Pseudonymization via hashing
NULLIFY = enum.auto() # Replace with NULL
HIDE = enum.auto() # Remove from results entirely
MASK = enum.auto() # Replace with '****'
# Type alias for CLS rules: {table_name: {column_name: action}}
CLSRules = dict[str, dict[str, CLSAction]]
# Hash function patterns by dialect. The placeholder {} will be replaced with the
# column. Some databases need casting for non-string types, so we cast to string/text.
# The fallback uses a literal since there's no universal hash function across all
# databases.
CLS_HASH_FUNCTIONS: dict[Dialects | type[Dialect] | None, str] = {
None: "'[HASHED]'", # Universal fallback - no hash function available
Dialects.DIALECT: "MD5(CAST({} AS VARCHAR))", # Generic SQL with MD5
Dialects.POSTGRES: "MD5(CAST({} AS TEXT))",
Dialects.MYSQL: "MD5(CAST({} AS CHAR))",
Dialects.BIGQUERY: "TO_HEX(MD5(CAST({} AS STRING)))",
Dialects.SNOWFLAKE: "MD5(TO_VARCHAR({}))",
Dialects.REDSHIFT: "MD5(CAST({} AS VARCHAR))",
Dialects.PRESTO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
Dialects.TRINO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
Dialects.SQLITE: "HEX({})", # SQLite doesn't have MD5, use HEX as placeholder
Dialects.DUCKDB: "MD5(CAST({} AS VARCHAR))",
Dialects.ORACLE: "STANDARD_HASH(TO_CHAR({}), 'MD5')",
Dialects.TSQL: (
"CONVERT(VARCHAR(32), HASHBYTES('MD5', CAST({} AS VARCHAR(MAX))), 2)"
),
Dialects.HIVE: "MD5(CAST({} AS STRING))",
Dialects.SPARK: "MD5(CAST({} AS STRING))",
Dialects.CLICKHOUSE: "MD5(toString({}))",
Dialects.DATABRICKS: "MD5(CAST({} AS STRING))",
Dialects.DORIS: "MD5(CAST({} AS VARCHAR))",
Dialects.STARROCKS: "MD5(CAST({} AS VARCHAR))",
Dialects.DRILL: "MD5(CAST({} AS VARCHAR))",
Dialects.DRUID: "MD5(CAST({} AS VARCHAR))",
Dialects.TERADATA: "HASH_MD5(CAST({} AS VARCHAR(10000)))",
Dialects.RISINGWAVE: "MD5(CAST({} AS VARCHAR))",
}
class CTASMethod(enum.Enum):
TABLE = enum.auto()
VIEW = enum.auto()
@@ -279,6 +329,276 @@ class RLSAsSubqueryTransformer(RLSTransformer):
return node
class CLSTransformer:
"""
AST transformer to apply Column-Level Security rules.
This transformer modifies SELECT expressions to apply CLS actions:
- HASH: Replace column with hash function (database-specific)
- NULLIFY: Replace column with NULL AS column_name
- HIDE: Remove column from SELECT entirely
- MASK: Replace column with '****' AS column_name
Example:
Given rules: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
Query: SELECT id, salary, name FROM my_table
Result: SELECT MD5(CAST(id AS TEXT)), NULL AS salary, name FROM my_table
"""
def __init__(
self,
rules: CLSRules,
dialect: Dialects | type[Dialect] | None,
) -> None:
self.rules = self._normalize_rules(rules)
self.dialect = dialect
self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None])
def _normalize_rules(self, rules: CLSRules) -> CLSRules:
"""
Normalize table and column names to lowercase for case-insensitive matching.
"""
return {
table.lower(): {col.lower(): action for col, action in cols.items()}
for table, cols in rules.items()
}
def _get_action(
self,
table_name: str | None,
column_name: str,
) -> CLSAction | None:
"""
Get the CLS action for a column, if any.
"""
if not table_name:
return None
table_rules = self.rules.get(table_name.lower())
if not table_rules:
return None
return table_rules.get(column_name.lower())
def _create_hash_expression(
self,
column: exp.Column,
alias: str,
) -> exp.Expression:
"""
Create a hash expression for a column.
"""
# Generate the column SQL without any alias
col_sql = column.sql(dialect=self.dialect)
hash_sql = self.hash_pattern.format(col_sql)
hash_expr = sqlglot.parse_one(hash_sql, dialect=self.dialect)
return exp.Alias(
this=hash_expr,
alias=exp.Identifier(this=alias),
)
def _create_null_expression(self, alias: str) -> exp.Expression:
"""
Create a NULL AS alias expression.
"""
return exp.Alias(
this=exp.Null(),
alias=exp.Identifier(this=alias),
)
def _create_mask_expression(self, alias: str) -> exp.Expression:
"""
Create a '****' AS alias expression.
"""
return exp.Alias(
this=exp.Literal(this="****", is_string=True),
alias=exp.Identifier(this=alias),
)
def _get_column_alias(self, expr: exp.Expression) -> str:
"""
Get the alias for a column expression.
"""
if isinstance(expr, exp.Alias):
return expr.alias
if isinstance(expr, exp.Column):
return expr.name
return expr.sql(dialect=self.dialect)
def _get_table_for_column(
self,
column: exp.Column,
scope_tables: dict[str, str],
) -> str | None:
"""
Resolve which table a column belongs to.
Args:
column: The column expression
scope_tables: Map of alias/name to actual table name
Returns:
The table name or None if cannot be resolved
"""
if column.table:
# Column is qualified with table name/alias
return scope_tables.get(column.table.lower(), column.table)
# For unqualified columns, if there's only one table in scope,
# we can infer the column belongs to that table
if len(scope_tables) == 1:
return next(iter(scope_tables.values()))
# With multiple tables, check if any table in rules has this column
# This is a best-effort match for unqualified columns
col_lower = column.name.lower()
for table_name in scope_tables.values():
table_rules = self.rules.get(table_name.lower())
if table_rules and col_lower in table_rules:
return table_name
return None
def _extract_scope_tables(self, select: exp.Select) -> dict[str, str]:
"""
Extract table names and aliases from a SELECT statement's FROM clause.
Returns a dict mapping alias (or table name if no alias) to actual table name.
"""
tables: dict[str, str] = {}
if from_clause := select.args.get("from"):
for table in from_clause.find_all(exp.Table):
table_name = table.name
alias = table.alias if table.alias else table_name
tables[alias.lower()] = table_name
for join in select.args.get("joins") or []:
for table in join.find_all(exp.Table):
table_name = table.name
alias = table.alias if table.alias else table_name
tables[alias.lower()] = table_name
return tables
def _transform_expression(
self,
expr: exp.Expression,
scope_tables: dict[str, str],
) -> exp.Expression | None:
"""
Transform a single SELECT expression based on CLS rules.
Returns:
- Transformed expression
- None if the column should be hidden
"""
# Get the underlying column (handle aliases)
column = expr.this if isinstance(expr, exp.Alias) else expr
alias = self._get_column_alias(expr)
if not isinstance(column, exp.Column):
# Not a simple column reference (could be a function, literal, etc.)
return expr
table_name = self._get_table_for_column(column, scope_tables)
action = self._get_action(table_name, column.name)
if action is None:
return expr
if action == CLSAction.HIDE:
return None
if action == CLSAction.HASH:
return self._create_hash_expression(column, alias)
if action == CLSAction.NULLIFY:
return self._create_null_expression(alias)
# action == CLSAction.MASK
return self._create_mask_expression(alias)
def _transform_star(
self,
star: exp.Star,
scope_tables: dict[str, str],
) -> list[exp.Expression]:
"""
Transform SELECT * by expanding hidden columns conceptually.
Since we don't have schema information, we cannot truly expand *.
We return the star as-is but log a warning.
"""
# Without schema information, we cannot expand SELECT *
# In a real implementation, you would need to query the database schema
logger.warning(
"CLS cannot fully process SELECT * without schema information. "
"Consider using explicit column lists for queries with CLS rules."
)
return [star]
def transform_select(self, select: exp.Select) -> exp.Select:
"""
Transform a SELECT statement by applying CLS rules to its expressions.
"""
scope_tables = self._extract_scope_tables(select)
expressions = select.args.get("expressions", [])
new_expressions: list[exp.Expression] = []
for expr in expressions:
if isinstance(expr, exp.Star):
new_expressions.extend(self._transform_star(expr, scope_tables))
else:
transformed = self._transform_expression(expr, scope_tables)
if transformed is not None:
new_expressions.append(transformed)
# Create a new SELECT with transformed expressions
select.set("expressions", new_expressions)
return select
def __call__(self, node: exp.Expression) -> exp.Expression:
"""
Transform callback for sqlglot's transform method.
"""
if isinstance(node, exp.Select):
return self.transform_select(node)
return node
def apply_cls(
sql: str,
rules: CLSRules,
engine: str = "base",
schema: dict[str, dict[str, str]] | None = None,
) -> str:
"""
Apply Column-Level Security rules to a SQL query.
This function transforms a SQL query by applying CLS actions to sensitive columns:
- HASH: Pseudonymize using database-specific hash function
- NULLIFY: Replace with NULL
- HIDE: Remove from SELECT results
- MASK: Replace with '****'
Args:
sql: The SQL query to transform
rules: CLS rules mapping table names to column actions
Example: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
engine: The database engine (used for dialect-specific hash functions)
schema: Optional schema for column qualification. Required for JOINs with
ambiguous column names. Format: {"table": {"column": "TYPE", ...}, ...}
Returns:
The transformed SQL query
"""
if not rules:
return sql
statement = SQLStatement(sql, engine)
statement.apply_cls(rules, schema=schema)
return statement.format(comments=True)
@dataclass(eq=True, frozen=True)
class Table:
"""
@@ -887,6 +1207,43 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
transformer = transformers[method](catalog, schema, predicates)
self._parsed = self._parsed.transform(transformer)
def apply_cls(
self,
rules: CLSRules,
schema: dict[str, dict[str, str]] | None = None,
) -> None:
"""
Apply Column-Level Security rules to the statement inplace.
CLS rules transform sensitive columns in SELECT statements:
- HASH: Pseudonymize using database-specific hash function
- NULLIFY: Replace with NULL
- HIDE: Remove from SELECT results
- MASK: Replace with '****'
:param rules: CLS rules mapping table names to column actions
Example: {"my_table": {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
:param schema: Optional schema for column qualification. Required for JOINs
with ambiguous column names. Format: {"table": {"column": "TYPE", ...}}
"""
if not rules:
return
# Always attempt to qualify columns for better CLS resolution.
# With schema: full qualification of all columns.
# Without schema: qualifies single-table queries, partial for JOINs.
from sqlglot.optimizer.qualify import qualify
self._parsed = qualify(
self._parsed,
schema=schema,
dialect=self._dialect,
validate_qualify_columns=False,
)
transformer = CLSTransformer(rules, self._dialect)
self._parsed = self._parsed.transform(transformer)
class KQLSplitState(enum.Enum):
"""