Snowflake fixes

This commit is contained in:
Beto Dealmeida
2025-12-19 10:56:43 -05:00
parent 40787496e2
commit c07382d606
3 changed files with 143 additions and 9 deletions

View File

@@ -295,16 +295,33 @@ class RLSTransformer:
) -> None:
self.catalog = catalog
self.schema = schema
self.rules = rules
# Normalize table keys to lowercase for case-insensitive matching
# This is needed because apply_cls calls qualify() which may change
# identifier case (e.g., Snowflake uppercases identifiers)
self.rules = {
self._normalize_table(table): predicates
for table, predicates in rules.items()
}
@staticmethod
def _normalize_table(table: Table) -> Table:
"""Normalize table to lowercase for case-insensitive matching."""
return Table(
table=table.table.lower() if table.table else table.table,
schema=table.schema.lower() if table.schema else table.schema,
catalog=table.catalog.lower() if table.catalog else table.catalog,
)
def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
"""
Get the combined RLS predicate for a table.
"""
table = Table(
table_node.name,
table_node.db if table_node.db else self.schema,
table_node.catalog if table_node.catalog else self.catalog,
table = self._normalize_table(
Table(
table_node.name,
table_node.db if table_node.db else self.schema,
table_node.catalog if table_node.catalog else self.catalog,
)
)
if predicates := self.rules.get(table):
return sqlglot.and_(*predicates)
@@ -396,7 +413,19 @@ class RLSAsSubqueryTransformer(RLSTransformer):
if predicate := self.get_predicate(node):
if node.alias:
alias = node.alias
# After qualify(), alias might be a string instead of TableAlias
if isinstance(node.alias, str):
# Check if the original alias was quoted by looking at the SQL
# If the alias contains special chars or is a reserved word, quote it
needs_quoting = (
not node.alias.isidentifier()
or "." in node.alias
)
alias = exp.TableAlias(
this=exp.Identifier(this=node.alias, quoted=needs_quoting)
)
else:
alias = node.alias
else:
name = ".".join(
part