mirror of
https://github.com/apache/superset.git
synced 2026-05-11 19:05:24 +00:00
refactor: keep engine manager focused on engine creation
This commit is contained in:
@@ -52,7 +52,6 @@ from superset.advanced_data_type.plugins.internet_address import internet_addres
|
||||
from superset.advanced_data_type.plugins.internet_port import internet_port
|
||||
from superset.advanced_data_type.types import AdvancedDataType
|
||||
from superset.constants import CHANGE_ME_SECRET_KEY
|
||||
from superset.engines.manager import EngineModes
|
||||
from superset.jinja_context import BaseTemplateProcessor
|
||||
from superset.key_value.types import JsonKeyValueCodec
|
||||
from superset.stats_logger import DummyStatsLogger
|
||||
@@ -265,22 +264,6 @@ SQLALCHEMY_ENGINE_OPTIONS = {}
|
||||
# SQLALCHEMY_CUSTOM_PASSWORD_STORE = lookup_password
|
||||
SQLALCHEMY_CUSTOM_PASSWORD_STORE = None
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Engine Manager Configuration
|
||||
# ---------------------------------------------------------
|
||||
|
||||
# Engine manager mode: "NEW" creates a new engine for every connection (default),
|
||||
# "SINGLETON" reuses engines with connection pooling
|
||||
ENGINE_MANAGER_MODE = EngineModes.NEW
|
||||
|
||||
# Cleanup interval for abandoned locks (default: 5 minutes)
|
||||
ENGINE_MANAGER_CLEANUP_INTERVAL = timedelta(minutes=5)
|
||||
|
||||
# Automatically start cleanup thread for SINGLETON mode (default: True)
|
||||
ENGINE_MANAGER_AUTO_START_CLEANUP = True
|
||||
|
||||
# ---------------------------------------------------------
|
||||
|
||||
#
|
||||
# The EncryptedFieldTypeAdapter is used whenever we're building SqlAlchemy models
|
||||
# which include sensitive fields that should be app-encrypted BEFORE sending
|
||||
|
||||
@@ -15,10 +15,6 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import enum
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from io import StringIO
|
||||
@@ -26,7 +22,7 @@ from typing import Any, Iterator, TYPE_CHECKING
|
||||
|
||||
import sshtunnel
|
||||
from paramiko import RSAKey
|
||||
from sqlalchemy import create_engine, event, pool
|
||||
from sqlalchemy import create_engine, pool
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
@@ -41,158 +37,24 @@ if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _LockManager:
|
||||
"""
|
||||
Manages per-key locks safely without defaultdict race conditions.
|
||||
|
||||
This class provides a thread-safe way to create and manage locks for specific keys,
|
||||
avoiding the race conditions that occur when using defaultdict with threading.Lock.
|
||||
|
||||
The implementation uses a two-level locking strategy:
|
||||
1. A meta-lock to protect the lock dictionary itself
|
||||
2. Per-key locks to protect specific resources
|
||||
|
||||
This ensures that:
|
||||
- Different keys can be locked concurrently (scalability)
|
||||
- Lock creation is thread-safe (no race conditions)
|
||||
- The same key always gets the same lock instance
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._locks: dict[str, threading.RLock] = {}
|
||||
self._meta_lock = threading.Lock()
|
||||
|
||||
def get_lock(self, key: str) -> threading.RLock:
|
||||
"""
|
||||
Get or create a lock for the given key.
|
||||
|
||||
This method uses double-checked locking to ensure thread safety:
|
||||
1. First check without lock (fast path)
|
||||
2. Acquire meta-lock if needed
|
||||
3. Double-check inside the lock to prevent race conditions
|
||||
|
||||
This approach minimizes lock contention while ensuring correctness.
|
||||
|
||||
:param key: The key to get a lock for
|
||||
:returns: An RLock instance for the given key
|
||||
"""
|
||||
if lock := self._locks.get(key):
|
||||
return lock
|
||||
|
||||
with self._meta_lock:
|
||||
# Double-check inside the lock
|
||||
lock = self._locks.get(key)
|
||||
if lock is None:
|
||||
lock = threading.RLock()
|
||||
self._locks[key] = lock
|
||||
return lock
|
||||
|
||||
def cleanup(self, active_keys: set[str]) -> None:
|
||||
"""
|
||||
Remove locks for keys that are no longer in use.
|
||||
|
||||
This prevents memory leaks from accumulating locks for resources
|
||||
that have been disposed.
|
||||
|
||||
:param active_keys: Set of keys that are still active
|
||||
"""
|
||||
with self._meta_lock:
|
||||
# Find locks to remove
|
||||
locks_to_remove = self._locks.keys() - active_keys
|
||||
for key in locks_to_remove:
|
||||
self._locks.pop(key, None)
|
||||
|
||||
|
||||
EngineKey = str
|
||||
TunnelKey = str
|
||||
|
||||
|
||||
def _generate_cache_key(*args: Any) -> str:
|
||||
"""
|
||||
Generate a deterministic cache key from arbitrary arguments.
|
||||
|
||||
Uses repr() for serialization and SHA-256 for hashing. The resulting key
|
||||
is a 32-character hex string that:
|
||||
1. Is deterministic for the same inputs
|
||||
2. Does not expose sensitive data (everything is hashed)
|
||||
3. Has sufficient entropy to avoid collisions
|
||||
|
||||
:param args: Arguments to include in the cache key
|
||||
:returns: 32-character hex string
|
||||
"""
|
||||
# Use repr() which works with most Python objects and is deterministic
|
||||
serialized = repr(args).encode("utf-8")
|
||||
return hashlib.sha256(serialized).hexdigest()[:32]
|
||||
|
||||
|
||||
class EngineModes(enum.Enum):
|
||||
# reuse existing engine if available, otherwise create a new one; this mode should
|
||||
# have a connection pool configured in the database
|
||||
SINGLETON = enum.auto()
|
||||
|
||||
# always create a new engine for every connection; this mode will use a NullPool
|
||||
# and is the default behavior for Superset
|
||||
NEW = enum.auto()
|
||||
|
||||
|
||||
class EngineManager:
|
||||
"""
|
||||
A manager for SQLAlchemy engines.
|
||||
|
||||
This class handles the creation and management of SQLAlchemy engines, allowing them
|
||||
to be configured with connection pools and reused across requests. The default mode
|
||||
is the original behavior for Superset, where we create a new engine for every
|
||||
connection, using a NullPool. The `SINGLETON` mode, on the other hand, allows for
|
||||
reusing of the engines, as well as configuring the pool through the database
|
||||
settings.
|
||||
"""
|
||||
"""Centralized SQLAlchemy engine creation for Superset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_context_manager: EngineContextManager,
|
||||
db_connection_mutator: DBConnectionMutator | None = None,
|
||||
mode: EngineModes = EngineModes.NEW,
|
||||
cleanup_interval: timedelta = timedelta(minutes=5),
|
||||
local_bind_address: str = "127.0.0.1",
|
||||
tunnel_timeout: timedelta = timedelta(seconds=30),
|
||||
ssh_timeout: timedelta = timedelta(seconds=1),
|
||||
) -> None:
|
||||
self.engine_context_manager = engine_context_manager
|
||||
self.db_connection_mutator = db_connection_mutator
|
||||
self.mode = mode
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.local_bind_address = local_bind_address
|
||||
|
||||
sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout.total_seconds()
|
||||
sshtunnel.SSH_TIMEOUT = ssh_timeout.total_seconds()
|
||||
|
||||
self._engines: dict[EngineKey, Engine] = {}
|
||||
self._engine_locks = _LockManager()
|
||||
self._tunnels: dict[TunnelKey, SSHTunnelForwarder] = {}
|
||||
self._tunnel_locks = _LockManager()
|
||||
|
||||
# Background cleanup thread management
|
||||
self._cleanup_thread: threading.Thread | None = None
|
||||
self._cleanup_stop_event = threading.Event()
|
||||
self._cleanup_thread_lock = threading.Lock()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Ensure cleanup thread is stopped when the manager is destroyed.
|
||||
"""
|
||||
try:
|
||||
self.stop_cleanup_thread()
|
||||
except Exception as ex:
|
||||
# Avoid exceptions during garbage collection, but log if possible
|
||||
try:
|
||||
logger.warning("Error stopping cleanup thread: %s", ex)
|
||||
except Exception: # noqa: S110
|
||||
# If logging fails during destruction, we can't do anything
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def get_engine(
|
||||
self,
|
||||
@@ -201,103 +63,31 @@ class EngineManager:
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
) -> Iterator[Engine]:
|
||||
"""
|
||||
Context manager to get a SQLAlchemy engine.
|
||||
"""
|
||||
# users can wrap the engine in their own context manager for different
|
||||
# reasons
|
||||
"""Context manager to get a SQLAlchemy engine."""
|
||||
from superset.utils.oauth2 import check_for_oauth2
|
||||
|
||||
with self.engine_context_manager(database, catalog, schema):
|
||||
# we need to check for errors indicating that OAuth2 is needed, and
|
||||
# return the proper exception so it starts the authentication flow
|
||||
from superset.utils.oauth2 import check_for_oauth2
|
||||
|
||||
with check_for_oauth2(database):
|
||||
yield self._get_engine(database, catalog, schema, source)
|
||||
uri, kwargs = self._get_engine_args(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
get_user_id(),
|
||||
)
|
||||
|
||||
def _get_engine(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
) -> Engine:
|
||||
"""
|
||||
Get a specific engine, or create it if none exists.
|
||||
"""
|
||||
source = source or get_query_source_from_request()
|
||||
user_id = get_user_id()
|
||||
|
||||
if self.mode == EngineModes.NEW:
|
||||
return self._create_engine(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
|
||||
engine_key = self._get_engine_key(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if engine := self._engines.get(engine_key):
|
||||
return engine
|
||||
|
||||
lock = self._engine_locks.get_lock(engine_key)
|
||||
with lock:
|
||||
# Double-check inside the lock
|
||||
if engine := self._engines.get(engine_key):
|
||||
return engine
|
||||
|
||||
# Create and cache the engine
|
||||
engine = self._create_engine(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
self._engines[engine_key] = engine
|
||||
self._add_disposal_listener(engine, engine_key)
|
||||
return engine
|
||||
|
||||
def _get_engine_key(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
user_id: int | None,
|
||||
) -> EngineKey:
|
||||
"""
|
||||
Generate a cache key for the engine.
|
||||
|
||||
The key is a hash of all parameters that affect the engine, ensuring
|
||||
proper cache isolation without exposing sensitive data.
|
||||
|
||||
:returns: 32-character hex string
|
||||
"""
|
||||
uri, kwargs = self._get_engine_args(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return _generate_cache_key(
|
||||
database.id,
|
||||
catalog,
|
||||
schema,
|
||||
str(uri),
|
||||
source,
|
||||
user_id,
|
||||
kwargs,
|
||||
)
|
||||
if database.ssh_tunnel:
|
||||
tunnel = self._create_tunnel(database.ssh_tunnel, uri)
|
||||
try:
|
||||
uri = uri.set(
|
||||
host=tunnel.local_bind_address[0],
|
||||
port=tunnel.local_bind_port,
|
||||
)
|
||||
yield self._create_engine(database, uri, kwargs)
|
||||
finally:
|
||||
tunnel.stop()
|
||||
else:
|
||||
yield self._create_engine(database, uri, kwargs)
|
||||
|
||||
def _get_engine_args(
|
||||
self,
|
||||
@@ -307,39 +97,15 @@ class EngineManager:
|
||||
source: QuerySource | None,
|
||||
user_id: int | None,
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
"""
|
||||
Build the almost final SQLAlchemy URI and engine kwargs.
|
||||
|
||||
"Almost" final because we may still need to mutate the URI if an SSH tunnel is
|
||||
needed, since it needs to connect to the tunnel instead of the original DB. But
|
||||
that information is only available after the tunnel is created.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
"""Build SQLAlchemy URI and kwargs before engine creation."""
|
||||
from superset.extensions import security_manager
|
||||
from superset.utils.feature_flag_manager import FeatureFlagManager
|
||||
|
||||
uri = make_url_safe(database.sqlalchemy_uri_decrypted)
|
||||
|
||||
extra = database.get_extra(source)
|
||||
# Make a copy to avoid mutating the original extra dict
|
||||
kwargs = dict(extra.get("engine_params", {}))
|
||||
kwargs["poolclass"] = pool.NullPool
|
||||
|
||||
# get pool class
|
||||
if self.mode == EngineModes.NEW or "poolclass" not in kwargs:
|
||||
kwargs["poolclass"] = pool.NullPool
|
||||
else:
|
||||
pools = {
|
||||
"queue": pool.QueuePool,
|
||||
"singleton": pool.SingletonThreadPool,
|
||||
"assertion": pool.AssertionPool,
|
||||
"null": pool.NullPool,
|
||||
"static": pool.StaticPool,
|
||||
}
|
||||
pool_name = kwargs["poolclass"]
|
||||
if isinstance(pool_name, str):
|
||||
kwargs["poolclass"] = pools.get(pool_name, pool.QueuePool)
|
||||
|
||||
# update URI for specific catalog/schema
|
||||
connect_args = kwargs.setdefault("connect_args", {})
|
||||
uri, connect_args = database.db_engine_spec.adjust_engine_params(
|
||||
uri,
|
||||
@@ -348,9 +114,7 @@ class EngineManager:
|
||||
schema,
|
||||
)
|
||||
|
||||
# get effective username
|
||||
username = database.get_effective_user(uri)
|
||||
|
||||
feature_flag_manager = FeatureFlagManager()
|
||||
if username and feature_flag_manager.is_feature_enabled(
|
||||
"IMPERSONATE_WITH_EMAIL_PREFIX"
|
||||
@@ -359,10 +123,8 @@ class EngineManager:
|
||||
if user and user.email and "@" in user.email:
|
||||
username = user.email.split("@")[0]
|
||||
|
||||
# update URI/kwargs for user impersonation
|
||||
if database.impersonate_user:
|
||||
oauth2_config = database.get_oauth2_config()
|
||||
# Import here to avoid circular imports
|
||||
from superset.utils.oauth2 import get_oauth2_access_token
|
||||
|
||||
access_token = (
|
||||
@@ -384,15 +146,10 @@ class EngineManager:
|
||||
kwargs,
|
||||
)
|
||||
|
||||
# update kwargs from params stored encrupted at rest
|
||||
database.update_params_from_encrypted_extra(kwargs)
|
||||
|
||||
# mutate URI
|
||||
if self.db_connection_mutator:
|
||||
source = source or get_query_source_from_request()
|
||||
# Import here to avoid circular imports
|
||||
from superset.extensions import security_manager
|
||||
|
||||
uri, kwargs = self.db_connection_mutator(
|
||||
uri,
|
||||
kwargs,
|
||||
@@ -401,111 +158,27 @@ class EngineManager:
|
||||
source,
|
||||
)
|
||||
|
||||
# validate final URI
|
||||
database.db_engine_spec.validate_database_uri(uri)
|
||||
|
||||
return uri, kwargs
|
||||
|
||||
def _create_engine(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
user_id: int | None,
|
||||
uri: URL,
|
||||
kwargs: dict[str, Any],
|
||||
) -> Engine:
|
||||
"""
|
||||
Create the actual engine.
|
||||
|
||||
This should be the only place in Superset where a SQLAlchemy engine is created,
|
||||
"""
|
||||
uri, kwargs = self._get_engine_args(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if database.ssh_tunnel:
|
||||
tunnel = self._get_tunnel(database.ssh_tunnel, uri)
|
||||
uri = uri.set(
|
||||
host=tunnel.local_bind_address[0],
|
||||
port=tunnel.local_bind_port,
|
||||
)
|
||||
|
||||
try:
|
||||
engine = create_engine(uri, **kwargs)
|
||||
return create_engine(uri, **kwargs)
|
||||
except Exception as ex:
|
||||
raise database.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
return engine
|
||||
|
||||
def _get_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
|
||||
tunnel_key = self._get_tunnel_key(ssh_tunnel, uri)
|
||||
|
||||
tunnel = self._tunnels.get(tunnel_key)
|
||||
if tunnel is not None and tunnel.is_active:
|
||||
return tunnel
|
||||
|
||||
lock = self._tunnel_locks.get_lock(tunnel_key)
|
||||
with lock:
|
||||
# Double-check inside the lock
|
||||
tunnel = self._tunnels.get(tunnel_key)
|
||||
if tunnel is not None and tunnel.is_active:
|
||||
return tunnel
|
||||
|
||||
# Create or replace tunnel
|
||||
return self._replace_tunnel(tunnel_key, ssh_tunnel, uri, tunnel)
|
||||
|
||||
def _replace_tunnel(
|
||||
self,
|
||||
tunnel_key: str,
|
||||
ssh_tunnel: "SSHTunnel",
|
||||
uri: URL,
|
||||
old_tunnel: SSHTunnelForwarder | None,
|
||||
) -> SSHTunnelForwarder:
|
||||
"""
|
||||
Replace tunnel with proper cleanup.
|
||||
|
||||
This function assumes caller holds lock.
|
||||
"""
|
||||
if old_tunnel:
|
||||
try:
|
||||
old_tunnel.stop()
|
||||
except Exception:
|
||||
logger.exception("Error stopping old tunnel")
|
||||
|
||||
try:
|
||||
new_tunnel = self._create_tunnel(ssh_tunnel, uri)
|
||||
self._tunnels[tunnel_key] = new_tunnel
|
||||
except Exception:
|
||||
# Remove failed tunnel from cache
|
||||
self._tunnels.pop(tunnel_key, None)
|
||||
logger.exception("Failed to create tunnel")
|
||||
raise
|
||||
|
||||
return new_tunnel
|
||||
|
||||
def _get_tunnel_key(self, ssh_tunnel: "SSHTunnel", uri: URL) -> TunnelKey:
|
||||
"""
|
||||
Generate a cache key for the SSH tunnel.
|
||||
|
||||
:returns: 32-character hex string
|
||||
"""
|
||||
tunnel_kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
|
||||
return _generate_cache_key(tunnel_kwargs)
|
||||
|
||||
def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
|
||||
kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
|
||||
# Use open_tunnel which handles debug_level properly
|
||||
tunnel = sshtunnel.open_tunnel(**kwargs)
|
||||
tunnel.start()
|
||||
|
||||
return tunnel
|
||||
|
||||
def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]:
|
||||
# Import here to avoid circular imports
|
||||
from superset.utils.ssh_tunnel import get_default_port
|
||||
|
||||
backend = uri.get_backend_name()
|
||||
@@ -518,7 +191,6 @@ class EngineManager:
|
||||
"ssh_username": ssh_tunnel.username,
|
||||
"remote_bind_address": (uri.host, port),
|
||||
"local_bind_address": (self.local_bind_address,),
|
||||
"debug_level": logging.getLogger("flask_appbuilder").level,
|
||||
}
|
||||
|
||||
if ssh_tunnel.password:
|
||||
@@ -531,107 +203,4 @@ class EngineManager:
|
||||
)
|
||||
kwargs["ssh_pkey"] = private_key
|
||||
|
||||
if self.mode == EngineModes.NEW:
|
||||
kwargs["set_keepalive"] = 0 # disable keepalive for one-time tunnels
|
||||
|
||||
return kwargs
|
||||
|
||||
def start_cleanup_thread(self) -> None:
|
||||
"""
|
||||
Start the background cleanup thread.
|
||||
|
||||
The thread will periodically clean up abandoned locks at the configured
|
||||
interval. This is safe to call multiple times - subsequent calls are no-ops.
|
||||
"""
|
||||
with self._cleanup_thread_lock:
|
||||
if self._cleanup_thread is None or not self._cleanup_thread.is_alive():
|
||||
self._cleanup_stop_event.clear()
|
||||
self._cleanup_thread = threading.Thread(
|
||||
target=self._cleanup_worker,
|
||||
name=f"EngineManager-cleanup-{id(self)}",
|
||||
daemon=True,
|
||||
)
|
||||
self._cleanup_thread.start()
|
||||
logger.info(
|
||||
"Started cleanup thread with %ds interval",
|
||||
self.cleanup_interval.total_seconds(),
|
||||
)
|
||||
|
||||
def stop_cleanup_thread(self) -> None:
|
||||
"""
|
||||
Stop the background cleanup thread gracefully.
|
||||
|
||||
This will signal the thread to stop and wait for it to finish.
|
||||
Safe to call even if no thread is running.
|
||||
"""
|
||||
with self._cleanup_thread_lock:
|
||||
if self._cleanup_thread is not None and self._cleanup_thread.is_alive():
|
||||
self._cleanup_stop_event.set()
|
||||
self._cleanup_thread.join(timeout=5.0) # 5 second timeout
|
||||
if self._cleanup_thread.is_alive():
|
||||
logger.warning("Cleanup thread did not stop within timeout")
|
||||
else:
|
||||
logger.info("Cleanup thread stopped")
|
||||
self._cleanup_thread = None
|
||||
|
||||
def _cleanup_worker(self) -> None:
|
||||
"""
|
||||
Background thread worker that periodically cleans up abandoned locks.
|
||||
"""
|
||||
while not self._cleanup_stop_event.is_set():
|
||||
try:
|
||||
self._cleanup_abandoned_locks()
|
||||
except Exception:
|
||||
logger.exception("Error during background cleanup")
|
||||
|
||||
# Use wait() instead of sleep() to allow for immediate shutdown
|
||||
if self._cleanup_stop_event.wait(
|
||||
timeout=self.cleanup_interval.total_seconds()
|
||||
):
|
||||
break # Stop event was set
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""
|
||||
Public method to manually trigger cleanup of abandoned locks.
|
||||
|
||||
This can be called periodically by external systems to prevent
|
||||
memory leaks from accumulating locks.
|
||||
"""
|
||||
self._cleanup_abandoned_locks()
|
||||
|
||||
def _cleanup_abandoned_locks(self) -> None:
|
||||
"""
|
||||
Clean up locks for engines and tunnels that no longer exist.
|
||||
|
||||
This prevents memory leaks from accumulating locks when engines/tunnels
|
||||
are disposed outside of normal cleanup paths.
|
||||
"""
|
||||
# Clean up engine locks for inactive engines
|
||||
active_engine_keys = set(self._engines.keys())
|
||||
self._engine_locks.cleanup(active_engine_keys)
|
||||
|
||||
# Clean up tunnel locks for inactive tunnels
|
||||
active_tunnel_keys = set(self._tunnels.keys())
|
||||
self._tunnel_locks.cleanup(active_tunnel_keys)
|
||||
|
||||
# Log for debugging
|
||||
if active_engine_keys or active_tunnel_keys:
|
||||
logger.debug(
|
||||
"EngineManager resources - Engines: %d, Tunnels: %d",
|
||||
len(active_engine_keys),
|
||||
len(active_tunnel_keys),
|
||||
)
|
||||
|
||||
def _add_disposal_listener(self, engine: Engine, engine_key: EngineKey) -> None:
|
||||
@event.listens_for(engine, "engine_disposed")
|
||||
def on_engine_disposed(engine_instance: Engine) -> None:
|
||||
try:
|
||||
# Remove engine from cache - no per-key locks to clean up anymore
|
||||
if self._engines.pop(engine_key, None):
|
||||
# Log only first 8 chars of hash for safety
|
||||
# (still enough for debugging, but doesn't expose full key)
|
||||
log_key = engine_key[:8] + "..."
|
||||
logger.info("Engine disposed and removed from cache: %s", log_key)
|
||||
except Exception as ex:
|
||||
logger.error("Error during engine disposal cleanup: %s", str(ex))
|
||||
# Don't log engine_key to avoid exposing credential hash
|
||||
|
||||
@@ -15,13 +15,12 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import atexit
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from superset.engines.manager import EngineManager, EngineModes
|
||||
from superset.engines.manager import EngineManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -29,10 +28,6 @@ logger = logging.getLogger(__name__)
|
||||
class EngineManagerExtension:
|
||||
"""
|
||||
Flask extension for managing SQLAlchemy engines in Superset.
|
||||
|
||||
This extension creates and configures an EngineManager instance based on
|
||||
Flask configuration, handling startup and shutdown of background cleanup
|
||||
threads as needed.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
@@ -44,47 +39,23 @@ class EngineManagerExtension:
|
||||
"""
|
||||
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
|
||||
db_connection_mutator = app.config["DB_CONNECTION_MUTATOR"]
|
||||
mode = app.config["ENGINE_MANAGER_MODE"]
|
||||
cleanup_interval = app.config["ENGINE_MANAGER_CLEANUP_INTERVAL"]
|
||||
local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||
tunnel_timeout = timedelta(seconds=app.config["SSH_TUNNEL_TIMEOUT_SEC"])
|
||||
ssh_timeout = timedelta(seconds=app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"])
|
||||
auto_start_cleanup = app.config["ENGINE_MANAGER_AUTO_START_CLEANUP"]
|
||||
|
||||
# Stop any existing manager's cleanup thread before creating a new one
|
||||
if self.engine_manager:
|
||||
self.engine_manager.stop_cleanup_thread()
|
||||
|
||||
# Create the engine manager
|
||||
self.engine_manager = EngineManager(
|
||||
engine_context_manager,
|
||||
db_connection_mutator,
|
||||
mode,
|
||||
cleanup_interval,
|
||||
local_bind_address,
|
||||
tunnel_timeout,
|
||||
ssh_timeout,
|
||||
)
|
||||
|
||||
# Start cleanup thread if requested and in SINGLETON mode
|
||||
if auto_start_cleanup and mode == EngineModes.SINGLETON:
|
||||
self.engine_manager.start_cleanup_thread()
|
||||
logger.info("Started EngineManager cleanup thread")
|
||||
|
||||
# Register shutdown handler
|
||||
def shutdown_engine_manager(exc: BaseException | None = None) -> None:
|
||||
if self.engine_manager:
|
||||
self.engine_manager.stop_cleanup_thread()
|
||||
|
||||
app.teardown_appcontext(shutdown_engine_manager)
|
||||
|
||||
# Register with atexit for clean shutdown
|
||||
atexit.register(shutdown_engine_manager)
|
||||
|
||||
logger.info(
|
||||
"Initialized EngineManager with mode=%s, cleanup_interval=%ds",
|
||||
mode,
|
||||
cleanup_interval.total_seconds(),
|
||||
"Initialized EngineManager with tunnel_timeout=%s, ssh_timeout=%s",
|
||||
tunnel_timeout.total_seconds(),
|
||||
ssh_timeout.total_seconds(),
|
||||
)
|
||||
|
||||
@property
|
||||
|
||||
@@ -429,8 +429,8 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
|
||||
Context manager for a SQLAlchemy engine.
|
||||
|
||||
This method will return a context manager for a SQLAlchemy engine. The engine
|
||||
manager handles connection pooling, SSH tunnels, and other connection details
|
||||
based on the configured mode (NEW or SINGLETON).
|
||||
manager handles engine creation, SSH tunnels, and connection details in a
|
||||
centralized place.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from superset.extensions import engine_manager_extension
|
||||
|
||||
@@ -15,513 +15,95 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for EngineManager."""
|
||||
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import _LockManager, EngineManager, EngineModes
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
|
||||
from superset.engines.manager import EngineManager
|
||||
|
||||
|
||||
class TestLockManager:
|
||||
"""Test the _LockManager class."""
|
||||
|
||||
def test_get_lock_creates_new_lock(self):
|
||||
"""Test that get_lock creates a new lock when needed."""
|
||||
manager = _LockManager()
|
||||
lock1 = manager.get_lock("key1")
|
||||
|
||||
assert isinstance(lock1, type(threading.RLock()))
|
||||
assert lock1 is manager.get_lock("key1") # Same lock returned
|
||||
|
||||
def test_get_lock_different_keys_different_locks(self):
|
||||
"""Test that different keys get different locks."""
|
||||
manager = _LockManager()
|
||||
lock1 = manager.get_lock("key1")
|
||||
lock2 = manager.get_lock("key2")
|
||||
|
||||
assert lock1 is not lock2
|
||||
|
||||
def test_cleanup_removes_unused_locks(self):
|
||||
"""Test that cleanup removes locks for inactive keys."""
|
||||
manager = _LockManager()
|
||||
|
||||
# Create locks
|
||||
_ = manager.get_lock("key1") # noqa: F841
|
||||
lock2 = manager.get_lock("key2")
|
||||
|
||||
# Cleanup with only key1 active
|
||||
manager.cleanup({"key1"})
|
||||
|
||||
# key2 lock should be removed
|
||||
lock3 = manager.get_lock("key2")
|
||||
assert lock3 is not lock2 # New lock created
|
||||
|
||||
def test_concurrent_lock_creation(self):
|
||||
"""Test that concurrent lock creation doesn't create duplicates."""
|
||||
manager = _LockManager()
|
||||
locks_created = []
|
||||
exceptions = []
|
||||
|
||||
def create_lock():
|
||||
try:
|
||||
lock = manager.get_lock("concurrent_key")
|
||||
locks_created.append(lock)
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
|
||||
# Create multiple threads trying to get the same lock
|
||||
threads = [threading.Thread(target=create_lock) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(exceptions) == 0
|
||||
assert len(locks_created) == 10
|
||||
|
||||
# All should be the same lock
|
||||
first_lock = locks_created[0]
|
||||
for lock in locks_created[1:]:
|
||||
assert lock is first_lock
|
||||
|
||||
|
||||
class TestEngineManager:
|
||||
"""Test the EngineManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine_manager(self):
|
||||
"""Create a mock EngineManager instance."""
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def dummy_context_manager(
|
||||
database: MagicMock, catalog: str | None, schema: str | None
|
||||
) -> Iterator[None]:
|
||||
yield
|
||||
|
||||
return EngineManager(engine_context_manager=dummy_context_manager)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database(self):
|
||||
"""Create a mock database."""
|
||||
database = MagicMock()
|
||||
database.sqlalchemy_uri_decrypted = "postgresql://user:pass@localhost/test"
|
||||
database.get_extra.return_value = {"engine_params": {}}
|
||||
database.get_effective_user.return_value = "test_user"
|
||||
database.impersonate_user = False
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (MagicMock(), {})
|
||||
database.db_engine_spec.impersonate_user = MagicMock(
|
||||
return_value=(MagicMock(), {})
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
database.ssh_tunnel = None
|
||||
return database
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_new_mode(
|
||||
self, mock_make_url, mock_create_engine, engine_manager, mock_database
|
||||
@pytest.fixture
|
||||
def engine_manager() -> EngineManager:
|
||||
@contextmanager
|
||||
def dummy_context_manager(
|
||||
database: MagicMock,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
):
|
||||
"""Test getting an engine in NEW mode (no caching)."""
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_make_url.return_value = MagicMock()
|
||||
mock_engine1 = MagicMock()
|
||||
mock_engine2 = MagicMock()
|
||||
mock_create_engine.side_effect = [mock_engine1, mock_engine2]
|
||||
|
||||
result = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
|
||||
|
||||
assert result is mock_engine1
|
||||
mock_create_engine.assert_called_once()
|
||||
|
||||
# Calling again should create a new engine (no caching)
|
||||
mock_create_engine.reset_mock()
|
||||
result2 = engine_manager._get_engine(mock_database, "catalog2", "schema2", None)
|
||||
|
||||
assert result2 is mock_engine2 # Different engine
|
||||
mock_create_engine.assert_called_once()
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_singleton_mode_caching(
|
||||
self, mock_make_url, mock_create_engine, engine_manager, mock_database
|
||||
):
|
||||
"""Test that engines are cached in SINGLETON mode."""
|
||||
engine_manager.mode = EngineModes.SINGLETON
|
||||
|
||||
# Use a real engine instead of MagicMock to avoid event listener issues
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
|
||||
mock_create_engine.return_value = real_engine
|
||||
mock_make_url.return_value = real_engine
|
||||
|
||||
# Call twice with same params - should be cached
|
||||
result1 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
|
||||
result2 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
|
||||
|
||||
assert result1 is result2 # Same engine returned (cached)
|
||||
mock_create_engine.assert_called_once() # Only created once
|
||||
|
||||
# Call with different params - should create new engine
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_concurrent_engine_creation(
|
||||
self, mock_make_url, mock_create_engine, engine_manager, mock_database
|
||||
):
|
||||
"""Test concurrent engine creation doesn't create duplicates."""
|
||||
engine_manager.mode = EngineModes.SINGLETON
|
||||
|
||||
# Use a real engine to avoid event listener issues with MagicMock
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
|
||||
mock_make_url.return_value = real_engine
|
||||
|
||||
create_count = [0]
|
||||
|
||||
def counting_create_engine(*args, **kwargs):
|
||||
create_count[0] += 1
|
||||
return real_engine
|
||||
|
||||
mock_create_engine.side_effect = counting_create_engine
|
||||
|
||||
results = []
|
||||
exceptions = []
|
||||
|
||||
def get_engine_thread():
|
||||
try:
|
||||
engine = engine_manager._get_engine(
|
||||
mock_database, "catalog1", "schema1", None
|
||||
)
|
||||
results.append(engine)
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
|
||||
# Run multiple threads
|
||||
threads = [threading.Thread(target=get_engine_thread) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(exceptions) == 0
|
||||
assert len(results) == 10
|
||||
assert create_count[0] == 1 # Engine created only once
|
||||
|
||||
# All results should be the same engine
|
||||
for engine in results:
|
||||
assert engine is real_engine
|
||||
|
||||
@patch("superset.engines.manager.sshtunnel.open_tunnel")
|
||||
def test_ssh_tunnel_creation(self, mock_open_tunnel, engine_manager):
|
||||
"""Test SSH tunnel creation and caching."""
|
||||
ssh_tunnel = MagicMock()
|
||||
ssh_tunnel.server_address = "ssh.example.com"
|
||||
ssh_tunnel.server_port = 22
|
||||
ssh_tunnel.username = "ssh_user"
|
||||
ssh_tunnel.password = "ssh_pass" # noqa: S105
|
||||
ssh_tunnel.private_key = None
|
||||
ssh_tunnel.private_key_password = None
|
||||
|
||||
tunnel_instance = MagicMock()
|
||||
tunnel_instance.is_active = True
|
||||
tunnel_instance.local_bind_address = ("127.0.0.1", 12345)
|
||||
mock_open_tunnel.return_value = tunnel_instance
|
||||
|
||||
uri = MagicMock()
|
||||
uri.host = "db.example.com"
|
||||
uri.port = 5432
|
||||
uri.get_backend_name.return_value = "postgresql"
|
||||
|
||||
result = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
|
||||
assert result is tunnel_instance
|
||||
mock_open_tunnel.assert_called_once()
|
||||
tunnel_instance.start.assert_called_once()
|
||||
|
||||
# Getting same tunnel again should return cached version
|
||||
mock_open_tunnel.reset_mock()
|
||||
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
|
||||
assert result2 is tunnel_instance
|
||||
mock_open_tunnel.assert_not_called()
|
||||
|
||||
@patch("superset.engines.manager.sshtunnel.open_tunnel")
|
||||
def test_ssh_tunnel_recreation_when_inactive(
|
||||
self, mock_open_tunnel, engine_manager
|
||||
):
|
||||
"""Test that inactive tunnels are replaced."""
|
||||
ssh_tunnel = MagicMock()
|
||||
ssh_tunnel.server_address = "ssh.example.com"
|
||||
ssh_tunnel.server_port = 22
|
||||
ssh_tunnel.username = "ssh_user"
|
||||
ssh_tunnel.password = "ssh_pass" # noqa: S105
|
||||
ssh_tunnel.private_key = None
|
||||
ssh_tunnel.private_key_password = None
|
||||
|
||||
# First tunnel is inactive
|
||||
inactive_tunnel = MagicMock()
|
||||
inactive_tunnel.is_active = False
|
||||
inactive_tunnel.local_bind_address = ("127.0.0.1", 12345)
|
||||
|
||||
# Second tunnel is active
|
||||
active_tunnel = MagicMock()
|
||||
active_tunnel.is_active = True
|
||||
active_tunnel.local_bind_address = ("127.0.0.1", 23456)
|
||||
|
||||
mock_open_tunnel.side_effect = [inactive_tunnel, active_tunnel]
|
||||
|
||||
uri = MagicMock()
|
||||
uri.host = "db.example.com"
|
||||
uri.port = 5432
|
||||
uri.get_backend_name.return_value = "postgresql"
|
||||
|
||||
# First call creates inactive tunnel
|
||||
result1 = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
assert result1 is inactive_tunnel
|
||||
|
||||
# Second call should create new tunnel since first is inactive
|
||||
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
assert result2 is active_tunnel
|
||||
assert mock_open_tunnel.call_count == 2
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_basic(
|
||||
self, mock_make_url, mock_create_engine, engine_manager
|
||||
):
|
||||
"""Test _get_engine_args returns correct URI and kwargs."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import EngineModes
|
||||
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {
|
||||
"engine_params": {},
|
||||
"connect_args": {"source": "Apache Superset"},
|
||||
}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = False
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
mock_uri,
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
|
||||
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
|
||||
|
||||
assert str(uri) == "trino://"
|
||||
assert "connect_args" in database.get_extra.return_value
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_user_impersonation(
|
||||
self, mock_make_url, mock_create_engine, engine_manager
|
||||
):
|
||||
"""Test user impersonation in _get_engine_args."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import EngineModes
|
||||
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {
|
||||
"engine_params": {},
|
||||
"connect_args": {"source": "Apache Superset"},
|
||||
}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = True
|
||||
database.get_oauth2_config.return_value = None
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
mock_uri,
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.impersonate_user.return_value = (
|
||||
mock_uri,
|
||||
{"connect_args": {"user": "alice", "source": "Apache Superset"}},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
|
||||
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
|
||||
|
||||
# Verify impersonate_user was called
|
||||
database.db_engine_spec.impersonate_user.assert_called_once()
|
||||
call_args = database.db_engine_spec.impersonate_user.call_args
|
||||
assert call_args[0][0] is database # database
|
||||
assert call_args[0][1] == "alice" # username
|
||||
assert call_args[0][2] is None # access_token (no OAuth2)
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_user_impersonation_email_prefix(
|
||||
self,
|
||||
mock_make_url,
|
||||
mock_create_engine,
|
||||
engine_manager,
|
||||
):
|
||||
"""Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature flag."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import EngineModes
|
||||
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
# Mock user with email
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "alice.doe@example.org"
|
||||
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {
|
||||
"engine_params": {},
|
||||
"connect_args": {"source": "Apache Superset"},
|
||||
}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = True
|
||||
database.get_oauth2_config.return_value = None
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
mock_uri,
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.impersonate_user.return_value = (
|
||||
mock_uri,
|
||||
{"connect_args": {"user": "alice.doe", "source": "Apache Superset"}},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"superset.extensions.security_manager.find_user",
|
||||
return_value=mock_user,
|
||||
),
|
||||
):
|
||||
uri, kwargs = engine_manager._get_engine_args(
|
||||
database, None, None, None, None
|
||||
)
|
||||
|
||||
# Verify impersonate_user was called with the email prefix
|
||||
database.db_engine_spec.impersonate_user.assert_called_once()
|
||||
call_args = database.db_engine_spec.impersonate_user.call_args
|
||||
assert call_args[0][1] == "alice.doe" # username from email prefix
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_engine_context_manager_called(
|
||||
self, mock_make_url, mock_create_engine, engine_manager, mock_database
|
||||
):
|
||||
"""Test that the engine context manager is properly called."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
mock_engine = MagicMock()
|
||||
mock_create_engine.return_value = mock_engine
|
||||
|
||||
# Track context manager calls
|
||||
context_manager_calls = []
|
||||
|
||||
def tracking_context_manager(database, catalog, schema):
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def inner():
|
||||
context_manager_calls.append(("enter", database, catalog, schema))
|
||||
yield
|
||||
context_manager_calls.append(("exit", database, catalog, schema))
|
||||
|
||||
return inner()
|
||||
|
||||
engine_manager.engine_context_manager = tracking_context_manager
|
||||
|
||||
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
|
||||
pass
|
||||
|
||||
assert len(context_manager_calls) == 2
|
||||
assert context_manager_calls[0][0] == "enter"
|
||||
assert context_manager_calls[0][1] is mock_database
|
||||
assert context_manager_calls[0][2] == "catalog1"
|
||||
assert context_manager_calls[0][3] == "schema1"
|
||||
assert context_manager_calls[1][0] == "exit"
|
||||
|
||||
@patch("superset.utils.oauth2.check_for_oauth2")
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_engine_oauth2_error_handling(
|
||||
self,
|
||||
mock_make_url,
|
||||
mock_create_engine,
|
||||
mock_check_for_oauth2,
|
||||
engine_manager,
|
||||
mock_database,
|
||||
):
|
||||
"""Test that OAuth2 errors are properly propagated from get_engine."""
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
# Simulate OAuth2 error during engine creation
|
||||
class OAuth2TestError(Exception):
|
||||
pass
|
||||
|
||||
oauth_error = OAuth2TestError("OAuth2 required")
|
||||
mock_create_engine.side_effect = oauth_error
|
||||
|
||||
# Make get_dbapi_mapped_exception return the original exception
|
||||
mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = (
|
||||
oauth_error
|
||||
)
|
||||
|
||||
# Mock check_for_oauth2 to re-raise the exception
|
||||
@contextmanager
|
||||
def mock_oauth2_context(database):
|
||||
try:
|
||||
yield
|
||||
except OAuth2TestError:
|
||||
raise
|
||||
|
||||
mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database)
|
||||
|
||||
with pytest.raises(OAuth2TestError, match="OAuth2 required"):
|
||||
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
|
||||
pass
|
||||
yield
|
||||
|
||||
return EngineManager(engine_context_manager=dummy_context_manager)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database() -> MagicMock:
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {"engine_params": {"poolclass": "queue"}}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = False
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
make_url("trino://"),
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
return database
|
||||
|
||||
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_uses_null_pool(
|
||||
mock_make_url: MagicMock,
|
||||
engine_manager: EngineManager,
|
||||
mock_database: MagicMock,
|
||||
) -> None:
|
||||
mock_make_url.return_value = make_url("trino://")
|
||||
|
||||
_, kwargs = engine_manager._get_engine_args(mock_database, None, None, None, None)
|
||||
|
||||
assert kwargs["poolclass"] is pool.NullPool
|
||||
|
||||
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_with_impersonation(
|
||||
mock_make_url: MagicMock,
|
||||
engine_manager: EngineManager,
|
||||
mock_database: MagicMock,
|
||||
) -> None:
|
||||
mock_make_url.return_value = make_url("trino://")
|
||||
mock_database.impersonate_user = True
|
||||
mock_database.get_oauth2_config.return_value = None
|
||||
mock_database.db_engine_spec.impersonate_user.return_value = (
|
||||
make_url("trino://"),
|
||||
{"connect_args": {"user": "alice"}, "poolclass": pool.NullPool},
|
||||
)
|
||||
|
||||
engine_manager._get_engine_args(mock_database, None, None, None, None)
|
||||
|
||||
mock_database.db_engine_spec.impersonate_user.assert_called_once()
|
||||
|
||||
|
||||
def test_get_tunnel_kwargs_requires_database_port(
|
||||
engine_manager: EngineManager,
|
||||
) -> None:
|
||||
ssh_tunnel = MagicMock()
|
||||
ssh_tunnel.server_address = "ssh.example.com"
|
||||
ssh_tunnel.server_port = 22
|
||||
ssh_tunnel.username = "ssh_user"
|
||||
ssh_tunnel.password = None
|
||||
ssh_tunnel.private_key = None
|
||||
ssh_tunnel.private_key_password = None
|
||||
|
||||
uri = MagicMock()
|
||||
uri.port = None
|
||||
uri.get_backend_name.return_value = "unknown"
|
||||
|
||||
with patch("superset.utils.ssh_tunnel.get_default_port", return_value=None):
|
||||
with pytest.raises(SSHTunnelDatabasePortError):
|
||||
engine_manager._get_tunnel_kwargs(ssh_tunnel, uri)
|
||||
|
||||
@@ -689,13 +689,15 @@ def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
|
||||
encrypted_extra=json.dumps(oauth2_client_info),
|
||||
)
|
||||
database.db_engine_spec.oauth2_exception = OAuth2Error
|
||||
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
|
||||
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
|
||||
create_engine = mocker.patch("superset.engines.manager.create_engine")
|
||||
create_engine.side_effect = OAuth2Error("OAuth2 required")
|
||||
|
||||
with pytest.raises(OAuth2RedirectError) as excinfo:
|
||||
with database.get_raw_connection() as conn:
|
||||
conn.cursor()
|
||||
assert str(excinfo.value) == "You don't have permission to access the data."
|
||||
|
||||
|
||||
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||
|
||||
Reference in New Issue
Block a user