mirror of
https://github.com/apache/superset.git
synced 2026-05-12 11:25:56 +00:00
feat: column-level security
This commit is contained in:
@@ -129,6 +129,56 @@ class LimitMethod(enum.Enum):
|
||||
FETCH_MANY = enum.auto()
|
||||
|
||||
|
||||
class CLSAction(enum.Enum):
|
||||
"""
|
||||
Column-Level Security actions.
|
||||
|
||||
These actions determine how sensitive columns are transformed in queries.
|
||||
"""
|
||||
|
||||
HASH = enum.auto() # Pseudonymization via hashing
|
||||
NULLIFY = enum.auto() # Replace with NULL
|
||||
HIDE = enum.auto() # Remove from results entirely
|
||||
MASK = enum.auto() # Replace with '****'
|
||||
|
||||
|
||||
# Type alias for CLS rules: {table_name: {column_name: action}}
|
||||
CLSRules = dict[str, dict[str, CLSAction]]
|
||||
|
||||
|
||||
# Hash function patterns by dialect. The placeholder {} will be replaced with the
|
||||
# column. Some databases need casting for non-string types, so we cast to string/text.
|
||||
# The fallback uses a literal since there's no universal hash function across all
|
||||
# databases.
|
||||
CLS_HASH_FUNCTIONS: dict[Dialects | type[Dialect] | None, str] = {
|
||||
None: "'[HASHED]'", # Universal fallback - no hash function available
|
||||
Dialects.DIALECT: "MD5(CAST({} AS VARCHAR))", # Generic SQL with MD5
|
||||
Dialects.POSTGRES: "MD5(CAST({} AS TEXT))",
|
||||
Dialects.MYSQL: "MD5(CAST({} AS CHAR))",
|
||||
Dialects.BIGQUERY: "TO_HEX(MD5(CAST({} AS STRING)))",
|
||||
Dialects.SNOWFLAKE: "MD5(TO_VARCHAR({}))",
|
||||
Dialects.REDSHIFT: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.PRESTO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
|
||||
Dialects.TRINO: "TO_HEX(MD5(CAST({} AS VARBINARY)))",
|
||||
Dialects.SQLITE: "HEX({})", # SQLite doesn't have MD5, use HEX as placeholder
|
||||
Dialects.DUCKDB: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.ORACLE: "STANDARD_HASH(TO_CHAR({}), 'MD5')",
|
||||
Dialects.TSQL: (
|
||||
"CONVERT(VARCHAR(32), HASHBYTES('MD5', CAST({} AS VARCHAR(MAX))), 2)"
|
||||
),
|
||||
Dialects.HIVE: "MD5(CAST({} AS STRING))",
|
||||
Dialects.SPARK: "MD5(CAST({} AS STRING))",
|
||||
Dialects.CLICKHOUSE: "MD5(toString({}))",
|
||||
Dialects.DATABRICKS: "MD5(CAST({} AS STRING))",
|
||||
Dialects.DORIS: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.STARROCKS: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.DRILL: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.DRUID: "MD5(CAST({} AS VARCHAR))",
|
||||
Dialects.TERADATA: "HASH_MD5(CAST({} AS VARCHAR(10000)))",
|
||||
Dialects.RISINGWAVE: "MD5(CAST({} AS VARCHAR))",
|
||||
}
|
||||
|
||||
|
||||
class CTASMethod(enum.Enum):
|
||||
TABLE = enum.auto()
|
||||
VIEW = enum.auto()
|
||||
@@ -279,6 +329,276 @@ class RLSAsSubqueryTransformer(RLSTransformer):
|
||||
return node
|
||||
|
||||
|
||||
class CLSTransformer:
|
||||
"""
|
||||
AST transformer to apply Column-Level Security rules.
|
||||
|
||||
This transformer modifies SELECT expressions to apply CLS actions:
|
||||
- HASH: Replace column with hash function (database-specific)
|
||||
- NULLIFY: Replace column with NULL AS column_name
|
||||
- HIDE: Remove column from SELECT entirely
|
||||
- MASK: Replace column with '****' AS column_name
|
||||
|
||||
Example:
|
||||
Given rules: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
||||
Query: SELECT id, salary, name FROM my_table
|
||||
Result: SELECT MD5(CAST(id AS TEXT)), NULL AS salary, name FROM my_table
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rules: CLSRules,
|
||||
dialect: Dialects | type[Dialect] | None,
|
||||
) -> None:
|
||||
self.rules = self._normalize_rules(rules)
|
||||
self.dialect = dialect
|
||||
self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None])
|
||||
|
||||
def _normalize_rules(self, rules: CLSRules) -> CLSRules:
|
||||
"""
|
||||
Normalize table and column names to lowercase for case-insensitive matching.
|
||||
"""
|
||||
return {
|
||||
table.lower(): {col.lower(): action for col, action in cols.items()}
|
||||
for table, cols in rules.items()
|
||||
}
|
||||
|
||||
def _get_action(
|
||||
self,
|
||||
table_name: str | None,
|
||||
column_name: str,
|
||||
) -> CLSAction | None:
|
||||
"""
|
||||
Get the CLS action for a column, if any.
|
||||
"""
|
||||
if not table_name:
|
||||
return None
|
||||
|
||||
table_rules = self.rules.get(table_name.lower())
|
||||
if not table_rules:
|
||||
return None
|
||||
|
||||
return table_rules.get(column_name.lower())
|
||||
|
||||
def _create_hash_expression(
|
||||
self,
|
||||
column: exp.Column,
|
||||
alias: str,
|
||||
) -> exp.Expression:
|
||||
"""
|
||||
Create a hash expression for a column.
|
||||
"""
|
||||
# Generate the column SQL without any alias
|
||||
col_sql = column.sql(dialect=self.dialect)
|
||||
hash_sql = self.hash_pattern.format(col_sql)
|
||||
hash_expr = sqlglot.parse_one(hash_sql, dialect=self.dialect)
|
||||
return exp.Alias(
|
||||
this=hash_expr,
|
||||
alias=exp.Identifier(this=alias),
|
||||
)
|
||||
|
||||
def _create_null_expression(self, alias: str) -> exp.Expression:
|
||||
"""
|
||||
Create a NULL AS alias expression.
|
||||
"""
|
||||
return exp.Alias(
|
||||
this=exp.Null(),
|
||||
alias=exp.Identifier(this=alias),
|
||||
)
|
||||
|
||||
def _create_mask_expression(self, alias: str) -> exp.Expression:
|
||||
"""
|
||||
Create a '****' AS alias expression.
|
||||
"""
|
||||
return exp.Alias(
|
||||
this=exp.Literal(this="****", is_string=True),
|
||||
alias=exp.Identifier(this=alias),
|
||||
)
|
||||
|
||||
def _get_column_alias(self, expr: exp.Expression) -> str:
|
||||
"""
|
||||
Get the alias for a column expression.
|
||||
"""
|
||||
if isinstance(expr, exp.Alias):
|
||||
return expr.alias
|
||||
if isinstance(expr, exp.Column):
|
||||
return expr.name
|
||||
return expr.sql(dialect=self.dialect)
|
||||
|
||||
def _get_table_for_column(
|
||||
self,
|
||||
column: exp.Column,
|
||||
scope_tables: dict[str, str],
|
||||
) -> str | None:
|
||||
"""
|
||||
Resolve which table a column belongs to.
|
||||
|
||||
Args:
|
||||
column: The column expression
|
||||
scope_tables: Map of alias/name to actual table name
|
||||
|
||||
Returns:
|
||||
The table name or None if cannot be resolved
|
||||
"""
|
||||
if column.table:
|
||||
# Column is qualified with table name/alias
|
||||
return scope_tables.get(column.table.lower(), column.table)
|
||||
|
||||
# For unqualified columns, if there's only one table in scope,
|
||||
# we can infer the column belongs to that table
|
||||
if len(scope_tables) == 1:
|
||||
return next(iter(scope_tables.values()))
|
||||
|
||||
# With multiple tables, check if any table in rules has this column
|
||||
# This is a best-effort match for unqualified columns
|
||||
col_lower = column.name.lower()
|
||||
for table_name in scope_tables.values():
|
||||
table_rules = self.rules.get(table_name.lower())
|
||||
if table_rules and col_lower in table_rules:
|
||||
return table_name
|
||||
|
||||
return None
|
||||
|
||||
def _extract_scope_tables(self, select: exp.Select) -> dict[str, str]:
|
||||
"""
|
||||
Extract table names and aliases from a SELECT statement's FROM clause.
|
||||
|
||||
Returns a dict mapping alias (or table name if no alias) to actual table name.
|
||||
"""
|
||||
tables: dict[str, str] = {}
|
||||
|
||||
if from_clause := select.args.get("from"):
|
||||
for table in from_clause.find_all(exp.Table):
|
||||
table_name = table.name
|
||||
alias = table.alias if table.alias else table_name
|
||||
tables[alias.lower()] = table_name
|
||||
|
||||
for join in select.args.get("joins") or []:
|
||||
for table in join.find_all(exp.Table):
|
||||
table_name = table.name
|
||||
alias = table.alias if table.alias else table_name
|
||||
tables[alias.lower()] = table_name
|
||||
|
||||
return tables
|
||||
|
||||
def _transform_expression(
|
||||
self,
|
||||
expr: exp.Expression,
|
||||
scope_tables: dict[str, str],
|
||||
) -> exp.Expression | None:
|
||||
"""
|
||||
Transform a single SELECT expression based on CLS rules.
|
||||
|
||||
Returns:
|
||||
- Transformed expression
|
||||
- None if the column should be hidden
|
||||
"""
|
||||
# Get the underlying column (handle aliases)
|
||||
column = expr.this if isinstance(expr, exp.Alias) else expr
|
||||
alias = self._get_column_alias(expr)
|
||||
|
||||
if not isinstance(column, exp.Column):
|
||||
# Not a simple column reference (could be a function, literal, etc.)
|
||||
return expr
|
||||
|
||||
table_name = self._get_table_for_column(column, scope_tables)
|
||||
action = self._get_action(table_name, column.name)
|
||||
|
||||
if action is None:
|
||||
return expr
|
||||
|
||||
if action == CLSAction.HIDE:
|
||||
return None
|
||||
if action == CLSAction.HASH:
|
||||
return self._create_hash_expression(column, alias)
|
||||
if action == CLSAction.NULLIFY:
|
||||
return self._create_null_expression(alias)
|
||||
# action == CLSAction.MASK
|
||||
return self._create_mask_expression(alias)
|
||||
|
||||
def _transform_star(
|
||||
self,
|
||||
star: exp.Star,
|
||||
scope_tables: dict[str, str],
|
||||
) -> list[exp.Expression]:
|
||||
"""
|
||||
Transform SELECT * by expanding hidden columns conceptually.
|
||||
|
||||
Since we don't have schema information, we cannot truly expand *.
|
||||
We return the star as-is but log a warning.
|
||||
"""
|
||||
# Without schema information, we cannot expand SELECT *
|
||||
# In a real implementation, you would need to query the database schema
|
||||
logger.warning(
|
||||
"CLS cannot fully process SELECT * without schema information. "
|
||||
"Consider using explicit column lists for queries with CLS rules."
|
||||
)
|
||||
return [star]
|
||||
|
||||
def transform_select(self, select: exp.Select) -> exp.Select:
|
||||
"""
|
||||
Transform a SELECT statement by applying CLS rules to its expressions.
|
||||
"""
|
||||
scope_tables = self._extract_scope_tables(select)
|
||||
expressions = select.args.get("expressions", [])
|
||||
|
||||
new_expressions: list[exp.Expression] = []
|
||||
for expr in expressions:
|
||||
if isinstance(expr, exp.Star):
|
||||
new_expressions.extend(self._transform_star(expr, scope_tables))
|
||||
else:
|
||||
transformed = self._transform_expression(expr, scope_tables)
|
||||
if transformed is not None:
|
||||
new_expressions.append(transformed)
|
||||
|
||||
# Create a new SELECT with transformed expressions
|
||||
select.set("expressions", new_expressions)
|
||||
return select
|
||||
|
||||
def __call__(self, node: exp.Expression) -> exp.Expression:
|
||||
"""
|
||||
Transform callback for sqlglot's transform method.
|
||||
"""
|
||||
if isinstance(node, exp.Select):
|
||||
return self.transform_select(node)
|
||||
return node
|
||||
|
||||
|
||||
def apply_cls(
|
||||
sql: str,
|
||||
rules: CLSRules,
|
||||
engine: str = "base",
|
||||
schema: dict[str, dict[str, str]] | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Apply Column-Level Security rules to a SQL query.
|
||||
|
||||
This function transforms a SQL query by applying CLS actions to sensitive columns:
|
||||
- HASH: Pseudonymize using database-specific hash function
|
||||
- NULLIFY: Replace with NULL
|
||||
- HIDE: Remove from SELECT results
|
||||
- MASK: Replace with '****'
|
||||
|
||||
Args:
|
||||
sql: The SQL query to transform
|
||||
rules: CLS rules mapping table names to column actions
|
||||
Example: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
||||
engine: The database engine (used for dialect-specific hash functions)
|
||||
schema: Optional schema for column qualification. Required for JOINs with
|
||||
ambiguous column names. Format: {"table": {"column": "TYPE", ...}, ...}
|
||||
|
||||
Returns:
|
||||
The transformed SQL query
|
||||
"""
|
||||
if not rules:
|
||||
return sql
|
||||
|
||||
statement = SQLStatement(sql, engine)
|
||||
statement.apply_cls(rules, schema=schema)
|
||||
|
||||
return statement.format(comments=True)
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Table:
|
||||
"""
|
||||
@@ -887,6 +1207,43 @@ class SQLStatement(BaseSQLStatement[exp.Expression]):
|
||||
transformer = transformers[method](catalog, schema, predicates)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
def apply_cls(
|
||||
self,
|
||||
rules: CLSRules,
|
||||
schema: dict[str, dict[str, str]] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Apply Column-Level Security rules to the statement inplace.
|
||||
|
||||
CLS rules transform sensitive columns in SELECT statements:
|
||||
- HASH: Pseudonymize using database-specific hash function
|
||||
- NULLIFY: Replace with NULL
|
||||
- HIDE: Remove from SELECT results
|
||||
- MASK: Replace with '****'
|
||||
|
||||
:param rules: CLS rules mapping table names to column actions
|
||||
Example: {"my_table": {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}}
|
||||
:param schema: Optional schema for column qualification. Required for JOINs
|
||||
with ambiguous column names. Format: {"table": {"column": "TYPE", ...}}
|
||||
"""
|
||||
if not rules:
|
||||
return
|
||||
|
||||
# Always attempt to qualify columns for better CLS resolution.
|
||||
# With schema: full qualification of all columns.
|
||||
# Without schema: qualifies single-table queries, partial for JOINs.
|
||||
from sqlglot.optimizer.qualify import qualify
|
||||
|
||||
self._parsed = qualify(
|
||||
self._parsed,
|
||||
schema=schema,
|
||||
dialect=self._dialect,
|
||||
validate_qualify_columns=False,
|
||||
)
|
||||
|
||||
transformer = CLSTransformer(rules, self._dialect)
|
||||
self._parsed = self._parsed.transform(transformer)
|
||||
|
||||
|
||||
class KQLSplitState(enum.Enum):
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user