refactor: keep engine manager focused on engine creation

This commit is contained in:
Beto Dealmeida
2026-05-05 18:22:36 -04:00
parent b7b59dfb8a
commit 3aad565eab
6 changed files with 124 additions and 1017 deletions

View File

@@ -52,7 +52,6 @@ from superset.advanced_data_type.plugins.internet_address import internet_addres
from superset.advanced_data_type.plugins.internet_port import internet_port
from superset.advanced_data_type.types import AdvancedDataType
from superset.constants import CHANGE_ME_SECRET_KEY
from superset.engines.manager import EngineModes
from superset.jinja_context import BaseTemplateProcessor
from superset.key_value.types import JsonKeyValueCodec
from superset.stats_logger import DummyStatsLogger
@@ -265,22 +264,6 @@ SQLALCHEMY_ENGINE_OPTIONS = {}
# SQLALCHEMY_CUSTOM_PASSWORD_STORE = lookup_password
SQLALCHEMY_CUSTOM_PASSWORD_STORE = None
# ---------------------------------------------------------
# Engine Manager Configuration
# ---------------------------------------------------------
# Engine manager mode: "NEW" creates a new engine for every connection (default),
# "SINGLETON" reuses engines with connection pooling
ENGINE_MANAGER_MODE = EngineModes.NEW
# Cleanup interval for abandoned locks (default: 5 minutes)
ENGINE_MANAGER_CLEANUP_INTERVAL = timedelta(minutes=5)
# Automatically start cleanup thread for SINGLETON mode (default: True)
ENGINE_MANAGER_AUTO_START_CLEANUP = True
# ---------------------------------------------------------
#
# The EncryptedFieldTypeAdapter is used whenever we're building SqlAlchemy models
# which include sensitive fields that should be app-encrypted BEFORE sending

View File

@@ -15,10 +15,6 @@
# specific language governing permissions and limitations
# under the License.
import enum
import hashlib
import logging
import threading
from contextlib import contextmanager
from datetime import timedelta
from io import StringIO
@@ -26,7 +22,7 @@ from typing import Any, Iterator, TYPE_CHECKING
import sshtunnel
from paramiko import RSAKey
from sqlalchemy import create_engine, event, pool
from sqlalchemy import create_engine, pool
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sshtunnel import SSHTunnelForwarder
@@ -41,158 +37,24 @@ if TYPE_CHECKING:
from superset.models.core import Database
logger = logging.getLogger(__name__)
class _LockManager:
"""
Manages per-key locks safely without defaultdict race conditions.
This class provides a thread-safe way to create and manage locks for specific keys,
avoiding the race conditions that occur when using defaultdict with threading.Lock.
The implementation uses a two-level locking strategy:
1. A meta-lock to protect the lock dictionary itself
2. Per-key locks to protect specific resources
This ensures that:
- Different keys can be locked concurrently (scalability)
- Lock creation is thread-safe (no race conditions)
- The same key always gets the same lock instance
"""
def __init__(self) -> None:
self._locks: dict[str, threading.RLock] = {}
self._meta_lock = threading.Lock()
def get_lock(self, key: str) -> threading.RLock:
"""
Get or create a lock for the given key.
This method uses double-checked locking to ensure thread safety:
1. First check without lock (fast path)
2. Acquire meta-lock if needed
3. Double-check inside the lock to prevent race conditions
This approach minimizes lock contention while ensuring correctness.
:param key: The key to get a lock for
:returns: An RLock instance for the given key
"""
if lock := self._locks.get(key):
return lock
with self._meta_lock:
# Double-check inside the lock
lock = self._locks.get(key)
if lock is None:
lock = threading.RLock()
self._locks[key] = lock
return lock
def cleanup(self, active_keys: set[str]) -> None:
"""
Remove locks for keys that are no longer in use.
This prevents memory leaks from accumulating locks for resources
that have been disposed.
:param active_keys: Set of keys that are still active
"""
with self._meta_lock:
# Find locks to remove
locks_to_remove = self._locks.keys() - active_keys
for key in locks_to_remove:
self._locks.pop(key, None)
EngineKey = str
TunnelKey = str
def _generate_cache_key(*args: Any) -> str:
"""
Generate a deterministic cache key from arbitrary arguments.
Uses repr() for serialization and SHA-256 for hashing. The resulting key
is a 32-character hex string that:
1. Is deterministic for the same inputs
2. Does not expose sensitive data (everything is hashed)
3. Has sufficient entropy to avoid collisions
:param args: Arguments to include in the cache key
:returns: 32-character hex string
"""
# Use repr() which works with most Python objects and is deterministic
serialized = repr(args).encode("utf-8")
return hashlib.sha256(serialized).hexdigest()[:32]
class EngineModes(enum.Enum):
# reuse existing engine if available, otherwise create a new one; this mode should
# have a connection pool configured in the database
SINGLETON = enum.auto()
# always create a new engine for every connection; this mode will use a NullPool
# and is the default behavior for Superset
NEW = enum.auto()
class EngineManager:
"""
A manager for SQLAlchemy engines.
This class handles the creation and management of SQLAlchemy engines, allowing them
to be configured with connection pools and reused across requests. The default mode
is the original behavior for Superset, where we create a new engine for every
connection, using a NullPool. The `SINGLETON` mode, on the other hand, allows for
reusing of the engines, as well as configuring the pool through the database
settings.
"""
"""Centralized SQLAlchemy engine creation for Superset."""
def __init__(
self,
engine_context_manager: EngineContextManager,
db_connection_mutator: DBConnectionMutator | None = None,
mode: EngineModes = EngineModes.NEW,
cleanup_interval: timedelta = timedelta(minutes=5),
local_bind_address: str = "127.0.0.1",
tunnel_timeout: timedelta = timedelta(seconds=30),
ssh_timeout: timedelta = timedelta(seconds=1),
) -> None:
self.engine_context_manager = engine_context_manager
self.db_connection_mutator = db_connection_mutator
self.mode = mode
self.cleanup_interval = cleanup_interval
self.local_bind_address = local_bind_address
sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout.total_seconds()
sshtunnel.SSH_TIMEOUT = ssh_timeout.total_seconds()
self._engines: dict[EngineKey, Engine] = {}
self._engine_locks = _LockManager()
self._tunnels: dict[TunnelKey, SSHTunnelForwarder] = {}
self._tunnel_locks = _LockManager()
# Background cleanup thread management
self._cleanup_thread: threading.Thread | None = None
self._cleanup_stop_event = threading.Event()
self._cleanup_thread_lock = threading.Lock()
def __del__(self) -> None:
"""
Ensure cleanup thread is stopped when the manager is destroyed.
"""
try:
self.stop_cleanup_thread()
except Exception as ex:
# Avoid exceptions during garbage collection, but log if possible
try:
logger.warning("Error stopping cleanup thread: %s", ex)
except Exception: # noqa: S110
# If logging fails during destruction, we can't do anything
pass
@contextmanager
def get_engine(
self,
@@ -201,103 +63,31 @@ class EngineManager:
schema: str | None,
source: QuerySource | None,
) -> Iterator[Engine]:
"""
Context manager to get a SQLAlchemy engine.
"""
# users can wrap the engine in their own context manager for different
# reasons
"""Context manager to get a SQLAlchemy engine."""
from superset.utils.oauth2 import check_for_oauth2
with self.engine_context_manager(database, catalog, schema):
# we need to check for errors indicating that OAuth2 is needed, and
# return the proper exception so it starts the authentication flow
from superset.utils.oauth2 import check_for_oauth2
with check_for_oauth2(database):
yield self._get_engine(database, catalog, schema, source)
uri, kwargs = self._get_engine_args(
database,
catalog,
schema,
source,
get_user_id(),
)
def _get_engine(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
) -> Engine:
"""
Get a specific engine, or create it if none exists.
"""
source = source or get_query_source_from_request()
user_id = get_user_id()
if self.mode == EngineModes.NEW:
return self._create_engine(
database,
catalog,
schema,
source,
user_id,
)
engine_key = self._get_engine_key(
database,
catalog,
schema,
source,
user_id,
)
if engine := self._engines.get(engine_key):
return engine
lock = self._engine_locks.get_lock(engine_key)
with lock:
# Double-check inside the lock
if engine := self._engines.get(engine_key):
return engine
# Create and cache the engine
engine = self._create_engine(
database,
catalog,
schema,
source,
user_id,
)
self._engines[engine_key] = engine
self._add_disposal_listener(engine, engine_key)
return engine
def _get_engine_key(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
user_id: int | None,
) -> EngineKey:
"""
Generate a cache key for the engine.
The key is a hash of all parameters that affect the engine, ensuring
proper cache isolation without exposing sensitive data.
:returns: 32-character hex string
"""
uri, kwargs = self._get_engine_args(
database,
catalog,
schema,
source,
user_id,
)
return _generate_cache_key(
database.id,
catalog,
schema,
str(uri),
source,
user_id,
kwargs,
)
if database.ssh_tunnel:
tunnel = self._create_tunnel(database.ssh_tunnel, uri)
try:
uri = uri.set(
host=tunnel.local_bind_address[0],
port=tunnel.local_bind_port,
)
yield self._create_engine(database, uri, kwargs)
finally:
tunnel.stop()
else:
yield self._create_engine(database, uri, kwargs)
def _get_engine_args(
self,
@@ -307,39 +97,15 @@ class EngineManager:
source: QuerySource | None,
user_id: int | None,
) -> tuple[URL, dict[str, Any]]:
"""
Build the almost final SQLAlchemy URI and engine kwargs.
"Almost" final because we may still need to mutate the URI if an SSH tunnel is
needed, since it needs to connect to the tunnel instead of the original DB. But
that information is only available after the tunnel is created.
"""
# Import here to avoid circular imports
"""Build SQLAlchemy URI and kwargs before engine creation."""
from superset.extensions import security_manager
from superset.utils.feature_flag_manager import FeatureFlagManager
uri = make_url_safe(database.sqlalchemy_uri_decrypted)
extra = database.get_extra(source)
# Make a copy to avoid mutating the original extra dict
kwargs = dict(extra.get("engine_params", {}))
kwargs["poolclass"] = pool.NullPool
# get pool class
if self.mode == EngineModes.NEW or "poolclass" not in kwargs:
kwargs["poolclass"] = pool.NullPool
else:
pools = {
"queue": pool.QueuePool,
"singleton": pool.SingletonThreadPool,
"assertion": pool.AssertionPool,
"null": pool.NullPool,
"static": pool.StaticPool,
}
pool_name = kwargs["poolclass"]
if isinstance(pool_name, str):
kwargs["poolclass"] = pools.get(pool_name, pool.QueuePool)
# update URI for specific catalog/schema
connect_args = kwargs.setdefault("connect_args", {})
uri, connect_args = database.db_engine_spec.adjust_engine_params(
uri,
@@ -348,9 +114,7 @@ class EngineManager:
schema,
)
# get effective username
username = database.get_effective_user(uri)
feature_flag_manager = FeatureFlagManager()
if username and feature_flag_manager.is_feature_enabled(
"IMPERSONATE_WITH_EMAIL_PREFIX"
@@ -359,10 +123,8 @@ class EngineManager:
if user and user.email and "@" in user.email:
username = user.email.split("@")[0]
# update URI/kwargs for user impersonation
if database.impersonate_user:
oauth2_config = database.get_oauth2_config()
# Import here to avoid circular imports
from superset.utils.oauth2 import get_oauth2_access_token
access_token = (
@@ -384,15 +146,10 @@ class EngineManager:
kwargs,
)
# update kwargs from params stored encrupted at rest
database.update_params_from_encrypted_extra(kwargs)
# mutate URI
if self.db_connection_mutator:
source = source or get_query_source_from_request()
# Import here to avoid circular imports
from superset.extensions import security_manager
uri, kwargs = self.db_connection_mutator(
uri,
kwargs,
@@ -401,111 +158,27 @@ class EngineManager:
source,
)
# validate final URI
database.db_engine_spec.validate_database_uri(uri)
return uri, kwargs
def _create_engine(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
user_id: int | None,
uri: URL,
kwargs: dict[str, Any],
) -> Engine:
"""
Create the actual engine.
This should be the only place in Superset where a SQLAlchemy engine is created,
"""
uri, kwargs = self._get_engine_args(
database,
catalog,
schema,
source,
user_id,
)
if database.ssh_tunnel:
tunnel = self._get_tunnel(database.ssh_tunnel, uri)
uri = uri.set(
host=tunnel.local_bind_address[0],
port=tunnel.local_bind_port,
)
try:
engine = create_engine(uri, **kwargs)
return create_engine(uri, **kwargs)
except Exception as ex:
raise database.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
return engine
def _get_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
tunnel_key = self._get_tunnel_key(ssh_tunnel, uri)
tunnel = self._tunnels.get(tunnel_key)
if tunnel is not None and tunnel.is_active:
return tunnel
lock = self._tunnel_locks.get_lock(tunnel_key)
with lock:
# Double-check inside the lock
tunnel = self._tunnels.get(tunnel_key)
if tunnel is not None and tunnel.is_active:
return tunnel
# Create or replace tunnel
return self._replace_tunnel(tunnel_key, ssh_tunnel, uri, tunnel)
def _replace_tunnel(
self,
tunnel_key: str,
ssh_tunnel: "SSHTunnel",
uri: URL,
old_tunnel: SSHTunnelForwarder | None,
) -> SSHTunnelForwarder:
"""
Replace tunnel with proper cleanup.
This function assumes caller holds lock.
"""
if old_tunnel:
try:
old_tunnel.stop()
except Exception:
logger.exception("Error stopping old tunnel")
try:
new_tunnel = self._create_tunnel(ssh_tunnel, uri)
self._tunnels[tunnel_key] = new_tunnel
except Exception:
# Remove failed tunnel from cache
self._tunnels.pop(tunnel_key, None)
logger.exception("Failed to create tunnel")
raise
return new_tunnel
def _get_tunnel_key(self, ssh_tunnel: "SSHTunnel", uri: URL) -> TunnelKey:
"""
Generate a cache key for the SSH tunnel.
:returns: 32-character hex string
"""
tunnel_kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
return _generate_cache_key(tunnel_kwargs)
def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
# Use open_tunnel which handles debug_level properly
tunnel = sshtunnel.open_tunnel(**kwargs)
tunnel.start()
return tunnel
def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]:
# Import here to avoid circular imports
from superset.utils.ssh_tunnel import get_default_port
backend = uri.get_backend_name()
@@ -518,7 +191,6 @@ class EngineManager:
"ssh_username": ssh_tunnel.username,
"remote_bind_address": (uri.host, port),
"local_bind_address": (self.local_bind_address,),
"debug_level": logging.getLogger("flask_appbuilder").level,
}
if ssh_tunnel.password:
@@ -531,107 +203,4 @@ class EngineManager:
)
kwargs["ssh_pkey"] = private_key
if self.mode == EngineModes.NEW:
kwargs["set_keepalive"] = 0 # disable keepalive for one-time tunnels
return kwargs
def start_cleanup_thread(self) -> None:
"""
Start the background cleanup thread.
The thread will periodically clean up abandoned locks at the configured
interval. This is safe to call multiple times - subsequent calls are no-ops.
"""
with self._cleanup_thread_lock:
if self._cleanup_thread is None or not self._cleanup_thread.is_alive():
self._cleanup_stop_event.clear()
self._cleanup_thread = threading.Thread(
target=self._cleanup_worker,
name=f"EngineManager-cleanup-{id(self)}",
daemon=True,
)
self._cleanup_thread.start()
logger.info(
"Started cleanup thread with %ds interval",
self.cleanup_interval.total_seconds(),
)
def stop_cleanup_thread(self) -> None:
"""
Stop the background cleanup thread gracefully.
This will signal the thread to stop and wait for it to finish.
Safe to call even if no thread is running.
"""
with self._cleanup_thread_lock:
if self._cleanup_thread is not None and self._cleanup_thread.is_alive():
self._cleanup_stop_event.set()
self._cleanup_thread.join(timeout=5.0) # 5 second timeout
if self._cleanup_thread.is_alive():
logger.warning("Cleanup thread did not stop within timeout")
else:
logger.info("Cleanup thread stopped")
self._cleanup_thread = None
def _cleanup_worker(self) -> None:
"""
Background thread worker that periodically cleans up abandoned locks.
"""
while not self._cleanup_stop_event.is_set():
try:
self._cleanup_abandoned_locks()
except Exception:
logger.exception("Error during background cleanup")
# Use wait() instead of sleep() to allow for immediate shutdown
if self._cleanup_stop_event.wait(
timeout=self.cleanup_interval.total_seconds()
):
break # Stop event was set
def cleanup(self) -> None:
"""
Public method to manually trigger cleanup of abandoned locks.
This can be called periodically by external systems to prevent
memory leaks from accumulating locks.
"""
self._cleanup_abandoned_locks()
def _cleanup_abandoned_locks(self) -> None:
"""
Clean up locks for engines and tunnels that no longer exist.
This prevents memory leaks from accumulating locks when engines/tunnels
are disposed outside of normal cleanup paths.
"""
# Clean up engine locks for inactive engines
active_engine_keys = set(self._engines.keys())
self._engine_locks.cleanup(active_engine_keys)
# Clean up tunnel locks for inactive tunnels
active_tunnel_keys = set(self._tunnels.keys())
self._tunnel_locks.cleanup(active_tunnel_keys)
# Log for debugging
if active_engine_keys or active_tunnel_keys:
logger.debug(
"EngineManager resources - Engines: %d, Tunnels: %d",
len(active_engine_keys),
len(active_tunnel_keys),
)
def _add_disposal_listener(self, engine: Engine, engine_key: EngineKey) -> None:
@event.listens_for(engine, "engine_disposed")
def on_engine_disposed(engine_instance: Engine) -> None:
try:
# Remove engine from cache - no per-key locks to clean up anymore
if self._engines.pop(engine_key, None):
# Log only first 8 chars of hash for safety
# (still enough for debugging, but doesn't expose full key)
log_key = engine_key[:8] + "..."
logger.info("Engine disposed and removed from cache: %s", log_key)
except Exception as ex:
logger.error("Error during engine disposal cleanup: %s", str(ex))
# Don't log engine_key to avoid exposing credential hash

