Compare commits

...

14 Commits

Author SHA1 Message Date
Beto Dealmeida
3bb4b5f3a6 fix: SSH tunnel and test connection error handling
- Use sshtunnel.open_tunnel() instead of SSHTunnelForwarder directly
  to properly handle debug_level parameter
- Fix keepalive parameter name (set_keepalive, not keepalive)
- Fix test assertions that were inside pytest.raises blocks and never
  executed - now check error_type instead of string messages
- Update SSH tunnel test mocks to patch open_tunnel

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 10:58:39 -05:00
Beto Dealmeida
c00fae53a5 Rebase 2026-02-04 09:58:15 -05:00
Beto Dealmeida
99935fc035 Fix tests 2026-02-04 09:44:54 -05:00
Beto Dealmeida
d06ccf5152 Fix poolclass check 2026-02-04 09:44:54 -05:00
Beto Dealmeida
62d2d82ed8 Fix more tests 2026-02-04 09:44:54 -05:00
Beto Dealmeida
688224c4c0 Simplify key generation 2026-02-04 09:44:54 -05:00
Beto Dealmeida
de8c250f86 Update existing tests 2026-02-04 09:44:54 -05:00
Beto Dealmeida
be31abeb7e Hash key 2026-02-04 09:42:43 -05:00
Beto Dealmeida
5f61bb8d76 Small improvements 2026-02-04 09:42:43 -05:00
Beto Dealmeida
bb5a15dc5a Cleanup 2026-02-04 09:42:42 -05:00
Beto Dealmeida
929b0337f4 Connecting 2026-02-04 09:42:42 -05:00
Beto Dealmeida
baf6e03d16 Add extension 2026-02-04 09:42:42 -05:00
Beto Dealmeida
e82e06891b Cleanup locks 2026-02-04 09:42:41 -05:00
Beto Dealmeida
5753dfbb6e feat: engine manager 2026-02-04 09:42:41 -05:00
17 changed files with 1433 additions and 463 deletions

View File

@@ -24,6 +24,41 @@ assists people when migrating to a new version.
## Next
### Engine Manager for Connection Pooling
A new `EngineManager` class has been introduced to centralize SQLAlchemy engine creation and management. This enables connection pooling for analytics databases and provides a more flexible architecture for engine configuration.
#### Breaking Changes
1. **Removed `SSH_TUNNEL_MANAGER_CLASS` config**: SSH tunnel handling is now integrated into the EngineManager. If you have custom SSH tunnel managers, you'll need to migrate to the new architecture.
2. **Removed `nullpool` parameter**: The `get_sqla_engine()` and `get_raw_connection()` methods on the `Database` model no longer accept a `nullpool` parameter. Pool configuration is now controlled through the engine manager.
3. **Removed `_get_sqla_engine()` method**: The private `_get_sqla_engine()` method has been removed from the `Database` model. All engine creation now goes through the `EngineManager`.
#### New Configuration Options
```python
# Engine manager mode:
# - EngineModes.NEW: Creates a new engine for every connection (default, original behavior)
# - EngineModes.SINGLETON: Reuses engines with connection pooling
from superset.engines.manager import EngineModes
ENGINE_MANAGER_MODE = EngineModes.NEW
# Cleanup interval for abandoned locks (default: 5 minutes)
from datetime import timedelta
ENGINE_MANAGER_CLEANUP_INTERVAL = timedelta(minutes=5)
# Automatically start cleanup thread for SINGLETON mode (default: True)
ENGINE_MANAGER_AUTO_START_CLEANUP = True
```
#### Migration Guide
- If you were using the `nullpool` parameter, remove it from your calls
- If you had a custom `SSH_TUNNEL_MANAGER_CLASS`, refactor to use the new EngineManager architecture
- If you need connection pooling, set `ENGINE_MANAGER_MODE = EngineModes.SINGLETON` and configure the pool in your database's `extra` JSON field
### WebSocket config for GAQ with Docker
[35896](https://github.com/apache/superset/pull/35896) and [37624](https://github.com/apache/superset/pull/37624) updated documentation on how to run and configure Superset with Docker. Specifically for the WebSocket configuration, a new `docker/superset-websocket/config.example.json` was added to the repo, so that users could copy it to create a `docker/superset-websocket/config.json` file. The existing `docker/superset-websocket/config.json` was removed and git-ignored, so if you're using GAQ / WebSocket make sure to:

View File

@@ -52,10 +52,15 @@ from superset.advanced_data_type.plugins.internet_address import internet_addres
from superset.advanced_data_type.plugins.internet_port import internet_port
from superset.advanced_data_type.types import AdvancedDataType
from superset.constants import CHANGE_ME_SECRET_KEY
from superset.engines.manager import EngineModes
from superset.jinja_context import BaseTemplateProcessor
from superset.key_value.types import JsonKeyValueCodec
from superset.stats_logger import DummyStatsLogger
from superset.superset_typing import CacheConfig
from superset.superset_typing import (
CacheConfig,
DBConnectionMutator,
EngineContextManager,
)
from superset.tasks.types import ExecutorType
from superset.themes.types import Theme
from superset.utils import core as utils
@@ -260,6 +265,22 @@ SQLALCHEMY_ENGINE_OPTIONS = {}
# SQLALCHEMY_CUSTOM_PASSWORD_STORE = lookup_password
SQLALCHEMY_CUSTOM_PASSWORD_STORE = None
# ---------------------------------------------------------
# Engine Manager Configuration
# ---------------------------------------------------------
# Engine manager mode: "NEW" creates a new engine for every connection (default),
# "SINGLETON" reuses engines with connection pooling
ENGINE_MANAGER_MODE = EngineModes.NEW
# Cleanup interval for abandoned locks in seconds (default: 5 minutes)
ENGINE_MANAGER_CLEANUP_INTERVAL = timedelta(minutes=5)
# Automatically start cleanup thread for SINGLETON mode (default: True)
ENGINE_MANAGER_AUTO_START_CLEANUP = True
# ---------------------------------------------------------
#
# The EncryptedFieldTypeAdapter is used whenever we're building SqlAlchemy models
# which include sensitive fields that should be app-encrypted BEFORE sending
@@ -809,7 +830,6 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
# FIREWALL (only port 22 is open)
# ----------------------------------------------------------------------
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
#: Timeout (seconds) for tunnel connection (open_channel timeout)
SSH_TUNNEL_TIMEOUT_SEC = 10.0
@@ -1684,7 +1704,7 @@ def engine_context_manager( # pylint: disable=unused-argument
yield None
ENGINE_CONTEXT_MANAGER = engine_context_manager
ENGINE_CONTEXT_MANAGER: EngineContextManager = engine_context_manager
# A callable that allows altering the database connection URL and params
# on the fly, at runtime. This allows for things like impersonation or
@@ -1701,7 +1721,7 @@ ENGINE_CONTEXT_MANAGER = engine_context_manager
#
# Note that the returned uri and params are passed directly to sqlalchemy's
# as such `create_engine(url, **params)`
DB_CONNECTION_MUTATOR = None
DB_CONNECTION_MUTATOR: DBConnectionMutator | None = None
# A callable that is invoked for every invocation of DB Engine Specs

View File

@@ -14,23 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import Mock
import sshtunnel
from superset.extensions.ssh import SSHManagerFactory
def test_ssh_tunnel_timeout_setting() -> None:
app = Mock()
app.config = {
"SSH_TUNNEL_MAX_RETRIES": 2,
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test",
"SSH_TUNNEL_TIMEOUT_SEC": 123.0,
"SSH_TUNNEL_PACKET_TIMEOUT_SEC": 321.0,
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
}
factory = SSHManagerFactory()
factory.init_app(app)
assert sshtunnel.TUNNEL_TIMEOUT == 123.0
assert sshtunnel.SSH_TIMEOUT == 321.0

630
superset/engines/manager.py Normal file
View File

@@ -0,0 +1,630 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import enum
import hashlib
import logging
import threading
from contextlib import contextmanager
from datetime import timedelta
from io import StringIO
from typing import Any, Iterator, TYPE_CHECKING
import sshtunnel
from paramiko import RSAKey
from sqlalchemy import create_engine, event, pool
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sshtunnel import SSHTunnelForwarder
from superset.databases.utils import make_url_safe
from superset.superset_typing import DBConnectionMutator, EngineContextManager
from superset.utils.core import get_query_source_from_request, get_user_id, QuerySource
if TYPE_CHECKING:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
logger = logging.getLogger(__name__)
class _LockManager:
"""
Manages per-key locks safely without defaultdict race conditions.
This class provides a thread-safe way to create and manage locks for specific keys,
avoiding the race conditions that occur when using defaultdict with threading.Lock.
The implementation uses a two-level locking strategy:
1. A meta-lock to protect the lock dictionary itself
2. Per-key locks to protect specific resources
This ensures that:
- Different keys can be locked concurrently (scalability)
- Lock creation is thread-safe (no race conditions)
- The same key always gets the same lock instance
"""
def __init__(self) -> None:
self._locks: dict[str, threading.RLock] = {}
self._meta_lock = threading.Lock()
def get_lock(self, key: str) -> threading.RLock:
"""
Get or create a lock for the given key.
This method uses double-checked locking to ensure thread safety:
1. First check without lock (fast path)
2. Acquire meta-lock if needed
3. Double-check inside the lock to prevent race conditions
This approach minimizes lock contention while ensuring correctness.
:param key: The key to get a lock for
:returns: An RLock instance for the given key
"""
if lock := self._locks.get(key):
return lock
with self._meta_lock:
# Double-check inside the lock
lock = self._locks.get(key)
if lock is None:
lock = threading.RLock()
self._locks[key] = lock
return lock
def cleanup(self, active_keys: set[str]) -> None:
"""
Remove locks for keys that are no longer in use.
This prevents memory leaks from accumulating locks for resources
that have been disposed.
:param active_keys: Set of keys that are still active
"""
with self._meta_lock:
# Find locks to remove
locks_to_remove = self._locks.keys() - active_keys
for key in locks_to_remove:
self._locks.pop(key, None)
EngineKey = str
TunnelKey = str
def _generate_cache_key(*args: Any) -> str:
"""
Generate a deterministic cache key from arbitrary arguments.
Uses repr() for serialization and SHA-256 for hashing. The resulting key
is a 32-character hex string that:
1. Is deterministic for the same inputs
2. Does not expose sensitive data (everything is hashed)
3. Has sufficient entropy to avoid collisions
:param args: Arguments to include in the cache key
:returns: 32-character hex string
"""
# Use repr() which works with most Python objects and is deterministic
serialized = repr(args).encode("utf-8")
return hashlib.sha256(serialized).hexdigest()[:32]
class EngineModes(enum.Enum):
# reuse existing engine if available, otherwise create a new one; this mode should
# have a connection pool configured in the database
SINGLETON = enum.auto()
# always create a new engine for every connection; this mode will use a NullPool
# and is the default behavior for Superset
NEW = enum.auto()
class EngineManager:
"""
A manager for SQLAlchemy engines.
This class handles the creation and management of SQLAlchemy engines, allowing them
to be configured with connection pools and reused across requests. The default mode
is the original behavior for Superset, where we create a new engine for every
connection, using a NullPool. The `SINGLETON` mode, on the other hand, allows for
reusing of the engines, as well as configuring the pool through the database
settings.
"""
def __init__(
self,
engine_context_manager: EngineContextManager,
db_connection_mutator: DBConnectionMutator | None = None,
mode: EngineModes = EngineModes.NEW,
cleanup_interval: timedelta = timedelta(minutes=5),
local_bind_address: str = "127.0.0.1",
tunnel_timeout: timedelta = timedelta(seconds=30),
ssh_timeout: timedelta = timedelta(seconds=1),
) -> None:
self.engine_context_manager = engine_context_manager
self.db_connection_mutator = db_connection_mutator
self.mode = mode
self.cleanup_interval = cleanup_interval
self.local_bind_address = local_bind_address
sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout.total_seconds()
sshtunnel.SSH_TIMEOUT = ssh_timeout.total_seconds()
self._engines: dict[EngineKey, Engine] = {}
self._engine_locks = _LockManager()
self._tunnels: dict[TunnelKey, SSHTunnelForwarder] = {}
self._tunnel_locks = _LockManager()
# Background cleanup thread management
self._cleanup_thread: threading.Thread | None = None
self._cleanup_stop_event = threading.Event()
self._cleanup_thread_lock = threading.Lock()
def __del__(self) -> None:
"""
Ensure cleanup thread is stopped when the manager is destroyed.
"""
try:
self.stop_cleanup_thread()
except Exception as ex:
# Avoid exceptions during garbage collection, but log if possible
try:
logger.warning("Error stopping cleanup thread: %s", ex)
except Exception: # noqa: S110
# If logging fails during destruction, we can't do anything
pass
@contextmanager
def get_engine(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
) -> Iterator[Engine]:
"""
Context manager to get a SQLAlchemy engine.
"""
# users can wrap the engine in their own context manager for different
# reasons
with self.engine_context_manager(database, catalog, schema):
# we need to check for errors indicating that OAuth2 is needed, and
# return the proper exception so it starts the authentication flow
from superset.utils.oauth2 import check_for_oauth2
with check_for_oauth2(database):
yield self._get_engine(database, catalog, schema, source)
def _get_engine(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
) -> Engine:
"""
Get a specific engine, or create it if none exists.
"""
source = source or get_query_source_from_request()
user_id = get_user_id()
if self.mode == EngineModes.NEW:
return self._create_engine(
database,
catalog,
schema,
source,
user_id,
)
engine_key = self._get_engine_key(
database,
catalog,
schema,
source,
user_id,
)
if engine := self._engines.get(engine_key):
return engine
lock = self._engine_locks.get_lock(engine_key)
with lock:
# Double-check inside the lock
if engine := self._engines.get(engine_key):
return engine
# Create and cache the engine
engine = self._create_engine(
database,
catalog,
schema,
source,
user_id,
)
self._engines[engine_key] = engine
self._add_disposal_listener(engine, engine_key)
return engine
def _get_engine_key(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
user_id: int | None,
) -> EngineKey:
"""
Generate a cache key for the engine.
The key is a hash of all parameters that affect the engine, ensuring
proper cache isolation without exposing sensitive data.
:returns: 32-character hex string
"""
uri, kwargs = self._get_engine_args(
database,
catalog,
schema,
source,
user_id,
)
return _generate_cache_key(
database.id,
catalog,
schema,
str(uri),
source,
user_id,
kwargs,
)
def _get_engine_args(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
user_id: int | None,
) -> tuple[URL, dict[str, Any]]:
"""
Build the almost final SQLAlchemy URI and engine kwargs.
"Almost" final because we may still need to mutate the URI if an SSH tunnel is
needed, since it needs to connect to the tunnel instead of the original DB. But
that information is only available after the tunnel is created.
"""
# Import here to avoid circular imports
from superset.extensions import security_manager
from superset.utils.feature_flag_manager import FeatureFlagManager
uri = make_url_safe(database.sqlalchemy_uri_decrypted)
extra = database.get_extra(source)
# Make a copy to avoid mutating the original extra dict
kwargs = dict(extra.get("engine_params", {}))
# get pool class
if self.mode == EngineModes.NEW or "poolclass" not in kwargs:
kwargs["poolclass"] = pool.NullPool
else:
pools = {
"queue": pool.QueuePool,
"singleton": pool.SingletonThreadPool,
"assertion": pool.AssertionPool,
"null": pool.NullPool,
"static": pool.StaticPool,
}
kwargs["poolclass"] = pools.get(extra["poolclass"], pool.QueuePool)
# update URI for specific catalog/schema
connect_args = dict(extra.get("connect_args", {}))
uri, connect_args = database.db_engine_spec.adjust_engine_params(
uri,
connect_args,
catalog,
schema,
)
# get effective username
username = database.get_effective_user(uri)
feature_flag_manager = FeatureFlagManager()
if username and feature_flag_manager.is_feature_enabled(
"IMPERSONATE_WITH_EMAIL_PREFIX"
):
user = security_manager.find_user(username=username)
if user and user.email and "@" in user.email:
username = user.email.split("@")[0]
# update URI/kwargs for user impersonation
if database.impersonate_user:
oauth2_config = database.get_oauth2_config()
# Import here to avoid circular imports
from superset.utils.oauth2 import get_oauth2_access_token
access_token = (
get_oauth2_access_token(
oauth2_config,
database.id,
user_id,
database.db_engine_spec,
)
if oauth2_config and user_id
else None
)
uri, kwargs = database.db_engine_spec.impersonate_user(
database,
username,
access_token,
uri,
kwargs,
)
# update kwargs from params stored encrupted at rest
database.update_params_from_encrypted_extra(kwargs)
# mutate URI
if self.db_connection_mutator:
source = source or get_query_source_from_request()
# Import here to avoid circular imports
from superset.extensions import security_manager
uri, kwargs = self.db_connection_mutator(
uri,
kwargs,
username,
security_manager,
source,
)
# validate final URI
database.db_engine_spec.validate_database_uri(uri)
return uri, kwargs
def _create_engine(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
user_id: int | None,
) -> Engine:
"""
Create the actual engine.
This should be the only place in Superset where a SQLAlchemy engine is created,
"""
uri, kwargs = self._get_engine_args(
database,
catalog,
schema,
source,
user_id,
)
if database.ssh_tunnel:
tunnel = self._get_tunnel(database.ssh_tunnel, uri)
uri = uri.set(
host=tunnel.local_bind_address[0],
port=tunnel.local_bind_port,
)
try:
engine = create_engine(uri, **kwargs)
except Exception as ex:
raise database.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
return engine
def _get_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
tunnel_key = self._get_tunnel_key(ssh_tunnel, uri)
tunnel = self._tunnels.get(tunnel_key)
if tunnel is not None and tunnel.is_active:
return tunnel
lock = self._tunnel_locks.get_lock(tunnel_key)
with lock:
# Double-check inside the lock
tunnel = self._tunnels.get(tunnel_key)
if tunnel is not None and tunnel.is_active:
return tunnel
# Create or replace tunnel
return self._replace_tunnel(tunnel_key, ssh_tunnel, uri, tunnel)
def _replace_tunnel(
self,
tunnel_key: str,
ssh_tunnel: "SSHTunnel",
uri: URL,
old_tunnel: SSHTunnelForwarder | None,
) -> SSHTunnelForwarder:
"""
Replace tunnel with proper cleanup.
This function assumes caller holds lock.
"""
if old_tunnel:
try:
old_tunnel.stop()
except Exception:
logger.exception("Error stopping old tunnel")
try:
new_tunnel = self._create_tunnel(ssh_tunnel, uri)
self._tunnels[tunnel_key] = new_tunnel
except Exception:
# Remove failed tunnel from cache
self._tunnels.pop(tunnel_key, None)
logger.exception("Failed to create tunnel")
raise
return new_tunnel
def _get_tunnel_key(self, ssh_tunnel: "SSHTunnel", uri: URL) -> TunnelKey:
"""
Generate a cache key for the SSH tunnel.
:returns: 32-character hex string
"""
tunnel_kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
return _generate_cache_key(tunnel_kwargs)
def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
# Use open_tunnel which handles debug_level properly
tunnel = sshtunnel.open_tunnel(**kwargs)
tunnel.start()
return tunnel
def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]:
# Import here to avoid circular imports
from superset.utils.ssh_tunnel import get_default_port
backend = uri.get_backend_name()
kwargs = {
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
"ssh_username": ssh_tunnel.username,
"remote_bind_address": (uri.host, uri.port or get_default_port(backend)),
"local_bind_address": (self.local_bind_address,),
"debug_level": logging.getLogger("flask_appbuilder").level,
}
if ssh_tunnel.password:
kwargs["ssh_password"] = ssh_tunnel.password
elif ssh_tunnel.private_key:
private_key_file = StringIO(ssh_tunnel.private_key)
private_key = RSAKey.from_private_key(
private_key_file,
ssh_tunnel.private_key_password,
)
kwargs["ssh_pkey"] = private_key
if self.mode == EngineModes.NEW:
kwargs["set_keepalive"] = 0 # disable keepalive for one-time tunnels
return kwargs
def start_cleanup_thread(self) -> None:
"""
Start the background cleanup thread.
The thread will periodically clean up abandoned locks at the configured
interval. This is safe to call multiple times - subsequent calls are no-ops.
"""
with self._cleanup_thread_lock:
if self._cleanup_thread is None or not self._cleanup_thread.is_alive():
self._cleanup_stop_event.clear()
self._cleanup_thread = threading.Thread(
target=self._cleanup_worker,
name=f"EngineManager-cleanup-{id(self)}",
daemon=True,
)
self._cleanup_thread.start()
logger.info(
"Started cleanup thread with %ds interval",
self.cleanup_interval.total_seconds(),
)
def stop_cleanup_thread(self) -> None:
"""
Stop the background cleanup thread gracefully.
This will signal the thread to stop and wait for it to finish.
Safe to call even if no thread is running.
"""
with self._cleanup_thread_lock:
if self._cleanup_thread is not None and self._cleanup_thread.is_alive():
self._cleanup_stop_event.set()
self._cleanup_thread.join(timeout=5.0) # 5 second timeout
if self._cleanup_thread.is_alive():
logger.warning("Cleanup thread did not stop within timeout")
else:
logger.info("Cleanup thread stopped")
self._cleanup_thread = None
def _cleanup_worker(self) -> None:
"""
Background thread worker that periodically cleans up abandoned locks.
"""
while not self._cleanup_stop_event.is_set():
try:
self._cleanup_abandoned_locks()
except Exception:
logger.exception("Error during background cleanup")
# Use wait() instead of sleep() to allow for immediate shutdown
if self._cleanup_stop_event.wait(
timeout=self.cleanup_interval.total_seconds()
):
break # Stop event was set
def cleanup(self) -> None:
"""
Public method to manually trigger cleanup of abandoned locks.
This can be called periodically by external systems to prevent
memory leaks from accumulating locks.
"""
self._cleanup_abandoned_locks()
def _cleanup_abandoned_locks(self) -> None:
"""
Clean up locks for engines and tunnels that no longer exist.
This prevents memory leaks from accumulating locks when engines/tunnels
are disposed outside of normal cleanup paths.
"""
# Clean up engine locks for inactive engines
active_engine_keys = set(self._engines.keys())
self._engine_locks.cleanup(active_engine_keys)
# Clean up tunnel locks for inactive tunnels
active_tunnel_keys = set(self._tunnels.keys())
self._tunnel_locks.cleanup(active_tunnel_keys)
# Log for debugging
if active_engine_keys or active_tunnel_keys:
logger.debug(
"EngineManager resources - Engines: %d, Tunnels: %d",
len(active_engine_keys),
len(active_tunnel_keys),
)
def _add_disposal_listener(self, engine: Engine, engine_key: EngineKey) -> None:
@event.listens_for(engine, "engine_disposed")
def on_engine_disposed(engine_instance: Engine) -> None:
try:
# Remove engine from cache - no per-key locks to clean up anymore
if self._engines.pop(engine_key, None):
# Log only first 8 chars of hash for safety
# (still enough for debugging, but doesn't expose full key)
log_key = engine_key[:8] + "..."
logger.info("Engine disposed and removed from cache: %s", log_key)
except Exception as ex:
logger.error("Error during engine disposal cleanup: %s", str(ex))
# Don't log engine_key to avoid exposing credential hash

View File

@@ -41,7 +41,7 @@ from werkzeug.local import LocalProxy
from superset.async_events.async_query_manager import AsyncQueryManager
from superset.async_events.async_query_manager_factory import AsyncQueryManagerFactory
from superset.extensions.ssh import SSHManagerFactory
from superset.extensions.engine_manager import EngineManagerExtension
from superset.extensions.stats_logger import BaseStatsLoggerManager
from superset.security.manager import SupersetSecurityManager
from superset.utils.cache_manager import CacheManager
@@ -136,6 +136,7 @@ cache_manager = CacheManager()
celery_app = celery.Celery()
csrf = CSRFProtect()
db = get_sqla_class()()
engine_manager_extension = EngineManagerExtension()
_event_logger: dict[str, Any] = {}
encrypted_field_factory = EncryptedFieldFactory()
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
@@ -146,6 +147,5 @@ migrate = Migrate()
profiling = ProfilingExtension()
results_backend_manager = ResultsBackendManager()
security_manager: SupersetSecurityManager = LocalProxy(lambda: appbuilder.sm)
ssh_manager_factory = SSHManagerFactory()
stats_logger_manager = BaseStatsLoggerManager()
talisman = Talisman()

View File

@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import timedelta
from flask import Flask
from superset.engines.manager import EngineManager, EngineModes
logger = logging.getLogger(__name__)
class EngineManagerExtension:
"""
Flask extension for managing SQLAlchemy engines in Superset.
This extension creates and configures an EngineManager instance based on
Flask configuration, handling startup and shutdown of background cleanup
threads as needed.
"""
def __init__(self) -> None:
self.engine_manager: EngineManager | None = None
def init_app(self, app: Flask) -> None:
"""
Initialize the EngineManager with Flask app configuration.
"""
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
db_connection_mutator = app.config["DB_CONNECTION_MUTATOR"]
mode = app.config["ENGINE_MANAGER_MODE"]
cleanup_interval = app.config["ENGINE_MANAGER_CLEANUP_INTERVAL"]
local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
tunnel_timeout = timedelta(seconds=app.config["SSH_TUNNEL_TIMEOUT_SEC"])
ssh_timeout = timedelta(seconds=app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"])
auto_start_cleanup = app.config["ENGINE_MANAGER_AUTO_START_CLEANUP"]
# Create the engine manager
self.engine_manager = EngineManager(
engine_context_manager,
db_connection_mutator,
mode,
cleanup_interval,
local_bind_address,
tunnel_timeout,
ssh_timeout,
)
# Start cleanup thread if requested and in SINGLETON mode
if auto_start_cleanup and mode == EngineModes.SINGLETON:
self.engine_manager.start_cleanup_thread()
logger.info("Started EngineManager cleanup thread")
# Register shutdown handler
def shutdown_engine_manager() -> None:
if self.engine_manager:
self.engine_manager.stop_cleanup_thread()
app.teardown_appcontext_funcs.append(lambda exc: None)
# Register with atexit for clean shutdown
import atexit
atexit.register(shutdown_engine_manager)
logger.info(
"Initialized EngineManager with mode=%s, cleanup_interval=%ds",
mode,
cleanup_interval.total_seconds(),
)
@property
def manager(self) -> EngineManager:
"""
Get the EngineManager instance.
Raises:
RuntimeError: If the extension hasn't been initialized with an app.
"""
if self.engine_manager is None:
raise RuntimeError(
"EngineManager extension not initialized. "
"Call init_app() with a Flask app first."
)
return self.engine_manager

View File

@@ -1,94 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from io import StringIO
from typing import TYPE_CHECKING
import sshtunnel
from flask import Flask
from paramiko import RSAKey
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.databases.utils import make_url_safe
from superset.utils.class_utils import load_class_from_name
if TYPE_CHECKING:
from superset.databases.ssh_tunnel.models import SSHTunnel
class SSHManager:
def __init__(self, app: Flask) -> None:
super().__init__()
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]
sshtunnel.SSH_TIMEOUT = app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]
def build_sqla_url(
self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder
) -> str:
# override any ssh tunnel configuration object
url = make_url_safe(sqlalchemy_url)
return url.set(
host=server.local_bind_address[0],
port=server.local_bind_port,
)
def create_tunnel(
self,
ssh_tunnel: "SSHTunnel",
sqlalchemy_database_uri: str,
) -> sshtunnel.SSHTunnelForwarder:
from superset.utils.ssh_tunnel import get_default_port
url = make_url_safe(sqlalchemy_database_uri)
backend = url.get_backend_name()
port = url.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
params = {
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
"ssh_username": ssh_tunnel.username,
"remote_bind_address": (url.host, port),
"local_bind_address": (self.local_bind_address,),
"debug_level": logging.getLogger("flask_appbuilder").level,
}
if ssh_tunnel.password:
params["ssh_password"] = ssh_tunnel.password
elif ssh_tunnel.private_key:
private_key_file = StringIO(ssh_tunnel.private_key)
private_key = RSAKey.from_private_key(
private_key_file, ssh_tunnel.private_key_password
)
params["ssh_pkey"] = private_key
return sshtunnel.open_tunnel(**params)
class SSHManagerFactory:
def __init__(self) -> None:
self._ssh_manager = None
def init_app(self, app: Flask) -> None:
self._ssh_manager = load_class_from_name(
app.config["SSH_TUNNEL_MANAGER_CLASS"]
)(app)
@property
def instance(self) -> SSHManager:
return self._ssh_manager # type: ignore

