diff --git a/superset/common/query_object.py b/superset/common/query_object.py index 2e1e00526a3..e56d795ebb6 100644 --- a/superset/common/query_object.py +++ b/superset/common/query_object.py @@ -360,11 +360,18 @@ class QueryObject: # pylint: disable=too-many-instance-attributes engine = database.db_engine_spec.engine if needs_transpilation: - clause = transpile_to_dialect(clause, engine) + # source_engine=engine ensures idempotency: this + # method can run more than once (validate() is called + # from both raise_for_access and get_df_payload), so + # the second pass must be able to re-parse the + # dialect-specific output (e.g. BigQuery backticks) + # produced by the first pass. + clause = transpile_to_dialect( + clause, engine, source_engine=engine, identify=True + ) sanitized_clause = sanitize_clause(clause, engine) - if sanitized_clause != clause: - self.extras[param] = sanitized_clause + self.extras[param] = sanitized_clause except QueryClauseValidationException as ex: raise QueryObjectValidationError(ex.message) from ex diff --git a/superset/sql/parse.py b/superset/sql/parse.py index 6f07c15c164..bb3ef5e1c4b 100644 --- a/superset/sql/parse.py +++ b/superset/sql/parse.py @@ -1568,6 +1568,7 @@ def transpile_to_dialect( sql: str, target_engine: str, source_engine: str | None = None, + identify: bool = False, ) -> str: """ Transpile SQL from one database dialect to another using SQLGlot. @@ -1576,6 +1577,7 @@ def transpile_to_dialect( sql: The SQL query to transpile target_engine: The target database engine (e.g., "mysql", "postgresql") source_engine: The source database engine. If None, uses generic SQL dialect. + identify: If True, quote all identifiers per the target dialect. Returns: The transpiled SQL string @@ -1598,6 +1600,7 @@ def transpile_to_dialect( copy=True, comments=False, pretty=False, + identify=identify, ) except ParseError as ex: raise QueryClauseValidationException(f"Cannot parse SQL clause: {sql}") from ex diff --git a/tests/unit_tests/sql/transpile_to_dialect_test.py b/tests/unit_tests/sql/transpile_to_dialect_test.py index 5a11e501fad..5c6c35df8b3 100644 --- a/tests/unit_tests/sql/transpile_to_dialect_test.py +++ b/tests/unit_tests/sql/transpile_to_dialect_test.py @@ -396,3 +396,127 @@ def test_transpile_unknown_source_engine_uses_generic() -> None: "SELECT * FROM orders", "postgresql", "unknown_engine" ) assert result == "SELECT * FROM orders" + + +# Tests for identify=True (identifier quoting) +@pytest.mark.parametrize( + "sql,dialect,expected", + [ + # PostgreSQL - double-quoted identifiers + ( + "STATE ILIKE '%AL%'", + "postgresql", + "\"STATE\" ILIKE '%AL%'", + ), + # MySQL - backtick-quoted identifiers, ILIKE transpiled + ( + "STATE ILIKE '%AL%'", + "mysql", + "LOWER(`STATE`) LIKE LOWER('%AL%')", + ), + # BigQuery - backtick-quoted identifiers, ILIKE transpiled + ( + "STATE ILIKE '%AL%'", + "bigquery", + "LOWER(`STATE`) LIKE LOWER('%AL%')", + ), + # Snowflake - double-quoted identifiers + ( + "STATE ILIKE '%AL%'", + "snowflake", + "\"STATE\" ILIKE '%AL%'", + ), + # MSSQL - bracket-quoted identifiers + ( + "STATE = 'CA'", + "mssql", + "[STATE] = 'CA'", + ), + # Compound filter with multiple identifiers + ( + "STATE = 'CA' AND AIRLINE = 'Delta'", + "postgresql", + "\"STATE\" = 'CA' AND \"AIRLINE\" = 'Delta'", + ), + # Lowercase identifiers also get quoted + ( + "name = 'test'", + "postgresql", + "\"name\" = 'test'", + ), + ], +) +def test_identify_quotes_identifiers(sql: str, dialect: str, expected: str) -> None: + """Test that identify=True quotes identifiers per target dialect.""" + assert transpile_to_dialect(sql, dialect, identify=True) == expected + + +def test_identify_unknown_engine_returns_unchanged() -> None: + """Test that identify=True has no effect on unknown engines.""" + sql = "STATE = 'CA'" + assert transpile_to_dialect(sql, "unknown_engine", identify=True) == sql + + +@pytest.mark.parametrize( + "sql,engine,expected", + [ + ( + "STATE ILIKE '%AL%'", + "postgresql", + "\"STATE\" ILIKE '%AL%'", + ), + ( + "country ILIKE '%Italy%'", + "bigquery", + "LOWER(`country`) LIKE LOWER('%Italy%')", + ), + ], +) +def test_identify_with_source_engine(sql: str, engine: str, expected: str) -> None: + """Test identify=True with source_engine matching target engine.""" + result = transpile_to_dialect(sql, engine, source_engine=engine, identify=True) + assert result == expected + + +@pytest.mark.parametrize( + "engine", + ["postgresql", "bigquery", "mysql", "snowflake"], +) +def test_identify_transpilation_is_idempotent(engine: str) -> None: + """Test that transpiling twice produces the same result (idempotent). + + This matters because _sanitize_filters() can be called multiple times + via validate(). + """ + clause = "STATE ILIKE '%AL%'" + pass1 = transpile_to_dialect(clause, engine, source_engine=engine, identify=True) + pass2 = transpile_to_dialect(pass1, engine, source_engine=engine, identify=True) + assert pass1 == pass2 + + +def test_sanitize_filters_writes_back_transpiled_clause() -> None: + """Test that _sanitize_filters always persists the transpiled clause. + + Regression test: a previous conditional `if sanitized_clause != clause` + skipped the write-back when transpile_to_dialect had already modified + the clause, leaving the original unquoted value in extras. + """ + from unittest.mock import MagicMock + + from superset.common.query_object import QueryObject + + mock_datasource = MagicMock() + mock_datasource.database.db_engine_spec.engine = "postgresql" + + query_obj = QueryObject( + datasource=mock_datasource, + columns=["STATE"], + metrics=[], + extras={ + "where": "STATE = 'CA'", + "transpile_to_dialect": True, + }, + ) + query_obj.validate() + + assert '"STATE"' in query_obj.extras["where"]