fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)

This commit is contained in:
John Bodley
2024-03-22 13:39:28 +13:00
committed by GitHub
parent a8c01f4cad
commit b25dd0c055
7 changed files with 141 additions and 30 deletions

View File

@@ -25,10 +25,12 @@ import urllib.parse
from collections.abc import Iterable, Iterator
from dataclasses import dataclass
from typing import Any, cast
from unittest.mock import Mock
import sqlglot
import sqlparse
from flask_babel import gettext as __
from jinja2 import nodes
from sqlalchemy import and_
from sqlglot import exp, parse, parse_one
from sqlglot.dialects.dialect import Dialect, Dialects
@@ -1232,3 +1234,61 @@ def extract_table_references(
Table(*[part["value"] for part in table["name"][::-1]])
for table in find_nodes_by_key(tree, "Table")
}
def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
"""
Extract all table references in the Jinjafied SQL statement.
Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
statement may represent invalid SQL which is non-parsable by SQLGlot.
Firstly, we extract any tables referenced within the confines of specific Jinja
macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL
expression to help ensure that the resulting SQL statements are parsable by
SQLGlot.
:param sql: The Jinjafied SQL statement
:param engine: The associated database engine
:returns: The set of tables referenced in the SQL statement
:raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
"""
from superset.jinja_context import ( # pylint: disable=import-outside-toplevel
get_template_processor,
)
# Mock the required database as the processor signature is exposed publically.
processor = get_template_processor(database=Mock(backend=engine))
template = processor.env.parse(sql)
tables = set()
for node in template.find_all(nodes.Call):
if isinstance(node.node, nodes.Getattr) and node.node.attr in (
"latest_partition",
"latest_sub_partition",
):
# Extract the table referenced in the macro.
tables.add(
Table(
*[
remove_quotes(part)
for part in node.args[0].value.split(".")[::-1]
if len(node.args) == 1
]
)
)
# Replace the potentially problematic Jinja macro with some benign SQL.
node.__class__ = nodes.TemplateData
node.fields = nodes.TemplateData.fields
node.data = "NULL"
return (
tables
| ParsedQuery(
sql_statement=processor.process_template(template),
engine=engine,
).tables
)