diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index b744b7b440b..b1b265c78eb 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -82,6 +82,43 @@ COLUMN_DOES_NOT_EXIST_REGEX = re.compile( SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P.*?)"') +def _check_not_redshift(dbapi_connection: Any, connection_record: Any) -> None: + """ + Event that checks if database is Amazon Redshift. + + SQLAlchemy pool `connect` event that checks whether the database is actually + Amazon Redshift by running `SELECT version()`. Redshift returns a version string + containing `Redshift`, e.g.:: + + PostgreSQL 8.0.2 on ... Redshift 1.0.77467 + + If detected, a `ValueError` is raised so that the user is prompted to + switch to the `redshift+psycopg2://` driver, which ensures the correct + sqlglot dialect is used for SQL transpilation. + """ + try: + cursor = dbapi_connection.cursor() + try: + cursor.execute("SELECT version()") + version = cursor.fetchone()[0] + finally: + cursor.close() + + if "redshift" in str(version).lower(): + raise ValueError( + "It looks like you're connecting to Amazon Redshift using the " + "PostgreSQL driver. Please use the Redshift driver instead " + "(redshift+psycopg2://) to ensure proper SQL dialect support." + ) + except ValueError: + raise + except Exception: + logger.debug( + "Failed to check database version for Redshift detection", + exc_info=True, + ) + + def parse_options(connect_args: dict[str, Any]) -> dict[str, str]: """ Parse ``options`` from ``connect_args`` into a dictionary. @@ -570,6 +607,15 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): return uri, connect_args + @staticmethod + def mutate_db_for_connection_test(database: Database) -> None: + """ + Flag the database so that the Redshift `SELECT version()` check + runs during `test_connection`. The actual check is injected as a + pool `connect` event inside `update_params_from_encrypted_extra`. + """ + database._check_redshift_version = True + @staticmethod def update_params_from_encrypted_extra( database: Database, @@ -581,6 +627,13 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): Handles AWS IAM authentication if configured, then merges any remaining encrypted_extra keys into params (standard behavior). """ + # During test_connection, inject a pool event to detect Redshift by + # checking SELECT version(). This catches cases where the hostname + # doesn't reveal Redshift (custom domains, private endpoints, etc.). + if getattr(database, "_check_redshift_version", False) is True: + pool_events = params.setdefault("pool_events", []) + pool_events.append((_check_not_redshift, "connect")) + if not database.encrypted_extra: return diff --git a/tests/unit_tests/db_engine_specs/test_postgres.py b/tests/unit_tests/db_engine_specs/test_postgres.py index ff11dd9aa42..88ce789e131 100644 --- a/tests/unit_tests/db_engine_specs/test_postgres.py +++ b/tests/unit_tests/db_engine_specs/test_postgres.py @@ -17,6 +17,7 @@ from datetime import datetime from typing import Any, Optional +from unittest.mock import MagicMock import pytest from pytest_mock import MockerFixture @@ -25,7 +26,10 @@ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.engine.url import make_url -from superset.db_engine_specs.postgres import PostgresEngineSpec as spec # noqa: N813 +from superset.db_engine_specs.postgres import ( + _check_not_redshift, + PostgresEngineSpec as spec, # noqa: N813 +) from superset.exceptions import SupersetSecurityException from superset.sql.parse import Table from superset.utils.core import GenericDataType @@ -280,3 +284,82 @@ SELECT * \nFROM my_schema.my_table LIMIT :param_1 """.strip() ) + + +class TestRedshiftDetection: + """ + Tests for detecting Redshift connections via the PostgreSQL dialect. + """ + + def test_check_not_redshift_detects_redshift(self) -> None: + """ + Pool connect event raises for a Redshift version string. + """ + cursor = MagicMock() + cursor.fetchone.return_value = ( + "PostgreSQL 8.0.2 on i686-pc-linux-gnu, compiled by GCC gcc (GCC) " + "3.4.2 20041017 (Red Hat 3.4.2-6.fc3), Redshift 1.0.77467", + ) + dbapi_conn = MagicMock() + dbapi_conn.cursor.return_value = cursor + + with pytest.raises(ValueError, match="Redshift"): + _check_not_redshift(dbapi_conn, None) + + def test_check_not_redshift_allows_postgres(self) -> None: + """ + Pool connect event allows a regular PostgreSQL version string. + """ + cursor = MagicMock() + cursor.fetchone.return_value = ( + "PostgreSQL 15.2 on x86_64-pc-linux-gnu, compiled by gcc", + ) + dbapi_conn = MagicMock() + dbapi_conn.cursor.return_value = cursor + + _check_not_redshift(dbapi_conn, None) # should not raise + + def test_check_not_redshift_fails_open(self) -> None: + """ + If SELECT version() errors, the connection is still allowed. + """ + cursor = MagicMock() + cursor.execute.side_effect = Exception("permission denied") + dbapi_conn = MagicMock() + dbapi_conn.cursor.return_value = cursor + + _check_not_redshift(dbapi_conn, None) # should not raise + + def test_mutate_db_sets_flag(self) -> None: + """ + mutate_db_for_connection_test sets the check flag. + """ + database = MagicMock() + spec.mutate_db_for_connection_test(database) + assert database._check_redshift_version is True + + def test_pool_event_injected_when_flag_set(self, mocker: MockerFixture) -> None: + """ + Pool event is added during test_connection. + """ + database = mocker.MagicMock( + encrypted_extra=None, + _check_redshift_version=True, + ) + params: dict[str, Any] = {} + spec.update_params_from_encrypted_extra(database, params) + + assert "pool_events" in params + fns = [fn for fn, _ in params["pool_events"]] + assert _check_not_redshift in fns + + def test_pool_event_not_injected_without_flag(self, mocker: MockerFixture) -> None: + """ + Pool event is NOT added during normal operation. + """ + database = mocker.MagicMock(encrypted_extra=None) + database._check_redshift_version = False + params: dict[str, Any] = {} + spec.update_params_from_encrypted_extra(database, params) + + assert "pool_events" not in params