mirror of
https://github.com/apache/superset.git
synced 2026-05-01 14:04:21 +00:00
Compare commits
2 Commits
upgrade-sq
...
cls
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c61c40704 | ||
|
|
57a210f7d6 |
@@ -129,6 +129,145 @@ class LimitMethod(enum.Enum):
|
||||
FETCH_MANY = enum.auto()
|
||||
|
||||
|
||||
class CLSAction(enum.Enum):
|
||||
"""
|
||||
Column-Level Security actions.
|
||||
|
||||
These actions determine how sensitive columns are transformed in queries.
|
||||
"""
|
||||
|
||||
HASH = enum.auto() # Pseudonymization via hashing
|
||||
NULLIFY = enum.auto() # Replace with NULL
|
||||
HIDE = enum.auto() # Remove from results entirely
|
||||
MASK = enum.auto() # Replace with '****'
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table:
|
||||
"""
|
||||
A fully qualified SQL table conforming to [[catalog.]schema.]table.
|
||||
"""
|
||||
|
||||
table: str
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return the fully qualified SQL table name.
|
||||
|
||||
Should not be used for SQL generation, only for logging and debugging, since the
|
||||
quoting is not engine-specific.
|
||||
"""
|
||||
return ".".join(
|
||||
urllib.parse.quote(part, safe="").replace(".", "%2E")
|
||||
for part in [self.catalog, self.schema, self.table]
|
||||
if part
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return str(self) == str(other)
|
||||
|
||||
def qualify(
|
||||
self,
|
||||
*,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
) -> Table:
|
||||
"""
|
||||
Return a new Table with the given schema and/or catalog, if not already set.
|
||||
"""
|
||||
return Table(
|
||||
table=self.table,
|
||||
schema=self.schema or schema,
|
||||
catalog=self.catalog or catalog,
|
||||
)
|
||||
|
||||
|
||||
# Type alias for CLS rules: {Table: {column_name: action}}
|
||||
CLSRules = dict[Table, dict[str, CLSAction]]
|
||||
|
||||
# CLS action precedence: higher value = stricter (less information revealed)
|
||||
# HIDE > NULLIFY > MASK > HASH
|
||||
CLS_ACTION_PRECEDENCE: dict[CLSAction, int] = {
|
||||
CLSAction.HASH: 1,
|
||||
CLSAction.MASK: 2,
|
||||
CLSAction.NULLIFY: 3,
|
||||
CLSAction.HIDE: 4,
|
||||
}
|
||||
|
||||
|
||||
def merge_cls_rules(*rules_list: CLSRules) -> CLSRules:
|
||||
"""
|
||||
Merge multiple CLS rule sets into one, using the stricter action when conflicts occur.
|
||||
|
||||
When multiple rules specify actions for the same table/column, the stricter action
|
||||
is kept. Precedence (strictest to least strict): HIDE > NULLIFY > MASK > HASH
|
||||
|
||||
Args:
|
||||
*rules_list: Variable number of CLSRules dicts to merge
|
||||
|
||||
Returns:
|
||||
A merged CLSRules dict with the strictest action for each table/column
|
||||
|
||||
Example:
|
||||
>>> rules1 = {Table("foo"): {"col1": CLSAction.HASH}}
|
||||
>>> rules2 = {Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
|
||||
>>> merge_cls_rules(rules1, rules2)
|
||||
{Table("foo"): {"col1": CLSAction.NULLIFY, "col2": CLSAction.HIDE}}
|
||||
"""
|
||||
merged: CLSRules = {}
|
||||
|
||||
for rules in rules_list:
|
||||
for table, columns in rules.items():
|
||||
if table not in merged:
|
||||
merged[table] = {}
|
||||
|
||||
for column, action in columns.items():
|
||||
existing_action = merged[table].get(column)
|
||||
if existing_action is None:
|
||||
merged[table][column] = action
|
||||
else:
|
||||
# Keep the stricter action (higher precedence value)
|
||||
if CLS_ACTION_PRECEDENCE[action] > CLS_ACTION_PRECEDENCE[existing_action]:
|
||||
merged[table][column] = action
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
# Hash function patterns by dialect. The placeholder {} will be replaced with the
|
||||
# column. Some databases need casting for non-string types, so we cast to string/text.
|
||||
# The fallback uses a literal since there's no universal hash function across all
|
||||
# databases.
|
||||
CLS_HASH_FUNCTIONS: dict[Dialects | type[Dialect] | None, str] = {
|
||||
None: "'[HASHED]'", # Universal fallback - no hash function available
|
||||
Dialects.DIALECT: "MD5(CAST({} AS VARCHAR))", # Generic SQL with MD5
|
||||
Dialects.POSTGRES: "MD5(CAST({} AS TEXT))",
|
||||
Dialects.MYSQL: "MD5(CAST({} AS CHAR))",
|
||||
Dialects.BIGQUERY: "TO_HEX(MD5(CAST({} AS STRING)))",
|
||||
Dialects.SNOWFLAKE: "MD5(TO_VARCHAR({}))",
|
||||
Dialects.REDSHIFT: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.PRESTO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
|
||||
Dialects.TRINO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
|
||||
Dialects.SQLITE: "HEX({})", # SQLite doesn't have MD5, use HEX as placeholder
|
||||
Dialects.DUCKDB: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.ORACLE: "STANDARD_HASH(TO_CHAR({}), 'MD5')",
|
||||
Dialects.TSQL: (
|
||||
"CONVERT(VARCHAR(32), HASHBYTES('MD5', CAST({} AS VARCHAR(MAX))), 2)"
|
||||
),
|
||||
Dialects.HIVE: "MD5(CAST({} AS STRING))",
|
||||
Dialects.SPARK: "MD5(CAST({} AS STRING))",
|
||||
Dialects.CLICKHOUSE: "MD5(toString({}))",
|
||||
Dialects.DATABRICKS: "MD5(CAST({} AS STRING))",
|
||||
Dialects.DORIS: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.STARROCKS: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.DRILL: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.DRUID: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.TERADATA: "HASH_MD5(CAST({} AS VARCHAR(10000)))",
|
||||
Dialects.RISINGWAVE: "MD5(CAST({} AS VARCHAR))",
|
||||
}
|
||||
|
||||
|
||||
class CTASMethod(enum.Enum):
|
||||
TABLE = enum.auto()
|
||||
VIEW = enum.auto()
|
||||
@@ -279,47 +418,530 @@ class RLSAsSubqueryTransformer(RLSTransformer):
|
||||
return node
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table:
|
||||
class CLSTransformer:
|
||||
"""
|
||||
A fully qualified SQL table conforming to [[catalog.]schema.]table.
|
||||
AST transformer to apply Column-Level Security rules.
|
||||
|
||||
This transformer modifies SELECT expressions and predicates to apply CLS actions:
|
||||
- HASH: Replace column with hash function (database-specific)
|
||||
- NULLIFY: Replace with NULL AS column_name (SELECT) or FALSE (predicates)
|
||||
- HIDE: Remove column from SELECT entirely, FALSE in predicates
|
||||
- MASK: Replace column with '****' AS column_name (SELECT) or FALSE (predicates)
|
||||
|
||||
Example:
|
||||
Given rules: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
||||
Query: SELECT id, salary, name FROM my_table WHERE id = 1
|
||||
Result: SELECT MD5(CAST(id AS TEXT)), NULL AS salary, name FROM my_table
|
||||
WHERE MD5(CAST(id AS TEXT)) = 1
|
||||
|
||||
For predicates, HASH transforms the column to ensure filtered results also respect
|
||||
the security policy. NULLIFY/MASK/HIDE transform to FALSE to prevent information
|
||||
leakage through filtering.
|
||||
"""
|
||||
|
||||
table: str
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
Return the fully qualified SQL table name.
|
||||
|
||||
Should not be used for SQL generation, only for logging and debugging, since the
|
||||
quoting is not engine-specific.
|
||||
"""
|
||||
return ".".join(
|
||||
urllib.parse.quote(part, safe="").replace(".", "%2E")
|
||||
for part in [self.catalog, self.schema, self.table]
|
||||
if part
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
return str(self) == str(other)
|
||||
|
||||
def qualify(
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
catalog: str | None = None,
|
||||
rules: CLSRules,
|
||||
dialect: Dialects | type[Dialect] | None,
|
||||
) -> None:
|
||||
self.rules = self._normalize_rules(rules)
|
||||
self.dialect = dialect
|
||||
self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None])
|
||||
|
||||
def _normalize_rules(self, rules: CLSRules) -> dict[Table, dict[str, CLSAction]]:
|
||||
"""
|
||||
Normalize table and column names to lowercase for case-insensitive matching.
|
||||
"""
|
||||
return {
|
||||
Table(
|
||||
table=table.table.lower(),
|
||||
schema=table.schema.lower() if table.schema else None,
|
||||
catalog=table.catalog.lower() if table.catalog else None,
|
||||
): {col.lower(): action for col, action in cols.items()}
|
||||
for table, cols in rules.items()
|
||||
}
|
||||
|
||||
def _get_action(
|
||||
self,
|
||||
table_name: str | None,
|
||||
column_name: str,
|
||||
schema: str | None = None,
|
||||
) -> Table:
|
||||
catalog: str | None = None,
|
||||
) -> CLSAction | None:
|
||||
"""
|
||||
Return a new Table with the given schema and/or catalog, if not already set.
|
||||
Get the CLS action for a column, if any.
|
||||
|
||||
Matching logic:
|
||||
1. First try exact match with schema/catalog if provided
|
||||
2. Fallback to table name match - if table names match, apply the rule
|
||||
regardless of schema/catalog (since query may not have schema info)
|
||||
"""
|
||||
return Table(
|
||||
table=self.table,
|
||||
schema=self.schema or schema,
|
||||
catalog=self.catalog or catalog,
|
||||
if not table_name:
|
||||
return None
|
||||
|
||||
# Create a normalized Table for lookup
|
||||
lookup_table = Table(
|
||||
table=table_name.lower(),
|
||||
schema=schema.lower() if schema else None,
|
||||
catalog=catalog.lower() if catalog else None,
|
||||
)
|
||||
|
||||
# First try exact match with schema/catalog
|
||||
table_rules = self.rules.get(lookup_table)
|
||||
if table_rules:
|
||||
return table_rules.get(column_name.lower())
|
||||
|
||||
# Fallback: match by table name only
|
||||
# This handles cases where the rule has schema/catalog but the query doesn't
|
||||
for rule_table, cols in self.rules.items():
|
||||
if rule_table.table == lookup_table.table:
|
||||
action = cols.get(column_name.lower())
|
||||
if action:
|
||||
return action
|
||||
|
||||
return None
|
||||
|
||||
def _create_hash_expression(
|
||||
self,
|
||||
column: exp.Column,
|
||||
alias: str,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Create a hash expression for a column.
|
||||
"""
|
||||
# Generate the column SQL without any alias
|
||||
col_sql = column.sql(dialect=self.dialect)
|
||||
hash_sql = self.hash_pattern.format(col_sql)
|
||||
hash_expr = sqlglot.parse_one(hash_sql, dialect=self.dialect)
|
||||
return exp.Alias(
|
||||
this=hash_expr,
|
||||
alias=exp.Identifier(this=alias),
|
||||
)
|
||||
|
||||
def _create_null_expression(self, alias: str) -> exp.Expression:
|
||||
"""
|
||||
Create a NULL AS alias expression.
|
||||
"""
|
||||
return exp.Alias(
|
||||
this=exp.Null(),
|
||||
alias=exp.Identifier(this=alias),
|
||||
)
|
||||
|
||||
def _create_mask_expression(
|
||||
self,
|
||||
column: exp.Column,
|
||||
alias: str,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Create a CASE expression that masks non-NULL values while preserving NULLs.
|
||||
|
||||
Generates: CASE WHEN column IS NULL THEN NULL ELSE '****' END AS alias
|
||||
|
||||
This preserves the semantic meaning of NULL (no value) vs masked (hidden value).
|
||||
"""
|
||||
return exp.Alias(
|
||||
this=exp.Case(
|
||||
ifs=[
|
||||
exp.If(
|
||||
this=exp.Is(this=column.copy(), expression=exp.Null()),
|
||||
true=exp.Null(),
|
||||
)
|
||||
],
|
||||
default=exp.Literal(this="****", is_string=True),
|
||||
),
|
||||
alias=exp.Identifier(this=alias),
|
||||
)
|
||||
|
||||
def _create_hash_expression_no_alias(
|
||||
self,
|
||||
column: exp.Column,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Create a hash expression for a column without an alias.
|
||||
|
||||
Used for transforming columns in predicates (WHERE, ON, etc.).
|
||||
"""
|
||||
col_sql = column.sql(dialect=self.dialect)
|
||||
hash_sql = self.hash_pattern.format(col_sql)
|
||||
return sqlglot.parse_one(hash_sql, dialect=self.dialect)
|
||||
|
||||
def _get_column_alias(self, expr: exp.Expression) -> str:
|
||||
"""
|
||||
Get the alias for a column expression.
|
||||
"""
|
||||
if isinstance(expr, exp.Alias):
|
||||
return expr.alias
|
||||
if isinstance(expr, exp.Column):
|
||||
return expr.name
|
||||
return expr.sql(dialect=self.dialect)
|
||||
|
||||
def _get_table_for_column(
|
||||
self,
|
||||
column: exp.Column,
|
||||
scope_tables: dict[str, str],
|
||||
) -> str | None:
|
||||
"""
|
||||
Resolve which table a column belongs to.
|
||||
|
||||
Args:
|
||||
column: The column expression
|
||||
scope_tables: Map of alias/name to actual table name
|
||||
|
||||
Returns:
|
||||
The table name or None if cannot be resolved
|
||||
"""
|
||||
if column.table:
|
||||
# Column is qualified with table name/alias
|
||||
return scope_tables.get(column.table.lower(), column.table)
|
||||
|
||||
# For unqualified columns, if there's only one table in scope,
|
||||
# we can infer the column belongs to that table
|
||||
if len(scope_tables) == 1:
|
||||
return next(iter(scope_tables.values()))
|
||||
|
||||
# With multiple tables, check if any table in rules has this column
|
||||
# This is a best-effort match for unqualified columns
|
||||
col_lower = column.name.lower()
|
||||
for table_name in scope_tables.values():
|
||||
# Look for a rule matching this table
|
||||
for rule_table, cols in self.rules.items():
|
||||
if rule_table.table == table_name.lower() and col_lower in cols:
|
||||
return table_name
|
||||
|
||||
return None
|
||||
|
||||
def _extract_scope_tables(self, select: exp.Select) -> dict[str, str]:
|
||||
"""
|
||||
Extract table names and aliases from a SELECT statement's FROM clause.
|
||||
|
||||
Returns a dict mapping alias (or table name if no alias) to actual table name.
|
||||
"""
|
||||
tables: dict[str, str] = {}
|
||||
|
||||
if from_clause := select.args.get("from"):
|
||||
for table in from_clause.find_all(exp.Table):
|
||||
table_name = table.name
|
||||
alias = table.alias if table.alias else table_name
|
||||
tables[alias.lower()] = table_name
|
||||
|
||||
for join in select.args.get("joins") or []:
|
||||
for table in join.find_all(exp.Table):
|
||||
table_name = table.name
|
||||
alias = table.alias if table.alias else table_name
|
||||
tables[alias.lower()] = table_name
|
||||
|
||||
return tables
|
||||
|
||||
def _transform_nested_column(
|
||||
self,
|
||||
column: exp.Column,
|
||||
scope_tables: dict[str, str],
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Transform a nested column reference within a SELECT expression.
|
||||
|
||||
This handles columns inside CASE expressions, function arguments, etc.
|
||||
Unlike top-level columns, nested columns use NULL for blocking instead
|
||||
of FALSE (which works better in non-predicate contexts).
|
||||
|
||||
- HASH: Replace with hash function
|
||||
- NULLIFY/MASK/HIDE: Replace with NULL (blocks computation safely)
|
||||
"""
|
||||
table_name = self._get_table_for_column(column, scope_tables)
|
||||
action = self._get_action(table_name, column.name)
|
||||
|
||||
if action is None:
|
||||
return column
|
||||
|
||||
if action == CLSAction.HASH:
|
||||
return self._create_hash_expression_no_alias(column)
|
||||
|
||||
# NULLIFY/MASK/HIDE: Return NULL to safely block any computation
|
||||
# NULL propagates through expressions: UPPER(NULL)→NULL, 1+NULL→NULL, etc.
|
||||
return exp.Null()
|
||||
|
||||
def _transform_expression(
|
||||
self,
|
||||
expr: exp.Expression,
|
||||
scope_tables: dict[str, str],
|
||||
) -> exp.Expression | None:
|
||||
"""
|
||||
Transform a single SELECT expression based on CLS rules.
|
||||
|
||||
For simple column references: apply full transformation with alias.
|
||||
For complex expressions: transform all nested column references.
|
||||
|
||||
Returns:
|
||||
- Transformed expression
|
||||
- None if a top-level column should be hidden
|
||||
"""
|
||||
# Get the underlying column (handle aliases)
|
||||
column = expr.this if isinstance(expr, exp.Alias) else expr
|
||||
alias = self._get_column_alias(expr)
|
||||
|
||||
if isinstance(column, exp.Column):
|
||||
# Simple column reference - apply full transformation with alias
|
||||
table_name = self._get_table_for_column(column, scope_tables)
|
||||
action = self._get_action(table_name, column.name)
|
||||
|
||||
if action is None:
|
||||
return expr
|
||||
|
||||
if action == CLSAction.HIDE:
|
||||
return None
|
||||
if action == CLSAction.HASH:
|
||||
return self._create_hash_expression(column, alias)
|
||||
if action == CLSAction.NULLIFY:
|
||||
return self._create_null_expression(alias)
|
||||
# action == CLSAction.MASK
|
||||
return self._create_mask_expression(column, alias)
|
||||
|
||||
# Complex expression (CASE, function, arithmetic, etc.)
|
||||
# Transform ALL nested column references within it
|
||||
def transform_nested(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Column):
|
||||
return self._transform_nested_column(node, scope_tables)
|
||||
return node
|
||||
|
||||
return expr.transform(transform_nested)
|
||||
|
||||
def _transform_star(
|
||||
self,
|
||||
star: exp.Star,
|
||||
scope_tables: dict[str, str],
|
||||
) -> list[exp.Expression]:
|
||||
"""
|
||||
Transform SELECT * by expanding hidden columns conceptually.
|
||||
|
||||
Since we don't have schema information, we cannot truly expand *.
|
||||
We return the star as-is but log a warning.
|
||||
"""
|
||||
# Without schema information, we cannot expand SELECT *
|
||||
# In a real implementation, you would need to query the database schema
|
||||
logger.warning(
|
||||
"CLS cannot fully process SELECT * without schema information. "
|
||||
"Consider using explicit column lists for queries with CLS rules."
|
||||
)
|
||||
return [star]
|
||||
|
||||
def _transform_non_select_column(
|
||||
self,
|
||||
column: exp.Column,
|
||||
scope_tables: dict[str, str],
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Transform a column reference outside of SELECT list.
|
||||
|
||||
This is the SINGLE transformation function for ALL column references
|
||||
outside the SELECT list (WHERE, HAVING, ON, GROUP BY, ORDER BY,
|
||||
window functions, CASE expressions, function arguments, etc.)
|
||||
|
||||
- HASH: Replace with hash function
|
||||
- NULLIFY/MASK/HIDE: Replace with FALSE (blocks predicates, marked for
|
||||
removal in GROUP BY/ORDER BY)
|
||||
"""
|
||||
table_name = self._get_table_for_column(column, scope_tables)
|
||||
action = self._get_action(table_name, column.name)
|
||||
|
||||
if action is None:
|
||||
return column
|
||||
|
||||
if action == CLSAction.HASH:
|
||||
return self._create_hash_expression_no_alias(column)
|
||||
|
||||
# NULLIFY/MASK/HIDE: Return FALSE to block usage
|
||||
# For predicates: FALSE blocks the filter
|
||||
# For GROUP BY/ORDER BY: Will be cleaned up in post-processing
|
||||
return exp.false()
|
||||
|
||||
@staticmethod
|
||||
def _is_blocked(node: exp.Expression) -> bool:
|
||||
"""Check if an expression is a blocked column (FALSE or NULL sentinel)."""
|
||||
# FALSE is used for blocked columns in predicates (Phase 2)
|
||||
# NULL is used for blocked columns in nested expressions (Phase 1)
|
||||
if isinstance(node, exp.Boolean) and not node.this:
|
||||
return True
|
||||
if isinstance(node, exp.Null):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _transform_all_non_select_columns(
|
||||
self,
|
||||
select: exp.Select,
|
||||
scope_tables: dict[str, str],
|
||||
) -> None:
|
||||
"""
|
||||
Transform ALL column references outside the SELECT list.
|
||||
|
||||
This uses sqlglot's transform() to recursively walk through the entire
|
||||
expression tree, ensuring we catch columns in:
|
||||
- WHERE clauses
|
||||
- HAVING clauses
|
||||
- JOIN ON conditions
|
||||
- GROUP BY clauses
|
||||
- ORDER BY clauses
|
||||
- Window function PARTITION BY / ORDER BY
|
||||
- CASE expressions
|
||||
- Function arguments
|
||||
- Any other nested expression
|
||||
|
||||
This is the security-critical function that ensures NO column reference
|
||||
is missed.
|
||||
"""
|
||||
|
||||
def transform_column(node: exp.Expression) -> exp.Expression:
|
||||
if isinstance(node, exp.Column):
|
||||
return self._transform_non_select_column(node, scope_tables)
|
||||
return node
|
||||
|
||||
# Transform WHERE
|
||||
if where := select.args.get("where"):
|
||||
transformed = where.this.transform(transform_column)
|
||||
select.set("where", exp.Where(this=transformed))
|
||||
|
||||
# Transform HAVING
|
||||
if having := select.args.get("having"):
|
||||
transformed = having.this.transform(transform_column)
|
||||
select.set("having", exp.Having(this=transformed))
|
||||
|
||||
# Transform all JOINs (ON conditions)
|
||||
for join in select.args.get("joins") or []:
|
||||
if on := join.args.get("on"):
|
||||
transformed = on.transform(transform_column)
|
||||
join.set("on", transformed)
|
||||
|
||||
# Transform GROUP BY and remove blocked (FALSE) expressions
|
||||
if group := select.args.get("group"):
|
||||
new_exprs = []
|
||||
for expr in group.expressions:
|
||||
transformed = expr.transform(transform_column)
|
||||
if not self._is_blocked(transformed):
|
||||
new_exprs.append(transformed)
|
||||
if new_exprs:
|
||||
group.set("expressions", new_exprs)
|
||||
else:
|
||||
select.set("group", None)
|
||||
|
||||
# Transform ORDER BY and remove blocked (FALSE) expressions
|
||||
if order := select.args.get("order"):
|
||||
new_exprs = []
|
||||
for ordered in order.expressions:
|
||||
transformed = ordered.transform(transform_column)
|
||||
# Check the inner expression (Ordered wraps the actual expr)
|
||||
inner = transformed.this if isinstance(transformed, exp.Ordered) else transformed
|
||||
if not self._is_blocked(inner):
|
||||
new_exprs.append(transformed)
|
||||
if new_exprs:
|
||||
order.set("expressions", new_exprs)
|
||||
else:
|
||||
select.set("order", None)
|
||||
|
||||
# Transform Window functions within SELECT expressions
|
||||
# Window functions have their own PARTITION BY and ORDER BY clauses
|
||||
for expr in select.args.get("expressions", []):
|
||||
for window in expr.find_all(exp.Window):
|
||||
# Transform PARTITION BY
|
||||
if partition_by := window.args.get("partition_by"):
|
||||
new_partition = []
|
||||
for part_expr in partition_by:
|
||||
transformed = part_expr.transform(transform_column)
|
||||
if not self._is_blocked(transformed):
|
||||
new_partition.append(transformed)
|
||||
window.set("partition_by", new_partition if new_partition else None)
|
||||
|
||||
# Transform ORDER BY within window
|
||||
if window_order := window.args.get("order"):
|
||||
new_order_exprs = []
|
||||
for ordered in window_order.expressions:
|
||||
transformed = ordered.transform(transform_column)
|
||||
inner = (
|
||||
transformed.this
|
||||
if isinstance(transformed, exp.Ordered)
|
||||
else transformed
|
||||
)
|
||||
if not self._is_blocked(inner):
|
||||
new_order_exprs.append(transformed)
|
||||
if new_order_exprs:
|
||||
window_order.set("expressions", new_order_exprs)
|
||||
else:
|
||||
window.set("order", None)
|
||||
|
||||
def transform_select(self, select: exp.Select) -> exp.Select:
|
||||
"""
|
||||
Transform a SELECT statement by applying CLS rules.
|
||||
|
||||
This is the main entry point for CLS transformation. It:
|
||||
1. Extracts table scope for column resolution
|
||||
2. Transforms SELECT list expressions (with HIDE removal and aliases)
|
||||
3. Transforms ALL other column references in the query
|
||||
"""
|
||||
scope_tables = self._extract_scope_tables(select)
|
||||
|
||||
# Phase 1: Transform SELECT list expressions
|
||||
# This handles HASH/NULLIFY/MASK with aliases, and removes HIDE columns
|
||||
expressions = select.args.get("expressions", [])
|
||||
new_expressions: list[exp.Expression] = []
|
||||
for expr in expressions:
|
||||
if isinstance(expr, exp.Star):
|
||||
new_expressions.extend(self._transform_star(expr, scope_tables))
|
||||
else:
|
||||
transformed = self._transform_expression(expr, scope_tables)
|
||||
if transformed is not None:
|
||||
new_expressions.append(transformed)
|
||||
select.set("expressions", new_expressions)
|
||||
|
||||
# Phase 2: Transform ALL other column references
|
||||
# This is the security-critical phase that catches every column reference
|
||||
self._transform_all_non_select_columns(select, scope_tables)
|
||||
|
||||
return select
|
||||
|
||||
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Transform callback for sqlglot's transform method.
|
||||
"""
|
||||
if isinstance(node, exp.Select):
|
||||
return self.transform_select(node)
|
||||
return node
|
||||
|
||||
|
||||
def apply_cls(
|
||||
sql: str,
|
||||
rules: CLSRules,
|
||||
engine: str = "base",
|
||||
schema: dict[str, dict[str, str]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Apply Column-Level Security rules to a SQL query.
|
||||
|
||||
This function transforms a SQL query by applying CLS actions to sensitive columns
|
||||
in both SELECT expressions and predicates (WHERE, ON, HAVING):
|
||||
|
||||
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
|
||||
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
|
||||
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
|
||||
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
|
||||
|
||||
Args:
|
||||
sql: The SQL query to transform
|
||||
rules: CLS rules mapping Table objects to column actions
|
||||
Example: {Table("my_table"): {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
||||
Tables can include schema/catalog for fully qualified matching.
|
||||
engine: The database engine (used for dialect-specific hash functions)
|
||||
schema: Optional schema for column qualification. Required for JOINs with
|
||||
ambiguous column names. Format: {"table": {"column": "TYPE", ...}, ...}
|
||||
|
||||
Returns:
|
||||
The transformed SQL query
|
||||
"""
|
||||
if not rules:
|
||||
return sql
|
||||
|
||||
statement = SQLStatement(sql, engine)
|
||||
statement.apply_cls(rules, schema=schema)
|
||||
|
||||
return statement.format(comments=True)
|
||||
|
||||
|
||||
# To avoid unnecessary parsing/formatting of queries, the statement has the concept of
|
||||
# an "internal representation", which is the AST of the SQL statement. For most of the
|
||||
@@ -887,6 +1509,43 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
transformer = transformers[method](catalog, schema, predicates)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
def apply_cls(
|
||||
self,
|
||||
rules: CLSRules,
|
||||
schema: dict[str, dict[str, str]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Apply Column-Level Security rules to the statement inplace.
|
||||
|
||||
CLS rules transform sensitive columns in SELECT statements and predicates:
|
||||
- HASH: Pseudonymize using database-specific hash function (both SELECT and predicates)
|
||||
- NULLIFY: Replace with NULL (SELECT), FALSE in predicates to block filtering
|
||||
- HIDE: Remove from SELECT results, FALSE in predicates to block filtering
|
||||
- MASK: Replace with '****' (SELECT), FALSE in predicates to block filtering
|
||||
|
||||
:param rules: CLS rules mapping Table objects to column actions
|
||||
Example: {Table("my_table"): {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
||||
:param schema: Optional schema for column qualification. Required for JOINs
|
||||
with ambiguous column names. Format: {"table": {"column": "TYPE", ...}}
|
||||
"""
|
||||
if not rules:
|
||||
return
|
||||
|
||||
# Always attempt to qualify columns for better CLS resolution.
|
||||
# With schema: full qualification of all columns.
|
||||
# Without schema: qualifies single-table queries, partial for JOINs.
|
||||
from sqlglot.optimizer.qualify import qualify
|
||||
|
||||
self._parsed = qualify(
|
||||
self._parsed,
|
||||
schema=schema,
|
||||
dialect=self._dialect,
|
||||
validate_qualify_columns=False,
|
||||
)
|
||||
|
||||
transformer = CLSTransformer(rules, self._dialect)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user