feat(oauth2): add PKCE support for database OAuth2 authentication (#37067)

This commit is contained in:
Beto Dealmeida
2026-01-30 23:28:10 -05:00
committed by GitHub
parent 05c2354997
commit 5d20dc57d7
10 changed files with 422 additions and 38 deletions

View File

@@ -21,7 +21,7 @@ from __future__ import annotations
import logging
import re
import warnings
from datetime import datetime
from datetime import datetime, timedelta
from inspect import signature
from re import Match, Pattern
from typing import (
@@ -36,7 +36,7 @@ from typing import (
Union,
)
from urllib.parse import urlencode, urljoin
from uuid import uuid4
from uuid import UUID, uuid4
import pandas as pd
import requests
@@ -63,6 +63,7 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
from superset.sql.parse import (
BaseSQLStatement,
LimitMethod,
@@ -83,7 +84,11 @@ from superset.utils.core import ColumnSpec, GenericDataType, QuerySource
from superset.utils.hashing import hash_from_str
from superset.utils.json import redact_sensitive, reveal_sensitive
from superset.utils.network import is_hostname_valid, is_port_open
from superset.utils.oauth2 import encode_oauth2_state
from superset.utils.oauth2 import (
encode_oauth2_state,
generate_code_challenge,
generate_code_verifier,
)
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
@@ -608,13 +613,38 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
tab sends a message to the original tab informing that authorization was
successful (or not), and then closes. The original tab will automatically
re-run the query after authorization.
PKCE (RFC 7636) is used to protect against authorization code interception
attacks. A code_verifier is generated and stored server-side in the KV store,
while the code_challenge (derived from the verifier) is sent to the
authorization server.
"""
# Prevent circular import.
from superset.daos.key_value import KeyValueDAO
tab_id = str(uuid4())
default_redirect_uri = app.config.get(
"DATABASE_OAUTH2_REDIRECT_URI",
url_for("DatabaseRestApi.oauth2", _external=True),
)
# Generate PKCE code verifier (RFC 7636)
code_verifier = generate_code_verifier()
# Store the code_verifier server-side in the KV store, keyed by tab_id.
# This avoids exposing it in the URL/browser history via the JWT state.
KeyValueDAO.delete_expired_entries(KeyValueResource.PKCE_CODE_VERIFIER)
KeyValueDAO.create_entry(
resource=KeyValueResource.PKCE_CODE_VERIFIER,
value={"code_verifier": code_verifier},
codec=JsonKeyValueCodec(),
key=UUID(tab_id),
expires_on=datetime.now() + timedelta(minutes=5),
)
# We need to commit here because we're going to raise an exception, which will
# revert any non-commited changes.
db.session.commit()
# The state is passed to the OAuth2 provider, and sent back to Superset after
# the user authorizes the access. The redirect endpoint in Superset can then
# inspect the state to figure out to which user/database the access token
@@ -641,7 +671,11 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
if oauth2_config is None:
raise OAuth2Error("No configuration found for OAuth2")
oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state)
oauth_url = cls.get_oauth2_authorization_uri(
oauth2_config,
state,
code_verifier=code_verifier,
)
raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
@@ -685,21 +719,29 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
config: OAuth2ClientConfig,
state: OAuth2State,
code_verifier: str | None = None,
) -> str:
"""
Return URI for initial OAuth2 request.
Uses standard OAuth 2.0 parameters only. Subclasses can override
to add provider-specific parameters (e.g., Google's prompt=consent).
Uses standard OAuth 2.0 parameters plus PKCE (RFC 7636) parameters.
Subclasses can override to add provider-specific parameters
(e.g., Google's prompt=consent).
"""
uri = config["authorization_request_uri"]
params = {
params: dict[str, str] = {
"scope": config["scope"],
"response_type": "code",
"state": encode_oauth2_state(state),
"redirect_uri": config["redirect_uri"],
"client_id": config["id"],
}
# Add PKCE parameters (RFC 7636) if code_verifier is provided
if code_verifier:
params["code_challenge"] = generate_code_challenge(code_verifier)
params["code_challenge_method"] = "S256"
return urljoin(uri, "?" + urlencode(params))
@classmethod
@@ -707,19 +749,27 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
cls,
config: OAuth2ClientConfig,
code: str,
code_verifier: str | None = None,
) -> OAuth2TokenResponse:
"""
Exchange authorization code for refresh/access tokens.
If code_verifier is provided (PKCE flow), it will be included in the
token request per RFC 7636.
"""
timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
req_body = {
req_body: dict[str, str] = {
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
# Add PKCE code_verifier if present (RFC 7636)
if code_verifier:
req_body["code_verifier"] = code_verifier
response = (
requests.post(uri, data=req_body, timeout=timeout)
if config["request_content_type"] == "data"

View File

@@ -161,6 +161,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
cls,
config: "OAuth2ClientConfig",
state: "OAuth2State",
code_verifier: str | None = None,
) -> str:
"""
Return URI for initial OAuth2 request with Google-specific parameters.
@@ -172,10 +173,10 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
"""
from urllib.parse import urlencode, urljoin
from superset.utils.oauth2 import encode_oauth2_state
from superset.utils.oauth2 import encode_oauth2_state, generate_code_challenge
uri = config["authorization_request_uri"]
params = {
params: dict[str, str] = {
"scope": config["scope"],
"response_type": "code",
"state": encode_oauth2_state(state),
@@ -186,6 +187,12 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
"include_granted_scopes": "false",
"prompt": "consent",
}
# Add PKCE parameters (RFC 7636) if code_verifier is provided
if code_verifier:
params["code_challenge"] = generate_code_challenge(code_verifier)
params["code_challenge_method"] = "S256"
return urljoin(uri, "?" + urlencode(params))
@classmethod