From aa69ce43d9724b06f82a48496bc8772eb5f7fbeb Mon Sep 17 00:00:00 2001 From: "Michael S. Molina" <70410625+michael-s-molina@users.noreply.github.com> Date: Fri, 22 Aug 2025 14:39:14 -0300 Subject: [PATCH] fix: User-provided Jinja template parameters causing SQL parsing errors (#34802) (cherry picked from commit e1234b226498919b2b88944a5ad6b3d9cf77ca67) --- superset/commands/sql_lab/execute.py | 14 +++++---- superset/db_engine_specs/base.py | 2 ++ superset/db_engine_specs/postgres.py | 9 +++--- superset/models/core.py | 10 +++++-- superset/models/sql_lab.py | 6 ++-- superset/security/manager.py | 12 ++++++-- superset/sql_parse.py | 39 ++++++++++++++++++------- superset/sqllab/validators.py | 8 +++-- tests/unit_tests/models/sql_lab_test.py | 4 +-- tests/unit_tests/sql_parse_tests.py | 10 +++---- 10 files changed, 75 insertions(+), 39 deletions(-) diff --git a/superset/commands/sql_lab/execute.py b/superset/commands/sql_lab/execute.py index 001d5609db4..a9aa2040485 100644 --- a/superset/commands/sql_lab/execute.py +++ b/superset/commands/sql_lab/execute.py @@ -18,7 +18,7 @@ from __future__ import annotations import logging -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.exc import SQLAlchemyError @@ -148,7 +148,7 @@ class ExecuteSqlCommand(BaseCommand): # Necessary to check access before rendering the Jinjafied query as the # some Jinja macros execute statements upon rendering. - self._validate_access(query) + self._validate_access(query, self._execution_context.template_params) self._execution_context.set_query(query) rendered_query = self._sql_query_render.render(self._execution_context) self._set_query_limit_if_required(rendered_query) @@ -204,9 +204,11 @@ class ExecuteSqlCommand(BaseCommand): db.session.commit() # pylint: disable=consider-using-transaction - def _validate_access(self, query: Query) -> None: + def _validate_access( + self, query: Query, template_params: Optional[dict[str, Any]] = None + ) -> None: try: - self._access_validator.validate(query) + self._access_validator.validate(query, template_params) except Exception as ex: raise QueryIsForbiddenToAccessException(self._execution_context, ex) from ex @@ -242,7 +244,9 @@ class ExecuteSqlCommand(BaseCommand): class CanAccessQueryValidator: - def validate(self, query: Query) -> None: + def validate( + self, query: Query, template_params: Optional[dict[str, Any]] = None + ) -> None: raise NotImplementedError() diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 128910b05d6..7cc187d55a0 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -29,6 +29,7 @@ from typing import ( cast, ContextManager, NamedTuple, + Optional, TYPE_CHECKING, TypedDict, Union, @@ -698,6 +699,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, query: Query, + template_params: Optional[dict[str, Any]] = None, ) -> str | None: """ Return the default schema for a given query. diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index a2e6a3fe129..66e03a28fed 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -21,7 +21,7 @@ import logging import re from datetime import datetime from re import Pattern -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON @@ -35,7 +35,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException, SupersetSecurityException from superset.models.sql_lab import Query -from superset.sql.parse import SQLScript +from superset.sql_parse import process_jinja_sql from superset.utils import core as utils, json from superset.utils.core import GenericDataType @@ -280,6 +280,7 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): cls, database: Database, query: Query, + template_params: Optional[dict[str, Any]] = None, ) -> str | None: """ Return the default schema for a given query. @@ -287,7 +288,7 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): This method simply uses the parent method after checking that there are no malicious path setting in the query. """ - script = SQLScript(query.sql, engine=cls.engine) + script = process_jinja_sql(query.sql, database, template_params).script settings = script.get_settings() if "search_path" in settings: raise SupersetSecurityException( @@ -300,7 +301,7 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): ) ) - return super().get_default_schema_for_query(database, query) + return super().get_default_schema_for_query(database, query, template_params) @classmethod def adjust_engine_params( diff --git a/superset/models/core.py b/superset/models/core.py index 6264761143d..4db7b04d561 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -30,7 +30,7 @@ from copy import deepcopy from datetime import datetime from functools import lru_cache from inspect import signature -from typing import Any, Callable, cast, TYPE_CHECKING +from typing import Any, Callable, cast, Optional, TYPE_CHECKING import numpy import pandas as pd @@ -613,7 +613,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable """ return self.db_engine_spec.get_default_schema(self, catalog) - def get_default_schema_for_query(self, query: Query) -> str | None: + def get_default_schema_for_query( + self, query: Query, template_params: Optional[dict[str, Any]] = None + ) -> str | None: """ Return the default schema for a given query. @@ -627,7 +629,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable default schema is defined in the SQLAlchemy URI; and in others the default schema might be determined by the database itself (like `public` for Postgres). """ # noqa: E501 - return self.db_engine_spec.get_default_schema_for_query(self, query) + return self.db_engine_spec.get_default_schema_for_query( + self, query, template_params + ) @staticmethod def post_process_df(df: pd.DataFrame) -> pd.DataFrame: diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 7d803cb4c19..a5e46af6634 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -56,7 +56,7 @@ from superset.models.helpers import ( ExtraJSONMixin, ImportExportMixin, ) -from superset.sql_parse import CtasMethod, extract_tables_from_jinja_sql, Table +from superset.sql_parse import CtasMethod, process_jinja_sql, Table from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import json from superset.utils.core import ( @@ -80,10 +80,10 @@ class SqlTablesMixin: # pylint: disable=too-few-public-methods def sql_tables(self) -> list[Table]: try: return list( - extract_tables_from_jinja_sql( + process_jinja_sql( self.sql, # type: ignore self.database, # type: ignore - ) + ).tables ) except (SupersetSecurityException, SupersetParseError, TemplateError): return [] diff --git a/superset/security/manager.py b/superset/security/manager.py index 40a44161ab2..9fdc2e7ace5 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -66,7 +66,7 @@ from superset.security.guest_token import ( GuestTokenUser, GuestUser, ) -from superset.sql_parse import extract_tables_from_jinja_sql, Table +from superset.sql_parse import process_jinja_sql, Table from superset.tasks.utils import get_current_user from superset.utils import json from superset.utils.core import ( @@ -2162,6 +2162,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods sql: Optional[str] = None, catalog: Optional[str] = None, schema: Optional[str] = None, + template_params: Optional[dict[str, Any]] = None, ) -> None: """ Raise an exception if the user cannot access the resource. @@ -2175,6 +2176,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :param sql: The SQL string (requires database) :param catalog: Optional catalog name :param schema: Optional schema name + :param template_params: Optional template parameters for Jinja templating :raises SupersetSecurityException: If the user cannot access the resource """ # pylint: disable=import-outside-toplevel @@ -2214,14 +2216,18 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # If the DB engine spec doesn't implement the logic the schema is read # from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy # inspector to read it. - default_schema = database.get_default_schema_for_query(query) + default_schema = database.get_default_schema_for_query( + query, template_params + ) tables = { Table( table_.table, table_.schema or default_schema, table_.catalog or query.catalog or default_catalog, ) - for table_ in extract_tables_from_jinja_sql(query.sql, database) + for table_ in process_jinja_sql( + query.sql, database, template_params + ).tables } elif table: # Make sure table has the default catalog, if not specified. diff --git a/superset/sql_parse.py b/superset/sql_parse.py index 0e4aa637571..5df84d971f6 100644 --- a/superset/sql_parse.py +++ b/superset/sql_parse.py @@ -22,7 +22,8 @@ from __future__ import annotations import logging import re from collections.abc import Iterator -from typing import Any, cast, TYPE_CHECKING +from dataclasses import dataclass +from typing import Any, cast, Optional, TYPE_CHECKING import sqlparse from flask_babel import gettext as __ @@ -919,9 +920,23 @@ def extract_table_references( # noqa: C901 } -def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: +@dataclass +class JinjaSQLResult: """ - Extract all table references in the Jinjafied SQL statement. + Result of processing Jinja SQL. + + Contains the processed SQL script and extracted table references. + """ + + script: SQLScript + tables: set[Table] + + +def process_jinja_sql( + sql: str, database: Database, template_params: Optional[dict[str, Any]] = None +) -> JinjaSQLResult: + """ + Process Jinja-templated SQL and extract table references. Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL statement may represent invalid SQL which is non-parsable by SQLGlot. @@ -933,7 +948,8 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: :param sql: The Jinjafied SQL statement :param database: The database associated with the SQL statement - :returns: The set of tables referenced in the SQL statement + :param template_params: Optional template parameters for Jinja templating + :returns: JinjaSQLResult containing the processed script and table references :raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement :raises jinja2.exceptions.TemplateError: If the Jinjafied SQL could not be rendered """ @@ -974,12 +990,13 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: # re-render template back into a string code = processor.env.compile(ast) template = Template.from_code(processor.env, code, globals=processor.env.globals) - rendered_sql = template.render(processor.get_context()) + rendered_sql = template.render(processor.get_context(), **(template_params or {})) - return ( - tables - | ParsedQuery( - sql_statement=processor.process_template(rendered_sql), - engine=database.db_engine_spec.engine, - ).tables + parsed_script = SQLScript( + processor.process_template(rendered_sql), + engine=database.db_engine_spec.engine, ) + for parsed_statement in parsed_script.statements: + tables |= parsed_statement.tables + + return JinjaSQLResult(script=parsed_script, tables=tables) diff --git a/superset/sqllab/validators.py b/superset/sqllab/validators.py index b79789da4cc..3627b40ba5b 100644 --- a/superset/sqllab/validators.py +++ b/superset/sqllab/validators.py @@ -17,7 +17,7 @@ # pylint: disable=too-few-public-methods from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from superset import security_manager from superset.commands.sql_lab.execute import CanAccessQueryValidator @@ -27,5 +27,7 @@ if TYPE_CHECKING: class CanAccessQueryValidatorImpl(CanAccessQueryValidator): - def validate(self, query: Query) -> None: - security_manager.raise_for_access(query=query) + def validate( + self, query: Query, template_params: Optional[dict[str, Any]] = None + ) -> None: + security_manager.raise_for_access(query=query, template_params=template_params) diff --git a/tests/unit_tests/models/sql_lab_test.py b/tests/unit_tests/models/sql_lab_test.py index e9fa7733701..c359e4f630b 100644 --- a/tests/unit_tests/models/sql_lab_test.py +++ b/tests/unit_tests/models/sql_lab_test.py @@ -56,7 +56,7 @@ def test_sql_tables_mixin_sql_tables_exception( mocker: MockerFixture, ) -> None: mocker.patch( - "superset.models.sql_lab.extract_tables_from_jinja_sql", + "superset.models.sql_lab.process_jinja_sql", side_effect=exception, ) @@ -87,7 +87,7 @@ def test_sql_tables_mixin_invalid_sql_returns_empty_list( ) -> None: """Test that SqlTablesMixin returns empty list when SQL parsing fails.""" mocker.patch( - "superset.models.sql_lab.extract_tables_from_jinja_sql", + "superset.models.sql_lab.process_jinja_sql", side_effect=SupersetParseError( sql=invalid_sql or "INVALID SQL", message=f"Failed to parse SQL: {invalid_sql}", diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index 23aa6b0b125..fef3a07ac50 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -35,12 +35,12 @@ from superset.sql_parse import ( add_table_name, check_sql_functions_exist, extract_table_references, - extract_tables_from_jinja_sql, get_rls_for_table, has_table_query, insert_rls_as_subquery, insert_rls_in_predicate, ParsedQuery, + process_jinja_sql, sanitize_clause, strip_comments_from_sql, ) @@ -1925,10 +1925,10 @@ def test_extract_tables_from_jinja_sql( expected: set[Table], ) -> None: assert ( - extract_tables_from_jinja_sql( + process_jinja_sql( sql=f"'{{{{ {engine}.{macro} }}}}'", database=mocker.Mock(), - ) + ).tables == expected ) @@ -1945,7 +1945,7 @@ def test_extract_tables_from_jinja_sql_disabled(mocker: MockerFixture) -> None: database = mocker.Mock() database.db_engine_spec.engine = "mssql" - assert extract_tables_from_jinja_sql( + assert process_jinja_sql( sql="SELECT 1 FROM t", database=database, - ) == {Table("t")} + ).tables == {Table("t")}