mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
fix(sql_parse): Ensure table extraction handles Jinja templating (#27470)
This commit is contained in:
committed by
Michael S. Molina
parent
4ff331a66c
commit
7c14968e6d
@@ -16,16 +16,19 @@
|
||||
# under the License.
|
||||
|
||||
# pylint: disable=too-many-lines
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
import urllib.parse
|
||||
from collections.abc import Iterable, Iterator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, cast, Optional
|
||||
from typing import Any, cast
|
||||
from unittest.mock import Mock
|
||||
|
||||
import sqlparse
|
||||
from flask_babel import gettext as __
|
||||
from jinja2 import nodes
|
||||
from sqlalchemy import and_
|
||||
from sqlglot import exp, parse, parse_one
|
||||
from sqlglot.dialects import Dialects
|
||||
@@ -142,7 +145,7 @@ class CtasMethod(StrEnum):
|
||||
VIEW = "VIEW"
|
||||
|
||||
|
||||
def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
|
||||
def _extract_limit_from_query(statement: TokenList) -> int | None:
|
||||
"""
|
||||
Extract limit clause from SQL statement.
|
||||
|
||||
@@ -163,9 +166,7 @@ def _extract_limit_from_query(statement: TokenList) -> Optional[int]:
|
||||
return None
|
||||
|
||||
|
||||
def extract_top_from_query(
|
||||
statement: TokenList, top_keywords: set[str]
|
||||
) -> Optional[int]:
|
||||
def extract_top_from_query(statement: TokenList, top_keywords: set[str]) -> int | None:
|
||||
"""
|
||||
Extract top clause value from SQL statement.
|
||||
|
||||
@@ -189,7 +190,7 @@ def extract_top_from_query(
|
||||
return top
|
||||
|
||||
|
||||
def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
||||
def get_cte_remainder_query(sql: str) -> tuple[str | None, str]:
|
||||
"""
|
||||
parse the SQL and return the CTE and rest of the block to the caller
|
||||
|
||||
@@ -197,7 +198,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
||||
:return: CTE and remainder block to the caller
|
||||
|
||||
"""
|
||||
cte: Optional[str] = None
|
||||
cte: str | None = None
|
||||
remainder = sql
|
||||
stmt = sqlparse.parse(sql)[0]
|
||||
|
||||
@@ -215,7 +216,7 @@ def get_cte_remainder_query(sql: str) -> tuple[Optional[str], str]:
|
||||
return cte, remainder
|
||||
|
||||
|
||||
def strip_comments_from_sql(statement: str, engine: Optional[str] = None) -> str:
|
||||
def strip_comments_from_sql(statement: str, engine: str | None = None) -> str:
|
||||
"""
|
||||
Strips comments from a SQL statement, does a simple test first
|
||||
to avoid always instantiating the expensive ParsedQuery constructor
|
||||
@@ -239,8 +240,8 @@ class Table:
|
||||
"""
|
||||
|
||||
table: str
|
||||
schema: Optional[str] = None
|
||||
catalog: Optional[str] = None
|
||||
schema: str | None = None
|
||||
catalog: str | None = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""
|
||||
@@ -262,7 +263,7 @@ class ParsedQuery:
|
||||
self,
|
||||
sql_statement: str,
|
||||
strip_comments: bool = False,
|
||||
engine: Optional[str] = None,
|
||||
engine: str | None = None,
|
||||
):
|
||||
if strip_comments:
|
||||
sql_statement = sqlparse.format(sql_statement, strip_comments=True)
|
||||
@@ -271,7 +272,7 @@ class ParsedQuery:
|
||||
self._dialect = SQLGLOT_DIALECTS.get(engine) if engine else None
|
||||
self._tables: set[Table] = set()
|
||||
self._alias_names: set[str] = set()
|
||||
self._limit: Optional[int] = None
|
||||
self._limit: int | None = None
|
||||
|
||||
logger.debug("Parsing with sqlparse statement: %s", self.sql)
|
||||
self._parsed = sqlparse.parse(self.stripped())
|
||||
@@ -382,7 +383,7 @@ class ParsedQuery:
|
||||
return source.name in ctes_in_scope
|
||||
|
||||
@property
|
||||
def limit(self) -> Optional[int]:
|
||||
def limit(self) -> int | None:
|
||||
return self._limit
|
||||
|
||||
def _get_cte_tables(self, parsed: dict[str, Any]) -> list[dict[str, Any]]:
|
||||
@@ -463,7 +464,7 @@ class ParsedQuery:
|
||||
|
||||
return True
|
||||
|
||||
def get_inner_cte_expression(self, tokens: TokenList) -> Optional[TokenList]:
|
||||
def get_inner_cte_expression(self, tokens: TokenList) -> TokenList | None:
|
||||
for token in tokens:
|
||||
if self._is_identifier(token):
|
||||
for identifier_token in token.tokens:
|
||||
@@ -527,7 +528,7 @@ class ParsedQuery:
|
||||
return statements
|
||||
|
||||
@staticmethod
|
||||
def get_table(tlist: TokenList) -> Optional[Table]:
|
||||
def get_table(tlist: TokenList) -> Table | None:
|
||||
"""
|
||||
Return the table if valid, i.e., conforms to the [[catalog.]schema.]table
|
||||
construct.
|
||||
@@ -563,7 +564,7 @@ class ParsedQuery:
|
||||
def as_create_table(
|
||||
self,
|
||||
table_name: str,
|
||||
schema_name: Optional[str] = None,
|
||||
schema_name: str | None = None,
|
||||
overwrite: bool = False,
|
||||
method: CtasMethod = CtasMethod.TABLE,
|
||||
) -> str:
|
||||
@@ -723,8 +724,8 @@ def add_table_name(rls: TokenList, table: str) -> None:
|
||||
def get_rls_for_table(
|
||||
candidate: Token,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
) -> Optional[TokenList]:
|
||||
default_schema: str | None,
|
||||
) -> TokenList | None:
|
||||
"""
|
||||
Given a table name, return any associated RLS predicates.
|
||||
"""
|
||||
@@ -770,7 +771,7 @@ def get_rls_for_table(
|
||||
def insert_rls_as_subquery(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
default_schema: str | None,
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
@@ -786,7 +787,7 @@ def insert_rls_as_subquery(
|
||||
This method is safer than ``insert_rls_in_predicate``, but doesn't work in all
|
||||
databases.
|
||||
"""
|
||||
rls: Optional[TokenList] = None
|
||||
rls: TokenList | None = None
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
# Recurse into child token list
|
||||
@@ -862,7 +863,7 @@ def insert_rls_as_subquery(
|
||||
def insert_rls_in_predicate(
|
||||
token_list: TokenList,
|
||||
database_id: int,
|
||||
default_schema: Optional[str],
|
||||
default_schema: str | None,
|
||||
) -> TokenList:
|
||||
"""
|
||||
Update a statement inplace applying any associated RLS predicates.
|
||||
@@ -873,7 +874,7 @@ def insert_rls_in_predicate(
|
||||
after: SELECT * FROM some_table WHERE ( 1=1) AND some_table.id=42
|
||||
|
||||
"""
|
||||
rls: Optional[TokenList] = None
|
||||
rls: TokenList | None = None
|
||||
state = InsertRLSState.SCANNING
|
||||
for token in token_list.tokens:
|
||||
# Recurse into child token list
|
||||
@@ -1007,7 +1008,7 @@ RE_JINJA_BLOCK = re.compile(r"\{[%#][^\{\}%#]+[%#]\}")
|
||||
|
||||
def extract_table_references(
|
||||
sql_text: str, sqla_dialect: str, show_warning: bool = True
|
||||
) -> set["Table"]:
|
||||
) -> set[Table]:
|
||||
"""
|
||||
Return all the dependencies from a SQL sql_text.
|
||||
"""
|
||||
@@ -1051,3 +1052,61 @@ def extract_table_references(
|
||||
Table(*[part["value"] for part in table["name"][::-1]])
|
||||
for table in find_nodes_by_key(tree, "Table")
|
||||
}
|
||||
|
||||
|
||||
def extract_tables_from_jinja_sql(sql: str, engine: str | None = None) -> set[Table]:
|
||||
"""
|
||||
Extract all table references in the Jinjafied SQL statement.
|
||||
|
||||
Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL
|
||||
statement may represent invalid SQL which is non-parsable by SQLGlot.
|
||||
|
||||
Firstly, we extract any tables referenced within the confines of specific Jinja
|
||||
macros. Secondly, we replace these non-SQL Jinja calls with a pseudo-benign SQL
|
||||
expression to help ensure that the resulting SQL statements are parsable by
|
||||
SQLGlot.
|
||||
|
||||
:param sql: The Jinjafied SQL statement
|
||||
:param engine: The associated database engine
|
||||
:returns: The set of tables referenced in the SQL statement
|
||||
:raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement
|
||||
"""
|
||||
|
||||
from superset.jinja_context import ( # pylint: disable=import-outside-toplevel
|
||||
get_template_processor,
|
||||
)
|
||||
|
||||
# Mock the required database as the processor signature is exposed publically.
|
||||
processor = get_template_processor(database=Mock(backend=engine))
|
||||
template = processor.env.parse(sql)
|
||||
|
||||
tables = set()
|
||||
|
||||
for node in template.find_all(nodes.Call):
|
||||
if isinstance(node.node, nodes.Getattr) and node.node.attr in (
|
||||
"latest_partition",
|
||||
"latest_sub_partition",
|
||||
):
|
||||
# Extract the table referenced in the macro.
|
||||
tables.add(
|
||||
Table(
|
||||
*[
|
||||
remove_quotes(part)
|
||||
for part in node.args[0].value.split(".")[::-1]
|
||||
if len(node.args) == 1
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
# Replace the potentially problematic Jinja macro with some benign SQL.
|
||||
node.__class__ = nodes.TemplateData
|
||||
node.fields = nodes.TemplateData.fields
|
||||
node.data = "NULL"
|
||||
|
||||
return (
|
||||
tables
|
||||
| ParsedQuery(
|
||||
sql_statement=processor.process_template(template),
|
||||
engine=engine,
|
||||
).tables
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user