mirror of
https://github.com/apache/superset.git
synced 2026-05-13 11:55:16 +00:00
Compare commits
1 Commits
master
...
fix/oauth2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
47676a35eb |
@@ -38,7 +38,7 @@ from superset.superset_typing import OAuth2ClientConfig, OAuth2State
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.models.core import Database, DatabaseUserOAuth2Tokens
|
||||
from superset.models.core import Database
|
||||
|
||||
JWT_EXPIRATION = timedelta(minutes=5)
|
||||
|
||||
@@ -116,7 +116,7 @@ def get_oauth2_access_token(
|
||||
return token.access_token
|
||||
|
||||
if token.refresh_token:
|
||||
return refresh_oauth2_token(config, database_id, user_id, db_engine_spec, token)
|
||||
return refresh_oauth2_token(config, database_id, user_id, db_engine_spec)
|
||||
|
||||
# since the access token is expired and there's no refresh token, delete the entry
|
||||
db.session.delete(token)
|
||||
@@ -129,8 +129,10 @@ def refresh_oauth2_token(
|
||||
database_id: int,
|
||||
user_id: int,
|
||||
db_engine_spec: type[BaseEngineSpec],
|
||||
token: DatabaseUserOAuth2Tokens,
|
||||
) -> str | None:
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from superset.models.core import DatabaseUserOAuth2Tokens
|
||||
|
||||
# Use longer TTL for OAuth2 token refresh (may involve network calls)
|
||||
with DistributedLock(
|
||||
namespace="refresh_oauth2_token",
|
||||
@@ -138,6 +140,15 @@ def refresh_oauth2_token(
|
||||
user_id=user_id,
|
||||
database_id=database_id,
|
||||
):
|
||||
# Short circuit in case another request already deleted the token
|
||||
token = (
|
||||
db.session.query(DatabaseUserOAuth2Tokens)
|
||||
.filter_by(user_id=user_id, database_id=database_id)
|
||||
.one_or_none()
|
||||
)
|
||||
if token is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
token_response = db_engine_spec.get_oauth2_fresh_token(
|
||||
config,
|
||||
|
||||
@@ -132,9 +132,10 @@ def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
|
||||
)
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with pytest.raises(OAuth2ExceptionError):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
db.session.delete.assert_called_with(token)
|
||||
db.session.flush.assert_called_once()
|
||||
@@ -161,9 +162,10 @@ def test_refresh_oauth2_token_keeps_token_on_other_exception(
|
||||
db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Network error")
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with pytest.raises(Exception, match="Network error"):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
db.session.delete.assert_not_called()
|
||||
|
||||
@@ -176,7 +178,7 @@ def test_refresh_oauth2_token_no_access_token_in_response(
|
||||
|
||||
This can happen when the refresh token was revoked.
|
||||
"""
|
||||
mocker.patch("superset.utils.oauth2.db")
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||
@@ -184,8 +186,9 @@ def test_refresh_oauth2_token_no_access_token_in_response(
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -209,9 +212,10 @@ def test_refresh_oauth2_token_updates_refresh_token(
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "old-refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with freeze_time("2024-01-01"):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert token.access_token == "new-access-token" # noqa: S105
|
||||
assert token.access_token_expiration == datetime(2024, 1, 1, 1)
|
||||
@@ -237,15 +241,37 @@ def test_refresh_oauth2_token_keeps_refresh_token(
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "original-refresh-token" # noqa: S105
|
||||
db.session.query().filter_by().one_or_none.return_value = token
|
||||
|
||||
with freeze_time("2024-01-01"):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert token.access_token == "new-access-token" # noqa: S105
|
||||
assert token.refresh_token == "original-refresh-token" # noqa: S105
|
||||
db.session.add.assert_called_with(token)
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_returns_none_when_row_deleted_under_lock(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token returns None when the row is gone under the lock.
|
||||
|
||||
When concurrent requests are triggered and the first one deletes the token row and
|
||||
releases the lock before the second one gets to `refresh_oauth2_token`, the token
|
||||
is queried again to avoid a stale reference.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db.session.query().filter_by().one_or_none.return_value = None
|
||||
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec)
|
||||
|
||||
assert result is None
|
||||
db_engine_spec.get_oauth2_fresh_token.assert_not_called()
|
||||
|
||||
|
||||
def test_generate_code_verifier_length() -> None:
|
||||
"""
|
||||
Test that generate_code_verifier produces a string of valid length (RFC 7636).
|
||||
|
||||
Reference in New Issue
Block a user