diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index 92f91c15ec0..98e3ffac2b3 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -80,9 +80,19 @@ class UpdateDatabaseCommand(BaseCommand): # existing personal tokens. self._handle_oauth2() - # build new DB + # Some DBs require running a query to get the default catalog. + # In these cases, if the current connection is broken then + # `get_default_catalog` would raise an exception. We need to + # gracefully handle that so that the connection can be fixed. original_database_name = self._model.database_name - original_catalog = self._model.get_default_catalog() + force_update: bool = False + try: + original_catalog = self._model.get_default_catalog() + except Exception: + original_catalog = None + force_update = True + + # build new DB database = DatabaseDAO.update(self._model, self._properties) database.set_sqlalchemy_uri(database.sqlalchemy_uri) ssh_tunnel = self._handle_ssh_tunnel(database) @@ -92,7 +102,8 @@ class UpdateDatabaseCommand(BaseCommand): # configured with multi-catalog support; if it was enabled or is enabled in the # update we don't update the assets if ( - new_catalog != original_catalog + force_update + or new_catalog != original_catalog and not self._model.allow_multi_catalog and not database.allow_multi_catalog ): diff --git a/superset/db_engine_specs/databricks.py b/superset/db_engine_specs/databricks.py index 48144bb35bc..66e9078bee8 100644 --- a/superset/db_engine_specs/databricks.py +++ b/superset/db_engine_specs/databricks.py @@ -440,6 +440,8 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec): """ Return the default catalog. + It's optionally specified in `connect_args.catalog`. If not: + The default behavior for Databricks is confusing. When Unity Catalog is not enabled we have (the DB engine spec hasn't been tested with it enabled): @@ -451,6 +453,10 @@ class DatabricksNativeEngineSpec(DatabricksDynamicBaseEngineSpec): To handle permissions correctly we use the result of `SHOW CATALOGS` when a single catalog is returned. """ + connect_args = cls.get_extra_params(database)["engine_params"]["connect_args"] + if default_catalog := connect_args.get("catalog"): + return default_catalog + with database.get_sqla_engine() as engine: catalogs = {catalog for (catalog,) in engine.execute("SHOW CATALOGS")} if len(catalogs) == 1: diff --git a/tests/unit_tests/commands/databases/update_test.py b/tests/unit_tests/commands/databases/update_test.py index 5e6b0869d72..32b816295be 100644 --- a/tests/unit_tests/commands/databases/update_test.py +++ b/tests/unit_tests/commands/databases/update_test.py @@ -642,3 +642,28 @@ def test_update_without_catalog_change(mocker: MockerFixture) -> None: UpdateDatabaseCommand(1, {}).run() update_catalog_attribute.assert_not_called() + + +def test_update_broken_connection(mocker: MockerFixture) -> None: + """ + Test that updating a database with a broken connection works + even if it has to run a query to get the default catalog. + """ + database = mocker.MagicMock() + database.get_default_catalog.side_effect = Exception("Broken connection") + database.id = 1 + new_db = mocker.MagicMock() + new_db.get_default_catalog.return_value = "main" + + database_dao = mocker.patch("superset.commands.database.update.DatabaseDAO") + database_dao.find_by_id.return_value = database + database_dao.update.return_value = new_db + mocker.patch("superset.commands.database.update.SyncPermissionsCommand") + + update_catalog_attribute = mocker.patch.object( + UpdateDatabaseCommand, + "_update_catalog_attribute", + ) + UpdateDatabaseCommand(1, {}).run() + + update_catalog_attribute.assert_called_once_with(1, "main")