diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index bd1e2f9c361..6a596965b7f 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -62,7 +62,11 @@ from superset import db from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants from superset.databases.utils import get_table_metadata, make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import OAuth2Error, OAuth2RedirectError +from superset.exceptions import ( + OAuth2Error, + OAuth2RedirectError, + OAuth2TokenRefreshError, +) from superset.key_value.types import JsonKeyValueCodec, KeyValueResource from superset.sql.parse import ( BaseSQLStatement, @@ -828,6 +832,8 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods if config["request_content_type"] == "data" else requests.post(uri, json=req_body, timeout=timeout) ) + if response.status_code in (400, 401, 403): + raise OAuth2TokenRefreshError(response.text) response.raise_for_status() return response.json() diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index 3bcb3a8c871..4874102c903 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -30,9 +30,7 @@ from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError from requests import Session -from requests.exceptions import HTTPError from shillelagh.adapters.api.gsheets.lib import SCOPES -from shillelagh.exceptions import UnauthenticatedError from sqlalchemy.engine import create_engine from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL @@ -43,7 +41,6 @@ from superset.db_engine_specs.base import DatabaseCategory from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException -from superset.superset_typing import OAuth2TokenResponse from superset.utils import json from superset.utils.oauth2 import get_oauth2_access_token @@ -154,7 +151,6 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): "https://accounts.google.com/o/oauth2/v2/auth" ) oauth2_token_request_uri = "https://oauth2.googleapis.com/token" # noqa: S105 - oauth2_exception = UnauthenticatedError @classmethod def get_oauth2_authorization_uri( @@ -218,29 +214,6 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): ) ) - @classmethod - def get_oauth2_fresh_token( - cls, - config: OAuth2ClientConfig, - refresh_token: str, - ) -> OAuth2TokenResponse: - """ - Refresh an OAuth2 access token that has expired. - - When trying to refresh an expired token that was revoked on Google side, - the request fails with 400 status code. - """ - try: - return super().get_oauth2_fresh_token(config, refresh_token) - except HTTPError as ex: - if ex.response is not None and ex.response.status_code == 400: - error_data = ex.response.json() - if error_data.get("error") == "invalid_grant": - raise UnauthenticatedError( - error_data.get("error_description", "Token has been revoked") - ) from ex - raise - @classmethod def impersonate_user( cls, diff --git a/superset/exceptions.py b/superset/exceptions.py index 3a81a249c47..afceac9b5d6 100644 --- a/superset/exceptions.py +++ b/superset/exceptions.py @@ -368,6 +368,27 @@ class OAuth2RedirectError(SupersetErrorException): ) +class OAuth2TokenRefreshError(OAuth2RedirectError): + """ + Raised when an OAuth2 refresh token request fails with a 400/401/403 error. + The stored token is no longer valid and the user must re-authenticate. + + Subclasses OAuth2RedirectError so that existing oauth2_exception checks + match it automatically, triggering start_oauth2_dance() via check_for_oauth2. + """ + + def __init__(self, response_text: str) -> None: + SupersetErrorException.__init__( + self, + SupersetError( + message="OAuth2 token refresh failed, re-authentication required.", + error_type=SupersetErrorType.OAUTH2_REDIRECT, + level=ErrorLevel.WARNING, + extra={"error": response_text}, + ), + ) + + class OAuth2Error(SupersetErrorException): """ Exception for when OAuth2 goes wrong. diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 62e6559a1b6..a2a2666e7c0 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -79,9 +79,9 @@ def generate_code_challenge(code_verifier: str) -> str: @backoff.on_exception( backoff.expo, AcquireDistributedLockFailedException, - factor=10, + factor=0.1, base=2, - max_tries=5, + max_tries=8, raise_on_giveup=False, giveup_log_level=logging.DEBUG, ) @@ -143,14 +143,17 @@ def refresh_oauth2_token( config, token.refresh_token, ) - except db_engine_spec.oauth2_exception: + except db_engine_spec.oauth2_exception as ex: # OAuth token is no longer valid, delete it and start OAuth2 dance logger.warning( - "OAuth2 token refresh failed for user=%s db=%s, deleting invalid token", + "OAuth2 token refresh failed for user=%s db=%s, " + "deleting token. Error: %s", user_id, database_id, + ex, ) db.session.delete(token) + db.session.flush() raise except Exception: # non-OAuth related failure, log the exception diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 22ec7d0aa18..5eae41458de 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -21,6 +21,7 @@ from __future__ import annotations import json # noqa: TID251 import re +from datetime import timedelta from textwrap import dedent from typing import Any from urllib.parse import parse_qs, urlparse @@ -44,6 +45,7 @@ from superset.superset_typing import ( ) from superset.utils.core import FilterOperator, GenericDataType from superset.utils.oauth2 import decode_oauth2_state +from tests.conftest import with_config from tests.unit_tests.db_engine_specs.utils import assert_column_spec @@ -597,6 +599,19 @@ def test_extract_errors(mocker: MockerFixture) -> None: assert result == [expected] +@with_config( + { + "CUSTOM_DATABASE_ERRORS": { + "examples": { + re.compile("This connector does not support roles"): ( + "Custom error message", + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + {}, + ) + } + } + }, +) def test_extract_errors_from_config(mocker: MockerFixture) -> None: """ Test that custom error messages are extracted correctly from app config @@ -606,21 +621,6 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None: class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" - mocker.patch( - "flask.current_app.config", - { - "CUSTOM_DATABASE_ERRORS": { - "examples": { - re.compile("This connector does not support roles"): ( - "Custom error message", - SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - {}, - ) - } - } - }, - ) - msg = "This connector does not support roles" result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples") @@ -631,6 +631,19 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None: assert result == [expected] +@with_config( + { + "CUSTOM_DATABASE_ERRORS": { + "examples": { + re.compile("This connector does not support roles"): ( + "Custom error message", + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + {}, + ) + } + } + }, +) def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> None: """ Test that custom error messages are only applied to the specified database_name. @@ -639,21 +652,6 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" - mocker.patch( - "flask.current_app.config", - { - "CUSTOM_DATABASE_ERRORS": { - "examples": { - re.compile("This connector does not support roles"): ( - "Custom error message", - SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - {}, - ) - } - } - }, - ) - msg = "This connector does not support roles" # database_name doesn't match configured one, so default message is used result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples_2") @@ -665,6 +663,27 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non assert result == [expected] +@with_config( + { + "CUSTOM_DATABASE_ERRORS": { + "examples": { + re.compile(r'message="(?P[^"]*)"'): ( + 'Unexpected error: "%(message)s"', + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + { + "custom_doc_links": [ + { + "url": "https://example.com/docs", + "label": "Check documentation", + }, + ], + "show_issue_info": False, + }, + ) + } + } + }, +) def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None: """ Test that custom error messages with regex, custom_doc_links, @@ -674,29 +693,6 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None: class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" - mocker.patch( - "flask.current_app.config", - { - "CUSTOM_DATABASE_ERRORS": { - "examples": { - re.compile(r'message="(?P[^"]*)"'): ( - 'Unexpected error: "%(message)s"', - SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - { - "custom_doc_links": [ - { - "url": "https://example.com/docs", - "label": "Check documentation", - }, - ], - "show_issue_info": False, - }, - ) - } - } - }, - ) - msg = ( "db error: SomeUserError(type=USER_ERROR, name=TABLE_NOT_FOUND, " 'message="line 3:6: Table ' @@ -735,6 +731,7 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None: ] +@with_config({"CUSTOM_DATABASE_ERRORS": {"examples": "not a dict"}}) def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture): """ Test that extract_errors doesn't fail when custom database errors @@ -744,11 +741,6 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture): class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" - mocker.patch( - "flask.current_app.config", - {"CUSTOM_DATABASE_ERRORS": "not a dict"}, - ) - msg = "This connector does not support roles" result = TestEngineSpec.extract_errors(Exception(msg)) @@ -759,6 +751,7 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture): assert result == [expected] +@with_config({"CUSTOM_DATABASE_ERRORS": {"examples": "not a dict"}}) def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture): """ Test that extract_errors doesn't fail when database-specific custom errors @@ -768,11 +761,6 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" - mocker.patch( - "flask.current_app.config", - {"CUSTOM_DATABASE_ERRORS": {"examples": "not a dict"}}, - ) - msg = "This connector does not support roles" result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples") @@ -783,6 +771,19 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture assert result == [expected] +@with_config( + { + "CUSTOM_DATABASE_ERRORS": { + "examples": { + re.compile("This connector does not support roles"): ( + "", + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + {}, + ) + } + } + }, +) def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture): """ Test that when the custom error message is empty, @@ -792,21 +793,6 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture): class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" - mocker.patch( - "flask.current_app.config", - { - "CUSTOM_DATABASE_ERRORS": { - "examples": { - re.compile("This connector does not support roles"): ( - "", - SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - {}, - ) - } - } - }, - ) - msg = "This connector does not support roles" result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples") @@ -817,6 +803,26 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture): assert result == [expected] +@with_config( + { + "CUSTOM_DATABASE_ERRORS": { + "examples": { + re.compile("connection error"): ( + "Examples DB error message", + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + {}, + ) + }, + "examples_2": { + re.compile("connection error"): ( + "Examples_2 DB error message", + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + {}, + ) + }, + } + }, +) def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) -> None: """ Test that custom error messages are matched by database_name. @@ -825,28 +831,6 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) - class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" - mocker.patch( - "flask.current_app.config", - { - "CUSTOM_DATABASE_ERRORS": { - "examples": { - re.compile("connection error"): ( - "Examples DB error message", - SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - {}, - ) - }, - "examples_2": { - re.compile("connection error"): ( - "Examples_2 DB error message", - SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - {}, - ) - }, - } - }, - ) - msg = "connection error occurred" # When database_name is examples_2 we should get that specific message result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples_2") @@ -858,6 +842,19 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) - assert result == [expected] +@with_config( + { + "CUSTOM_DATABASE_ERRORS": { + "examples": { + re.compile("connection error"): ( + "Examples DB error message", + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + {}, + ) + }, + } + }, +) def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None: """ Test that when database_name has no match, the original error message is preserved. @@ -866,21 +863,6 @@ def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None: class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" - mocker.patch( - "flask.current_app.config", - { - "CUSTOM_DATABASE_ERRORS": { - "examples": { - re.compile("connection error"): ( - "Examples DB error message", - SupersetErrorType.GENERIC_DB_ENGINE_ERROR, - {}, - ) - }, - } - }, - ) - msg = "some other error" result = TestEngineSpec.extract_errors(Exception(msg), database_name="examples_2") @@ -980,16 +962,13 @@ def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None: assert query["code_challenge"][0] == expected_challenge +@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)}) def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None: """ Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier. """ from superset.db_engine_specs.base import BaseEngineSpec - mocker.patch( - "flask.current_app.config", - {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)}, - ) mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") mock_post.return_value.json.return_value = { "access_token": "test-access-token", # noqa: S105 @@ -1015,6 +994,7 @@ def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None: assert "code_verifier" not in request_body +@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)}) def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None: """ Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided. @@ -1022,10 +1002,6 @@ def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None: from superset.db_engine_specs.base import BaseEngineSpec from superset.utils.oauth2 import generate_code_verifier - mocker.patch( - "flask.current_app.config", - {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)}, - ) mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") mock_post.return_value.json.return_value = { "access_token": "test-access-token", # noqa: S105 @@ -1097,6 +1073,7 @@ def test_get_oauth2_authorization_uri_additional_params( assert query["access_type"][0] == "offline" +@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)}) def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None: """ Test that a subclass can inject additional params into the token request body @@ -1109,10 +1086,6 @@ def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None: "audience": "https://api.example.com", } - mocker.patch( - "flask.current_app.config", - {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)}, - ) mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") mock_post.return_value.json.return_value = { "access_token": "test-access-token", # noqa: S105 @@ -1143,6 +1116,94 @@ def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None: assert request_body["audience"] == "https://api.example.com" +@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)}) +def test_get_oauth2_fresh_token_success(mocker: MockerFixture) -> None: + """ + Test that get_oauth2_fresh_token returns the token response on success. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = { + "access_token": "new-access-token", + "expires_in": 3600, + } + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + result = BaseEngineSpec.get_oauth2_fresh_token(config, "refresh-token") + assert result == {"access_token": "new-access-token", "expires_in": 3600} + + +@pytest.mark.parametrize("status_code", [400, 401, 403]) +@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)}) +def test_get_oauth2_fresh_token_raises_on_auth_error( + mocker: MockerFixture, + status_code: int, +) -> None: + """ + Test that get_oauth2_fresh_token raises OAuth2TokenRefreshError on 400/401/403. + """ + from superset.db_engine_specs.base import BaseEngineSpec + from superset.exceptions import OAuth2TokenRefreshError + + mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") + mock_post.return_value.status_code = status_code + mock_post.return_value.text = '{"error": "invalid_grant"}' + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + with pytest.raises(OAuth2TokenRefreshError) as exc_info: + BaseEngineSpec.get_oauth2_fresh_token(config, "refresh-token") + + assert exc_info.value.error.extra["error"] == '{"error": "invalid_grant"}' + + +@with_config({"DATABASE_OAUTH2_TIMEOUT": timedelta(seconds=30)}) +def test_get_oauth2_fresh_token_raises_on_server_error(mocker: MockerFixture) -> None: + """ + Test that get_oauth2_fresh_token raises HTTPError (not OAuth2TokenRefreshError) + on 5xx. + """ + from requests.exceptions import HTTPError + + from superset.db_engine_specs.base import BaseEngineSpec + + mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") + mock_post.return_value.status_code = 500 + mock_post.return_value.raise_for_status.side_effect = HTTPError("500 Server Error") + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + with pytest.raises(HTTPError): + BaseEngineSpec.get_oauth2_fresh_token(config, "refresh-token") + + def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None: """ Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if set. @@ -1182,19 +1243,18 @@ def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> N assert error.extra["redirect_uri"] == custom_redirect_uri +@with_config( + { + "SECRET_KEY": "test-secret-key", + "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", + } +) def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None: """ Test that start_oauth2_dance falls back to url_for when no config is set. """ fallback_uri = "http://localhost:8088/api/v1/database/oauth2/" - mocker.patch( - "flask.current_app.config", - { - "SECRET_KEY": "test-secret-key", - "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", - }, - ) mocker.patch( "superset.utils.oauth2.url_for", return_value=fallback_uri, diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index aac67ea1fe2..77f27344858 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -24,11 +24,10 @@ import pandas as pd import pytest from pytest_mock import MockerFixture from requests.exceptions import HTTPError -from shillelagh.exceptions import UnauthenticatedError from sqlalchemy.engine.url import make_url from superset.errors import ErrorLevel, SupersetError, SupersetErrorType -from superset.exceptions import SupersetException +from superset.exceptions import OAuth2TokenRefreshError, SupersetException from superset.sql.parse import Table from superset.superset_typing import OAuth2ClientConfig from superset.utils import json @@ -817,26 +816,20 @@ def test_get_oauth2_fresh_token_invalid_grant( oauth2_config: OAuth2ClientConfig, ) -> None: """ - Test that get_oauth2_fresh_token raises UnauthenticatedError for invalid_grant. + Test that get_oauth2_fresh_token raises OAuth2TokenRefreshError for a 400 response. - When a token is revoked on Google side, the refresh request returns 400 - with error=invalid_grant. + When a token is revoked on Google side, the refresh request returns 400. """ from superset.db_engine_specs.gsheets import GSheetsEngineSpec - mock_response = mocker.MagicMock() - mock_response.status_code = 400 - mock_response.json.return_value = { - "error": "invalid_grant", - "error_description": "Token has been expired or revoked.", - } - http_error = HTTPError() - http_error.response = mock_response - requests = mocker.patch("superset.db_engine_specs.base.requests") - requests.post().raise_for_status.side_effect = http_error + requests.post().status_code = 400 + requests.post().text = ( + '{"error": "invalid_grant",' + ' "error_description": "Token has been expired or revoked."}' + ) - with pytest.raises(UnauthenticatedError): + with pytest.raises(OAuth2TokenRefreshError): GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token") diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index 1ca9a4eb21b..c74b2f9570b 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -137,6 +137,7 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception( refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) db.session.delete.assert_called_with(token) + db.session.flush.assert_called_once() def test_refresh_oauth2_token_keeps_token_on_other_exception(