diff --git a/superset/config.py b/superset/config.py index c81e15ac0b4..a2df28e1019 100644 --- a/superset/config.py +++ b/superset/config.py @@ -52,7 +52,6 @@ from superset.advanced_data_type.plugins.internet_address import internet_addres from superset.advanced_data_type.plugins.internet_port import internet_port from superset.advanced_data_type.types import AdvancedDataType from superset.constants import CHANGE_ME_SECRET_KEY -from superset.engines.manager import EngineModes from superset.jinja_context import BaseTemplateProcessor from superset.key_value.types import JsonKeyValueCodec from superset.stats_logger import DummyStatsLogger @@ -265,22 +264,6 @@ SQLALCHEMY_ENGINE_OPTIONS = {} # SQLALCHEMY_CUSTOM_PASSWORD_STORE = lookup_password SQLALCHEMY_CUSTOM_PASSWORD_STORE = None -# --------------------------------------------------------- -# Engine Manager Configuration -# --------------------------------------------------------- - -# Engine manager mode: "NEW" creates a new engine for every connection (default), -# "SINGLETON" reuses engines with connection pooling -ENGINE_MANAGER_MODE = EngineModes.NEW - -# Cleanup interval for abandoned locks (default: 5 minutes) -ENGINE_MANAGER_CLEANUP_INTERVAL = timedelta(minutes=5) - -# Automatically start cleanup thread for SINGLETON mode (default: True) -ENGINE_MANAGER_AUTO_START_CLEANUP = True - -# --------------------------------------------------------- - # # The EncryptedFieldTypeAdapter is used whenever we're building SqlAlchemy models # which include sensitive fields that should be app-encrypted BEFORE sending diff --git a/superset/engines/manager.py b/superset/engines/manager.py index c5b539c22d7..52597a260ad 100644 --- a/superset/engines/manager.py +++ b/superset/engines/manager.py @@ -15,10 +15,6 @@ # specific language governing permissions and limitations # under the License. -import enum -import hashlib -import logging -import threading from contextlib import contextmanager from datetime import timedelta from io import StringIO @@ -26,7 +22,7 @@ from typing import Any, Iterator, TYPE_CHECKING import sshtunnel from paramiko import RSAKey -from sqlalchemy import create_engine, event, pool +from sqlalchemy import create_engine, pool from sqlalchemy.engine import Engine from sqlalchemy.engine.url import URL from sshtunnel import SSHTunnelForwarder @@ -41,158 +37,24 @@ if TYPE_CHECKING: from superset.models.core import Database -logger = logging.getLogger(__name__) - - -class _LockManager: - """ - Manages per-key locks safely without defaultdict race conditions. - - This class provides a thread-safe way to create and manage locks for specific keys, - avoiding the race conditions that occur when using defaultdict with threading.Lock. - - The implementation uses a two-level locking strategy: - 1. A meta-lock to protect the lock dictionary itself - 2. Per-key locks to protect specific resources - - This ensures that: - - Different keys can be locked concurrently (scalability) - - Lock creation is thread-safe (no race conditions) - - The same key always gets the same lock instance - """ - - def __init__(self) -> None: - self._locks: dict[str, threading.RLock] = {} - self._meta_lock = threading.Lock() - - def get_lock(self, key: str) -> threading.RLock: - """ - Get or create a lock for the given key. - - This method uses double-checked locking to ensure thread safety: - 1. First check without lock (fast path) - 2. Acquire meta-lock if needed - 3. Double-check inside the lock to prevent race conditions - - This approach minimizes lock contention while ensuring correctness. - - :param key: The key to get a lock for - :returns: An RLock instance for the given key - """ - if lock := self._locks.get(key): - return lock - - with self._meta_lock: - # Double-check inside the lock - lock = self._locks.get(key) - if lock is None: - lock = threading.RLock() - self._locks[key] = lock - return lock - - def cleanup(self, active_keys: set[str]) -> None: - """ - Remove locks for keys that are no longer in use. - - This prevents memory leaks from accumulating locks for resources - that have been disposed. - - :param active_keys: Set of keys that are still active - """ - with self._meta_lock: - # Find locks to remove - locks_to_remove = self._locks.keys() - active_keys - for key in locks_to_remove: - self._locks.pop(key, None) - - -EngineKey = str -TunnelKey = str - - -def _generate_cache_key(*args: Any) -> str: - """ - Generate a deterministic cache key from arbitrary arguments. - - Uses repr() for serialization and SHA-256 for hashing. The resulting key - is a 32-character hex string that: - 1. Is deterministic for the same inputs - 2. Does not expose sensitive data (everything is hashed) - 3. Has sufficient entropy to avoid collisions - - :param args: Arguments to include in the cache key - :returns: 32-character hex string - """ - # Use repr() which works with most Python objects and is deterministic - serialized = repr(args).encode("utf-8") - return hashlib.sha256(serialized).hexdigest()[:32] - - -class EngineModes(enum.Enum): - # reuse existing engine if available, otherwise create a new one; this mode should - # have a connection pool configured in the database - SINGLETON = enum.auto() - - # always create a new engine for every connection; this mode will use a NullPool - # and is the default behavior for Superset - NEW = enum.auto() - - class EngineManager: - """ - A manager for SQLAlchemy engines. - - This class handles the creation and management of SQLAlchemy engines, allowing them - to be configured with connection pools and reused across requests. The default mode - is the original behavior for Superset, where we create a new engine for every - connection, using a NullPool. The `SINGLETON` mode, on the other hand, allows for - reusing of the engines, as well as configuring the pool through the database - settings. - """ + """Centralized SQLAlchemy engine creation for Superset.""" def __init__( self, engine_context_manager: EngineContextManager, db_connection_mutator: DBConnectionMutator | None = None, - mode: EngineModes = EngineModes.NEW, - cleanup_interval: timedelta = timedelta(minutes=5), local_bind_address: str = "127.0.0.1", tunnel_timeout: timedelta = timedelta(seconds=30), ssh_timeout: timedelta = timedelta(seconds=1), ) -> None: self.engine_context_manager = engine_context_manager self.db_connection_mutator = db_connection_mutator - self.mode = mode - self.cleanup_interval = cleanup_interval self.local_bind_address = local_bind_address sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout.total_seconds() sshtunnel.SSH_TIMEOUT = ssh_timeout.total_seconds() - self._engines: dict[EngineKey, Engine] = {} - self._engine_locks = _LockManager() - self._tunnels: dict[TunnelKey, SSHTunnelForwarder] = {} - self._tunnel_locks = _LockManager() - - # Background cleanup thread management - self._cleanup_thread: threading.Thread | None = None - self._cleanup_stop_event = threading.Event() - self._cleanup_thread_lock = threading.Lock() - - def __del__(self) -> None: - """ - Ensure cleanup thread is stopped when the manager is destroyed. - """ - try: - self.stop_cleanup_thread() - except Exception as ex: - # Avoid exceptions during garbage collection, but log if possible - try: - logger.warning("Error stopping cleanup thread: %s", ex) - except Exception: # noqa: S110 - # If logging fails during destruction, we can't do anything - pass - @contextmanager def get_engine( self, @@ -201,103 +63,31 @@ class EngineManager: schema: str | None, source: QuerySource | None, ) -> Iterator[Engine]: - """ - Context manager to get a SQLAlchemy engine. - """ - # users can wrap the engine in their own context manager for different - # reasons + """Context manager to get a SQLAlchemy engine.""" + from superset.utils.oauth2 import check_for_oauth2 + with self.engine_context_manager(database, catalog, schema): - # we need to check for errors indicating that OAuth2 is needed, and - # return the proper exception so it starts the authentication flow - from superset.utils.oauth2 import check_for_oauth2 - with check_for_oauth2(database): - yield self._get_engine(database, catalog, schema, source) + uri, kwargs = self._get_engine_args( + database, + catalog, + schema, + source, + get_user_id(), + ) - def _get_engine( - self, - database: "Database", - catalog: str | None, - schema: str | None, - source: QuerySource | None, - ) -> Engine: - """ - Get a specific engine, or create it if none exists. - """ - source = source or get_query_source_from_request() - user_id = get_user_id() - - if self.mode == EngineModes.NEW: - return self._create_engine( - database, - catalog, - schema, - source, - user_id, - ) - - engine_key = self._get_engine_key( - database, - catalog, - schema, - source, - user_id, - ) - - if engine := self._engines.get(engine_key): - return engine - - lock = self._engine_locks.get_lock(engine_key) - with lock: - # Double-check inside the lock - if engine := self._engines.get(engine_key): - return engine - - # Create and cache the engine - engine = self._create_engine( - database, - catalog, - schema, - source, - user_id, - ) - self._engines[engine_key] = engine - self._add_disposal_listener(engine, engine_key) - return engine - - def _get_engine_key( - self, - database: "Database", - catalog: str | None, - schema: str | None, - source: QuerySource | None, - user_id: int | None, - ) -> EngineKey: - """ - Generate a cache key for the engine. - - The key is a hash of all parameters that affect the engine, ensuring - proper cache isolation without exposing sensitive data. - - :returns: 32-character hex string - """ - uri, kwargs = self._get_engine_args( - database, - catalog, - schema, - source, - user_id, - ) - - return _generate_cache_key( - database.id, - catalog, - schema, - str(uri), - source, - user_id, - kwargs, - ) + if database.ssh_tunnel: + tunnel = self._create_tunnel(database.ssh_tunnel, uri) + try: + uri = uri.set( + host=tunnel.local_bind_address[0], + port=tunnel.local_bind_port, + ) + yield self._create_engine(database, uri, kwargs) + finally: + tunnel.stop() + else: + yield self._create_engine(database, uri, kwargs) def _get_engine_args( self, @@ -307,39 +97,15 @@ class EngineManager: source: QuerySource | None, user_id: int | None, ) -> tuple[URL, dict[str, Any]]: - """ - Build the almost final SQLAlchemy URI and engine kwargs. - - "Almost" final because we may still need to mutate the URI if an SSH tunnel is - needed, since it needs to connect to the tunnel instead of the original DB. But - that information is only available after the tunnel is created. - """ - # Import here to avoid circular imports + """Build SQLAlchemy URI and kwargs before engine creation.""" from superset.extensions import security_manager from superset.utils.feature_flag_manager import FeatureFlagManager uri = make_url_safe(database.sqlalchemy_uri_decrypted) - extra = database.get_extra(source) - # Make a copy to avoid mutating the original extra dict kwargs = dict(extra.get("engine_params", {})) + kwargs["poolclass"] = pool.NullPool - # get pool class - if self.mode == EngineModes.NEW or "poolclass" not in kwargs: - kwargs["poolclass"] = pool.NullPool - else: - pools = { - "queue": pool.QueuePool, - "singleton": pool.SingletonThreadPool, - "assertion": pool.AssertionPool, - "null": pool.NullPool, - "static": pool.StaticPool, - } - pool_name = kwargs["poolclass"] - if isinstance(pool_name, str): - kwargs["poolclass"] = pools.get(pool_name, pool.QueuePool) - - # update URI for specific catalog/schema connect_args = kwargs.setdefault("connect_args", {}) uri, connect_args = database.db_engine_spec.adjust_engine_params( uri, @@ -348,9 +114,7 @@ class EngineManager: schema, ) - # get effective username username = database.get_effective_user(uri) - feature_flag_manager = FeatureFlagManager() if username and feature_flag_manager.is_feature_enabled( "IMPERSONATE_WITH_EMAIL_PREFIX" @@ -359,10 +123,8 @@ class EngineManager: if user and user.email and "@" in user.email: username = user.email.split("@")[0] - # update URI/kwargs for user impersonation if database.impersonate_user: oauth2_config = database.get_oauth2_config() - # Import here to avoid circular imports from superset.utils.oauth2 import get_oauth2_access_token access_token = ( @@ -384,15 +146,10 @@ class EngineManager: kwargs, ) - # update kwargs from params stored encrupted at rest database.update_params_from_encrypted_extra(kwargs) - # mutate URI if self.db_connection_mutator: source = source or get_query_source_from_request() - # Import here to avoid circular imports - from superset.extensions import security_manager - uri, kwargs = self.db_connection_mutator( uri, kwargs, @@ -401,111 +158,27 @@ class EngineManager: source, ) - # validate final URI database.db_engine_spec.validate_database_uri(uri) - return uri, kwargs def _create_engine( self, database: "Database", - catalog: str | None, - schema: str | None, - source: QuerySource | None, - user_id: int | None, + uri: URL, + kwargs: dict[str, Any], ) -> Engine: - """ - Create the actual engine. - - This should be the only place in Superset where a SQLAlchemy engine is created, - """ - uri, kwargs = self._get_engine_args( - database, - catalog, - schema, - source, - user_id, - ) - - if database.ssh_tunnel: - tunnel = self._get_tunnel(database.ssh_tunnel, uri) - uri = uri.set( - host=tunnel.local_bind_address[0], - port=tunnel.local_bind_port, - ) - try: - engine = create_engine(uri, **kwargs) + return create_engine(uri, **kwargs) except Exception as ex: raise database.db_engine_spec.get_dbapi_mapped_exception(ex) from ex - return engine - - def _get_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder: - tunnel_key = self._get_tunnel_key(ssh_tunnel, uri) - - tunnel = self._tunnels.get(tunnel_key) - if tunnel is not None and tunnel.is_active: - return tunnel - - lock = self._tunnel_locks.get_lock(tunnel_key) - with lock: - # Double-check inside the lock - tunnel = self._tunnels.get(tunnel_key) - if tunnel is not None and tunnel.is_active: - return tunnel - - # Create or replace tunnel - return self._replace_tunnel(tunnel_key, ssh_tunnel, uri, tunnel) - - def _replace_tunnel( - self, - tunnel_key: str, - ssh_tunnel: "SSHTunnel", - uri: URL, - old_tunnel: SSHTunnelForwarder | None, - ) -> SSHTunnelForwarder: - """ - Replace tunnel with proper cleanup. - - This function assumes caller holds lock. - """ - if old_tunnel: - try: - old_tunnel.stop() - except Exception: - logger.exception("Error stopping old tunnel") - - try: - new_tunnel = self._create_tunnel(ssh_tunnel, uri) - self._tunnels[tunnel_key] = new_tunnel - except Exception: - # Remove failed tunnel from cache - self._tunnels.pop(tunnel_key, None) - logger.exception("Failed to create tunnel") - raise - - return new_tunnel - - def _get_tunnel_key(self, ssh_tunnel: "SSHTunnel", uri: URL) -> TunnelKey: - """ - Generate a cache key for the SSH tunnel. - - :returns: 32-character hex string - """ - tunnel_kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri) - return _generate_cache_key(tunnel_kwargs) - def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder: kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri) - # Use open_tunnel which handles debug_level properly tunnel = sshtunnel.open_tunnel(**kwargs) tunnel.start() - return tunnel def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]: - # Import here to avoid circular imports from superset.utils.ssh_tunnel import get_default_port backend = uri.get_backend_name() @@ -518,7 +191,6 @@ class EngineManager: "ssh_username": ssh_tunnel.username, "remote_bind_address": (uri.host, port), "local_bind_address": (self.local_bind_address,), - "debug_level": logging.getLogger("flask_appbuilder").level, } if ssh_tunnel.password: @@ -531,107 +203,4 @@ class EngineManager: ) kwargs["ssh_pkey"] = private_key - if self.mode == EngineModes.NEW: - kwargs["set_keepalive"] = 0 # disable keepalive for one-time tunnels - return kwargs - - def start_cleanup_thread(self) -> None: - """ - Start the background cleanup thread. - - The thread will periodically clean up abandoned locks at the configured - interval. This is safe to call multiple times - subsequent calls are no-ops. - """ - with self._cleanup_thread_lock: - if self._cleanup_thread is None or not self._cleanup_thread.is_alive(): - self._cleanup_stop_event.clear() - self._cleanup_thread = threading.Thread( - target=self._cleanup_worker, - name=f"EngineManager-cleanup-{id(self)}", - daemon=True, - ) - self._cleanup_thread.start() - logger.info( - "Started cleanup thread with %ds interval", - self.cleanup_interval.total_seconds(), - ) - - def stop_cleanup_thread(self) -> None: - """ - Stop the background cleanup thread gracefully. - - This will signal the thread to stop and wait for it to finish. - Safe to call even if no thread is running. - """ - with self._cleanup_thread_lock: - if self._cleanup_thread is not None and self._cleanup_thread.is_alive(): - self._cleanup_stop_event.set() - self._cleanup_thread.join(timeout=5.0) # 5 second timeout - if self._cleanup_thread.is_alive(): - logger.warning("Cleanup thread did not stop within timeout") - else: - logger.info("Cleanup thread stopped") - self._cleanup_thread = None - - def _cleanup_worker(self) -> None: - """ - Background thread worker that periodically cleans up abandoned locks. - """ - while not self._cleanup_stop_event.is_set(): - try: - self._cleanup_abandoned_locks() - except Exception: - logger.exception("Error during background cleanup") - - # Use wait() instead of sleep() to allow for immediate shutdown - if self._cleanup_stop_event.wait( - timeout=self.cleanup_interval.total_seconds() - ): - break # Stop event was set - - def cleanup(self) -> None: - """ - Public method to manually trigger cleanup of abandoned locks. - - This can be called periodically by external systems to prevent - memory leaks from accumulating locks. - """ - self._cleanup_abandoned_locks() - - def _cleanup_abandoned_locks(self) -> None: - """ - Clean up locks for engines and tunnels that no longer exist. - - This prevents memory leaks from accumulating locks when engines/tunnels - are disposed outside of normal cleanup paths. - """ - # Clean up engine locks for inactive engines - active_engine_keys = set(self._engines.keys()) - self._engine_locks.cleanup(active_engine_keys) - - # Clean up tunnel locks for inactive tunnels - active_tunnel_keys = set(self._tunnels.keys()) - self._tunnel_locks.cleanup(active_tunnel_keys) - - # Log for debugging - if active_engine_keys or active_tunnel_keys: - logger.debug( - "EngineManager resources - Engines: %d, Tunnels: %d", - len(active_engine_keys), - len(active_tunnel_keys), - ) - - def _add_disposal_listener(self, engine: Engine, engine_key: EngineKey) -> None: - @event.listens_for(engine, "engine_disposed") - def on_engine_disposed(engine_instance: Engine) -> None: - try: - # Remove engine from cache - no per-key locks to clean up anymore - if self._engines.pop(engine_key, None): - # Log only first 8 chars of hash for safety - # (still enough for debugging, but doesn't expose full key) - log_key = engine_key[:8] + "..." - logger.info("Engine disposed and removed from cache: %s", log_key) - except Exception as ex: - logger.error("Error during engine disposal cleanup: %s", str(ex)) - # Don't log engine_key to avoid exposing credential hash diff --git a/superset/extensions/engine_manager.py b/superset/extensions/engine_manager.py index f81afc28242..d775d071f16 100644 --- a/superset/extensions/engine_manager.py +++ b/superset/extensions/engine_manager.py @@ -15,13 +15,12 @@ # specific language governing permissions and limitations # under the License. -import atexit import logging from datetime import timedelta from flask import Flask -from superset.engines.manager import EngineManager, EngineModes +from superset.engines.manager import EngineManager logger = logging.getLogger(__name__) @@ -29,10 +28,6 @@ logger = logging.getLogger(__name__) class EngineManagerExtension: """ Flask extension for managing SQLAlchemy engines in Superset. - - This extension creates and configures an EngineManager instance based on - Flask configuration, handling startup and shutdown of background cleanup - threads as needed. """ def __init__(self) -> None: @@ -44,47 +39,23 @@ class EngineManagerExtension: """ engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"] db_connection_mutator = app.config["DB_CONNECTION_MUTATOR"] - mode = app.config["ENGINE_MANAGER_MODE"] - cleanup_interval = app.config["ENGINE_MANAGER_CLEANUP_INTERVAL"] local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"] tunnel_timeout = timedelta(seconds=app.config["SSH_TUNNEL_TIMEOUT_SEC"]) ssh_timeout = timedelta(seconds=app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]) - auto_start_cleanup = app.config["ENGINE_MANAGER_AUTO_START_CLEANUP"] - - # Stop any existing manager's cleanup thread before creating a new one - if self.engine_manager: - self.engine_manager.stop_cleanup_thread() # Create the engine manager self.engine_manager = EngineManager( engine_context_manager, db_connection_mutator, - mode, - cleanup_interval, local_bind_address, tunnel_timeout, ssh_timeout, ) - # Start cleanup thread if requested and in SINGLETON mode - if auto_start_cleanup and mode == EngineModes.SINGLETON: - self.engine_manager.start_cleanup_thread() - logger.info("Started EngineManager cleanup thread") - - # Register shutdown handler - def shutdown_engine_manager(exc: BaseException | None = None) -> None: - if self.engine_manager: - self.engine_manager.stop_cleanup_thread() - - app.teardown_appcontext(shutdown_engine_manager) - - # Register with atexit for clean shutdown - atexit.register(shutdown_engine_manager) - logger.info( - "Initialized EngineManager with mode=%s, cleanup_interval=%ds", - mode, - cleanup_interval.total_seconds(), + "Initialized EngineManager with tunnel_timeout=%s, ssh_timeout=%s", + tunnel_timeout.total_seconds(), + ssh_timeout.total_seconds(), ) @property diff --git a/superset/models/core.py b/superset/models/core.py index f0f85b32cec..59626e083a4 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -429,8 +429,8 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: Context manager for a SQLAlchemy engine. This method will return a context manager for a SQLAlchemy engine. The engine - manager handles connection pooling, SSH tunnels, and other connection details - based on the configured mode (NEW or SINGLETON). + manager handles engine creation, SSH tunnels, and connection details in a + centralized place. """ # Import here to avoid circular imports from superset.extensions import engine_manager_extension diff --git a/tests/unit_tests/engines/manager_test.py b/tests/unit_tests/engines/manager_test.py index e0abf88f57e..11382b4b37f 100644 --- a/tests/unit_tests/engines/manager_test.py +++ b/tests/unit_tests/engines/manager_test.py @@ -15,513 +15,95 @@ # specific language governing permissions and limitations # under the License. -"""Unit tests for EngineManager.""" - -import threading -from collections.abc import Iterator +from contextlib import contextmanager from unittest.mock import MagicMock, patch import pytest +from sqlalchemy import pool +from sqlalchemy.engine.url import make_url -from superset.engines.manager import _LockManager, EngineManager, EngineModes +from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError +from superset.engines.manager import EngineManager -class TestLockManager: - """Test the _LockManager class.""" - - def test_get_lock_creates_new_lock(self): - """Test that get_lock creates a new lock when needed.""" - manager = _LockManager() - lock1 = manager.get_lock("key1") - - assert isinstance(lock1, type(threading.RLock())) - assert lock1 is manager.get_lock("key1") # Same lock returned - - def test_get_lock_different_keys_different_locks(self): - """Test that different keys get different locks.""" - manager = _LockManager() - lock1 = manager.get_lock("key1") - lock2 = manager.get_lock("key2") - - assert lock1 is not lock2 - - def test_cleanup_removes_unused_locks(self): - """Test that cleanup removes locks for inactive keys.""" - manager = _LockManager() - - # Create locks - _ = manager.get_lock("key1") # noqa: F841 - lock2 = manager.get_lock("key2") - - # Cleanup with only key1 active - manager.cleanup({"key1"}) - - # key2 lock should be removed - lock3 = manager.get_lock("key2") - assert lock3 is not lock2 # New lock created - - def test_concurrent_lock_creation(self): - """Test that concurrent lock creation doesn't create duplicates.""" - manager = _LockManager() - locks_created = [] - exceptions = [] - - def create_lock(): - try: - lock = manager.get_lock("concurrent_key") - locks_created.append(lock) - except Exception as e: - exceptions.append(e) - - # Create multiple threads trying to get the same lock - threads = [threading.Thread(target=create_lock) for _ in range(10)] - for t in threads: - t.start() - for t in threads: - t.join() - - assert len(exceptions) == 0 - assert len(locks_created) == 10 - - # All should be the same lock - first_lock = locks_created[0] - for lock in locks_created[1:]: - assert lock is first_lock - - -class TestEngineManager: - """Test the EngineManager class.""" - - @pytest.fixture - def engine_manager(self): - """Create a mock EngineManager instance.""" - from contextlib import contextmanager - - @contextmanager - def dummy_context_manager( - database: MagicMock, catalog: str | None, schema: str | None - ) -> Iterator[None]: - yield - - return EngineManager(engine_context_manager=dummy_context_manager) - - @pytest.fixture - def mock_database(self): - """Create a mock database.""" - database = MagicMock() - database.sqlalchemy_uri_decrypted = "postgresql://user:pass@localhost/test" - database.get_extra.return_value = {"engine_params": {}} - database.get_effective_user.return_value = "test_user" - database.impersonate_user = False - database.update_params_from_encrypted_extra = MagicMock() - database.db_engine_spec = MagicMock() - database.db_engine_spec.adjust_engine_params.return_value = (MagicMock(), {}) - database.db_engine_spec.impersonate_user = MagicMock( - return_value=(MagicMock(), {}) - ) - database.db_engine_spec.validate_database_uri = MagicMock() - database.ssh_tunnel = None - return database - - @patch("superset.engines.manager.create_engine") - @patch("superset.engines.manager.make_url_safe") - def test_get_engine_new_mode( - self, mock_make_url, mock_create_engine, engine_manager, mock_database +@pytest.fixture +def engine_manager() -> EngineManager: + @contextmanager + def dummy_context_manager( + database: MagicMock, + catalog: str | None, + schema: str | None, ): - """Test getting an engine in NEW mode (no caching).""" - engine_manager.mode = EngineModes.NEW - - mock_make_url.return_value = MagicMock() - mock_engine1 = MagicMock() - mock_engine2 = MagicMock() - mock_create_engine.side_effect = [mock_engine1, mock_engine2] - - result = engine_manager._get_engine(mock_database, "catalog1", "schema1", None) - - assert result is mock_engine1 - mock_create_engine.assert_called_once() - - # Calling again should create a new engine (no caching) - mock_create_engine.reset_mock() - result2 = engine_manager._get_engine(mock_database, "catalog2", "schema2", None) - - assert result2 is mock_engine2 # Different engine - mock_create_engine.assert_called_once() - - @patch("superset.engines.manager.create_engine") - @patch("superset.engines.manager.make_url_safe") - def test_get_engine_singleton_mode_caching( - self, mock_make_url, mock_create_engine, engine_manager, mock_database - ): - """Test that engines are cached in SINGLETON mode.""" - engine_manager.mode = EngineModes.SINGLETON - - # Use a real engine instead of MagicMock to avoid event listener issues - from sqlalchemy import create_engine - from sqlalchemy.pool import StaticPool - - real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool) - mock_create_engine.return_value = real_engine - mock_make_url.return_value = real_engine - - # Call twice with same params - should be cached - result1 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None) - result2 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None) - - assert result1 is result2 # Same engine returned (cached) - mock_create_engine.assert_called_once() # Only created once - - # Call with different params - should create new engine - - @patch("superset.engines.manager.create_engine") - @patch("superset.engines.manager.make_url_safe") - def test_concurrent_engine_creation( - self, mock_make_url, mock_create_engine, engine_manager, mock_database - ): - """Test concurrent engine creation doesn't create duplicates.""" - engine_manager.mode = EngineModes.SINGLETON - - # Use a real engine to avoid event listener issues with MagicMock - from sqlalchemy import create_engine - from sqlalchemy.pool import StaticPool - - real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool) - mock_make_url.return_value = real_engine - - create_count = [0] - - def counting_create_engine(*args, **kwargs): - create_count[0] += 1 - return real_engine - - mock_create_engine.side_effect = counting_create_engine - - results = [] - exceptions = [] - - def get_engine_thread(): - try: - engine = engine_manager._get_engine( - mock_database, "catalog1", "schema1", None - ) - results.append(engine) - except Exception as e: - exceptions.append(e) - - # Run multiple threads - threads = [threading.Thread(target=get_engine_thread) for _ in range(10)] - for t in threads: - t.start() - for t in threads: - t.join() - - assert len(exceptions) == 0 - assert len(results) == 10 - assert create_count[0] == 1 # Engine created only once - - # All results should be the same engine - for engine in results: - assert engine is real_engine - - @patch("superset.engines.manager.sshtunnel.open_tunnel") - def test_ssh_tunnel_creation(self, mock_open_tunnel, engine_manager): - """Test SSH tunnel creation and caching.""" - ssh_tunnel = MagicMock() - ssh_tunnel.server_address = "ssh.example.com" - ssh_tunnel.server_port = 22 - ssh_tunnel.username = "ssh_user" - ssh_tunnel.password = "ssh_pass" # noqa: S105 - ssh_tunnel.private_key = None - ssh_tunnel.private_key_password = None - - tunnel_instance = MagicMock() - tunnel_instance.is_active = True - tunnel_instance.local_bind_address = ("127.0.0.1", 12345) - mock_open_tunnel.return_value = tunnel_instance - - uri = MagicMock() - uri.host = "db.example.com" - uri.port = 5432 - uri.get_backend_name.return_value = "postgresql" - - result = engine_manager._get_tunnel(ssh_tunnel, uri) - - assert result is tunnel_instance - mock_open_tunnel.assert_called_once() - tunnel_instance.start.assert_called_once() - - # Getting same tunnel again should return cached version - mock_open_tunnel.reset_mock() - result2 = engine_manager._get_tunnel(ssh_tunnel, uri) - - assert result2 is tunnel_instance - mock_open_tunnel.assert_not_called() - - @patch("superset.engines.manager.sshtunnel.open_tunnel") - def test_ssh_tunnel_recreation_when_inactive( - self, mock_open_tunnel, engine_manager - ): - """Test that inactive tunnels are replaced.""" - ssh_tunnel = MagicMock() - ssh_tunnel.server_address = "ssh.example.com" - ssh_tunnel.server_port = 22 - ssh_tunnel.username = "ssh_user" - ssh_tunnel.password = "ssh_pass" # noqa: S105 - ssh_tunnel.private_key = None - ssh_tunnel.private_key_password = None - - # First tunnel is inactive - inactive_tunnel = MagicMock() - inactive_tunnel.is_active = False - inactive_tunnel.local_bind_address = ("127.0.0.1", 12345) - - # Second tunnel is active - active_tunnel = MagicMock() - active_tunnel.is_active = True - active_tunnel.local_bind_address = ("127.0.0.1", 23456) - - mock_open_tunnel.side_effect = [inactive_tunnel, active_tunnel] - - uri = MagicMock() - uri.host = "db.example.com" - uri.port = 5432 - uri.get_backend_name.return_value = "postgresql" - - # First call creates inactive tunnel - result1 = engine_manager._get_tunnel(ssh_tunnel, uri) - assert result1 is inactive_tunnel - - # Second call should create new tunnel since first is inactive - result2 = engine_manager._get_tunnel(ssh_tunnel, uri) - assert result2 is active_tunnel - assert mock_open_tunnel.call_count == 2 - - @patch("superset.engines.manager.create_engine") - @patch("superset.engines.manager.make_url_safe") - def test_get_engine_args_basic( - self, mock_make_url, mock_create_engine, engine_manager - ): - """Test _get_engine_args returns correct URI and kwargs.""" - from sqlalchemy.engine.url import make_url - - from superset.engines.manager import EngineModes - - engine_manager.mode = EngineModes.NEW - - mock_uri = make_url("trino://") - mock_make_url.return_value = mock_uri - - database = MagicMock() - database.id = 1 - database.sqlalchemy_uri_decrypted = "trino://" - database.get_extra.return_value = { - "engine_params": {}, - "connect_args": {"source": "Apache Superset"}, - } - database.get_effective_user.return_value = "alice" - database.impersonate_user = False - database.update_params_from_encrypted_extra = MagicMock() - database.db_engine_spec = MagicMock() - database.db_engine_spec.adjust_engine_params.return_value = ( - mock_uri, - {"source": "Apache Superset"}, - ) - database.db_engine_spec.validate_database_uri = MagicMock() - - uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None) - - assert str(uri) == "trino://" - assert "connect_args" in database.get_extra.return_value - - @patch("superset.engines.manager.create_engine") - @patch("superset.engines.manager.make_url_safe") - def test_get_engine_args_user_impersonation( - self, mock_make_url, mock_create_engine, engine_manager - ): - """Test user impersonation in _get_engine_args.""" - from sqlalchemy.engine.url import make_url - - from superset.engines.manager import EngineModes - - engine_manager.mode = EngineModes.NEW - - mock_uri = make_url("trino://") - mock_make_url.return_value = mock_uri - - database = MagicMock() - database.id = 1 - database.sqlalchemy_uri_decrypted = "trino://" - database.get_extra.return_value = { - "engine_params": {}, - "connect_args": {"source": "Apache Superset"}, - } - database.get_effective_user.return_value = "alice" - database.impersonate_user = True - database.get_oauth2_config.return_value = None - database.update_params_from_encrypted_extra = MagicMock() - database.db_engine_spec = MagicMock() - database.db_engine_spec.adjust_engine_params.return_value = ( - mock_uri, - {"source": "Apache Superset"}, - ) - database.db_engine_spec.impersonate_user.return_value = ( - mock_uri, - {"connect_args": {"user": "alice", "source": "Apache Superset"}}, - ) - database.db_engine_spec.validate_database_uri = MagicMock() - - uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None) - - # Verify impersonate_user was called - database.db_engine_spec.impersonate_user.assert_called_once() - call_args = database.db_engine_spec.impersonate_user.call_args - assert call_args[0][0] is database # database - assert call_args[0][1] == "alice" # username - assert call_args[0][2] is None # access_token (no OAuth2) - - @patch("superset.engines.manager.create_engine") - @patch("superset.engines.manager.make_url_safe") - def test_get_engine_args_user_impersonation_email_prefix( - self, - mock_make_url, - mock_create_engine, - engine_manager, - ): - """Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature flag.""" - from sqlalchemy.engine.url import make_url - - from superset.engines.manager import EngineModes - - engine_manager.mode = EngineModes.NEW - - mock_uri = make_url("trino://") - mock_make_url.return_value = mock_uri - - # Mock user with email - mock_user = MagicMock() - mock_user.email = "alice.doe@example.org" - - database = MagicMock() - database.id = 1 - database.sqlalchemy_uri_decrypted = "trino://" - database.get_extra.return_value = { - "engine_params": {}, - "connect_args": {"source": "Apache Superset"}, - } - database.get_effective_user.return_value = "alice" - database.impersonate_user = True - database.get_oauth2_config.return_value = None - database.update_params_from_encrypted_extra = MagicMock() - database.db_engine_spec = MagicMock() - database.db_engine_spec.adjust_engine_params.return_value = ( - mock_uri, - {"source": "Apache Superset"}, - ) - database.db_engine_spec.impersonate_user.return_value = ( - mock_uri, - {"connect_args": {"user": "alice.doe", "source": "Apache Superset"}}, - ) - database.db_engine_spec.validate_database_uri = MagicMock() - - with ( - patch( - "superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled", - return_value=True, - ), - patch( - "superset.extensions.security_manager.find_user", - return_value=mock_user, - ), - ): - uri, kwargs = engine_manager._get_engine_args( - database, None, None, None, None - ) - - # Verify impersonate_user was called with the email prefix - database.db_engine_spec.impersonate_user.assert_called_once() - call_args = database.db_engine_spec.impersonate_user.call_args - assert call_args[0][1] == "alice.doe" # username from email prefix - - @patch("superset.engines.manager.create_engine") - @patch("superset.engines.manager.make_url_safe") - def test_engine_context_manager_called( - self, mock_make_url, mock_create_engine, engine_manager, mock_database - ): - """Test that the engine context manager is properly called.""" - from sqlalchemy.engine.url import make_url - - mock_uri = make_url("trino://") - mock_make_url.return_value = mock_uri - mock_engine = MagicMock() - mock_create_engine.return_value = mock_engine - - # Track context manager calls - context_manager_calls = [] - - def tracking_context_manager(database, catalog, schema): - from contextlib import contextmanager - - @contextmanager - def inner(): - context_manager_calls.append(("enter", database, catalog, schema)) - yield - context_manager_calls.append(("exit", database, catalog, schema)) - - return inner() - - engine_manager.engine_context_manager = tracking_context_manager - - with engine_manager.get_engine(mock_database, "catalog1", "schema1", None): - pass - - assert len(context_manager_calls) == 2 - assert context_manager_calls[0][0] == "enter" - assert context_manager_calls[0][1] is mock_database - assert context_manager_calls[0][2] == "catalog1" - assert context_manager_calls[0][3] == "schema1" - assert context_manager_calls[1][0] == "exit" - - @patch("superset.utils.oauth2.check_for_oauth2") - @patch("superset.engines.manager.create_engine") - @patch("superset.engines.manager.make_url_safe") - def test_engine_oauth2_error_handling( - self, - mock_make_url, - mock_create_engine, - mock_check_for_oauth2, - engine_manager, - mock_database, - ): - """Test that OAuth2 errors are properly propagated from get_engine.""" - from contextlib import contextmanager - - from sqlalchemy.engine.url import make_url - - mock_uri = make_url("trino://") - mock_make_url.return_value = mock_uri - - # Simulate OAuth2 error during engine creation - class OAuth2TestError(Exception): - pass - - oauth_error = OAuth2TestError("OAuth2 required") - mock_create_engine.side_effect = oauth_error - - # Make get_dbapi_mapped_exception return the original exception - mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = ( - oauth_error - ) - - # Mock check_for_oauth2 to re-raise the exception - @contextmanager - def mock_oauth2_context(database): - try: - yield - except OAuth2TestError: - raise - - mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database) - - with pytest.raises(OAuth2TestError, match="OAuth2 required"): - with engine_manager.get_engine(mock_database, "catalog1", "schema1", None): - pass + yield + + return EngineManager(engine_context_manager=dummy_context_manager) + + +@pytest.fixture +def mock_database() -> MagicMock: + database = MagicMock() + database.id = 1 + database.sqlalchemy_uri_decrypted = "trino://" + database.get_extra.return_value = {"engine_params": {"poolclass": "queue"}} + database.get_effective_user.return_value = "alice" + database.impersonate_user = False + database.update_params_from_encrypted_extra = MagicMock() + database.db_engine_spec = MagicMock() + database.db_engine_spec.adjust_engine_params.return_value = ( + make_url("trino://"), + {"source": "Apache Superset"}, + ) + database.db_engine_spec.validate_database_uri = MagicMock() + return database + + +@patch("superset.engines.manager.make_url_safe") +def test_get_engine_args_uses_null_pool( + mock_make_url: MagicMock, + engine_manager: EngineManager, + mock_database: MagicMock, +) -> None: + mock_make_url.return_value = make_url("trino://") + + _, kwargs = engine_manager._get_engine_args(mock_database, None, None, None, None) + + assert kwargs["poolclass"] is pool.NullPool + + +@patch("superset.engines.manager.make_url_safe") +def test_get_engine_args_with_impersonation( + mock_make_url: MagicMock, + engine_manager: EngineManager, + mock_database: MagicMock, +) -> None: + mock_make_url.return_value = make_url("trino://") + mock_database.impersonate_user = True + mock_database.get_oauth2_config.return_value = None + mock_database.db_engine_spec.impersonate_user.return_value = ( + make_url("trino://"), + {"connect_args": {"user": "alice"}, "poolclass": pool.NullPool}, + ) + + engine_manager._get_engine_args(mock_database, None, None, None, None) + + mock_database.db_engine_spec.impersonate_user.assert_called_once() + + +def test_get_tunnel_kwargs_requires_database_port( + engine_manager: EngineManager, +) -> None: + ssh_tunnel = MagicMock() + ssh_tunnel.server_address = "ssh.example.com" + ssh_tunnel.server_port = 22 + ssh_tunnel.username = "ssh_user" + ssh_tunnel.password = None + ssh_tunnel.private_key = None + ssh_tunnel.private_key_password = None + + uri = MagicMock() + uri.port = None + uri.get_backend_name.return_value = "unknown" + + with patch("superset.utils.ssh_tunnel.get_default_port", return_value=None): + with pytest.raises(SSHTunnelDatabasePortError): + engine_manager._get_tunnel_kwargs(ssh_tunnel, uri) diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index e2190164e95..12da1b489a2 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -689,13 +689,15 @@ def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None: encrypted_extra=json.dumps(oauth2_client_info), ) database.db_engine_spec.oauth2_exception = OAuth2Error - _get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine") - _get_sqla_engine.side_effect = OAuth2Error("OAuth2 required") + create_engine = mocker.patch("superset.engines.manager.create_engine") + create_engine.side_effect = OAuth2Error("OAuth2 required") with pytest.raises(OAuth2RedirectError) as excinfo: with database.get_raw_connection() as conn: conn.cursor() assert str(excinfo.value) == "You don't have permission to access the data." + + def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None: """ Test that we can start OAuth2 from `raw_connection()` errors.