View File

@@ -15,13 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import atexit
import logging
from datetime import timedelta
from flask import Flask
from superset.engines.manager import EngineManager, EngineModes
from superset.engines.manager import EngineManager
logger = logging.getLogger(__name__)
@@ -29,10 +28,6 @@ logger = logging.getLogger(__name__)
class EngineManagerExtension:
"""
Flask extension for managing SQLAlchemy engines in Superset.
This extension creates and configures an EngineManager instance based on
Flask configuration, handling startup and shutdown of background cleanup
threads as needed.
"""
def __init__(self) -> None:
@@ -44,47 +39,23 @@ class EngineManagerExtension:
"""
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
db_connection_mutator = app.config["DB_CONNECTION_MUTATOR"]
mode = app.config["ENGINE_MANAGER_MODE"]
cleanup_interval = app.config["ENGINE_MANAGER_CLEANUP_INTERVAL"]
local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
tunnel_timeout = timedelta(seconds=app.config["SSH_TUNNEL_TIMEOUT_SEC"])
ssh_timeout = timedelta(seconds=app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"])
auto_start_cleanup = app.config["ENGINE_MANAGER_AUTO_START_CLEANUP"]
# Stop any existing manager's cleanup thread before creating a new one
if self.engine_manager:
self.engine_manager.stop_cleanup_thread()
# Create the engine manager
self.engine_manager = EngineManager(
engine_context_manager,
db_connection_mutator,
mode,
cleanup_interval,
local_bind_address,
tunnel_timeout,
ssh_timeout,
)
# Start cleanup thread if requested and in SINGLETON mode
if auto_start_cleanup and mode == EngineModes.SINGLETON:
self.engine_manager.start_cleanup_thread()
logger.info("Started EngineManager cleanup thread")
# Register shutdown handler
def shutdown_engine_manager(exc: BaseException | None = None) -> None:
if self.engine_manager:
self.engine_manager.stop_cleanup_thread()
app.teardown_appcontext(shutdown_engine_manager)
# Register with atexit for clean shutdown
atexit.register(shutdown_engine_manager)
logger.info(
"Initialized EngineManager with mode=%s, cleanup_interval=%ds",
mode,
cleanup_interval.total_seconds(),
"Initialized EngineManager with tunnel_timeout=%s, ssh_timeout=%s",
tunnel_timeout.total_seconds(),
ssh_timeout.total_seconds(),
)
@property

View File

@@ -429,8 +429,8 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
Context manager for a SQLAlchemy engine.
This method will return a context manager for a SQLAlchemy engine. The engine
manager handles connection pooling, SSH tunnels, and other connection details
based on the configured mode (NEW or SINGLETON).
manager handles engine creation, SSH tunnels, and connection details in a
centralized place.
"""
# Import here to avoid circular imports
from superset.extensions import engine_manager_extension

