diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index fb0e26e77e7..965eec46b12 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -572,6 +572,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods oauth2_token_request_uri: str | None = None oauth2_token_request_type = "data" # noqa: S105 + # Driver-specific query params to be included in `get_oauth2_authorization_uri` + oauth2_additional_auth_uri_query_params: dict[str, Any] = {} + # Driver-specific params to be included in the `get_oauth2_token` request body + oauth2_additional_token_request_params: dict[str, Any] = {} # Driver-specific exception that should be mapped to OAuth2RedirectError oauth2_exception = OAuth2RedirectError @@ -754,6 +758,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "state": encode_oauth2_state(state), "redirect_uri": config["redirect_uri"], "client_id": config["id"], + **cls.oauth2_additional_auth_uri_query_params, } # Add PKCE parameters (RFC 7636) if code_verifier is provided @@ -784,6 +789,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "client_secret": config["secret"], "redirect_uri": config["redirect_uri"], "grant_type": "authorization_code", + **cls.oauth2_additional_token_request_params, } # Add PKCE code_verifier if present (RFC 7636) if code_verifier: diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index 57cc0a25ce9..4978c0af5c5 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -167,6 +167,10 @@ def refresh_oauth2_token( token.access_token_expiration = datetime.now() + timedelta( seconds=token_response["expires_in"] ) + # Support single-use refresh tokens + if new_refresh_token := token_response.get("refresh_token"): + token.refresh_token = new_refresh_token + db.session.add(token) return token.access_token diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index 6c6b98e0593..14fa82b55af 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -1052,6 +1052,97 @@ def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None: assert request_body["code_verifier"] == code_verifier +def test_get_oauth2_authorization_uri_additional_params( + mocker: MockerFixture, +) -> None: + """ + Test that a subclass can inject additional query params into the authorization URI + via `oauth2_additional_auth_uri_query_params`. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + class CustomEngineSpec(BaseEngineSpec): + oauth2_additional_auth_uri_query_params = { + "prompt": "consent", + "access_type": "offline", + } + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + state: OAuth2State = { + "database_id": 1, + "user_id": 1, + "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/", + "tab_id": "1234", + } + + url = CustomEngineSpec.get_oauth2_authorization_uri(config, state) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + # Standard params still present + assert query["response_type"][0] == "code" + assert query["client_id"][0] == "client-id" + + # Additional params included + assert query["prompt"][0] == "consent" + assert query["access_type"][0] == "offline" + + +def test_get_oauth2_token_additional_params(mocker: MockerFixture) -> None: + """ + Test that a subclass can inject additional params into the token request body + via `oauth2_additional_token_request_params`. + """ + from superset.db_engine_specs.base import BaseEngineSpec + + class CustomEngineSpec(BaseEngineSpec): + oauth2_additional_token_request_params = { + "audience": "https://api.example.com", + } + + mocker.patch( + "flask.current_app.config", + {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)}, + ) + mock_post = mocker.patch("superset.db_engine_specs.base.requests.post") + mock_post.return_value.json.return_value = { + "access_token": "test-access-token", # noqa: S105 + "expires_in": 3600, + } + + config: OAuth2ClientConfig = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + result = CustomEngineSpec.get_oauth2_token(config, "auth-code") + + assert result["access_token"] == "test-access-token" # noqa: S105 + call_kwargs = mock_post.call_args + request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data") + + # Standard params still present + assert request_body["grant_type"] == "authorization_code" + assert request_body["client_id"] == "client-id" + + # Additional param included + assert request_body["audience"] == "https://api.example.com" + + def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None: """ Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if set. diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index 08b7cc9c6e7..f04ae26e7c2 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -188,6 +188,62 @@ def test_refresh_oauth2_token_no_access_token_in_response( assert result is None +def test_refresh_oauth2_token_updates_refresh_token( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token updates the refresh token when a new one is returned. + + Some OAuth2 providers issue single-use refresh tokens, where each token refresh + response includes a new refresh token that replaces the previous one. + """ + db = mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.DistributedLock") + db_engine_spec = mocker.MagicMock() + db_engine_spec.get_oauth2_fresh_token.return_value = { + "access_token": "new-access-token", + "expires_in": 3600, + "refresh_token": "new-refresh-token", + } + token = mocker.MagicMock() + token.refresh_token = "old-refresh-token" # noqa: S105 + + with freeze_time("2024-01-01"): + refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + + assert token.access_token == "new-access-token" # noqa: S105 + assert token.access_token_expiration == datetime(2024, 1, 1, 1) + assert token.refresh_token == "new-refresh-token" # noqa: S105 + db.session.add.assert_called_with(token) + + +def test_refresh_oauth2_token_keeps_refresh_token( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token keeps the existing refresh token when none returned. + + When the OAuth2 provider does not issue a new refresh token in the response, + the original refresh token should be preserved. + """ + db = mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.DistributedLock") + db_engine_spec = mocker.MagicMock() + db_engine_spec.get_oauth2_fresh_token.return_value = { + "access_token": "new-access-token", + "expires_in": 3600, + } + token = mocker.MagicMock() + token.refresh_token = "original-refresh-token" # noqa: S105 + + with freeze_time("2024-01-01"): + refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + + assert token.access_token == "new-access-token" # noqa: S105 + assert token.refresh_token == "original-refresh-token" # noqa: S105 + db.session.add.assert_called_with(token) + + def test_generate_code_verifier_length() -> None: """ Test that generate_code_verifier produces a string of valid length (RFC 7636).