diff --git a/superset/commands/database/exceptions.py b/superset/commands/database/exceptions.py index 6dfdadacbd7..10c1a96c67a 100644 --- a/superset/commands/database/exceptions.py +++ b/superset/commands/database/exceptions.py @@ -138,6 +138,15 @@ class DatabaseConnectionFailedError( # pylint: disable=too-many-ancestors message = _("Connection failed, please check your connection settings") +class MissingOAuth2TokenError(DatabaseUpdateFailedError): + """ + Exception for when the connection is missing an OAuth2 token + and it's not possible to initiate an OAuth2 dance. + """ + + message = _("Missing OAuth2 token") + + class DatabaseDeleteDatasetsExistFailedError(DeleteFailedError): message = _("Cannot delete a database that has datasets attached") diff --git a/superset/commands/database/sync_permissions.py b/superset/commands/database/sync_permissions.py index 4f041ce94cb..bceacff89b1 100644 --- a/superset/commands/database/sync_permissions.py +++ b/superset/commands/database/sync_permissions.py @@ -28,6 +28,7 @@ from superset.commands.database.exceptions import ( DatabaseConnectionFailedError, DatabaseConnectionSyncPermissionsError, DatabaseNotFoundError, + MissingOAuth2TokenError, UserNotFoundInSessionError, ) from superset.commands.database.utils import ( @@ -115,6 +116,11 @@ class SyncPermissionsCommand(BaseCommand): try: alive = ping(engine) except Exception as err: + if ( + self.db_connection.is_oauth2_enabled() + and self.db_connection.db_engine_spec.needs_oauth2(err) + ): + raise MissingOAuth2TokenError() from err raise DatabaseConnectionFailedError() from err if not alive: diff --git a/superset/commands/database/update.py b/superset/commands/database/update.py index f562b5d7802..92f91c15ec0 100644 --- a/superset/commands/database/update.py +++ b/superset/commands/database/update.py @@ -30,6 +30,7 @@ from superset.commands.database.exceptions import ( DatabaseInvalidError, DatabaseNotFoundError, DatabaseUpdateFailedError, + MissingOAuth2TokenError, ) from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand @@ -108,7 +109,7 @@ class UpdateDatabaseCommand(BaseCommand): db_connection=database, ssh_tunnel=ssh_tunnel, ).run() - except OAuth2RedirectError: + except (OAuth2RedirectError, MissingOAuth2TokenError): pass return database diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index fc558e519fb..0c6dad0b1df 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -35,6 +35,7 @@ from sqlalchemy.exc import DBAPIError from sqlalchemy.sql import func from superset import db, security_manager +from superset.commands.database.exceptions import MissingOAuth2TokenError from superset.connectors.sqla.models import SqlaTable from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe # noqa: F401 @@ -1393,6 +1394,32 @@ class TestDatabaseApi(SupersetTestCase): db.session.delete(model) db.session.commit() + @mock.patch( + "superset.commands.database.sync_permissions.SyncPermissionsCommand.run", + ) + def test_update_database_missing_oauth2_token(self, mock_sync_perms): + """ + Database API: Test update DB connection that does not have + an OAuth2 token yet does not raise. + """ + example_db = get_example_database() + test_database = self.insert_database( + "test-oauth-database", example_db.sqlalchemy_uri_decrypted + ) + mock_sync_perms.side_effect = MissingOAuth2TokenError() + self.login(ADMIN_USERNAME) + database_data = { + "database_name": "test-database-updated", + "configuration_method": ConfigurationMethod.SQLALCHEMY_FORM, + } + uri = f"api/v1/database/{test_database.id}" + rv = self.client.put(uri, json=database_data) + assert rv.status_code == 200 + # Cleanup + model = db.session.query(Database).get(test_database.id) + db.session.delete(model) + db.session.commit() + def test_update_database_uniqueness(self): """ Database API: Test update uniqueness diff --git a/tests/unit_tests/commands/databases/conftest.py b/tests/unit_tests/commands/databases/conftest.py index 49da52daf91..81f489d95ed 100644 --- a/tests/unit_tests/commands/databases/conftest.py +++ b/tests/unit_tests/commands/databases/conftest.py @@ -64,6 +64,8 @@ def database_without_catalog(mocker: MockerFixture) -> MagicMock: database.db_engine_spec.__name__ = "test_engine" database.db_engine_spec.supports_catalog = False database.get_all_schema_names.return_value = ["schema1", "schema2"] + database.is_oauth2_enabled.return_value = False + database.db_engine_spec.needs_oauth2.return_value = False return database diff --git a/tests/unit_tests/commands/databases/sync_permissions_test.py b/tests/unit_tests/commands/databases/sync_permissions_test.py index e597e384db9..78dfe3d0c15 100644 --- a/tests/unit_tests/commands/databases/sync_permissions_test.py +++ b/tests/unit_tests/commands/databases/sync_permissions_test.py @@ -25,6 +25,7 @@ from superset import db from superset.commands.database.exceptions import ( DatabaseConnectionFailedError, DatabaseNotFoundError, + MissingOAuth2TokenError, UserNotFoundInSessionError, ) from superset.commands.database.sync_permissions import SyncPermissionsCommand @@ -146,14 +147,18 @@ def test_sync_permissions_command_passing_all_values( @with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False}) -def test_sync_permissions_command_raise(mocker: MockerFixture): +def test_sync_permissions_command_raise( + mocker: MockerFixture, + database_without_catalog: MagicMock, + database_needs_oauth2: MagicMock, +): """ Test ``SyncPermissionsCommand`` when an exception is raised. """ mock_database_dao = mocker.patch( "superset.commands.database.sync_permissions.DatabaseDAO" ) - mock_database_dao.find_by_id.return_value = mocker.MagicMock() + mock_database_dao.find_by_id.return_value = database_without_catalog mock_database_dao.get_ssh_tunnel.return_value = mocker.MagicMock() mock_user = mocker.patch( "superset.commands.database.sync_permissions.security_manager.get_user_by_username" @@ -169,6 +174,11 @@ def test_sync_permissions_command_raise(mocker: MockerFixture): mock_ping.side_effect = Exception with pytest.raises(DatabaseConnectionFailedError): SyncPermissionsCommand(1, "admin").run() + # OAuth2 error + mock_database_dao.find_by_id.reset_mock() + mock_database_dao.find_by_id.return_value = database_needs_oauth2 + with pytest.raises(MissingOAuth2TokenError): + SyncPermissionsCommand(1, "admin").run() # User not found in session mock_user.reset_mock()