View File

@@ -15,513 +15,95 @@
# specific language governing permissions and limitations
# under the License.
"""Unit tests for EngineManager."""
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import pool
from sqlalchemy.engine.url import make_url
from superset.engines.manager import _LockManager, EngineManager, EngineModes
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.engines.manager import EngineManager
class TestLockManager:
"""Test the _LockManager class."""
def test_get_lock_creates_new_lock(self):
"""Test that get_lock creates a new lock when needed."""
manager = _LockManager()
lock1 = manager.get_lock("key1")
assert isinstance(lock1, type(threading.RLock()))
assert lock1 is manager.get_lock("key1") # Same lock returned
def test_get_lock_different_keys_different_locks(self):
"""Test that different keys get different locks."""
manager = _LockManager()
lock1 = manager.get_lock("key1")
lock2 = manager.get_lock("key2")
assert lock1 is not lock2
def test_cleanup_removes_unused_locks(self):
"""Test that cleanup removes locks for inactive keys."""
manager = _LockManager()
# Create locks
_ = manager.get_lock("key1") # noqa: F841
lock2 = manager.get_lock("key2")
# Cleanup with only key1 active
manager.cleanup({"key1"})
# key2 lock should be removed
lock3 = manager.get_lock("key2")
assert lock3 is not lock2 # New lock created
def test_concurrent_lock_creation(self):
"""Test that concurrent lock creation doesn't create duplicates."""
manager = _LockManager()
locks_created = []
exceptions = []
def create_lock():
try:
lock = manager.get_lock("concurrent_key")
locks_created.append(lock)
except Exception as e:
exceptions.append(e)
# Create multiple threads trying to get the same lock
threads = [threading.Thread(target=create_lock) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(exceptions) == 0
assert len(locks_created) == 10
# All should be the same lock
first_lock = locks_created[0]
for lock in locks_created[1:]:
assert lock is first_lock
class TestEngineManager:
"""Test the EngineManager class."""
@pytest.fixture
def engine_manager(self):
"""Create a mock EngineManager instance."""
from contextlib import contextmanager
@contextmanager
def dummy_context_manager(
database: MagicMock, catalog: str | None, schema: str | None
) -> Iterator[None]:
yield
return EngineManager(engine_context_manager=dummy_context_manager)
@pytest.fixture
def mock_database(self):
"""Create a mock database."""
database = MagicMock()
database.sqlalchemy_uri_decrypted = "postgresql://user:pass@localhost/test"
database.get_extra.return_value = {"engine_params": {}}
database.get_effective_user.return_value = "test_user"
database.impersonate_user = False
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (MagicMock(), {})
database.db_engine_spec.impersonate_user = MagicMock(
return_value=(MagicMock(), {})
)
database.db_engine_spec.validate_database_uri = MagicMock()
database.ssh_tunnel = None
return database
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_new_mode(
self, mock_make_url, mock_create_engine, engine_manager, mock_database
@pytest.fixture
def engine_manager() -> EngineManager:
@contextmanager
def dummy_context_manager(
database: MagicMock,
catalog: str | None,
schema: str | None,
):
"""Test getting an engine in NEW mode (no caching)."""
engine_manager.mode = EngineModes.NEW
mock_make_url.return_value = MagicMock()
mock_engine1 = MagicMock()
mock_engine2 = MagicMock()
mock_create_engine.side_effect = [mock_engine1, mock_engine2]
result = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
assert result is mock_engine1
mock_create_engine.assert_called_once()
# Calling again should create a new engine (no caching)
mock_create_engine.reset_mock()
result2 = engine_manager._get_engine(mock_database, "catalog2", "schema2", None)
assert result2 is mock_engine2 # Different engine
mock_create_engine.assert_called_once()
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_singleton_mode_caching(
self, mock_make_url, mock_create_engine, engine_manager, mock_database
):
"""Test that engines are cached in SINGLETON mode."""
engine_manager.mode = EngineModes.SINGLETON
# Use a real engine instead of MagicMock to avoid event listener issues
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
mock_create_engine.return_value = real_engine
mock_make_url.return_value = real_engine
# Call twice with same params - should be cached
result1 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
result2 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
assert result1 is result2 # Same engine returned (cached)
mock_create_engine.assert_called_once() # Only created once
# Call with different params - should create new engine
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_concurrent_engine_creation(
self, mock_make_url, mock_create_engine, engine_manager, mock_database
):
"""Test concurrent engine creation doesn't create duplicates."""
engine_manager.mode = EngineModes.SINGLETON
# Use a real engine to avoid event listener issues with MagicMock
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
mock_make_url.return_value = real_engine
create_count = [0]
def counting_create_engine(*args, **kwargs):
create_count[0] += 1
return real_engine
mock_create_engine.side_effect = counting_create_engine
results = []
exceptions = []
def get_engine_thread():
try:
engine = engine_manager._get_engine(
mock_database, "catalog1", "schema1", None
)
results.append(engine)
except Exception as e:
exceptions.append(e)
# Run multiple threads
threads = [threading.Thread(target=get_engine_thread) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(exceptions) == 0
assert len(results) == 10
assert create_count[0] == 1 # Engine created only once
# All results should be the same engine
for engine in results:
assert engine is real_engine
@patch("superset.engines.manager.sshtunnel.open_tunnel")
def test_ssh_tunnel_creation(self, mock_open_tunnel, engine_manager):
"""Test SSH tunnel creation and caching."""
ssh_tunnel = MagicMock()
ssh_tunnel.server_address = "ssh.example.com"
ssh_tunnel.server_port = 22
ssh_tunnel.username = "ssh_user"
ssh_tunnel.password = "ssh_pass" # noqa: S105
ssh_tunnel.private_key = None
ssh_tunnel.private_key_password = None
tunnel_instance = MagicMock()
tunnel_instance.is_active = True
tunnel_instance.local_bind_address = ("127.0.0.1", 12345)
mock_open_tunnel.return_value = tunnel_instance
uri = MagicMock()
uri.host = "db.example.com"
uri.port = 5432
uri.get_backend_name.return_value = "postgresql"
result = engine_manager._get_tunnel(ssh_tunnel, uri)
assert result is tunnel_instance
mock_open_tunnel.assert_called_once()
tunnel_instance.start.assert_called_once()
# Getting same tunnel again should return cached version
mock_open_tunnel.reset_mock()
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
assert result2 is tunnel_instance
mock_open_tunnel.assert_not_called()
@patch("superset.engines.manager.sshtunnel.open_tunnel")
def test_ssh_tunnel_recreation_when_inactive(
self, mock_open_tunnel, engine_manager
):
"""Test that inactive tunnels are replaced."""
ssh_tunnel = MagicMock()
ssh_tunnel.server_address = "ssh.example.com"
ssh_tunnel.server_port = 22
ssh_tunnel.username = "ssh_user"
ssh_tunnel.password = "ssh_pass" # noqa: S105
ssh_tunnel.private_key = None
ssh_tunnel.private_key_password = None
# First tunnel is inactive
inactive_tunnel = MagicMock()
inactive_tunnel.is_active = False
inactive_tunnel.local_bind_address = ("127.0.0.1", 12345)
# Second tunnel is active
active_tunnel = MagicMock()
active_tunnel.is_active = True
active_tunnel.local_bind_address = ("127.0.0.1", 23456)
mock_open_tunnel.side_effect = [inactive_tunnel, active_tunnel]
uri = MagicMock()
uri.host = "db.example.com"
uri.port = 5432
uri.get_backend_name.return_value = "postgresql"
# First call creates inactive tunnel
result1 = engine_manager._get_tunnel(ssh_tunnel, uri)
assert result1 is inactive_tunnel
# Second call should create new tunnel since first is inactive
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
assert result2 is active_tunnel
assert mock_open_tunnel.call_count == 2
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_basic(
self, mock_make_url, mock_create_engine, engine_manager
):
"""Test _get_engine_args returns correct URI and kwargs."""
from sqlalchemy.engine.url import make_url
from superset.engines.manager import EngineModes
engine_manager.mode = EngineModes.NEW
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {
"engine_params": {},
"connect_args": {"source": "Apache Superset"},
}
database.get_effective_user.return_value = "alice"
database.impersonate_user = False
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
mock_uri,
{"source": "Apache Superset"},
)
database.db_engine_spec.validate_database_uri = MagicMock()
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
assert str(uri) == "trino://"
assert "connect_args" in database.get_extra.return_value
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_user_impersonation(
self, mock_make_url, mock_create_engine, engine_manager
):
"""Test user impersonation in _get_engine_args."""
from sqlalchemy.engine.url import make_url
from superset.engines.manager import EngineModes
engine_manager.mode = EngineModes.NEW
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {
"engine_params": {},
"connect_args": {"source": "Apache Superset"},
}
database.get_effective_user.return_value = "alice"
database.impersonate_user = True
database.get_oauth2_config.return_value = None
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
mock_uri,
{"source": "Apache Superset"},
)
database.db_engine_spec.impersonate_user.return_value = (
mock_uri,
{"connect_args": {"user": "alice", "source": "Apache Superset"}},
)
database.db_engine_spec.validate_database_uri = MagicMock()
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
# Verify impersonate_user was called
database.db_engine_spec.impersonate_user.assert_called_once()
call_args = database.db_engine_spec.impersonate_user.call_args
assert call_args[0][0] is database # database
assert call_args[0][1] == "alice" # username
assert call_args[0][2] is None # access_token (no OAuth2)
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_user_impersonation_email_prefix(
self,
mock_make_url,
mock_create_engine,
engine_manager,
):
"""Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature flag."""
from sqlalchemy.engine.url import make_url
from superset.engines.manager import EngineModes
engine_manager.mode = EngineModes.NEW
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
# Mock user with email
mock_user = MagicMock()
mock_user.email = "alice.doe@example.org"
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {
"engine_params": {},
"connect_args": {"source": "Apache Superset"},
}
database.get_effective_user.return_value = "alice"
database.impersonate_user = True
database.get_oauth2_config.return_value = None
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
mock_uri,
{"source": "Apache Superset"},
)
database.db_engine_spec.impersonate_user.return_value = (
mock_uri,
{"connect_args": {"user": "alice.doe", "source": "Apache Superset"}},
)
database.db_engine_spec.validate_database_uri = MagicMock()
with (
patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled",
return_value=True,
),
patch(
"superset.extensions.security_manager.find_user",
return_value=mock_user,
),
):
uri, kwargs = engine_manager._get_engine_args(
database, None, None, None, None
)
# Verify impersonate_user was called with the email prefix
database.db_engine_spec.impersonate_user.assert_called_once()
call_args = database.db_engine_spec.impersonate_user.call_args
assert call_args[0][1] == "alice.doe" # username from email prefix
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_engine_context_manager_called(
self, mock_make_url, mock_create_engine, engine_manager, mock_database
):
"""Test that the engine context manager is properly called."""
from sqlalchemy.engine.url import make_url
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
# Track context manager calls
context_manager_calls = []
def tracking_context_manager(database, catalog, schema):
from contextlib import contextmanager
@contextmanager
def inner():
context_manager_calls.append(("enter", database, catalog, schema))
yield
context_manager_calls.append(("exit", database, catalog, schema))
return inner()
engine_manager.engine_context_manager = tracking_context_manager
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
pass
assert len(context_manager_calls) == 2
assert context_manager_calls[0][0] == "enter"
assert context_manager_calls[0][1] is mock_database
assert context_manager_calls[0][2] == "catalog1"
assert context_manager_calls[0][3] == "schema1"
assert context_manager_calls[1][0] == "exit"
@patch("superset.utils.oauth2.check_for_oauth2")
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_engine_oauth2_error_handling(
self,
mock_make_url,
mock_create_engine,
mock_check_for_oauth2,
engine_manager,
mock_database,
):
"""Test that OAuth2 errors are properly propagated from get_engine."""
from contextlib import contextmanager
from sqlalchemy.engine.url import make_url
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
# Simulate OAuth2 error during engine creation
class OAuth2TestError(Exception):
pass
oauth_error = OAuth2TestError("OAuth2 required")
mock_create_engine.side_effect = oauth_error
# Make get_dbapi_mapped_exception return the original exception
mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = (
oauth_error
)
# Mock check_for_oauth2 to re-raise the exception
@contextmanager
def mock_oauth2_context(database):
try:
yield
except OAuth2TestError:
raise
mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database)
with pytest.raises(OAuth2TestError, match="OAuth2 required"):
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
pass
yield
return EngineManager(engine_context_manager=dummy_context_manager)
@pytest.fixture
def mock_database() -> MagicMock:
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {"engine_params": {"poolclass": "queue"}}
database.get_effective_user.return_value = "alice"
database.impersonate_user = False
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
make_url("trino://"),
{"source": "Apache Superset"},
)
database.db_engine_spec.validate_database_uri = MagicMock()
return database
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_uses_null_pool(
mock_make_url: MagicMock,
engine_manager: EngineManager,
mock_database: MagicMock,
) -> None:
mock_make_url.return_value = make_url("trino://")
_, kwargs = engine_manager._get_engine_args(mock_database, None, None, None, None)
assert kwargs["poolclass"] is pool.NullPool
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_with_impersonation(
mock_make_url: MagicMock,
engine_manager: EngineManager,
mock_database: MagicMock,
) -> None:
mock_make_url.return_value = make_url("trino://")
mock_database.impersonate_user = True
mock_database.get_oauth2_config.return_value = None
mock_database.db_engine_spec.impersonate_user.return_value = (
make_url("trino://"),
{"connect_args": {"user": "alice"}, "poolclass": pool.NullPool},
)
engine_manager._get_engine_args(mock_database, None, None, None, None)
mock_database.db_engine_spec.impersonate_user.assert_called_once()
def test_get_tunnel_kwargs_requires_database_port(
engine_manager: EngineManager,
) -> None:
ssh_tunnel = MagicMock()
ssh_tunnel.server_address = "ssh.example.com"
ssh_tunnel.server_port = 22
ssh_tunnel.username = "ssh_user"
ssh_tunnel.password = None
ssh_tunnel.private_key = None
ssh_tunnel.private_key_password = None
uri = MagicMock()
uri.port = None
uri.get_backend_name.return_value = "unknown"
with patch("superset.utils.ssh_tunnel.get_default_port", return_value=None):
with pytest.raises(SSHTunnelDatabasePortError):
engine_manager._get_tunnel_kwargs(ssh_tunnel, uri)

View File

@@ -689,13 +689,15 @@ def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
create_engine = mocker.patch("superset.engines.manager.create_engine")
create_engine.side_effect = OAuth2Error("OAuth2 required")
with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
conn.cursor()
assert str(excinfo.value) == "You don't have permission to access the data."
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.