From 08df7d5178b2ffd09f66968a9ef98acc8c6a80ca Mon Sep 17 00:00:00 2001 From: Beto Dealmeida Date: Wed, 4 Feb 2026 10:58:39 -0500 Subject: [PATCH] fix: SSH tunnel and test connection error handling - Use sshtunnel.open_tunnel() instead of SSHTunnelForwarder directly to properly handle debug_level parameter - Fix keepalive parameter name (set_keepalive, not keepalive) - Fix test assertions that were inside pytest.raises blocks and never executed - now check error_type instead of string messages - Update SSH tunnel test mocks to patch open_tunnel Co-Authored-By: Claude Opus 4.5 --- superset/engines/manager.py | 5 +-- .../databases/commands_tests.py | 32 ++++++++++--------- tests/unit_tests/engines/manager_test.py | 21 ++++++------ 3 files changed, 31 insertions(+), 27 deletions(-) diff --git a/superset/engines/manager.py b/superset/engines/manager.py index 1beb41929af..9f3722d5eb8 100644 --- a/superset/engines/manager.py +++ b/superset/engines/manager.py @@ -495,7 +495,8 @@ class EngineManager: def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder: kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri) - tunnel = SSHTunnelForwarder(**kwargs) + # Use open_tunnel which handles debug_level properly + tunnel = sshtunnel.open_tunnel(**kwargs) tunnel.start() return tunnel @@ -524,7 +525,7 @@ class EngineManager: kwargs["ssh_pkey"] = private_key if self.mode == EngineModes.NEW: - kwargs["keepalive"] = 0 # disable + kwargs["set_keepalive"] = 0 # disable keepalive for one-time tunnels return kwargs diff --git a/tests/integration_tests/databases/commands_tests.py b/tests/integration_tests/databases/commands_tests.py index f300aa45604..21939386abb 100644 --- a/tests/integration_tests/databases/commands_tests.py +++ b/tests/integration_tests/databases/commands_tests.py @@ -913,10 +913,12 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): json_payload = {"sqlalchemy_uri": db_uri} command_without_db_name = TestConnectionDatabaseCommand(json_payload) - with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: # noqa: PT012 + with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo: command_without_db_name.run() - assert str(excinfo.value) == ( - "Unexpected error occurred, please check your logs for details" + # Exception wraps errors from db_engine_spec.extract_errors() + assert ( + excinfo.value.errors[0].error_type + == SupersetErrorType.GENERIC_DB_ENGINE_ERROR ) mock_event_logger.assert_called() @@ -929,9 +931,8 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): """Test to make sure do_ping exceptions gets captured""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.return_value.__enter__.return_value.dialect.do_ping.side_effect = Exception( - "An error has occurred!" - ) + mock_engine = mock_get_sqla_engine.return_value.__enter__.return_value + mock_engine.dialect.do_ping.side_effect = Exception("An error has occurred!") db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} command_without_db_name = TestConnectionDatabaseCommand(json_payload) @@ -979,17 +980,17 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): connection exc is raised""" database = get_example_database() mock_g.user = security_manager.find_user("admin") - mock_get_sqla_engine.return_value.__enter__.side_effect = SupersetSecurityException( - SupersetError(error_type=500, message="test", level="info") + mock_get_sqla_engine.return_value.__enter__.side_effect = ( + SupersetSecurityException( + SupersetError(error_type=500, message="test", level="info") + ) ) db_uri = database.sqlalchemy_uri_decrypted json_payload = {"sqlalchemy_uri": db_uri} command_without_db_name = TestConnectionDatabaseCommand(json_payload) - with pytest.raises(DatabaseSecurityUnsafeError) as excinfo: # noqa: PT012 + with pytest.raises(DatabaseSecurityUnsafeError): command_without_db_name.run() - assert str(excinfo.value) == ("Stopped an unsafe database connection") - mock_event_logger.assert_called() @patch("superset.models.core.Database.get_sqla_engine") @@ -1008,12 +1009,13 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase): json_payload = {"sqlalchemy_uri": db_uri} command_without_db_name = TestConnectionDatabaseCommand(json_payload) - with pytest.raises(SupersetErrorsException) as excinfo: # noqa: PT012 + with pytest.raises(SupersetErrorsException) as excinfo: command_without_db_name.run() - assert str(excinfo.value) == ( - "Connection failed, please check your connection settings" + # Exception wraps errors from db_engine_spec.extract_errors() + assert ( + excinfo.value.errors[0].error_type + == SupersetErrorType.GENERIC_DB_ENGINE_ERROR ) - mock_event_logger.assert_called() diff --git a/tests/unit_tests/engines/manager_test.py b/tests/unit_tests/engines/manager_test.py index dc5248d663a..e0abf88f57e 100644 --- a/tests/unit_tests/engines/manager_test.py +++ b/tests/unit_tests/engines/manager_test.py @@ -223,8 +223,8 @@ class TestEngineManager: for engine in results: assert engine is real_engine - @patch("superset.engines.manager.SSHTunnelForwarder") - def test_ssh_tunnel_creation(self, mock_tunnel_class, engine_manager): + @patch("superset.engines.manager.sshtunnel.open_tunnel") + def test_ssh_tunnel_creation(self, mock_open_tunnel, engine_manager): """Test SSH tunnel creation and caching.""" ssh_tunnel = MagicMock() ssh_tunnel.server_address = "ssh.example.com" @@ -237,7 +237,7 @@ class TestEngineManager: tunnel_instance = MagicMock() tunnel_instance.is_active = True tunnel_instance.local_bind_address = ("127.0.0.1", 12345) - mock_tunnel_class.return_value = tunnel_instance + mock_open_tunnel.return_value = tunnel_instance uri = MagicMock() uri.host = "db.example.com" @@ -247,18 +247,19 @@ class TestEngineManager: result = engine_manager._get_tunnel(ssh_tunnel, uri) assert result is tunnel_instance - mock_tunnel_class.assert_called_once() + mock_open_tunnel.assert_called_once() + tunnel_instance.start.assert_called_once() # Getting same tunnel again should return cached version - mock_tunnel_class.reset_mock() + mock_open_tunnel.reset_mock() result2 = engine_manager._get_tunnel(ssh_tunnel, uri) assert result2 is tunnel_instance - mock_tunnel_class.assert_not_called() + mock_open_tunnel.assert_not_called() - @patch("superset.engines.manager.SSHTunnelForwarder") + @patch("superset.engines.manager.sshtunnel.open_tunnel") def test_ssh_tunnel_recreation_when_inactive( - self, mock_tunnel_class, engine_manager + self, mock_open_tunnel, engine_manager ): """Test that inactive tunnels are replaced.""" ssh_tunnel = MagicMock() @@ -279,7 +280,7 @@ class TestEngineManager: active_tunnel.is_active = True active_tunnel.local_bind_address = ("127.0.0.1", 23456) - mock_tunnel_class.side_effect = [inactive_tunnel, active_tunnel] + mock_open_tunnel.side_effect = [inactive_tunnel, active_tunnel] uri = MagicMock() uri.host = "db.example.com" @@ -293,7 +294,7 @@ class TestEngineManager: # Second call should create new tunnel since first is inactive result2 = engine_manager._get_tunnel(ssh_tunnel, uri) assert result2 is active_tunnel - assert mock_tunnel_class.call_count == 2 + assert mock_open_tunnel.call_count == 2 @patch("superset.engines.manager.create_engine") @patch("superset.engines.manager.make_url_safe")