diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index 89e607ba67a..9e9161ea51f 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -33,6 +33,7 @@ from superset.databases.utils import make_url_safe from superset.extensions import event_logger from superset.models.core import Database from superset.utils.decorators import on_error, transaction +from superset.utils.ssh_tunnel import get_default_port logger = logging.getLogger(__name__) @@ -72,7 +73,9 @@ class CreateSSHTunnelCommand(BaseCommand): "private_key_password" ) url = make_url_safe(self._database.sqlalchemy_uri) - if not url.port: + backend = url.get_backend_name() + port = url.port or get_default_port(backend) + if not port: raise SSHTunnelDatabasePortError() if not server_address: exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) diff --git a/superset/commands/database/ssh_tunnel/update.py b/superset/commands/database/ssh_tunnel/update.py index b2fa416bd59..763d36e89a0 100644 --- a/superset/commands/database/ssh_tunnel/update.py +++ b/superset/commands/database/ssh_tunnel/update.py @@ -32,6 +32,7 @@ from superset.daos.database import SSHTunnelDAO from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe from superset.utils.decorators import on_error, transaction +from superset.utils.ssh_tunnel import get_default_port logger = logging.getLogger(__name__) @@ -75,5 +76,7 @@ class UpdateSSHTunnelCommand(BaseCommand): raise SSHTunnelInvalidError( exceptions=[SSHTunnelRequiredFieldValidationError("private_key")] ) - if not url.port: + backend = url.get_backend_name() + port = url.port or get_default_port(backend) + if not port: raise SSHTunnelDatabasePortError() diff --git a/superset/extensions/ssh.py b/superset/extensions/ssh.py index 09840cc38bc..a3c015b9ebc 100644 --- a/superset/extensions/ssh.py +++ b/superset/extensions/ssh.py @@ -52,11 +52,14 @@ class SSHManager: ssh_tunnel: "SSHTunnel", sqlalchemy_database_uri: str, ) -> sshtunnel.SSHTunnelForwarder: + from superset.utils.ssh_tunnel import get_default_port + url = make_url_safe(sqlalchemy_database_uri) + backend = url.get_backend_name() params = { "ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port), "ssh_username": ssh_tunnel.username, - "remote_bind_address": (url.host, url.port), + "remote_bind_address": (url.host, url.port or get_default_port(backend)), "local_bind_address": (self.local_bind_address,), "debug_level": logging.getLogger("flask_appbuilder").level, } diff --git a/superset/utils/ssh_tunnel.py b/superset/utils/ssh_tunnel.py index 8421350f8c1..1471d54f4b1 100644 --- a/superset/utils/ssh_tunnel.py +++ b/superset/utils/ssh_tunnel.py @@ -20,6 +20,13 @@ from typing import Any from superset.constants import PASSWORD_MASK from superset.databases.ssh_tunnel.models import SSHTunnel +DEFAULT_PORTS: dict[str, int] = { + "postgresql": 5432, + "mysql": 3306, + "oracle": 1521, + "mssql": 1433, +} + def mask_password_info(ssh_tunnel: dict[str, Any]) -> dict[str, Any]: if ssh_tunnel.pop("password", None) is not None: @@ -41,3 +48,10 @@ def unmask_password_info( if ssh_tunnel.get("private_key_password") == PASSWORD_MASK: ssh_tunnel["private_key_password"] = model.private_key_password return ssh_tunnel + + +def get_default_port(backend: str) -> int | None: + """ + Get the default port for the given backend. + """ + return DEFAULT_PORTS.get(backend) diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 03327e0e267..1a37394d773 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -345,7 +345,59 @@ class TestDatabaseApi(SupersetTestCase): @mock.patch("superset.commands.database.create.is_feature_enabled") @mock.patch("superset.models.core.Database.get_all_catalog_names") @mock.patch("superset.models.core.Database.get_all_schema_names") - def test_create_database_with_missing_port_raises_error( + def test_create_database_with_ssh_tunnel_no_port( + self, + mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, + ): + """ + Database API: Test create with SSH Tunnel + """ + mock_create_is_feature_enabled.return_value = True + self.login(ADMIN_USERNAME) + example_db = get_example_database() + if example_db.backend == "sqlite": + return + + modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db" + + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": modified_sqlalchemy_uri, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data_with_ssh_tunnel) + response = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 201 + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response.get("id")) + .one() + ) + assert response.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX" # noqa: S105 + assert model_ssh_tunnel.database_id == response.get("id") + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch("superset.commands.database.create.is_feature_enabled") + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") + def test_create_database_with_ssh_tunnel_no_port_no_default( self, mock_get_all_schema_names, mock_get_all_catalog_names, @@ -361,7 +413,7 @@ class TestDatabaseApi(SupersetTestCase): if example_db.backend == "sqlite": return - modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db" + modified_sqlalchemy_uri = "weird+db://foo:bar@localhost/test-db" ssh_tunnel_properties = { "server_address": "123.132.123.1", @@ -369,13 +421,6 @@ class TestDatabaseApi(SupersetTestCase): "username": "foo", "password": "bar", } - - database_data_with_ssh_tunnel = { - "database_name": "test-db-with-ssh-tunnel", - "sqlalchemy_uri": modified_sqlalchemy_uri, - "ssh_tunnel": ssh_tunnel_properties, - } - database_data_with_ssh_tunnel = { "database_name": "test-db-with-ssh-tunnel", "sqlalchemy_uri": modified_sqlalchemy_uri, @@ -459,7 +504,71 @@ class TestDatabaseApi(SupersetTestCase): @mock.patch("superset.commands.database.update.is_feature_enabled") @mock.patch("superset.models.core.Database.get_all_catalog_names") @mock.patch("superset.models.core.Database.get_all_schema_names") - def test_update_database_with_missing_port_raises_error( + def test_update_database_with_ssh_tunnel_no_port( + self, + mock_get_all_schema_names, + mock_get_all_catalog_names, + mock_update_is_feature_enabled, + mock_create_is_feature_enabled, + mock_test_connection_database_command_run, + ): + """ + Database API: Test update Database with SSH Tunnel + """ + mock_create_is_feature_enabled.return_value = True + mock_update_is_feature_enabled.return_value = True + self.login(ADMIN_USERNAME) + example_db = get_example_database() + if example_db.backend == "sqlite": + return + + modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db" + + ssh_tunnel_properties = { + "server_address": "123.132.123.1", + "server_port": 8080, + "username": "foo", + "password": "bar", + } + database_data = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted, + } + database_data_with_ssh_tunnel = { + "database_name": "test-db-with-ssh-tunnel", + "sqlalchemy_uri": modified_sqlalchemy_uri, + "ssh_tunnel": ssh_tunnel_properties, + } + + uri = "api/v1/database/" + rv = self.client.post(uri, json=database_data) + response = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 201 + + uri = "api/v1/database/{}".format(response.get("id")) + rv = self.client.put(uri, json=database_data_with_ssh_tunnel) + response_update = json.loads(rv.data.decode("utf-8")) + assert rv.status_code == 200 + + model_ssh_tunnel = ( + db.session.query(SSHTunnel) + .filter(SSHTunnel.database_id == response_update.get("id")) + .one() + ) + assert model_ssh_tunnel.database_id == response_update.get("id") + # Cleanup + model = db.session.query(Database).get(response.get("id")) + db.session.delete(model) + db.session.commit() + + @mock.patch( + "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", + ) + @mock.patch("superset.commands.database.create.is_feature_enabled") + @mock.patch("superset.commands.database.update.is_feature_enabled") + @mock.patch("superset.models.core.Database.get_all_catalog_names") + @mock.patch("superset.models.core.Database.get_all_schema_names") + def test_update_database_no_port_no_default( self, mock_get_all_schema_names, mock_get_all_catalog_names, @@ -477,7 +586,7 @@ class TestDatabaseApi(SupersetTestCase): if example_db.backend == "sqlite": return - modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db" + modified_sqlalchemy_uri = "weird+db://foo:bar@localhost/test-db" ssh_tunnel_properties = { "server_address": "123.132.123.1", diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index 9b9393d3a73..168c0fc3d53 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -17,6 +17,7 @@ import pytest +from sqlalchemy.orm.session import Session from superset.commands.database.ssh_tunnel.exceptions import ( SSHTunnelDatabasePortError, @@ -24,13 +25,16 @@ from superset.commands.database.ssh_tunnel.exceptions import ( ) -def test_create_ssh_tunnel_command() -> None: +def test_create_ssh_tunnel_command(session: Session) -> None: + from superset import db from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database + engine = db.session.get_bind() + Database.metadata.create_all(engine) # pylint: disable=no-member + database = Database( - id=1, database_name="my_database", sqlalchemy_uri="postgresql://u:p@localhost:5432/db", ) @@ -49,12 +53,15 @@ def test_create_ssh_tunnel_command() -> None: assert isinstance(result, SSHTunnel) -def test_create_ssh_tunnel_command_invalid_params() -> None: +def test_create_ssh_tunnel_command_invalid_params(session: Session) -> None: + from superset import db from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.models.core import Database + engine = db.session.get_bind() + Database.metadata.create_all(engine) # pylint: disable=no-member + database = Database( - id=1, database_name="my_database", sqlalchemy_uri="postgresql://u:p@localhost:5432/db", ) @@ -76,14 +83,52 @@ def test_create_ssh_tunnel_command_invalid_params() -> None: assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") -def test_create_ssh_tunnel_command_no_port() -> None: +def test_create_ssh_tunnel_command_no_port(session: Session) -> None: + """ + Test that SSH Tunnel can be created without explicit port but with a default one. + """ + from superset import db + from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand + from superset.databases.ssh_tunnel.models import SSHTunnel + from superset.models.core import Database + + engine = db.session.get_bind() + Database.metadata.create_all(engine) # pylint: disable=no-member + + database = Database( + database_name="my_database", + sqlalchemy_uri="postgresql://u:p@localhost/db", + ) + + properties = { + "database": database, + "server_address": "123.132.123.1", + "server_port": "3005", + "username": "foo", + "password": "bar", + } + + result = CreateSSHTunnelCommand(database, properties).run() + + assert result is not None + assert isinstance(result, SSHTunnel) + + +def test_create_ssh_tunnel_command_no_port_no_default(session: Session) -> None: + """ + Test that error is raised when creating SSH Tunnel without explicit/default ports. + """ + from superset import db from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.models.core import Database + engine = db.session.get_bind() + Database.metadata.create_all(engine) # pylint: disable=no-member + database = Database( id=1, database_name="my_database", - sqlalchemy_uri="postgresql://u:p@localhost/db", + sqlalchemy_uri="weird+db://u:p@localhost/db", ) properties = { diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py index 66684eb8def..f223014ee34 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py @@ -103,6 +103,37 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None: "session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True ) def test_update_shh_tunnel_no_port(session_with_data: Session) -> None: + """ + Test that SSH Tunnel can be updated without explicit port but with a default one. + """ + from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand + from superset.daos.database import DatabaseDAO + from superset.databases.ssh_tunnel.models import SSHTunnel + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result, SSHTunnel) + assert 1 == result.database_id + assert "Test" == result.server_address + + update_payload = {"server_address": "Test2"} + UpdateSSHTunnelCommand(1, update_payload).run() + + result = DatabaseDAO.get_ssh_tunnel(1) + + assert result + assert isinstance(result, SSHTunnel) + assert "Test2" == result.server_address + + +@pytest.mark.parametrize( + "session_with_data", ["weird+db://u:p@localhost/testdb"], indirect=True +) +def test_update_shh_tunnel_no_port_no_default(session_with_data: Session) -> None: + """ + Test that error is raised when updating SSH Tunnel without explicit/default ports. + """ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand from superset.daos.database import DatabaseDAO from superset.databases.ssh_tunnel.models import SSHTunnel