View File

@@ -49,13 +49,13 @@ from superset.extensions import (
csrf,
db,
encrypted_field_factory,
engine_manager_extension,
feature_flag_manager,
machine_auth_provider_factory,
manifest_processor,
migrate,
profiling,
results_backend_manager,
ssh_manager_factory,
stats_logger_manager,
talisman,
)
@@ -585,8 +585,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
self.configure_url_map_converters()
self.configure_data_sources()
self.configure_auth_provider()
self.configure_engine_manager()
self.configure_async_queries()
self.configure_ssh_manager()
self.configure_stats_manager()
# Hook that provides administrators a handle on the Flask APP
@@ -761,8 +761,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
def configure_auth_provider(self) -> None:
machine_auth_provider_factory.init_app(self.superset_app)
def configure_ssh_manager(self) -> None:
ssh_manager_factory.init_app(self.superset_app)
def configure_engine_manager(self) -> None:
engine_manager_extension.init_app(self.superset_app)
def configure_stats_manager(self) -> None:
stats_logger_manager.init_app(self.superset_app)

View File

@@ -25,24 +25,22 @@ import builtins
import logging
import textwrap
from ast import literal_eval
from contextlib import closing, contextmanager, nullcontext, suppress
from contextlib import closing, contextmanager, suppress
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
from inspect import signature
from typing import Any, Callable, cast, Optional, TYPE_CHECKING
from typing import Any, Callable, cast, Iterator, Optional, TYPE_CHECKING
import numpy
import pandas as pd
import sqlalchemy as sqla
import sshtunnel
from flask import current_app as app, g, has_app_context
from flask_appbuilder import Model
from marshmallow.exceptions import ValidationError
from sqlalchemy import (
Boolean,
Column,
create_engine,
DateTime,
ForeignKey,
Integer,
@@ -57,7 +55,6 @@ from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchModuleError
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship
from sqlalchemy.pool import NullPool
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import ColumnElement, expression, Select
from superset_core.api.models import Database as CoreDatabase
@@ -72,7 +69,6 @@ from superset.extensions import (
encrypted_field_factory,
event_logger,
security_manager,
ssh_manager_factory,
)
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin
from superset.result_set import SupersetResultSet
@@ -84,10 +80,9 @@ from superset.superset_typing import (
)
from superset.utils import cache as cache_util, core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import get_query_source_from_request, get_username
from superset.utils.core import get_username
from superset.utils.oauth2 import (
check_for_oauth2,
get_oauth2_access_token,
OAuth2ClientConfigSchema,
)
@@ -424,130 +419,31 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
)
@contextmanager
def get_sqla_engine( # pylint: disable=too-many-arguments
def get_sqla_engine(
self,
catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
) -> Engine:
) -> Iterator[Engine]:
"""
Context manager for a SQLAlchemy engine.
This method will return a context manager for a SQLAlchemy engine. Using the
context manager (as opposed to the engine directly) is important because we need
to potentially establish SSH tunnels before the connection is created, and clean
them up once the engine is no longer used.
This method will return a context manager for a SQLAlchemy engine. The engine
manager handles connection pooling, SSH tunnels, and other connection details
based on the configured mode (NEW or SINGLETON).
"""
# Import here to avoid circular imports
from superset.extensions import engine_manager_extension
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
ssh_context_manager = (
ssh_manager_factory.instance.create_tunnel(
ssh_tunnel=self.ssh_tunnel,
sqlalchemy_database_uri=sqlalchemy_uri,
)
if self.ssh_tunnel
else nullcontext()
)
with ssh_context_manager as ssh_context:
if ssh_context:
logger.info(
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s "
"ssh_timeout at %s",
sshtunnel.TUNNEL_TIMEOUT,
sshtunnel.SSH_TIMEOUT,
ssh_context.local_bind_address,
)
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
sqlalchemy_uri,
ssh_context,
)
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
with engine_context_manager(self, catalog, schema):
with check_for_oauth2(self):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
self,
catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
sqlalchemy_uri: str | None = None,
) -> Engine:
sqlalchemy_url = make_url_safe(
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
extra = self.get_extra(source)
engine_kwargs = extra.get("engine_params", {})
if nullpool:
engine_kwargs["poolclass"] = NullPool
connect_args = engine_kwargs.setdefault("connect_args", {})
# modify URL/args for a specific catalog/schema
sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
uri=sqlalchemy_url,
connect_args=connect_args,
# Use the engine manager to get the engine
engine_manager = engine_manager_extension.manager
with engine_manager.get_engine(
database=self,
catalog=catalog,
schema=schema,
)
effective_username = self.get_effective_user(sqlalchemy_url)
if effective_username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"):
user = security_manager.find_user(username=effective_username)
if user and user.email:
effective_username = user.email.split("@")[0]
oauth2_config = self.get_oauth2_config()
access_token = (
get_oauth2_access_token(
oauth2_config,
self.id,
g.user.id,
self.db_engine_spec,
)
if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
else None
)
masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
if self.impersonate_user:
sqlalchemy_url, engine_kwargs = self.db_engine_spec.impersonate_user(
self,
effective_username,
access_token,
sqlalchemy_url,
engine_kwargs,
)
self.update_params_from_encrypted_extra(engine_kwargs)
if DB_CONNECTION_MUTATOR := app.config["DB_CONNECTION_MUTATOR"]: # noqa: N806
source = source or get_query_source_from_request()
sqlalchemy_url, engine_kwargs = DB_CONNECTION_MUTATOR(
sqlalchemy_url,
engine_kwargs,
effective_username,
security_manager,
source,
)
try:
return create_engine(sqlalchemy_url, **engine_kwargs)
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
source=source,
) as engine:
yield engine
def add_database_to_signature(
self,
@@ -572,13 +468,11 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
self,
catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
) -> Connection:
with self.get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
) as engine:
with check_for_oauth2(self):

