Compare commits

...

2 Commits

Author SHA1 Message Date
Vitor Avila
7575f31a49 Addressing PR feedback 2026-05-13 15:15:11 -03:00
Vitor Avila
47676a35eb fix(OAuth2): Re-query the OAuth2 token to avoid stale reference 2026-05-12 12:29:55 -03:00
2 changed files with 146 additions and 9 deletions

View File

@@ -38,7 +38,7 @@ from superset.superset_typing import OAuth2ClientConfig, OAuth2State
if TYPE_CHECKING:
from superset.db_engine_specs.base import BaseEngineSpec
from superset.models.core import Database, DatabaseUserOAuth2Tokens
from superset.models.core import Database
JWT_EXPIRATION = timedelta(minutes=5)
@@ -116,7 +116,7 @@ def get_oauth2_access_token(
return token.access_token
if token.refresh_token:
return refresh_oauth2_token(config, database_id, user_id, db_engine_spec, token)
return refresh_oauth2_token(config, database_id, user_id, db_engine_spec)
# since the access token is expired and there's no refresh token, delete the entry
db.session.delete(token)
@@ -129,8 +129,10 @@ def refresh_oauth2_token(
database_id: int,
user_id: int,
db_engine_spec: type[BaseEngineSpec],
token: DatabaseUserOAuth2Tokens,
) -> str | None:
# pylint: disable=import-outside-toplevel
from superset.models.core import DatabaseUserOAuth2Tokens
# Use longer TTL for OAuth2 token refresh (may involve network calls)
with DistributedLock(
namespace="refresh_oauth2_token",
@@ -138,6 +140,22 @@ def refresh_oauth2_token(
user_id=user_id,
database_id=database_id,
):
# Short circuit in case another request already deleted the token
token = (
db.session.query(DatabaseUserOAuth2Tokens)
.filter_by(user_id=user_id, database_id=database_id)
.one_or_none()
)
if token is None:
return None
if token.access_token and datetime.now() < token.access_token_expiration:
return token.access_token
if not token.refresh_token:
db.session.delete(token)
return None
try:
token_response = db_engine_spec.get_oauth2_fresh_token(
config,

View File

@@ -131,10 +131,12 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
"Token revoked"
)
token = mocker.MagicMock()
token.access_token = None
token.refresh_token = "refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with pytest.raises(OAuth2ExceptionError):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
db.session.delete.assert_called_with(token)
db.session.flush.assert_called_once()
@@ -160,10 +162,12 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
db_engine_spec.oauth2_exception = OAuth2ExceptionError
db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Network error")
token = mocker.MagicMock()
token.access_token = None
token.refresh_token = "refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with pytest.raises(Exception, match="Network error"):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
db.session.delete.assert_not_called()
@@ -176,16 +180,18 @@ def test_refresh_oauth2_token_no_access_token_in_response(
This can happen when the refresh token was revoked.
"""
mocker.patch("superset.utils.oauth2.db")
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"error": "invalid_grant",
}
token = mocker.MagicMock()
token.access_token = None
token.refresh_token = "refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert result is None
@@ -208,10 +214,12 @@ def test_refresh_oauth2_token_updates_refresh_token(
"refresh_token": "new-refresh-token",
}
token = mocker.MagicMock()
token.access_token = None
token.refresh_token = "old-refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-01"):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert token.access_token == "new-access-token" # noqa: S105
assert token.access_token_expiration == datetime(2024, 1, 1, 1)
@@ -236,16 +244,127 @@ def test_refresh_oauth2_token_keeps_refresh_token(
"expires_in": 3600,
}
token = mocker.MagicMock()
token.access_token = None
token.refresh_token = "original-refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-01"):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert token.access_token == "new-access-token" # noqa: S105
assert token.refresh_token == "original-refresh-token" # noqa: S105
db.session.add.assert_called_with(token)
def test_refresh_oauth2_token_refreshes_when_access_token_expired_under_lock(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token triggers a refresh when the access_token is expired.
When the re-query under the lock returns a token whose access_token has expired
but a refresh_token is available, the function should call the token endpoint
and persist the new access_token.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
}
token = mocker.MagicMock()
token.access_token = "expired-token" # noqa: S105
token.access_token_expiration = datetime(2024, 1, 1)
token.refresh_token = "refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-02"):
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert result == "new-access-token"
db_engine_spec.get_oauth2_fresh_token.assert_called_once_with(
DUMMY_OAUTH2_CONFIG, "refresh-token"
)
db.session.add.assert_called_with(token)
def test_refresh_oauth2_token_returns_existing_token_when_still_valid_under_lock(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token returns the existing access_token if still valid.
When concurrent requests are triggered and the first one refreshes the token and
releases the lock before the second one gets to `refresh_oauth2_token`, the second
request should pick up the already-refreshed access_token instead of refreshing
it again.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
token = mocker.MagicMock()
token.access_token = "fresh-access-token" # noqa: S105
token.access_token_expiration = datetime(2024, 1, 2)
token.refresh_token = "refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-01"):
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert result == "fresh-access-token"
db_engine_spec.get_oauth2_fresh_token.assert_not_called()
db.session.delete.assert_not_called()
def test_refresh_oauth2_token_deletes_when_no_refresh_token_under_lock(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token deletes the row when there's no refresh_token.
When the token has expired and the re-query under the lock shows no refresh_token
is available, the row should be deleted and None returned so the caller can
trigger the OAuth2 dance.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
token = mocker.MagicMock()
token.access_token = "expired-token" # noqa: S105
token.access_token_expiration = datetime(2024, 1, 1)
token.refresh_token = None
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-02"):
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert result is None
db.session.delete.assert_called_with(token)
db_engine_spec.get_oauth2_fresh_token.assert_not_called()
def test_refresh_oauth2_token_returns_none_when_row_deleted_under_lock(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token returns None when the row is gone under the lock.
When concurrent requests are triggered and the first one deletes the token row and
releases the lock before the second one gets to `refresh_oauth2_token`, the token
is queried again to avoid a stale reference.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db.session.query().filter_by().one_or_none.return_value = None
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
assert result is None
db_engine_spec.get_oauth2_fresh_token.assert_not_called()
def test_generate_code_verifier_length() -> None:
"""
Test that generate_code_verifier produces a string of valid length (RFC 7636).