feat(oauth2): add PKCE support for database OAuth2 authentication (#37067)

This commit is contained in:
Beto Dealmeida
2026-01-30 23:28:10 -05:00
committed by GitHub
parent 05c2354997
commit 5d20dc57d7
10 changed files with 422 additions and 38 deletions

View File

@@ -18,6 +18,7 @@
import json # noqa: TID251
from unittest.mock import MagicMock
from urllib.parse import parse_qs, urlparse
from uuid import UUID
import pytest
@@ -201,6 +202,13 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
"superset.db_engine_specs.base.uuid4",
return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"),
)
mocker.patch(
"superset.db_engine_specs.base.generate_code_verifier",
return_value="xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ",
)
mocker.patch("superset.daos.key_value.KeyValueDAO.delete_expired_entries")
mocker.patch("superset.daos.key_value.KeyValueDAO.create_entry")
mocker.patch("superset.db_engine_specs.base.db.session.commit")
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
@@ -222,22 +230,39 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
mocker.patch("superset.sql_lab.get_query", return_value=query)
payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
assert payload == {
"status": QueryStatus.FAILED,
"error": "You don't have permission to access the data.",
"errors": [
{
"message": "You don't have permission to access the data.",
"error_type": SupersetErrorType.OAUTH2_REDIRECT,
"level": ErrorLevel.WARNING,
"extra": {
"url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3AUSERADMIN&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocalhost%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id",
"tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
"redirect_uri": "http://localhost/api/v1/database/oauth2/",
},
}
],
}
assert payload["status"] == QueryStatus.FAILED
assert payload["error"] == "You don't have permission to access the data."
assert len(payload["errors"]) == 1
error = payload["errors"][0]
assert error["message"] == "You don't have permission to access the data."
assert error["error_type"] == SupersetErrorType.OAUTH2_REDIRECT
assert error["level"] == ErrorLevel.WARNING
assert error["extra"]["tab_id"] == "fb11f528-6eba-4a8a-837e-6b0d39ee9187"
assert error["extra"]["redirect_uri"] == "http://localhost/api/v1/database/oauth2/"
# Parse the OAuth2 authorization URL and verify components individually,
# since the JWT state and PKCE code_challenge are computed deterministically
# from mocked inputs but their exact encoding depends on library internals.
url = urlparse(error["extra"]["url"])
assert url.scheme == "https"
assert url.netloc == "abcd1234.snowflakecomputing.com"
assert url.path == "/oauth/authorize"
params = parse_qs(url.query)
assert params["scope"] == ["refresh_token session:role:USERADMIN"]
assert params["response_type"] == ["code"]
assert params["redirect_uri"] == ["http://localhost/api/v1/database/oauth2/"]
assert params["client_id"] == ["my_client_id"]
assert params["code_challenge_method"] == ["S256"]
# Verify PKCE code_challenge matches the mocked code_verifier
from superset.utils.oauth2 import generate_code_challenge
expected_code_challenge = generate_code_challenge(
"xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ"
)
assert params["code_challenge"] == [expected_code_challenge]
def test_apply_rls(mocker: MockerFixture) -> None: