mirror of
https://github.com/apache/superset.git
synced 2026-05-06 16:34:32 +00:00
fix: SSH tunnel and test connection error handling
- Use sshtunnel.open_tunnel() instead of SSHTunnelForwarder directly to properly handle debug_level parameter - Fix keepalive parameter name (set_keepalive, not keepalive) - Fix test assertions that were inside pytest.raises blocks and never executed - now check error_type instead of string messages - Update SSH tunnel test mocks to patch open_tunnel Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -495,7 +495,8 @@ class EngineManager:
|
||||
|
||||
def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder:
|
||||
kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
|
||||
tunnel = SSHTunnelForwarder(**kwargs)
|
||||
# Use open_tunnel which handles debug_level properly
|
||||
tunnel = sshtunnel.open_tunnel(**kwargs)
|
||||
tunnel.start()
|
||||
|
||||
return tunnel
|
||||
@@ -524,7 +525,7 @@ class EngineManager:
|
||||
kwargs["ssh_pkey"] = private_key
|
||||
|
||||
if self.mode == EngineModes.NEW:
|
||||
kwargs["keepalive"] = 0 # disable
|
||||
kwargs["set_keepalive"] = 0 # disable keepalive for one-time tunnels
|
||||
|
||||
return kwargs
|
||||
|
||||
|
||||
@@ -913,10 +913,12 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: # noqa: PT012
|
||||
with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo:
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == (
|
||||
"Unexpected error occurred, please check your logs for details"
|
||||
# Exception wraps errors from db_engine_spec.extract_errors()
|
||||
assert (
|
||||
excinfo.value.errors[0].error_type
|
||||
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
|
||||
)
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@@ -929,9 +931,8 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
"""Test to make sure do_ping exceptions gets captured"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.return_value.__enter__.return_value.dialect.do_ping.side_effect = Exception(
|
||||
"An error has occurred!"
|
||||
)
|
||||
mock_engine = mock_get_sqla_engine.return_value.__enter__.return_value
|
||||
mock_engine.dialect.do_ping.side_effect = Exception("An error has occurred!")
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
@@ -979,17 +980,17 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
connection exc is raised"""
|
||||
database = get_example_database()
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
mock_get_sqla_engine.return_value.__enter__.side_effect = SupersetSecurityException(
|
||||
SupersetError(error_type=500, message="test", level="info")
|
||||
mock_get_sqla_engine.return_value.__enter__.side_effect = (
|
||||
SupersetSecurityException(
|
||||
SupersetError(error_type=500, message="test", level="info")
|
||||
)
|
||||
)
|
||||
db_uri = database.sqlalchemy_uri_decrypted
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(DatabaseSecurityUnsafeError) as excinfo: # noqa: PT012
|
||||
with pytest.raises(DatabaseSecurityUnsafeError):
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == ("Stopped an unsafe database connection")
|
||||
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
@patch("superset.models.core.Database.get_sqla_engine")
|
||||
@@ -1008,12 +1009,13 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
|
||||
json_payload = {"sqlalchemy_uri": db_uri}
|
||||
command_without_db_name = TestConnectionDatabaseCommand(json_payload)
|
||||
|
||||
with pytest.raises(SupersetErrorsException) as excinfo: # noqa: PT012
|
||||
with pytest.raises(SupersetErrorsException) as excinfo:
|
||||
command_without_db_name.run()
|
||||
assert str(excinfo.value) == (
|
||||
"Connection failed, please check your connection settings"
|
||||
# Exception wraps errors from db_engine_spec.extract_errors()
|
||||
assert (
|
||||
excinfo.value.errors[0].error_type
|
||||
== SupersetErrorType.GENERIC_DB_ENGINE_ERROR
|
||||
)
|
||||
|
||||
mock_event_logger.assert_called()
|
||||
|
||||
|
||||
|
||||
@@ -223,8 +223,8 @@ class TestEngineManager:
|
||||
for engine in results:
|
||||
assert engine is real_engine
|
||||
|
||||
@patch("superset.engines.manager.SSHTunnelForwarder")
|
||||
def test_ssh_tunnel_creation(self, mock_tunnel_class, engine_manager):
|
||||
@patch("superset.engines.manager.sshtunnel.open_tunnel")
|
||||
def test_ssh_tunnel_creation(self, mock_open_tunnel, engine_manager):
|
||||
"""Test SSH tunnel creation and caching."""
|
||||
ssh_tunnel = MagicMock()
|
||||
ssh_tunnel.server_address = "ssh.example.com"
|
||||
@@ -237,7 +237,7 @@ class TestEngineManager:
|
||||
tunnel_instance = MagicMock()
|
||||
tunnel_instance.is_active = True
|
||||
tunnel_instance.local_bind_address = ("127.0.0.1", 12345)
|
||||
mock_tunnel_class.return_value = tunnel_instance
|
||||
mock_open_tunnel.return_value = tunnel_instance
|
||||
|
||||
uri = MagicMock()
|
||||
uri.host = "db.example.com"
|
||||
@@ -247,18 +247,19 @@ class TestEngineManager:
|
||||
result = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
|
||||
assert result is tunnel_instance
|
||||
mock_tunnel_class.assert_called_once()
|
||||
mock_open_tunnel.assert_called_once()
|
||||
tunnel_instance.start.assert_called_once()
|
||||
|
||||
# Getting same tunnel again should return cached version
|
||||
mock_tunnel_class.reset_mock()
|
||||
mock_open_tunnel.reset_mock()
|
||||
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
|
||||
assert result2 is tunnel_instance
|
||||
mock_tunnel_class.assert_not_called()
|
||||
mock_open_tunnel.assert_not_called()
|
||||
|
||||
@patch("superset.engines.manager.SSHTunnelForwarder")
|
||||
@patch("superset.engines.manager.sshtunnel.open_tunnel")
|
||||
def test_ssh_tunnel_recreation_when_inactive(
|
||||
self, mock_tunnel_class, engine_manager
|
||||
self, mock_open_tunnel, engine_manager
|
||||
):
|
||||
"""Test that inactive tunnels are replaced."""
|
||||
ssh_tunnel = MagicMock()
|
||||
@@ -279,7 +280,7 @@ class TestEngineManager:
|
||||
active_tunnel.is_active = True
|
||||
active_tunnel.local_bind_address = ("127.0.0.1", 23456)
|
||||
|
||||
mock_tunnel_class.side_effect = [inactive_tunnel, active_tunnel]
|
||||
mock_open_tunnel.side_effect = [inactive_tunnel, active_tunnel]
|
||||
|
||||
uri = MagicMock()
|
||||
uri.host = "db.example.com"
|
||||
@@ -293,7 +294,7 @@ class TestEngineManager:
|
||||
# Second call should create new tunnel since first is inactive
|
||||
result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
|
||||
assert result2 is active_tunnel
|
||||
assert mock_tunnel_class.call_count == 2
|
||||
assert mock_open_tunnel.call_count == 2
|
||||
|
||||
@patch("superset.engines.manager.create_engine")
|
||||
@patch("superset.engines.manager.make_url_safe")
|
||||
|
||||
Reference in New Issue
Block a user