View File

@@ -18,17 +18,42 @@ from __future__ import annotations
from collections.abc import Hashable, Sequence
from datetime import datetime
from typing import Any, Literal, TYPE_CHECKING, TypeAlias, TypedDict
from typing import (
Any,
Callable,
ContextManager,
Literal,
TYPE_CHECKING,
TypeAlias,
TypedDict,
)
from sqlalchemy.engine.url import URL
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import NotRequired
from werkzeug.wrappers import Response
if TYPE_CHECKING:
from superset.utils.core import GenericDataType, QueryObjectFilterClause
from superset.models.core import Database
from superset.utils.core import (
GenericDataType,
QueryObjectFilterClause,
QuerySource,
)
SQLType: TypeAlias = TypeEngine | type[TypeEngine]
# Type alias for database connection mutator function
DBConnectionMutator: TypeAlias = Callable[
[URL, dict[str, Any], str | None, Any, "QuerySource | None"],
tuple[URL, dict[str, Any]],
]
# Type alias for engine context manager
EngineContextManager: TypeAlias = Callable[
["Database", str | None, str | None], ContextManager[None]
]
class LegacyMetric(TypedDict):
label: str | None

View File

@@ -170,7 +170,6 @@ def example_db_provider() -> Callable[[], Database]:
return self._db
def _load_lazy_data_to_decouple_from_session(self) -> None:
self._db._get_sqla_engine() # type: ignore
self._db.backend # type: ignore # noqa: B018
def remove(self) -> None:

