From 2e7a2b1f2dc45b42e41cd95624741063e2109600 Mon Sep 17 00:00:00 2001 From: Shaitan <105581038+sha174n@users.noreply.github.com> Date: Fri, 15 May 2026 14:48:38 +0100 Subject: [PATCH] fix: escape SQL identifiers in db engine spec prequeries and metadata queries (#39840) Co-authored-by: Claude Sonnet 4.6 --- superset/db_engine_specs/bigquery.py | 27 +++--- superset/db_engine_specs/databricks.py | 8 +- superset/db_engine_specs/db2.py | 5 +- superset/db_engine_specs/gsheets.py | 3 +- superset/db_engine_specs/hive.py | 27 ++++-- superset/db_engine_specs/postgres.py | 5 +- superset/db_engine_specs/starrocks.py | 3 +- .../db_engine_specs/test_databricks.py | 11 ++- tests/unit_tests/db_engine_specs/test_db2.py | 3 + tests/unit_tests/db_engine_specs/test_hive.py | 84 +++++++++++++++++++ .../db_engine_specs/test_postgres.py | 3 + .../db_engine_specs/test_starrocks.py | 5 ++ 12 files changed, 152 insertions(+), 32 deletions(-) diff --git a/superset/db_engine_specs/bigquery.py b/superset/db_engine_specs/bigquery.py index de13b1f2666..8a82623e164 100644 --- a/superset/db_engine_specs/bigquery.py +++ b/superset/db_engine_specs/bigquery.py @@ -883,6 +883,14 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met # We will return the original exception return exception + @staticmethod + def _information_schema_ref(schema: str, catalog: str | None) -> str: + escaped_schema = schema.replace("`", "``") + if catalog: + escaped_catalog = catalog.replace("`", "``") + return f"`{escaped_catalog}.{escaped_schema}.INFORMATION_SCHEMA.TABLES`" + return f"`{escaped_schema}.INFORMATION_SCHEMA.TABLES`" + @classmethod def get_materialized_view_names( cls, @@ -899,14 +907,8 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met if not schema: return set() - # Construct the query to get materialized views from INFORMATION_SCHEMA - if catalog := database.get_default_catalog(): - information_schema = f"`{catalog}.{schema}.INFORMATION_SCHEMA.TABLES`" - else: - information_schema = f"`{schema}.INFORMATION_SCHEMA.TABLES`" - - # Use string formatting for the table name since it's not user input - # The catalog and schema are from trusted sources (database configuration) + catalog = database.get_default_catalog() + information_schema = cls._information_schema_ref(schema, catalog) query = f""" SELECT table_name FROM {information_schema} @@ -945,15 +947,8 @@ class BigQueryEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-met if not schema: return set() - # Construct the query to get regular views from INFORMATION_SCHEMA catalog = database.get_default_catalog() - if catalog: - information_schema = f"`{catalog}.{schema}.INFORMATION_SCHEMA.TABLES`" - else: - information_schema = f"`{schema}.INFORMATION_SCHEMA.TABLES`" - - # Use string formatting for the table name since it's not user input - # The catalog and schema are from trusted sources (database configuration) + information_schema = cls._information_schema_ref(schema, catalog) query = f""" SELECT table_name FROM {information_schema} diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 980c297d106..3bc9c0a258a 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -569,11 +569,11 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec): ) -> list[str]: prequeries = [] if catalog: - catalog = f"`{catalog}`" if not catalog.startswith("`") else catalog - prequeries.append(f"USE CATALOG {catalog}") + escaped_catalog = catalog.replace("`", "``") + prequeries.append(f"USE CATALOG `{escaped_catalog}`") if schema: - schema = f"`{schema}`" if not schema.startswith("`") else schema - prequeries.append(f"USE SCHEMA {schema}") + escaped_schema = schema.replace("`", "``") + prequeries.append(f"USE SCHEMA `{escaped_schema}`") return prequeries @classmethod diff --git a/superset/db_engine_specs/db2.py b/superset/db_engine_specs/db2.py index 8994113a740..0e24d341513 100644 --- a/superset/db_engine_specs/db2.py +++ b/superset/db_engine_specs/db2.py @@ -162,4 +162,7 @@ class Db2EngineSpec(BaseEngineSpec): be anything, and we would have to block users from running any queries referencing tables without an explicit schema. """ - return [f'set current_schema "{schema}"'] if schema else [] + if not schema: + return [] + escaped = schema.replace('"', '""') + return [f'set current_schema "{escaped}"'] diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 5eda47e8e6a..138c03f3861 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -268,7 +268,8 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): schema=table.schema, ) as conn: cursor = conn.cursor() - cursor.execute(f'SELECT GET_METADATA("{table.table}")') + escaped_table = table.table.replace('"', '""') + cursor.execute(f'SELECT GET_METADATA("{escaped_table}")') results = cursor.fetchone()[0] try: metadata = json.loads(results) diff --git a/superset/db_engine_specs/hive.py b/superset/db_engine_specs/hive.py index 61ef59272a7..3dffab83e9b 100644 --- a/superset/db_engine_specs/hive.py +++ b/superset/db_engine_specs/hive.py @@ -206,13 +206,24 @@ class HiveEngineSpec(PrestoEngineSpec): if to_sql_kwargs["if_exists"] == "fail": # Ensure table doesn't already exist. + escaped_table = ( + table.table.replace("\\", "\\\\") + .replace("'", "\\'") + .replace("%", "\\%") + .replace("_", "\\_") + ) + # Hive LIKE uses backslash as the escape character. Python needs \\\\ + # to produce the two-character SQL literal \\ (a single backslash). + escape_clause = " ESCAPE '\\\\'" if table.schema: + escaped_schema = table.schema.replace("`", "``") table_exists = not database.get_df( - f"SHOW TABLES IN {table.schema} LIKE '{table.table}'" + f"SHOW TABLES IN `{escaped_schema}`" + f" LIKE '{escaped_table}'{escape_clause}" ).empty else: table_exists = not database.get_df( - f"SHOW TABLES LIKE '{table.table}'" + f"SHOW TABLES LIKE '{escaped_table}'{escape_clause}" ).empty if table_exists: @@ -498,9 +509,12 @@ class HiveEngineSpec(PrestoEngineSpec): order_by: list[tuple[str, bool]] | None = None, filters: dict[Any, Any] | None = None, ) -> str: - full_table_name = ( - f"{table.schema}.{table.table}" if table.schema else table.table - ) + escaped_table = table.table.replace("`", "``") + if table.schema: + escaped_schema = table.schema.replace("`", "``") + full_table_name = f"`{escaped_schema}`.`{escaped_table}`" + else: + full_table_name = f"`{escaped_table}`" return f"SHOW PARTITIONS {full_table_name}" @classmethod @@ -628,7 +642,8 @@ class HiveEngineSpec(PrestoEngineSpec): sql = "SHOW VIEWS" if schema: - sql += f" IN `{schema}`" + escaped_schema = schema.replace("`", "``") + sql += f" IN `{escaped_schema}`" with database.get_raw_connection(schema=schema) as conn: cursor = conn.cursor() diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index b1b265c78eb..4d62b5dd242 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -694,7 +694,10 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): be anything, and we would have to block users from running any queries referencing tables without an explicit schema. """ - return [f'set search_path = "{schema}"'] if schema else [] + if not schema: + return [] + escaped = schema.replace('"', '""') + return [f'set search_path = "{escaped}"'] @classmethod def get_allow_cost_estimate(cls, extra: dict[str, Any]) -> bool: diff --git a/superset/db_engine_specs/starrocks.py b/superset/db_engine_specs/starrocks.py index c485ce216fe..ff00eeb91fb 100644 --- a/superset/db_engine_specs/starrocks.py +++ b/superset/db_engine_specs/starrocks.py @@ -413,6 +413,7 @@ class StarRocksEngineSpec(MySQLEngineSpec): username = database.get_effective_user(database.url_object) if username: - return [f'EXECUTE AS "{username}" WITH NO REVERT;'] + escaped = username.replace('"', '""') + return [f'EXECUTE AS "{escaped}" WITH NO REVERT;'] return [] diff --git a/tests/unit_tests/db_engine_specs/test_databricks.py b/tests/unit_tests/db_engine_specs/test_databricks.py index a5850020097..61366136e6a 100644 --- a/tests/unit_tests/db_engine_specs/test_databricks.py +++ b/tests/unit_tests/db_engine_specs/test_databricks.py @@ -281,6 +281,13 @@ def test_get_prequeries(mocker: MockerFixture) -> None: assert DatabricksNativeEngineSpec.get_prequeries( database, catalog="`escaped-hyphen`", schema="`hyphen-escaped`" ) == [ - "USE CATALOG `escaped-hyphen`", - "USE SCHEMA `hyphen-escaped`", + "USE CATALOG ```escaped-hyphen```", + "USE SCHEMA ```hyphen-escaped```", + ] + + assert DatabricksNativeEngineSpec.get_prequeries( + database, catalog="evil` USE CATALOG bad", schema="evil` USE SCHEMA bad" + ) == [ + "USE CATALOG `evil`` USE CATALOG bad`", + "USE SCHEMA `evil`` USE SCHEMA bad`", ] diff --git a/tests/unit_tests/db_engine_specs/test_db2.py b/tests/unit_tests/db_engine_specs/test_db2.py index 1ae4471f9ff..6f469fb7a20 100644 --- a/tests/unit_tests/db_engine_specs/test_db2.py +++ b/tests/unit_tests/db_engine_specs/test_db2.py @@ -81,6 +81,9 @@ def test_get_prequeries(mocker: MockerFixture) -> None: assert Db2EngineSpec.get_prequeries(database, schema="my_schema") == [ 'set current_schema "my_schema"' ] + assert Db2EngineSpec.get_prequeries(database, schema='evil"; SELECT 1--') == [ + 'set current_schema "evil""; SELECT 1--"' + ] @pytest.mark.parametrize( diff --git a/tests/unit_tests/db_engine_specs/test_hive.py b/tests/unit_tests/db_engine_specs/test_hive.py index ec307154618..f070b805a33 100644 --- a/tests/unit_tests/db_engine_specs/test_hive.py +++ b/tests/unit_tests/db_engine_specs/test_hive.py @@ -99,3 +99,87 @@ SELECT * \nFROM my_schema.my_table LIMIT :param_1 """.strip() ) + + +def test_get_view_names_escapes_schema(mocker: MockerFixture) -> None: + """ + Test that ``get_view_names`` correctly escapes backticks in schema names + within the SHOW VIEWS statement. + """ + from superset.db_engine_specs.hive import HiveEngineSpec + + database = mocker.MagicMock() + inspector = mocker.MagicMock() + + conn = mocker.MagicMock() + cursor = mocker.MagicMock() + cursor.fetchall.return_value = [] + conn.__enter__ = mocker.MagicMock(return_value=conn) + conn.__exit__ = mocker.MagicMock(return_value=False) + conn.cursor.return_value = cursor + database.get_raw_connection.return_value = conn + + HiveEngineSpec.get_view_names(database, inspector, schema="evil` UNION SELECT 1--") + cursor.execute.assert_called_once() + sql = cursor.execute.call_args[0][0] + assert "IN `evil`` UNION SELECT 1--`" in sql + + +def test_df_to_sql_escapes_like_wildcards(mocker: MockerFixture) -> None: + """ + Test that ``df_to_sql`` escapes ``%`` and ``_`` wildcard characters in the + SHOW TABLES LIKE pattern used to detect table existence. + """ + import pandas as pd + + from superset.db_engine_specs.hive import HiveEngineSpec + from superset.exceptions import SupersetException + from superset.sql.parse import Table + + database = mocker.MagicMock() + # Simulate an existing table so df_to_sql raises before reaching the upload path + database.get_df.return_value = pd.DataFrame({"name": ["sales_%_2024"]}) + + with pytest.raises(SupersetException, match="Table already exists"): + HiveEngineSpec.df_to_sql( + database=database, + table=Table("sales_%_2024", "my_schema"), + df=pd.DataFrame({"a": [1]}), + to_sql_kwargs={"if_exists": "fail"}, + ) + + database.get_df.assert_called_once() + sql = database.get_df.call_args[0][0] + assert r"\%" in sql + assert r"\_" in sql + assert "ESCAPE" in sql + + +def test_partition_query_escapes_identifiers() -> None: + """ + Test that ``_partition_query`` correctly backtick-quotes table and schema names + in the SHOW PARTITIONS statement. + """ + from superset.db_engine_specs.hive import HiveEngineSpec + from superset.sql.parse import Table + + result = HiveEngineSpec._partition_query( + table=Table("my_table", "my_schema"), + indexes=[], + database=None, # type: ignore + ) + assert result == "SHOW PARTITIONS `my_schema`.`my_table`" + + result = HiveEngineSpec._partition_query( + table=Table("evil`tbl", "evil`schema"), + indexes=[], + database=None, # type: ignore + ) + assert result == "SHOW PARTITIONS `evil``schema`.`evil``tbl`" + + result = HiveEngineSpec._partition_query( + table=Table("no_schema_tbl"), + indexes=[], + database=None, # type: ignore + ) + assert result == "SHOW PARTITIONS `no_schema_tbl`" diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index 88ce789e131..c043bae3ef8 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -147,6 +147,9 @@ def test_get_prequeries(mocker: MockerFixture) -> None: assert spec.get_prequeries(database) == [] assert spec.get_prequeries(database, schema="test") == ['set search_path = "test"'] + assert spec.get_prequeries(database, schema='evil"; SELECT 1--') == [ + 'set search_path = "evil""; SELECT 1--"' + ] def test_get_default_schema_for_query(mocker: MockerFixture) -> None: diff --git a/tests/unit_tests/db_engine_specs/test_starrocks.py b/tests/unit_tests/db_engine_specs/test_starrocks.py index 257fc6341f3..2951974cf13 100644 --- a/tests/unit_tests/db_engine_specs/test_starrocks.py +++ b/tests/unit_tests/db_engine_specs/test_starrocks.py @@ -169,6 +169,11 @@ def test_impersonation_username(mocker: MockerFixture) -> None: 'EXECUTE AS "alice" WITH NO REVERT;' ] + database.get_effective_user.return_value = 'evil" WITH NO REVERT; DROP TABLE x--' + assert StarRocksEngineSpec.get_prequeries(database) == [ + 'EXECUTE AS "evil"" WITH NO REVERT; DROP TABLE x--" WITH NO REVERT;' + ] + def test_impersonation_disabled(mocker: MockerFixture) -> None: """