mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
feat(oauth2): add PKCE support for database OAuth2 authentication (#37067)
This commit is contained in:
@@ -17,7 +17,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import secrets
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, Iterator, TYPE_CHECKING
|
||||
@@ -40,6 +43,37 @@ JWT_EXPIRATION = timedelta(minutes=5)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# PKCE code verifier length (RFC 7636 recommends 43-128 characters)
|
||||
PKCE_CODE_VERIFIER_LENGTH = 64
|
||||
|
||||
|
||||
def generate_code_verifier() -> str:
|
||||
"""
|
||||
Generate a PKCE code verifier (RFC 7636).
|
||||
|
||||
The code verifier is a high-entropy cryptographic random string using
|
||||
unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~",
|
||||
with a minimum length of 43 characters and a maximum length of 128.
|
||||
"""
|
||||
# Generate random bytes and encode as URL-safe base64
|
||||
random_bytes = secrets.token_bytes(PKCE_CODE_VERIFIER_LENGTH)
|
||||
# Use URL-safe base64 encoding without padding
|
||||
code_verifier = base64.urlsafe_b64encode(random_bytes).rstrip(b"=").decode("ascii")
|
||||
return code_verifier
|
||||
|
||||
|
||||
def generate_code_challenge(code_verifier: str) -> str:
|
||||
"""
|
||||
Generate a PKCE code challenge from a code verifier (RFC 7636).
|
||||
|
||||
Uses the S256 method: BASE64URL(SHA256(code_verifier))
|
||||
"""
|
||||
# Compute SHA-256 hash of the code verifier
|
||||
digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
|
||||
# Encode as URL-safe base64 without padding
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
|
||||
return code_challenge
|
||||
|
||||
|
||||
@backoff.on_exception(
|
||||
backoff.expo,
|
||||
@@ -140,13 +174,14 @@ def encode_oauth2_state(state: OAuth2State) -> str:
|
||||
"""
|
||||
Encode the OAuth2 state.
|
||||
"""
|
||||
payload = {
|
||||
payload: dict[str, Any] = {
|
||||
"exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
|
||||
"database_id": state["database_id"],
|
||||
"user_id": state["user_id"],
|
||||
"default_redirect_uri": state["default_redirect_uri"],
|
||||
"tab_id": state["tab_id"],
|
||||
}
|
||||
|
||||
encoded_state = jwt.encode(
|
||||
payload=payload,
|
||||
key=app.config["SECRET_KEY"],
|
||||
@@ -172,12 +207,12 @@ class OAuth2StateSchema(Schema):
|
||||
data: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> OAuth2State:
|
||||
return OAuth2State(
|
||||
database_id=data["database_id"],
|
||||
user_id=data["user_id"],
|
||||
default_redirect_uri=data["default_redirect_uri"],
|
||||
tab_id=data["tab_id"],
|
||||
)
|
||||
return {
|
||||
"database_id": data["database_id"],
|
||||
"user_id": data["user_id"],
|
||||
"default_redirect_uri": data["default_redirect_uri"],
|
||||
"tab_id": data["tab_id"],
|
||||
}
|
||||
|
||||
class Meta: # pylint: disable=too-few-public-methods
|
||||
# ignore `exp`
|
||||
|
||||
Reference in New Issue
Block a user