feat(oauth2): add PKCE support for database OAuth2 authentication (#37067)

This commit is contained in:
Beto Dealmeida
2026-01-30 23:28:10 -05:00
committed by GitHub
parent 05c2354997
commit 5d20dc57d7
10 changed files with 422 additions and 38 deletions

View File

@@ -17,6 +17,8 @@
# pylint: disable=invalid-name, disallowed-name
import base64
import hashlib
from datetime import datetime
from typing import cast
@@ -25,7 +27,14 @@ from freezegun import freeze_time
from pytest_mock import MockerFixture
from superset.superset_typing import OAuth2ClientConfig
from superset.utils.oauth2 import get_oauth2_access_token, refresh_oauth2_token
from superset.utils.oauth2 import (
decode_oauth2_state,
encode_oauth2_state,
generate_code_challenge,
generate_code_verifier,
get_oauth2_access_token,
refresh_oauth2_token,
)
DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {})
@@ -177,3 +186,96 @@ def test_refresh_oauth2_token_no_access_token_in_response(
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
assert result is None
def test_generate_code_verifier_length() -> None:
"""
Test that generate_code_verifier produces a string of valid length (RFC 7636).
"""
code_verifier = generate_code_verifier()
# RFC 7636 requires 43-128 characters
assert 43 <= len(code_verifier) <= 128
def test_generate_code_verifier_uniqueness() -> None:
"""
Test that generate_code_verifier produces unique values.
"""
verifiers = {generate_code_verifier() for _ in range(100)}
# All generated verifiers should be unique
assert len(verifiers) == 100
def test_generate_code_verifier_valid_characters() -> None:
"""
Test that generate_code_verifier only uses valid characters (RFC 7636).
"""
code_verifier = generate_code_verifier()
# RFC 7636 allows: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
# URL-safe base64 uses: [A-Z] / [a-z] / [0-9] / "-" / "_"
valid_chars = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
)
assert all(char in valid_chars for char in code_verifier)
def test_generate_code_challenge_s256() -> None:
"""
Test that generate_code_challenge produces correct S256 challenge.
"""
# Use a known code_verifier to verify the challenge computation
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
# Compute expected challenge manually
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
expected_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
code_challenge = generate_code_challenge(code_verifier)
assert code_challenge == expected_challenge
def test_generate_code_challenge_rfc_example() -> None:
"""
Test PKCE code challenge against RFC 7636 Appendix B example.
See: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B
"""
# RFC 7636 example code_verifier (Appendix B)
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
# RFC 7636 expected code_challenge for S256 method
expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
code_challenge = generate_code_challenge(code_verifier)
assert code_challenge == expected_challenge
def test_encode_decode_oauth2_state(
mocker: MockerFixture,
) -> None:
"""
Test that encode/decode cycle preserves state fields.
"""
from superset.superset_typing import OAuth2State
mocker.patch(
"flask.current_app.config",
{
"SECRET_KEY": "test-secret-key",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
state: OAuth2State = {
"database_id": 1,
"user_id": 2,
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"tab_id": "test-tab-id",
}
with freeze_time("2024-01-01"):
encoded = encode_oauth2_state(state)
decoded = decode_oauth2_state(encoded)
assert "code_verifier" not in decoded
assert decoded["database_id"] == 1
assert decoded["user_id"] == 2