View File

@@ -897,7 +897,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
class TestTestConnectionDatabaseCommand(SupersetTestCase):
@patch("superset.models.core.Database._get_sqla_engine")
@patch("superset.models.core.Database.get_sqla_engine")
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
@patch("superset.utils.core.g")
def test_connection_db_exception(
@@ -906,19 +906,23 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
"""Test to make sure event_logger is called when an exception is raised"""
database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.side_effect = Exception("An error has occurred!")
mock_get_sqla_engine.return_value.__enter__.side_effect = Exception(
"An error has occurred!"
)
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: # noqa: PT012
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo:
command_without_db_name.run()
assert str(excinfo.value) == (
"Unexpected error occurred, please check your logs for details"
)
# Exception wraps errors from db_engine_spec.extract_errors()
assert (
excinfo.value.errors[0].error_type
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
)
mock_event_logger.assert_called()
@patch("superset.models.core.Database._get_sqla_engine")
@patch("superset.models.core.Database.get_sqla_engine")
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
@patch("superset.utils.core.g")
def test_connection_do_ping_exception(
@@ -927,9 +931,8 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
"""Test to make sure do_ping exceptions gets captured"""
database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception(
"An error has occurred!"
)
mock_engine = mock_get_sqla_engine.return_value.__enter__.return_value
mock_engine.dialect.do_ping.side_effect = Exception("An error has occurred!")
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
@@ -967,7 +970,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
== SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
)
@patch("superset.models.core.Database._get_sqla_engine")
@patch("superset.models.core.Database.get_sqla_engine")
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
@patch("superset.utils.core.g")
def test_connection_superset_security_connection(
@@ -977,20 +980,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
connection exc is raised"""
database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.side_effect = SupersetSecurityException(
SupersetError(error_type=500, message="test", level="info")
mock_get_sqla_engine.return_value.__enter__.side_effect = (
SupersetSecurityException(
SupersetError(error_type=500, message="test", level="info")
)
)
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
with pytest.raises(DatabaseSecurityUnsafeError) as excinfo: # noqa: PT012
with pytest.raises(DatabaseSecurityUnsafeError):
command_without_db_name.run()
assert str(excinfo.value) == ("Stopped an unsafe database connection")
mock_event_logger.assert_called()
@patch("superset.models.core.Database._get_sqla_engine")
@patch("superset.models.core.Database.get_sqla_engine")
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
@patch("superset.utils.core.g")
def test_connection_db_api_exc(
@@ -999,19 +1002,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
"""Test to make sure event_logger is called when DBAPIError is raised"""
database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.side_effect = DBAPIError(
mock_get_sqla_engine.return_value.__enter__.side_effect = DBAPIError(
statement="error", params={}, orig={}
)
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
with pytest.raises(SupersetErrorsException) as excinfo: # noqa: PT012
with pytest.raises(SupersetErrorsException) as excinfo:
command_without_db_name.run()
assert str(excinfo.value) == (
"Connection failed, please check your connection settings"
)
# Exception wraps errors from db_engine_spec.extract_errors()
assert (
excinfo.value.errors[0].error_type
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
)
mock_event_logger.assert_called()
@@ -1147,7 +1151,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
with pytest.raises(DatabaseNotFoundError) as excinfo: # noqa: PT012
command.run()
assert str(excinfo.value) == ("Database not found.")
assert str(excinfo.value) == ("Database not found.")
@patch("superset.daos.database.DatabaseDAO.find_by_id")
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
@@ -1166,26 +1170,35 @@ class TestTablesDatabaseCommand(SupersetTestCase):
command = TablesDatabaseCommand(database.id, None, "main", False)
with pytest.raises(SupersetException) as excinfo: # noqa: PT012
command.run()
assert str(excinfo.value) == "Test Error"
assert str(excinfo.value) == "Test Error"
@patch("superset.daos.database.DatabaseDAO.find_by_id")
@patch("superset.models.core.Database.get_all_materialized_view_names_in_schema")
@patch("superset.models.core.Database.get_all_view_names_in_schema")
@patch("superset.models.core.Database.get_all_table_names_in_schema")
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
@patch("superset.utils.core.g")
def test_database_tables_exception(
self, mock_g, mock_can_access_database, mock_find_by_id
self,
mock_g,
mock_can_access_database,
mock_get_tables,
mock_get_views,
mock_get_mvs,
mock_find_by_id,
):
database = get_example_database()
mock_find_by_id.return_value = database
mock_get_tables.return_value = {("table1", "main", None)}
mock_get_views.return_value = set()
mock_get_mvs.return_value = []
mock_can_access_database.side_effect = Exception("Test Error")
mock_g.user = security_manager.find_user("admin")
command = TablesDatabaseCommand(database.id, None, "main", False)
with pytest.raises(DatabaseTablesUnexpectedError) as excinfo: # noqa: PT012
command.run()
assert (
str(excinfo.value)
== "Unexpected error occurred, please check your logs for details"
)
assert str(excinfo.value) == "Test Error"
@patch("superset.daos.database.DatabaseDAO.find_by_id")
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")

View File

@@ -145,7 +145,7 @@ class TestDatabaseModel(SupersetTestCase):
username = make_url(engine.url).username
assert example_user.username != username
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
@unittest.skipUnless(
SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
)
@@ -172,7 +172,8 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://gamma@localhost/"
@@ -185,7 +186,8 @@ class TestDatabaseModel(SupersetTestCase):
}
model.impersonate_user = False
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://localhost/"
@@ -199,13 +201,14 @@ class TestDatabaseModel(SupersetTestCase):
@unittest.skipUnless(
SupersetTestCase.is_module_installed("mysqlclient"), "mysqlclient not installed"
)
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
database_name="test_database1",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "mysql://user:password@localhost"
@@ -215,13 +218,14 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database2",
sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
)
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
assert call_args[1]["connect_args"]["allow_local_infile"] == 0
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
@@ -230,7 +234,8 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri="trino://localhost"
)
model.impersonate_user = True
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://localhost/"
@@ -242,7 +247,8 @@ class TestDatabaseModel(SupersetTestCase):
)
model.impersonate_user = True
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert (
@@ -251,7 +257,7 @@ class TestDatabaseModel(SupersetTestCase):
)
assert call_args[1]["connect_args"]["user"] == "gamma"
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
@unittest.skipUnless(
SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
)
@@ -281,7 +287,8 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
@@ -294,7 +301,8 @@ class TestDatabaseModel(SupersetTestCase):
}
model.impersonate_user = False
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
@@ -376,7 +384,7 @@ class TestDatabaseModel(SupersetTestCase):
df = main_db.get_df("USE superset; SELECT ';';", None, None)
assert df.iat[0, 0] == ";"
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
def test_get_sqla_engine(self, mocked_create_engine):
model = Database(
database_name="test_database",
@@ -387,7 +395,8 @@ class TestDatabaseModel(SupersetTestCase):
)
mocked_create_engine.side_effect = Exception()
with self.assertRaises(SupersetException): # noqa: PT027
model._get_sqla_engine()
with model.get_sqla_engine():
pass
class TestSqlaTableModel(SupersetTestCase):

View File

@@ -0,0 +1,527 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for EngineManager."""
import threading
from collections.abc import Iterator
from unittest.mock import MagicMock, patch
import pytest
from superset.engines.manager import _LockManager, EngineManager, EngineModes
class TestLockManager:
"""Test the _LockManager class."""
def test_get_lock_creates_new_lock(self):
"""Test that get_lock creates a new lock when needed."""
manager = _LockManager()
lock1 = manager.get_lock("key1")
assert isinstance(lock1, type(threading.RLock()))
assert lock1 is manager.get_lock("key1") # Same lock returned
def test_get_lock_different_keys_different_locks(self):
"""Test that different keys get different locks."""
manager = _LockManager()
lock1 = manager.get_lock("key1")
lock2 = manager.get_lock("key2")
assert lock1 is not lock2
def test_cleanup_removes_unused_locks(self):
"""Test that cleanup removes locks for inactive keys."""
manager = _LockManager()
# Create locks
_ = manager.get_lock("key1") # noqa: F841
lock2 = manager.get_lock("key2")
# Cleanup with only key1 active
manager.cleanup({"key1"})
# key2 lock should be removed
lock3 = manager.get_lock("key2")
assert lock3 is not lock2 # New lock created
def test_concurrent_lock_creation(self):
"""Test that concurrent lock creation doesn't create duplicates."""
manager = _LockManager()
locks_created = []
exceptions = []
def create_lock():
try:
lock = manager.get_lock("concurrent_key")
locks_created.append(lock)
except Exception as e:
exceptions.append(e)
# Create multiple threads trying to get the same lock
threads = [threading.Thread(target=create_lock) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(exceptions) == 0
assert len(locks_created) == 10
# All should be the same lock
first_lock = locks_created[0]
for lock in locks_created[1:]:
assert lock is first_lock
class TestEngineManager:
"""Test the EngineManager class."""
@pytest.fixture
def engine_manager(self):
"""Create a mock EngineManager instance."""
from contextlib import contextmanager
@contextmanager
def dummy_context_manager(
database: MagicMock, catalog: str | None, schema: str | None
) -> Iterator[None]:
yield
return EngineManager(engine_context_manager=dummy_context_manager)
@pytest.fixture
def mock_database(self):
"""Create a mock database."""
database = MagicMock()
database.sqlalchemy_uri_decrypted = "postgresql://user:pass@localhost/test"
database.get_extra.return_value = {"engine_params": {}}
database.get_effective_user.return_value = "test_user"
database.impersonate_user = False
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (MagicMock(), {})
database.db_engine_spec.impersonate_user = MagicMock(
return_value=(MagicMock(), {})
)
database.db_engine_spec.validate_database_uri = MagicMock()
database.ssh_tunnel = None
return database
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_new_mode(
self, mock_make_url, mock_create_engine, engine_manager, mock_database
):
"""Test getting an engine in NEW mode (no caching)."""
engine_manager.mode = EngineModes.NEW
mock_make_url.return_value = MagicMock()
mock_engine1 = MagicMock()
mock_engine2 = MagicMock()
mock_create_engine.side_effect = [mock_engine1, mock_engine2]
result = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
assert result is mock_engine1
mock_create_engine.assert_called_once()
# Calling again should create a new engine (no caching)
mock_create_engine.reset_mock()
result2 = engine_manager._get_engine(mock_database, "catalog2", "schema2", None)
assert result2 is mock_engine2 # Different engine
mock_create_engine.assert_called_once()
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_singleton_mode_caching(
self, mock_make_url, mock_create_engine, engine_manager, mock_database
):
"""Test that engines are cached in SINGLETON mode."""
engine_manager.mode = EngineModes.SINGLETON
# Use a real engine instead of MagicMock to avoid event listener issues
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
mock_create_engine.return_value = real_engine
mock_make_url.return_value = real_engine
# Call twice with same params - should be cached
result1 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
result2 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
assert result1 is result2 # Same engine returned (cached)
mock_create_engine.assert_called_once() # Only created once
# Call with different params - should create new engine
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_concurrent_engine_creation(
self, mock_make_url, mock_create_engine, engine_manager, mock_database
):
"""Test concurrent engine creation doesn't create duplicates."""
engine_manager.mode = EngineModes.SINGLETON
# Use a real engine to avoid event listener issues with MagicMock
from sqlalchemy import create_engine
from sqlalchemy.pool import StaticPool
real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
mock_make_url.return_value = real_engine
create_count = [0]
def counting_create_engine(*args, **kwargs):
create_count[0] += 1
return real_engine
mock_create_engine.side_effect = counting_create_engine
results = []
exceptions = []
def get_engine_thread():
try:
engine = engine_manager._get_engine(
mock_database, "catalog1", "schema1", None
)
results.append(engine)
except Exception as e:
exceptions.append(e)
# Run multiple threads
threads = [threading.Thread(target=get_engine_thread) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
assert len(exceptions) == 0
assert len(results) == 10
assert create_count[0] == 1 # Engine created only once
# All results should be the same engine
for engine in results:
assert engine is real_engine
@patch("superset.engines.manager.sshtunnel.open_tunnel")
def test_ssh_tunnel_creation(self, mock_open_tunnel, engine_manager):
"""Test SSH tunnel creation and caching."""
ssh_tunnel = MagicMock()
ssh_tunnel.server_address = "ssh.example.com"
ssh_tunnel.server_port = 22
ssh_tunnel.username = "ssh_user"
ssh_tunnel.password = "ssh_pass" # noqa: S105
ssh_tunnel.private_key = None
ssh_tunnel.private_key_password = None
tunnel_instance = MagicMock()
tunnel_instance.is_active = True
tunnel_instance.local_bind_address = ("127.0.0.1", 12345)
mock_open_tunnel.return_value = tunnel_instance
uri = MagicMock()
uri.host = "db.example.com"
uri.port = 5432
uri.get_backend_name.return_value = "postgresql"
result = engine_manager._get_tunnel(ssh_tunnel, uri)
assert result is tunnel_instance
mock_open_tunnel.assert_called_once()
tunnel_instance.start.assert_called_once()
# Getting same tunnel again should return cached version
mock_open_tunnel.reset_mock()
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
assert result2 is tunnel_instance
mock_open_tunnel.assert_not_called()
@patch("superset.engines.manager.sshtunnel.open_tunnel")
def test_ssh_tunnel_recreation_when_inactive(
self, mock_open_tunnel, engine_manager
):
"""Test that inactive tunnels are replaced."""
ssh_tunnel = MagicMock()
ssh_tunnel.server_address = "ssh.example.com"
ssh_tunnel.server_port = 22
ssh_tunnel.username = "ssh_user"
ssh_tunnel.password = "ssh_pass" # noqa: S105
ssh_tunnel.private_key = None
ssh_tunnel.private_key_password = None
# First tunnel is inactive
inactive_tunnel = MagicMock()
inactive_tunnel.is_active = False
inactive_tunnel.local_bind_address = ("127.0.0.1", 12345)
# Second tunnel is active
active_tunnel = MagicMock()
active_tunnel.is_active = True
active_tunnel.local_bind_address = ("127.0.0.1", 23456)
mock_open_tunnel.side_effect = [inactive_tunnel, active_tunnel]
uri = MagicMock()
uri.host = "db.example.com"
uri.port = 5432
uri.get_backend_name.return_value = "postgresql"
# First call creates inactive tunnel
result1 = engine_manager._get_tunnel(ssh_tunnel, uri)
assert result1 is inactive_tunnel
# Second call should create new tunnel since first is inactive
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
assert result2 is active_tunnel
assert mock_open_tunnel.call_count == 2
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_basic(
self, mock_make_url, mock_create_engine, engine_manager
):
"""Test _get_engine_args returns correct URI and kwargs."""
from sqlalchemy.engine.url import make_url
from superset.engines.manager import EngineModes
engine_manager.mode = EngineModes.NEW
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {
"engine_params": {},
"connect_args": {"source": "Apache Superset"},
}
database.get_effective_user.return_value = "alice"
database.impersonate_user = False
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
mock_uri,
{"source": "Apache Superset"},
)
database.db_engine_spec.validate_database_uri = MagicMock()
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
assert str(uri) == "trino://"
assert "connect_args" in database.get_extra.return_value
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_user_impersonation(
self, mock_make_url, mock_create_engine, engine_manager
):
"""Test user impersonation in _get_engine_args."""
from sqlalchemy.engine.url import make_url
from superset.engines.manager import EngineModes
engine_manager.mode = EngineModes.NEW
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {
"engine_params": {},
"connect_args": {"source": "Apache Superset"},
}
database.get_effective_user.return_value = "alice"
database.impersonate_user = True
database.get_oauth2_config.return_value = None
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
mock_uri,
{"source": "Apache Superset"},
)
database.db_engine_spec.impersonate_user.return_value = (
mock_uri,
{"connect_args": {"user": "alice", "source": "Apache Superset"}},
)
database.db_engine_spec.validate_database_uri = MagicMock()
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
# Verify impersonate_user was called
database.db_engine_spec.impersonate_user.assert_called_once()
call_args = database.db_engine_spec.impersonate_user.call_args
assert call_args[0][0] is database # database
assert call_args[0][1] == "alice" # username
assert call_args[0][2] is None # access_token (no OAuth2)
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_user_impersonation_email_prefix(
self,
mock_make_url,
mock_create_engine,
engine_manager,
):
"""Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature flag."""
from sqlalchemy.engine.url import make_url
from superset.engines.manager import EngineModes
engine_manager.mode = EngineModes.NEW
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
# Mock user with email
mock_user = MagicMock()
mock_user.email = "alice.doe@example.org"
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {
"engine_params": {},
"connect_args": {"source": "Apache Superset"},
}
database.get_effective_user.return_value = "alice"
database.impersonate_user = True
database.get_oauth2_config.return_value = None
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
mock_uri,
{"source": "Apache Superset"},
)
database.db_engine_spec.impersonate_user.return_value = (
mock_uri,
{"connect_args": {"user": "alice.doe", "source": "Apache Superset"}},
)
database.db_engine_spec.validate_database_uri = MagicMock()
with (
patch(
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled",
return_value=True,
),
patch(
"superset.extensions.security_manager.find_user",
return_value=mock_user,
),
):
uri, kwargs = engine_manager._get_engine_args(
database, None, None, None, None
)
# Verify impersonate_user was called with the email prefix
database.db_engine_spec.impersonate_user.assert_called_once()
call_args = database.db_engine_spec.impersonate_user.call_args
assert call_args[0][1] == "alice.doe" # username from email prefix
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_engine_context_manager_called(
self, mock_make_url, mock_create_engine, engine_manager, mock_database
):
"""Test that the engine context manager is properly called."""
from sqlalchemy.engine.url import make_url
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
mock_engine = MagicMock()
mock_create_engine.return_value = mock_engine
# Track context manager calls
context_manager_calls = []
def tracking_context_manager(database, catalog, schema):
from contextlib import contextmanager
@contextmanager
def inner():
context_manager_calls.append(("enter", database, catalog, schema))
yield
context_manager_calls.append(("exit", database, catalog, schema))
return inner()
engine_manager.engine_context_manager = tracking_context_manager
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
pass
assert len(context_manager_calls) == 2
assert context_manager_calls[0][0] == "enter"
assert context_manager_calls[0][1] is mock_database
assert context_manager_calls[0][2] == "catalog1"
assert context_manager_calls[0][3] == "schema1"
assert context_manager_calls[1][0] == "exit"
@patch("superset.utils.oauth2.check_for_oauth2")
@patch("superset.engines.manager.create_engine")
@patch("superset.engines.manager.make_url_safe")
def test_engine_oauth2_error_handling(
self,
mock_make_url,
mock_create_engine,
mock_check_for_oauth2,
engine_manager,
mock_database,
):
"""Test that OAuth2 errors are properly propagated from get_engine."""
from contextlib import contextmanager
from sqlalchemy.engine.url import make_url
mock_uri = make_url("trino://")
mock_make_url.return_value = mock_uri
# Simulate OAuth2 error during engine creation
class OAuth2TestError(Exception):
pass
oauth_error = OAuth2TestError("OAuth2 required")
mock_create_engine.side_effect = oauth_error
# Make get_dbapi_mapped_exception return the original exception
mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = (
oauth_error
)
# Mock check_for_oauth2 to re-raise the exception
@contextmanager
def mock_oauth2_context(database):
try:
yield
except OAuth2TestError:
raise
mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database)
with pytest.raises(OAuth2TestError, match="OAuth2 required"):
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
pass

