Compare commits

...

1 Commits

Author SHA1 Message Date
Vitor Avila
47676a35eb fix(OAuth2): Re-query the OAuth2 token to avoid stale reference 2026-05-12 12:29:55 -03:00
2 changed files with 46 additions and 9 deletions

View File

@@ -38,7 +38,7 @@ from superset.superset_typing import OAuth2ClientConfig, OAuth2State
if TYPE_CHECKING:
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import Database, DatabaseUserOAuth2Tokens
from superset.models.core import Database
JWT_EXPIRATION = timedelta(minutes=5)
@@ -116,7 +116,7 @@ def get_oauth2_access_token(
return token.access_token
if token.refresh_token:
return refresh_oauth2_token(config, database_id, user_id, db_engine_spec, token)
return refresh_oauth2_token(config, database_id, user_id, db_engine_spec)
# since the access token is expired and there's no refresh token, delete the entry
db.session.delete(token)
@@ -129,8 +129,10 @@ def refresh_oauth2_token(
database_id: int,
user_id: int,
db_engine_spec: type[BaseEngineSpec],
token: DatabaseUserOAuth2Tokens,
) -> str | None:
# pylint: disable=import-outside-toplevel
from superset.models.core import DatabaseUserOAuth2Tokens
# Use longer TTL for OAuth2 token refresh (may involve network calls)
with DistributedLock(
namespace="refresh_oauth2_token",
@@ -138,6 +140,15 @@ def refresh_oauth2_token(
user_id=user_id,
database_id=database_id,
):
# Short circuit in case another request already deleted the token
token = (
db.session.query(DatabaseUserOAuth2Tokens)
.filter_by(user_id=user_id, database_id=database_id)
.one_or_none()
)
if token is None:
return None
try:
token_response = db_engine_spec.get_oauth2_fresh_token(
config,

View File

@@ -132,9 +132,10 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
)
token = mocker.MagicMock()
token.refresh_token = "refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with pytest.raises(OAuth2ExceptionError):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
db.session.delete.assert_called_with(token)
db.session.flush.assert_called_once()
@@ -161,9 +162,10 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Network error")
token = mocker.MagicMock()
token.refresh_token = "refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with pytest.raises(Exception, match="Network error"):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
db.session.delete.assert_not_called()
@@ -176,7 +178,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
This can happen when the refresh token was revoked.
"""
mocker.patch("superset.utils.oauth2.db")
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
@@ -184,8 +186,9 @@ def test_refresh_oauth2_token_no_access_token_in_response(
}
token = mocker.MagicMock()
token.refresh_token = "refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert result is None
@@ -209,9 +212,10 @@ def test_refresh_oauth2_token_updates_refresh_token(
}
token = mocker.MagicMock()
token.refresh_token = "old-refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-01"):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert token.access_token == "new-access-token" # noqa: S105
assert token.access_token_expiration == datetime(2024, 1, 1, 1)
@@ -237,15 +241,37 @@ def test_refresh_oauth2_token_keeps_refresh_token(
}
token = mocker.MagicMock()
token.refresh_token = "original-refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-01"):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert token.access_token == "new-access-token" # noqa: S105
assert token.refresh_token == "original-refresh-token" # noqa: S105
db.session.add.assert_called_with(token)
def test_refresh_oauth2_token_returns_none_when_row_deleted_under_lock(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token returns None when the row is gone under the lock.
When concurrent requests are triggered and the first one deletes the token row and
releases the lock before the second one gets to `refresh_oauth2_token`, the token
is queried again to avoid a stale reference.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db.session.query().filter_by().one_or_none.return_value = None
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert result is None
db_engine_spec.get_oauth2_fresh_token.assert_not_called()
def test_generate_code_verifier_length() -> None:
"""
Test that generate_code_verifier produces a string of valid length (RFC 7636).