fix: User-provided Jinja template parameters causing SQL parsing errors (#34802)

This commit is contained in:
Michael S. Molina
2025-08-22 14:39:14 -03:00
committed by GitHub
parent 75af53dc3d
commit e1234b2264
10 changed files with 131 additions and 36 deletions

View File

@@ -24,7 +24,7 @@ import re
import urllib.parse
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Any, Generic, TYPE_CHECKING, TypeVar
from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar
import sqlglot
from jinja2 import nodes, Template
@@ -1380,6 +1380,18 @@ def is_cte(source: exp.Table, scope: Scope) -> bool:
T = TypeVar("T", str, None)
@dataclass
class JinjaSQLResult:
"""
Result of processing Jinja SQL.
Contains the processed SQL script and extracted table references.
"""
script: SQLScript
tables: set[Table]
def remove_quotes(val: T) -> T:
"""
Helper that removes surrounding quotes from strings.
@@ -1393,9 +1405,11 @@ def remove_quotes(val: T) -> T:
return val
def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
def process_jinja_sql(
sql: str, database: Database, template_params: Optional[dict[str, Any]] = None
) -> JinjaSQLResult:
"""
Extract all table references in the Jinjafied SQL statement.
Process Jinja-templated SQL and extract table references.
Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
statement may represent invalid SQL which is non-parsable by SQLGlot.
@@ -1407,7 +1421,8 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
:param sql: The Jinjafied SQL statement
:param database: The database associated with the SQL statement
:returns: The set of tables referenced in the SQL statement
:param template_params: Optional template parameters for Jinja templating
:returns: JinjaSQLResult containing the processed script and table references
:raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
:raises jinja2.exceptions.TemplateError: If the Jinjafied SQL could not be rendered
"""
@@ -1448,7 +1463,7 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
# re-render template back into a string
code = processor.env.compile(ast)
template = Template.from_code(processor.env, code, globals=processor.env.globals)
rendered_sql = template.render(processor.get_context())
rendered_sql = template.render(processor.get_context(), **(template_params or {}))
parsed_script = SQLScript(
processor.process_template(rendered_sql),
@@ -1457,7 +1472,7 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]:
for parsed_statement in parsed_script.statements:
tables |= parsed_statement.tables
return tables
return JinjaSQLResult(script=parsed_script, tables=tables)
def sanitize_clause(clause: str, engine: str) -> str: