diff --git a/superset/sql/parse.py b/superset/sql/parse.py index ba4d288a972..f8b50f2d3b9 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -129,6 +129,56 @@ class LimitMethod(enum.Enum): FETCH_MANY = enum.auto() +class CLSAction(enum.Enum): + """ + Column-Level Security actions. + + These actions determine how sensitive columns are transformed in queries. + """ + + HASH = enum.auto() # Pseudonymization via hashing + NULLIFY = enum.auto() # Replace with NULL + HIDE = enum.auto() # Remove from results entirely + MASK = enum.auto() # Replace with '****' + + +# Type alias for CLS rules: {table_name: {column_name: action}} +CLSRules = dict[str, dict[str, CLSAction]] + + +# Hash function patterns by dialect. The placeholder {} will be replaced with the +# column. Some databases need casting for non-string types, so we cast to string/text. +# The fallback uses a literal since there's no universal hash function across all +# databases. +CLS_HASH_FUNCTIONS: dict[Dialects | type[Dialect] | None, str] = { + None: "'[HASHED]'", # Universal fallback - no hash function available + Dialects.DIALECT: "MD5(CAST({} AS VARCHAR))", # Generic SQL with MD5 + Dialects.POSTGRES: "MD5(CAST({} AS TEXT))", + Dialects.MYSQL: "MD5(CAST({} AS CHAR))", + Dialects.BIGQUERY: "TO_HEX(MD5(CAST({} AS STRING)))", + Dialects.SNOWFLAKE: "MD5(TO_VARCHAR({}))", + Dialects.REDSHIFT: "MD5(CAST({} AS VARCHAR))", + Dialects.PRESTO: "TO_HEX(MD5(CAST({} AS VARBINARY)))", + Dialects.TRINO: "TO_HEX(MD5(CAST({} AS VARBINARY)))", + Dialects.SQLITE: "HEX({})", # SQLite doesn't have MD5, use HEX as placeholder + Dialects.DUCKDB: "MD5(CAST({} AS VARCHAR))", + Dialects.ORACLE: "STANDARD_HASH(TO_CHAR({}), 'MD5')", + Dialects.TSQL: ( + "CONVERT(VARCHAR(32), HASHBYTES('MD5', CAST({} AS VARCHAR(MAX))), 2)" + ), + Dialects.HIVE: "MD5(CAST({} AS STRING))", + Dialects.SPARK: "MD5(CAST({} AS STRING))", + Dialects.CLICKHOUSE: "MD5(toString({}))", + Dialects.DATABRICKS: "MD5(CAST({} AS STRING))", + Dialects.DORIS: "MD5(CAST({} AS VARCHAR))", + Dialects.STARROCKS: "MD5(CAST({} AS VARCHAR))", + Dialects.DRILL: "MD5(CAST({} AS VARCHAR))", + Dialects.DRUID: "MD5(CAST({} AS VARCHAR))", + Dialects.TERADATA: "HASH_MD5(CAST({} AS VARCHAR(10000)))", + Dialects.RISINGWAVE: "MD5(CAST({} AS VARCHAR))", +} + + class CTASMethod(enum.Enum): TABLE = enum.auto() VIEW = enum.auto() @@ -279,6 +329,276 @@ class RLSAsSubqueryTransformer(RLSTransformer): return node +class CLSTransformer: + """ + AST transformer to apply Column-Level Security rules. + + This transformer modifies SELECT expressions to apply CLS actions: + - HASH: Replace column with hash function (database-specific) + - NULLIFY: Replace column with NULL AS column_name + - HIDE: Remove column from SELECT entirely + - MASK: Replace column with '****' AS column_name + + Example: + Given rules: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}} + Query: SELECT id, salary, name FROM my_table + Result: SELECT MD5(CAST(id AS TEXT)), NULL AS salary, name FROM my_table + """ + + def __init__( + self, + rules: CLSRules, + dialect: Dialects | type[Dialect] | None, + ) -> None: + self.rules = self._normalize_rules(rules) + self.dialect = dialect + self.hash_pattern = CLS_HASH_FUNCTIONS.get(dialect, CLS_HASH_FUNCTIONS[None]) + + def _normalize_rules(self, rules: CLSRules) -> CLSRules: + """ + Normalize table and column names to lowercase for case-insensitive matching. + """ + return { + table.lower(): {col.lower(): action for col, action in cols.items()} + for table, cols in rules.items() + } + + def _get_action( + self, + table_name: str | None, + column_name: str, + ) -> CLSAction | None: + """ + Get the CLS action for a column, if any. + """ + if not table_name: + return None + + table_rules = self.rules.get(table_name.lower()) + if not table_rules: + return None + + return table_rules.get(column_name.lower()) + + def _create_hash_expression( + self, + column: exp.Column, + alias: str, + ) -> exp.Expression: + """ + Create a hash expression for a column. + """ + # Generate the column SQL without any alias + col_sql = column.sql(dialect=self.dialect) + hash_sql = self.hash_pattern.format(col_sql) + hash_expr = sqlglot.parse_one(hash_sql, dialect=self.dialect) + return exp.Alias( + this=hash_expr, + alias=exp.Identifier(this=alias), + ) + + def _create_null_expression(self, alias: str) -> exp.Expression: + """ + Create a NULL AS alias expression. + """ + return exp.Alias( + this=exp.Null(), + alias=exp.Identifier(this=alias), + ) + + def _create_mask_expression(self, alias: str) -> exp.Expression: + """ + Create a '****' AS alias expression. + """ + return exp.Alias( + this=exp.Literal(this="****", is_string=True), + alias=exp.Identifier(this=alias), + ) + + def _get_column_alias(self, expr: exp.Expression) -> str: + """ + Get the alias for a column expression. + """ + if isinstance(expr, exp.Alias): + return expr.alias + if isinstance(expr, exp.Column): + return expr.name + return expr.sql(dialect=self.dialect) + + def _get_table_for_column( + self, + column: exp.Column, + scope_tables: dict[str, str], + ) -> str | None: + """ + Resolve which table a column belongs to. + + Args: + column: The column expression + scope_tables: Map of alias/name to actual table name + + Returns: + The table name or None if cannot be resolved + """ + if column.table: + # Column is qualified with table name/alias + return scope_tables.get(column.table.lower(), column.table) + + # For unqualified columns, if there's only one table in scope, + # we can infer the column belongs to that table + if len(scope_tables) == 1: + return next(iter(scope_tables.values())) + + # With multiple tables, check if any table in rules has this column + # This is a best-effort match for unqualified columns + col_lower = column.name.lower() + for table_name in scope_tables.values(): + table_rules = self.rules.get(table_name.lower()) + if table_rules and col_lower in table_rules: + return table_name + + return None + + def _extract_scope_tables(self, select: exp.Select) -> dict[str, str]: + """ + Extract table names and aliases from a SELECT statement's FROM clause. + + Returns a dict mapping alias (or table name if no alias) to actual table name. + """ + tables: dict[str, str] = {} + + if from_clause := select.args.get("from"): + for table in from_clause.find_all(exp.Table): + table_name = table.name + alias = table.alias if table.alias else table_name + tables[alias.lower()] = table_name + + for join in select.args.get("joins") or []: + for table in join.find_all(exp.Table): + table_name = table.name + alias = table.alias if table.alias else table_name + tables[alias.lower()] = table_name + + return tables + + def _transform_expression( + self, + expr: exp.Expression, + scope_tables: dict[str, str], + ) -> exp.Expression | None: + """ + Transform a single SELECT expression based on CLS rules. + + Returns: + - Transformed expression + - None if the column should be hidden + """ + # Get the underlying column (handle aliases) + column = expr.this if isinstance(expr, exp.Alias) else expr + alias = self._get_column_alias(expr) + + if not isinstance(column, exp.Column): + # Not a simple column reference (could be a function, literal, etc.) + return expr + + table_name = self._get_table_for_column(column, scope_tables) + action = self._get_action(table_name, column.name) + + if action is None: + return expr + + if action == CLSAction.HIDE: + return None + if action == CLSAction.HASH: + return self._create_hash_expression(column, alias) + if action == CLSAction.NULLIFY: + return self._create_null_expression(alias) + # action == CLSAction.MASK + return self._create_mask_expression(alias) + + def _transform_star( + self, + star: exp.Star, + scope_tables: dict[str, str], + ) -> list[exp.Expression]: + """ + Transform SELECT * by expanding hidden columns conceptually. + + Since we don't have schema information, we cannot truly expand *. + We return the star as-is but log a warning. + """ + # Without schema information, we cannot expand SELECT * + # In a real implementation, you would need to query the database schema + logger.warning( + "CLS cannot fully process SELECT * without schema information. " + "Consider using explicit column lists for queries with CLS rules." + ) + return [star] + + def transform_select(self, select: exp.Select) -> exp.Select: + """ + Transform a SELECT statement by applying CLS rules to its expressions. + """ + scope_tables = self._extract_scope_tables(select) + expressions = select.args.get("expressions", []) + + new_expressions: list[exp.Expression] = [] + for expr in expressions: + if isinstance(expr, exp.Star): + new_expressions.extend(self._transform_star(expr, scope_tables)) + else: + transformed = self._transform_expression(expr, scope_tables) + if transformed is not None: + new_expressions.append(transformed) + + # Create a new SELECT with transformed expressions + select.set("expressions", new_expressions) + return select + + def __call__(self, node: exp.Expression) -> exp.Expression: + """ + Transform callback for sqlglot's transform method. + """ + if isinstance(node, exp.Select): + return self.transform_select(node) + return node + + +def apply_cls( + sql: str, + rules: CLSRules, + engine: str = "base", + schema: dict[str, dict[str, str]] | None = None, +) -> str: + """ + Apply Column-Level Security rules to a SQL query. + + This function transforms a SQL query by applying CLS actions to sensitive columns: + - HASH: Pseudonymize using database-specific hash function + - NULLIFY: Replace with NULL + - HIDE: Remove from SELECT results + - MASK: Replace with '****' + + Args: + sql: The SQL query to transform + rules: CLS rules mapping table names to column actions + Example: {"my_table": {"id": CLSAction.HASH, "salary": CLSAction.NULLIFY}} + engine: The database engine (used for dialect-specific hash functions) + schema: Optional schema for column qualification. Required for JOINs with + ambiguous column names. Format: {"table": {"column": "TYPE", ...}, ...} + + Returns: + The transformed SQL query + """ + if not rules: + return sql + + statement = SQLStatement(sql, engine) + statement.apply_cls(rules, schema=schema) + + return statement.format(comments=True) + + @dataclass(eq=True, frozen=True) class Table: """ @@ -887,6 +1207,43 @@ class SQLStatement(BaseSQLStatement[exp.Expression]): transformer = transformers[method](catalog, schema, predicates) self._parsed = self._parsed.transform(transformer) + def apply_cls( + self, + rules: CLSRules, + schema: dict[str, dict[str, str]] | None = None, + ) -> None: + """ + Apply Column-Level Security rules to the statement inplace. + + CLS rules transform sensitive columns in SELECT statements: + - HASH: Pseudonymize using database-specific hash function + - NULLIFY: Replace with NULL + - HIDE: Remove from SELECT results + - MASK: Replace with '****' + + :param rules: CLS rules mapping table names to column actions + Example: {"my_table": {"ssn": CLSAction.HASH, "salary": CLSAction.NULLIFY}} + :param schema: Optional schema for column qualification. Required for JOINs + with ambiguous column names. Format: {"table": {"column": "TYPE", ...}} + """ + if not rules: + return + + # Always attempt to qualify columns for better CLS resolution. + # With schema: full qualification of all columns. + # Without schema: qualifies single-table queries, partial for JOINs. + from sqlglot.optimizer.qualify import qualify + + self._parsed = qualify( + self._parsed, + schema=schema, + dialect=self._dialect, + validate_qualify_columns=False, + ) + + transformer = CLSTransformer(rules, self._dialect) + self._parsed = self._parsed.transform(transformer) + class KQLSplitState(enum.Enum): """ diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 0af93696595..f5a10a3c702 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -17,6 +17,8 @@ # pylint: disable=invalid-name, redefined-outer-name, too-many-lines +from typing import Any + import pytest from pytest_mock import MockerFixture from sqlglot import Dialects, exp, parse_one @@ -24,6 +26,10 @@ from sqlglot import Dialects, exp, parse_one from superset.exceptions import QueryClauseValidationException, SupersetParseError from superset.jinja_context import JinjaTemplateProcessor from superset.sql.parse import ( + apply_cls, + CLS_HASH_FUNCTIONS, + CLSAction, + CLSTransformer, CTASMethod, extract_tables_from_statement, JinjaSQLResult, @@ -2977,3 +2983,783 @@ def test_has_subquery(sql: str, engine: str, expected: bool) -> None: Test the `has_subquery` method. """ assert SQLStatement(sql, engine).has_subquery() == expected + + +# ============================================================================= +# Column-Level Security (CLS) Tests +# ============================================================================= + + +def test_cls_action_enum() -> None: + """ + Test CLSAction enum values exist. + """ + assert CLSAction.HASH is not None + assert CLSAction.NULLIFY is not None + assert CLSAction.HIDE is not None + assert CLSAction.MASK is not None + + +def test_cls_hash_functions_mapping() -> None: + """ + Test that CLS_HASH_FUNCTIONS has entries for common dialects. + """ + # Check fallback exists + assert None in CLS_HASH_FUNCTIONS + assert CLS_HASH_FUNCTIONS[None] == "'[HASHED]'" + + # Check common dialects + assert Dialects.POSTGRES in CLS_HASH_FUNCTIONS + assert Dialects.MYSQL in CLS_HASH_FUNCTIONS + assert Dialects.BIGQUERY in CLS_HASH_FUNCTIONS + assert Dialects.SNOWFLAKE in CLS_HASH_FUNCTIONS + + # Verify hash patterns contain placeholder + for dialect, pattern in CLS_HASH_FUNCTIONS.items(): + if dialect is not None and pattern != "'[HASHED]'": + assert "{}" in pattern, f"Missing placeholder in {dialect} hash pattern" + + +def test_apply_cls_empty_rules() -> None: + """ + Test that apply_cls returns original SQL when rules are empty. + """ + sql = "SELECT id, name FROM users" + result = apply_cls(sql, {}, engine="postgresql") + assert result == sql + + +def test_apply_cls_hash_action() -> None: + """ + Test CLSAction.HASH transforms column with hash function. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "SELECT ssn, name FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n' + ' "users"."name" AS "name"\n' + 'FROM "users" AS "users"' + ) + + +def test_apply_cls_nullify_action() -> None: + """ + Test CLSAction.NULLIFY transforms column to NULL. + """ + rules = {"users": {"salary": CLSAction.NULLIFY}} + sql = "SELECT salary, name FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + 'SELECT\n NULL AS salary,\n "users"."name" AS "name"\nFROM "users" AS "users"' + ) + + +def test_apply_cls_hide_action() -> None: + """ + Test CLSAction.HIDE removes column from SELECT. + """ + rules = {"users": {"password": CLSAction.HIDE}} + sql = "SELECT password, name FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ('SELECT\n "users"."name" AS "name"\nFROM "users" AS "users"') + + +def test_apply_cls_mask_action() -> None: + """ + Test CLSAction.MASK transforms column to '****'. + """ + rules = {"users": {"phone": CLSAction.MASK}} + sql = "SELECT phone, name FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + " '****' AS phone,\n" + ' "users"."name" AS "name"\n' + 'FROM "users" AS "users"' + ) + + +def test_apply_cls_all_actions() -> None: + """ + Test all CLS actions in a single query. + """ + rules = { + "employees": { + "ssn": CLSAction.HASH, + "salary": CLSAction.NULLIFY, + "password": CLSAction.HIDE, + "phone": CLSAction.MASK, + } + } + sql = "SELECT ssn, salary, password, phone, name FROM employees" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' MD5(CAST("employees"."ssn" AS TEXT)) AS ssn,\n' + " NULL AS salary,\n" + " '****' AS phone,\n" + ' "employees"."name" AS "name"\n' + 'FROM "employees" AS "employees"' + ) + + +def test_apply_cls_qualified_columns() -> None: + """ + Test CLS with fully qualified column names (table.column). + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "SELECT users.ssn, users.name FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n' + ' "users"."name" AS "name"\n' + 'FROM "users" AS "users"' + ) + + +def test_apply_cls_table_alias() -> None: + """ + Test CLS with table aliases. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "SELECT u.ssn, u.name FROM users u" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' MD5(CAST("u"."ssn" AS TEXT)) AS ssn,\n' + ' "u"."name" AS "name"\n' + 'FROM "users" AS "u"' + ) + + +def test_apply_cls_join() -> None: + """ + Test CLS with JOIN queries. + """ + rules = { + "employees": {"ssn": CLSAction.HASH}, + "salaries": {"amount": CLSAction.NULLIFY}, + } + sql = """ +SELECT e.ssn, e.name, s.amount +FROM employees e +JOIN salaries s +ON e.id = s.employee_id + """ + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' MD5(CAST("e"."ssn" AS TEXT)) AS ssn,\n' + ' "e"."name" AS "name",\n' + " NULL AS amount\n" + 'FROM "employees" AS "e"\n' + 'JOIN "salaries" AS "s"\n' + ' ON "e"."id" = "s"."employee_id"' + ) + + +def test_apply_cls_case_insensitive() -> None: + """ + Test CLS rules are case-insensitive for table and column names. + """ + rules = {"USERS": {"SSN": CLSAction.HASH}} + sql = "SELECT ssn, name FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n' + ' "users"."name" AS "name"\n' + 'FROM "users" AS "users"' + ) + + +def test_apply_cls_with_column_alias() -> None: + """ + Test CLS preserves existing column aliases. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "SELECT ssn AS social_security, name FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' MD5(CAST("users"."ssn" AS TEXT)) AS social_security,\n' + ' "users"."name" AS "name"\n' + 'FROM "users" AS "users"' + ) + + +def test_apply_cls_no_matching_table() -> None: + """ + Test CLS leaves columns unchanged when table doesn't match rules. + """ + rules = {"other_table": {"ssn": CLSAction.HASH}} + sql = "SELECT ssn, name FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # Table doesn't match rules, so columns are unchanged (just qualified) + assert result == ( + "SELECT\n" + ' "users"."ssn" AS "ssn",\n' + ' "users"."name" AS "name"\n' + 'FROM "users" AS "users"' + ) + + +def test_apply_cls_non_column_expressions() -> None: + """ + Test CLS leaves non-column expressions unchanged. + """ + rules = {"users": {"name": CLSAction.HASH}} + sql = "SELECT 1 AS one, 'test' AS str, COUNT(*) AS cnt FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' 1 AS "one",\n' + " 'test' AS \"str\",\n" + ' COUNT(*) AS "cnt"\n' + 'FROM "users" AS "users"' + ) + + +def test_apply_cls_with_schema() -> None: + """ + Test CLS with schema for column qualification. + """ + rules = { + "employees": {"ssn": CLSAction.HASH}, + "departments": {"budget": CLSAction.NULLIFY}, + } + schema = { + "employees": { + "id": "INT", + "ssn": "VARCHAR", + "name": "VARCHAR", + "dept_id": "INT", + }, + "departments": {"id": "INT", "name": "VARCHAR", "budget": "DECIMAL"}, + } + sql = """ +SELECT + ssn, name, budget +FROM employees e +JOIN departments d +ON e.dept_id = d.id + """ + result = apply_cls(sql, rules, engine="postgresql", schema=schema) + + assert "MD5" in result + assert "NULL" in result + + +def test_apply_cls_different_dialects() -> None: + """ + Test CLS uses correct hash function for different database dialects. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "SELECT ssn FROM users" + + # PostgreSQL + result_pg = apply_cls(sql, rules, engine="postgresql") + assert result_pg == ( + 'SELECT\n MD5(CAST("users"."ssn" AS TEXT)) AS ssn\nFROM "users" AS "users"' + ) + + # MySQL + result_mysql = apply_cls(sql, rules, engine="mysql") + assert result_mysql == ( + "SELECT\n MD5(CAST(`users`.`ssn` AS CHAR)) AS ssn\nFROM `users` AS `users`" + ) + + # BigQuery + result_bq = apply_cls(sql, rules, engine="bigquery") + assert result_bq == ( + "SELECT\n" + " TO_HEX(MD5(CAST(`users`.`ssn` AS STRING))) AS ssn\n" + "FROM `users` AS `users`" + ) + + +def test_apply_cls_unknown_dialect_fallback() -> None: + """ + Test CLS uses fallback for unknown database dialects. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "SELECT users.ssn FROM users" + result = apply_cls(sql, rules, engine="unknown_database") + + assert result == ('SELECT\n \'[HASHED]\' AS ssn\nFROM "users" AS "users"') + + +def test_apply_cls_select_star_warning(caplog: pytest.LogCaptureFixture) -> None: + """ + Test CLS logs warning for SELECT * queries. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "SELECT * FROM users" + + import logging + + with caplog.at_level(logging.WARNING): + result = apply_cls(sql, rules, engine="postgresql") + + assert ( + "SELECT *" in caplog.text or "CLS cannot fully process SELECT *" in caplog.text + ) + assert "*" in result # Star should be preserved + + +def test_sql_statement_apply_cls_method() -> None: + """ + Test SQLStatement.apply_cls method. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + statement = SQLStatement("SELECT ssn, name FROM users", engine="postgresql") + statement.apply_cls(rules) + result = statement.format() + + assert result == ( + "SELECT\n" + ' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n' + ' "users"."name" AS "name"\n' + 'FROM "users" AS "users"' + ) + + +def test_sql_statement_apply_cls_empty_rules() -> None: + """ + Test SQLStatement.apply_cls with empty rules returns unchanged statement. + """ + original_sql = "SELECT ssn, name FROM users" + statement = SQLStatement(original_sql, engine="postgresql") + statement.apply_cls({}) + result = statement.format() + + # Empty rules, so original query is preserved (just formatted) + assert result == ("SELECT\n ssn,\n name\nFROM users") + + +def test_sql_statement_apply_cls_with_schema() -> None: + """ + Test SQLStatement.apply_cls with schema parameter. + """ + rules = {"employees": {"ssn": CLSAction.HASH}} + schema = {"employees": {"id": "INT", "ssn": "VARCHAR", "name": "VARCHAR"}} + statement = SQLStatement("SELECT ssn, name FROM employees", engine="postgresql") + statement.apply_cls(rules, schema=schema) + result = statement.format() + + assert result == ( + "SELECT\n" + ' MD5(CAST("employees"."ssn" AS TEXT)) AS ssn,\n' + ' "employees"."name" AS "name"\n' + 'FROM "employees" AS "employees"' + ) + + +def test_cls_transformer_normalize_rules() -> None: + """ + Test CLSTransformer normalizes table and column names to lowercase. + """ + rules = {"USERS": {"SSN": CLSAction.HASH, "Name": CLSAction.MASK}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + + assert "users" in transformer.rules + assert "ssn" in transformer.rules["users"] + assert "name" in transformer.rules["users"] + + +def test_cls_transformer_get_action() -> None: + """ + Test CLSTransformer._get_action method. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + + # Valid table and column + assert transformer._get_action("users", "ssn") == CLSAction.HASH + + # Case insensitive + assert transformer._get_action("USERS", "SSN") == CLSAction.HASH + + # No matching column + assert transformer._get_action("users", "name") is None + + # No matching table + assert transformer._get_action("other", "ssn") is None + + # None table + assert transformer._get_action(None, "ssn") is None + + +def test_cls_transformer_extract_scope_tables() -> None: + """ + Test CLSTransformer._extract_scope_tables method. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + + # Single table + select = parse_one("SELECT * FROM users") + tables = transformer._extract_scope_tables(select) + assert "users" in tables + assert tables["users"] == "users" + + # Table with alias + select = parse_one("SELECT * FROM users u") + tables = transformer._extract_scope_tables(select) + assert "u" in tables + assert tables["u"] == "users" + + # JOIN + select = parse_one("SELECT * FROM users u JOIN orders o ON u.id = o.user_id") + tables = transformer._extract_scope_tables(select) + assert "u" in tables + assert "o" in tables + assert tables["u"] == "users" + assert tables["o"] == "orders" + + +def test_cls_transformer_get_table_for_column_qualified() -> None: + """ + Test CLSTransformer._get_table_for_column with qualified columns. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + scope_tables = {"u": "users", "o": "orders"} + + # Qualified with alias + column = parse_one("u.ssn").find(exp.Column) + result = transformer._get_table_for_column(column, scope_tables) + assert result == "users" + + # Qualified with unknown alias (returns as-is) + column = parse_one("x.ssn").find(exp.Column) + result = transformer._get_table_for_column(column, scope_tables) + assert result == "x" + + +def test_cls_transformer_get_table_for_column_single_table() -> None: + """ + Test CLSTransformer._get_table_for_column infers single table. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + scope_tables = {"users": "users"} + + # Unqualified column with single table in scope + column = parse_one("ssn").find(exp.Column) + result = transformer._get_table_for_column(column, scope_tables) + assert result == "users" + + +def test_cls_transformer_get_table_for_column_multi_table_rules_match() -> None: + """ + Test CLSTransformer._get_table_for_column matches against rules. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + scope_tables = {"users": "users", "orders": "orders"} + + # Unqualified column that only exists in rules for one table + column = parse_one("ssn").find(exp.Column) + result = transformer._get_table_for_column(column, scope_tables) + assert result == "users" + + +def test_cls_transformer_get_table_for_column_no_match() -> None: + """ + Test CLSTransformer._get_table_for_column returns None when no match. + """ + rules = {"other": {"col": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + scope_tables = {"users": "users", "orders": "orders"} + + # Unqualified column with no matching rule + column = parse_one("ssn").find(exp.Column) + result = transformer._get_table_for_column(column, scope_tables) + assert result is None + + +def test_cls_transformer_get_column_alias() -> None: + """ + Test CLSTransformer._get_column_alias method. + """ + transformer = CLSTransformer({}, Dialects.POSTGRES) + + # Column expression + column = parse_one("ssn").find(exp.Column) + assert transformer._get_column_alias(column) == "ssn" + + # Alias expression + alias = parse_one("ssn AS social").find(exp.Alias) + assert transformer._get_column_alias(alias) == "social" + + # Other expression (literal) + literal = parse_one("'test'").find(exp.Literal) + assert transformer._get_column_alias(literal) == "'test'" + + +def test_cls_transformer_create_expressions() -> None: + """ + Test CLSTransformer expression creation methods. + """ + transformer = CLSTransformer({}, Dialects.POSTGRES) + + # Hash expression + column = parse_one("ssn").find(exp.Column) + hash_expr = transformer._create_hash_expression(column, "ssn") + assert isinstance(hash_expr, exp.Alias) + assert hash_expr.alias == "ssn" + + # Null expression + null_expr = transformer._create_null_expression("salary") + assert isinstance(null_expr, exp.Alias) + assert null_expr.alias == "salary" + assert isinstance(null_expr.this, exp.Null) + + # Mask expression + mask_expr = transformer._create_mask_expression("phone") + assert isinstance(mask_expr, exp.Alias) + assert mask_expr.alias == "phone" + assert isinstance(mask_expr.this, exp.Literal) + assert mask_expr.this.this == "****" + + +def test_cls_transformer_call_non_select() -> None: + """ + Test CLSTransformer.__call__ returns non-SELECT nodes unchanged. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + + # Non-SELECT node should be returned unchanged + table = parse_one("users").find(exp.Column) + result = transformer(table) + assert result == table + + +def test_cls_transformer_transform_expression_non_column() -> None: + """ + Test CLSTransformer._transform_expression returns non-column expressions unchanged. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + transformer = CLSTransformer(rules, Dialects.POSTGRES) + scope_tables = {"users": "users"} + + # Literal expression should be unchanged + literal = parse_one("'test'") + result = transformer._transform_expression(literal, scope_tables) + assert result == literal + + # Function expression should be unchanged + func = parse_one("COUNT(*)") + result = transformer._transform_expression(func, scope_tables) + assert result == func + + +@pytest.mark.parametrize( + "sql,rules,engine,expected", + [ + # Basic HASH + ( + "SELECT t.id FROM t", + {"t": {"id": CLSAction.HASH}}, + "postgresql", + 'SELECT\n MD5(CAST("t"."id" AS TEXT)) AS id\nFROM "t" AS "t"', + ), + # Basic NULLIFY + ( + "SELECT t.salary FROM t", + {"t": {"salary": CLSAction.NULLIFY}}, + "postgresql", + 'SELECT\n NULL AS salary\nFROM "t" AS "t"', + ), + # Basic HIDE + ( + "SELECT t.secret, t.public FROM t", + {"t": {"secret": CLSAction.HIDE}}, + "postgresql", + 'SELECT\n "t"."public" AS "public"\nFROM "t" AS "t"', + ), + # Basic MASK + ( + "SELECT t.phone FROM t", + {"t": {"phone": CLSAction.MASK}}, + "postgresql", + 'SELECT\n \'****\' AS phone\nFROM "t" AS "t"', + ), + # Multiple tables with different rules + ( + "SELECT a.ssn, b.amount FROM users a JOIN payments b ON a.id = b.user_id", + { + "users": {"ssn": CLSAction.HASH}, + "payments": {"amount": CLSAction.NULLIFY}, + }, + "postgresql", + ( + "SELECT\n" + ' MD5(CAST("a"."ssn" AS TEXT)) AS ssn,\n' + " NULL AS amount\n" + 'FROM "users" AS "a"\n' + 'JOIN "payments" AS "b"\n' + ' ON "a"."id" = "b"."user_id"' + ), + ), + # Snowflake dialect + ( + "SELECT t.col FROM t", + {"t": {"col": CLSAction.HASH}}, + "snowflake", + 'SELECT\n MD5(TO_CHAR("T"."COL")) AS COL\nFROM "T" AS "T"', + ), + # ClickHouse dialect + ( + "SELECT t.col FROM t", + {"t": {"col": CLSAction.HASH}}, + "clickhouse", + 'SELECT\n MD5(toString("t"."col")) AS col\nFROM "t" AS "t"', + ), + ], +) +def test_apply_cls_parametrized( + sql: str, + rules: dict[str, Any], + engine: str, + expected: str, +) -> None: + """ + Parametrized tests for apply_cls function. + """ + result = apply_cls(sql, rules, engine=engine) + assert result == expected + + +def test_apply_cls_subquery() -> None: + """ + Test CLS applies to subqueries. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "SELECT * FROM (SELECT ssn, name FROM users) AS subq" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' "subq"."ssn" AS "ssn",\n' + ' "subq"."name" AS "name"\n' + "FROM (\n" + " SELECT\n" + ' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n' + ' "users"."name" AS "name"\n' + ' FROM "users" AS "users"\n' + ') AS "subq"' + ) + + +def test_apply_cls_cte() -> None: + """ + Test CLS applies to CTEs. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "WITH cte AS (SELECT ssn, name FROM users) SELECT * FROM cte" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + 'WITH "cte" AS (\n' + " SELECT\n" + ' MD5(CAST("users"."ssn" AS TEXT)) AS ssn,\n' + ' "users"."name" AS "name"\n' + ' FROM "users" AS "users"\n' + ")\n" + "SELECT\n" + ' "cte"."ssn" AS "ssn",\n' + ' "cte"."name" AS "name"\n' + 'FROM "cte" AS "cte"' + ) + + +def test_apply_cls_union() -> None: + """ + Test CLS applies to UNION queries. + """ + rules = {"users": {"ssn": CLSAction.HASH}} + sql = "SELECT ssn FROM users UNION SELECT ssn FROM archived_users" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + "SELECT\n" + ' MD5(CAST("users"."ssn" AS TEXT)) AS ssn\n' + 'FROM "users" AS "users"\n' + "UNION\n" + "SELECT\n" + ' "archived_users"."ssn" AS "ssn"\n' + 'FROM "archived_users" AS "archived_users"' + ) + + +def test_cls_hide_all_columns() -> None: + """ + Test CLS HIDE action when all columns are hidden. + """ + rules = {"users": {"id": CLSAction.HIDE, "name": CLSAction.HIDE}} + sql = "SELECT id, name FROM users" + result = apply_cls(sql, rules, engine="postgresql") + + # Both columns should be hidden, resulting in empty SELECT + assert result == 'SELECT\nFROM "users" AS "users"' + + +def test_cls_transformer_extract_scope_tables_no_from() -> None: + """ + Test CLSTransformer._extract_scope_tables with no FROM clause. + """ + transformer = CLSTransformer({}, Dialects.POSTGRES) + select = parse_one("SELECT 1") + tables = transformer._extract_scope_tables(select) + assert tables == {} + + +def test_cls_transformer_extract_scope_tables_no_joins() -> None: + """ + Test CLSTransformer._extract_scope_tables with FROM but no JOINs. + """ + transformer = CLSTransformer({}, Dialects.POSTGRES) + select = parse_one("SELECT * FROM users") + tables = transformer._extract_scope_tables(select) + assert "users" in tables + assert len(tables) == 1 + + +def test_apply_cls_aliased_column_preserves_alias() -> None: + """ + Test that CLS preserves the alias when column has AS clause. + """ + rules = {"t": {"col": CLSAction.HASH}} + sql = "SELECT t.col AS my_alias FROM t" + result = apply_cls(sql, rules, engine="postgresql") + + assert result == ( + 'SELECT\n MD5(CAST("t"."col" AS TEXT)) AS my_alias\nFROM "t" AS "t"' + ) + + +def test_cls_transformer_hash_pattern_fallback() -> None: + """ + Test CLSTransformer uses fallback hash pattern for unknown dialect. + """ + # Use None as dialect to trigger fallback + transformer = CLSTransformer({"t": {"col": CLSAction.HASH}}, None) + assert transformer.hash_pattern == "'[HASHED]'"