diff --git a/superset/commands/database/create.py b/superset/commands/database/create.py index a012e9b2a57..cde9dd8e884 100644 --- a/superset/commands/database/create.py +++ b/superset/commands/database/create.py @@ -41,6 +41,7 @@ from superset.daos.database import DatabaseDAO from superset.daos.exceptions import DAOCreateFailedError from superset.exceptions import SupersetErrorsException from superset.extensions import db, event_logger, security_manager +from superset.models.core import Database logger = logging.getLogger(__name__) stats_logger = current_app.config["STATS_LOGGER"] @@ -76,34 +77,20 @@ class CreateDatabaseCommand(BaseCommand): "{}", ) - try: - database = DatabaseDAO.create(attributes=self._properties, commit=False) - database.set_sqlalchemy_uri(database.sqlalchemy_uri) + ssh_tunnel = None + + try: + database = self._create_database() - ssh_tunnel = None if ssh_tunnel_properties := self._properties.get("ssh_tunnel"): if not is_feature_enabled("SSH_TUNNELING"): - db.session.rollback() raise SSHTunnelingNotEnabledError() - try: - # So database.id is not None - db.session.flush() - ssh_tunnel = CreateSSHTunnelCommand( - database.id, ssh_tunnel_properties - ).run() - except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: - event_logger.log_with_context( - action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", - engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], - ) - # So we can show the original message - raise ex - except Exception as ex: - event_logger.log_with_context( - action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", - engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], - ) - raise DatabaseCreateFailedError() from ex + + ssh_tunnel = CreateSSHTunnelCommand( + database, ssh_tunnel_properties + ).run() + + db.session.commit() # adding a new database we always want to force refresh schema list schemas = database.get_all_schema_names(cache=False, ssh_tunnel=ssh_tunnel) @@ -112,9 +99,23 @@ class CreateDatabaseCommand(BaseCommand): "schema_access", security_manager.get_schema_perm(database, schema) ) - db.session.commit() - - except DAOCreateFailedError as ex: + except ( + SSHTunnelInvalidError, + SSHTunnelCreateFailedError, + SSHTunnelingNotEnabledError, + ) as ex: + db.session.rollback() + event_logger.log_with_context( + action=f"db_creation_failed.{ex.__class__.__name__}.ssh_tunnel", + engine=self._properties.get("sqlalchemy_uri", "").split(":")[0], + ) + # So we can show the original message + raise ex + except ( + DAOCreateFailedError, + DatabaseInvalidError, + Exception, + ) as ex: db.session.rollback() event_logger.log_with_context( action=f"db_creation_failed.{ex.__class__.__name__}", @@ -150,3 +151,8 @@ class CreateDatabaseCommand(BaseCommand): ) ) raise exception + + def _create_database(self) -> Database: + database = DatabaseDAO.create(attributes=self._properties, commit=False) + database.set_sqlalchemy_uri(database.sqlalchemy_uri) + return database diff --git a/superset/commands/database/ssh_tunnel/create.py b/superset/commands/database/ssh_tunnel/create.py index 07209f010ba..cbfee3ce2ae 100644 --- a/superset/commands/database/ssh_tunnel/create.py +++ b/superset/commands/database/ssh_tunnel/create.py @@ -28,39 +28,32 @@ from superset.commands.database.ssh_tunnel.exceptions import ( ) from superset.daos.database import SSHTunnelDAO from superset.daos.exceptions import DAOCreateFailedError -from superset.extensions import db, event_logger +from superset.extensions import event_logger +from superset.models.core import Database logger = logging.getLogger(__name__) class CreateSSHTunnelCommand(BaseCommand): - def __init__(self, database_id: int, data: dict[str, Any]): + def __init__(self, database: Database, data: dict[str, Any]): self._properties = data.copy() - self._properties["database_id"] = database_id + self._properties["database"] = database def run(self) -> Model: try: - # Start nested transaction since we are always creating the tunnel - # through a DB command (Create or Update). Without this, we cannot - # safely rollback changes to databases if any, i.e, things like - # test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail - db.session.begin_nested() self.validate() - return SSHTunnelDAO.create(attributes=self._properties, commit=False) + ssh_tunnel = SSHTunnelDAO.create(attributes=self._properties, commit=False) + return ssh_tunnel except DAOCreateFailedError as ex: - # Rollback nested transaction - db.session.rollback() raise SSHTunnelCreateFailedError() from ex except SSHTunnelInvalidError as ex: - # Rollback nested transaction - db.session.rollback() raise ex def validate(self) -> None: # TODO(hughhh): check to make sure the server port is not localhost # using the config.SSH_TUNNEL_MANAGER + exceptions: list[ValidationError] = [] - database_id: Optional[int] = self._properties.get("database_id") server_address: Optional[str] = self._properties.get("server_address") server_port: Optional[int] = self._properties.get("server_port") username: Optional[str] = self._properties.get("username") @@ -68,8 +61,6 @@ class CreateSSHTunnelCommand(BaseCommand): private_key_password: Optional[str] = self._properties.get( "private_key_password" ) - if not database_id: - exceptions.append(SSHTunnelRequiredFieldValidationError("database_id")) if not server_address: exceptions.append(SSHTunnelRequiredFieldValidationError("server_address")) if not server_port: diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 039d731d72d..edc0ba1b989 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -78,7 +78,7 @@ class UpdateDatabaseCommand(BaseCommand): if existing_ssh_tunnel_model is None: # We couldn't found an existing tunnel so we need to create one try: - CreateSSHTunnelCommand(database.id, ssh_tunnel_properties).run() + CreateSSHTunnelCommand(database, ssh_tunnel_properties).run() except (SSHTunnelInvalidError, SSHTunnelCreateFailedError) as ex: # So we can show the original message raise ex diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index 0bc1f245a1f..f7b8cc0ec8c 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -538,14 +538,16 @@ class TestDatabaseApi(SupersetTestCase): @mock.patch( "superset.models.core.Database.get_all_schema_names", ) + @mock.patch("superset.extensions.db.session.rollback") def test_do_not_create_database_if_ssh_tunnel_creation_fails( self, + mock_rollback, mock_test_connection_database_command_run, mock_create_is_feature_enabled, mock_get_all_schema_names, ): """ - Database API: Test Database is not created if SSH Tunnel creation fails + Database API: Test rollback is called if SSH Tunnel creation fails """ mock_create_is_feature_enabled.return_value = True self.login(username="admin") @@ -566,6 +568,7 @@ class TestDatabaseApi(SupersetTestCase): rv = self.client.post(uri, json=database_data) response = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) + model_ssh_tunnel = ( db.session.query(SSHTunnel) .filter(SSHTunnel.database_id == response.get("id")) @@ -573,14 +576,9 @@ class TestDatabaseApi(SupersetTestCase): ) assert model_ssh_tunnel is None self.assertEqual(response, fail_message) - # Cleanup - model = ( - db.session.query(Database) - .filter(Database.database_name == "test-db-failure-ssh-tunnel") - .one_or_none() - ) - # the DB should not be created - assert model is None + + # Check that rollback was called + mock_rollback.assert_called() @mock.patch( "superset.commands.database.test_connection.TestConnectionDatabaseCommand.run", diff --git a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py index 1cd9afcc809..f6e5ca9d096 100644 --- a/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py +++ b/tests/integration_tests/databases/ssh_tunnel/commands/commands_tests.py @@ -30,23 +30,6 @@ from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand from tests.integration_tests.base_tests import SupersetTestCase -class TestCreateSSHTunnelCommand(SupersetTestCase): - @mock.patch("superset.utils.core.g") - def test_create_invalid_database_id(self, mock_g): - mock_g.user = security_manager.find_user("admin") - command = CreateSSHTunnelCommand( - None, - { - "server_address": "127.0.0.1", - "server_port": 5432, - "username": "test_user", - }, - ) - with pytest.raises(SSHTunnelInvalidError) as excinfo: - command.run() - assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.") - - class TestUpdateSSHTunnelCommand(SupersetTestCase): @mock.patch("superset.utils.core.g") def test_update_ssh_tunnel_not_found(self, mock_g): diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index bd891b64f05..1777bdc2e10 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -37,7 +37,7 @@ def test_create_ssh_tunnel_command() -> None: "password": "bar", } - result = CreateSSHTunnelCommand(db.id, properties).run() + result = CreateSSHTunnelCommand(db, properties).run() assert result is not None assert isinstance(result, SSHTunnel) @@ -53,14 +53,14 @@ def test_create_ssh_tunnel_command_invalid_params() -> None: # If we are trying to create a tunnel with a private_key_password # then a private_key is mandatory properties = { - "database_id": db.id, + "database": db, "server_address": "123.132.123.1", "server_port": "3005", "username": "foo", "private_key_password": "bar", } - command = CreateSSHTunnelCommand(db.id, properties) + command = CreateSSHTunnelCommand(db, properties) with pytest.raises(SSHTunnelInvalidError) as excinfo: command.run()