mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
feat(oauth2): add PKCE support for database OAuth2 authentication (#37067)
This commit is contained in:
@@ -255,7 +255,7 @@ def test_database_connection(
|
||||
"service_account_info": {
|
||||
"type": "service_account",
|
||||
"project_id": "black-sanctum-314419",
|
||||
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
||||
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173", # noqa: E501
|
||||
"private_key": "XXXXXXXXXX",
|
||||
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", # noqa: E501
|
||||
"client_id": "114567578578109757129",
|
||||
@@ -621,6 +621,10 @@ def test_oauth2_happy_path(
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "ZZZ",
|
||||
}
|
||||
mocker.patch(
|
||||
"superset.commands.database.oauth2.KeyValueDAO.get_value",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
state: OAuth2State = {
|
||||
"user_id": 1,
|
||||
@@ -641,7 +645,11 @@ def test_oauth2_happy_path(
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
|
||||
get_oauth2_token.assert_called_with(
|
||||
{"id": "one", "secret": "two"},
|
||||
"XXX",
|
||||
code_verifier=None,
|
||||
)
|
||||
|
||||
token = db.session.query(DatabaseUserOAuth2Tokens).one()
|
||||
assert token.user_id == 1
|
||||
@@ -689,6 +697,10 @@ def test_oauth2_permissions(
|
||||
"expires_in": 3600,
|
||||
"refresh_token": "ZZZ",
|
||||
}
|
||||
mocker.patch(
|
||||
"superset.commands.database.oauth2.KeyValueDAO.get_value",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
state: OAuth2State = {
|
||||
"user_id": 1,
|
||||
@@ -709,7 +721,11 @@ def test_oauth2_permissions(
|
||||
)
|
||||
|
||||
assert response.status_code == 200
|
||||
get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
|
||||
get_oauth2_token.assert_called_with(
|
||||
{"id": "one", "secret": "two"},
|
||||
"XXX",
|
||||
code_verifier=None,
|
||||
)
|
||||
|
||||
token = db.session.query(DatabaseUserOAuth2Tokens).one()
|
||||
assert token.user_id == 1
|
||||
@@ -762,6 +778,10 @@ def test_oauth2_multiple_tokens(
|
||||
"refresh_token": "ZZZ2",
|
||||
},
|
||||
]
|
||||
mocker.patch(
|
||||
"superset.commands.database.oauth2.KeyValueDAO.get_value",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
state: OAuth2State = {
|
||||
"user_id": 1,
|
||||
|
||||
@@ -889,6 +889,124 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) ->
|
||||
assert "access_type" not in query
|
||||
assert "include_granted_scopes" not in query
|
||||
|
||||
# Verify PKCE parameters are NOT included when code_verifier is not provided
|
||||
assert "code_challenge" not in query
|
||||
assert "code_challenge_method" not in query
|
||||
|
||||
|
||||
def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that BaseEngineSpec.get_oauth2_authorization_uri includes PKCE parameters
|
||||
when code_verifier is passed as a parameter (RFC 7636).
|
||||
"""
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.utils.oauth2 import generate_code_challenge, generate_code_verifier
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 1,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||
"tab_id": "1234",
|
||||
}
|
||||
|
||||
url = BaseEngineSpec.get_oauth2_authorization_uri(
|
||||
config, state, code_verifier=code_verifier
|
||||
)
|
||||
parsed = urlparse(url)
|
||||
query = parse_qs(parsed.query)
|
||||
|
||||
# Verify PKCE parameters are included (RFC 7636)
|
||||
assert "code_challenge" in query
|
||||
assert query["code_challenge_method"][0] == "S256"
|
||||
# Verify the code_challenge matches the expected value
|
||||
expected_challenge = generate_code_challenge(code_verifier)
|
||||
assert query["code_challenge"][0] == expected_challenge
|
||||
|
||||
|
||||
def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
|
||||
)
|
||||
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "test-access-token", # noqa: S105
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
result = BaseEngineSpec.get_oauth2_token(config, "auth-code")
|
||||
|
||||
assert result["access_token"] == "test-access-token" # noqa: S105
|
||||
# Verify code_verifier is NOT in the request body
|
||||
call_kwargs = mock_post.call_args
|
||||
request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data")
|
||||
assert "code_verifier" not in request_body
|
||||
|
||||
|
||||
def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided.
|
||||
"""
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.utils.oauth2 import generate_code_verifier
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 30)},
|
||||
)
|
||||
mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
|
||||
mock_post.return_value.json.return_value = {
|
||||
"access_token": "test-access-token", # noqa: S105
|
||||
"expires_in": 3600,
|
||||
}
|
||||
|
||||
config: OAuth2ClientConfig = {
|
||||
"id": "client-id",
|
||||
"secret": "client-secret",
|
||||
"scope": "read write",
|
||||
"redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
|
||||
"authorization_request_uri": "https://oauth.example.com/authorize",
|
||||
"token_request_uri": "https://oauth.example.com/token",
|
||||
"request_content_type": "json",
|
||||
}
|
||||
|
||||
code_verifier = generate_code_verifier()
|
||||
result = BaseEngineSpec.get_oauth2_token(config, "auth-code", code_verifier)
|
||||
|
||||
assert result["access_token"] == "test-access-token" # noqa: S105
|
||||
# Verify code_verifier IS in the request body (PKCE)
|
||||
call_kwargs = mock_post.call_args
|
||||
request_body = call_kwargs.kwargs.get("json") or call_kwargs.kwargs.get("data")
|
||||
assert request_body["code_verifier"] == code_verifier
|
||||
|
||||
|
||||
def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None:
|
||||
"""
|
||||
@@ -904,6 +1022,8 @@ def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> N
|
||||
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
|
||||
},
|
||||
)
|
||||
mocker.patch("superset.daos.key_value.KeyValueDAO")
|
||||
mocker.patch("superset.db_engine_specs.base.db")
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user.id = 1
|
||||
@@ -944,6 +1064,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None
|
||||
"superset.db_engine_specs.base.url_for",
|
||||
return_value=fallback_uri,
|
||||
)
|
||||
mocker.patch("superset.daos.key_value.KeyValueDAO")
|
||||
mocker.patch("superset.db_engine_specs.base.db")
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user.id = 1
|
||||
|
||||
@@ -18,6 +18,7 @@
|
||||
|
||||
import json # noqa: TID251
|
||||
from unittest.mock import MagicMock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
from uuid import UUID
|
||||
|
||||
import pytest
|
||||
@@ -201,6 +202,13 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
|
||||
"superset.db_engine_specs.base.uuid4",
|
||||
return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"),
|
||||
)
|
||||
mocker.patch(
|
||||
"superset.db_engine_specs.base.generate_code_verifier",
|
||||
return_value="xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ",
|
||||
)
|
||||
mocker.patch("superset.daos.key_value.KeyValueDAO.delete_expired_entries")
|
||||
mocker.patch("superset.daos.key_value.KeyValueDAO.create_entry")
|
||||
mocker.patch("superset.db_engine_specs.base.db.session.commit")
|
||||
|
||||
g = mocker.patch("superset.db_engine_specs.base.g")
|
||||
g.user = mocker.MagicMock()
|
||||
@@ -222,22 +230,39 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, app) -> None:
|
||||
mocker.patch("superset.sql_lab.get_query", return_value=query)
|
||||
|
||||
payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
|
||||
assert payload == {
|
||||
"status": QueryStatus.FAILED,
|
||||
"error": "You don't have permission to access the data.",
|
||||
"errors": [
|
||||
{
|
||||
"message": "You don't have permission to access the data.",
|
||||
"error_type": SupersetErrorType.OAUTH2_REDIRECT,
|
||||
"level": ErrorLevel.WARNING,
|
||||
"extra": {
|
||||
"url": "https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3AUSERADMIN&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocalhost%2Fapi%2Fv1%2Fdatabase%2Foauth2%2F&client_id=my_client_id",
|
||||
"tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
|
||||
"redirect_uri": "http://localhost/api/v1/database/oauth2/",
|
||||
},
|
||||
}
|
||||
],
|
||||
}
|
||||
assert payload["status"] == QueryStatus.FAILED
|
||||
assert payload["error"] == "You don't have permission to access the data."
|
||||
assert len(payload["errors"]) == 1
|
||||
|
||||
error = payload["errors"][0]
|
||||
assert error["message"] == "You don't have permission to access the data."
|
||||
assert error["error_type"] == SupersetErrorType.OAUTH2_REDIRECT
|
||||
assert error["level"] == ErrorLevel.WARNING
|
||||
assert error["extra"]["tab_id"] == "fb11f528-6eba-4a8a-837e-6b0d39ee9187"
|
||||
assert error["extra"]["redirect_uri"] == "http://localhost/api/v1/database/oauth2/"
|
||||
|
||||
# Parse the OAuth2 authorization URL and verify components individually,
|
||||
# since the JWT state and PKCE code_challenge are computed deterministically
|
||||
# from mocked inputs but their exact encoding depends on library internals.
|
||||
url = urlparse(error["extra"]["url"])
|
||||
assert url.scheme == "https"
|
||||
assert url.netloc == "abcd1234.snowflakecomputing.com"
|
||||
assert url.path == "/oauth/authorize"
|
||||
|
||||
params = parse_qs(url.query)
|
||||
assert params["scope"] == ["refresh_token session:role:USERADMIN"]
|
||||
assert params["response_type"] == ["code"]
|
||||
assert params["redirect_uri"] == ["http://localhost/api/v1/database/oauth2/"]
|
||||
assert params["client_id"] == ["my_client_id"]
|
||||
assert params["code_challenge_method"] == ["S256"]
|
||||
|
||||
# Verify PKCE code_challenge matches the mocked code_verifier
|
||||
from superset.utils.oauth2 import generate_code_challenge
|
||||
|
||||
expected_code_challenge = generate_code_challenge(
|
||||
"xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ"
|
||||
)
|
||||
assert params["code_challenge"] == [expected_code_challenge]
|
||||
|
||||
|
||||
def test_apply_rls(mocker: MockerFixture) -> None:
|
||||
|
||||
@@ -17,6 +17,8 @@
|
||||
|
||||
# pylint: disable=invalid-name, disallowed-name
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
@@ -25,7 +27,14 @@ from freezegun import freeze_time
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from superset.superset_typing import OAuth2ClientConfig
|
||||
from superset.utils.oauth2 import get_oauth2_access_token, refresh_oauth2_token
|
||||
from superset.utils.oauth2 import (
|
||||
decode_oauth2_state,
|
||||
encode_oauth2_state,
|
||||
generate_code_challenge,
|
||||
generate_code_verifier,
|
||||
get_oauth2_access_token,
|
||||
refresh_oauth2_token,
|
||||
)
|
||||
|
||||
DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {})
|
||||
|
||||
@@ -177,3 +186,96 @@ def test_refresh_oauth2_token_no_access_token_in_response(
|
||||
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_generate_code_verifier_length() -> None:
|
||||
"""
|
||||
Test that generate_code_verifier produces a string of valid length (RFC 7636).
|
||||
"""
|
||||
code_verifier = generate_code_verifier()
|
||||
# RFC 7636 requires 43-128 characters
|
||||
assert 43 <= len(code_verifier) <= 128
|
||||
|
||||
|
||||
def test_generate_code_verifier_uniqueness() -> None:
|
||||
"""
|
||||
Test that generate_code_verifier produces unique values.
|
||||
"""
|
||||
verifiers = {generate_code_verifier() for _ in range(100)}
|
||||
# All generated verifiers should be unique
|
||||
assert len(verifiers) == 100
|
||||
|
||||
|
||||
def test_generate_code_verifier_valid_characters() -> None:
|
||||
"""
|
||||
Test that generate_code_verifier only uses valid characters (RFC 7636).
|
||||
"""
|
||||
code_verifier = generate_code_verifier()
|
||||
# RFC 7636 allows: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
|
||||
# URL-safe base64 uses: [A-Z] / [a-z] / [0-9] / "-" / "_"
|
||||
valid_chars = set(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
|
||||
)
|
||||
assert all(char in valid_chars for char in code_verifier)
|
||||
|
||||
|
||||
def test_generate_code_challenge_s256() -> None:
|
||||
"""
|
||||
Test that generate_code_challenge produces correct S256 challenge.
|
||||
"""
|
||||
# Use a known code_verifier to verify the challenge computation
|
||||
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
|
||||
# Compute expected challenge manually
|
||||
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
expected_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
|
||||
code_challenge = generate_code_challenge(code_verifier)
|
||||
assert code_challenge == expected_challenge
|
||||
|
||||
|
||||
def test_generate_code_challenge_rfc_example() -> None:
|
||||
"""
|
||||
Test PKCE code challenge against RFC 7636 Appendix B example.
|
||||
|
||||
See: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B
|
||||
"""
|
||||
# RFC 7636 example code_verifier (Appendix B)
|
||||
code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
|
||||
# RFC 7636 expected code_challenge for S256 method
|
||||
expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
|
||||
|
||||
code_challenge = generate_code_challenge(code_verifier)
|
||||
assert code_challenge == expected_challenge
|
||||
|
||||
|
||||
def test_encode_decode_oauth2_state(
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
"""
|
||||
Test that encode/decode cycle preserves state fields.
|
||||
"""
|
||||
from superset.superset_typing import OAuth2State
|
||||
|
||||
mocker.patch(
|
||||
"flask.current_app.config",
|
||||
{
|
||||
"SECRET_KEY": "test-secret-key",
|
||||
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
|
||||
},
|
||||
)
|
||||
|
||||
state: OAuth2State = {
|
||||
"database_id": 1,
|
||||
"user_id": 2,
|
||||
"default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
|
||||
"tab_id": "test-tab-id",
|
||||
}
|
||||
|
||||
with freeze_time("2024-01-01"):
|
||||
encoded = encode_oauth2_state(state)
|
||||
decoded = decode_oauth2_state(encoded)
|
||||
|
||||
assert "code_verifier" not in decoded
|
||||
assert decoded["database_id"] == 1
|
||||
assert decoded["user_id"] == 2
|
||||
|
||||
Reference in New Issue
Block a user