Snowflake fixes

This commit is contained in:
Beto Dealmeida
2025-12-19 10:56:43 -05:00
parent 40787496e2
commit c07382d606
3 changed files with 143 additions and 9 deletions

View File

@@ -551,7 +551,7 @@ def apply_data_access_rules(
if table_cls:
cls_rules[qualified_table] = table_cls
# Apply CLS first (before RLS) so that hidden columns are removed
# Apply CLS first (before RLS) so that column transformations happen
# before RLS wraps the query in a subquery
if cls_rules:
# Build schema dict for sqlglot's qualify() to expand SELECT *
@@ -591,8 +591,10 @@ def apply_data_access_rules(
parsed_statement.apply_cls(cls_rules, schema=table_schemas if table_schemas else None)
# Apply RLS after CLS - RLS wraps the query in a subquery with SELECT *
# which will pick up the already-transformed columns from CLS
# Apply RLS after CLS
# Note: CLS runs qualify() which may change identifier case for some databases
# (e.g., Snowflake uppercases identifiers). The RLSTransformer normalizes
# table names to lowercase for case-insensitive matching.
if rls_predicates:
parsed_statement.apply_rls(catalog, schema, rls_predicates, method)

View File

@@ -295,16 +295,33 @@ class RLSTransformer:
) -> None:
self.catalog = catalog
self.schema = schema
self.rules = rules
# Normalize table keys to lowercase for case-insensitive matching
# This is needed because apply_cls calls qualify() which may change
# identifier case (e.g., Snowflake uppercases identifiers)
self.rules = {
self._normalize_table(table): predicates
for table, predicates in rules.items()
}
@staticmethod
def _normalize_table(table: Table) -> Table:
"""Normalize table to lowercase for case-insensitive matching."""
return Table(
table=table.table.lower() if table.table else table.table,
schema=table.schema.lower() if table.schema else table.schema,
catalog=table.catalog.lower() if table.catalog else table.catalog,
)
def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
"""
Get the combined RLS predicate for a table.
"""
table = Table(
table_node.name,
table_node.db if table_node.db else self.schema,
table_node.catalog if table_node.catalog else self.catalog,
table = self._normalize_table(
Table(
table_node.name,
table_node.db if table_node.db else self.schema,
table_node.catalog if table_node.catalog else self.catalog,
)
)
if predicates := self.rules.get(table):
return sqlglot.and_(*predicates)
@@ -396,7 +413,19 @@ class RLSAsSubqueryTransformer(RLSTransformer):
if predicate := self.get_predicate(node):
if node.alias:
alias = node.alias
# After qualify(), alias might be a string instead of TableAlias
if isinstance(node.alias, str):
# Check if the original alias was quoted by looking at the SQL
# If the alias contains special chars or is a reserved word, quote it
needs_quoting = (
not node.alias.isidentifier()
or "." in node.alias
)
alias = exp.TableAlias(
this=exp.Identifier(this=node.alias, quoted=needs_quoting)
)
else:
alias = node.alias
else:
name = ".".join(
part

View File

@@ -4690,3 +4690,106 @@ def test_merge_cls_rules_complex_scenario() -> None:
"credit_card": CLSAction.HIDE,
},
}
def test_combined_rls_and_cls_snowflake() -> None:
"""
Test combined RLS and CLS application with Snowflake dialect.
This tests a real-world scenario where:
- A table has both RLS (row-level security) and CLS (column-level security) rules
- The column names in the rule are lowercase but Snowflake uppercases identifiers
- Both transformations should be applied correctly
Rule configuration:
{
"allowed": [
{"database": "Snowflake", "catalog": "SAMPLE_DATA", "schema": "tpcds_sf10tcl"},
{
"database": "Snowflake",
"table": "customer",
"catalog": "SAMPLE_DATA",
"schema": "tpcds_sf10tcl",
"rls": {"predicate": "C_BIRTH_COUNTRY = 'BRAZIL'"},
"cls": {"c_email_address": "hash"}
}
]
}
"""
sql = "SELECT C_BIRTH_COUNTRY, C_EMAIL_ADDRESS FROM customer LIMIT 10"
statement = SQLStatement(sql, engine="snowflake")
# CLS rules with lowercase column name (as stored in the rule)
cls_rules = {
Table("customer", "tpcds_sf10tcl", "SAMPLE_DATA"): {
"c_email_address": CLSAction.HASH
}
}
# RLS predicate
rls_predicate = statement.parse_predicate("C_BIRTH_COUNTRY = 'BRAZIL'")
rls_predicates = {
Table("customer", "tpcds_sf10tcl", "SAMPLE_DATA"): [rls_predicate]
}
# Schema for qualify() to expand SELECT *
table_schema = {
"SAMPLE_DATA": {
"tpcds_sf10tcl": {
"customer": {
"C_BIRTH_COUNTRY": "VARCHAR",
"C_EMAIL_ADDRESS": "VARCHAR",
}
}
}
}
# Apply CLS first (before RLS)
statement.apply_cls(cls_rules, schema=table_schema)
# Apply RLS after CLS
statement.apply_rls(
"SAMPLE_DATA", "tpcds_sf10tcl", rls_predicates, RLSMethod.AS_SUBQUERY
)
result = statement.format()
# Verify CLS is applied: c_email_address should be hashed with MD5
assert "MD5(TO_CHAR(" in result
assert "C_EMAIL_ADDRESS" in result
# Verify RLS is applied: should have subquery with WHERE clause
assert "WHERE" in result
assert "C_BIRTH_COUNTRY = 'BRAZIL'" in result
# Verify the structure: should be SELECT ... FROM (SELECT * FROM ... WHERE ...) AS ...
assert "SELECT" in result
assert "FROM (" in result
def test_combined_rls_and_cls_case_insensitive_matching() -> None:
"""
Test that RLS and CLS matching is case-insensitive.
Snowflake uppercases identifiers, so the rule column names (lowercase)
must match the query column names (uppercase after qualify()).
"""
sql = "SELECT email FROM users"
statement = SQLStatement(sql, engine="snowflake")
# Rules with lowercase names
cls_rules = {Table("users"): {"email": CLSAction.HASH}}
rls_predicate = statement.parse_predicate("active = TRUE")
rls_predicates = {Table("users"): [rls_predicate]}
table_schema = {"users": {"email": "VARCHAR", "active": "BOOLEAN"}}
# Apply CLS then RLS
statement.apply_cls(cls_rules, schema=table_schema)
statement.apply_rls(None, None, rls_predicates, RLSMethod.AS_SUBQUERY)
result = statement.format()
# Both should be applied despite case differences
assert "MD5" in result # CLS hash applied
assert "active = TRUE" in result # RLS predicate applied