mirror of
https://github.com/apache/superset.git
synced 2026-05-12 03:15:55 +00:00
Snowflake fixes
This commit is contained in:
@@ -295,16 +295,33 @@ class RLSTransformer:
|
||||
) -> None:
|
||||
self.catalog = catalog
|
||||
self.schema = schema
|
||||
self.rules = rules
|
||||
# Normalize table keys to lowercase for case-insensitive matching
|
||||
# This is needed because apply_cls calls qualify() which may change
|
||||
# identifier case (e.g., Snowflake uppercases identifiers)
|
||||
self.rules = {
|
||||
self._normalize_table(table): predicates
|
||||
for table, predicates in rules.items()
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _normalize_table(table: Table) -> Table:
|
||||
"""Normalize table to lowercase for case-insensitive matching."""
|
||||
return Table(
|
||||
table=table.table.lower() if table.table else table.table,
|
||||
schema=table.schema.lower() if table.schema else table.schema,
|
||||
catalog=table.catalog.lower() if table.catalog else table.catalog,
|
||||
)
|
||||
|
||||
def get_predicate(self, table_node: exp.Table) -> exp.Expression | None:
|
||||
"""
|
||||
Get the combined RLS predicate for a table.
|
||||
"""
|
||||
table = Table(
|
||||
table_node.name,
|
||||
table_node.db if table_node.db else self.schema,
|
||||
table_node.catalog if table_node.catalog else self.catalog,
|
||||
table = self._normalize_table(
|
||||
Table(
|
||||
table_node.name,
|
||||
table_node.db if table_node.db else self.schema,
|
||||
table_node.catalog if table_node.catalog else self.catalog,
|
||||
)
|
||||
)
|
||||
if predicates := self.rules.get(table):
|
||||
return sqlglot.and_(*predicates)
|
||||
@@ -396,7 +413,19 @@ class RLSAsSubqueryTransformer(RLSTransformer):
|
||||
|
||||
if predicate := self.get_predicate(node):
|
||||
if node.alias:
|
||||
alias = node.alias
|
||||
# After qualify(), alias might be a string instead of TableAlias
|
||||
if isinstance(node.alias, str):
|
||||
# Check if the original alias was quoted by looking at the SQL
|
||||
# If the alias contains special chars or is a reserved word, quote it
|
||||
needs_quoting = (
|
||||
not node.alias.isidentifier()
|
||||
or "." in node.alias
|
||||
)
|
||||
alias = exp.TableAlias(
|
||||
this=exp.Identifier(this=node.alias, quoted=needs_quoting)
|
||||
)
|
||||
else:
|
||||
alias = node.alias
|
||||
else:
|
||||
name = ".".join(
|
||||
part
|
||||
|
||||
Reference in New Issue
Block a user