diff --git a/superset/sqllab/api.py b/superset/sqllab/api.py index f81031a601c..0947bdf17d1 100644 --- a/superset/sqllab/api.py +++ b/superset/sqllab/api.py @@ -237,31 +237,36 @@ class SqlLabRestApi(BaseSupersetApi): sql = model["sql"] template_params = model.get("template_params") database_id = model.get("database_id") + database_engine = None - # Process Jinja templates if template_params and database_id are provided - if template_params and database_id is not None: + # Process Jinja templates if template_params are provided + if database_id is not None: database = DatabaseDAO.find_by_id(database_id) - if database: - try: - template_params = ( - json.loads(template_params) - if isinstance(template_params, str) - else template_params - ) - if template_params: - template_processor = get_template_processor( - database=database - ) - sql = template_processor.process_template( - sql, **template_params - ) - except json.JSONDecodeError: - logger.warning( - "Invalid template parameter %s. Skipping processing", - str(template_params), - ) - result = SQLScript(sql, model.get("engine")).format() + if database: + database_engine = database.db_engine_spec.engine + + if template_params: + try: + template_params = ( + json.loads(template_params) + if isinstance(template_params, str) + else template_params + ) + if template_params: + template_processor = get_template_processor( + database=database + ) + sql = template_processor.process_template( + sql, **template_params + ) + except json.JSONDecodeError: + logger.warning( + "Invalid template parameter %s. Skipping processing", + str(template_params), + ) + + result = SQLScript(sql, model.get("engine", database_engine)).format() return self.response(200, result=result) except ValidationError as error: return self.response_400(message=error.messages) diff --git a/tests/integration_tests/sql_lab/api_tests.py b/tests/integration_tests/sql_lab/api_tests.py index 9454e7c9b10..4e9f50f1310 100644 --- a/tests/integration_tests/sql_lab/api_tests.py +++ b/tests/integration_tests/sql_lab/api_tests.py @@ -39,6 +39,7 @@ from superset.utils.database import ( from superset.utils import core as utils, json from superset.models.sql_lab import Query +from superset.sql.parse import SQLScript from tests.integration_tests.base_tests import SupersetTestCase from tests.integration_tests.constants import ( ADMIN_USERNAME, @@ -288,10 +289,38 @@ class TestSqlLabApi(SupersetTestCase): self.assertDictEqual(resp_data, success_resp) # noqa: PT009 assert rv.status_code == 200 + def test_format_sql_request_with_db_id(self): + self.login(ADMIN_USERNAME) + example_db = get_example_database() + + # IIF is normalized differently per dialect: + # SQLite preserves IIF(), Postgres/base converts to CASE WHEN. + # Compute the expected result from the actual engine so the test is + # environment-independent. + sql = "select IIF(score > 0, 'positive', 'negative') from my_table" + engine = example_db.db_engine_spec.engine + expected = SQLScript(sql, engine).format() + + data = {"sql": sql, "database_id": example_db.id} + rv = self.client.post( + "/api/v1/sqllab/format_sql/", + json=data, + ) + resp_data = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 200 + assert resp_data["result"] == expected + def test_format_sql_request_with_jinja(self): self.login(ADMIN_USERNAME) example_db = get_example_database() + # Quoted identifier formatting varies by dialect (e.g., MySQL uses backticks). + # Compute the expected result from the actual engine so the test is + # environment-independent. + rendered_sql = 'select * from "Vehicle Sales"' + engine = example_db.db_engine_spec.engine + expected = SQLScript(rendered_sql, engine).format() + data = { "sql": "select * from {{tbl}}", "database_id": example_db.id, @@ -302,10 +331,10 @@ class TestSqlLabApi(SupersetTestCase): json=data, ) resp_data = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 200 # Verify that Jinja template was processed before formatting assert "{{tbl}}" not in resp_data["result"] - assert '"Vehicle Sales"' in resp_data["result"] - assert rv.status_code == 200 + assert resp_data["result"] == expected @mock.patch("superset.commands.sql_lab.results.results_backend_use_msgpack", False) def test_execute_required_params(self):