mirror of
https://github.com/apache/superset.git
synced 2026-05-13 20:05:20 +00:00
Compare commits
19 Commits
rls-splice
...
engine-man
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
75a64b3062 | ||
|
|
9c5e6d187b | ||
|
|
3aad565eab | ||
|
|
b7b59dfb8a | ||
|
|
fa0d4e1c08 | ||
|
|
08df7d5178 | ||
|
|
9dc54d8f1b | ||
|
|
e2ce534148 | ||
|
|
b3f8831d34 | ||
|
|
1775cae220 | ||
|
|
26f0390bbb | ||
|
|
ea27cabfc6 | ||
|
|
f39367bffd | ||
|
|
11395531f2 | ||
|
|
ec018cd842 | ||
|
|
48d3f441b8 | ||
|
|
b3393c65f7 | ||
|
|
8776b651a5 | ||
|
|
ccd32920fc |
19
UPDATING.md
19
UPDATING.md
@@ -24,6 +24,24 @@ assists people when migrating to a new version.
|
||||
|
||||
## Next
|
||||
|
||||
### `SSH_TUNNEL_MANAGER_CLASS` replaced by `ENGINE_MANAGER_CLASS`
|
||||
|
||||
The `SSH_TUNNEL_MANAGER_CLASS` config setting, the `superset.extensions.ssh` module (containing `SSHManager` and `SSHManagerFactory`), and the `ssh_manager_factory` extension singleton have been removed. SQLAlchemy engine creation — including SSH tunnel construction and URL rewriting — is now centralized in `EngineManager` (`superset/engines/manager.py`), wired up via `EngineManagerExtension` (`superset/extensions/engine_manager.py`).
|
||||
|
||||
A new config setting, `ENGINE_MANAGER_CLASS` (default: `"superset.engines.manager.EngineManager"`), replaces `SSH_TUNNEL_MANAGER_CLASS` as the customization hook. Deployments that previously subclassed `SSHManager` (e.g. for bastion routing, audit logging, host-key policy, or custom credential handling) should subclass `EngineManager` instead and set `ENGINE_MANAGER_CLASS` to the dotted path of the subclass. Override the relevant methods:
|
||||
|
||||
| Old `SSHManager` method | New override point on `EngineManager` |
|
||||
|---|---|
|
||||
| `__init__(app)` reading `SSH_TUNNEL_*` configs | `__init__` — the same `SSH_TUNNEL_LOCAL_BIND_ADDRESS`, `SSH_TUNNEL_TIMEOUT_SEC`, and `SSH_TUNNEL_PACKET_TIMEOUT_SEC` configs are still loaded by `EngineManagerExtension.init_app` and passed in |
|
||||
| `create_tunnel(ssh_tunnel, uri)` | `_get_tunnel_kwargs(ssh_tunnel, uri)` for parameter construction and `_create_tunnel(ssh_tunnel, uri)` for the `sshtunnel.open_tunnel` + `start()` call |
|
||||
| `build_sqla_url(url, server)` | Inlined in `get_engine` as `uri.set(host=tunnel.local_bind_address[0], port=tunnel.local_bind_port)` |
|
||||
|
||||
**Behavioral note:** the old `SSHManager.create_tunnel` passed `debug_level=logging.getLogger("flask_appbuilder").level` to `sshtunnel.open_tunnel`. The new `_get_tunnel_kwargs` does not. Subclasses relying on that should add it back in their override.
|
||||
|
||||
### `Database.get_sqla_engine(nullpool=...)` deprecated
|
||||
|
||||
The `nullpool` keyword argument to `Database.get_sqla_engine` is deprecated and ignored — the engine manager always uses `NullPool`. The kwarg is still accepted (with a `DeprecationWarning`) so external callers passing `nullpool=False` won't fail with `TypeError`, but the resulting engine will use `NullPool` regardless. Remove the argument from your callers; it will be deleted in a future release.
|
||||
|
||||
### Granular Export Controls
|
||||
|
||||
A new feature flag `GRANULAR_EXPORT_CONTROLS` introduces three fine-grained permissions that replace the legacy `can_csv` permission:
|
||||
@@ -114,7 +132,6 @@ DISTRIBUTED_COORDINATION_CONFIG = {
|
||||
```
|
||||
|
||||
See `superset/config.py` for complete configuration options.
|
||||
|
||||
### WebSocket config for GAQ with Docker
|
||||
|
||||
[35896](https://github.com/apache/superset/pull/35896) and [37624](https://github.com/apache/superset/pull/37624) updated documentation on how to run and configure Superset with Docker. Specifically for the WebSocket configuration, a new `docker/superset-websocket/config.example.json` was added to the repo, so that users could copy it to create a `docker/superset-websocket/config.json` file. The existing `docker/superset-websocket/config.json` was removed and git-ignored, so if you're using GAQ / WebSocket make sure to:
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
* under the License.
|
||||
*/
|
||||
import { getExtensionsRegistry } from '@superset-ui/core';
|
||||
import type { ComponentType, ReactNode } from 'react';
|
||||
import { Provider as ReduxProvider } from 'react-redux';
|
||||
import { QueryParamProvider } from 'use-query-params';
|
||||
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
|
||||
@@ -64,7 +65,7 @@ export const EmbeddedContextProviders: React.FC<{
|
||||
}> = ({ children }) => {
|
||||
const RootContextProviderExtension = extensionsRegistry.get(
|
||||
'root.context.provider',
|
||||
);
|
||||
) as ComponentType<{ children?: ReactNode }> | undefined;
|
||||
|
||||
return (
|
||||
<SupersetThemeProvider themeController={themeController}>
|
||||
|
||||
@@ -188,7 +188,9 @@ function CollectionControl({
|
||||
// Two items can collide when keyAccessor returns falsy and the index
|
||||
// fallback is used — breaking dnd-kit reordering and React reconciliation.
|
||||
// Assign a stable nanoid per item ref when no key is available.
|
||||
const generatedIdsRef = useRef<WeakMap<CollectionItem, string>>(new WeakMap());
|
||||
const generatedIdsRef = useRef<WeakMap<CollectionItem, string>>(
|
||||
new WeakMap(),
|
||||
);
|
||||
const itemIds = useMemo(
|
||||
() =>
|
||||
value.map(item => {
|
||||
|
||||
@@ -34,6 +34,7 @@ import {
|
||||
Tooltip,
|
||||
Row,
|
||||
type OnClickHandler,
|
||||
type ButtonProps as CoreButtonProps,
|
||||
} from '@superset-ui/core/components';
|
||||
import { Icons } from '@superset-ui/core/components/Icons';
|
||||
import { MenuObjectProps } from 'src/types/bootstrapTypes';
|
||||
@@ -148,7 +149,7 @@ export interface ButtonProps {
|
||||
'data-test'?: string;
|
||||
buttonStyle: 'primary' | 'secondary' | 'dashed' | 'link' | 'tertiary';
|
||||
loading?: boolean;
|
||||
icon?: ReactNode;
|
||||
icon?: CoreButtonProps['icon'];
|
||||
component?: ReactNode;
|
||||
}
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
*/
|
||||
|
||||
import { getExtensionsRegistry } from '@superset-ui/core';
|
||||
import type { ComponentType, ReactNode } from 'react';
|
||||
import { Provider as ReduxProvider } from 'react-redux';
|
||||
import { QueryParamProvider } from 'use-query-params';
|
||||
import { ReactRouter5Adapter } from 'use-query-params/adapters/react-router-5';
|
||||
@@ -39,7 +40,7 @@ export const RootContextProviders: React.FC<{ children?: React.ReactNode }> = ({
|
||||
}) => {
|
||||
const RootContextProviderExtension = extensionsRegistry.get(
|
||||
'root.context.provider',
|
||||
);
|
||||
) as ComponentType<{ children?: ReactNode }> | undefined;
|
||||
|
||||
return (
|
||||
<SupersetThemeProvider themeController={themeController}>
|
||||
|
||||
@@ -55,7 +55,11 @@ from superset.constants import CHANGE_ME_SECRET_KEY
|
||||
from superset.jinja_context import BaseTemplateProcessor
|
||||
from superset.key_value.types import JsonKeyValueCodec
|
||||
from superset.stats_logger import DummyStatsLogger
|
||||
from superset.superset_typing import CacheConfig
|
||||
from superset.superset_typing import (
|
||||
CacheConfig,
|
||||
DBConnectionMutator,
|
||||
EngineContextManager,
|
||||
)
|
||||
from superset.tasks.types import ExecutorType
|
||||
from superset.themes.types import Theme
|
||||
from superset.utils import core as utils
|
||||
@@ -831,7 +835,6 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = {
|
||||
# FIREWALL (only port 22 is open)
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
|
||||
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
|
||||
#: Timeout (seconds) for tunnel connection (open_channel timeout)
|
||||
SSH_TUNNEL_TIMEOUT_SEC = 10.0
|
||||
@@ -1720,7 +1723,14 @@ def engine_context_manager( # pylint: disable=unused-argument
|
||||
yield None
|
||||
|
||||
|
||||
ENGINE_CONTEXT_MANAGER = engine_context_manager
|
||||
ENGINE_CONTEXT_MANAGER: EngineContextManager = engine_context_manager
|
||||
|
||||
# The class used to manage SQLAlchemy engine creation, including SSH tunnels
|
||||
# and connection details. Deployments that need custom behavior (e.g. bastion
|
||||
# routing, audit logging, host-key policy, custom credential handling) can
|
||||
# subclass `superset.engines.manager.EngineManager` and point this setting at
|
||||
# the subclass.
|
||||
ENGINE_MANAGER_CLASS = "superset.engines.manager.EngineManager"
|
||||
|
||||
# A callable that allows altering the database connection URL and params
|
||||
# on the fly, at runtime. This allows for things like impersonation or
|
||||
@@ -1737,7 +1747,7 @@ ENGINE_CONTEXT_MANAGER = engine_context_manager
|
||||
#
|
||||
# Note that the returned uri and params are passed directly to sqlalchemy's
|
||||
# as such `create_engine(url, **params)`
|
||||
DB_CONNECTION_MUTATOR = None
|
||||
DB_CONNECTION_MUTATOR: DBConnectionMutator | None = None
|
||||
|
||||
|
||||
# A callable that is invoked for every invocation of DB Engine Specs
|
||||
|
||||
@@ -14,23 +14,3 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest.mock import Mock
|
||||
|
||||
import sshtunnel
|
||||
|
||||
from superset.extensions.ssh import SSHManagerFactory
|
||||
|
||||
|
||||
def test_ssh_tunnel_timeout_setting() -> None:
|
||||
app = Mock()
|
||||
app.config = {
|
||||
"SSH_TUNNEL_MAX_RETRIES": 2,
|
||||
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test",
|
||||
"SSH_TUNNEL_TIMEOUT_SEC": 123.0,
|
||||
"SSH_TUNNEL_PACKET_TIMEOUT_SEC": 321.0,
|
||||
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
|
||||
}
|
||||
factory = SSHManagerFactory()
|
||||
factory.init_app(app)
|
||||
assert sshtunnel.TUNNEL_TIMEOUT == 123.0
|
||||
assert sshtunnel.SSH_TIMEOUT == 321.0
|
||||
203
superset/engines/manager.py
Normal file
203
superset/engines/manager.py
Normal file
@@ -0,0 +1,203 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from datetime import timedelta
|
||||
from io import StringIO
|
||||
from typing import Any, Iterator, TYPE_CHECKING
|
||||
|
||||
import sshtunnel
|
||||
from paramiko import RSAKey
|
||||
from sqlalchemy import create_engine, pool
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sshtunnel import SSHTunnelForwarder
|
||||
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.superset_typing import DBConnectionMutator, EngineContextManager
|
||||
from superset.utils.core import get_query_source_from_request, get_user_id, QuerySource
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
class EngineManager:
|
||||
"""Centralized SQLAlchemy engine creation for Superset."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine_context_manager: EngineContextManager,
|
||||
db_connection_mutator: DBConnectionMutator | None = None,
|
||||
local_bind_address: str = "127.0.0.1",
|
||||
tunnel_timeout: timedelta = timedelta(seconds=30),
|
||||
ssh_timeout: timedelta = timedelta(seconds=1),
|
||||
) -> None:
|
||||
self.engine_context_manager = engine_context_manager
|
||||
self.db_connection_mutator = db_connection_mutator
|
||||
self.local_bind_address = local_bind_address
|
||||
|
||||
sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout.total_seconds()
|
||||
sshtunnel.SSH_TIMEOUT = ssh_timeout.total_seconds()
|
||||
|
||||
@contextmanager
|
||||
def get_engine(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
) -> Iterator[Engine]:
|
||||
"""Context manager to get a SQLAlchemy engine."""
|
||||
from superset.utils.oauth2 import check_for_oauth2
|
||||
|
||||
with self.engine_context_manager(database, catalog, schema):
|
||||
with check_for_oauth2(database):
|
||||
uri, kwargs = self._get_engine_args(
|
||||
database,
|
||||
catalog,
|
||||
schema,
|
||||
source,
|
||||
get_user_id(),
|
||||
)
|
||||
|
||||
if database.ssh_tunnel:
|
||||
tunnel = self._create_tunnel(database.ssh_tunnel, uri)
|
||||
try:
|
||||
uri = uri.set(
|
||||
host=tunnel.local_bind_address[0],
|
||||
port=tunnel.local_bind_port,
|
||||
)
|
||||
yield self._create_engine(database, uri, kwargs)
|
||||
finally:
|
||||
tunnel.stop()
|
||||
else:
|
||||
yield self._create_engine(database, uri, kwargs)
|
||||
|
||||
def _get_engine_args(
|
||||
self,
|
||||
database: "Database",
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
source: QuerySource | None,
|
||||
user_id: int | None,
|
||||
) -> tuple[URL, dict[str, Any]]:
|
||||
"""Build SQLAlchemy URI and kwargs before engine creation."""
|
||||
from superset import is_feature_enabled
|
||||
from superset.extensions import security_manager
|
||||
|
||||
uri = make_url_safe(database.sqlalchemy_uri_decrypted)
|
||||
extra = database.get_extra(source)
|
||||
kwargs = dict(extra.get("engine_params", {}))
|
||||
kwargs["poolclass"] = pool.NullPool
|
||||
|
||||
connect_args = kwargs.setdefault("connect_args", {})
|
||||
uri, connect_args = database.db_engine_spec.adjust_engine_params(
|
||||
uri,
|
||||
connect_args,
|
||||
catalog,
|
||||
schema,
|
||||
)
|
||||
|
||||
username = database.get_effective_user(uri)
|
||||
if username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"):
|
||||
user = security_manager.find_user(username=username)
|
||||
if user and user.email and "@" in user.email:
|
||||
username = user.email.split("@")[0]
|
||||
|
||||
if database.impersonate_user:
|
||||
oauth2_config = database.get_oauth2_config()
|
||||
from superset.utils.oauth2 import get_oauth2_access_token
|
||||
|
||||
access_token = (
|
||||
get_oauth2_access_token(
|
||||
oauth2_config,
|
||||
database.id,
|
||||
user_id,
|
||||
database.db_engine_spec,
|
||||
)
|
||||
if oauth2_config and user_id
|
||||
else None
|
||||
)
|
||||
|
||||
uri, kwargs = database.db_engine_spec.impersonate_user(
|
||||
database,
|
||||
username,
|
||||
access_token,
|
||||
uri,
|
||||
kwargs,
|
||||
)
|
||||
|
||||
database.update_params_from_encrypted_extra(kwargs)
|
||||
|
||||
if self.db_connection_mutator:
|
||||
source = source or get_query_source_from_request()
|
||||
uri, kwargs = self.db_connection_mutator(
|
||||
uri,
|
||||
kwargs,
|
||||
username,
|
||||
security_manager,
|
||||
source,
|
||||
)
|
||||
|
||||
database.db_engine_spec.validate_database_uri(uri)
|
||||
return uri, kwargs
|
||||
|
||||
def _create_engine(
|
||||
self,
|
||||
database: "Database",
|
||||
uri: URL,
|
||||
kwargs: dict[str, Any],
|
||||
) -> Engine:
|
||||
try:
|
||||
return create_engine(uri, **kwargs)
|
||||
except Exception as ex:
|
||||
raise database.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
|
||||
def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
|
||||
kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
|
||||
tunnel = sshtunnel.open_tunnel(**kwargs)
|
||||
tunnel.start()
|
||||
return tunnel
|
||||
|
||||
def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]:
|
||||
from superset.utils.ssh_tunnel import get_default_port
|
||||
|
||||
backend = uri.get_backend_name()
|
||||
port = uri.port or get_default_port(backend)
|
||||
if not port:
|
||||
raise SSHTunnelDatabasePortError()
|
||||
|
||||
kwargs = {
|
||||
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
|
||||
"ssh_username": ssh_tunnel.username,
|
||||
"remote_bind_address": (uri.host, port),
|
||||
"local_bind_address": (self.local_bind_address,),
|
||||
}
|
||||
|
||||
if ssh_tunnel.password:
|
||||
kwargs["ssh_password"] = ssh_tunnel.password
|
||||
elif ssh_tunnel.private_key:
|
||||
private_key_file = StringIO(ssh_tunnel.private_key)
|
||||
private_key = RSAKey.from_private_key(
|
||||
private_key_file,
|
||||
ssh_tunnel.private_key_password,
|
||||
)
|
||||
kwargs["ssh_pkey"] = private_key
|
||||
|
||||
return kwargs
|
||||
@@ -42,7 +42,7 @@ from werkzeug.local import LocalProxy
|
||||
|
||||
from superset.async_events.async_query_manager import AsyncQueryManager
|
||||
from superset.async_events.async_query_manager_factory import AsyncQueryManagerFactory
|
||||
from superset.extensions.ssh import SSHManagerFactory
|
||||
from superset.extensions.engine_manager import EngineManagerExtension
|
||||
from superset.extensions.stats_logger import BaseStatsLoggerManager
|
||||
from superset.security.manager import SupersetSecurityManager
|
||||
from superset.utils.cache_manager import CacheManager
|
||||
@@ -146,6 +146,7 @@ cache_manager = CacheManager()
|
||||
celery_app = celery.Celery()
|
||||
csrf = CSRFProtect()
|
||||
db = get_sqla_class()()
|
||||
engine_manager_extension = EngineManagerExtension()
|
||||
_event_logger: dict[str, Any] = {}
|
||||
encrypted_field_factory = EncryptedFieldFactory()
|
||||
event_logger = LocalProxy(lambda: _event_logger.get("event_logger"))
|
||||
@@ -156,6 +157,5 @@ migrate = Migrate()
|
||||
profiling = ProfilingExtension()
|
||||
results_backend_manager = ResultsBackendManager()
|
||||
security_manager: SupersetSecurityManager = LocalProxy(lambda: appbuilder.sm)
|
||||
ssh_manager_factory = SSHManagerFactory()
|
||||
stats_logger_manager = BaseStatsLoggerManager()
|
||||
talisman = Talisman()
|
||||
|
||||
77
superset/extensions/engine_manager.py
Normal file
77
superset/extensions/engine_manager.py
Normal file
@@ -0,0 +1,77 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from superset.engines.manager import EngineManager
|
||||
from superset.utils.class_utils import load_class_from_name
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EngineManagerExtension:
|
||||
"""
|
||||
Flask extension for managing SQLAlchemy engines in Superset.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.engine_manager: EngineManager | None = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
"""
|
||||
Initialize the EngineManager with Flask app configuration.
|
||||
"""
|
||||
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
|
||||
db_connection_mutator = app.config["DB_CONNECTION_MUTATOR"]
|
||||
local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||
tunnel_timeout = timedelta(seconds=app.config["SSH_TUNNEL_TIMEOUT_SEC"])
|
||||
ssh_timeout = timedelta(seconds=app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"])
|
||||
|
||||
engine_manager_class: type[EngineManager] = load_class_from_name(
|
||||
app.config["ENGINE_MANAGER_CLASS"]
|
||||
)
|
||||
self.engine_manager = engine_manager_class(
|
||||
engine_context_manager,
|
||||
db_connection_mutator,
|
||||
local_bind_address,
|
||||
tunnel_timeout,
|
||||
ssh_timeout,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Initialized EngineManager with tunnel_timeout=%s, ssh_timeout=%s",
|
||||
tunnel_timeout.total_seconds(),
|
||||
ssh_timeout.total_seconds(),
|
||||
)
|
||||
|
||||
@property
|
||||
def manager(self) -> EngineManager:
|
||||
"""
|
||||
Get the EngineManager instance.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the extension hasn't been initialized with an app.
|
||||
"""
|
||||
if self.engine_manager is None:
|
||||
raise RuntimeError(
|
||||
"EngineManager extension not initialized. "
|
||||
"Call init_app() with a Flask app first."
|
||||
)
|
||||
return self.engine_manager
|
||||
@@ -1,94 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from io import StringIO
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import sshtunnel
|
||||
from flask import Flask
|
||||
from paramiko import RSAKey
|
||||
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.utils.class_utils import load_class_from_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
|
||||
class SSHManager:
|
||||
def __init__(self, app: Flask) -> None:
|
||||
super().__init__()
|
||||
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]
|
||||
sshtunnel.SSH_TIMEOUT = app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]
|
||||
|
||||
def build_sqla_url(
|
||||
self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder
|
||||
) -> str:
|
||||
# override any ssh tunnel configuration object
|
||||
url = make_url_safe(sqlalchemy_url)
|
||||
return url.set(
|
||||
host=server.local_bind_address[0],
|
||||
port=server.local_bind_port,
|
||||
)
|
||||
|
||||
def create_tunnel(
|
||||
self,
|
||||
ssh_tunnel: "SSHTunnel",
|
||||
sqlalchemy_database_uri: str,
|
||||
) -> sshtunnel.SSHTunnelForwarder:
|
||||
from superset.utils.ssh_tunnel import get_default_port
|
||||
|
||||
url = make_url_safe(sqlalchemy_database_uri)
|
||||
backend = url.get_backend_name()
|
||||
port = url.port or get_default_port(backend)
|
||||
if not port:
|
||||
raise SSHTunnelDatabasePortError()
|
||||
params = {
|
||||
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
|
||||
"ssh_username": ssh_tunnel.username,
|
||||
"remote_bind_address": (url.host, port),
|
||||
"local_bind_address": (self.local_bind_address,),
|
||||
"debug_level": logging.getLogger("flask_appbuilder").level,
|
||||
}
|
||||
|
||||
if ssh_tunnel.password:
|
||||
params["ssh_password"] = ssh_tunnel.password
|
||||
elif ssh_tunnel.private_key:
|
||||
private_key_file = StringIO(ssh_tunnel.private_key)
|
||||
private_key = RSAKey.from_private_key(
|
||||
private_key_file, ssh_tunnel.private_key_password
|
||||
)
|
||||
params["ssh_pkey"] = private_key
|
||||
|
||||
return sshtunnel.open_tunnel(**params)
|
||||
|
||||
|
||||
class SSHManagerFactory:
|
||||
def __init__(self) -> None:
|
||||
self._ssh_manager = None
|
||||
|
||||
def init_app(self, app: Flask) -> None:
|
||||
self._ssh_manager = load_class_from_name(
|
||||
app.config["SSH_TUNNEL_MANAGER_CLASS"]
|
||||
)(app)
|
||||
|
||||
@property
|
||||
def instance(self) -> SSHManager:
|
||||
return self._ssh_manager # type: ignore
|
||||
@@ -49,13 +49,13 @@ from superset.extensions import (
|
||||
csrf,
|
||||
db,
|
||||
encrypted_field_factory,
|
||||
engine_manager_extension,
|
||||
feature_flag_manager,
|
||||
machine_auth_provider_factory,
|
||||
manifest_processor,
|
||||
migrate,
|
||||
profiling,
|
||||
results_backend_manager,
|
||||
ssh_manager_factory,
|
||||
stats_logger_manager,
|
||||
talisman,
|
||||
)
|
||||
@@ -616,8 +616,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
self.configure_url_map_converters()
|
||||
self.configure_data_sources()
|
||||
self.configure_auth_provider()
|
||||
self.configure_engine_manager()
|
||||
self.configure_async_queries()
|
||||
self.configure_ssh_manager()
|
||||
self.configure_stats_manager()
|
||||
self.configure_task_manager()
|
||||
|
||||
@@ -793,8 +793,8 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods
|
||||
def configure_auth_provider(self) -> None:
|
||||
machine_auth_provider_factory.init_app(self.superset_app)
|
||||
|
||||
def configure_ssh_manager(self) -> None:
|
||||
ssh_manager_factory.init_app(self.superset_app)
|
||||
def configure_engine_manager(self) -> None:
|
||||
engine_manager_extension.init_app(self.superset_app)
|
||||
|
||||
def configure_stats_manager(self) -> None:
|
||||
stats_logger_manager.init_app(self.superset_app)
|
||||
|
||||
@@ -25,24 +25,22 @@ import builtins
|
||||
import logging
|
||||
import textwrap
|
||||
from ast import literal_eval
|
||||
from contextlib import closing, contextmanager, nullcontext, suppress
|
||||
from contextlib import closing, contextmanager, suppress
|
||||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from functools import lru_cache
|
||||
from inspect import signature
|
||||
from typing import Any, Callable, cast, Optional, TYPE_CHECKING
|
||||
from typing import Any, Callable, cast, Iterator, Optional, TYPE_CHECKING
|
||||
|
||||
import numpy
|
||||
import pandas as pd
|
||||
import sqlalchemy as sqla
|
||||
import sshtunnel
|
||||
from flask import current_app as app, g, has_app_context
|
||||
from flask_appbuilder import Model
|
||||
from marshmallow.exceptions import ValidationError
|
||||
from sqlalchemy import (
|
||||
Boolean,
|
||||
Column,
|
||||
create_engine,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
@@ -57,7 +55,6 @@ from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.exc import NoSuchModuleError
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.pool import NullPool
|
||||
from sqlalchemy.schema import UniqueConstraint
|
||||
from sqlalchemy.sql import ColumnElement, expression, Select
|
||||
from superset_core.common.models import Database as CoreDatabase
|
||||
@@ -72,7 +69,6 @@ from superset.extensions import (
|
||||
encrypted_field_factory,
|
||||
event_logger,
|
||||
security_manager,
|
||||
ssh_manager_factory,
|
||||
)
|
||||
from superset.models.helpers import AuditMixinNullable, ImportExportMixin, UUIDMixin
|
||||
from superset.result_set import SupersetResultSet
|
||||
@@ -84,10 +80,9 @@ from superset.superset_typing import (
|
||||
)
|
||||
from superset.utils import cache as cache_util, core as utils, json
|
||||
from superset.utils.backports import StrEnum
|
||||
from superset.utils.core import get_query_source_from_request, get_username
|
||||
from superset.utils.core import get_username
|
||||
from superset.utils.oauth2 import (
|
||||
check_for_oauth2,
|
||||
get_oauth2_access_token,
|
||||
OAuth2ClientConfigSchema,
|
||||
)
|
||||
|
||||
@@ -424,130 +419,46 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def get_sqla_engine( # pylint: disable=too-many-arguments
|
||||
def get_sqla_engine(
|
||||
self,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
nullpool: bool = True,
|
||||
source: utils.QuerySource | None = None,
|
||||
) -> Engine:
|
||||
nullpool: bool | None = None,
|
||||
) -> Iterator[Engine]:
|
||||
"""
|
||||
Context manager for a SQLAlchemy engine.
|
||||
|
||||
This method will return a context manager for a SQLAlchemy engine. Using the
|
||||
context manager (as opposed to the engine directly) is important because we need
|
||||
to potentially establish SSH tunnels before the connection is created, and clean
|
||||
them up once the engine is no longer used.
|
||||
This method will return a context manager for a SQLAlchemy engine. The engine
|
||||
manager handles engine creation, SSH tunnels, and connection details in a
|
||||
centralized place.
|
||||
|
||||
The ``nullpool`` argument is deprecated and ignored — the engine manager
|
||||
always uses ``NullPool``. It is kept temporarily for backwards compatibility
|
||||
with external callers and will be removed in a future release.
|
||||
"""
|
||||
if nullpool is not None:
|
||||
import warnings
|
||||
|
||||
sqlalchemy_uri = self.sqlalchemy_uri_decrypted
|
||||
|
||||
ssh_context_manager = (
|
||||
ssh_manager_factory.instance.create_tunnel(
|
||||
ssh_tunnel=self.ssh_tunnel,
|
||||
sqlalchemy_database_uri=sqlalchemy_uri,
|
||||
warnings.warn(
|
||||
"The `nullpool` argument to `Database.get_sqla_engine` is "
|
||||
"deprecated and ignored; the engine manager always uses NullPool.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if self.ssh_tunnel
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
with ssh_context_manager as ssh_context:
|
||||
if ssh_context:
|
||||
logger.info(
|
||||
"[SSH] Successfully created tunnel w/ %s tunnel_timeout + %s "
|
||||
"ssh_timeout at %s",
|
||||
sshtunnel.TUNNEL_TIMEOUT,
|
||||
sshtunnel.SSH_TIMEOUT,
|
||||
ssh_context.local_bind_address,
|
||||
)
|
||||
sqlalchemy_uri = ssh_manager_factory.instance.build_sqla_url(
|
||||
sqlalchemy_uri,
|
||||
ssh_context,
|
||||
)
|
||||
# Import here to avoid circular imports
|
||||
from superset.extensions import engine_manager_extension
|
||||
|
||||
engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"]
|
||||
with engine_context_manager(self, catalog, schema):
|
||||
with check_for_oauth2(self):
|
||||
yield self._get_sqla_engine(
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
nullpool=nullpool,
|
||||
source=source,
|
||||
sqlalchemy_uri=sqlalchemy_uri,
|
||||
)
|
||||
|
||||
def _get_sqla_engine( # pylint: disable=too-many-locals # noqa: C901
|
||||
self,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
nullpool: bool = True,
|
||||
source: utils.QuerySource | None = None,
|
||||
sqlalchemy_uri: str | None = None,
|
||||
) -> Engine:
|
||||
sqlalchemy_url = make_url_safe(
|
||||
sqlalchemy_uri if sqlalchemy_uri else self.sqlalchemy_uri_decrypted
|
||||
)
|
||||
self.db_engine_spec.validate_database_uri(sqlalchemy_url)
|
||||
|
||||
extra = self.get_extra(source)
|
||||
engine_kwargs = extra.get("engine_params", {})
|
||||
if nullpool:
|
||||
engine_kwargs["poolclass"] = NullPool
|
||||
connect_args = engine_kwargs.setdefault("connect_args", {})
|
||||
|
||||
# modify URL/args for a specific catalog/schema
|
||||
sqlalchemy_url, connect_args = self.db_engine_spec.adjust_engine_params(
|
||||
uri=sqlalchemy_url,
|
||||
connect_args=connect_args,
|
||||
# Use the engine manager to get the engine
|
||||
engine_manager = engine_manager_extension.manager
|
||||
with engine_manager.get_engine(
|
||||
database=self,
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
)
|
||||
|
||||
effective_username = self.get_effective_user(sqlalchemy_url)
|
||||
if effective_username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"):
|
||||
user = security_manager.find_user(username=effective_username)
|
||||
if user and user.email:
|
||||
effective_username = user.email.split("@")[0]
|
||||
|
||||
oauth2_config = self.get_oauth2_config()
|
||||
access_token = (
|
||||
get_oauth2_access_token(
|
||||
oauth2_config,
|
||||
self.id,
|
||||
g.user.id,
|
||||
self.db_engine_spec,
|
||||
)
|
||||
if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
|
||||
else None
|
||||
)
|
||||
masked_url = self.get_password_masked_url(sqlalchemy_url)
|
||||
logger.debug("Database._get_sqla_engine(). Masked URL: %s", str(masked_url))
|
||||
|
||||
if self.impersonate_user:
|
||||
sqlalchemy_url, engine_kwargs = self.db_engine_spec.impersonate_user(
|
||||
self,
|
||||
effective_username,
|
||||
access_token,
|
||||
sqlalchemy_url,
|
||||
engine_kwargs,
|
||||
)
|
||||
|
||||
self.update_params_from_encrypted_extra(engine_kwargs)
|
||||
|
||||
if DB_CONNECTION_MUTATOR := app.config["DB_CONNECTION_MUTATOR"]: # noqa: N806
|
||||
source = source or get_query_source_from_request()
|
||||
|
||||
sqlalchemy_url, engine_kwargs = DB_CONNECTION_MUTATOR(
|
||||
sqlalchemy_url,
|
||||
engine_kwargs,
|
||||
effective_username,
|
||||
security_manager,
|
||||
source,
|
||||
)
|
||||
try:
|
||||
return create_engine(sqlalchemy_url, **engine_kwargs)
|
||||
except Exception as ex:
|
||||
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
|
||||
source=source,
|
||||
) as engine:
|
||||
yield engine
|
||||
|
||||
def add_database_to_signature(
|
||||
self,
|
||||
@@ -572,13 +483,11 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint:
|
||||
self,
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
nullpool: bool = True,
|
||||
source: utils.QuerySource | None = None,
|
||||
) -> Connection:
|
||||
with self.get_sqla_engine(
|
||||
catalog=catalog,
|
||||
schema=schema,
|
||||
nullpool=nullpool,
|
||||
source=source,
|
||||
) as engine:
|
||||
with check_for_oauth2(self):
|
||||
|
||||
@@ -18,14 +18,28 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Hashable, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TYPE_CHECKING, TypeAlias, TypedDict
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
ContextManager,
|
||||
Literal,
|
||||
TYPE_CHECKING,
|
||||
TypeAlias,
|
||||
TypedDict,
|
||||
)
|
||||
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.sql.type_api import TypeEngine
|
||||
from typing_extensions import NotRequired
|
||||
from werkzeug.wrappers import Response
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.utils.core import GenericDataType, QueryObjectFilterClause
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import (
|
||||
GenericDataType,
|
||||
QueryObjectFilterClause,
|
||||
QuerySource,
|
||||
)
|
||||
|
||||
SQLType: TypeAlias = TypeEngine | type[TypeEngine]
|
||||
|
||||
@@ -70,6 +84,18 @@ class DatasetMetricData(TypedDict, total=False):
|
||||
verbose_name: str | None
|
||||
|
||||
|
||||
# Type alias for database connection mutator function
|
||||
DBConnectionMutator: TypeAlias = Callable[
|
||||
[URL, dict[str, Any], str | None, Any, "QuerySource | None"],
|
||||
tuple[URL, dict[str, Any]],
|
||||
]
|
||||
|
||||
# Type alias for engine context manager
|
||||
EngineContextManager: TypeAlias = Callable[
|
||||
["Database", str | None, str | None], ContextManager[None]
|
||||
]
|
||||
|
||||
|
||||
class LegacyMetric(TypedDict):
|
||||
label: str | None
|
||||
|
||||
|
||||
@@ -170,7 +170,6 @@ def example_db_provider() -> Callable[[], Database]:
|
||||
return self._db
|
||||
|
||||
def _load_lazy_data_to_decouple_from_session(self) -> None:
|
||||
self._db._get_sqla_engine() # type: ignore
|
||||
self._db.backend # type: ignore # noqa: B018
|
||||
|
||||
def remove(self) -> None:
|
||||
|
||||
@@ -897,7 +897,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
|
||||
|
||||
|
||||
class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_db_exception(
|
||||
@@ -906,19 +906,23 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure event_logger is called when an exception is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.side_effect = Exception("An error has occurred!")
|
||||
mock_get_sqla_engine.return_value.__enter__.side_effect = Exception(
|
||||
"An error has occurred!"
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: # noqa: PT012
|
||||
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo:
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == (
|
||||
"Unexpected error occurred, please check your logs for details"
|
||||
)
|
||||
# Exception wraps errors from db_engine_spec.extract_errors()
|
||||
assert (
|
||||
excinfo.value.errors[0].error_type
|
||||
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
|
||||
)
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_do_ping_exception(
|
||||
@@ -927,9 +931,8 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure do_ping exceptions gets captured"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = Exception(
|
||||
"An error has occurred!"
|
||||
)
|
||||
mock_engine = mock_get_sqla_engine.return_value.__enter__.return_value
|
||||
mock_engine.dialect.do_ping.side_effect = Exception("An error has occurred!")
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
@@ -967,7 +970,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
== SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
|
||||
)
|
||||
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_superset_security_connection(
|
||||
@@ -977,20 +980,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
connection exc is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.side_effect = SupersetSecurityException(
|
||||
SupersetError(error_type=500, message="test", level="info")
|
||||
mock_get_sqla_engine.return_value.__enter__.side_effect = (
|
||||
SupersetSecurityException(
|
||||
SupersetError(error_type=500, message="test", level="info")
|
||||
)
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(DatabaseSecurityUnsafeError) as excinfo: # noqa: PT012
|
||||
with pytest.raises(DatabaseSecurityUnsafeError):
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == ("Stopped an unsafe database connection")
|
||||
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@patch("superset.models.core.Database._get_sqla_engine")
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_connection_db_api_exc(
|
||||
@@ -999,19 +1002,20 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure event_logger is called when DBAPIError is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.side_effect = DBAPIError(
|
||||
mock_get_sqla_engine.return_value.__enter__.side_effect = DBAPIError(
|
||||
statement="error", params={}, orig={}
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(SupersetErrorsException) as excinfo: # noqa: PT012
|
||||
with pytest.raises(SupersetErrorsException) as excinfo:
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == (
|
||||
"Connection failed, please check your connection settings"
|
||||
)
|
||||
|
||||
# Exception wraps errors from db_engine_spec.extract_errors()
|
||||
assert (
|
||||
excinfo.value.errors[0].error_type
|
||||
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
|
||||
)
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
|
||||
@@ -1147,7 +1151,7 @@ class TestTablesDatabaseCommand(SupersetTestCase):
|
||||
|
||||
with pytest.raises(DatabaseNotFoundError) as excinfo: # noqa: PT012
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("Database not found.")
|
||||
assert str(excinfo.value) == ("Database not found.")
|
||||
|
||||
@patch("superset.daos.database.DatabaseDAO.find_by_id")
|
||||
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
|
||||
@@ -1166,26 +1170,35 @@ class TestTablesDatabaseCommand(SupersetTestCase):
|
||||
command = TablesDatabaseCommand(database.id, None, "main", False)
|
||||
with pytest.raises(SupersetException) as excinfo: # noqa: PT012
|
||||
command.run()
|
||||
assert str(excinfo.value) == "Test Error"
|
||||
assert str(excinfo.value) == "Test Error"
|
||||
|
||||
@patch("superset.daos.database.DatabaseDAO.find_by_id")
|
||||
@patch("superset.models.core.Database.get_all_materialized_view_names_in_schema")
|
||||
@patch("superset.models.core.Database.get_all_view_names_in_schema")
|
||||
@patch("superset.models.core.Database.get_all_table_names_in_schema")
|
||||
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
|
||||
@patch("superset.utils.core.g")
|
||||
def test_database_tables_exception(
|
||||
self, mock_g, mock_can_access_database, mock_find_by_id
|
||||
self,
|
||||
mock_g,
|
||||
mock_can_access_database,
|
||||
mock_get_tables,
|
||||
mock_get_views,
|
||||
mock_get_mvs,
|
||||
mock_find_by_id,
|
||||
):
|
||||
database = get_example_database()
|
||||
mock_find_by_id.return_value = database
|
||||
mock_get_tables.return_value = {("table1", "main", None)}
|
||||
mock_get_views.return_value = set()
|
||||
mock_get_mvs.return_value = []
|
||||
mock_can_access_database.side_effect = Exception("Test Error")
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
|
||||
command = TablesDatabaseCommand(database.id, None, "main", False)
|
||||
with pytest.raises(DatabaseTablesUnexpectedError) as excinfo: # noqa: PT012
|
||||
command.run()
|
||||
assert (
|
||||
str(excinfo.value)
|
||||
== "Unexpected error occurred, please check your logs for details"
|
||||
)
|
||||
assert str(excinfo.value) == "Test Error"
|
||||
|
||||
@patch("superset.daos.database.DatabaseDAO.find_by_id")
|
||||
@patch("superset.security.manager.SupersetSecurityManager.can_access_database")
|
||||
|
||||
@@ -145,7 +145,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
username = make_url(engine.url).username
|
||||
assert example_user.username != username
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
|
||||
)
|
||||
@@ -172,7 +172,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://gamma@localhost/"
|
||||
@@ -185,7 +186,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "presto://localhost/"
|
||||
@@ -199,13 +201,14 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("mysqlclient"), "mysqlclient not installed"
|
||||
)
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
def test_adjust_engine_params_mysql(self, mocked_create_engine):
|
||||
model = Database(
|
||||
database_name="test_database1",
|
||||
sqlalchemy_uri="mysql://user:password@localhost",
|
||||
)
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "mysql://user:password@localhost"
|
||||
@@ -215,13 +218,14 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database2",
|
||||
sqlalchemy_uri="mysql+mysqlconnector://user:password@localhost",
|
||||
)
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "mysql+mysqlconnector://user:password@localhost"
|
||||
assert call_args[1]["connect_args"]["allow_local_infile"] == 0
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
def test_impersonate_user_trino(self, mocked_create_engine):
|
||||
principal_user = security_manager.find_user(username="gamma")
|
||||
|
||||
@@ -230,7 +234,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database", sqlalchemy_uri="trino://localhost"
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "trino://localhost/"
|
||||
@@ -242,7 +247,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
)
|
||||
|
||||
model.impersonate_user = True
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert (
|
||||
@@ -251,7 +257,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
)
|
||||
assert call_args[1]["connect_args"]["user"] == "gamma"
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
@unittest.skipUnless(
|
||||
SupersetTestCase.is_module_installed("pyhive"), "pyhive not installed"
|
||||
)
|
||||
@@ -281,7 +287,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
database_name="test_database", sqlalchemy_uri=uri, extra=extra
|
||||
)
|
||||
model.impersonate_user = True
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
@@ -294,7 +301,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
}
|
||||
|
||||
model.impersonate_user = False
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
call_args = mocked_create_engine.call_args
|
||||
|
||||
assert str(call_args[0][0]) == "hive://localhost"
|
||||
@@ -376,7 +384,7 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
df = main_db.get_df("USE superset; SELECT ';';", None, None)
|
||||
assert df.iat[0, 0] == ";"
|
||||
|
||||
@mock.patch("superset.models.core.create_engine")
|
||||
@mock.patch("superset.engines.manager.create_engine")
|
||||
def test_get_sqla_engine(self, mocked_create_engine):
|
||||
model = Database(
|
||||
database_name="test_database",
|
||||
@@ -387,7 +395,8 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
)
|
||||
mocked_create_engine.side_effect = Exception()
|
||||
with self.assertRaises(SupersetException): # noqa: PT027
|
||||
model._get_sqla_engine()
|
||||
with model.get_sqla_engine():
|
||||
pass
|
||||
|
||||
|
||||
class TestSqlaTableModel(SupersetTestCase):
|
||||
|
||||
109
tests/unit_tests/engines/manager_test.py
Normal file
109
tests/unit_tests/engines/manager_test.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine.url import make_url
|
||||
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
|
||||
from superset.engines.manager import EngineManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine_manager() -> EngineManager:
|
||||
@contextmanager
|
||||
def dummy_context_manager(
|
||||
database: MagicMock,
|
||||
catalog: str | None,
|
||||
schema: str | None,
|
||||
):
|
||||
yield
|
||||
|
||||
return EngineManager(engine_context_manager=dummy_context_manager)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_database() -> MagicMock:
|
||||
database = MagicMock()
|
||||
database.id = 1
|
||||
database.sqlalchemy_uri_decrypted = "trino://"
|
||||
database.get_extra.return_value = {"engine_params": {"poolclass": "queue"}}
|
||||
database.get_effective_user.return_value = "alice"
|
||||
database.impersonate_user = False
|
||||
database.update_params_from_encrypted_extra = MagicMock()
|
||||
database.db_engine_spec = MagicMock()
|
||||
database.db_engine_spec.adjust_engine_params.return_value = (
|
||||
make_url("trino://"),
|
||||
{"source": "Apache Superset"},
|
||||
)
|
||||
database.db_engine_spec.validate_database_uri = MagicMock()
|
||||
return database
|
||||
|
||||
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_uses_null_pool(
|
||||
mock_make_url: MagicMock,
|
||||
engine_manager: EngineManager,
|
||||
mock_database: MagicMock,
|
||||
) -> None:
|
||||
mock_make_url.return_value = make_url("trino://")
|
||||
|
||||
_, kwargs = engine_manager._get_engine_args(mock_database, None, None, None, None)
|
||||
|
||||
assert kwargs["poolclass"] is pool.NullPool
|
||||
|
||||
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
def test_get_engine_args_with_impersonation(
|
||||
mock_make_url: MagicMock,
|
||||
engine_manager: EngineManager,
|
||||
mock_database: MagicMock,
|
||||
) -> None:
|
||||
mock_make_url.return_value = make_url("trino://")
|
||||
mock_database.impersonate_user = True
|
||||
mock_database.get_oauth2_config.return_value = None
|
||||
mock_database.db_engine_spec.impersonate_user.return_value = (
|
||||
make_url("trino://"),
|
||||
{"connect_args": {"user": "alice"}, "poolclass": pool.NullPool},
|
||||
)
|
||||
|
||||
engine_manager._get_engine_args(mock_database, None, None, None, None)
|
||||
|
||||
mock_database.db_engine_spec.impersonate_user.assert_called_once()
|
||||
|
||||
|
||||
def test_get_tunnel_kwargs_requires_database_port(
|
||||
engine_manager: EngineManager,
|
||||
) -> None:
|
||||
ssh_tunnel = MagicMock()
|
||||
ssh_tunnel.server_address = "ssh.example.com"
|
||||
ssh_tunnel.server_port = 22
|
||||
ssh_tunnel.username = "ssh_user"
|
||||
ssh_tunnel.password = None
|
||||
ssh_tunnel.private_key = None
|
||||
ssh_tunnel.private_key_password = None
|
||||
|
||||
uri = MagicMock()
|
||||
uri.port = None
|
||||
uri.get_backend_name.return_value = "unknown"
|
||||
|
||||
with patch("superset.utils.ssh_tunnel.get_default_port", return_value=None):
|
||||
with pytest.raises(SSHTunnelDatabasePortError):
|
||||
engine_manager._get_tunnel_kwargs(ssh_tunnel, uri)
|
||||
@@ -124,7 +124,7 @@ class TestSupersetAppInitializer:
|
||||
patch.object(app_initializer, "configure_data_sources"),
|
||||
patch.object(app_initializer, "configure_auth_provider"),
|
||||
patch.object(app_initializer, "configure_async_queries"),
|
||||
patch.object(app_initializer, "configure_ssh_manager"),
|
||||
patch.object(app_initializer, "configure_engine_manager"),
|
||||
patch.object(app_initializer, "configure_stats_manager"),
|
||||
patch.object(app_initializer, "init_views"),
|
||||
):
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from flask import current_app
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
@@ -29,7 +28,6 @@ from sqlalchemy import (
|
||||
Table as SqlalchemyTable,
|
||||
)
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.orm.session import Session
|
||||
from sqlalchemy.sql import Select
|
||||
|
||||
@@ -525,60 +523,6 @@ def test_get_all_materialized_view_names_in_schema_needs_oauth2(
|
||||
assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT
|
||||
|
||||
|
||||
def test_get_sqla_engine(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test `_get_sqla_engine`.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
user = mocker.MagicMock()
|
||||
user.email = "alice.doe@example.org"
|
||||
mocker.patch(
|
||||
"superset.models.core.security_manager.find_user",
|
||||
return_value=user,
|
||||
)
|
||||
mocker.patch("superset.models.core.get_username", return_value="alice")
|
||||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
make_url("trino:///"),
|
||||
connect_args={"source": "Apache Superset"},
|
||||
)
|
||||
|
||||
|
||||
def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test user impersonation in `_get_sqla_engine`.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
user = mocker.MagicMock()
|
||||
user.email = "alice.doe@example.org"
|
||||
mocker.patch(
|
||||
"superset.models.core.security_manager.find_user",
|
||||
return_value=user,
|
||||
)
|
||||
mocker.patch("superset.models.core.get_username", return_value="alice")
|
||||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="trino://",
|
||||
impersonate_user=True,
|
||||
)
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
make_url("trino:///"),
|
||||
connect_args={"user": "alice", "source": "Apache Superset"},
|
||||
)
|
||||
|
||||
|
||||
def test_add_database_to_signature():
|
||||
args = ["param1", "param2"]
|
||||
|
||||
@@ -604,36 +548,6 @@ def test_add_database_to_signature():
|
||||
assert args3 == ["param1", "param2", database]
|
||||
|
||||
|
||||
@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
|
||||
def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test user impersonation in `_get_sqla_engine` with `username_from_email`.
|
||||
"""
|
||||
from superset.models.core import Database
|
||||
|
||||
user = mocker.MagicMock()
|
||||
user.email = "alice.doe@example.org"
|
||||
mocker.patch(
|
||||
"superset.models.core.security_manager.find_user",
|
||||
return_value=user,
|
||||
)
|
||||
mocker.patch("superset.models.core.get_username", return_value="alice")
|
||||
|
||||
create_engine = mocker.patch("superset.models.core.create_engine")
|
||||
|
||||
database = Database(
|
||||
database_name="my_db",
|
||||
sqlalchemy_uri="trino://",
|
||||
impersonate_user=True,
|
||||
)
|
||||
database._get_sqla_engine(nullpool=False)
|
||||
|
||||
create_engine.assert_called_with(
|
||||
make_url("trino:///"),
|
||||
connect_args={"user": "alice.doe", "source": "Apache Superset"},
|
||||
)
|
||||
|
||||
|
||||
def test_is_oauth2_enabled() -> None:
|
||||
"""
|
||||
Test the `is_oauth2_enabled` method.
|
||||
@@ -775,8 +689,8 @@ def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
|
||||
encrypted_extra=json.dumps(oauth2_client_info),
|
||||
)
|
||||
database.db_engine_spec.oauth2_exception = OAuth2Error
|
||||
_get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
|
||||
_get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
|
||||
create_engine = mocker.patch("superset.engines.manager.create_engine")
|
||||
create_engine.side_effect = OAuth2Error("OAuth2 required")
|
||||
|
||||
with pytest.raises(OAuth2RedirectError) as excinfo:
|
||||
with database.get_raw_connection() as conn:
|
||||
@@ -879,56 +793,6 @@ def test_get_schema_access_for_file_upload() -> None:
|
||||
assert database.get_schema_access_for_file_upload() == {"public"}
|
||||
|
||||
|
||||
def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> None:
|
||||
"""
|
||||
Test the engine context manager.
|
||||
"""
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
engine_context_manager = MagicMock()
|
||||
mocker.patch.dict(
|
||||
current_app.config,
|
||||
{"ENGINE_CONTEXT_MANAGER": engine_context_manager},
|
||||
)
|
||||
_get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")
|
||||
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
with database.get_sqla_engine("catalog", "schema"):
|
||||
pass
|
||||
|
||||
engine_context_manager.assert_called_once_with(database, "catalog", "schema")
|
||||
engine_context_manager().__enter__.assert_called_once()
|
||||
engine_context_manager().__exit__.assert_called_once_with(None, None, None)
|
||||
_get_sqla_engine.assert_called_once_with(
|
||||
catalog="catalog",
|
||||
schema="schema",
|
||||
nullpool=True,
|
||||
source=None,
|
||||
sqlalchemy_uri="trino://",
|
||||
)
|
||||
|
||||
|
||||
def test_engine_oauth2(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that we handle OAuth2 when `create_engine` fails.
|
||||
"""
|
||||
database = Database(database_name="my_db", sqlalchemy_uri="trino://")
|
||||
mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
|
||||
mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
|
||||
mocker.patch.object(database.db_engine_spec, "needs_oauth2", return_value=True)
|
||||
start_oauth2_dance = mocker.patch.object(
|
||||
database.db_engine_spec,
|
||||
"start_oauth2_dance",
|
||||
side_effect=OAuth2Error("OAuth2 required"),
|
||||
)
|
||||
|
||||
with pytest.raises(OAuth2Error):
|
||||
with database.get_sqla_engine("catalog", "schema"):
|
||||
pass
|
||||
|
||||
start_oauth2_dance.assert_called_with(database)
|
||||
|
||||
|
||||
def test_purge_oauth2_tokens(session: Session) -> None:
|
||||
"""
|
||||
Test the `purge_oauth2_tokens` method.
|
||||
|
||||
@@ -202,7 +202,6 @@ def setup_mock_raw_connection(
|
||||
def _raw_connection(
|
||||
catalog: str | None = None,
|
||||
schema: str | None = None,
|
||||
nullpool: bool = True,
|
||||
source: Any | None = None,
|
||||
):
|
||||
yield mock_connection
|
||||
|
||||
Reference in New Issue
Block a user