mirror of
https://github.com/apache/superset.git
synced 2026-04-28 12:34:23 +00:00
Compare commits
14 Commits
fix/postgr
...
engine-man
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3bb4b5f3a6 | ||
|
|
c00fae53a5 | ||
|
|
99935fc035 | ||
|
|
d06ccf5152 | ||
|
|
62d2d82ed8 | ||
|
|
688224c4c0 | ||
|
|
de8c250f86 | ||
|
|
be31abeb7e | ||
|
|
5f61bb8d76 | ||
|
|
bb5a15dc5a | ||
|
|
929b0337f4 | ||
|
|
baf6e03d16 | ||
|
|
e82e06891b | ||
|
|
5753dfbb6e |
35
UPDATING.md
35
UPDATING.md
@@ -24,6 +24,41 @@ assists people when migrating to a new version.
|
||||
|
||||
## Next
|
||||
|
||||
### Engine Manager for Connection Pooling
|
||||
|
||||
A new `EngineManager` class has been introduced to centralize SQLAlchemy engine creation and management. This enables connection pooling for analytics databases and provides a more flexible architecture for engine configuration.
|
||||
|
||||
#### Breaking Changes
|
||||
|
||||
1. **Removed `SSH_TUNNEL_MANAGER_CLASS` config**: SSH tunnel handling is now integrated into the EngineManager. If you have custom SSH tunnel managers, you'll need to migrate to the new architecture.
|
||||
|
||||
2. **Removed `nullpool` parameter**: The `get_sqla_engine()` and `get_raw_connection()` methods on the `Database` model no longer accept a `nullpool` parameter. Pool configuration is now controlled through the engine manager.
|
||||
|
||||
3. **Removed `_get_sqla_engine()` method**: The private `_get_sqla_engine()` method has been removed from the `Database` model. All engine creation now goes through the `EngineManager`.
|
||||
|
||||
#### New Configuration Options
|
||||
|
||||
```python
|
||||
# Engine manager mode:
|
||||
# - EngineModes.NEW: Creates a new engine for every connection (default, original behavior)
|
||||
# - EngineModes.SINGLETON: Reuses engines with connection pooling
|
||||
from superset.engines.manager import EngineModes
|
||||
ENGINE_MANAGER_MODE = EngineModes.NEW
|
||||
|
||||
# Cleanup interval for abandoned locks (default: 5 minutes)
|
||||
from datetime import timedelta
|
||||
ENGINE_MANAGER_CLEANUP_INTERVAL = timedelta(minutes=5)
|
||||
|
||||
# Automatically start cleanup thread for SINGLETON mode (default: True)
|
||||
ENGINE_MANAGER_AUTO_START_CLEANUP = True
|
||||
```
|
||||
|
||||
#### Migration Guide
|
||||
|
||||
- If you were using the `nullpool` parameter, remove it from your calls
|
||||
- If you had a custom `SSH_TUNNEL_MANAGER_CLASS`, refactor to use the new EngineManager architecture
|
||||
- If you need connection pooling, set `ENGINE_MANAGER_MODE = EngineModes.SINGLETON` and configure the pool in your database's `extra` JSON field
|
||||
|
||||
### WebSocket config for GAQ with Docker
|
||||
|
||||
[35896](https://github.com/apache/superset/pull/35896) and [37624](https://github.com/apache/superset/pull/37624) updated documentation on how to run and configure Superset with Docker. Specifically for the WebSocket configuration, a new `docker/superset-websocket/config.example.json` was added to the repo, so that users could copy it to create a `docker/superset-websocket/config.json` file. The existing `docker/superset-websocket/config.json` was removed and git-ignored, so if you're using GAQ / WebSocket make sure to:
|
||||
|
||||
@@ -52,10 +52,15 @@ from superset.advanced_data_type.plugins.internet_address import internet_addres
|
||||
from superset.advanced_data_type.plugins.internet_port import internet_port
|
||||
from superset.advanced_data_type.types import AdvancedDataType
|
||||
from superset.constants import CHANGE_ME_SECRET_KEY
|
||||
from superset.engines.manager import EngineModes
|
||||
from superset.jinja_context import BaseTemplateProcessor
|
||||
from superset.key_value.types import JsonKeyValueCodec
|
||||
from superset.stats_logger import DummyStatsLogger
|
||||
from superset.superset_typing import CacheConfig
|
||||
from superset.superset_typing import (
|
||||
CacheConfig,
|
||||
DBConnectionMutator,
|
||||
EngineContextManager,
|
||||
)
|
||||
from superset.tasks.types import ExecutorType
|
||||
from superset.themes.types import Theme
|
||||
from superset.utils import core as utils
|
||||
@@ -260,6 +265,22 @@ SQLALCHEMY_ENGINE_OPTIONS = {}
|
||||
# SQLALCHEMY_CUSTOM_PASSWORD_STORE = lookup_password
|
||||
SQLALCHEMY_CUSTOM_PASSWORD_STORE = None
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Engine Manager Configuration
|
||||
# ---------------------------------------------------------
|
||||
|
||||
# Engine manager mode: "NEW" creates a new engine for every connection (default),
|
||||
# "SINGLETON" reuses engines with connection pooling
|
||||
ENGINE_MANAGER_MODE = EngineModes.NEW
|
||||
|
||||
# Cleanup interval for abandoned locks in seconds (default: 5 minutes)
|
||||
ENGINE_MANAGER_CLEANUP_INTERVAL = timedelta(minutes=5)
|
||||
|
||||
# Automatically start cleanup thread for SINGLETON mode (default: True)
|
||||
ENGINE_MANAGER_AUTO_START_CLEANUP = True
|
||||
|
||||
# ---------------------------------------------------------
|
||||
|
||||
#
|
||||
# The EncryptedFieldTypeAdapter is used whenever we're building SqlAlchemy models
|
||||
# which include sensitive fields that should be app-encrypted BEFORE sending
|
||||
@@ -809,7 +830,6 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
|
||||
# FIREWALL (only port 22 is open)
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
|
||||
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
|
||||
#: Timeout (seconds) for tunnel connection (open_channel timeout)
|
||||
SSH_TUNNEL_TIMEOUT_SEC = 10.0
|
||||
@@ -1684,7 +1704,7 @@ def engine_context_manager( # pylint: disable=unused-argument
|
||||
yield None
|
||||
|
||||
|
||||
ENGINE_CONTEXT_MANAGER = engine_context_manager
|
||||
ENGINE_CONTEXT_MANAGER: EngineContextManager = engine_context_manager
|
||||
|
||||
# A callable that allows altering the database connection URL and params
|
||||
# on the fly, at runtime. This allows for things like impersonation or
|
||||
@@ -1701,7 +1721,7 @@ ENGINE_CONTEXT_MANAGER = engine_context_manager
|
||||
#
|
||||
# Note that the returned uri and params are passed directly to sqlalchemy's
|
||||
# as such `create_engine(url, **params)`
|
||||
DB_CONNECTION_MUTATOR = None
|
||||
DB_CONNECTION_MUTATOR: DBConnectionMutator | None = None
|
||||
|
||||
|
||||
# A callable that is invoked for every invocation of DB Engine Specs
|
||||
|
||||
@@ -14,23 +14,3 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest.mock import Mock
|
||||
|
||||
import sshtunnel
|
||||
|
||||
from superset.extensions.ssh import SSHManagerFactory
|
||||
|
||||
|
||||
def test_ssh_tunnel_timeout_setting() -> None:
|
||||
app = Mock()
|
||||
app.config = {
|
||||
"SSH_TUNNEL_MAX_RETRIES": 2,
|
||||
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test",
|
||||
"SSH_TUNNEL_TIMEOUT_SEC": 123.0,
|
||||
"SSH_TUNNEL_PACKET_TIMEOUT_SEC": 321.0,
|
||||
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
|
||||
}
|
||||
factory = SSHManagerFactory()
|
||||
factory.init_app(app)
|
||||
assert sshtunnel.TUNNEL_TIMEOUT == 123.0
|
||||
assert sshtunnel.SSH_TIMEOUT == 321.0
|
||||
630
superset/engines/manager.py
Normal file
630
superset/engines/manager.py
Normal file
@@ -0,0 +1,630 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import enum
|
||||
import hashlib
|
||||
import logging
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from io import StringIO
|
||||
from typing import Any, Iterator, TYPE_CHECKING
|
||||
|
||||
import sshtunnel
|
||||
from paramiko import RSAKey
|
||||
from sqlalchemy import create_engine, event, pool
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.superset_typing import DBConnectionMutator, EngineContextManager
|
||||
from superset.utils.core import get_query_source_from_request, get_user_id, QuerySource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _LockManager:
|
||||
"""
|
||||
Manages per-key locks safely without defaultdict race conditions.
|
||||
|
||||
This class provides a thread-safe way to create and manage locks for specific keys,
|
||||
avoiding the race conditions that occur when using defaultdict with threading.Lock.
|
||||
|
||||
The implementation uses a two-level locking strategy:
|
||||
1. A meta-lock to protect the lock dictionary itself
|
||||
2. Per-key locks to protect specific resources
|
||||
|
||||
This ensures that:
|
||||
- Different keys can be locked concurrently (scalability)
|
||||
- Lock creation is thread-safe (no race conditions)
|
||||
- The same key always gets the same lock instance
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._locks: dict[str, threading.RLock] = {}
|
||||
self._meta_lock = threading.Lock()
|
||||
|
||||
def get_lock(self, key: str) -> threading.RLock:
|
||||
"""
|
||||
Get or create a lock for the given key.
|
||||
|
||||
This method uses double-checked locking to ensure thread safety:
|
||||
1. First check without lock (fast path)
|
||||
2. Acquire meta-lock if needed
|
||||
3. Double-check inside the lock to prevent race conditions
|
||||
|
||||
This approach minimizes lock contention while ensuring correctness.
|
||||
|
||||
:param key: The key to get a lock for
|
||||
:returns: An RLock instance for the given key
|
||||
"""
|
||||
if lock := self._locks.get(key):
|
||||
return lock
|
||||
|
||||
with self._meta_lock:
|
||||
# Double-check inside the lock
|
||||
lock = self._locks.get(key)
|
||||
if lock is None:
|
||||
lock = threading.RLock()
|
||||
self._locks[key] = lock
|
||||
return lock
|
||||
|
||||
def cleanup(self, active_keys: set[str]) -> None:
|
||||
"""
|
||||
Remove locks for keys that are no longer in use.
|
||||
|
||||
This prevents memory leaks from accumulating locks for resources
|
||||
that have been disposed.
|
||||
|
||||
:param active_keys: Set of keys that are still active
|
||||
"""
|
||||
with self._meta_lock:
|
||||
# Find locks to remove
|
||||
locks_to_remove = self._locks.keys() - active_keys
|
||||
for key in locks_to_remove:
|
||||
self._locks.pop(key, None)
|
||||
|
||||
|
||||
EngineKey = str
|
||||
TunnelKey = str
|
||||
|
||||
|
||||
def _generate_cache_key(*args: Any) -> str:
|
||||
"""
|
||||
Generate a deterministic cache key from arbitrary arguments.
|
||||
|
||||
Uses repr() for serialization and SHA-256 for hashing. The resulting key
|
||||
is a 32-character hex string that:
|
||||
1. Is deterministic for the same inputs
|
||||
2. Does not expose sensitive data (everything is hashed)
|
||||
3. Has sufficient entropy to avoid collisions
|
||||
|
||||
:param args: Arguments to include in the cache key
|
||||
:returns: 32-character hex string
|
||||
"""
|
||||
# Use repr() which works with most Python objects and is deterministic
|
||||
serialized = repr(args).encode("utf-8")
|
||||
return hashlib.sha256(serialized).hexdigest()[:32]
|
||||
|
||||
|
||||
class EngineModes(enum.Enum):
|
||||
# reuse existing engine if available, otherwise create a new one; this mode should
|
||||
# have a connection pool configured in the database
|
||||
SINGLETON = enum.auto()
|
||||
|
||||
# always create a new engine for every connection; this mode will use a NullPool
|
||||
# and is the default behavior for Superset
|
||||
NEW = enum.auto()
|
||||
|
||||
|
||||
class EngineManager:
|
||||
"""
|
||||
A manager for SQLAlchemy engines.
|
||||
|
||||
This class handles the creation and management of SQLAlchemy engines, allowing them
|
||||
to be configured with connection pools and reused across requests. The default mode
|
||||
is the original behavior for Superset, where we create a new engine for every
|
||||
connection, using a NullPool. The `SINGLETON` mode, on the other hand, allows for
|
||||
reusing of the engines, as well as configuring the pool through the database
|
||||
settings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_context_manager: EngineContextManager,
|
||||
db_connection_mutator: DBConnectionMutator | None = None,
|
||||
mode: EngineModes = EngineModes.NEW,
|
||||
cleanup_interval: timedelta = timedelta(minutes=5),
|
||||
local_bind_address: str = "127.0.0.1",
|
||||
tunnel_timeout: timedelta = timedelta(seconds=30),
|
||||
ssh_timeout: timedelta = timedelta(seconds=1),
|
||||
) -> None:
|
||||
self.engine_context_manager = engine_context_manager
|
||||
self.db_connection_mutator = db_connection_mutator
|
||||
self.mode = mode
|
||||
self.cleanup_interval = cleanup_interval
|
||||
self.local_bind_address = local_bind_address
|
||||
|
||||
sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout.total_seconds()
|
||||
sshtunnel.SSH_TIMEOUT = ssh_timeout.total_seconds()
|
||||
|
||||
self._engines: dict[EngineKey, Engine] = {}
|
||||
self._engine_locks = _LockManager()
|
||||
self._tunnels: dict[TunnelKey, SSHTunnelForwarder] = {}
|
||||
self._tunnel_locks = _LockManager()
|
||||
|
||||
# Background cleanup thread management
|
||||
self._cleanup_thread: threading.Thread | None = None
|
||||
self._cleanup_stop_event = threading.Event()
|
||||
self._cleanup_thread_lock = threading.Lock()
|
||||
|
||||
def __del__(self) -> None:
|
||||
"""
|
||||
Ensure cleanup thread is stopped when the manager is destroyed.
|
||||
"""
|
||||
try:
|
||||
self.stop_cleanup_thread()
|
||||
except Exception as ex:
|
||||
# Avoid exceptions during garbage collection, but log if possible
|
||||
try:
|
||||
logger.warning("Error stopping cleanup thread: %s", ex)
|
||||
except Exception: # noqa: S110
|
||||
# If logging fails during destruction, we can't do anything
|
||||
pass
|
||||
|
||||
@contextmanager
|
||||
def get_engine(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
) -> Iterator[Engine]:
|
||||
"""
|
||||
Context manager to get a SQLAlchemy engine.
|
||||
"""
|
||||
# users can wrap the engine in their own context manager for different
|
||||
# reasons
|
||||
with self.engine_context_manager(database, catalog, schema):
|
||||
# we need to check for errors indicating that OAuth2 is needed, and
|
||||
# return the proper exception so it starts the authentication flow
|
||||
from superset.utils.oauth2 import check_for_oauth2
|
||||
|
||||
with check_for_oauth2(database):
|
||||
yield self._get_engine(database, catalog, schema, source)
|
||||
|
||||
def _get_engine(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
) -> Engine:
|
||||
"""
|
||||
Get a specific engine, or create it if none exists.
|
||||
"""
|
||||
source = source or get_query_source_from_request()
|
||||
user_id = get_user_id()
|
||||
|
||||
if self.mode == EngineModes.NEW:
|
||||
return self._create_engine(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
|
||||
engine_key = self._get_engine_key(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if engine := self._engines.get(engine_key):
|
||||
return engine
|
||||
|
||||
lock = self._engine_locks.get_lock(engine_key)
|
||||
with lock:
|
||||
# Double-check inside the lock
|
||||
if engine := self._engines.get(engine_key):
|
||||
return engine
|
||||
|
||||
# Create and cache the engine
|
||||
engine = self._create_engine(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
self._engines[engine_key] = engine
|
||||
self._add_disposal_listener(engine, engine_key)
|
||||
return engine
|
||||
|
||||
def _get_engine_key(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
user_id: int | None,
|
||||
) -> EngineKey:
|
||||
"""
|
||||
Generate a cache key for the engine.
|
||||
|
||||
The key is a hash of all parameters that affect the engine, ensuring
|
||||
proper cache isolation without exposing sensitive data.
|
||||
|
||||
:returns: 32-character hex string
|
||||
"""
|
||||
uri, kwargs = self._get_engine_args(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
|
||||
return _generate_cache_key(
|
||||
database.id,
|
||||
catalog,
|
||||
schema,
|
||||
str(uri),
|
||||
source,
|
||||
user_id,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
def _get_engine_args(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
user_id: int | None,
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
"""
|
||||
Build the almost final SQLAlchemy URI and engine kwargs.
|
||||
|
||||
"Almost" final because we may still need to mutate the URI if an SSH tunnel is
|
||||
needed, since it needs to connect to the tunnel instead of the original DB. But
|
||||
that information is only available after the tunnel is created.
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from superset.extensions import security_manager
|
||||
from superset.utils.feature_flag_manager import FeatureFlagManager
|
||||
|
||||
uri = make_url_safe(database.sqlalchemy_uri_decrypted)
|
||||
|
||||
extra = database.get_extra(source)
|
||||
# Make a copy to avoid mutating the original extra dict
|
||||
kwargs = dict(extra.get("engine_params", {}))
|
||||
|
||||
# get pool class
|
||||
if self.mode == EngineModes.NEW or "poolclass" not in kwargs:
|
||||
kwargs["poolclass"] = pool.NullPool
|
||||
else:
|
||||
pools = {
|
||||
"queue": pool.QueuePool,
|
||||
"singleton": pool.SingletonThreadPool,
|
||||
"assertion": pool.AssertionPool,
|
||||
"null": pool.NullPool,
|
||||
"static": pool.StaticPool,
|
||||
}
|
||||
kwargs["poolclass"] = pools.get(extra["poolclass"], pool.QueuePool)
|
||||
|
||||
# update URI for specific catalog/schema
|
||||
connect_args = dict(extra.get("connect_args", {}))
|
||||
uri, connect_args = database.db_engine_spec.adjust_engine_params(
|
||||
uri,
|
||||
connect_args,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
|
||||
# get effective username
|
||||
username = database.get_effective_user(uri)
|
||||
|
||||
feature_flag_manager = FeatureFlagManager()
|
||||
if username and feature_flag_manager.is_feature_enabled(
|
||||
"IMPERSONATE_WITH_EMAIL_PREFIX"
|
||||
):
|
||||
user = security_manager.find_user(username=username)
|
||||
if user and user.email and "@" in user.email:
|
||||
username = user.email.split("@")[0]
|
||||
|
||||
# update URI/kwargs for user impersonation
|
||||
if database.impersonate_user:
|
||||
oauth2_config = database.get_oauth2_config()
|
||||
# Import here to avoid circular imports
|
||||
from superset.utils.oauth2 import get_oauth2_access_token
|
||||
|
||||
access_token = (
|
||||
get_oauth2_access_token(
|
||||
oauth2_config,
|
||||
database.id,
|
||||
user_id,
|
||||
database.db_engine_spec,
|
||||
)
|
||||
if oauth2_config and user_id
|
||||
else None
|
||||
)
|
||||
|
||||
uri, kwargs = database.db_engine_spec.impersonate_user(
|
||||
database,
|
||||
username,
|
||||
access_token,
|
||||
uri,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
# update kwargs from params stored encrupted at rest
|
||||
database.update_params_from_encrypted_extra(kwargs)
|
||||
|
||||
# mutate URI
|
||||
if self.db_connection_mutator:
|
||||
source = source or get_query_source_from_request()
|
||||
# Import here to avoid circular imports
|
||||
from superset.extensions import security_manager
|
||||
|
||||
uri, kwargs = self.db_connection_mutator(
|
||||
uri,
|
||||
kwargs,
|
||||
username,
|
||||
security_manager,
|
||||
source,
|
||||
)
|
||||
|
||||
# validate final URI
|
||||
database.db_engine_spec.validate_database_uri(uri)
|
||||
|
||||
return uri, kwargs
|
||||
|
||||
def _create_engine(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
user_id: int | None,
|
||||
) -> Engine:
|
||||
"""
|
||||
Create the actual engine.
|
||||
|
||||
This should be the only place in Superset where a SQLAlchemy engine is created,
|
||||
"""
|
||||
uri, kwargs = self._get_engine_args(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
user_id,
|
||||
)
|
||||
|
||||
if database.ssh_tunnel:
|
||||
tunnel = self._get_tunnel(database.ssh_tunnel, uri)
|
||||
uri = uri.set(
|
||||
host=tunnel.local_bind_address[0],
|
||||
port=tunnel.local_bind_port,
|
||||
)
|
||||
|
||||
try:
|
||||
engine = create_engine(uri, **kwargs)
|
||||
except Exception as ex:
|
||||
raise database.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
return engine
|
||||
|
||||
def _get_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
|
||||
tunnel_key = self._get_tunnel_key(ssh_tunnel, uri)
|
||||
|
||||
tunnel = self._tunnels.get(tunnel_key)
|
||||
if tunnel is not None and tunnel.is_active:
|
||||
return tunnel
|
||||
|
||||
lock = self._tunnel_locks.get_lock(tunnel_key)
|
||||
with lock:
|
||||
# Double-check inside the lock
|
||||
tunnel = self._tunnels.get(tunnel_key)
|
||||
if tunnel is not None and tunnel.is_active:
|
||||
return tunnel
|
||||
|
||||
# Create or replace tunnel
|
||||
return self._replace_tunnel(tunnel_key, ssh_tunnel, uri, tunnel)
|
||||
|
||||
def _replace_tunnel(
|
||||
self,
|
||||
tunnel_key: str,
|
||||
ssh_tunnel: "SSHTunnel",
|
||||
uri: URL,
|
||||
old_tunnel: SSHTunnelForwarder | None,
|
||||
) -> SSHTunnelForwarder:
|
||||
"""
|
||||
Replace tunnel with proper cleanup.
|
||||
|
||||
This function assumes caller holds lock.
|
||||
"""
|
||||
if old_tunnel:
|
||||
try:
|
||||
old_tunnel.stop()
|
||||
except Exception:
|
||||
logger.exception("Error stopping old tunnel")
|
||||
|
||||
try:
|
||||
new_tunnel = self._create_tunnel(ssh_tunnel, uri)
|
||||
self._tunnels[tunnel_key] = new_tunnel
|
||||
except Exception:
|
||||
# Remove failed tunnel from cache
|
||||
self._tunnels.pop(tunnel_key, None)
|
||||
logger.exception("Failed to create tunnel")
|
||||
raise
|
||||
|
||||
return new_tunnel
|
||||
|
||||
def _get_tunnel_key(self, ssh_tunnel: "SSHTunnel", uri: URL) -> TunnelKey:
|
||||
"""
|
||||
Generate a cache key for the SSH tunnel.
|
||||
|
||||
:returns: 32-character hex string
|
||||
"""
|
||||
tunnel_kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
|
||||
return _generate_cache_key(tunnel_kwargs)
|
||||
|
||||
def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
|
||||
kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
|
||||
# Use open_tunnel which handles debug_level properly
|
||||
tunnel = sshtunnel.open_tunnel(**kwargs)
|
||||
tunnel.start()
|
||||
|
||||
return tunnel
|
||||
|
||||
def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]:
|
||||
# Import here to avoid circular imports
|
||||
from superset.utils.ssh_tunnel import get_default_port
|
||||
|
||||
backend = uri.get_backend_name()
|
||||
kwargs = {
|
||||
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
|
||||
"ssh_username": ssh_tunnel.username,
|
||||
"remote_bind_address": (uri.host, uri.port or get_default_port(backend)),
|
||||
"local_bind_address": (self.local_bind_address,),
|
||||
"debug_level": logging.getLogger("flask_appbuilder").level,
|
||||
}
|
||||
|
||||
if ssh_tunnel.password:
|
||||
kwargs["ssh_password"] = ssh_tunnel.password
|
||||
elif ssh_tunnel.private_key:
|
||||
private_key_file = StringIO(ssh_tunnel.private_key)
|
||||
private_key = RSAKey.from_private_key(
|
||||
private_key_file,
|
||||
ssh_tunnel.private_key_password,
|
||||
)
|
||||
kwargs["ssh_pkey"] = private_key
|
||||
|
||||
if self.mode == EngineModes.NEW:
|
||||
kwargs["set_keepalive"] = 0 # disable keepalive for one-time tunnels
|
||||
|
||||
return kwargs
|
||||
|
||||
def start_cleanup_thread(self) -> None:
|
||||
"""
|
||||
Start the background cleanup thread.
|
||||
|
||||
The thread will periodically clean up abandoned locks at the configured
|
||||
interval. This is safe to call multiple times - subsequent calls are no-ops.
|
||||
"""
|
||||
with self._cleanup_thread_lock:
|
||||
if self._cleanup_thread is None or not self._cleanup_thread.is_alive():
|
||||
self._cleanup_stop_event.clear()
|
||||
self._cleanup_thread = threading.Thread(
|
||||
target=self._cleanup_worker,
|
||||
name=f"EngineManager-cleanup-{id(self)}",
|
||||
daemon=True,
|
||||
)
|
||||
self._cleanup_thread.start()
|
||||
logger.info(
|
||||
"Started cleanup thread with %ds interval",
|
||||
self.cleanup_interval.total_seconds(),
|
||||
)
|
||||
|
||||
def stop_cleanup_thread(self) -> None:
|
||||
"""
|
||||
Stop the background cleanup thread gracefully.
|
||||
|
||||
This will signal the thread to stop and wait for it to finish.
|
||||
Safe to call even if no thread is running.
|
||||
"""
|
||||
with self._cleanup_thread_lock:
|
||||
if self._cleanup_thread is not None and self._cleanup_thread.is_alive():
|
||||
self._cleanup_stop_event.set()
|
||||
self._cleanup_thread.join(timeout=5.0) # 5 second timeout
|
||||
if self._cleanup_thread.is_alive():
|
||||
logger.warning("Cleanup thread did not stop within timeout")
|
||||
else:
|
||||
logger.info("Cleanup thread stopped")
|
||||
self._cleanup_thread = None
|
||||
|
||||
def _cleanup_worker(self) -> None:
|
||||
"""
|
||||
Background thread worker that periodically cleans up abandoned locks.
|
||||
"""
|
||||
while not self._cleanup_stop_event.is_set():
|
||||
try:
|
||||
self._cleanup_abandoned_locks()
|
||||
except Exception:
|
||||
logger.exception("Error during background cleanup")
|
||||
|
||||
# Use wait() instead of sleep() to allow for immediate shutdown
|
||||
if self._cleanup_stop_event.wait(
|
||||
timeout=self.cleanup_interval.total_seconds()
|
||||
):
|
||||
break # Stop event was set
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""
|
||||
Public method to manually trigger cleanup of abandoned locks.
|
||||
|
||||
This can be called periodically by external systems to prevent
|
||||
memory leaks from accumulating locks.
|
||||
"""
|
||||
self._cleanup_abandoned_locks()
|
||||
|
||||
def _cleanup_abandoned_locks(self) -> None:
|
||||
"""
|
||||
Clean up locks for engines and tunnels that no longer exist.
|
||||
|
||||
This prevents memory leaks from accumulating locks when engines/tunnels
|
||||
are disposed outside of normal cleanup paths.
|
||||
"""
|
||||
# Clean up engine locks for inactive engines
|
||||
active_engine_keys = set(self._engines.keys())
|
||||
self._engine_locks.cleanup(active_engine_keys)
|
||||
|
||||
# Clean up tunnel locks for inactive tunnels
|
||||
active_tunnel_keys = set(self._tunnels.keys())
|
||||
self._tunnel_locks.cleanup(active_tunnel_keys)
|
||||
|
||||
# Log for debugging
|
||||
if active_engine_keys or active_tunnel_keys:
|
||||
logger.debug(
|
||||
"EngineManager resources - Engines: %d, Tunnels: %d",
|
||||
len(active_engine_keys),
|
||||
len(active_tunnel_keys),
|
||||
)
|
||||
|
||||
def _add_disposal_listener(self, engine: Engine, engine_key: EngineKey) -> None:
|
||||
@event.listens_for(engine, "engine_disposed")
|
||||
def on_engine_disposed(engine_instance: Engine) -> None:
|
||||
try:
|
||||
# Remove engine from cache - no per-key locks to clean up anymore
|
||||
if self._engines.pop(engine_key, None):
|
||||
# Log only first 8 chars of hash for safety
|
||||
# (still enough for debugging, but doesn't expose full key)
|
||||
log_key = engine_key[:8] + "..."
|
||||
logger.info("Engine disposed and removed from cache: %s", log_key)
|
||||
except Exception as ex:
|
||||
logger.error("Error during engine disposal cleanup: %s", str(ex))
|
||||
# Don't log engine_key to avoid exposing credential hash
|
||||
@@ -41,7 +41,7 @@ from werkzeug.local import LocalProxy
|
||||
|
||||
from superset.async_events.async_query_manager import AsyncQueryManager
|
||||
from superset.async_events.async_query_manager_factory import AsyncQueryManagerFactory
|
||||
from superset.extensions.ssh import SSHManagerFactory
|
||||
from superset.extensions.engine_manager import EngineManagerExtension
|
||||
from superset.extensions.stats_logger import BaseStatsLoggerManager
|
||||
from superset.security.manager import SupersetSecurityManager
|
||||
from superset.utils.cache_manager import CacheManager
|
||||
@@ -136,6 +136,7 @@ cache_manager = CacheManager()
|
||||
celery_app = celery.Celery()
|
||||
csrf = CSRFProtect()
|
||||
db = get_sqla_class()()
|
||||
engine_manager_extension = EngineManagerExtension()
|
||||
_event_logger: dict[str, Any] = {}
|
||||
encrypted_field_factory = EncryptedFieldFactory()
|
||||
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
|
||||
@@ -146,6 +147,5 @@ migrate = Migrate()
|
||||
profiling = ProfilingExtension()
|
||||
results_backend_manager = ResultsBackendManager()
|
||||
security_manager: SupersetSecurityManager = LocalProxy(lambda: appbuilder.sm)
|
||||
ssh_manager_factory = SSHManagerFactory()
|
||||
stats_logger_manager = BaseStatsLoggerManager()
|
||||
talisman = Talisman()
|
||||
|
||||
100
superset/extensions/engine_manager.py
Normal file
100
superset/extensions/engine_manager.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from superset.engines.manager import EngineManager, EngineModes
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EngineManagerExtension:
|
||||
"""
|
||||
Flask extension for managing SQLAlchemy engines in Superset.
|
||||
|
||||
This extension creates and configures an EngineManager instance based on
|
||||
Flask configuration, handling startup and shutdown of background cleanup
|
||||
threads as needed.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.engine_manager: EngineManager | None = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
"""
|
||||
Initialize the EngineManager with Flask app configuration.
|
||||
"""
|
||||
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
|
||||
db_connection_mutator = app.config["DB_CONNECTION_MUTATOR"]
|
||||
mode = app.config["ENGINE_MANAGER_MODE"]
|
||||
cleanup_interval = app.config["ENGINE_MANAGER_CLEANUP_INTERVAL"]
|
||||
local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||
tunnel_timeout = timedelta(seconds=app.config["SSH_TUNNEL_TIMEOUT_SEC"])
|
||||
ssh_timeout = timedelta(seconds=app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"])
|
||||
auto_start_cleanup = app.config["ENGINE_MANAGER_AUTO_START_CLEANUP"]
|
||||
|
||||
# Create the engine manager
|
||||
self.engine_manager = EngineManager(
|
||||
engine_context_manager,
|
||||
db_connection_mutator,
|
||||
mode,
|
||||
cleanup_interval,
|
||||
local_bind_address,
|
||||
tunnel_timeout,
|
||||
ssh_timeout,
|
||||
)
|
||||
|
||||
# Start cleanup thread if requested and in SINGLETON mode
|
||||
if auto_start_cleanup and mode == EngineModes.SINGLETON:
|
||||
self.engine_manager.start_cleanup_thread()
|
||||
logger.info("Started EngineManager cleanup thread")
|
||||
|
||||
# Register shutdown handler
|
||||
def shutdown_engine_manager() -> None:
|
||||
if self.engine_manager:
|
||||
self.engine_manager.stop_cleanup_thread()
|
||||
|
||||
app.teardown_appcontext_funcs.append(lambda exc: None)
|
||||
|
||||
# Register with atexit for clean shutdown
|
||||
import atexit
|
||||
|
||||
atexit.register(shutdown_engine_manager)
|
||||
|
||||
logger.info(
|
||||
"Initialized EngineManager with mode=%s, cleanup_interval=%ds",
|
||||
mode,
|
||||
cleanup_interval.total_seconds(),
|
||||
)
|
||||
|
||||
@property
|
||||
def manager(self) -> EngineManager:
|
||||
"""
|
||||
Get the EngineManager instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the extension hasn't been initialized with an app.
|
||||
"""
|
||||
if self.engine_manager is None:
|
||||
raise RuntimeError(
|
||||
"EngineManager extension not initialized. "
|
||||
"Call init_app() with a Flask app first."
|
||||
)
|
||||
return self.engine_manager
|
||||
@@ -1,94 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from io import StringIO
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sshtunnel
|
||||
from flask import Flask
|
||||
from paramiko import RSAKey
|
||||
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.utils.class_utils import load_class_from_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
|
||||
class SSHManager:
|
||||
def __init__(self, app: Flask) -> None:
|
||||
super().__init__()
|
||||
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]
|
||||
sshtunnel.SSH_TIMEOUT = app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]
|
||||
|
||||
def build_sqla_url(
|
||||
self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder
|
||||
) -> str:
|
||||
# override any ssh tunnel configuration object
|
||||
url = make_url_safe(sqlalchemy_url)
|
||||
return url.set(
|
||||
host=server.local_bind_address[0],
|
||||
port=server.local_bind_port,
|
||||
)
|
||||
|
||||
def create_tunnel(
|
||||
self,
|
||||
ssh_tunnel: "SSHTunnel",
|
||||
sqlalchemy_database_uri: str,
|
||||
) -> sshtunnel.SSHTunnelForwarder:
|
||||
from superset.utils.ssh_tunnel import get_default_port
|
||||
|
||||
url = make_url_safe(sqlalchemy_database_uri)
|
||||
backend = url.get_backend_name()
|
||||
port = url.port or get_default_port(backend)
|
||||
if not port:
|
||||
raise SSHTunnelDatabasePortError()
|
||||
params = {
|
||||
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
|
||||
"ssh_username": ssh_tunnel.username,
|
||||
"remote_bind_address": (url.host, port),
|
||||
"local_bind_address": (self.local_bind_address,),
|
||||
"debug_level": logging.getLogger("flask_appbuilder").level,
|
||||
}
|
||||
|
||||
if ssh_tunnel.password:
|
||||
params["ssh_password"] = ssh_tunnel.password
|
||||
elif ssh_tunnel.private_key:
|
||||
private_key_file = StringIO(ssh_tunnel.private_key)
|
||||
private_key = RSAKey.from_private_key(
|
||||
private_key_file, ssh_tunnel.private_key_password
|
||||
)
|
||||
params["ssh_pkey"] = private_key
|
||||
|
||||
return sshtunnel.open_tunnel(**params)
|
||||
|
||||
|
||||
class SSHManagerFactory:
|
||||
def __init__(self) -> None:
|
||||
self._ssh_manager = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
self._ssh_manager = load_class_from_name(
|
||||
app.config["SSH_TUNNEL_MANAGER_CLASS"]
|
||||
)(app)
|
||||
|
||||
@property
|
||||
def instance(self) -> SSHManager:
|
||||
return self._ssh_manager # type: ignore
|
||||
@@ -49,13 +49,13 @@ from superset.extensions import (
|
||||
csrf,
|
||||
db,
|
||||
encrypted_field_factory,
|
||||
engine_manager_extension,
|
||||
feature_flag_manager,
|
||||
machine_auth_provider_factory,
|
||||
manifest_processor,
|
||||
migrate,
|
||||
profiling,
|
||||
results_backend_manager,
|
||||
ssh_manager_factory,
|
||||
stats_logger_manager,
|
||||
talisman,
|
||||
)
|
||||
@@ -585,8 +585,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
self.configure_url_map_converters()
|
||||
self.configure_data_sources()
|
||||
self.configure_auth_provider()
|
||||
self.configure_engine_manager()
|
||||
self.configure_async_queries()
|
||||
self.configure_ssh_manager()
|
||||
self.configure_stats_manager()
|
||||
|
||||
# Hook that provides administrators a handle on the Flask APP
|
||||
@@ -761,8 +761,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
def configure_auth_provider(self) -> None:
|
||||
machine_auth_provider_factory.init_app(self.superset_app)
|
||||
|
||||
def configure_ssh_manager(self) -> None:
|
||||
ssh_manager_factory.init_app(self.superset_app)
|
||||
def configure_engine_manager(self) -> None:
|
||||
engine_manager_extension.init_app(self.superset_app)
|
||||
|
||||
def configure_stats_manager(self) -> None:
|
||||
stats_logger_manager.init_app(self.superset_app)
|
||||
|
||||
@@ -25,24 +25,22 @@ import builtins
|
||||
import logging
|
||||
import textwrap
|
||||
from ast import literal_eval
|
||||
from contextlib import closing, contextmanager, nullcontext, suppress
|
||||
from contextlib import closing, contextmanager, suppress
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, cast, Iterator, Optional, TYPE_CHECKING
|
||||
|
||||
import numpy
|
||||
import pandas as pd
|
||||
import sqlalchemy as sqla
|
||||
import sshtunnel
|
||||
from flask import current_app as app, g, has_app_context
|
||||
from flask_appbuilder import Model
|
||||
from marshmallow.exceptions import ValidationError
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
create_engine,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
@@ -57,7 +55,6 @@ from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import NoSuchModuleError
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import ColumnElement, expression, Select
|
||||
from superset_core.api.models import Database as CoreDatabase
|
||||
@@ -72,7 +69,6 @@ from superset.extensions import (
|
||||
encrypted_field_factory,
|
||||
event_logger,
|
||||
security_manager,
|
||||
ssh_manager_factory,
|
||||
)
|
||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin
|
||||
from superset.result_set import SupersetResultSet
|
||||
@@ -84,10 +80,9 @@ from superset.superset_typing import (
|
||||
)
|
||||
from superset.utils import cache as cache_util, core as utils, json
|
||||
from superset.utils.backports import StrEnum
|
||||
from superset.utils.core import get_query_source_from_request, get_username
|
||||
from superset.utils.core import get_username
|
||||
from superset.utils.oauth2 import (
|
||||
check_for_oauth2,
|
||||
get_oauth2_access_token,
|
||||
OAuth2ClientConfigSchema,
|
||||
)
|
||||
|
||||
@@ -424,130 +419,31 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def get_sqla_engine( # pylint: disable=too-many-arguments
|
||||
def get_sqla_engine(
|
||||
self,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
nullpool: bool = True,
|
||||
source: utils.QuerySource | None = None,
|
||||
) -> Engine:
|
||||
) -> Iterator[Engine]:
|
||||
"""
|
||||
Context manager for a SQLAlchemy engine.
|
||||
|
||||
This method will return a context manager for a SQLAlchemy engine. Using the
|
||||
context manager (as opposed to the engine directly) is important because we need
|
||||
to potentially establish SSH tunnels before the connection is created, and clean
|
||||
them up once the engine is no longer used.
|
||||
This method will return a context manager for a SQLAlchemy engine. The engine
|
||||
manager handles connection pooling, SSH tunnels, and other connection details
|
||||
based on the configured mode (NEW or SINGLETON).
|
||||
"""
|
||||
# Import here to avoid circular imports
|
||||
from superset.extensions import engine_manager_extension
|
||||
|
||||
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
|
||||
|
||||
ssh_context_manager = (
|
||||
ssh_manager_factory.instance.create_tunnel(
|
||||
ssh_tunnel=self.ssh_tunnel,
|
||||
sqlalchemy_database_uri=sqlalchemy_uri,
|
||||
)
|
||||
if self.ssh_tunnel
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with ssh_context_manager as ssh_context:
|
||||
if ssh_context:
|
||||
logger.info(
|
||||
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s "
|
||||
"ssh_timeout at %s",
|
||||
sshtunnel.TUNNEL_TIMEOUT,
|
||||
sshtunnel.SSH_TIMEOUT,
|
||||
ssh_context.local_bind_address,
|
||||
)
|
||||
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
|
||||
sqlalchemy_uri,
|
||||
ssh_context,
|
||||
)
|
||||
|
||||
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
|
||||
with engine_context_manager(self, catalog, schema):
|
||||
with check_for_oauth2(self):
|
||||
yield self._get_sqla_engine(
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
nullpool=nullpool,
|
||||
source=source,
|
||||
sqlalchemy_uri=sqlalchemy_uri,
|
||||
)
|
||||
|
||||
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
|
||||
self,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
nullpool: bool = True,
|
||||
source: utils.QuerySource | None = None,
|
||||
sqlalchemy_uri: str | None = None,
|
||||
) -> Engine:
|
||||
sqlalchemy_url = make_url_safe(
|
||||
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
|
||||
)
|
||||
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
|
||||
|
||||
extra = self.get_extra(source)
|
||||
engine_kwargs = extra.get("engine_params", {})
|
||||
if nullpool:
|
||||
engine_kwargs["poolclass"] = NullPool
|
||||
connect_args = engine_kwargs.setdefault("connect_args", {})
|
||||
|
||||
# modify URL/args for a specific catalog/schema
|
||||
sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
|
||||
uri=sqlalchemy_url,
|
||||
connect_args=connect_args,
|
||||
# Use the engine manager to get the engine
|
||||
engine_manager = engine_manager_extension.manager
|
||||
with engine_manager.get_engine(
|
||||
database=self,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
effective_username = self.get_effective_user(sqlalchemy_url)
|
||||
if effective_username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"):
|
||||
user = security_manager.find_user(username=effective_username)
|
||||
if user and user.email:
|
||||
effective_username = user.email.split("@")[0]
|
||||
|
||||
oauth2_config = self.get_oauth2_config()
|
||||
access_token = (
|
||||
get_oauth2_access_token(
|
||||
oauth2_config,
|
||||
self.id,
|
||||
g.user.id,
|
||||
self.db_engine_spec,
|
||||
)
|
||||
if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
|
||||
else None
|
||||
)
|
||||
masked_url = self.get_password_masked_url(sqlalchemy_url)
|
||||
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
|
||||
|
||||
if self.impersonate_user:
|
||||
sqlalchemy_url, engine_kwargs = self.db_engine_spec.impersonate_user(
|
||||
self,
|
||||
effective_username,
|
||||
access_token,
|
||||
sqlalchemy_url,
|
||||
engine_kwargs,
|
||||
)
|
||||
|
||||
self.update_params_from_encrypted_extra(engine_kwargs)
|
||||
|
||||
if DB_CONNECTION_MUTATOR := app.config["DB_CONNECTION_MUTATOR"]: # noqa: N806
|
||||
source = source or get_query_source_from_request()
|
||||
|
||||
sqlalchemy_url, engine_kwargs = DB_CONNECTION_MUTATOR(
|
||||
sqlalchemy_url,
|
||||
engine_kwargs,
|
||||
effective_username,
|
||||
security_manager,
|
||||
source,
|
||||
)
|
||||
try:
|
||||
return create_engine(sqlalchemy_url, **engine_kwargs)
|
||||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
source=source,
|
||||
) as engine:
|
||||
yield engine
|
||||
|
||||
def add_database_to_signature(
|
||||
self,
|
||||
@@ -572,13 +468,11 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
|
||||
self,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
nullpool: bool = True,
|
||||
source: utils.QuerySource | None = None,
|
||||
) -> Connection:
|
||||
with self.get_sqla_engine(
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
nullpool=nullpool,
|
||||
source=source,
|
||||
) as engine:
|
||||
with check_for_oauth2(self):
|
||||
|
||||
@@ -18,17 +18,42 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Hashable, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TYPE_CHECKING, TypeAlias, TypedDict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Literal,
|
||||
TYPE_CHECKING,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
)
|
||||
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from typing_extensions import NotRequired
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.utils.core import GenericDataType, QueryObjectFilterClause
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import (
|
||||
GenericDataType,
|
||||
QueryObjectFilterClause,
|
||||
QuerySource,
|
||||
)
|
||||
|
||||
SQLType: TypeAlias = TypeEngine | type[TypeEngine]
|
||||
|
||||
# Type alias for database connection mutator function
|
||||
DBConnectionMutator: TypeAlias = Callable[
|
||||
[URL, dict[str, Any], str | None, Any, "QuerySource | None"],
|
||||
tuple[URL, dict[str, Any]],
|
||||
]
|
||||
|
||||
# Type alias for engine context manager
|
||||
EngineContextManager: TypeAlias = Callable[
|
||||
["Database", str | None, str | None], ContextManager[None]
|
||||
]
|
||||
|
||||
|
||||
class LegacyMetric(TypedDict):
|
||||
label: str | None
|
||||
|
||||
@@ -170,7 +170,6 @@ def example_db_provider() -> Callable[[], Database]:
|
||||
return self._db
|
||||
|
||||
def _load_lazy_data_to_decouple_from_session(self) -> None:
|
||||
self._db._get_sqla_engine() # type: ignore
|
||||
self._db.backend # type: ignore # noqa: B018
|
||||
|
||||
def remove(self) -> None:
|
||||
|
||||
@@ -897,7 +897,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
|
||||
|
||||
|
||||
class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_db_exception(
|
||||
@@ -906,19 +906,23 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure event_logger is called when an exception is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.side_effect = Exception("An error has occurred!")
|
||||
mock_get_sqla_engine.return_value.__enter__.side_effect = Exception(
|
||||
"An error has occurred!"
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: # noqa: PT012
|
||||
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo:
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == (
|
||||
"Unexpected error occurred, please check your logs for details"
|
||||
)
|
||||
# Exception wraps errors from db_engine_spec.extract_errors()
|
||||
assert (
|
||||
excinfo.value.errors[0].error_type
|
||||
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
|
||||
)
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_do_ping_exception(
|
||||
@@ -927,9 +931,8 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure do_ping exceptions gets captured"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception(
|
||||
"An error has occurred!"
|
||||
)
|
||||
mock_engine = mock_get_sqla_engine.return_value.__enter__.return_value
|
||||
mock_engine.dialect.do_ping.side_effect = Exception("An error has occurred!")
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
@@ -967,7 +970,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
== SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
|
||||
)
|
||||
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_superset_security_connection(
|
||||
@@ -977,20 +980,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
connection exc is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.side_effect = SupersetSecurityException(
|
||||
SupersetError(error_type=500, message="test", level="info")
|
||||
mock_get_sqla_engine.return_value.__enter__.side_effect = (
|
||||
SupersetSecurityException(
|
||||
SupersetError(error_type=500, message="test", level="info")
|
||||
)
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(DatabaseSecurityUnsafeError) as excinfo: # noqa: PT012
|
||||
with pytest.raises(DatabaseSecurityUnsafeError):
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == ("Stopped an unsafe database connection")
|
||||
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_db_api_exc(
|
||||
@@ -999,19 +1002,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure event_logger is called when DBAPIError is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.side_effect = DBAPIError(
|
||||
mock_get_sqla_engine.return_value.__enter__.side_effect = DBAPIError(
|
||||
statement="error", params={}, orig={}
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(SupersetErrorsException) as excinfo: # noqa: PT012
|
||||
with pytest.raises(SupersetErrorsException) as excinfo:
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == (
|
||||
"Connection failed, please check your connection settings"
|
||||
)
|
||||
|
||||
# Exception wraps errors from db_engine_spec.extract_errors()
|
||||
assert (
|
||||
excinfo.value.errors[0].error_type
|
||||
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
|
||||
)
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
|
||||
@@ -1147,7 +1151,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
|
||||
|
||||
with pytest.raises(DatabaseNotFoundError) as excinfo: # noqa: PT012
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("Database not found.")
|
||||
assert str(excinfo.value) == ("Database not found.")
|
||||
|
||||
@patch("superset.daos.database.DatabaseDAO.find_by_id")
|
||||
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
|
||||
@@ -1166,26 +1170,35 @@ class TestTablesDatabaseCommand(SupersetTestCase):
|
||||
command = TablesDatabaseCommand(database.id, None, "main", False)
|
||||
with pytest.raises(SupersetException) as excinfo: # noqa: PT012
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Test Error"
|
||||
assert str(excinfo.value) == "Test Error"
|
||||
|
||||
@patch("superset.daos.database.DatabaseDAO.find_by_id")
|
||||
@patch("superset.models.core.Database.get_all_materialized_view_names_in_schema")
|
||||
@patch("superset.models.core.Database.get_all_view_names_in_schema")
|
||||
@patch("superset.models.core.Database.get_all_table_names_in_schema")
|
||||
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_database_tables_exception(
|
||||
self, mock_g, mock_can_access_database, mock_find_by_id
|
||||
self,
|
||||
mock_g,
|
||||
mock_can_access_database,
|
||||
mock_get_tables,
|
||||
mock_get_views,
|
||||
mock_get_mvs,
|
||||
mock_find_by_id,
|
||||
):
|
||||
database = get_example_database()
|
||||
mock_find_by_id.return_value = database
|
||||
mock_get_tables.return_value = {("table1", "main", None)}
|
||||
mock_get_views.return_value = set()
|
||||
mock_get_mvs.return_value = []
|
||||
mock_can_access_database.side_effect = Exception("Test Error")
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
|
||||
command = TablesDatabaseCommand(database.id, None, "main", False)
|
||||
with pytest.raises(DatabaseTablesUnexpectedError) as excinfo: # noqa: PT012
|
||||
command.run()
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "Unexpected error occurred, please check your logs for details"
|
||||
)
|
||||
assert str(excinfo.value) == "Test Error"
|
||||
|
||||
@patch("superset.daos.database.DatabaseDAO.find_by_id")
|
||||
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
|
||||
|
||||
@@ -145,7 +145,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
username = make_url(engine.url).username
|
||||
assert example_user.username != username
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
|
||||
)
|
||||
@@ -172,7 +172,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://gamma@localhost/"
|
||||
@@ -185,7 +186,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://localhost/"
|
||||
@@ -199,13 +201,14 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("mysqlclient"), "mysqlclient not installed"
|
||||
)
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
def test_adjust_engine_params_mysql(self, mocked_create_engine):
|
||||
model = Database(
|
||||
database_name="test_database1",
|
||||
sqlalchemy_uri="mysql://user:password@localhost",
|
||||
)
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "mysql://user:password@localhost"
|
||||
@@ -215,13 +218,14 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database2",
|
||||
sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
|
||||
)
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
|
||||
assert call_args[1]["connect_args"]["allow_local_infile"] == 0
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
def test_impersonate_user_trino(self, mocked_create_engine):
|
||||
principal_user = security_manager.find_user(username="gamma")
|
||||
|
||||
@@ -230,7 +234,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database", sqlalchemy_uri="trino://localhost"
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "trino://localhost/"
|
||||
@@ -242,7 +247,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
)
|
||||
|
||||
model.impersonate_user = True
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert (
|
||||
@@ -251,7 +257,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
)
|
||||
assert call_args[1]["connect_args"]["user"] == "gamma"
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
|
||||
)
|
||||
@@ -281,7 +287,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
@@ -294,7 +301,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
@@ -376,7 +384,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
df = main_db.get_df("USE superset; SELECT ';';", None, None)
|
||||
assert df.iat[0, 0] == ";"
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
def test_get_sqla_engine(self, mocked_create_engine):
|
||||
model = Database(
|
||||
database_name="test_database",
|
||||
@@ -387,7 +395,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
)
|
||||
mocked_create_engine.side_effect = Exception()
|
||||
with self.assertRaises(SupersetException): # noqa: PT027
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
|
||||
|
||||
class TestSqlaTableModel(SupersetTestCase):
|
||||
|
||||
527
tests/unit_tests/engines/manager_test.py
Normal file
527
tests/unit_tests/engines/manager_test.py
Normal file
@@ -0,0 +1,527 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for EngineManager."""
|
||||
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.engines.manager import _LockManager, EngineManager, EngineModes
|
||||
|
||||
|
||||
class TestLockManager:
|
||||
"""Test the _LockManager class."""
|
||||
|
||||
def test_get_lock_creates_new_lock(self):
|
||||
"""Test that get_lock creates a new lock when needed."""
|
||||
manager = _LockManager()
|
||||
lock1 = manager.get_lock("key1")
|
||||
|
||||
assert isinstance(lock1, type(threading.RLock()))
|
||||
assert lock1 is manager.get_lock("key1") # Same lock returned
|
||||
|
||||
def test_get_lock_different_keys_different_locks(self):
|
||||
"""Test that different keys get different locks."""
|
||||
manager = _LockManager()
|
||||
lock1 = manager.get_lock("key1")
|
||||
lock2 = manager.get_lock("key2")
|
||||
|
||||
assert lock1 is not lock2
|
||||
|
||||
def test_cleanup_removes_unused_locks(self):
|
||||
"""Test that cleanup removes locks for inactive keys."""
|
||||
manager = _LockManager()
|
||||
|
||||
# Create locks
|
||||
_ = manager.get_lock("key1") # noqa: F841
|
||||
lock2 = manager.get_lock("key2")
|
||||
|
||||
# Cleanup with only key1 active
|
||||
manager.cleanup({"key1"})
|
||||
|
||||
# key2 lock should be removed
|
||||
lock3 = manager.get_lock("key2")
|
||||
assert lock3 is not lock2 # New lock created
|
||||
|
||||
def test_concurrent_lock_creation(self):
|
||||
"""Test that concurrent lock creation doesn't create duplicates."""
|
||||
manager = _LockManager()
|
||||
locks_created = []
|
||||
exceptions = []
|
||||
|
||||
def create_lock():
|
||||
try:
|
||||
lock = manager.get_lock("concurrent_key")
|
||||
locks_created.append(lock)
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
|
||||
# Create multiple threads trying to get the same lock
|
||||
threads = [threading.Thread(target=create_lock) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(exceptions) == 0
|
||||
assert len(locks_created) == 10
|
||||
|
||||
# All should be the same lock
|
||||
first_lock = locks_created[0]
|
||||
for lock in locks_created[1:]:
|
||||
assert lock is first_lock
|
||||
|
||||
|
||||
class TestEngineManager:
|
||||
"""Test the EngineManager class."""
|
||||
|
||||
@pytest.fixture
|
||||
def engine_manager(self):
|
||||
"""Create a mock EngineManager instance."""
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def dummy_context_manager(
|
||||
database: MagicMock, catalog: str | None, schema: str | None
|
||||
) -> Iterator[None]:
|
||||
yield
|
||||
|
||||
return EngineManager(engine_context_manager=dummy_context_manager)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database(self):
|
||||
"""Create a mock database."""
|
||||
database = MagicMock()
|
||||
database.sqlalchemy_uri_decrypted = "postgresql://user:pass@localhost/test"
|
||||
database.get_extra.return_value = {"engine_params": {}}
|
||||
database.get_effective_user.return_value = "test_user"
|
||||
database.impersonate_user = False
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (MagicMock(), {})
|
||||
database.db_engine_spec.impersonate_user = MagicMock(
|
||||
return_value=(MagicMock(), {})
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
database.ssh_tunnel = None
|
||||
return database
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_new_mode(
|
||||
self, mock_make_url, mock_create_engine, engine_manager, mock_database
|
||||
):
|
||||
"""Test getting an engine in NEW mode (no caching)."""
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_make_url.return_value = MagicMock()
|
||||
mock_engine1 = MagicMock()
|
||||
mock_engine2 = MagicMock()
|
||||
mock_create_engine.side_effect = [mock_engine1, mock_engine2]
|
||||
|
||||
result = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
|
||||
|
||||
assert result is mock_engine1
|
||||
mock_create_engine.assert_called_once()
|
||||
|
||||
# Calling again should create a new engine (no caching)
|
||||
mock_create_engine.reset_mock()
|
||||
result2 = engine_manager._get_engine(mock_database, "catalog2", "schema2", None)
|
||||
|
||||
assert result2 is mock_engine2 # Different engine
|
||||
mock_create_engine.assert_called_once()
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_singleton_mode_caching(
|
||||
self, mock_make_url, mock_create_engine, engine_manager, mock_database
|
||||
):
|
||||
"""Test that engines are cached in SINGLETON mode."""
|
||||
engine_manager.mode = EngineModes.SINGLETON
|
||||
|
||||
# Use a real engine instead of MagicMock to avoid event listener issues
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
|
||||
mock_create_engine.return_value = real_engine
|
||||
mock_make_url.return_value = real_engine
|
||||
|
||||
# Call twice with same params - should be cached
|
||||
result1 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
|
||||
result2 = engine_manager._get_engine(mock_database, "catalog1", "schema1", None)
|
||||
|
||||
assert result1 is result2 # Same engine returned (cached)
|
||||
mock_create_engine.assert_called_once() # Only created once
|
||||
|
||||
# Call with different params - should create new engine
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_concurrent_engine_creation(
|
||||
self, mock_make_url, mock_create_engine, engine_manager, mock_database
|
||||
):
|
||||
"""Test concurrent engine creation doesn't create duplicates."""
|
||||
engine_manager.mode = EngineModes.SINGLETON
|
||||
|
||||
# Use a real engine to avoid event listener issues with MagicMock
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
|
||||
mock_make_url.return_value = real_engine
|
||||
|
||||
create_count = [0]
|
||||
|
||||
def counting_create_engine(*args, **kwargs):
|
||||
create_count[0] += 1
|
||||
return real_engine
|
||||
|
||||
mock_create_engine.side_effect = counting_create_engine
|
||||
|
||||
results = []
|
||||
exceptions = []
|
||||
|
||||
def get_engine_thread():
|
||||
try:
|
||||
engine = engine_manager._get_engine(
|
||||
mock_database, "catalog1", "schema1", None
|
||||
)
|
||||
results.append(engine)
|
||||
except Exception as e:
|
||||
exceptions.append(e)
|
||||
|
||||
# Run multiple threads
|
||||
threads = [threading.Thread(target=get_engine_thread) for _ in range(10)]
|
||||
for t in threads:
|
||||
t.start()
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
assert len(exceptions) == 0
|
||||
assert len(results) == 10
|
||||
assert create_count[0] == 1 # Engine created only once
|
||||
|
||||
# All results should be the same engine
|
||||
for engine in results:
|
||||
assert engine is real_engine
|
||||
|
||||
@patch("superset.engines.manager.sshtunnel.open_tunnel")
|
||||
def test_ssh_tunnel_creation(self, mock_open_tunnel, engine_manager):
|
||||
"""Test SSH tunnel creation and caching."""
|
||||
ssh_tunnel = MagicMock()
|
||||
ssh_tunnel.server_address = "ssh.example.com"
|
||||
ssh_tunnel.server_port = 22
|
||||
ssh_tunnel.username = "ssh_user"
|
||||
ssh_tunnel.password = "ssh_pass" # noqa: S105
|
||||
ssh_tunnel.private_key = None
|
||||
ssh_tunnel.private_key_password = None
|
||||
|
||||
tunnel_instance = MagicMock()
|
||||
tunnel_instance.is_active = True
|
||||
tunnel_instance.local_bind_address = ("127.0.0.1", 12345)
|
||||
mock_open_tunnel.return_value = tunnel_instance
|
||||
|
||||
uri = MagicMock()
|
||||
uri.host = "db.example.com"
|
||||
uri.port = 5432
|
||||
uri.get_backend_name.return_value = "postgresql"
|
||||
|
||||
result = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
|
||||
assert result is tunnel_instance
|
||||
mock_open_tunnel.assert_called_once()
|
||||
tunnel_instance.start.assert_called_once()
|
||||
|
||||
# Getting same tunnel again should return cached version
|
||||
mock_open_tunnel.reset_mock()
|
||||
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
|
||||
assert result2 is tunnel_instance
|
||||
mock_open_tunnel.assert_not_called()
|
||||
|
||||
@patch("superset.engines.manager.sshtunnel.open_tunnel")
|
||||
def test_ssh_tunnel_recreation_when_inactive(
|
||||
self, mock_open_tunnel, engine_manager
|
||||
):
|
||||
"""Test that inactive tunnels are replaced."""
|
||||
ssh_tunnel = MagicMock()
|
||||
ssh_tunnel.server_address = "ssh.example.com"
|
||||
ssh_tunnel.server_port = 22
|
||||
ssh_tunnel.username = "ssh_user"
|
||||
ssh_tunnel.password = "ssh_pass" # noqa: S105
|
||||
ssh_tunnel.private_key = None
|
||||
ssh_tunnel.private_key_password = None
|
||||
|
||||
# First tunnel is inactive
|
||||
inactive_tunnel = MagicMock()
|
||||
inactive_tunnel.is_active = False
|
||||
inactive_tunnel.local_bind_address = ("127.0.0.1", 12345)
|
||||
|
||||
# Second tunnel is active
|
||||
active_tunnel = MagicMock()
|
||||
active_tunnel.is_active = True
|
||||
active_tunnel.local_bind_address = ("127.0.0.1", 23456)
|
||||
|
||||
mock_open_tunnel.side_effect = [inactive_tunnel, active_tunnel]
|
||||
|
||||
uri = MagicMock()
|
||||
uri.host = "db.example.com"
|
||||
uri.port = 5432
|
||||
uri.get_backend_name.return_value = "postgresql"
|
||||
|
||||
# First call creates inactive tunnel
|
||||
result1 = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
assert result1 is inactive_tunnel
|
||||
|
||||
# Second call should create new tunnel since first is inactive
|
||||
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
assert result2 is active_tunnel
|
||||
assert mock_open_tunnel.call_count == 2
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_basic(
|
||||
self, mock_make_url, mock_create_engine, engine_manager
|
||||
):
|
||||
"""Test _get_engine_args returns correct URI and kwargs."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import EngineModes
|
||||
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {
|
||||
"engine_params": {},
|
||||
"connect_args": {"source": "Apache Superset"},
|
||||
}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = False
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
mock_uri,
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
|
||||
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
|
||||
|
||||
assert str(uri) == "trino://"
|
||||
assert "connect_args" in database.get_extra.return_value
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_user_impersonation(
|
||||
self, mock_make_url, mock_create_engine, engine_manager
|
||||
):
|
||||
"""Test user impersonation in _get_engine_args."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import EngineModes
|
||||
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {
|
||||
"engine_params": {},
|
||||
"connect_args": {"source": "Apache Superset"},
|
||||
}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = True
|
||||
database.get_oauth2_config.return_value = None
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
mock_uri,
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.impersonate_user.return_value = (
|
||||
mock_uri,
|
||||
{"connect_args": {"user": "alice", "source": "Apache Superset"}},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
|
||||
uri, kwargs = engine_manager._get_engine_args(database, None, None, None, None)
|
||||
|
||||
# Verify impersonate_user was called
|
||||
database.db_engine_spec.impersonate_user.assert_called_once()
|
||||
call_args = database.db_engine_spec.impersonate_user.call_args
|
||||
assert call_args[0][0] is database # database
|
||||
assert call_args[0][1] == "alice" # username
|
||||
assert call_args[0][2] is None # access_token (no OAuth2)
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_user_impersonation_email_prefix(
|
||||
self,
|
||||
mock_make_url,
|
||||
mock_create_engine,
|
||||
engine_manager,
|
||||
):
|
||||
"""Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature flag."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.engines.manager import EngineModes
|
||||
|
||||
engine_manager.mode = EngineModes.NEW
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
# Mock user with email
|
||||
mock_user = MagicMock()
|
||||
mock_user.email = "alice.doe@example.org"
|
||||
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {
|
||||
"engine_params": {},
|
||||
"connect_args": {"source": "Apache Superset"},
|
||||
}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = True
|
||||
database.get_oauth2_config.return_value = None
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
mock_uri,
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.impersonate_user.return_value = (
|
||||
mock_uri,
|
||||
{"connect_args": {"user": "alice.doe", "source": "Apache Superset"}},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"superset.extensions.security_manager.find_user",
|
||||
return_value=mock_user,
|
||||
),
|
||||
):
|
||||
uri, kwargs = engine_manager._get_engine_args(
|
||||
database, None, None, None, None
|
||||
)
|
||||
|
||||
# Verify impersonate_user was called with the email prefix
|
||||
database.db_engine_spec.impersonate_user.assert_called_once()
|
||||
call_args = database.db_engine_spec.impersonate_user.call_args
|
||||
assert call_args[0][1] == "alice.doe" # username from email prefix
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_engine_context_manager_called(
|
||||
self, mock_make_url, mock_create_engine, engine_manager, mock_database
|
||||
):
|
||||
"""Test that the engine context manager is properly called."""
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
mock_engine = MagicMock()
|
||||
mock_create_engine.return_value = mock_engine
|
||||
|
||||
# Track context manager calls
|
||||
context_manager_calls = []
|
||||
|
||||
def tracking_context_manager(database, catalog, schema):
|
||||
from contextlib import contextmanager
|
||||
|
||||
@contextmanager
|
||||
def inner():
|
||||
context_manager_calls.append(("enter", database, catalog, schema))
|
||||
yield
|
||||
context_manager_calls.append(("exit", database, catalog, schema))
|
||||
|
||||
return inner()
|
||||
|
||||
engine_manager.engine_context_manager = tracking_context_manager
|
||||
|
||||
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
|
||||
pass
|
||||
|
||||
assert len(context_manager_calls) == 2
|
||||
assert context_manager_calls[0][0] == "enter"
|
||||
assert context_manager_calls[0][1] is mock_database
|
||||
assert context_manager_calls[0][2] == "catalog1"
|
||||
assert context_manager_calls[0][3] == "schema1"
|
||||
assert context_manager_calls[1][0] == "exit"
|
||||
|
||||
@patch("superset.utils.oauth2.check_for_oauth2")
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_engine_oauth2_error_handling(
|
||||
self,
|
||||
mock_make_url,
|
||||
mock_create_engine,
|
||||
mock_check_for_oauth2,
|
||||
engine_manager,
|
||||
mock_database,
|
||||
):
|
||||
"""Test that OAuth2 errors are properly propagated from get_engine."""
|
||||
from contextlib import contextmanager
|
||||
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
mock_uri = make_url("trino://")
|
||||
mock_make_url.return_value = mock_uri
|
||||
|
||||
# Simulate OAuth2 error during engine creation
|
||||
class OAuth2TestError(Exception):
|
||||
pass
|
||||
|
||||
oauth_error = OAuth2TestError("OAuth2 required")
|
||||
mock_create_engine.side_effect = oauth_error
|
||||
|
||||
# Make get_dbapi_mapped_exception return the original exception
|
||||
mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = (
|
||||
oauth_error
|
||||
)
|
||||
|
||||
# Mock check_for_oauth2 to re-raise the exception
|
||||
@contextmanager
|
||||
def mock_oauth2_context(database):
|
||||
try:
|
||||
yield
|
||||
except OAuth2TestError:
|
||||
raise
|
||||
|
||||
mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database)
|
||||
|
||||
with pytest.raises(OAuth2TestError, match="OAuth2 required"):
|
||||
with engine_manager.get_engine(mock_database, "catalog1", "schema1", None):
|
||||
pass
|
||||
@@ -123,7 +123,7 @@ class TestSupersetAppInitializer:
|
||||
patch.object(app_initializer, "configure_data_sources"),
|
||||
patch.object(app_initializer, "configure_auth_provider"),
|
||||
patch.object(app_initializer, "configure_async_queries"),
|
||||
patch.object(app_initializer, "configure_ssh_manager"),
|
||||
patch.object(app_initializer, "configure_engine_manager"),
|
||||
patch.object(app_initializer, "configure_stats_manager"),
|
||||
patch.object(app_initializer, "init_views"),
|
||||
):
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
@@ -29,7 +28,6 @@ from sqlalchemy import (
|
||||
Table as SqlalchemyTable,
|
||||
)
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
@@ -525,60 +523,6 @@ def test_get_all_materialized_view_names_in_schema_needs_oauth2(
|
||||
assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT
|
||||
|
||||
|
||||
def test_get_sqla_engine(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `_get_sqla_engine`.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
user = mocker.MagicMock()
|
||||
user.email = "alice.doe@example.org"
|
||||
mocker.patch(
|
||||
"superset.models.core.security_manager.find_user",
|
||||
return_value=user,
|
||||
)
|
||||
mocker.patch("superset.models.core.get_username", return_value="alice")
|
||||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
make_url("trino:///"),
|
||||
connect_args={"source": "Apache Superset"},
|
||||
)
|
||||
|
||||
|
||||
def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test user impersonation in `_get_sqla_engine`.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
user = mocker.MagicMock()
|
||||
user.email = "alice.doe@example.org"
|
||||
mocker.patch(
|
||||
"superset.models.core.security_manager.find_user",
|
||||
return_value=user,
|
||||
)
|
||||
mocker.patch("superset.models.core.get_username", return_value="alice")
|
||||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="trino://",
|
||||
impersonate_user=True,
|
||||
)
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
make_url("trino:///"),
|
||||
connect_args={"user": "alice", "source": "Apache Superset"},
|
||||
)
|
||||
|
||||
|
||||
def test_add_database_to_signature():
|
||||
args = ["param1", "param2"]
|
||||
|
||||
@@ -604,36 +548,6 @@ def test_add_database_to_signature():
|
||||
assert args3 == ["param1", "param2", database]
|
||||
|
||||
|
||||
@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
|
||||
def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test user impersonation in `_get_sqla_engine` with `username_from_email`.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
user = mocker.MagicMock()
|
||||
user.email = "alice.doe@example.org"
|
||||
mocker.patch(
|
||||
"superset.models.core.security_manager.find_user",
|
||||
return_value=user,
|
||||
)
|
||||
mocker.patch("superset.models.core.get_username", return_value="alice")
|
||||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="trino://",
|
||||
impersonate_user=True,
|
||||
)
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
make_url("trino:///"),
|
||||
connect_args={"user": "alice.doe", "source": "Apache Superset"},
|
||||
)
|
||||
|
||||
|
||||
def test_is_oauth2_enabled() -> None:
|
||||
"""
|
||||
Test the `is_oauth2_enabled` method.
|
||||
@@ -753,37 +667,6 @@ def test_get_oauth2_config_redirect_uri_from_config(
|
||||
assert config["redirect_uri"] == custom_redirect_uri
|
||||
|
||||
|
||||
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||
|
||||
With OAuth2, some databases will raise an exception when the engine is first created
|
||||
(eg, BigQuery). Others, like, Snowflake, when the connection is created. And
|
||||
finally, GSheets will raise an exception when the query is executed.
|
||||
|
||||
This tests verifies that when calling `raw_connection()` the OAuth2 flow is
|
||||
triggered when the engine is created.
|
||||
"""
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user = mocker.MagicMock()
|
||||
g.user.id = 42
|
||||
|
||||
database = Database(
|
||||
id=1,
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="sqlite://",
|
||||
encrypted_extra=json.dumps(oauth2_client_info),
|
||||
)
|
||||
database.db_engine_spec.oauth2_exception = OAuth2Error # type: ignore
|
||||
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
|
||||
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
|
||||
|
||||
with pytest.raises(OAuth2RedirectError) as excinfo:
|
||||
with database.get_raw_connection() as conn:
|
||||
conn.cursor()
|
||||
assert str(excinfo.value) == "You don't have permission to access the data."
|
||||
|
||||
|
||||
def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we can start OAuth2 from `raw_connection()` errors.
|
||||
@@ -879,56 +762,6 @@ def test_get_schema_access_for_file_upload() -> None:
|
||||
assert database.get_schema_access_for_file_upload() == {"public"}
|
||||
|
||||
|
||||
def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> None:
|
||||
"""
|
||||
Test the engine context manager.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
engine_context_manager = MagicMock()
|
||||
mocker.patch.dict(
|
||||
current_app.config,
|
||||
{"ENGINE_CONTEXT_MANAGER": engine_context_manager},
|
||||
)
|
||||
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
with database.get_sqla_engine("catalog", "schema"):
|
||||
pass
|
||||
|
||||
engine_context_manager.assert_called_once_with(database, "catalog", "schema")
|
||||
engine_context_manager().__enter__.assert_called_once()
|
||||
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
|
||||
_get_sqla_engine.assert_called_once_with(
|
||||
catalog="catalog",
|
||||
schema="schema",
|
||||
nullpool=True,
|
||||
source=None,
|
||||
sqlalchemy_uri="trino://",
|
||||
)
|
||||
|
||||
|
||||
def test_engine_oauth2(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we handle OAuth2 when `create_engine` fails.
|
||||
"""
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
|
||||
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
|
||||
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
|
||||
start_oauth2_dance = mocker.patch.object(
|
||||
database.db_engine_spec,
|
||||
"start_oauth2_dance",
|
||||
side_effect=OAuth2Error("OAuth2 required"),
|
||||
)
|
||||
|
||||
with pytest.raises(OAuth2Error):
|
||||
with database.get_sqla_engine("catalog", "schema"):
|
||||
pass
|
||||
|
||||
start_oauth2_dance.assert_called_with(database)
|
||||
|
||||
|
||||
def test_purge_oauth2_tokens(session: Session) -> None:
|
||||
"""
|
||||
Test the `purge_oauth2_tokens` method.
|
||||
|
||||
@@ -202,7 +202,6 @@ def setup_mock_raw_connection(
|
||||
def _raw_connection(
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
nullpool: bool = True,
|
||||
source: Any | None = None,
|
||||
):
|
||||
yield mock_connection
|
||||
|
||||
Reference in New Issue
Block a user