mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
chore(ssh): Allow users to set TUNNEL_TIMEOUT from config (#24202)
This commit is contained in:
@@ -515,6 +515,7 @@ DEFAULT_FEATURE_FLAGS: Dict[str, bool] = {
|
|||||||
# ----------------------------------------------------------------------
|
# ----------------------------------------------------------------------
|
||||||
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
|
SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager"
|
||||||
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
|
SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1"
|
||||||
|
SSH_TUNNEL_TIMEOUT_SEC = 10.0
|
||||||
|
|
||||||
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
|
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
|
||||||
DEFAULT_FEATURE_FLAGS.update(
|
DEFAULT_FEATURE_FLAGS.update(
|
||||||
|
|||||||
@@ -20,9 +20,9 @@ import logging
|
|||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import sshtunnel
|
||||||
from flask import Flask
|
from flask import Flask
|
||||||
from paramiko import RSAKey
|
from paramiko import RSAKey
|
||||||
from sshtunnel import open_tunnel, SSHTunnelForwarder
|
|
||||||
|
|
||||||
from superset.databases.utils import make_url_safe
|
from superset.databases.utils import make_url_safe
|
||||||
|
|
||||||
@@ -34,9 +34,10 @@ class SSHManager:
|
|||||||
def __init__(self, app: Flask) -> None:
|
def __init__(self, app: Flask) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||||
|
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]
|
||||||
|
|
||||||
def build_sqla_url( # pylint: disable=no-self-use
|
def build_sqla_url( # pylint: disable=no-self-use
|
||||||
self, sqlalchemy_url: str, server: SSHTunnelForwarder
|
self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder
|
||||||
) -> str:
|
) -> str:
|
||||||
# override any ssh tunnel configuration object
|
# override any ssh tunnel configuration object
|
||||||
url = make_url_safe(sqlalchemy_url)
|
url = make_url_safe(sqlalchemy_url)
|
||||||
@@ -49,7 +50,7 @@ class SSHManager:
|
|||||||
self,
|
self,
|
||||||
ssh_tunnel: "SSHTunnel",
|
ssh_tunnel: "SSHTunnel",
|
||||||
sqlalchemy_database_uri: str,
|
sqlalchemy_database_uri: str,
|
||||||
) -> SSHTunnelForwarder:
|
) -> sshtunnel.SSHTunnelForwarder:
|
||||||
url = make_url_safe(sqlalchemy_database_uri)
|
url = make_url_safe(sqlalchemy_database_uri)
|
||||||
params = {
|
params = {
|
||||||
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
|
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
|
||||||
@@ -68,7 +69,7 @@ class SSHManager:
|
|||||||
)
|
)
|
||||||
params["ssh_pkey"] = private_key
|
params["ssh_pkey"] = private_key
|
||||||
|
|
||||||
return open_tunnel(**params)
|
return sshtunnel.open_tunnel(**params)
|
||||||
|
|
||||||
|
|
||||||
class SSHManagerFactory:
|
class SSHManagerFactory:
|
||||||
|
|||||||
36
tests/unit_tests/extensions/ssh_test.py
Normal file
36
tests/unit_tests/extensions/ssh_test.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import sshtunnel
|
||||||
|
|
||||||
|
from superset.extensions.ssh import SSHManagerFactory
|
||||||
|
|
||||||
|
|
||||||
|
def test_ssh_tunnel_timeout_setting() -> None:
|
||||||
|
app = Mock()
|
||||||
|
app.config = {
|
||||||
|
"SSH_TUNNEL_MAX_RETRIES": 2,
|
||||||
|
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test",
|
||||||
|
"SSH_TUNNEL_TIMEOUT_SEC": 123.0,
|
||||||
|
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
|
||||||
|
}
|
||||||
|
factory = SSHManagerFactory()
|
||||||
|
factory.init_app(app)
|
||||||
|
assert sshtunnel.TUNNEL_TIMEOUT == 123.0
|
||||||
Reference in New Issue
Block a user