Files
superset2/tests/unit_tests/utils/oauth2_tests.py

338 lines
12 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=invalid-name, disallowed-name
import base64
import hashlib
from datetime import datetime
from typing import cast
import pytest
from freezegun import freeze_time
from pytest_mock import MockerFixture
from superset.superset_typing import OAuth2ClientConfig
from superset.utils.oauth2 import (
decode_oauth2_state,
encode_oauth2_state,
generate_code_challenge,
generate_code_verifier,
get_oauth2_access_token,
refresh_oauth2_token,
)
DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {})
def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None:
"""
Test `get_oauth2_access_token` when there's no token.
"""
db = mocker.patch("superset.utils.oauth2.db")
db_engine_spec = mocker.MagicMock()
db.session.query().filter_by().one_or_none.return_value = None
assert get_oauth2_access_token({}, 1, 1, db_engine_spec) is None
def test_get_oauth2_access_token_base_token_valid(mocker: MockerFixture) -> None:
"""
Test `get_oauth2_access_token` when the token is valid.
"""
db = mocker.patch("superset.utils.oauth2.db")
db_engine_spec = mocker.MagicMock()
token = mocker.MagicMock()
token.access_token = "access-token" # noqa: S105
token.access_token_expiration = datetime(2024, 1, 2)
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-01"):
assert get_oauth2_access_token({}, 1, 1, db_engine_spec) == "access-token"
def test_get_oauth2_access_token_base_refresh(mocker: MockerFixture) -> None:
"""
Test `get_oauth2_access_token` when the token needs to be refreshed.
"""
db = mocker.patch("superset.utils.oauth2.db")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"access_token": "new-token",
"expires_in": 3600,
}
token = mocker.MagicMock()
token.access_token = "access-token" # noqa: S105
token.access_token_expiration = datetime(2024, 1, 1)
token.refresh_token = "refresh-token" # noqa: S105
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-02"):
assert get_oauth2_access_token({}, 1, 1, db_engine_spec) == "new-token"
# check that token was updated
assert token.access_token == "new-token" # noqa: S105
assert token.access_token_expiration == datetime(2024, 1, 2, 1)
db.session.add.assert_called_with(token)
def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None:
"""
Test `get_oauth2_access_token` when token is expired and there's no refresh.
"""
db = mocker.patch("superset.utils.oauth2.db")
db_engine_spec = mocker.MagicMock()
token = mocker.MagicMock()
token.access_token = "access-token" # noqa: S105
token.access_token_expiration = datetime(2024, 1, 1)
token.refresh_token = None
db.session.query().filter_by().one_or_none.return_value = token
with freeze_time("2024-01-02"):
assert get_oauth2_access_token({}, 1, 1, db_engine_spec) is None
# check that token was deleted
db.session.delete.assert_called_with(token)
def test_refresh_oauth2_token_deletes_token_on_oauth2_exception(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token deletes the token on OAuth2-specific exception.
When the token refresh fails with an OAuth2-specific exception (e.g., token
was revoked), the invalid token should be deleted and the exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
db_engine_spec = mocker.MagicMock()
db_engine_spec.oauth2_exception = OAuth2ExceptionError
db_engine_spec.get_oauth2_fresh_token.side_effect = OAuth2ExceptionError(
"Token revoked"
)
token = mocker.MagicMock()
token.refresh_token = "refresh-token" # noqa: S105
with pytest.raises(OAuth2ExceptionError):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
db.session.delete.assert_called_with(token)
def test_refresh_oauth2_token_keeps_token_on_other_exception(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token keeps the token on non-OAuth2 exceptions.
When the token refresh fails with a transient error (e.g., network issue),
the token should be kept (refresh token may still be valid) and the
exception re-raised.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
class OAuth2ExceptionError(Exception):
pass
db_engine_spec = mocker.MagicMock()
db_engine_spec.oauth2_exception = OAuth2ExceptionError
db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Network error")
token = mocker.MagicMock()
token.refresh_token = "refresh-token" # noqa: S105
with pytest.raises(Exception, match="Network error"):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
db.session.delete.assert_not_called()
def test_refresh_oauth2_token_no_access_token_in_response(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token returns None when no access_token in response.
This can happen when the refresh token was revoked.
"""
mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"error": "invalid_grant",
}
token = mocker.MagicMock()
token.refresh_token = "refresh-token" # noqa: S105
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
assert result is None
def test_refresh_oauth2_token_updates_refresh_token(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token updates the refresh token when a new one is returned.
Some OAuth2 providers issue single-use refresh tokens, where each token refresh
response includes a new refresh token that replaces the previous one.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
"refresh_token": "new-refresh-token",
}
token = mocker.MagicMock()
token.refresh_token = "old-refresh-token" # noqa: S105
with freeze_time("2024-01-01"):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
assert token.access_token == "new-access-token" # noqa: S105
assert token.access_token_expiration == datetime(2024, 1, 1, 1)
assert token.refresh_token == "new-refresh-token" # noqa: S105
db.session.add.assert_called_with(token)
def test_refresh_oauth2_token_keeps_refresh_token(
mocker: MockerFixture,
) -> None:
"""
Test that refresh_oauth2_token keeps the existing refresh token when none returned.
When the OAuth2 provider does not issue a new refresh token in the response,
the original refresh token should be preserved.
"""
db = mocker.patch("superset.utils.oauth2.db")
mocker.patch("superset.utils.oauth2.DistributedLock")
db_engine_spec = mocker.MagicMock()
db_engine_spec.get_oauth2_fresh_token.return_value = {
"access_token": "new-access-token",
"expires_in": 3600,
}
token = mocker.MagicMock()
token.refresh_token = "original-refresh-token" # noqa: S105
with freeze_time("2024-01-01"):
refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
assert token.access_token == "new-access-token" # noqa: S105
assert token.refresh_token == "original-refresh-token" # noqa: S105
db.session.add.assert_called_with(token)
def test_generate_code_verifier_length() -> None:
"""
Test that generate_code_verifier produces a string of valid length (RFC 7636).
"""
code_verifier = generate_code_verifier()
# RFC 7636 requires 43-128 characters
assert 43 <= len(code_verifier) <= 128
def test_generate_code_verifier_uniqueness() -> None:
"""
Test that generate_code_verifier produces unique values.
"""
verifiers = {generate_code_verifier() for _ in range(100)}
# All generated verifiers should be unique
assert len(verifiers) == 100
def test_generate_code_verifier_valid_characters() -> None:
"""
Test that generate_code_verifier only uses valid characters (RFC 7636).
"""
code_verifier = generate_code_verifier()
# RFC 7636 allows: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
# URL-safe base64 uses: [A-Z] / [a-z] / [0-9] / "-" / "_"
valid_chars = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
)
assert all(char in valid_chars for char in code_verifier)
def test_generate_code_challenge_s256() -> None:
"""
Test that generate_code_challenge produces correct S256 challenge.
"""
# Use a known code_verifier to verify the challenge computation
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
# Compute expected challenge manually
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
expected_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
code_challenge = generate_code_challenge(code_verifier)
assert code_challenge == expected_challenge
def test_generate_code_challenge_rfc_example() -> None:
"""
Test PKCE code challenge against RFC 7636 Appendix B example.
See: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B
"""
# RFC 7636 example code_verifier (Appendix B)
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
# RFC 7636 expected code_challenge for S256 method
expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
code_challenge = generate_code_challenge(code_verifier)
assert code_challenge == expected_challenge
def test_encode_decode_oauth2_state(
mocker: MockerFixture,
) -> None:
"""
Test that encode/decode cycle preserves state fields.
"""
from superset.superset_typing import OAuth2State
mocker.patch(
"flask.current_app.config",
{
"SECRET_KEY": "test-secret-key",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
state: OAuth2State = {
"database_id": 1,
"user_id": 2,
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"tab_id": "test-tab-id",
}
with freeze_time("2024-01-01"):
encoded = encode_oauth2_state(state)
decoded = decode_oauth2_state(encoded)
assert "code_verifier" not in decoded
assert decoded["database_id"] == 1
assert decoded["user_id"] == 2