View File

@@ -123,7 +123,7 @@ class TestSupersetAppInitializer:
patch.object(app_initializer, "configure_data_sources"),
patch.object(app_initializer, "configure_auth_provider"),
patch.object(app_initializer, "configure_async_queries"),
patch.object(app_initializer, "configure_ssh_manager"),
patch.object(app_initializer, "configure_engine_manager"),
patch.object(app_initializer, "configure_stats_manager"),
patch.object(app_initializer, "init_views"),
):

View File

@@ -19,7 +19,6 @@
from datetime import datetime
import pytest
from flask import current_app
from pytest_mock import MockerFixture
from sqlalchemy import (
Column,
@@ -29,7 +28,6 @@ from sqlalchemy import (
Table as SqlalchemyTable,
)
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import Select
@@ -525,60 +523,6 @@ def test_get_all_materialized_view_names_in_schema_needs_oauth2(
assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT
def test_get_sqla_engine(mocker: MockerFixture) -> None:
"""
Test `_get_sqla_engine`.
"""
from superset.models.core import Database
user = mocker.MagicMock()
user.email = "alice.doe@example.org"
mocker.patch(
"superset.models.core.security_manager.find_user",
return_value=user,
)
mocker.patch("superset.models.core.get_username", return_value="alice")
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
make_url("trino:///"),
connect_args={"source": "Apache Superset"},
)
def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
"""
Test user impersonation in `_get_sqla_engine`.
"""
from superset.models.core import Database
user = mocker.MagicMock()
user.email = "alice.doe@example.org"
mocker.patch(
"superset.models.core.security_manager.find_user",
return_value=user,
)
mocker.patch("superset.models.core.get_username", return_value="alice")
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
impersonate_user=True,
)
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
make_url("trino:///"),
connect_args={"user": "alice", "source": "Apache Superset"},
)
def test_add_database_to_signature():
args = ["param1", "param2"]
@@ -604,36 +548,6 @@ def test_add_database_to_signature():
assert args3 == ["param1", "param2", database]
@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None:
"""
Test user impersonation in `_get_sqla_engine` with `username_from_email`.
"""
from superset.models.core import Database
user = mocker.MagicMock()
user.email = "alice.doe@example.org"
mocker.patch(
"superset.models.core.security_manager.find_user",
return_value=user,
)
mocker.patch("superset.models.core.get_username", return_value="alice")
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
impersonate_user=True,
)
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
make_url("trino:///"),
connect_args={"user": "alice.doe", "source": "Apache Superset"},
)
def test_is_oauth2_enabled() -> None:
"""
Test the `is_oauth2_enabled` method.
@@ -753,37 +667,6 @@ def test_get_oauth2_config_redirect_uri_from_config(
assert config["redirect_uri"] == custom_redirect_uri
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
With OAuth2, some databases will raise an exception when the engine is first created
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
finally, GSheets will raise an exception when the query is executed.
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
triggered when the engine is created.
"""
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
g.user.id = 42
database = Database(
id=1,
database_name="my_db",
sqlalchemy_uri="sqlite://",
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
conn.cursor()
assert str(excinfo.value) == "You don't have permission to access the data."
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.
@@ -879,56 +762,6 @@ def test_get_schema_access_for_file_upload() -> None:
assert database.get_schema_access_for_file_upload() == {"public"}
def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> None:
"""
Test the engine context manager.
"""
from unittest.mock import MagicMock
engine_context_manager = MagicMock()
mocker.patch.dict(
current_app.config,
{"ENGINE_CONTEXT_MANAGER": engine_context_manager},
)
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
with database.get_sqla_engine("catalog", "schema"):
pass
engine_context_manager.assert_called_once_with(database, "catalog", "schema")
engine_context_manager().__enter__.assert_called_once()
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
_get_sqla_engine.assert_called_once_with(
catalog="catalog",
schema="schema",
nullpool=True,
source=None,
sqlalchemy_uri="trino://",
)
def test_engine_oauth2(mocker: MockerFixture) -> None:
"""
Test that we handle OAuth2 when `create_engine` fails.
"""
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
start_oauth2_dance = mocker.patch.object(
database.db_engine_spec,
"start_oauth2_dance",
side_effect=OAuth2Error("OAuth2 required"),
)
with pytest.raises(OAuth2Error):
with database.get_sqla_engine("catalog", "schema"):
pass
start_oauth2_dance.assert_called_with(database)
def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.

View File

@@ -202,7 +202,6 @@ def setup_mock_raw_connection(
def _raw_connection(
catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: Any | None = None,
):
yield mock_connection