mirror of
https://github.com/apache/superset.git
synced 2026-04-07 10:31:50 +00:00
feat: Support OAuth2 single-use refresh tokens (#38364)
This commit is contained in:
@@ -572,6 +572,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
oauth2_token_request_uri: str | None = None
|
||||
oauth2_token_request_type = "data" # noqa: S105
|
||||
|
||||
# Driver-specific query params to be included in `get_oauth2_authorization_uri`
|
||||
oauth2_additional_auth_uri_query_params: dict[str, Any] = {}
|
||||
# Driver-specific params to be included in the `get_oauth2_token` request body
|
||||
oauth2_additional_token_request_params: dict[str, Any] = {}
|
||||
# Driver-specific exception that should be mapped to OAuth2RedirectError
|
||||
oauth2_exception = OAuth2RedirectError
|
||||
|
||||
@@ -754,6 +758,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"state": encode_oauth2_state(state),
|
||||
"redirect_uri": config["redirect_uri"],
|
||||
"client_id": config["id"],
|
||||
**cls.oauth2_additional_auth_uri_query_params,
|
||||
}
|
||||
|
||||
# Add PKCE parameters (RFC 7636) if code_verifier is provided
|
||||
@@ -784,6 +789,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
"client_secret": config["secret"],
|
||||
"redirect_uri": config["redirect_uri"],
|
||||
"grant_type": "authorization_code",
|
||||
**cls.oauth2_additional_token_request_params,
|
||||
}
|
||||
# Add PKCE code_verifier if present (RFC 7636)
|
||||
if code_verifier:
|
||||
|
||||
@@ -167,6 +167,10 @@ def refresh_oauth2_token(
|
||||
token.access_token_expiration = datetime.now() + timedelta(
|
||||
seconds=token_response["expires_in"]
|
||||
)
|
||||
# Support single-use refresh tokens
|
||||
if new_refresh_token := token_response.get("refresh_token"):
|
||||
token.refresh_token = new_refresh_token
|
||||
|
||||
db.session.add(token)
|
||||
|
||||
return token.access_token
|
||||
|
||||
@@ -1052,6 +1052,97 @@ def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
|
||||
assert request_body["code_verifier"] == code_verifier
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_additional_params(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that a subclass can inject additional query params into the authorization URI
|
||||
via `oauth2_additional_auth_uri_query_params`.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class CustomEngineSpec(BaseEngineSpec):
|
||||
oauth2_additional_auth_uri_query_params = {
|
||||
"prompt": "consent",
|
||||
"access_type": "offline",
|
||||
}
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = CustomEngineSpec.get_oauth2_authorization_uri(config, state)
|
||||
parsed = urlparse(url)
|
||||
query = parse_qs(parsed.query)
|
||||
|
||||
# Standard params still present
|
||||
assert query["response_type"][0] == "code"
|
||||
assert query["client_id"][0] == "client-id"
|
||||
|
||||
# Additional params included
|
||||
assert query["prompt"][0] == "consent"
|
||||
assert query["access_type"][0] == "offline"
|
||||
|
||||
|
||||
def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that a subclass can inject additional params into the token request body
|
||||
via `oauth2_additional_token_request_params`.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
class CustomEngineSpec(BaseEngineSpec):
|
||||
oauth2_additional_token_request_params = {
|
||||
"audience": "https://api.example.com",
|
||||
}
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
|
||||
)
|
||||
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "test-access-token", # noqa: S105
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
result = CustomEngineSpec.get_oauth2_token(config, "auth-code")
|
||||
|
||||
assert result["access_token"] == "test-access-token" # noqa: S105
|
||||
call_kwargs = mock_post.call_args
|
||||
request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data")
|
||||
|
||||
# Standard params still present
|
||||
assert request_body["grant_type"] == "authorization_code"
|
||||
assert request_body["client_id"] == "client-id"
|
||||
|
||||
# Additional param included
|
||||
assert request_body["audience"] == "https://api.example.com"
|
||||
|
||||
|
||||
def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if set.
|
||||
|
||||
@@ -188,6 +188,62 @@ def test_refresh_oauth2_token_no_access_token_in_response(
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_updates_refresh_token(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token updates the refresh token when a new one is returned.
|
||||
|
||||
Some OAuth2 providers issue single-use refresh tokens, where each token refresh
|
||||
response includes a new refresh token that replaces the previous one.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "new-refresh-token",
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "old-refresh-token" # noqa: S105
|
||||
|
||||
with freeze_time("2024-01-01"):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
|
||||
assert token.access_token == "new-access-token" # noqa: S105
|
||||
assert token.access_token_expiration == datetime(2024, 1, 1, 1)
|
||||
assert token.refresh_token == "new-refresh-token" # noqa: S105
|
||||
db.session.add.assert_called_with(token)
|
||||
|
||||
|
||||
def test_refresh_oauth2_token_keeps_refresh_token(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that refresh_oauth2_token keeps the existing refresh token when none returned.
|
||||
|
||||
When the OAuth2 provider does not issue a new refresh token in the response,
|
||||
the original refresh token should be preserved.
|
||||
"""
|
||||
db = mocker.patch("superset.utils.oauth2.db")
|
||||
mocker.patch("superset.utils.oauth2.DistributedLock")
|
||||
db_engine_spec = mocker.MagicMock()
|
||||
db_engine_spec.get_oauth2_fresh_token.return_value = {
|
||||
"access_token": "new-access-token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
token = mocker.MagicMock()
|
||||
token.refresh_token = "original-refresh-token" # noqa: S105
|
||||
|
||||
with freeze_time("2024-01-01"):
|
||||
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
|
||||
assert token.access_token == "new-access-token" # noqa: S105
|
||||
assert token.refresh_token == "original-refresh-token" # noqa: S105
|
||||
db.session.add.assert_called_with(token)
|
||||
|
||||
|
||||
def test_generate_code_verifier_length() -> None:
|
||||
"""
|
||||
Test that generate_code_verifier produces a string of valid length (RFC 7636).
|
||||
|
||||
Reference in New Issue
Block a user