mirror of
https://github.com/apache/superset.git
synced 2026-04-22 17:45:21 +00:00
feat: prevent Postgres connection to Redshift (#38693)
This commit is contained in:
@@ -82,6 +82,43 @@ COLUMN_DOES_NOT_EXIST_REGEX = re.compile(
|
|||||||
SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P<syntax_error>.*?)"')
|
SYNTAX_ERROR_REGEX = re.compile('syntax error at or near "(?P<syntax_error>.*?)"')
|
||||||
|
|
||||||
|
|
||||||
|
def _check_not_redshift(dbapi_connection: Any, connection_record: Any) -> None:
|
||||||
|
"""
|
||||||
|
Event that checks if database is Amazon Redshift.
|
||||||
|
|
||||||
|
SQLAlchemy pool `connect` event that checks whether the database is actually
|
||||||
|
Amazon Redshift by running `SELECT version()`. Redshift returns a version string
|
||||||
|
containing `Redshift`, e.g.::
|
||||||
|
|
||||||
|
PostgreSQL 8.0.2 on ... Redshift 1.0.77467
|
||||||
|
|
||||||
|
If detected, a `ValueError` is raised so that the user is prompted to
|
||||||
|
switch to the `redshift+psycopg2://` driver, which ensures the correct
|
||||||
|
sqlglot dialect is used for SQL transpilation.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
try:
|
||||||
|
cursor.execute("SELECT version()")
|
||||||
|
version = cursor.fetchone()[0]
|
||||||
|
finally:
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
if "redshift" in str(version).lower():
|
||||||
|
raise ValueError(
|
||||||
|
"It looks like you're connecting to Amazon Redshift using the "
|
||||||
|
"PostgreSQL driver. Please use the Redshift driver instead "
|
||||||
|
"(redshift+psycopg2://) to ensure proper SQL dialect support."
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.debug(
|
||||||
|
"Failed to check database version for Redshift detection",
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def parse_options(connect_args: dict[str, Any]) -> dict[str, str]:
|
def parse_options(connect_args: dict[str, Any]) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Parse ``options`` from ``connect_args`` into a dictionary.
|
Parse ``options`` from ``connect_args`` into a dictionary.
|
||||||
@@ -570,6 +607,15 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
|||||||
|
|
||||||
return uri, connect_args
|
return uri, connect_args
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def mutate_db_for_connection_test(database: Database) -> None:
|
||||||
|
"""
|
||||||
|
Flag the database so that the Redshift `SELECT version()` check
|
||||||
|
runs during `test_connection`. The actual check is injected as a
|
||||||
|
pool `connect` event inside `update_params_from_encrypted_extra`.
|
||||||
|
"""
|
||||||
|
database._check_redshift_version = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def update_params_from_encrypted_extra(
|
def update_params_from_encrypted_extra(
|
||||||
database: Database,
|
database: Database,
|
||||||
@@ -581,6 +627,13 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec):
|
|||||||
Handles AWS IAM authentication if configured, then merges any
|
Handles AWS IAM authentication if configured, then merges any
|
||||||
remaining encrypted_extra keys into params (standard behavior).
|
remaining encrypted_extra keys into params (standard behavior).
|
||||||
"""
|
"""
|
||||||
|
# During test_connection, inject a pool event to detect Redshift by
|
||||||
|
# checking SELECT version(). This catches cases where the hostname
|
||||||
|
# doesn't reveal Redshift (custom domains, private endpoints, etc.).
|
||||||
|
if getattr(database, "_check_redshift_version", False) is True:
|
||||||
|
pool_events = params.setdefault("pool_events", [])
|
||||||
|
pool_events.append((_check_not_redshift, "connect"))
|
||||||
|
|
||||||
if not database.encrypted_extra:
|
if not database.encrypted_extra:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,7 @@
|
|||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest_mock import MockerFixture
|
from pytest_mock import MockerFixture
|
||||||
@@ -25,7 +26,10 @@ from sqlalchemy.dialects.postgresql import DOUBLE_PRECISION, ENUM, JSON
|
|||||||
from sqlalchemy.engine.interfaces import Dialect
|
from sqlalchemy.engine.interfaces import Dialect
|
||||||
from sqlalchemy.engine.url import make_url
|
from sqlalchemy.engine.url import make_url
|
||||||
|
|
||||||
from superset.db_engine_specs.postgres import PostgresEngineSpec as spec # noqa: N813
|
from superset.db_engine_specs.postgres import (
|
||||||
|
_check_not_redshift,
|
||||||
|
PostgresEngineSpec as spec, # noqa: N813
|
||||||
|
)
|
||||||
from superset.exceptions import SupersetSecurityException
|
from superset.exceptions import SupersetSecurityException
|
||||||
from superset.sql.parse import Table
|
from superset.sql.parse import Table
|
||||||
from superset.utils.core import GenericDataType
|
from superset.utils.core import GenericDataType
|
||||||
@@ -280,3 +284,82 @@ SELECT * \nFROM my_schema.my_table
|
|||||||
LIMIT :param_1
|
LIMIT :param_1
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRedshiftDetection:
|
||||||
|
"""
|
||||||
|
Tests for detecting Redshift connections via the PostgreSQL dialect.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def test_check_not_redshift_detects_redshift(self) -> None:
|
||||||
|
"""
|
||||||
|
Pool connect event raises for a Redshift version string.
|
||||||
|
"""
|
||||||
|
cursor = MagicMock()
|
||||||
|
cursor.fetchone.return_value = (
|
||||||
|
"PostgreSQL 8.0.2 on i686-pc-linux-gnu, compiled by GCC gcc (GCC) "
|
||||||
|
"3.4.2 20041017 (Red Hat 3.4.2-6.fc3), Redshift 1.0.77467",
|
||||||
|
)
|
||||||
|
dbapi_conn = MagicMock()
|
||||||
|
dbapi_conn.cursor.return_value = cursor
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="Redshift"):
|
||||||
|
_check_not_redshift(dbapi_conn, None)
|
||||||
|
|
||||||
|
def test_check_not_redshift_allows_postgres(self) -> None:
|
||||||
|
"""
|
||||||
|
Pool connect event allows a regular PostgreSQL version string.
|
||||||
|
"""
|
||||||
|
cursor = MagicMock()
|
||||||
|
cursor.fetchone.return_value = (
|
||||||
|
"PostgreSQL 15.2 on x86_64-pc-linux-gnu, compiled by gcc",
|
||||||
|
)
|
||||||
|
dbapi_conn = MagicMock()
|
||||||
|
dbapi_conn.cursor.return_value = cursor
|
||||||
|
|
||||||
|
_check_not_redshift(dbapi_conn, None) # should not raise
|
||||||
|
|
||||||
|
def test_check_not_redshift_fails_open(self) -> None:
|
||||||
|
"""
|
||||||
|
If SELECT version() errors, the connection is still allowed.
|
||||||
|
"""
|
||||||
|
cursor = MagicMock()
|
||||||
|
cursor.execute.side_effect = Exception("permission denied")
|
||||||
|
dbapi_conn = MagicMock()
|
||||||
|
dbapi_conn.cursor.return_value = cursor
|
||||||
|
|
||||||
|
_check_not_redshift(dbapi_conn, None) # should not raise
|
||||||
|
|
||||||
|
def test_mutate_db_sets_flag(self) -> None:
|
||||||
|
"""
|
||||||
|
mutate_db_for_connection_test sets the check flag.
|
||||||
|
"""
|
||||||
|
database = MagicMock()
|
||||||
|
spec.mutate_db_for_connection_test(database)
|
||||||
|
assert database._check_redshift_version is True
|
||||||
|
|
||||||
|
def test_pool_event_injected_when_flag_set(self, mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Pool event is added during test_connection.
|
||||||
|
"""
|
||||||
|
database = mocker.MagicMock(
|
||||||
|
encrypted_extra=None,
|
||||||
|
_check_redshift_version=True,
|
||||||
|
)
|
||||||
|
params: dict[str, Any] = {}
|
||||||
|
spec.update_params_from_encrypted_extra(database, params)
|
||||||
|
|
||||||
|
assert "pool_events" in params
|
||||||
|
fns = [fn for fn, _ in params["pool_events"]]
|
||||||
|
assert _check_not_redshift in fns
|
||||||
|
|
||||||
|
def test_pool_event_not_injected_without_flag(self, mocker: MockerFixture) -> None:
|
||||||
|
"""
|
||||||
|
Pool event is NOT added during normal operation.
|
||||||
|
"""
|
||||||
|
database = mocker.MagicMock(encrypted_extra=None)
|
||||||
|
database._check_redshift_version = False
|
||||||
|
params: dict[str, Any] = {}
|
||||||
|
spec.update_params_from_encrypted_extra(database, params)
|
||||||
|
|
||||||
|
assert "pool_events" not in params
|
||||||
|
|||||||
Reference in New Issue
Block a user