feat(oauth2): add PKCE support for database OAuth2 authentication (#37067)

This commit is contained in:
Beto Dealmeida
2026-01-30 23:28:10 -05:00
committed by GitHub
parent 05c2354997
commit 5d20dc57d7
10 changed files with 422 additions and 38 deletions

View File

@@ -255,7 +255,7 @@ def test_database_connection(
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", # noqa: E501
"private_key": "XXXXXXXXXX",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", # noqa: E501
"client_id": "114567578578109757129",
@@ -621,6 +621,10 @@ def test_oauth2_happy_path(
"expires_in": 3600,
"refresh_token": "ZZZ",
}
mocker.patch(
"superset.commands.database.oauth2.KeyValueDAO.get_value",
return_value=None,
)
state: OAuth2State = {
"user_id": 1,
@@ -641,7 +645,11 @@ def test_oauth2_happy_path(
)
assert response.status_code == 200
get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
get_oauth2_token.assert_called_with(
{"id": "one", "secret": "two"},
"XXX",
code_verifier=None,
)
token = db.session.query(DatabaseUserOAuth2Tokens).one()
assert token.user_id == 1
@@ -689,6 +697,10 @@ def test_oauth2_permissions(
"expires_in": 3600,
"refresh_token": "ZZZ",
}
mocker.patch(
"superset.commands.database.oauth2.KeyValueDAO.get_value",
return_value=None,
)
state: OAuth2State = {
"user_id": 1,
@@ -709,7 +721,11 @@ def test_oauth2_permissions(
)
assert response.status_code == 200
get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
get_oauth2_token.assert_called_with(
{"id": "one", "secret": "two"},
"XXX",
code_verifier=None,
)
token = db.session.query(DatabaseUserOAuth2Tokens).one()
assert token.user_id == 1
@@ -762,6 +778,10 @@ def test_oauth2_multiple_tokens(
"refresh_token": "ZZZ2",
},
]
mocker.patch(
"superset.commands.database.oauth2.KeyValueDAO.get_value",
return_value=None,
)
state: OAuth2State = {
"user_id": 1,

View File

@@ -889,6 +889,124 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) ->
assert "access_type" not in query
assert "include_granted_scopes" not in query
# Verify PKCE parameters are NOT included when code_verifier is not provided
assert "code_challenge" not in query
assert "code_challenge_method" not in query
def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None:
"""
Test that BaseEngineSpec.get_oauth2_authorization_uri includes PKCE parameters
when code_verifier is passed as a parameter (RFC 7636).
"""
from urllib.parse import parse_qs, urlparse
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils.oauth2 import generate_code_challenge, generate_code_verifier
config: OAuth2ClientConfig = {
"id": "client-id",
"secret": "client-secret",
"scope": "read write",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://oauth.example.com/authorize",
"token_request_uri": "https://oauth.example.com/token",
"request_content_type": "json",
}
code_verifier = generate_code_verifier()
state: OAuth2State = {
"database_id": 1,
"user_id": 1,
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"tab_id": "1234",
}
url = BaseEngineSpec.get_oauth2_authorization_uri(
config, state, code_verifier=code_verifier
)
parsed = urlparse(url)
query = parse_qs(parsed.query)
# Verify PKCE parameters are included (RFC 7636)
assert "code_challenge" in query
assert query["code_challenge_method"][0] == "S256"
# Verify the code_challenge matches the expected value
expected_challenge = generate_code_challenge(code_verifier)
assert query["code_challenge"][0] == expected_challenge
def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None:
"""
Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier.
"""
from superset.db_engine_specs.base import BaseEngineSpec
mocker.patch(
"flask.current_app.config",
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
)
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
mock_post.return_value.json.return_value = {
"access_token": "test-access-token", # noqa: S105
"expires_in": 3600,
}
config: OAuth2ClientConfig = {
"id": "client-id",
"secret": "client-secret",
"scope": "read write",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://oauth.example.com/authorize",
"token_request_uri": "https://oauth.example.com/token",
"request_content_type": "json",
}
result = BaseEngineSpec.get_oauth2_token(config, "auth-code")
assert result["access_token"] == "test-access-token" # noqa: S105
# Verify code_verifier is NOT in the request body
call_kwargs = mock_post.call_args
request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data")
assert "code_verifier" not in request_body
def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
"""
Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided.
"""
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils.oauth2 import generate_code_verifier
mocker.patch(
"flask.current_app.config",
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
)
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
mock_post.return_value.json.return_value = {
"access_token": "test-access-token", # noqa: S105
"expires_in": 3600,
}
config: OAuth2ClientConfig = {
"id": "client-id",
"secret": "client-secret",
"scope": "read write",
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
"authorization_request_uri": "https://oauth.example.com/authorize",
"token_request_uri": "https://oauth.example.com/token",
"request_content_type": "json",
}
code_verifier = generate_code_verifier()
result = BaseEngineSpec.get_oauth2_token(config, "auth-code", code_verifier)
assert result["access_token"] == "test-access-token" # noqa: S105
# Verify code_verifier IS in the request body (PKCE)
call_kwargs = mock_post.call_args
request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data")
assert request_body["code_verifier"] == code_verifier
def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None:
"""
@@ -904,6 +1022,8 @@ def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> N
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
mocker.patch("superset.daos.key_value.KeyValueDAO")
mocker.patch("superset.db_engine_specs.base.db")
g = mocker.patch("superset.db_engine_specs.base.g")
g.user.id = 1
@@ -944,6 +1064,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
"superset.db_engine_specs.base.url_for",
return_value=fallback_uri,
)
mocker.patch("superset.daos.key_value.KeyValueDAO")
mocker.patch("superset.db_engine_specs.base.db")
g = mocker.patch("superset.db_engine_specs.base.g")
g.user.id = 1

View File

@@ -18,6 +18,7 @@
import json # noqa: TID251
from unittest.mock import MagicMock
from urllib.parse import parse_qs, urlparse
from uuid import UUID
import pytest
@@ -201,6 +202,13 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
"superset.db_engine_specs.base.uuid4",
return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"),
)
mocker.patch(
"superset.db_engine_specs.base.generate_code_verifier",
return_value="xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ",
)
mocker.patch("superset.daos.key_value.KeyValueDAO.delete_expired_entries")
mocker.patch("superset.daos.key_value.KeyValueDAO.create_entry")
mocker.patch("superset.db_engine_specs.base.db.session.commit")
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
@@ -222,22 +230,39 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
mocker.patch("superset.sql_lab.get_query", return_value=query)
payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
assert payload == {
"status": QueryStatus.FAILED,
"error": "You don't have permission to access the data.",
"errors": [
{
"message": "You don't have permission to access the data.",
"error_type": SupersetErrorType.OAUTH2_REDIRECT,
"level": ErrorLevel.WARNING,
"extra": {
"url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3AUSERADMIN&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocalhost%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id",
"tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
"redirect_uri": "http://localhost/api/v1/database/oauth2/",
},
}
],
}
assert payload["status"] == QueryStatus.FAILED
assert payload["error"] == "You don't have permission to access the data."
assert len(payload["errors"]) == 1
error = payload["errors"][0]
assert error["message"] == "You don't have permission to access the data."
assert error["error_type"] == SupersetErrorType.OAUTH2_REDIRECT
assert error["level"] == ErrorLevel.WARNING
assert error["extra"]["tab_id"] == "fb11f528-6eba-4a8a-837e-6b0d39ee9187"
assert error["extra"]["redirect_uri"] == "http://localhost/api/v1/database/oauth2/"
# Parse the OAuth2 authorization URL and verify components individually,
# since the JWT state and PKCE code_challenge are computed deterministically
# from mocked inputs but their exact encoding depends on library internals.
url = urlparse(error["extra"]["url"])
assert url.scheme == "https"
assert url.netloc == "abcd1234.snowflakecomputing.com"
assert url.path == "/oauth/authorize"
params = parse_qs(url.query)
assert params["scope"] == ["refresh_token session:role:USERADMIN"]
assert params["response_type"] == ["code"]
assert params["redirect_uri"] == ["http://localhost/api/v1/database/oauth2/"]
assert params["client_id"] == ["my_client_id"]
assert params["code_challenge_method"] == ["S256"]
# Verify PKCE code_challenge matches the mocked code_verifier
from superset.utils.oauth2 import generate_code_challenge
expected_code_challenge = generate_code_challenge(
"xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ"
)
assert params["code_challenge"] == [expected_code_challenge]
def test_apply_rls(mocker: MockerFixture) -> None:

View File

@@ -17,6 +17,8 @@
# pylint: disable=invalid-name, disallowed-name
import base64
import hashlib
from datetime import datetime
from typing import cast
@@ -25,7 +27,14 @@ from freezegun import freeze_time
from pytest_mock import MockerFixture
from superset.superset_typing import OAuth2ClientConfig
from superset.utils.oauth2 import get_oauth2_access_token, refresh_oauth2_token
from superset.utils.oauth2 import (
decode_oauth2_state,
encode_oauth2_state,
generate_code_challenge,
generate_code_verifier,
get_oauth2_access_token,
refresh_oauth2_token,
)
DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {})
@@ -177,3 +186,96 @@ def test_refresh_oauth2_token_no_access_token_in_response(
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
assert result is None
def test_generate_code_verifier_length() -> None:
"""
Test that generate_code_verifier produces a string of valid length (RFC 7636).
"""
code_verifier = generate_code_verifier()
# RFC 7636 requires 43-128 characters
assert 43 <= len(code_verifier) <= 128
def test_generate_code_verifier_uniqueness() -> None:
"""
Test that generate_code_verifier produces unique values.
"""
verifiers = {generate_code_verifier() for _ in range(100)}
# All generated verifiers should be unique
assert len(verifiers) == 100
def test_generate_code_verifier_valid_characters() -> None:
"""
Test that generate_code_verifier only uses valid characters (RFC 7636).
"""
code_verifier = generate_code_verifier()
# RFC 7636 allows: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
# URL-safe base64 uses: [A-Z] / [a-z] / [0-9] / "-" / "_"
valid_chars = set(
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
)
assert all(char in valid_chars for char in code_verifier)
def test_generate_code_challenge_s256() -> None:
"""
Test that generate_code_challenge produces correct S256 challenge.
"""
# Use a known code_verifier to verify the challenge computation
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
# Compute expected challenge manually
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
expected_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
code_challenge = generate_code_challenge(code_verifier)
assert code_challenge == expected_challenge
def test_generate_code_challenge_rfc_example() -> None:
"""
Test PKCE code challenge against RFC 7636 Appendix B example.
See: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B
"""
# RFC 7636 example code_verifier (Appendix B)
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
# RFC 7636 expected code_challenge for S256 method
expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
code_challenge = generate_code_challenge(code_verifier)
assert code_challenge == expected_challenge
def test_encode_decode_oauth2_state(
mocker: MockerFixture,
) -> None:
"""
Test that encode/decode cycle preserves state fields.
"""
from superset.superset_typing import OAuth2State
mocker.patch(
"flask.current_app.config",
{
"SECRET_KEY": "test-secret-key",
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
state: OAuth2State = {
"database_id": 1,
"user_id": 2,
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
"tab_id": "test-tab-id",
}
with freeze_time("2024-01-01"):
encoded = encode_oauth2_state(state)
decoded = decode_oauth2_state(encoded)
assert "code_verifier" not in decoded
assert decoded["database_id"] == 1
assert decoded["user_id"] == 2