diff --git a/superset/commands/sql_lab/execute.py b/superset/commands/sql_lab/execute.py index 20cba40cfc0..a3d29fe330f 100644 --- a/superset/commands/sql_lab/execute.py +++ b/superset/commands/sql_lab/execute.py @@ -18,7 +18,7 @@ from __future__ import annotations import logging -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.exc import SQLAlchemyError @@ -148,7 +148,7 @@ class ExecuteSqlCommand(BaseCommand): # Necessary to check access before rendering the Jinjafied query as the # some Jinja macros execute statements upon rendering. - self._validate_access(query) + self._validate_access(query, self._execution_context.template_params) self._execution_context.set_query(query) rendered_query = self._sql_query_render.render(self._execution_context) self._set_query_limit_if_required(rendered_query) @@ -204,9 +204,11 @@ class ExecuteSqlCommand(BaseCommand): db.session.commit() # pylint: disable=consider-using-transaction - def _validate_access(self, query: Query) -> None: + def _validate_access( + self, query: Query, template_params: Optional[dict[str, Any]] = None + ) -> None: try: - self._access_validator.validate(query) + self._access_validator.validate(query, template_params) except Exception as ex: raise QueryIsForbiddenToAccessException(self._execution_context, ex) from ex @@ -242,7 +244,9 @@ class ExecuteSqlCommand(BaseCommand): class CanAccessQueryValidator: - def validate(self, query: Query) -> None: + def validate( + self, query: Query, template_params: Optional[dict[str, Any]] = None + ) -> None: raise NotImplementedError() diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 07e787df119..4eb30f40e1b 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -30,6 +30,7 @@ from typing import ( cast, ContextManager, NamedTuple, + Optional, TYPE_CHECKING, TypedDict, Union, @@ -707,6 +708,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods cls, database: Database, query: Query, + template_params: Optional[dict[str, Any]] = None, ) -> str | None: """ Return the default schema for a given query. diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 3e95ecc7901..b259164e4fb 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -21,7 +21,7 @@ import logging import re from datetime import datetime from re import Pattern -from typing import Any, TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from flask_babel import gettext as __ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON @@ -35,7 +35,7 @@ from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException, SupersetSecurityException from superset.models.sql_lab import Query -from superset.sql.parse import SQLScript +from superset.sql.parse import process_jinja_sql from superset.utils import core as utils, json from superset.utils.core import GenericDataType, QuerySource @@ -281,6 +281,7 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): cls, database: Database, query: Query, + template_params: Optional[dict[str, Any]] = None, ) -> str | None: """ Return the default schema for a given query. @@ -288,7 +289,7 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): This method simply uses the parent method after checking that there are no malicious path setting in the query. """ - script = SQLScript(query.sql, engine=cls.engine) + script = process_jinja_sql(query.sql, database, template_params).script settings = script.get_settings() if "search_path" in settings: raise SupersetSecurityException( @@ -301,7 +302,7 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): ) ) - return super().get_default_schema_for_query(database, query) + return super().get_default_schema_for_query(database, query, template_params) @classmethod def adjust_engine_params( diff --git a/superset/models/core.py b/superset/models/core.py index 603b0799d90..68726d9da0c 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -30,7 +30,7 @@ from copy import deepcopy from datetime import datetime from functools import lru_cache from inspect import signature -from typing import Any, Callable, cast, TYPE_CHECKING +from typing import Any, Callable, cast, Optional, TYPE_CHECKING import numpy import pandas as pd @@ -607,7 +607,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable """ return self.db_engine_spec.get_default_schema(self, catalog) - def get_default_schema_for_query(self, query: Query) -> str | None: + def get_default_schema_for_query( + self, query: Query, template_params: Optional[dict[str, Any]] = None + ) -> str | None: """ Return the default schema for a given query. @@ -621,7 +623,9 @@ class Database(Model, AuditMixinNullable, ImportExportMixin): # pylint: disable default schema is defined in the SQLAlchemy URI; and in others the default schema might be determined by the database itself (like `public` for Postgres). """ # noqa: E501 - return self.db_engine_spec.get_default_schema_for_query(self, query) + return self.db_engine_spec.get_default_schema_for_query( + self, query, template_params + ) @staticmethod def post_process_df(df: pd.DataFrame) -> pd.DataFrame: diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index d1603af44b5..53224f24721 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -56,7 +56,11 @@ from superset.models.helpers import ( ExtraJSONMixin, ImportExportMixin, ) -from superset.sql.parse import CTASMethod, extract_tables_from_jinja_sql, Table +from superset.sql.parse import ( + CTASMethod, + process_jinja_sql, + Table, +) from superset.sqllab.limiting_factor import LimitingFactor from superset.utils import json from superset.utils.core import ( @@ -80,10 +84,10 @@ class SqlTablesMixin: # pylint: disable=too-few-public-methods def sql_tables(self) -> list[Table]: try: return list( - extract_tables_from_jinja_sql( + process_jinja_sql( self.sql, # type: ignore self.database, # type: ignore - ) + ).tables ) except (SupersetSecurityException, SupersetParseError, TemplateError): return [] diff --git a/superset/security/manager.py b/superset/security/manager.py index 339bbbe2ab3..7ac989e0b4c 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -68,7 +68,7 @@ from superset.security.guest_token import ( GuestTokenUser, GuestUser, ) -from superset.sql.parse import extract_tables_from_jinja_sql, Table +from superset.sql.parse import process_jinja_sql, Table from superset.tasks.utils import get_current_user from superset.utils import json from superset.utils.core import ( @@ -2287,6 +2287,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods sql: Optional[str] = None, catalog: Optional[str] = None, schema: Optional[str] = None, + template_params: Optional[dict[str, Any]] = None, ) -> None: """ Raise an exception if the user cannot access the resource. @@ -2300,6 +2301,7 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods :param sql: The SQL string (requires database) :param catalog: Optional catalog name :param schema: Optional schema name + :param template_params: Optional template parameters for Jinja templating :raises SupersetSecurityException: If the user cannot access the resource """ # pylint: disable=import-outside-toplevel @@ -2339,14 +2341,18 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods # If the DB engine spec doesn't implement the logic the schema is read # from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy # inspector to read it. - default_schema = database.get_default_schema_for_query(query) + default_schema = database.get_default_schema_for_query( + query, template_params + ) tables = { Table( table_.table, table_.schema or default_schema, table_.catalog or query.catalog or default_catalog, ) - for table_ in extract_tables_from_jinja_sql(query.sql, database) + for table_ in process_jinja_sql( + query.sql, database, template_params + ).tables } elif table: # Make sure table has the default catalog, if not specified. diff --git a/superset/sql/parse.py b/superset/sql/parse.py index bcc6b32382b..7c430677efe 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -24,7 +24,7 @@ import re import urllib.parse from collections.abc import Iterable from dataclasses import dataclass -from typing import Any, Generic, TYPE_CHECKING, TypeVar +from typing import Any, Generic, Optional, TYPE_CHECKING, TypeVar import sqlglot from jinja2 import nodes, Template @@ -1380,6 +1380,18 @@ def is_cte(source: exp.Table, scope: Scope) -> bool: T = TypeVar("T", str, None) +@dataclass +class JinjaSQLResult: + """ + Result of processing Jinja SQL. + + Contains the processed SQL script and extracted table references. + """ + + script: SQLScript + tables: set[Table] + + def remove_quotes(val: T) -> T: """ Helper that removes surrounding quotes from strings. @@ -1393,9 +1405,11 @@ def remove_quotes(val: T) -> T: return val -def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: +def process_jinja_sql( + sql: str, database: Database, template_params: Optional[dict[str, Any]] = None +) -> JinjaSQLResult: """ - Extract all table references in the Jinjafied SQL statement. + Process Jinja-templated SQL and extract table references. Due to Jinja templating, a multiphase approach is necessary as the Jinjafied SQL statement may represent invalid SQL which is non-parsable by SQLGlot. @@ -1407,7 +1421,8 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: :param sql: The Jinjafied SQL statement :param database: The database associated with the SQL statement - :returns: The set of tables referenced in the SQL statement + :param template_params: Optional template parameters for Jinja templating + :returns: JinjaSQLResult containing the processed script and table references :raises SupersetSecurityException: If SQLGlot is unable to parse the SQL statement :raises jinja2.exceptions.TemplateError: If the Jinjafied SQL could not be rendered """ @@ -1448,7 +1463,7 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: # re-render template back into a string code = processor.env.compile(ast) template = Template.from_code(processor.env, code, globals=processor.env.globals) - rendered_sql = template.render(processor.get_context()) + rendered_sql = template.render(processor.get_context(), **(template_params or {})) parsed_script = SQLScript( processor.process_template(rendered_sql), @@ -1457,7 +1472,7 @@ def extract_tables_from_jinja_sql(sql: str, database: Database) -> set[Table]: for parsed_statement in parsed_script.statements: tables |= parsed_statement.tables - return tables + return JinjaSQLResult(script=parsed_script, tables=tables) def sanitize_clause(clause: str, engine: str) -> str: diff --git a/superset/sqllab/validators.py b/superset/sqllab/validators.py index b79789da4cc..3627b40ba5b 100644 --- a/superset/sqllab/validators.py +++ b/superset/sqllab/validators.py @@ -17,7 +17,7 @@ # pylint: disable=too-few-public-methods from __future__ import annotations -from typing import TYPE_CHECKING +from typing import Any, Optional, TYPE_CHECKING from superset import security_manager from superset.commands.sql_lab.execute import CanAccessQueryValidator @@ -27,5 +27,7 @@ if TYPE_CHECKING: class CanAccessQueryValidatorImpl(CanAccessQueryValidator): - def validate(self, query: Query) -> None: - security_manager.raise_for_access(query=query) + def validate( + self, query: Query, template_params: Optional[dict[str, Any]] = None + ) -> None: + security_manager.raise_for_access(query=query, template_params=template_params) diff --git a/tests/unit_tests/models/sql_lab_test.py b/tests/unit_tests/models/sql_lab_test.py index e9fa7733701..c359e4f630b 100644 --- a/tests/unit_tests/models/sql_lab_test.py +++ b/tests/unit_tests/models/sql_lab_test.py @@ -56,7 +56,7 @@ def test_sql_tables_mixin_sql_tables_exception( mocker: MockerFixture, ) -> None: mocker.patch( - "superset.models.sql_lab.extract_tables_from_jinja_sql", + "superset.models.sql_lab.process_jinja_sql", side_effect=exception, ) @@ -87,7 +87,7 @@ def test_sql_tables_mixin_invalid_sql_returns_empty_list( ) -> None: """Test that SqlTablesMixin returns empty list when SQL parsing fails.""" mocker.patch( - "superset.models.sql_lab.extract_tables_from_jinja_sql", + "superset.models.sql_lab.process_jinja_sql", side_effect=SupersetParseError( sql=invalid_sql or "INVALID SQL", message=f"Failed to parse SQL: {invalid_sql}", diff --git a/tests/unit_tests/sql/parse_tests.py b/tests/unit_tests/sql/parse_tests.py index 6e4db3c8880..f87f5613c30 100644 --- a/tests/unit_tests/sql/parse_tests.py +++ b/tests/unit_tests/sql/parse_tests.py @@ -25,11 +25,12 @@ from superset.exceptions import QueryClauseValidationException, SupersetParseErr from superset.jinja_context import JinjaTemplateProcessor from superset.sql.parse import ( CTASMethod, - extract_tables_from_jinja_sql, extract_tables_from_statement, + JinjaSQLResult, KQLTokenType, KustoKQLStatement, LimitMethod, + process_jinja_sql, remove_quotes, RLSMethod, sanitize_clause, @@ -2661,10 +2662,10 @@ def test_extract_tables_from_jinja_sql( expected: set[Table], ) -> None: assert ( - extract_tables_from_jinja_sql( + process_jinja_sql( sql=f"'{{{{ {engine}.{macro} }}}}'", database=mocker.MagicMock(backend=engine), - ) + ).tables == expected ) @@ -2677,10 +2678,10 @@ def test_extract_tables_from_jinja_sql_disabled(mocker: MockerFixture) -> None: database = mocker.MagicMock() database.db_engine_spec.engine = "mssql" - assert extract_tables_from_jinja_sql( + assert process_jinja_sql( sql="SELECT 1 FROM t", database=database, - ) == {Table("t")} + ).tables == {Table("t")} def test_extract_tables_from_jinja_sql_invalid_function(mocker: MockerFixture) -> None: @@ -2696,10 +2697,66 @@ def test_extract_tables_from_jinja_sql_invalid_function(mocker: MockerFixture) - return_value=processor, ) - assert extract_tables_from_jinja_sql( + assert process_jinja_sql( sql="SELECT * FROM {{ my_table() }}", database=database, - ) == {Table("t")} + ).tables == {Table("t")} + + +def test_process_jinja_sql_result_object_structure(mocker: MockerFixture) -> None: + """ + Test that process_jinja_sql returns a proper JinjaSQLResult object + with correct script and tables properties. + """ + database = mocker.MagicMock() + database.db_engine_spec.engine = "postgresql" + + result = process_jinja_sql( + sql="SELECT id FROM users WHERE active = true", + database=database, + ) + + # Test that result is the correct type + assert isinstance(result, JinjaSQLResult) + + # Test that script property returns a SQLScript + assert hasattr(result, "script") + assert isinstance(result.script, SQLScript) + + # Test that tables property returns a set of Tables + assert hasattr(result, "tables") + assert isinstance(result.tables, set) + assert result.tables == {Table("users")} + + # Test that the script contains the expected SQL + formatted_sql = result.script.format() + assert "users" in formatted_sql + assert "active = TRUE" in formatted_sql + + +def test_process_jinja_sql_template_params_parameter(mocker: MockerFixture) -> None: + """ + Test that the template_params parameter is properly handled. + """ + database = mocker.MagicMock() + database.db_engine_spec.engine = "postgresql" + + processor = JinjaTemplateProcessor(database) + mocker.patch( + "superset.jinja_context.get_template_processor", + return_value=processor, + ) + + # Test that template_params parameter is accepted and passed through + result = process_jinja_sql( + sql="SELECT * FROM table_name", + database=database, + template_params={"param1": "value1"}, + ) + + # Verify the function accepts the parameter without error + assert isinstance(result, JinjaSQLResult) + assert result.tables == {Table("table_name")} @pytest.mark.parametrize(