mirror of
https://github.com/apache/superset.git
synced 2026-05-14 20:35:23 +00:00
Compare commits
17 Commits
embedded-e
...
amin/mcp-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3e997cd5c2 | ||
|
|
c1fcabbc55 | ||
|
|
fd80f76661 | ||
|
|
5bb315591b | ||
|
|
d8db1c9230 | ||
|
|
202b19951a | ||
|
|
d675a97686 | ||
|
|
ba97a29468 | ||
|
|
637f74d0d8 | ||
|
|
5649b28495 | ||
|
|
aaabec317e | ||
|
|
9ece5f42d7 | ||
|
|
76ad5e1bf7 | ||
|
|
38b9ea5484 | ||
|
|
135579b8e1 | ||
|
|
b9aee62f5f | ||
|
|
7ad0b5e3f8 |
@@ -49,7 +49,9 @@ from contextlib import AbstractContextManager
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
|
||||
from flask import g, has_request_context
|
||||
from flask_appbuilder.security.sqla.models import Group, User
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.mcp_service.composite_token_verifier import API_KEY_PASSTHROUGH_CLAIM
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
@@ -148,23 +150,14 @@ def check_tool_permission(func: Callable[..., Any]) -> bool:
|
||||
def load_user_with_relationships(
|
||||
username: str | None = None, email: str | None = None
|
||||
) -> User | None:
|
||||
"""
|
||||
Load a user with all relationships needed for permission checks.
|
||||
"""Load a user with roles and group roles eagerly loaded.
|
||||
|
||||
This function eagerly loads User.roles, User.groups, and Group.roles
|
||||
to prevent detached instance errors when the session is closed/rolled back.
|
||||
|
||||
IMPORTANT: Always use this function instead of security_manager.find_user()
|
||||
when loading users for MCP tool execution. The find_user() method doesn't
|
||||
eagerly load Group.roles, causing "detached instance" errors when permission
|
||||
checks access group.roles after the session is rolled back.
|
||||
|
||||
Args:
|
||||
username: The username to look up (optional if email provided)
|
||||
email: The email to look up (optional if username provided)
|
||||
|
||||
Returns:
|
||||
User object with relationships loaded, or None if not found
|
||||
Delegates to :meth:`SupersetSecurityManager.find_user_with_relationships`,
|
||||
which mirrors FAB's ``find_user`` (including ``auth_username_ci`` and
|
||||
``MultipleResultsFound`` handling) while adding eager loading of
|
||||
``User.roles`` and ``User.groups.roles`` to prevent detached-instance
|
||||
errors when the SQLAlchemy session is closed or rolled back after the
|
||||
lookup — as happens in MCP tool-execution contexts.
|
||||
|
||||
Raises:
|
||||
ValueError: If neither username nor email is provided
|
||||
@@ -172,21 +165,9 @@ def load_user_with_relationships(
|
||||
if not username and not email:
|
||||
raise ValueError("Either username or email must be provided")
|
||||
|
||||
from sqlalchemy.orm import joinedload
|
||||
from superset import security_manager
|
||||
|
||||
from superset.extensions import db
|
||||
|
||||
query = db.session.query(User).options(
|
||||
joinedload(User.roles),
|
||||
joinedload(User.groups).joinedload(Group.roles),
|
||||
)
|
||||
|
||||
if username:
|
||||
query = query.filter(User.username == username)
|
||||
else:
|
||||
query = query.filter(User.email == email)
|
||||
|
||||
return query.first()
|
||||
return security_manager.find_user_with_relationships(username=username, email=email)
|
||||
|
||||
|
||||
def _resolve_user_from_jwt_context(app: Any) -> User | None:
|
||||
@@ -218,6 +199,25 @@ def _resolve_user_from_jwt_context(app: Any) -> User | None:
|
||||
if access_token is None:
|
||||
return None
|
||||
|
||||
# API key pass-through: CompositeTokenVerifier accepted this token
|
||||
# at the transport layer but defers actual validation to
|
||||
# _resolve_user_from_api_key() (priority 2 in get_user_from_request).
|
||||
# Require client_id=="api_key" (set by CompositeTokenVerifier) in addition
|
||||
# to the claim so that an external IdP JWT that happens to include the
|
||||
# claim name is not misclassified as an API-key pass-through.
|
||||
claims = getattr(access_token, "claims", None)
|
||||
if isinstance(claims, dict) and claims.get(API_KEY_PASSTHROUGH_CLAIM):
|
||||
if getattr(access_token, "client_id", None) == "api_key":
|
||||
logger.debug(
|
||||
"API key pass-through token detected, deferring to API key auth"
|
||||
)
|
||||
return None
|
||||
logger.debug(
|
||||
"Ignoring %s claim on non-API-key token (client_id=%r); processing as JWT",
|
||||
API_KEY_PASSTHROUGH_CLAIM,
|
||||
getattr(access_token, "client_id", None),
|
||||
)
|
||||
|
||||
# Use configurable resolver or default
|
||||
from superset.mcp_service.mcp_config import default_user_resolver
|
||||
|
||||
@@ -248,37 +248,57 @@ def _resolve_user_from_jwt_context(app: Any) -> User | None:
|
||||
|
||||
def _resolve_user_from_api_key(app: Any) -> User | None:
|
||||
"""
|
||||
Resolve the current user from an API key in the Authorization header.
|
||||
Resolve the current user from an API key passed via Bearer token.
|
||||
|
||||
Uses FAB SecurityManager's API key validation. Only attempts when
|
||||
FAB_API_KEY_ENABLED is True and a request context is active.
|
||||
Reads the token from FastMCP's per-request ``AccessToken`` (set by
|
||||
``CompositeTokenVerifier`` when a Bearer token matches an API key
|
||||
prefix). The streamable-http transport does not push a Flask request
|
||||
context, so we cannot rely on ``flask.request`` headers — the verifier
|
||||
already saw the token and stashed it on the ``AccessToken``.
|
||||
|
||||
Returns:
|
||||
User object with relationships loaded, or None if no API key present
|
||||
or API key auth is not enabled/available.
|
||||
User object with relationships loaded, or None if no API key
|
||||
pass-through token is present or API key auth is not enabled.
|
||||
|
||||
Raises:
|
||||
PermissionError: If an API key is present but invalid/expired,
|
||||
or if validation is not available in this FAB version.
|
||||
PermissionError: If an API key pass-through token is present but
|
||||
invalid/expired (fail closed — do NOT fall through to weaker
|
||||
auth sources like ``MCP_DEV_USERNAME``), or if validation is
|
||||
not available in this FAB version.
|
||||
"""
|
||||
if not app.config.get("FAB_API_KEY_ENABLED", False) or not has_request_context():
|
||||
if not app.config.get("FAB_API_KEY_ENABLED", False):
|
||||
return None
|
||||
|
||||
try:
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
except ImportError:
|
||||
logger.debug("fastmcp.server.dependencies not available, skipping API key auth")
|
||||
return None
|
||||
|
||||
access_token = get_access_token()
|
||||
if access_token is None:
|
||||
return None
|
||||
|
||||
# Only validate tokens that the CompositeTokenVerifier flagged as
|
||||
# API key pass-throughs. Plain JWTs were already validated by the JWT
|
||||
# verifier and resolved in _resolve_user_from_jwt_context.
|
||||
claims = getattr(access_token, "claims", None)
|
||||
if not (isinstance(claims, dict) and claims.get(API_KEY_PASSTHROUGH_CLAIM)):
|
||||
return None
|
||||
# Defense-in-depth: require client_id=="api_key" (set by CompositeTokenVerifier)
|
||||
# to guard against rogue external IdP JWTs that include the passthrough claim.
|
||||
if getattr(access_token, "client_id", None) != "api_key":
|
||||
return None
|
||||
|
||||
api_key_string = getattr(access_token, "token", None)
|
||||
if not api_key_string:
|
||||
# Passthrough claim is set but the raw token is absent — fail closed
|
||||
# rather than silently falling through to weaker auth sources.
|
||||
raise PermissionError(
|
||||
"API key pass-through token is missing the raw token value."
|
||||
)
|
||||
|
||||
sm = app.appbuilder.sm
|
||||
# extract_api_key_from_request is FAB's method for reading
|
||||
# the Bearer token from the Authorization header and matching prefixes.
|
||||
# Not all FAB versions include this method, so guard with hasattr.
|
||||
if not hasattr(sm, "extract_api_key_from_request"):
|
||||
logger.debug(
|
||||
"FAB SecurityManager does not have extract_api_key_from_request; "
|
||||
"API key authentication is not available in this FAB version"
|
||||
)
|
||||
return None
|
||||
|
||||
api_key_string = sm.extract_api_key_from_request()
|
||||
if api_key_string is None:
|
||||
return None
|
||||
|
||||
if not hasattr(sm, "validate_api_key"):
|
||||
logger.warning(
|
||||
"FAB SecurityManager does not have validate_api_key; "
|
||||
@@ -445,7 +465,6 @@ def _setup_user_context() -> User | None:
|
||||
# tool calls when no per-request middleware refreshes it.
|
||||
# Only clear in app-context-only mode; preserve g.user when
|
||||
# a request context is active (external middleware set it).
|
||||
from flask import has_request_context
|
||||
|
||||
if not has_request_context():
|
||||
g.pop("user", None)
|
||||
@@ -489,7 +508,7 @@ def _setup_user_context() -> User | None:
|
||||
logger.error("DB connection failed on retry during user setup: %s", e)
|
||||
_cleanup_session_on_error()
|
||||
raise
|
||||
except ValueError as e:
|
||||
except (ValueError, PermissionError) as e:
|
||||
# User resolution failed — fail closed. Do not fall back to
|
||||
# g.user from middleware, as that could allow a request to
|
||||
# proceed as a different user in multi-tenant deployments.
|
||||
@@ -535,7 +554,7 @@ def mcp_auth_hook(tool_func: F) -> F: # noqa: C901
|
||||
import inspect
|
||||
import types
|
||||
|
||||
from flask import current_app, has_app_context, has_request_context
|
||||
from flask import current_app, has_app_context
|
||||
|
||||
def _get_app_context_manager() -> AbstractContextManager[None]:
|
||||
"""Push a fresh app context unless a request context is active.
|
||||
|
||||
113
superset/mcp_service/composite_token_verifier.py
Normal file
113
superset/mcp_service/composite_token_verifier.py
Normal file
@@ -0,0 +1,113 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Composite token verifier for MCP authentication.
|
||||
|
||||
Routes Bearer tokens to the appropriate verifier based on prefix:
|
||||
- Tokens matching FAB_API_KEY_PREFIXES (e.g. ``sst_``) are passed through
|
||||
to the Flask layer where ``_resolve_user_from_api_key()`` handles
|
||||
actual validation via FAB SecurityManager.
|
||||
- All other tokens are delegated to the wrapped JWT verifier (when one is
|
||||
configured); when no JWT verifier is configured, non-API-key tokens are
|
||||
rejected at the transport layer.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp.server.auth import AccessToken
|
||||
from fastmcp.server.auth.providers.jwt import TokenVerifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Namespaced claim that flags an AccessToken as an API-key pass-through.
|
||||
# Namespacing avoids collision with custom claims an external IdP might
|
||||
# happen to mint on a JWT — a plain ``_api_key_passthrough`` claim could
|
||||
# be silently misidentified as a Superset API-key request.
|
||||
API_KEY_PASSTHROUGH_CLAIM = "_superset_mcp_api_key_passthrough"
|
||||
|
||||
|
||||
class CompositeTokenVerifier(TokenVerifier):
|
||||
"""Routes Bearer tokens between API key pass-through and JWT verification.
|
||||
|
||||
API key tokens (identified by prefix) are accepted at the transport layer
|
||||
with a marker claim so that ``_resolve_user_from_jwt_context()`` can
|
||||
detect them and fall through to ``_resolve_user_from_api_key()`` for
|
||||
actual validation.
|
||||
|
||||
Args:
|
||||
jwt_verifier: The wrapped JWT verifier for non-API-key tokens.
|
||||
When ``None``, only API-key tokens are accepted; all other
|
||||
Bearer tokens are rejected at the transport layer (used when
|
||||
``MCP_AUTH_ENABLED=False`` but ``FAB_API_KEY_ENABLED=True``).
|
||||
api_key_prefixes: List of prefixes that identify API key tokens
|
||||
(e.g. ``["sst_"]``).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
jwt_verifier: TokenVerifier | None,
|
||||
api_key_prefixes: list[str],
|
||||
) -> None:
|
||||
super().__init__(
|
||||
base_url=getattr(jwt_verifier, "base_url", None),
|
||||
required_scopes=getattr(jwt_verifier, "required_scopes", None) or [],
|
||||
)
|
||||
self._jwt_verifier = jwt_verifier
|
||||
valid: list[str] = [
|
||||
p for p in api_key_prefixes if isinstance(p, str) and p.strip()
|
||||
]
|
||||
invalid = [p for p in api_key_prefixes if p not in valid]
|
||||
if invalid:
|
||||
logger.warning(
|
||||
"FAB_API_KEY_PREFIXES contains invalid entries (ignored): %r", invalid
|
||||
)
|
||||
self._api_key_prefixes = tuple(valid)
|
||||
|
||||
async def verify_token(self, token: str) -> AccessToken | None:
|
||||
"""Verify a Bearer token.
|
||||
|
||||
If the token starts with an API key prefix, return a pass-through
|
||||
AccessToken with the namespaced ``API_KEY_PASSTHROUGH_CLAIM``
|
||||
(``_superset_mcp_api_key_passthrough``). The Flask-layer
|
||||
``_resolve_user_from_api_key()`` performs the real validation.
|
||||
|
||||
Otherwise, delegate to the wrapped JWT verifier when one is
|
||||
configured; if no JWT verifier is configured, reject the token.
|
||||
"""
|
||||
if any(token.startswith(prefix) for prefix in self._api_key_prefixes):
|
||||
logger.debug("API key token detected (prefix match), passing through")
|
||||
# Populate ``scopes`` from ``self.required_scopes`` so FastMCP's
|
||||
# ``RequireAuthMiddleware`` (transport-layer scope check) is
|
||||
# satisfied for API-key requests. Without this, MCP_REQUIRED_SCOPES
|
||||
# being non-empty would 403 every API-key call before
|
||||
# ``_resolve_user_from_api_key`` even runs.
|
||||
return AccessToken(
|
||||
token=token,
|
||||
client_id="api_key",
|
||||
scopes=list(self.required_scopes or []),
|
||||
claims={API_KEY_PASSTHROUGH_CLAIM: True},
|
||||
)
|
||||
|
||||
if self._jwt_verifier is None:
|
||||
logger.debug(
|
||||
"Bearer token does not match any API key prefix and no JWT "
|
||||
"verifier is configured; rejecting"
|
||||
)
|
||||
return None
|
||||
|
||||
return await self._jwt_verifier.verify_token(token)
|
||||
@@ -20,12 +20,15 @@ import logging
|
||||
import secrets
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
||||
from flask import Flask
|
||||
|
||||
from superset.mcp_service.composite_token_verifier import CompositeTokenVerifier
|
||||
from superset.mcp_service.constants import (
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_WARN_THRESHOLD_PCT,
|
||||
)
|
||||
from superset.mcp_service.jwt_verifier import DetailedJWTVerifier
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -284,56 +287,94 @@ MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = {
|
||||
|
||||
|
||||
def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]:
|
||||
"""Default MCP auth factory using app.config values."""
|
||||
if not app.config.get("MCP_AUTH_ENABLED", False):
|
||||
"""Default MCP auth factory using app.config values.
|
||||
|
||||
Returns an auth provider when ``MCP_AUTH_ENABLED=True`` (JWT verifier,
|
||||
optionally wrapped with ``CompositeTokenVerifier`` for API keys) or
|
||||
when only ``FAB_API_KEY_ENABLED=True`` (API-key-only verifier that
|
||||
rejects all non-API-key Bearer tokens at the transport).
|
||||
"""
|
||||
auth_enabled = app.config.get("MCP_AUTH_ENABLED", False)
|
||||
api_key_enabled = app.config.get("FAB_API_KEY_ENABLED", False)
|
||||
|
||||
if not (auth_enabled or api_key_enabled):
|
||||
return None
|
||||
|
||||
jwks_uri = app.config.get("MCP_JWKS_URI")
|
||||
public_key = app.config.get("MCP_JWT_PUBLIC_KEY")
|
||||
secret = app.config.get("MCP_JWT_SECRET")
|
||||
jwt_verifier: Any | None = None
|
||||
|
||||
if not (jwks_uri or public_key or secret):
|
||||
logger.warning("MCP_AUTH_ENABLED is True but no JWT keys/secret configured")
|
||||
return None
|
||||
if auth_enabled:
|
||||
jwks_uri = app.config.get("MCP_JWKS_URI")
|
||||
public_key = app.config.get("MCP_JWT_PUBLIC_KEY")
|
||||
secret = app.config.get("MCP_JWT_SECRET")
|
||||
|
||||
try:
|
||||
debug_errors = app.config.get("MCP_JWT_DEBUG_ERRORS", False)
|
||||
|
||||
common_kwargs: dict[str, Any] = {
|
||||
"issuer": app.config.get("MCP_JWT_ISSUER"),
|
||||
"audience": app.config.get("MCP_JWT_AUDIENCE"),
|
||||
"required_scopes": app.config.get("MCP_REQUIRED_SCOPES", []),
|
||||
}
|
||||
|
||||
# For HS256 (symmetric), use the secret as the public_key parameter
|
||||
if app.config.get("MCP_JWT_ALGORITHM") == "HS256" and secret:
|
||||
common_kwargs["public_key"] = secret
|
||||
common_kwargs["algorithm"] = "HS256"
|
||||
if not (jwks_uri or public_key or secret):
|
||||
logger.warning("MCP_AUTH_ENABLED is True but no JWT keys/secret configured")
|
||||
if not api_key_enabled:
|
||||
return None
|
||||
else:
|
||||
# For RS256 (asymmetric), use public key or JWKS
|
||||
common_kwargs["jwks_uri"] = jwks_uri
|
||||
common_kwargs["public_key"] = public_key
|
||||
common_kwargs["algorithm"] = app.config.get("MCP_JWT_ALGORITHM", "RS256")
|
||||
try:
|
||||
jwt_verifier = _build_jwt_verifier(
|
||||
app=app,
|
||||
jwks_uri=jwks_uri,
|
||||
public_key=public_key,
|
||||
secret=secret,
|
||||
)
|
||||
except Exception: # noqa: BLE001 — JWT lib raises many types; broad catch intentional
|
||||
# Do not log the exception — it may contain secrets (e.g., key material)
|
||||
logger.error("Failed to create MCP JWT verifier")
|
||||
if not api_key_enabled:
|
||||
return None
|
||||
|
||||
if debug_errors:
|
||||
# DetailedJWTVerifier: detailed server-side logging of JWT
|
||||
# validation failures. HTTP responses are always generic per
|
||||
# RFC 6750 Section 3.1.
|
||||
from superset.mcp_service.jwt_verifier import DetailedJWTVerifier
|
||||
|
||||
auth_provider = DetailedJWTVerifier(**common_kwargs)
|
||||
if api_key_enabled:
|
||||
raw_prefixes = app.config.get("FAB_API_KEY_PREFIXES", ["sst_"])
|
||||
# Normalize: a plain string (e.g. "sst_") would iterate as characters;
|
||||
# wrap it in a list so CompositeTokenVerifier receives a proper sequence.
|
||||
if isinstance(raw_prefixes, str):
|
||||
api_key_prefixes = [raw_prefixes]
|
||||
else:
|
||||
# Default JWTVerifier: minimal logging, generic error responses.
|
||||
from fastmcp.server.auth.providers.jwt import JWTVerifier
|
||||
api_key_prefixes = list(raw_prefixes)
|
||||
logger.info("API key auth enabled for MCP")
|
||||
return CompositeTokenVerifier(
|
||||
jwt_verifier=jwt_verifier,
|
||||
api_key_prefixes=api_key_prefixes,
|
||||
)
|
||||
|
||||
auth_provider = JWTVerifier(**common_kwargs)
|
||||
return jwt_verifier
|
||||
|
||||
return auth_provider
|
||||
except Exception:
|
||||
# Do not log the exception — it may contain the HS256 secret
|
||||
# from common_kwargs["public_key"]
|
||||
logger.error("Failed to create MCP auth provider")
|
||||
return None
|
||||
|
||||
def _build_jwt_verifier(
|
||||
app: Flask,
|
||||
jwks_uri: Optional[str],
|
||||
public_key: Optional[str],
|
||||
secret: Optional[str],
|
||||
) -> Any:
|
||||
"""Construct the JWT verifier from configured keys/secret."""
|
||||
debug_errors = app.config.get("MCP_JWT_DEBUG_ERRORS", False)
|
||||
|
||||
common_kwargs: Dict[str, Any] = {
|
||||
"issuer": app.config.get("MCP_JWT_ISSUER"),
|
||||
"audience": app.config.get("MCP_JWT_AUDIENCE"),
|
||||
"required_scopes": app.config.get("MCP_REQUIRED_SCOPES", []),
|
||||
}
|
||||
|
||||
# For HS256 (symmetric), use the secret as the public_key parameter
|
||||
if app.config.get("MCP_JWT_ALGORITHM") == "HS256" and secret:
|
||||
common_kwargs["public_key"] = secret
|
||||
common_kwargs["algorithm"] = "HS256"
|
||||
else:
|
||||
# For RS256 (asymmetric), use public key or JWKS
|
||||
common_kwargs["jwks_uri"] = jwks_uri
|
||||
common_kwargs["public_key"] = public_key
|
||||
common_kwargs["algorithm"] = app.config.get("MCP_JWT_ALGORITHM", "RS256")
|
||||
|
||||
if debug_errors:
|
||||
# DetailedJWTVerifier: detailed server-side logging of JWT
|
||||
# validation failures. HTTP responses are always generic per
|
||||
# RFC 6750 Section 3.1.
|
||||
return DetailedJWTVerifier(**common_kwargs)
|
||||
|
||||
# Default JWTVerifier: minimal logging, generic error responses.
|
||||
return JWTVerifier(**common_kwargs)
|
||||
|
||||
|
||||
def default_user_resolver(app: Any, access_token: Any) -> str | None:
|
||||
|
||||
@@ -429,13 +429,13 @@ def _tool_allowed_for_current_user(tool: Any) -> bool:
|
||||
if not getattr(g, "user", None):
|
||||
try:
|
||||
g.user = get_user_from_request()
|
||||
except ValueError:
|
||||
except (ValueError, PermissionError):
|
||||
return False
|
||||
|
||||
method_permission_name = getattr(tool_func, METHOD_PERMISSION_ATTR, "read")
|
||||
permission_name = f"{PERMISSION_PREFIX}{method_permission_name}"
|
||||
return security_manager.can_access(permission_name, class_permission_name)
|
||||
except (AttributeError, RuntimeError, ValueError):
|
||||
except (AttributeError, RuntimeError, ValueError, PermissionError):
|
||||
logger.debug("Could not evaluate tool search permission", exc_info=True)
|
||||
return False
|
||||
|
||||
@@ -673,7 +673,9 @@ def _create_auth_provider(flask_app: Any) -> Any | None:
|
||||
"""Create an auth provider from Flask app config.
|
||||
|
||||
Tries MCP_AUTH_FACTORY first, then falls back to the default factory
|
||||
when MCP_AUTH_ENABLED is True.
|
||||
when either ``MCP_AUTH_ENABLED`` (JWT auth) or ``FAB_API_KEY_ENABLED``
|
||||
(API key auth) is True. The default factory builds a
|
||||
``CompositeTokenVerifier`` that handles either or both auth modes.
|
||||
"""
|
||||
auth_provider = None
|
||||
if auth_factory := flask_app.config.get("MCP_AUTH_FACTORY"):
|
||||
@@ -686,7 +688,9 @@ def _create_auth_provider(flask_app: Any) -> Any | None:
|
||||
except Exception:
|
||||
# Do not log the exception — it may contain secrets
|
||||
logger.error("Failed to create auth provider from MCP_AUTH_FACTORY")
|
||||
elif flask_app.config.get("MCP_AUTH_ENABLED", False):
|
||||
elif flask_app.config.get("MCP_AUTH_ENABLED", False) or flask_app.config.get(
|
||||
"FAB_API_KEY_ENABLED", False
|
||||
):
|
||||
from superset.mcp_service.mcp_config import (
|
||||
create_default_mcp_auth_factory,
|
||||
)
|
||||
|
||||
@@ -49,7 +49,8 @@ from flask_login import AnonymousUserMixin, LoginManager
|
||||
from jwt.api_jwt import _jwt_global_obj
|
||||
from sqlalchemy import and_, inspect, or_
|
||||
from sqlalchemy.engine.base import Connection
|
||||
from sqlalchemy.orm import eagerload
|
||||
from sqlalchemy.orm import eagerload, joinedload
|
||||
from sqlalchemy.orm.exc import MultipleResultsFound
|
||||
from sqlalchemy.orm.mapper import Mapper
|
||||
from sqlalchemy.orm.query import Query as SqlaQuery
|
||||
from sqlalchemy.sql import exists
|
||||
@@ -401,6 +402,10 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
"PermissionViewMenu",
|
||||
"ViewMenu",
|
||||
"User",
|
||||
# FAB ApiKeyApi blueprint (active when FAB_API_KEY_ENABLED=True).
|
||||
# Listed unconditionally — harmless when the feature is off because
|
||||
# no PVMs exist under this view menu.
|
||||
"ApiKey",
|
||||
} | USER_MODEL_VIEWS
|
||||
|
||||
ALPHA_ONLY_VIEW_MENUS = {
|
||||
@@ -2862,6 +2867,60 @@ class SupersetSecurityManager( # pylint: disable=too-many-public-methods
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
def find_user_with_relationships(
|
||||
self,
|
||||
username: Optional[str] = None,
|
||||
email: Optional[str] = None,
|
||||
) -> Optional[User]:
|
||||
"""Find a user with roles and group roles eagerly loaded.
|
||||
|
||||
Mirrors FAB's ``SecurityManager.find_user``
|
||||
(including ``auth_username_ci`` case-insensitive handling and
|
||||
``MultipleResultsFound`` guard) and additionally eager-loads
|
||||
``User.roles`` and ``User.groups.roles`` to prevent detached-instance
|
||||
errors when the SQLAlchemy session is closed or rolled back after the
|
||||
lookup — as happens in MCP tool-execution contexts.
|
||||
"""
|
||||
eager = [
|
||||
joinedload(self.user_model.roles),
|
||||
joinedload(self.user_model.groups).joinedload("roles"),
|
||||
]
|
||||
if username:
|
||||
try:
|
||||
if self.auth_username_ci:
|
||||
from sqlalchemy import func as sa_func
|
||||
|
||||
return (
|
||||
self.session.query(self.user_model)
|
||||
.options(*eager)
|
||||
.filter(
|
||||
sa_func.lower(self.user_model.username)
|
||||
== sa_func.lower(username)
|
||||
)
|
||||
.one_or_none()
|
||||
)
|
||||
return (
|
||||
self.session.query(self.user_model)
|
||||
.options(*eager)
|
||||
.filter(self.user_model.username == username)
|
||||
.one_or_none()
|
||||
)
|
||||
except MultipleResultsFound:
|
||||
logger.error("Multiple results found for user %s", username)
|
||||
return None
|
||||
if email:
|
||||
try:
|
||||
return (
|
||||
self.session.query(self.user_model)
|
||||
.options(*eager)
|
||||
.filter_by(email=email)
|
||||
.one_or_none()
|
||||
)
|
||||
except MultipleResultsFound:
|
||||
logger.error("Multiple results found for user with email %s", email)
|
||||
return None
|
||||
return None
|
||||
|
||||
def get_anonymous_user(self) -> User:
|
||||
return AnonymousUserMixin()
|
||||
|
||||
|
||||
@@ -15,25 +15,55 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for API key authentication in get_user_from_request()."""
|
||||
"""Tests for API key authentication in get_user_from_request().
|
||||
|
||||
The streamable-http transport does not push a Flask request context, so
|
||||
``_resolve_user_from_api_key`` reads the token from FastMCP's per-request
|
||||
``AccessToken`` (populated by ``CompositeTokenVerifier``) rather than from
|
||||
``flask.request``. These tests mock ``get_access_token`` accordingly.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask import g
|
||||
|
||||
from superset.mcp_service.auth import get_user_from_request
|
||||
from superset.app import SupersetApp
|
||||
from superset.mcp_service.auth import (
|
||||
_resolve_user_from_jwt_context,
|
||||
get_user_from_request,
|
||||
)
|
||||
from superset.mcp_service.composite_token_verifier import API_KEY_PASSTHROUGH_CLAIM
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user():
|
||||
def mock_user() -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.username = "api_key_user"
|
||||
return user
|
||||
|
||||
|
||||
def _passthrough_access_token(token: str) -> MagicMock:
|
||||
"""Build an AccessToken matching what CompositeTokenVerifier emits."""
|
||||
access_token = MagicMock()
|
||||
access_token.token = token
|
||||
access_token.client_id = "api_key"
|
||||
access_token.claims = {API_KEY_PASSTHROUGH_CLAIM: True}
|
||||
return access_token
|
||||
|
||||
|
||||
def _patch_access_token(access_token: MagicMock | None):
|
||||
"""Patch get_access_token where _resolve_user_from_api_key imports it."""
|
||||
return patch(
|
||||
"fastmcp.server.dependencies.get_access_token",
|
||||
return_value=access_token,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _enable_api_keys(app):
|
||||
def _enable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
|
||||
"""Enable FAB API key auth and clear MCP_DEV_USERNAME so the API key
|
||||
path is exercised instead of falling through to the dev-user fallback."""
|
||||
app.config["FAB_API_KEY_ENABLED"] = True
|
||||
@@ -45,7 +75,7 @@ def _enable_api_keys(app):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def _disable_api_keys(app):
|
||||
def _disable_api_keys(app: SupersetApp) -> Generator[None, None, None]:
|
||||
app.config["FAB_API_KEY_ENABLED"] = False
|
||||
old_dev = app.config.pop("MCP_DEV_USERNAME", None)
|
||||
yield
|
||||
@@ -54,24 +84,45 @@ def _disable_api_keys(app):
|
||||
app.config["MCP_DEV_USERNAME"] = old_dev
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _mock_sm_ctx(app: SupersetApp, mock_sm: MagicMock):
|
||||
"""Push an app context with g.user cleared and appbuilder.sm mocked."""
|
||||
with app.app_context():
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
yield
|
||||
|
||||
|
||||
def _patch_load_user_not_found():
|
||||
"""Patch load_user_with_relationships to return None (user not found).
|
||||
|
||||
load_user_with_relationships delegates to the global security_manager
|
||||
(not app.appbuilder.sm), so tests that need the JWT path to raise
|
||||
ValueError("not found") must patch it directly at the module level.
|
||||
"""
|
||||
return patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
|
||||
# -- Valid API key -> user loaded --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_valid_api_key_returns_user(app, mock_user) -> None:
|
||||
"""A valid API key should authenticate and return the user."""
|
||||
def test_valid_api_key_returns_user(app: SupersetApp, mock_user: MagicMock) -> None:
|
||||
"""A valid API key pass-through token should authenticate and return the user."""
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.extract_api_key_from_request.return_value = "sst_abc123"
|
||||
mock_sm.validate_api_key.return_value = mock_user
|
||||
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=mock_user,
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with (
|
||||
_patch_access_token(_passthrough_access_token("sst_abc123")),
|
||||
patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=mock_user,
|
||||
),
|
||||
):
|
||||
result = get_user_from_request()
|
||||
|
||||
@@ -79,75 +130,70 @@ def test_valid_api_key_returns_user(app, mock_user) -> None:
|
||||
mock_sm.validate_api_key.assert_called_once_with("sst_abc123")
|
||||
|
||||
|
||||
# -- Invalid API key -> PermissionError --
|
||||
# -- Invalid API key -> PermissionError (does not silently fall back) --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_invalid_api_key_raises(app) -> None:
|
||||
"""An invalid API key should raise PermissionError."""
|
||||
def test_invalid_api_key_raises(app: SupersetApp) -> None:
|
||||
"""An invalid API key pass-through token should raise PermissionError
|
||||
(fail closed — do NOT fall through to MCP_DEV_USERNAME)."""
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.extract_api_key_from_request.return_value = "sst_bad_key"
|
||||
mock_sm.validate_api_key.return_value = None
|
||||
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_bad_key"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with pytest.raises(PermissionError, match="Invalid or expired API key"):
|
||||
get_user_from_request()
|
||||
# The dangerous fallthrough scenario: dev username IS set, but the
|
||||
# request presented an invalid API key. The dev fallback must not
|
||||
# mask the rejection.
|
||||
app.config["MCP_DEV_USERNAME"] = "admin"
|
||||
try:
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(_passthrough_access_token("sst_bad_key")):
|
||||
with pytest.raises(PermissionError, match="Invalid or expired API key"):
|
||||
get_user_from_request()
|
||||
finally:
|
||||
app.config.pop("MCP_DEV_USERNAME", None)
|
||||
|
||||
|
||||
# -- API key disabled -> falls through to next auth method --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_disable_api_keys")
|
||||
def test_api_key_disabled_skips_auth(app) -> None:
|
||||
"""When FAB_API_KEY_ENABLED is False, API key auth is skipped entirely."""
|
||||
def test_api_key_disabled_skips_auth(app: SupersetApp) -> None:
|
||||
"""When FAB_API_KEY_ENABLED is False, API key auth is skipped entirely
|
||||
even if an AccessToken is present."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
# Without API key auth or MCP_DEV_USERNAME, should raise ValueError
|
||||
# about no authenticated user (not about invalid API key)
|
||||
with pytest.raises(ValueError, match="No authenticated user found"):
|
||||
get_user_from_request()
|
||||
|
||||
# SecurityManager API key methods should never be called
|
||||
mock_sm.extract_api_key_from_request.assert_not_called()
|
||||
|
||||
|
||||
# -- No request context -> API key auth skipped --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_no_request_context_skips_api_key_auth(app) -> None:
|
||||
"""Without a request context, API key auth should be skipped
|
||||
(e.g., during MCP tool discovery with only an app context)."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
with app.app_context():
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
# Explicitly mock has_request_context to False because the test
|
||||
# framework's app fixture may implicitly provide a request context.
|
||||
with patch("superset.mcp_service.auth.has_request_context", return_value=False):
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(_passthrough_access_token("sst_abc123")):
|
||||
with pytest.raises(ValueError, match="No authenticated user found"):
|
||||
get_user_from_request()
|
||||
|
||||
mock_sm.extract_api_key_from_request.assert_not_called()
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
|
||||
|
||||
# -- No AccessToken -> API key auth skipped --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_no_access_token_skips_api_key_auth(app: SupersetApp) -> None:
|
||||
"""Without a FastMCP AccessToken (e.g., MCP_AUTH_ENABLED=False and no
|
||||
auth provider installed), API key auth is skipped."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(None):
|
||||
with pytest.raises(ValueError, match="No authenticated user found"):
|
||||
get_user_from_request()
|
||||
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
|
||||
|
||||
# -- g.user fallback when no higher-priority auth succeeds --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_disable_api_keys")
|
||||
def test_g_user_fallback_when_no_jwt_or_api_key(app, mock_user) -> None:
|
||||
def test_g_user_fallback_when_no_jwt_or_api_key(
|
||||
app: SupersetApp, mock_user: MagicMock
|
||||
) -> None:
|
||||
"""When no JWT or API key auth succeeds and MCP_DEV_USERNAME is not set,
|
||||
g.user (set by external middleware) is used as fallback."""
|
||||
with app.test_request_context():
|
||||
@@ -158,89 +204,174 @@ def test_g_user_fallback_when_no_jwt_or_api_key(app, mock_user) -> None:
|
||||
assert result.username == "api_key_user"
|
||||
|
||||
|
||||
# -- FAB version without extract_api_key_from_request --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_fab_without_extract_method_skips_gracefully(app) -> None:
|
||||
"""If FAB SecurityManager lacks extract_api_key_from_request,
|
||||
API key auth should be skipped with a debug log, not crash."""
|
||||
mock_sm = MagicMock(spec=[]) # empty spec = no attributes
|
||||
|
||||
with app.test_request_context():
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with pytest.raises(ValueError, match="No authenticated user found"):
|
||||
get_user_from_request()
|
||||
|
||||
|
||||
# -- FAB version without validate_api_key --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_fab_without_validate_method_raises(app) -> None:
|
||||
"""If FAB has extract_api_key_from_request but not validate_api_key,
|
||||
should raise PermissionError about unavailable validation."""
|
||||
mock_sm = MagicMock(spec=["extract_api_key_from_request"])
|
||||
mock_sm.extract_api_key_from_request.return_value = "sst_abc123"
|
||||
def test_fab_without_validate_method_raises(app: SupersetApp) -> None:
|
||||
"""If FAB SecurityManager lacks validate_api_key, should raise
|
||||
PermissionError about unavailable validation."""
|
||||
mock_sm = MagicMock(spec=[]) # empty spec = no attributes
|
||||
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with pytest.raises(
|
||||
PermissionError, match="API key validation is not available"
|
||||
):
|
||||
get_user_from_request()
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(_passthrough_access_token("sst_abc123")):
|
||||
with pytest.raises(
|
||||
PermissionError, match="API key validation is not available"
|
||||
):
|
||||
get_user_from_request()
|
||||
|
||||
|
||||
# -- Relationship reload fallback --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_relationship_reload_failure_returns_original_user(app, mock_user) -> None:
|
||||
def test_relationship_reload_failure_returns_original_user(
|
||||
app: SupersetApp, mock_user: MagicMock
|
||||
) -> None:
|
||||
"""If load_user_with_relationships fails, the original user from
|
||||
validate_api_key should be returned as fallback."""
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.extract_api_key_from_request.return_value = "sst_abc123"
|
||||
mock_sm.validate_api_key.return_value = mock_user
|
||||
|
||||
with app.test_request_context(headers={"Authorization": "Bearer sst_abc123"}):
|
||||
g.user = None
|
||||
app.appbuilder = MagicMock()
|
||||
app.appbuilder.sm = mock_sm
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=None,
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with (
|
||||
_patch_access_token(_passthrough_access_token("sst_abc123")),
|
||||
patch(
|
||||
"superset.mcp_service.auth.load_user_with_relationships",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = get_user_from_request()
|
||||
|
||||
assert result is mock_user
|
||||
|
||||
|
||||
# -- AccessToken without passthrough claim (plain JWT) -> skip API key auth --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_jwt_access_token_skips_api_key_auth(app: SupersetApp) -> None:
|
||||
"""When the AccessToken is a plain JWT (no API_KEY_PASSTHROUGH_CLAIM),
|
||||
API key auth is skipped — the JWT was already validated by the JWT
|
||||
verifier and resolved in _resolve_user_from_jwt_context."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
jwt_access_token = MagicMock()
|
||||
jwt_access_token.token = "eyJhbGciOiJIUzI1NiJ9.not-an-api-key" # noqa: S105
|
||||
jwt_access_token.claims = {"sub": "alice"}
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(jwt_access_token), _patch_load_user_not_found():
|
||||
# _resolve_user_from_jwt_context resolves "alice" from JWT claims
|
||||
# and raises ValueError because the username is not a real user.
|
||||
# We assert that _resolve_user_from_api_key did NOT short-circuit
|
||||
# to the API key path.
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
get_user_from_request()
|
||||
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
|
||||
|
||||
# -- API key pass-through detection in JWT context resolver --
|
||||
|
||||
|
||||
def test_jwt_context_with_api_key_passthrough_returns_none(app: SupersetApp) -> None:
|
||||
"""When CompositeTokenVerifier passes through an API key token,
|
||||
_resolve_user_from_jwt_context should detect the namespaced
|
||||
pass-through claim AND client_id=="api_key" and return None so
|
||||
get_user_from_request falls through to _resolve_user_from_api_key."""
|
||||
mock_access_token = MagicMock()
|
||||
mock_access_token.client_id = "api_key"
|
||||
mock_access_token.claims = {API_KEY_PASSTHROUGH_CLAIM: True}
|
||||
|
||||
with patch(
|
||||
"fastmcp.server.dependencies.get_access_token",
|
||||
return_value=mock_access_token,
|
||||
):
|
||||
result = _resolve_user_from_jwt_context(app)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_namespaced_claim_without_api_key_client_id_is_ignored(
|
||||
app: SupersetApp,
|
||||
) -> None:
|
||||
"""An external IdP JWT that includes the namespaced API_KEY_PASSTHROUGH_CLAIM
|
||||
but does NOT have client_id=='api_key' must NOT divert into the API-key path.
|
||||
The client_id guard prevents misclassification / DoS for affected JWT users."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
rogue_token = MagicMock()
|
||||
rogue_token.token = "eyJhbGciOiJSUzI1NiJ9.idp_jwt_with_rogue_claim" # noqa: S105
|
||||
rogue_token.client_id = "some-idp-client"
|
||||
rogue_token.claims = {API_KEY_PASSTHROUGH_CLAIM: True, "sub": "alice"}
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(rogue_token), _patch_load_user_not_found():
|
||||
# JWT path resolves "alice" from claims and raises ValueError
|
||||
# because no such user exists.
|
||||
# validate_api_key must NOT be called — the rogue claim was ignored.
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
get_user_from_request()
|
||||
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
|
||||
|
||||
# -- Plain JWT with a colliding non-namespaced claim is NOT mistaken for API key --
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("_enable_api_keys")
|
||||
def test_unnamespaced_passthrough_claim_does_not_trigger_api_key_path(
|
||||
app: SupersetApp,
|
||||
) -> None:
|
||||
"""A JWT minted by an external IdP that happens to include a custom
|
||||
``_api_key_passthrough`` claim (legacy unnamespaced name) must NOT be
|
||||
treated as an API-key pass-through. Only the namespaced
|
||||
``API_KEY_PASSTHROUGH_CLAIM`` triggers the API-key path."""
|
||||
mock_sm = MagicMock()
|
||||
|
||||
rogue_token = MagicMock()
|
||||
rogue_token.token = "eyJhbGciOiJSUzI1NiJ9.rogue_jwt" # noqa: S105
|
||||
rogue_token.claims = {"_api_key_passthrough": True, "sub": "alice"}
|
||||
|
||||
with _mock_sm_ctx(app, mock_sm):
|
||||
with _patch_access_token(rogue_token), _patch_load_user_not_found():
|
||||
# JWT path resolves "alice" from claims and raises ValueError.
|
||||
# validate_api_key must NOT be called — the rogue claim was ignored.
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
get_user_from_request()
|
||||
|
||||
mock_sm.validate_api_key.assert_not_called()
|
||||
|
||||
|
||||
# -- SecurityManager method name regression test --
|
||||
|
||||
|
||||
def test_security_manager_has_expected_api_key_methods() -> None:
|
||||
"""Regression test: verify the SecurityManager method names referenced in
|
||||
auth._resolve_user_from_api_key() actually exist on the FAB SecurityManager
|
||||
class. This catches future renames before they silently break API key auth
|
||||
at runtime (SC-99414: _extract_api_key_from_request vs
|
||||
extract_api_key_from_request)."""
|
||||
from superset import security_manager
|
||||
def test_security_manager_has_expected_api_key_methods(app: SupersetApp) -> None:
|
||||
"""Regression test: verify the SecurityManager method name referenced in
|
||||
auth._resolve_user_from_api_key() actually exists on the FAB
|
||||
SecurityManager class. Catches future renames before they silently break
|
||||
API key auth at runtime (see PR #39437)."""
|
||||
with app.app_context():
|
||||
from superset import security_manager
|
||||
|
||||
sm = security_manager
|
||||
assert hasattr(sm, "extract_api_key_from_request"), (
|
||||
"FAB SecurityManager is missing 'extract_api_key_from_request'. "
|
||||
"auth._resolve_user_from_api_key() references this method by name — "
|
||||
"update auth.py if the FAB API changed."
|
||||
)
|
||||
assert hasattr(sm, "validate_api_key"), (
|
||||
"FAB SecurityManager is missing 'validate_api_key'. "
|
||||
"auth._resolve_user_from_api_key() references this method by name — "
|
||||
"update auth.py if the FAB API changed."
|
||||
)
|
||||
sm = security_manager
|
||||
assert hasattr(sm, "validate_api_key"), (
|
||||
"FAB SecurityManager is missing 'validate_api_key'. "
|
||||
"auth._resolve_user_from_api_key() references this method by name — "
|
||||
"update auth.py if the FAB API changed."
|
||||
)
|
||||
|
||||
|
||||
def test_security_manager_has_find_user_with_relationships(app: SupersetApp) -> None:
|
||||
"""Regression test: verify SupersetSecurityManager.find_user_with_relationships
|
||||
exists. load_user_with_relationships() in auth.py delegates to it — a rename
|
||||
or removal would silently break MCP user resolution at runtime."""
|
||||
with app.app_context():
|
||||
from superset import security_manager
|
||||
|
||||
assert hasattr(security_manager, "find_user_with_relationships"), (
|
||||
"SupersetSecurityManager is missing 'find_user_with_relationships'. "
|
||||
"auth.load_user_with_relationships() delegates to this method — "
|
||||
"update auth.py if the method was renamed or removed."
|
||||
)
|
||||
|
||||
@@ -285,7 +285,7 @@ def test_mcp_auth_hook_clears_stale_g_user(app) -> None:
|
||||
# framework's autouse app_context fixture may implicitly provide
|
||||
# a request context in some CI environments.
|
||||
with (
|
||||
patch("flask.has_request_context", return_value=False),
|
||||
patch("superset.mcp_service.auth.has_request_context", return_value=False),
|
||||
patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
side_effect=lambda: _assert_cleared_then_return(),
|
||||
@@ -324,7 +324,7 @@ def test_mcp_auth_hook_clears_stale_g_user_async(app) -> None:
|
||||
with app.app_context():
|
||||
g.user = stale_user
|
||||
with (
|
||||
patch("flask.has_request_context", return_value=False),
|
||||
patch("superset.mcp_service.auth.has_request_context", return_value=False),
|
||||
patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
side_effect=lambda: _assert_cleared_then_return(),
|
||||
|
||||
218
tests/unit_tests/mcp_service/test_composite_token_verifier.py
Normal file
218
tests/unit_tests/mcp_service/test_composite_token_verifier.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for CompositeTokenVerifier."""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastmcp.server.auth import AccessToken
|
||||
|
||||
from superset.mcp_service.composite_token_verifier import (
|
||||
API_KEY_PASSTHROUGH_CLAIM,
|
||||
CompositeTokenVerifier,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwt_verifier() -> MagicMock:
|
||||
verifier = MagicMock()
|
||||
verifier.required_scopes = []
|
||||
verifier.verify_token = AsyncMock()
|
||||
return verifier
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def composite_verifier(mock_jwt_verifier: MagicMock) -> CompositeTokenVerifier:
|
||||
return CompositeTokenVerifier(
|
||||
jwt_verifier=mock_jwt_verifier,
|
||||
api_key_prefixes=["sst_", "pat_"],
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_token_returns_passthrough(
|
||||
composite_verifier: CompositeTokenVerifier,
|
||||
) -> None:
|
||||
"""Tokens matching an API key prefix return a pass-through AccessToken."""
|
||||
api_key = "sst_abc123secret" # noqa: S105
|
||||
result = await composite_verifier.verify_token(api_key)
|
||||
|
||||
assert result is not None
|
||||
assert result.token == api_key
|
||||
assert result.client_id == "api_key"
|
||||
assert result.claims.get(API_KEY_PASSTHROUGH_CLAIM) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_second_prefix_matches(
|
||||
composite_verifier: CompositeTokenVerifier,
|
||||
) -> None:
|
||||
"""All configured prefixes are checked, not just the first."""
|
||||
result = await composite_verifier.verify_token("pat_mytoken")
|
||||
|
||||
assert result is not None
|
||||
assert result.claims.get(API_KEY_PASSTHROUGH_CLAIM) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_jwt_token_delegates_to_wrapped_verifier(
|
||||
composite_verifier: CompositeTokenVerifier, mock_jwt_verifier: MagicMock
|
||||
) -> None:
|
||||
"""Non-API-key tokens are delegated to the wrapped JWT verifier."""
|
||||
jwt_token = "eyJhbGciOiJSUzI1NiJ9.jwt_payload" # noqa: S105
|
||||
jwt_result = AccessToken(
|
||||
token=jwt_token,
|
||||
client_id="oauth_client",
|
||||
scopes=["read"],
|
||||
claims={"sub": "user1"},
|
||||
)
|
||||
mock_jwt_verifier.verify_token.return_value = jwt_result
|
||||
|
||||
result = await composite_verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload")
|
||||
|
||||
assert result is jwt_result
|
||||
mock_jwt_verifier.verify_token.assert_awaited_once_with(
|
||||
"eyJhbGciOiJSUzI1NiJ9.jwt_payload"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_jwt_returns_none(
|
||||
composite_verifier: CompositeTokenVerifier, mock_jwt_verifier: MagicMock
|
||||
) -> None:
|
||||
"""When the JWT verifier rejects a token, None is returned."""
|
||||
mock_jwt_verifier.verify_token.return_value = None
|
||||
|
||||
result = await composite_verifier.verify_token("not_a_valid_token")
|
||||
|
||||
assert result is None
|
||||
mock_jwt_verifier.verify_token.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_does_not_call_jwt_verifier(
|
||||
composite_verifier: CompositeTokenVerifier, mock_jwt_verifier: MagicMock
|
||||
) -> None:
|
||||
"""API key tokens bypass the JWT verifier entirely."""
|
||||
await composite_verifier.verify_token("sst_test_key")
|
||||
|
||||
mock_jwt_verifier.verify_token.assert_not_awaited()
|
||||
|
||||
|
||||
# -- API-key-only mode (no JWT verifier configured) --
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_only_mode_accepts_api_keys() -> None:
|
||||
"""When jwt_verifier is None, API key tokens are still passed through."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["sst_"])
|
||||
|
||||
result = await verifier.verify_token("sst_abc123")
|
||||
|
||||
assert result is not None
|
||||
assert result.claims.get(API_KEY_PASSTHROUGH_CLAIM) is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_only_mode_rejects_non_api_key_tokens() -> None:
|
||||
"""When jwt_verifier is None, non-API-key Bearer tokens are rejected at
|
||||
the transport instead of being silently accepted."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["sst_"])
|
||||
|
||||
result = await verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload")
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_string_prefix_is_filtered_out() -> None:
|
||||
"""An empty-string prefix would match every Bearer token (DoS vector).
|
||||
It must be silently dropped and never stored in _api_key_prefixes."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=[""])
|
||||
|
||||
assert "" not in verifier._api_key_prefixes
|
||||
# A plain JWT must NOT be misidentified as an API key.
|
||||
result = await verifier.verify_token("eyJhbGciOiJSUzI1NiJ9.jwt_payload")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_only_prefix_is_filtered_out() -> None:
|
||||
"""A whitespace-only prefix is also invalid and must be dropped."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=[" "])
|
||||
|
||||
assert " " not in verifier._api_key_prefixes
|
||||
result = await verifier.verify_token(" starts_with_spaces")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_non_string_prefix_is_filtered_out() -> None:
|
||||
"""Non-string entries (e.g. None, int) must not be stored and must not
|
||||
cause a TypeError during verify_token."""
|
||||
verifier = CompositeTokenVerifier(
|
||||
jwt_verifier=None,
|
||||
api_key_prefixes=[None, 42, "sst_"], # type: ignore[list-item]
|
||||
)
|
||||
|
||||
assert None not in verifier._api_key_prefixes
|
||||
assert 42 not in verifier._api_key_prefixes
|
||||
assert verifier._api_key_prefixes == ("sst_",)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_prefixes_emit_warning(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""Invalid prefix entries must trigger a logger.warning so operators can
|
||||
detect misconfiguration in FAB_API_KEY_PREFIXES."""
|
||||
import logging
|
||||
|
||||
logger_name = "superset.mcp_service.composite_token_verifier"
|
||||
with caplog.at_level(logging.WARNING, logger=logger_name):
|
||||
CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["", "sst_"])
|
||||
|
||||
assert any("invalid" in record.message.lower() for record in caplog.records)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_invalid_prefixes_accepts_no_api_keys() -> None:
|
||||
"""When all prefixes are invalid and filtered out, no token should match
|
||||
the API key path."""
|
||||
verifier = CompositeTokenVerifier(jwt_verifier=None, api_key_prefixes=["", " "])
|
||||
|
||||
assert verifier._api_key_prefixes == ()
|
||||
result = await verifier.verify_token("sst_abc123")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_api_key_passthrough_propagates_required_scopes() -> None:
|
||||
"""The pass-through AccessToken must carry the verifier's required_scopes
|
||||
so FastMCP's transport-level ``RequireAuthMiddleware`` does not 403 the
|
||||
request before ``_resolve_user_from_api_key`` runs."""
|
||||
jwt_verifier = MagicMock()
|
||||
jwt_verifier.required_scopes = ["read", "write"]
|
||||
jwt_verifier.verify_token = AsyncMock()
|
||||
|
||||
verifier = CompositeTokenVerifier(
|
||||
jwt_verifier=jwt_verifier, api_key_prefixes=["sst_"]
|
||||
)
|
||||
|
||||
result = await verifier.verify_token("sst_abc123")
|
||||
|
||||
assert result is not None
|
||||
assert result.scopes == ["read", "write"]
|
||||
Reference in New Issue
Block a user