Compare commits

...

19 Commits

Author SHA1 Message Date
Beto Dealmeida
75a64b3062 Address comments 2026-05-11 15:57:19 -04:00
Beto Dealmeida
9c5e6d187b Fix lint 2026-05-05 20:11:56 -04:00
Beto Dealmeida
3aad565eab refactor: keep engine manager focused on engine creation 2026-05-05 18:22:36 -04:00
Beto Dealmeida
b7b59dfb8a fix: Add port validation for SSH tunnels
Raise SSHTunnelDatabasePortError when the database URI has no port and
there's no default port for the database backend. This matches the
original behavior from the removed ssh.py module.

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-05-05 18:08:58 -04:00
Beto Dealmeida
fa0d4e1c08 Small fixes 2026-05-05 18:08:58 -04:00
Beto Dealmeida
08df7d5178 fix: SSH tunnel and test connection error handling
- Use sshtunnel.open_tunnel() instead of SSHTunnelForwarder directly
  to properly handle debug_level parameter
- Fix keepalive parameter name (set_keepalive, not keepalive)
- Fix test assertions that were inside pytest.raises blocks and never
  executed - now check error_type instead of string messages
- Update SSH tunnel test mocks to patch open_tunnel

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
2026-05-05 18:08:57 -04:00
Beto Dealmeida
9dc54d8f1b Rebase 2026-05-05 18:08:57 -04:00
Beto Dealmeida
e2ce534148 Fix tests 2026-05-05 18:08:43 -04:00
Beto Dealmeida
b3f8831d34 Fix poolclass check 2026-05-05 18:08:43 -04:00
Beto Dealmeida
1775cae220 Fix more tests 2026-05-05 18:08:42 -04:00
Beto Dealmeida
26f0390bbb Simplify key generation 2026-05-05 18:08:42 -04:00
Beto Dealmeida
ea27cabfc6 Update existing tests 2026-05-05 18:08:42 -04:00
Beto Dealmeida
f39367bffd Hash key 2026-05-05 18:08:32 -04:00
Beto Dealmeida
11395531f2 Small improvements 2026-05-05 18:08:31 -04:00
Beto Dealmeida
ec018cd842 Cleanup 2026-05-05 18:02:45 -04:00
Beto Dealmeida
48d3f441b8 Connecting 2026-05-05 18:02:45 -04:00
Beto Dealmeida
b3393c65f7 Add extension 2026-05-05 18:02:45 -04:00
Beto Dealmeida
8776b651a5 Cleanup locks 2026-05-05 18:02:45 -04:00
Beto Dealmeida
ccd32920fc feat: engine manager 2026-05-05 18:02:45 -04:00
21 changed files with 562 additions and 436 deletions

View File

@@ -24,6 +24,24 @@ assists people when migrating to a new version.
## Next
### `SSH_TUNNEL_MANAGER_CLASS` replaced by `ENGINE_MANAGER_CLASS`
The `SSH_TUNNEL_MANAGER_CLASS` config setting, the `superset.extensions.ssh` module (containing `SSHManager` and `SSHManagerFactory`), and the `ssh_manager_factory` extension singleton have been removed. SQLAlchemy engine creation — including SSH tunnel construction and URL rewriting — is now centralized in `EngineManager` (`superset/engines/manager.py`), wired up via `EngineManagerExtension` (`superset/extensions/engine_manager.py`).
A new config setting, `ENGINE_MANAGER_CLASS` (default: `"superset.engines.manager.EngineManager"`), replaces `SSH_TUNNEL_MANAGER_CLASS` as the customization hook. Deployments that previously subclassed `SSHManager` (e.g. for bastion routing, audit logging, host-key policy, or custom credential handling) should subclass `EngineManager` instead and set `ENGINE_MANAGER_CLASS` to the dotted path of the subclass. Override the relevant methods:
| Old `SSHManager` method | New override point on `EngineManager` |
|---|---|
| `__init__(app)` reading `SSH_TUNNEL_*` configs | `__init__` — the same `SSH_TUNNEL_LOCAL_BIND_ADDRESS`, `SSH_TUNNEL_TIMEOUT_SEC`, and `SSH_TUNNEL_PACKET_TIMEOUT_SEC` configs are still loaded by `EngineManagerExtension.init_app` and passed in |
| `create_tunnel(ssh_tunnel, uri)` | `_get_tunnel_kwargs(ssh_tunnel, uri)` for parameter construction and `_create_tunnel(ssh_tunnel, uri)` for the `sshtunnel.open_tunnel` + `start()` call |
| `build_sqla_url(url, server)` | Inlined in `get_engine` as `uri.set(host=tunnel.local_bind_address[0], port=tunnel.local_bind_port)` |
**Behavioral note:** the old `SSHManager.create_tunnel` passed `debug_level=logging.getLogger("flask_appbuilder").level` to `sshtunnel.open_tunnel`. The new `_get_tunnel_kwargs` does not. Subclasses relying on that should add it back in their override.
### `Database.get_sqla_engine(nullpool=...)` deprecated
The `nullpool` keyword argument to `Database.get_sqla_engine` is deprecated and ignored — the engine manager always uses `NullPool`. The kwarg is still accepted (with a `DeprecationWarning`) so external callers passing `nullpool=False` won't fail with `TypeError`, but the resulting engine will use `NullPool` regardless. Remove the argument from your callers; it will be deleted in a future release.
### Granular Export Controls
A new feature flag `GRANULAR_EXPORT_CONTROLS` introduces three fine-grained permissions that replace the legacy `can_csv` permission:
@@ -114,7 +132,6 @@ DISTRIBUTED_COORDINATION_CONFIG = {
```
See `superset/config.py` for complete configuration options.
### WebSocket config for GAQ with Docker
[35896](https://github.com/apache/superset/pull/35896) and [37624](https://github.com/apache/superset/pull/37624) updated documentation on how to run and configure Superset with Docker. Specifically for the WebSocket configuration, a new `docker/superset-websocket/config.example.json` was added to the repo, so that users could copy it to create a `docker/superset-websocket/config.json` file. The existing `docker/superset-websocket/config.json` was removed and git-ignored, so if you're using GAQ / WebSocket make sure to:

View File

@@ -17,6 +17,7 @@
* under the License.
*/
import { getExtensionsRegistry } from '@superset-ui/core';
import type { ComponentType, ReactNode } from 'react';
import { Provider as ReduxProvider } from 'react-redux';
import { QueryParamProvider } from 'use-query-params';
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
@@ -64,7 +65,7 @@ export const EmbeddedContextProviders: React.FC<{
}> = ({ children }) => {
const RootContextProviderExtension = extensionsRegistry.get(
'root.context.provider',
);
) as ComponentType<{ children?: ReactNode }> | undefined;
return (
<SupersetThemeProvider themeController={themeController}>

View File

@@ -188,7 +188,9 @@ function CollectionControl({
// Two items can collide when keyAccessor returns falsy and the index
// fallback is used — breaking dnd-kit reordering and React reconciliation.
// Assign a stable nanoid per item ref when no key is available.
const generatedIdsRef = useRef<WeakMap<CollectionItem, string>>(new WeakMap());
const generatedIdsRef = useRef<WeakMap<CollectionItem, string>>(
new WeakMap(),
);
const itemIds = useMemo(
() =>
value.map(item => {

View File

@@ -34,6 +34,7 @@ import {
Tooltip,
Row,
type OnClickHandler,
type ButtonProps as CoreButtonProps,
} from '@superset-ui/core/components';
import { Icons } from '@superset-ui/core/components/Icons';
import { MenuObjectProps } from 'src/types/bootstrapTypes';
@@ -148,7 +149,7 @@ export interface ButtonProps {
'data-test'?: string;
buttonStyle: 'primary' | 'secondary' | 'dashed' | 'link' | 'tertiary';
loading?: boolean;
icon?: ReactNode;
icon?: CoreButtonProps['icon'];
component?: ReactNode;
}

View File

@@ -18,6 +18,7 @@
*/
import { getExtensionsRegistry } from '@superset-ui/core';
import type { ComponentType, ReactNode } from 'react';
import { Provider as ReduxProvider } from 'react-redux';
import { QueryParamProvider } from 'use-query-params';
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
@@ -39,7 +40,7 @@ export const RootContextProviders: React.FC<{ children?: React.ReactNode }> = ({
}) => {
const RootContextProviderExtension = extensionsRegistry.get(
'root.context.provider',
);
) as ComponentType<{ children?: ReactNode }> | undefined;
return (
<SupersetThemeProvider themeController={themeController}>

View File

@@ -55,7 +55,11 @@ from superset.constants import CHANGE_ME_SECRET_KEY
from superset.jinja_context import BaseTemplateProcessor
from superset.key_value.types import JsonKeyValueCodec
from superset.stats_logger import DummyStatsLogger
from superset.superset_typing import CacheConfig
from superset.superset_typing import (
CacheConfig,
DBConnectionMutator,
EngineContextManager,
)
from superset.tasks.types import ExecutorType
from superset.themes.types import Theme
from superset.utils import core as utils
@@ -831,7 +835,6 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
# FIREWALL (only port 22 is open)
# ----------------------------------------------------------------------
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
#: Timeout (seconds) for tunnel connection (open_channel timeout)
SSH_TUNNEL_TIMEOUT_SEC = 10.0
@@ -1720,7 +1723,14 @@ def engine_context_manager( # pylint: disable=unused-argument
yield None
ENGINE_CONTEXT_MANAGER = engine_context_manager
ENGINE_CONTEXT_MANAGER: EngineContextManager = engine_context_manager
# The class used to manage SQLAlchemy engine creation, including SSH tunnels
# and connection details. Deployments that need custom behavior (e.g. bastion
# routing, audit logging, host-key policy, custom credential handling) can
# subclass `superset.engines.manager.EngineManager` and point this setting at
# the subclass.
ENGINE_MANAGER_CLASS = "superset.engines.manager.EngineManager"
# A callable that allows altering the database connection URL and params
# on the fly, at runtime. This allows for things like impersonation or
@@ -1737,7 +1747,7 @@ ENGINE_CONTEXT_MANAGER = engine_context_manager
#
# Note that the returned uri and params are passed directly to sqlalchemy's
# as such `create_engine(url, **params)`
DB_CONNECTION_MUTATOR = None
DB_CONNECTION_MUTATOR: DBConnectionMutator | None = None
# A callable that is invoked for every invocation of DB Engine Specs

View File

@@ -14,23 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import Mock
import sshtunnel
from superset.extensions.ssh import SSHManagerFactory
def test_ssh_tunnel_timeout_setting() -> None:
app = Mock()
app.config = {
"SSH_TUNNEL_MAX_RETRIES": 2,
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test",
"SSH_TUNNEL_TIMEOUT_SEC": 123.0,
"SSH_TUNNEL_PACKET_TIMEOUT_SEC": 321.0,
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
}
factory = SSHManagerFactory()
factory.init_app(app)
assert sshtunnel.TUNNEL_TIMEOUT == 123.0
assert sshtunnel.SSH_TIMEOUT == 321.0

203
superset/engines/manager.py Normal file
View File

@@ -0,0 +1,203 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from contextlib import contextmanager
from datetime import timedelta
from io import StringIO
from typing import Any, Iterator, TYPE_CHECKING
import sshtunnel
from paramiko import RSAKey
from sqlalchemy import create_engine, pool
from sqlalchemy.engine import Engine
from sqlalchemy.engine.url import URL
from sshtunnel import SSHTunnelForwarder
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.databases.utils import make_url_safe
from superset.superset_typing import DBConnectionMutator, EngineContextManager
from superset.utils.core import get_query_source_from_request, get_user_id, QuerySource
if TYPE_CHECKING:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
class EngineManager:
"""Centralized SQLAlchemy engine creation for Superset."""
def __init__(
self,
engine_context_manager: EngineContextManager,
db_connection_mutator: DBConnectionMutator | None = None,
local_bind_address: str = "127.0.0.1",
tunnel_timeout: timedelta = timedelta(seconds=30),
ssh_timeout: timedelta = timedelta(seconds=1),
) -> None:
self.engine_context_manager = engine_context_manager
self.db_connection_mutator = db_connection_mutator
self.local_bind_address = local_bind_address
sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout.total_seconds()
sshtunnel.SSH_TIMEOUT = ssh_timeout.total_seconds()
@contextmanager
def get_engine(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
) -> Iterator[Engine]:
"""Context manager to get a SQLAlchemy engine."""
from superset.utils.oauth2 import check_for_oauth2
with self.engine_context_manager(database, catalog, schema):
with check_for_oauth2(database):
uri, kwargs = self._get_engine_args(
database,
catalog,
schema,
source,
get_user_id(),
)
if database.ssh_tunnel:
tunnel = self._create_tunnel(database.ssh_tunnel, uri)
try:
uri = uri.set(
host=tunnel.local_bind_address[0],
port=tunnel.local_bind_port,
)
yield self._create_engine(database, uri, kwargs)
finally:
tunnel.stop()
else:
yield self._create_engine(database, uri, kwargs)
def _get_engine_args(
self,
database: "Database",
catalog: str | None,
schema: str | None,
source: QuerySource | None,
user_id: int | None,
) -> tuple[URL, dict[str, Any]]:
"""Build SQLAlchemy URI and kwargs before engine creation."""
from superset import is_feature_enabled
from superset.extensions import security_manager
uri = make_url_safe(database.sqlalchemy_uri_decrypted)
extra = database.get_extra(source)
kwargs = dict(extra.get("engine_params", {}))
kwargs["poolclass"] = pool.NullPool
connect_args = kwargs.setdefault("connect_args", {})
uri, connect_args = database.db_engine_spec.adjust_engine_params(
uri,
connect_args,
catalog,
schema,
)
username = database.get_effective_user(uri)
if username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"):
user = security_manager.find_user(username=username)
if user and user.email and "@" in user.email:
username = user.email.split("@")[0]
if database.impersonate_user:
oauth2_config = database.get_oauth2_config()
from superset.utils.oauth2 import get_oauth2_access_token
access_token = (
get_oauth2_access_token(
oauth2_config,
database.id,
user_id,
database.db_engine_spec,
)
if oauth2_config and user_id
else None
)
uri, kwargs = database.db_engine_spec.impersonate_user(
database,
username,
access_token,
uri,
kwargs,
)
database.update_params_from_encrypted_extra(kwargs)
if self.db_connection_mutator:
source = source or get_query_source_from_request()
uri, kwargs = self.db_connection_mutator(
uri,
kwargs,
username,
security_manager,
source,
)
database.db_engine_spec.validate_database_uri(uri)
return uri, kwargs
def _create_engine(
self,
database: "Database",
uri: URL,
kwargs: dict[str, Any],
) -> Engine:
try:
return create_engine(uri, **kwargs)
except Exception as ex:
raise database.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
tunnel = sshtunnel.open_tunnel(**kwargs)
tunnel.start()
return tunnel
def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]:
from superset.utils.ssh_tunnel import get_default_port
backend = uri.get_backend_name()
port = uri.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
kwargs = {
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
"ssh_username": ssh_tunnel.username,
"remote_bind_address": (uri.host, port),
"local_bind_address": (self.local_bind_address,),
}
if ssh_tunnel.password:
kwargs["ssh_password"] = ssh_tunnel.password
elif ssh_tunnel.private_key:
private_key_file = StringIO(ssh_tunnel.private_key)
private_key = RSAKey.from_private_key(
private_key_file,
ssh_tunnel.private_key_password,
)
kwargs["ssh_pkey"] = private_key
return kwargs

View File

@@ -42,7 +42,7 @@ from werkzeug.local import LocalProxy
from superset.async_events.async_query_manager import AsyncQueryManager
from superset.async_events.async_query_manager_factory import AsyncQueryManagerFactory
from superset.extensions.ssh import SSHManagerFactory
from superset.extensions.engine_manager import EngineManagerExtension
from superset.extensions.stats_logger import BaseStatsLoggerManager
from superset.security.manager import SupersetSecurityManager
from superset.utils.cache_manager import CacheManager
@@ -146,6 +146,7 @@ cache_manager = CacheManager()
celery_app = celery.Celery()
csrf = CSRFProtect()
db = get_sqla_class()()
engine_manager_extension = EngineManagerExtension()
_event_logger: dict[str, Any] = {}
encrypted_field_factory = EncryptedFieldFactory()
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
@@ -156,6 +157,5 @@ migrate = Migrate()
profiling = ProfilingExtension()
results_backend_manager = ResultsBackendManager()
security_manager: SupersetSecurityManager = LocalProxy(lambda: appbuilder.sm)
ssh_manager_factory = SSHManagerFactory()
stats_logger_manager = BaseStatsLoggerManager()
talisman = Talisman()

View File

@@ -0,0 +1,77 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import timedelta
from flask import Flask
from superset.engines.manager import EngineManager
from superset.utils.class_utils import load_class_from_name
logger = logging.getLogger(__name__)
class EngineManagerExtension:
"""
Flask extension for managing SQLAlchemy engines in Superset.
"""
def __init__(self) -> None:
self.engine_manager: EngineManager | None = None
def init_app(self, app: Flask) -> None:
"""
Initialize the EngineManager with Flask app configuration.
"""
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
db_connection_mutator = app.config["DB_CONNECTION_MUTATOR"]
local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
tunnel_timeout = timedelta(seconds=app.config["SSH_TUNNEL_TIMEOUT_SEC"])
ssh_timeout = timedelta(seconds=app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"])
engine_manager_class: type[EngineManager] = load_class_from_name(
app.config["ENGINE_MANAGER_CLASS"]
)
self.engine_manager = engine_manager_class(
engine_context_manager,
db_connection_mutator,
local_bind_address,
tunnel_timeout,
ssh_timeout,
)
logger.info(
"Initialized EngineManager with tunnel_timeout=%s, ssh_timeout=%s",
tunnel_timeout.total_seconds(),
ssh_timeout.total_seconds(),
)
@property
def manager(self) -> EngineManager:
"""
Get the EngineManager instance.
Raises:
RuntimeError: If the extension hasn't been initialized with an app.
"""
if self.engine_manager is None:
raise RuntimeError(
"EngineManager extension not initialized. "
"Call init_app() with a Flask app first."
)
return self.engine_manager

View File

@@ -1,94 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from io import StringIO
from typing import TYPE_CHECKING
import sshtunnel
from flask import Flask
from paramiko import RSAKey
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.databases.utils import make_url_safe
from superset.utils.class_utils import load_class_from_name
if TYPE_CHECKING:
from superset.databases.ssh_tunnel.models import SSHTunnel
class SSHManager:
def __init__(self, app: Flask) -> None:
super().__init__()
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]
sshtunnel.SSH_TIMEOUT = app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]
def build_sqla_url(
self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder
) -> str:
# override any ssh tunnel configuration object
url = make_url_safe(sqlalchemy_url)
return url.set(
host=server.local_bind_address[0],
port=server.local_bind_port,
)
def create_tunnel(
self,
ssh_tunnel: "SSHTunnel",
sqlalchemy_database_uri: str,
) -> sshtunnel.SSHTunnelForwarder:
from superset.utils.ssh_tunnel import get_default_port
url = make_url_safe(sqlalchemy_database_uri)
backend = url.get_backend_name()
port = url.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
params = {
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
"ssh_username": ssh_tunnel.username,
"remote_bind_address": (url.host, port),
"local_bind_address": (self.local_bind_address,),
"debug_level": logging.getLogger("flask_appbuilder").level,
}
if ssh_tunnel.password:
params["ssh_password"] = ssh_tunnel.password
elif ssh_tunnel.private_key:
private_key_file = StringIO(ssh_tunnel.private_key)
private_key = RSAKey.from_private_key(
private_key_file, ssh_tunnel.private_key_password
)
params["ssh_pkey"] = private_key
return sshtunnel.open_tunnel(**params)
class SSHManagerFactory:
def __init__(self) -> None:
self._ssh_manager = None
def init_app(self, app: Flask) -> None:
self._ssh_manager = load_class_from_name(
app.config["SSH_TUNNEL_MANAGER_CLASS"]
)(app)
@property
def instance(self) -> SSHManager:
return self._ssh_manager # type: ignore

View File

@@ -49,13 +49,13 @@ from superset.extensions import (
csrf,
db,
encrypted_field_factory,
engine_manager_extension,
feature_flag_manager,
machine_auth_provider_factory,
manifest_processor,
migrate,
profiling,
results_backend_manager,
ssh_manager_factory,
stats_logger_manager,
talisman,
)
@@ -616,8 +616,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
self.configure_url_map_converters()
self.configure_data_sources()
self.configure_auth_provider()
self.configure_engine_manager()
self.configure_async_queries()
self.configure_ssh_manager()
self.configure_stats_manager()
self.configure_task_manager()
@@ -793,8 +793,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
def configure_auth_provider(self) -> None:
machine_auth_provider_factory.init_app(self.superset_app)
def configure_ssh_manager(self) -> None:
ssh_manager_factory.init_app(self.superset_app)
def configure_engine_manager(self) -> None:
engine_manager_extension.init_app(self.superset_app)
def configure_stats_manager(self) -> None:
stats_logger_manager.init_app(self.superset_app)

View File

@@ -25,24 +25,22 @@ import builtins
import logging
import textwrap
from ast import literal_eval
from contextlib import closing, contextmanager, nullcontext, suppress
from contextlib import closing, contextmanager, suppress
from copy import deepcopy
from datetime import datetime
from functools import lru_cache
from inspect import signature
from typing import Any, Callable, cast, Optional, TYPE_CHECKING
from typing import Any, Callable, cast, Iterator, Optional, TYPE_CHECKING
import numpy
import pandas as pd
import sqlalchemy as sqla
import sshtunnel
from flask import current_app as app, g, has_app_context
from flask_appbuilder import Model
from marshmallow.exceptions import ValidationError
from sqlalchemy import (
Boolean,
Column,
create_engine,
DateTime,
ForeignKey,
Integer,
@@ -57,7 +55,6 @@ from sqlalchemy.engine.url import URL
from sqlalchemy.exc import NoSuchModuleError
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import relationship
from sqlalchemy.pool import NullPool
from sqlalchemy.schema import UniqueConstraint
from sqlalchemy.sql import ColumnElement, expression, Select
from superset_core.common.models import Database as CoreDatabase
@@ -72,7 +69,6 @@ from superset.extensions import (
encrypted_field_factory,
event_logger,
security_manager,
ssh_manager_factory,
)
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin
from superset.result_set import SupersetResultSet
@@ -84,10 +80,9 @@ from superset.superset_typing import (
)
from superset.utils import cache as cache_util, core as utils, json
from superset.utils.backports import StrEnum
from superset.utils.core import get_query_source_from_request, get_username
from superset.utils.core import get_username
from superset.utils.oauth2 import (
check_for_oauth2,
get_oauth2_access_token,
OAuth2ClientConfigSchema,
)
@@ -424,130 +419,46 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
)
@contextmanager
def get_sqla_engine( # pylint: disable=too-many-arguments
def get_sqla_engine(
self,
catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
) -> Engine:
nullpool: bool | None = None,
) -> Iterator[Engine]:
"""
Context manager for a SQLAlchemy engine.
This method will return a context manager for a SQLAlchemy engine. Using the
context manager (as opposed to the engine directly) is important because we need
to potentially establish SSH tunnels before the connection is created, and clean
them up once the engine is no longer used.
This method will return a context manager for a SQLAlchemy engine. The engine
manager handles engine creation, SSH tunnels, and connection details in a
centralized place.
The ``nullpool`` argument is deprecated and ignored — the engine manager
always uses ``NullPool``. It is kept temporarily for backwards compatibility
with external callers and will be removed in a future release.
"""
if nullpool is not None:
import warnings
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
ssh_context_manager = (
ssh_manager_factory.instance.create_tunnel(
ssh_tunnel=self.ssh_tunnel,
sqlalchemy_database_uri=sqlalchemy_uri,
warnings.warn(
"The `nullpool` argument to `Database.get_sqla_engine` is "
"deprecated and ignored; the engine manager always uses NullPool.",
DeprecationWarning,
stacklevel=2,
)
if self.ssh_tunnel
else nullcontext()
)
with ssh_context_manager as ssh_context:
if ssh_context:
logger.info(
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s "
"ssh_timeout at %s",
sshtunnel.TUNNEL_TIMEOUT,
sshtunnel.SSH_TIMEOUT,
ssh_context.local_bind_address,
)
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
sqlalchemy_uri,
ssh_context,
)
# Import here to avoid circular imports
from superset.extensions import engine_manager_extension
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
with engine_context_manager(self, catalog, schema):
with check_for_oauth2(self):
yield self._get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
sqlalchemy_uri=sqlalchemy_uri,
)
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
self,
catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
sqlalchemy_uri: str | None = None,
) -> Engine:
sqlalchemy_url = make_url_safe(
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
)
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
extra = self.get_extra(source)
engine_kwargs = extra.get("engine_params", {})
if nullpool:
engine_kwargs["poolclass"] = NullPool
connect_args = engine_kwargs.setdefault("connect_args", {})
# modify URL/args for a specific catalog/schema
sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
uri=sqlalchemy_url,
connect_args=connect_args,
# Use the engine manager to get the engine
engine_manager = engine_manager_extension.manager
with engine_manager.get_engine(
database=self,
catalog=catalog,
schema=schema,
)
effective_username = self.get_effective_user(sqlalchemy_url)
if effective_username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"):
user = security_manager.find_user(username=effective_username)
if user and user.email:
effective_username = user.email.split("@")[0]
oauth2_config = self.get_oauth2_config()
access_token = (
get_oauth2_access_token(
oauth2_config,
self.id,
g.user.id,
self.db_engine_spec,
)
if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
else None
)
masked_url = self.get_password_masked_url(sqlalchemy_url)
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
if self.impersonate_user:
sqlalchemy_url, engine_kwargs = self.db_engine_spec.impersonate_user(
self,
effective_username,
access_token,
sqlalchemy_url,
engine_kwargs,
)
self.update_params_from_encrypted_extra(engine_kwargs)
if DB_CONNECTION_MUTATOR := app.config["DB_CONNECTION_MUTATOR"]: # noqa: N806
source = source or get_query_source_from_request()
sqlalchemy_url, engine_kwargs = DB_CONNECTION_MUTATOR(
sqlalchemy_url,
engine_kwargs,
effective_username,
security_manager,
source,
)
try:
return create_engine(sqlalchemy_url, **engine_kwargs)
except Exception as ex:
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
source=source,
) as engine:
yield engine
def add_database_to_signature(
self,
@@ -572,13 +483,11 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
self,
catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: utils.QuerySource | None = None,
) -> Connection:
with self.get_sqla_engine(
catalog=catalog,
schema=schema,
nullpool=nullpool,
source=source,
) as engine:
with check_for_oauth2(self):

View File

@@ -18,14 +18,28 @@ from __future__ import annotations
from collections.abc import Hashable, Sequence
from datetime import datetime
from typing import Any, Literal, TYPE_CHECKING, TypeAlias, TypedDict
from typing import (
Any,
Callable,
ContextManager,
Literal,
TYPE_CHECKING,
TypeAlias,
TypedDict,
)
from sqlalchemy.engine.url import URL
from sqlalchemy.sql.type_api import TypeEngine
from typing_extensions import NotRequired
from werkzeug.wrappers import Response
if TYPE_CHECKING:
from superset.utils.core import GenericDataType, QueryObjectFilterClause
from superset.models.core import Database
from superset.utils.core import (
GenericDataType,
QueryObjectFilterClause,
QuerySource,
)
SQLType: TypeAlias = TypeEngine | type[TypeEngine]
@@ -70,6 +84,18 @@ class DatasetMetricData(TypedDict, total=False):
verbose_name: str | None
# Type alias for database connection mutator function
DBConnectionMutator: TypeAlias = Callable[
[URL, dict[str, Any], str | None, Any, "QuerySource | None"],
tuple[URL, dict[str, Any]],
]
# Type alias for engine context manager
EngineContextManager: TypeAlias = Callable[
["Database", str | None, str | None], ContextManager[None]
]
class LegacyMetric(TypedDict):
label: str | None

View File

@@ -170,7 +170,6 @@ def example_db_provider() -> Callable[[], Database]:
return self._db
def _load_lazy_data_to_decouple_from_session(self) -> None:
self._db._get_sqla_engine() # type: ignore
self._db.backend # type: ignore # noqa: B018
def remove(self) -> None:

View File

@@ -897,7 +897,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
class TestTestConnectionDatabaseCommand(SupersetTestCase):
@patch("superset.models.core.Database._get_sqla_engine")
@patch("superset.models.core.Database.get_sqla_engine")
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
@patch("superset.utils.core.g")
def test_connection_db_exception(
@@ -906,19 +906,23 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
"""Test to make sure event_logger is called when an exception is raised"""
database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.side_effect = Exception("An error has occurred!")
mock_get_sqla_engine.return_value.__enter__.side_effect = Exception(
"An error has occurred!"
)
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: # noqa: PT012
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo:
command_without_db_name.run()
assert str(excinfo.value) == (
"Unexpected error occurred, please check your logs for details"
)
# Exception wraps errors from db_engine_spec.extract_errors()
assert (
excinfo.value.errors[0].error_type
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
)
mock_event_logger.assert_called()
@patch("superset.models.core.Database._get_sqla_engine")
@patch("superset.models.core.Database.get_sqla_engine")
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
@patch("superset.utils.core.g")
def test_connection_do_ping_exception(
@@ -927,9 +931,8 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
"""Test to make sure do_ping exceptions gets captured"""
database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception(
"An error has occurred!"
)
mock_engine = mock_get_sqla_engine.return_value.__enter__.return_value
mock_engine.dialect.do_ping.side_effect = Exception("An error has occurred!")
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
@@ -967,7 +970,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
== SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
)
@patch("superset.models.core.Database._get_sqla_engine")
@patch("superset.models.core.Database.get_sqla_engine")
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
@patch("superset.utils.core.g")
def test_connection_superset_security_connection(
@@ -977,20 +980,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
connection exc is raised"""
database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.side_effect = SupersetSecurityException(
SupersetError(error_type=500, message="test", level="info")
mock_get_sqla_engine.return_value.__enter__.side_effect = (
SupersetSecurityException(
SupersetError(error_type=500, message="test", level="info")
)
)
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
with pytest.raises(DatabaseSecurityUnsafeError) as excinfo: # noqa: PT012
with pytest.raises(DatabaseSecurityUnsafeError):
command_without_db_name.run()
assert str(excinfo.value) == ("Stopped an unsafe database connection")
mock_event_logger.assert_called()
@patch("superset.models.core.Database._get_sqla_engine")
@patch("superset.models.core.Database.get_sqla_engine")
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
@patch("superset.utils.core.g")
def test_connection_db_api_exc(
@@ -999,19 +1002,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
"""Test to make sure event_logger is called when DBAPIError is raised"""
database = get_example_database()
mock_g.user = security_manager.find_user("admin")
mock_get_sqla_engine.side_effect = DBAPIError(
mock_get_sqla_engine.return_value.__enter__.side_effect = DBAPIError(
statement="error", params={}, orig={}
)
db_uri = database.sqlalchemy_uri_decrypted
json_payload = {"sqlalchemy_uri": db_uri}
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
with pytest.raises(SupersetErrorsException) as excinfo: # noqa: PT012
with pytest.raises(SupersetErrorsException) as excinfo:
command_without_db_name.run()
assert str(excinfo.value) == (
"Connection failed, please check your connection settings"
)
# Exception wraps errors from db_engine_spec.extract_errors()
assert (
excinfo.value.errors[0].error_type
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
)
mock_event_logger.assert_called()
@@ -1147,7 +1151,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
with pytest.raises(DatabaseNotFoundError) as excinfo: # noqa: PT012
command.run()
assert str(excinfo.value) == ("Database not found.")
assert str(excinfo.value) == ("Database not found.")
@patch("superset.daos.database.DatabaseDAO.find_by_id")
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
@@ -1166,26 +1170,35 @@ class TestTablesDatabaseCommand(SupersetTestCase):
command = TablesDatabaseCommand(database.id, None, "main", False)
with pytest.raises(SupersetException) as excinfo: # noqa: PT012
command.run()
assert str(excinfo.value) == "Test Error"
assert str(excinfo.value) == "Test Error"
@patch("superset.daos.database.DatabaseDAO.find_by_id")
@patch("superset.models.core.Database.get_all_materialized_view_names_in_schema")
@patch("superset.models.core.Database.get_all_view_names_in_schema")
@patch("superset.models.core.Database.get_all_table_names_in_schema")
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
@patch("superset.utils.core.g")
def test_database_tables_exception(
self, mock_g, mock_can_access_database, mock_find_by_id
self,
mock_g,
mock_can_access_database,
mock_get_tables,
mock_get_views,
mock_get_mvs,
mock_find_by_id,
):
database = get_example_database()
mock_find_by_id.return_value = database
mock_get_tables.return_value = {("table1", "main", None)}
mock_get_views.return_value = set()
mock_get_mvs.return_value = []
mock_can_access_database.side_effect = Exception("Test Error")
mock_g.user = security_manager.find_user("admin")
command = TablesDatabaseCommand(database.id, None, "main", False)
with pytest.raises(DatabaseTablesUnexpectedError) as excinfo: # noqa: PT012
command.run()
assert (
str(excinfo.value)
== "Unexpected error occurred, please check your logs for details"
)
assert str(excinfo.value) == "Test Error"
@patch("superset.daos.database.DatabaseDAO.find_by_id")
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")

View File

@@ -145,7 +145,7 @@ class TestDatabaseModel(SupersetTestCase):
username = make_url(engine.url).username
assert example_user.username != username
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
@unittest.skipUnless(
SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
)
@@ -172,7 +172,8 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://gamma@localhost/"
@@ -185,7 +186,8 @@ class TestDatabaseModel(SupersetTestCase):
}
model.impersonate_user = False
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "presto://localhost/"
@@ -199,13 +201,14 @@ class TestDatabaseModel(SupersetTestCase):
@unittest.skipUnless(
SupersetTestCase.is_module_installed("mysqlclient"), "mysqlclient not installed"
)
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
def test_adjust_engine_params_mysql(self, mocked_create_engine):
model = Database(
database_name="test_database1",
sqlalchemy_uri="mysql://user:password@localhost",
)
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "mysql://user:password@localhost"
@@ -215,13 +218,14 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database2",
sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
)
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
assert call_args[1]["connect_args"]["allow_local_infile"] == 0
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
def test_impersonate_user_trino(self, mocked_create_engine):
principal_user = security_manager.find_user(username="gamma")
@@ -230,7 +234,8 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri="trino://localhost"
)
model.impersonate_user = True
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "trino://localhost/"
@@ -242,7 +247,8 @@ class TestDatabaseModel(SupersetTestCase):
)
model.impersonate_user = True
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert (
@@ -251,7 +257,7 @@ class TestDatabaseModel(SupersetTestCase):
)
assert call_args[1]["connect_args"]["user"] == "gamma"
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
@unittest.skipUnless(
SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
)
@@ -281,7 +287,8 @@ class TestDatabaseModel(SupersetTestCase):
database_name="test_database", sqlalchemy_uri=uri, extra=extra
)
model.impersonate_user = True
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
@@ -294,7 +301,8 @@ class TestDatabaseModel(SupersetTestCase):
}
model.impersonate_user = False
model._get_sqla_engine()
with model.get_sqla_engine():
pass
call_args = mocked_create_engine.call_args
assert str(call_args[0][0]) == "hive://localhost"
@@ -376,7 +384,7 @@ class TestDatabaseModel(SupersetTestCase):
df = main_db.get_df("USE superset; SELECT ';';", None, None)
assert df.iat[0, 0] == ";"
@mock.patch("superset.models.core.create_engine")
@mock.patch("superset.engines.manager.create_engine")
def test_get_sqla_engine(self, mocked_create_engine):
model = Database(
database_name="test_database",
@@ -387,7 +395,8 @@ class TestDatabaseModel(SupersetTestCase):
)
mocked_create_engine.side_effect = Exception()
with self.assertRaises(SupersetException): # noqa: PT027
model._get_sqla_engine()
with model.get_sqla_engine():
pass
class TestSqlaTableModel(SupersetTestCase):

View File

@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from contextlib import contextmanager
from unittest.mock import MagicMock, patch
import pytest
from sqlalchemy import pool
from sqlalchemy.engine.url import make_url
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.engines.manager import EngineManager
@pytest.fixture
def engine_manager() -> EngineManager:
@contextmanager
def dummy_context_manager(
database: MagicMock,
catalog: str | None,
schema: str | None,
):
yield
return EngineManager(engine_context_manager=dummy_context_manager)
@pytest.fixture
def mock_database() -> MagicMock:
database = MagicMock()
database.id = 1
database.sqlalchemy_uri_decrypted = "trino://"
database.get_extra.return_value = {"engine_params": {"poolclass": "queue"}}
database.get_effective_user.return_value = "alice"
database.impersonate_user = False
database.update_params_from_encrypted_extra = MagicMock()
database.db_engine_spec = MagicMock()
database.db_engine_spec.adjust_engine_params.return_value = (
make_url("trino://"),
{"source": "Apache Superset"},
)
database.db_engine_spec.validate_database_uri = MagicMock()
return database
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_uses_null_pool(
mock_make_url: MagicMock,
engine_manager: EngineManager,
mock_database: MagicMock,
) -> None:
mock_make_url.return_value = make_url("trino://")
_, kwargs = engine_manager._get_engine_args(mock_database, None, None, None, None)
assert kwargs["poolclass"] is pool.NullPool
@patch("superset.engines.manager.make_url_safe")
def test_get_engine_args_with_impersonation(
mock_make_url: MagicMock,
engine_manager: EngineManager,
mock_database: MagicMock,
) -> None:
mock_make_url.return_value = make_url("trino://")
mock_database.impersonate_user = True
mock_database.get_oauth2_config.return_value = None
mock_database.db_engine_spec.impersonate_user.return_value = (
make_url("trino://"),
{"connect_args": {"user": "alice"}, "poolclass": pool.NullPool},
)
engine_manager._get_engine_args(mock_database, None, None, None, None)
mock_database.db_engine_spec.impersonate_user.assert_called_once()
def test_get_tunnel_kwargs_requires_database_port(
engine_manager: EngineManager,
) -> None:
ssh_tunnel = MagicMock()
ssh_tunnel.server_address = "ssh.example.com"
ssh_tunnel.server_port = 22
ssh_tunnel.username = "ssh_user"
ssh_tunnel.password = None
ssh_tunnel.private_key = None
ssh_tunnel.private_key_password = None
uri = MagicMock()
uri.port = None
uri.get_backend_name.return_value = "unknown"
with patch("superset.utils.ssh_tunnel.get_default_port", return_value=None):
with pytest.raises(SSHTunnelDatabasePortError):
engine_manager._get_tunnel_kwargs(ssh_tunnel, uri)

View File

@@ -124,7 +124,7 @@ class TestSupersetAppInitializer:
patch.object(app_initializer, "configure_data_sources"),
patch.object(app_initializer, "configure_auth_provider"),
patch.object(app_initializer, "configure_async_queries"),
patch.object(app_initializer, "configure_ssh_manager"),
patch.object(app_initializer, "configure_engine_manager"),
patch.object(app_initializer, "configure_stats_manager"),
patch.object(app_initializer, "init_views"),
):

View File

@@ -19,7 +19,6 @@
from datetime import datetime
import pytest
from flask import current_app
from pytest_mock import MockerFixture
from sqlalchemy import (
Column,
@@ -29,7 +28,6 @@ from sqlalchemy import (
Table as SqlalchemyTable,
)
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.engine.url import make_url
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import Select
@@ -525,60 +523,6 @@ def test_get_all_materialized_view_names_in_schema_needs_oauth2(
assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT
def test_get_sqla_engine(mocker: MockerFixture) -> None:
"""
Test `_get_sqla_engine`.
"""
from superset.models.core import Database
user = mocker.MagicMock()
user.email = "alice.doe@example.org"
mocker.patch(
"superset.models.core.security_manager.find_user",
return_value=user,
)
mocker.patch("superset.models.core.get_username", return_value="alice")
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
make_url("trino:///"),
connect_args={"source": "Apache Superset"},
)
def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
"""
Test user impersonation in `_get_sqla_engine`.
"""
from superset.models.core import Database
user = mocker.MagicMock()
user.email = "alice.doe@example.org"
mocker.patch(
"superset.models.core.security_manager.find_user",
return_value=user,
)
mocker.patch("superset.models.core.get_username", return_value="alice")
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
impersonate_user=True,
)
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
make_url("trino:///"),
connect_args={"user": "alice", "source": "Apache Superset"},
)
def test_add_database_to_signature():
args = ["param1", "param2"]
@@ -604,36 +548,6 @@ def test_add_database_to_signature():
assert args3 == ["param1", "param2", database]
@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None:
"""
Test user impersonation in `_get_sqla_engine` with `username_from_email`.
"""
from superset.models.core import Database
user = mocker.MagicMock()
user.email = "alice.doe@example.org"
mocker.patch(
"superset.models.core.security_manager.find_user",
return_value=user,
)
mocker.patch("superset.models.core.get_username", return_value="alice")
create_engine = mocker.patch("superset.models.core.create_engine")
database = Database(
database_name="my_db",
sqlalchemy_uri="trino://",
impersonate_user=True,
)
database._get_sqla_engine(nullpool=False)
create_engine.assert_called_with(
make_url("trino:///"),
connect_args={"user": "alice.doe", "source": "Apache Superset"},
)
def test_is_oauth2_enabled() -> None:
"""
Test the `is_oauth2_enabled` method.
@@ -775,8 +689,8 @@ def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
encrypted_extra=json.dumps(oauth2_client_info),
)
database.db_engine_spec.oauth2_exception = OAuth2Error
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
create_engine = mocker.patch("superset.engines.manager.create_engine")
create_engine.side_effect = OAuth2Error("OAuth2 required")
with pytest.raises(OAuth2RedirectError) as excinfo:
with database.get_raw_connection() as conn:
@@ -879,56 +793,6 @@ def test_get_schema_access_for_file_upload() -> None:
assert database.get_schema_access_for_file_upload() == {"public"}
def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> None:
"""
Test the engine context manager.
"""
from unittest.mock import MagicMock
engine_context_manager = MagicMock()
mocker.patch.dict(
current_app.config,
{"ENGINE_CONTEXT_MANAGER": engine_context_manager},
)
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
with database.get_sqla_engine("catalog", "schema"):
pass
engine_context_manager.assert_called_once_with(database, "catalog", "schema")
engine_context_manager().__enter__.assert_called_once()
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
_get_sqla_engine.assert_called_once_with(
catalog="catalog",
schema="schema",
nullpool=True,
source=None,
sqlalchemy_uri="trino://",
)
def test_engine_oauth2(mocker: MockerFixture) -> None:
"""
Test that we handle OAuth2 when `create_engine` fails.
"""
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
start_oauth2_dance = mocker.patch.object(
database.db_engine_spec,
"start_oauth2_dance",
side_effect=OAuth2Error("OAuth2 required"),
)
with pytest.raises(OAuth2Error):
with database.get_sqla_engine("catalog", "schema"):
pass
start_oauth2_dance.assert_called_with(database)
def test_purge_oauth2_tokens(session: Session) -> None:
"""
Test the `purge_oauth2_tokens` method.

View File

@@ -202,7 +202,6 @@ def setup_mock_raw_connection(
def _raw_connection(
catalog: str | None = None,
schema: str | None = None,
nullpool: bool = True,
source: Any | None = None,
):
yield mock_connection