From d40a5cad5dc169a2e703f497d995bb6f76e89c4b Mon Sep 17 00:00:00 2001 From: Vitor Avila <96086495+Vitor-Avila@users.noreply.github.com> Date: Mon, 18 May 2026 13:07:54 -0300 Subject: [PATCH] fix(OAuth2): Re-query the OAuth2 token to avoid stale reference (#40071) --- superset/utils/oauth2.py | 24 ++++- tests/unit_tests/utils/oauth2_tests.py | 131 +++++++++++++++++++++++-- 2 files changed, 146 insertions(+), 9 deletions(-) diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index a2a2666e7c0..020f5397b3c 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -38,7 +38,7 @@ from superset.superset_typing import OAuth2ClientConfig, OAuth2State if TYPE_CHECKING: from superset.db_engine_specs.base import BaseEngineSpec - from superset.models.core import Database, DatabaseUserOAuth2Tokens + from superset.models.core import Database JWT_EXPIRATION = timedelta(minutes=5) @@ -116,7 +116,7 @@ def get_oauth2_access_token( return token.access_token if token.refresh_token: - return refresh_oauth2_token(config, database_id, user_id, db_engine_spec, token) + return refresh_oauth2_token(config, database_id, user_id, db_engine_spec) # since the access token is expired and there's no refresh token, delete the entry db.session.delete(token) @@ -129,8 +129,10 @@ def refresh_oauth2_token( database_id: int, user_id: int, db_engine_spec: type[BaseEngineSpec], - token: DatabaseUserOAuth2Tokens, ) -> str | None: + # pylint: disable=import-outside-toplevel + from superset.models.core import DatabaseUserOAuth2Tokens + # Use longer TTL for OAuth2 token refresh (may involve network calls) with DistributedLock( namespace="refresh_oauth2_token", @@ -138,6 +140,22 @@ def refresh_oauth2_token( user_id=user_id, database_id=database_id, ): + # Short circuit in case another request already deleted the token + token = ( + db.session.query(DatabaseUserOAuth2Tokens) + .filter_by(user_id=user_id, database_id=database_id) + .one_or_none() + ) + if token is None: + return None + + if token.access_token and datetime.now() < token.access_token_expiration: + return token.access_token + + if not token.refresh_token: + db.session.delete(token) + return None + try: token_response = db_engine_spec.get_oauth2_fresh_token( config, diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index c74b2f9570b..ac788ce66e5 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -131,10 +131,12 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception( "Token revoked" ) token = mocker.MagicMock() + token.access_token = None token.refresh_token = "refresh-token" # noqa: S105 + db.session.query().filter_by().one_or_none.return_value = token with pytest.raises(OAuth2ExceptionError): - refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec) db.session.delete.assert_called_with(token) db.session.flush.assert_called_once() @@ -160,10 +162,12 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception( db_engine_spec.oauth2_exception = OAuth2ExceptionError db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Network error") token = mocker.MagicMock() + token.access_token = None token.refresh_token = "refresh-token" # noqa: S105 + db.session.query().filter_by().one_or_none.return_value = token with pytest.raises(Exception, match="Network error"): - refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec) db.session.delete.assert_not_called() @@ -176,16 +180,18 @@ def test_refresh_oauth2_token_no_access_token_in_response( This can happen when the refresh token was revoked. """ - mocker.patch("superset.utils.oauth2.db") + db = mocker.patch("superset.utils.oauth2.db") mocker.patch("superset.utils.oauth2.DistributedLock") db_engine_spec = mocker.MagicMock() db_engine_spec.get_oauth2_fresh_token.return_value = { "error": "invalid_grant", } token = mocker.MagicMock() + token.access_token = None token.refresh_token = "refresh-token" # noqa: S105 + db.session.query().filter_by().one_or_none.return_value = token - result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec) assert result is None @@ -208,10 +214,12 @@ def test_refresh_oauth2_token_updates_refresh_token( "refresh_token": "new-refresh-token", } token = mocker.MagicMock() + token.access_token = None token.refresh_token = "old-refresh-token" # noqa: S105 + db.session.query().filter_by().one_or_none.return_value = token with freeze_time("2024-01-01"): - refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec) assert token.access_token == "new-access-token" # noqa: S105 assert token.access_token_expiration == datetime(2024, 1, 1, 1) @@ -236,16 +244,127 @@ def test_refresh_oauth2_token_keeps_refresh_token( "expires_in": 3600, } token = mocker.MagicMock() + token.access_token = None token.refresh_token = "original-refresh-token" # noqa: S105 + db.session.query().filter_by().one_or_none.return_value = token with freeze_time("2024-01-01"): - refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec) assert token.access_token == "new-access-token" # noqa: S105 assert token.refresh_token == "original-refresh-token" # noqa: S105 db.session.add.assert_called_with(token) +def test_refresh_oauth2_token_refreshes_when_access_token_expired_under_lock( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token triggers a refresh when the access_token is expired. + + When the re-query under the lock returns a token whose access_token has expired + but a refresh_token is available, the function should call the token endpoint + and persist the new access_token. + """ + db = mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.DistributedLock") + db_engine_spec = mocker.MagicMock() + db_engine_spec.get_oauth2_fresh_token.return_value = { + "access_token": "new-access-token", + "expires_in": 3600, + } + token = mocker.MagicMock() + token.access_token = "expired-token" # noqa: S105 + token.access_token_expiration = datetime(2024, 1, 1) + token.refresh_token = "refresh-token" # noqa: S105 + db.session.query().filter_by().one_or_none.return_value = token + + with freeze_time("2024-01-02"): + result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec) + + assert result == "new-access-token" + db_engine_spec.get_oauth2_fresh_token.assert_called_once_with( + DUMMY_OAUTH2_CONFIG, "refresh-token" + ) + db.session.add.assert_called_with(token) + + +def test_refresh_oauth2_token_returns_existing_token_when_still_valid_under_lock( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token returns the existing access_token if still valid. + + When concurrent requests are triggered and the first one refreshes the token and + releases the lock before the second one gets to `refresh_oauth2_token`, the second + request should pick up the already-refreshed access_token instead of refreshing + it again. + """ + db = mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.DistributedLock") + db_engine_spec = mocker.MagicMock() + token = mocker.MagicMock() + token.access_token = "fresh-access-token" # noqa: S105 + token.access_token_expiration = datetime(2024, 1, 2) + token.refresh_token = "refresh-token" # noqa: S105 + db.session.query().filter_by().one_or_none.return_value = token + + with freeze_time("2024-01-01"): + result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec) + + assert result == "fresh-access-token" + db_engine_spec.get_oauth2_fresh_token.assert_not_called() + db.session.delete.assert_not_called() + + +def test_refresh_oauth2_token_deletes_when_no_refresh_token_under_lock( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token deletes the row when there's no refresh_token. + + When the token has expired and the re-query under the lock shows no refresh_token + is available, the row should be deleted and None returned so the caller can + trigger the OAuth2 dance. + """ + db = mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.DistributedLock") + db_engine_spec = mocker.MagicMock() + token = mocker.MagicMock() + token.access_token = "expired-token" # noqa: S105 + token.access_token_expiration = datetime(2024, 1, 1) + token.refresh_token = None + db.session.query().filter_by().one_or_none.return_value = token + + with freeze_time("2024-01-02"): + result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec) + + assert result is None + db.session.delete.assert_called_with(token) + db_engine_spec.get_oauth2_fresh_token.assert_not_called() + + +def test_refresh_oauth2_token_returns_none_when_row_deleted_under_lock( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token returns None when the row is gone under the lock. + + When concurrent requests are triggered and the first one deletes the token row and + releases the lock before the second one gets to `refresh_oauth2_token`, the token + is queried again to avoid a stale reference. + """ + db = mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.DistributedLock") + db_engine_spec = mocker.MagicMock() + db.session.query().filter_by().one_or_none.return_value = None + + result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec) + + assert result is None + db_engine_spec.get_oauth2_fresh_token.assert_not_called() + + def test_generate_code_verifier_length() -> None: """ Test that generate_code_verifier produces a string of valid length (RFC 7636).