fix(mcp): expose individual tool parameters when MCP_PARSE_REQUEST_ENABLED=False (#38714)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-03-18 11:38:22 +01:00
committed by GitHub
parent 834d2abe70
commit e02ca8871d
2 changed files with 525 additions and 83 deletions

View File

@@ -25,11 +25,12 @@ for input parameters, making MCP tools more flexible for different clients.
from __future__ import annotations
import asyncio
import inspect
import logging
from functools import wraps
from typing import Any, Callable, List, Type, TypeVar
from typing import Annotated, Any, Callable, List, Type, TypeVar
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, Field, ValidationError
logger = logging.getLogger(__name__)
@@ -411,9 +412,12 @@ def parse_request(
the tool function. Also modifies the function's type annotations to accept
str | RequestModel to pass FastMCP validation.
Can be disabled by setting MCP_PARSE_REQUEST_ENABLED = False in config.
When disabled, string-to-model parsing is skipped but ctx injection and
signature stripping still apply.
Behavior depends on MCP_PARSE_REQUEST_ENABLED config (checked at decoration time):
- **When True (default):** Single `request: str | Model` param with string parsing
at runtime. Preserves the Claude Code double-serialization workaround.
- **When False (production):** Flattened wrapper exposing all Pydantic model fields
as individual function parameters. FastMCP generates proper per-field schemas.
See: https://github.com/anthropics/claude-code/issues/5504
@@ -444,87 +448,232 @@ def parse_request(
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
import types
def _maybe_parse(request: Any) -> Any:
if _is_parse_request_enabled():
try:
return parse_json_or_model(request, request_class, "request")
except ValidationError as e:
from fastmcp.exceptions import ToolError
details = []
for err in e.errors():
field = " -> ".join(str(loc) for loc in err["loc"])
details.append(f"{field}: {err['msg']}")
required_fields = [
f.alias or name
for name, f in request_class.model_fields.items()
if f.is_required()
]
raise ToolError(
f"Invalid request parameters: {'; '.join(details)}. "
f"Required fields for {request_class.__name__}: "
f"{', '.join(required_fields)}"
) from None
return request
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
from fastmcp.server.dependencies import get_context
ctx = get_context()
return await func(_maybe_parse(request), ctx, *args, **kwargs)
wrapper = async_wrapper
else:
@wraps(func)
def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
from fastmcp.server.dependencies import get_context
ctx = get_context()
return func(_maybe_parse(request), ctx, *args, **kwargs)
wrapper = sync_wrapper
# Merge original function's __globals__ into wrapper's __globals__
# This allows get_type_hints() to resolve type annotations from the
# original module (e.g., Context from fastmcp)
# FastMCP 2.13.2+ uses get_type_hints() which needs access to these types
merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined]
new_wrapper = types.FunctionType(
wrapper.__code__, # type: ignore[attr-defined]
merged_globals,
wrapper.__name__,
wrapper.__defaults__, # type: ignore[attr-defined]
wrapper.__closure__, # type: ignore[attr-defined]
)
# Copy __dict__ but exclude __wrapped__
# NOTE: We intentionally do NOT preserve __wrapped__ here.
# Setting __wrapped__ causes inspect.signature() to follow the chain
# and find 'ctx' in the original function's signature, even after
# FastMCP's create_function_without_params removes it from annotations.
# This breaks Pydantic's TypeAdapter which expects signature params
# to match type_hints.
new_wrapper.__dict__.update(
{k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"}
)
new_wrapper.__module__ = wrapper.__module__
new_wrapper.__qualname__ = wrapper.__qualname__
# Copy docstring from original function (not wrapper, which has no docstring)
new_wrapper.__doc__ = func.__doc__
request_annotation = str | request_class
_apply_signature_for_fastmcp(new_wrapper, func, request_annotation)
return new_wrapper
if _is_parse_request_enabled():
return _create_string_parsing_wrapper(func, request_class)
return _create_flattened_wrapper(func, request_class)
return decorator
def _create_string_parsing_wrapper(
func: Callable[..., Any],
request_class: Type[BaseModel],
) -> Callable[..., Any]:
"""Create a wrapper that accepts a single `request: str | Model` parameter.
This is the original parse_request behavior: at runtime, if the request
is a JSON string it gets parsed into the Pydantic model. Used when
MCP_PARSE_REQUEST_ENABLED is True.
"""
import types
def _maybe_parse(request: Any) -> Any:
if _is_parse_request_enabled():
try:
return parse_json_or_model(request, request_class, "request")
except ValidationError as e:
from fastmcp.exceptions import ToolError
details = []
for err in e.errors():
field = " -> ".join(str(loc) for loc in err["loc"])
details.append(f"{field}: {err['msg']}")
required_fields = [
f.alias or name
for name, f in request_class.model_fields.items()
if f.is_required()
]
raise ToolError(
f"Invalid request parameters: {'; '.join(details)}. "
f"Required fields for {request_class.__name__}: "
f"{', '.join(required_fields)}"
) from None
return request
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
from fastmcp.server.dependencies import get_context
ctx = get_context()
return await func(_maybe_parse(request), ctx, *args, **kwargs)
wrapper = async_wrapper
else:
@wraps(func)
def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
from fastmcp.server.dependencies import get_context
ctx = get_context()
return func(_maybe_parse(request), ctx, *args, **kwargs)
wrapper = sync_wrapper
# Merge original function's __globals__ into wrapper's __globals__
# This allows get_type_hints() to resolve type annotations from the
# original module (e.g., Context from fastmcp)
# FastMCP 2.13.2+ uses get_type_hints() which needs access to these types
merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined]
new_wrapper = types.FunctionType(
wrapper.__code__, # type: ignore[attr-defined]
merged_globals,
wrapper.__name__,
wrapper.__defaults__, # type: ignore[attr-defined]
wrapper.__closure__, # type: ignore[attr-defined]
)
# Copy __dict__ but exclude __wrapped__
# NOTE: We intentionally do NOT preserve __wrapped__ here.
# Setting __wrapped__ causes inspect.signature() to follow the chain
# and find 'ctx' in the original function's signature, even after
# FastMCP's create_function_without_params removes it from annotations.
# This breaks Pydantic's TypeAdapter which expects signature params
# to match type_hints.
new_wrapper.__dict__.update(
{k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"}
)
new_wrapper.__module__ = wrapper.__module__
new_wrapper.__qualname__ = wrapper.__qualname__
# Copy docstring from original function (not wrapper, which has no docstring)
new_wrapper.__doc__ = func.__doc__
request_annotation = str | request_class
_apply_signature_for_fastmcp(new_wrapper, func, request_annotation)
return new_wrapper
def _create_flattened_wrapper(
func: Callable[..., Any],
request_class: Type[BaseModel],
) -> Callable[..., Any]:
"""Create a wrapper that exposes individual Pydantic model fields as parameters.
Used when MCP_PARSE_REQUEST_ENABLED is False. Instead of a single opaque
``request`` parameter, the wrapper accepts each field from ``request_class``
as a keyword argument. FastMCP then generates a proper JSON schema with
individual properties for each field.
"""
import types
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(**kwargs: Any) -> Any:
from fastmcp.server.dependencies import get_context
ctx = get_context()
model = request_class.model_validate(kwargs)
return await func(model, ctx)
wrapper = async_wrapper
else:
@wraps(func)
def sync_wrapper(**kwargs: Any) -> Any:
from fastmcp.server.dependencies import get_context
ctx = get_context()
model = request_class.model_validate(kwargs)
return func(model, ctx)
wrapper = sync_wrapper
# Merge globals so get_type_hints() can resolve annotations from
# the original function's module.
merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined]
new_wrapper = types.FunctionType(
wrapper.__code__, # type: ignore[attr-defined]
merged_globals,
wrapper.__name__,
wrapper.__defaults__, # type: ignore[attr-defined]
wrapper.__closure__, # type: ignore[attr-defined]
)
new_wrapper.__dict__.update(
{k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"}
)
new_wrapper.__module__ = wrapper.__module__
new_wrapper.__qualname__ = wrapper.__qualname__
new_wrapper.__doc__ = func.__doc__
_apply_flattened_signature(new_wrapper, func, request_class)
return new_wrapper
def _apply_flattened_signature(
wrapper: Any,
original_func: Callable[..., Any],
request_class: Type[BaseModel],
) -> None:
"""Build a signature from the Pydantic model's fields and apply it to *wrapper*.
Each field in ``request_class.model_fields`` becomes a keyword-only parameter.
Field descriptions and constraints (ge, le, min_length, max_length, pattern,
etc.) are preserved via ``Annotated[type, Field(description=...), ...]``
so that FastMCP propagates them into the generated JSON schema.
"""
params: list[inspect.Parameter] = []
annotations: dict[str, Any] = {}
for name, field_info in request_class.model_fields.items():
# Use alias as param name when set, otherwise field name.
# This ensures model_validate() always works: aliases are always
# accepted, while field names require populate_by_name=True.
param_name = field_info.alias if isinstance(field_info.alias, str) else name
# Determine the default value for this parameter.
# Check default_factory before default because Pydantic sets
# default to PydanticUndefined when default_factory is used.
if field_info.is_required():
default = inspect.Parameter.empty
elif field_info.default_factory is not None:
default = field_info.default_factory()
else:
default = field_info.default # covers None and explicit defaults
# Build Annotated type with description and all constraint metadata
# (ge, le, min_length, max_length, pattern, etc.) so FastMCP
# propagates them into the generated JSON schema.
base_annotation = field_info.annotation
metadata_markers: list[Any] = []
if field_info.description:
metadata_markers.append(Field(description=field_info.description))
# Forward constraint metadata from the original Pydantic field.
# Pydantic v2 stores constraints (Ge, Le, MinLen, MaxLen, etc.)
# as annotation objects in field_info.metadata.
if field_info.metadata:
metadata_markers.extend(field_info.metadata)
if metadata_markers:
# Dynamically construct Annotated[base_type, meta1, meta2, ...]
annotation: Any = Annotated.__class_getitem__( # type: ignore[attr-defined]
(base_annotation, *metadata_markers)
)
else:
annotation = base_annotation
params.append(
inspect.Parameter(
param_name,
kind=inspect.Parameter.KEYWORD_ONLY,
default=default,
annotation=annotation,
)
)
annotations[param_name] = annotation
# Preserve return annotation from the original function if present
orig_sig = inspect.signature(original_func)
wrapper.__signature__ = orig_sig.replace(
parameters=params,
return_annotation=orig_sig.return_annotation,
)
wrapper.__annotations__ = annotations
if orig_sig.return_annotation is not inspect.Parameter.empty:
wrapper.__annotations__["return"] = orig_sig.return_annotation
def _apply_signature_for_fastmcp(
wrapper: Any,
original_func: Callable[..., Any],

View File

@@ -19,8 +19,12 @@
Unit tests for MCP service schema utilities.
"""
import asyncio
import inspect
from typing import Annotated, List
import pytest
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, Field, ValidationError
from superset.mcp_service.utils.schema_utils import (
JSONParseError,
@@ -564,3 +568,292 @@ class TestParseRequestDecorator:
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = sync_tool(json_str)
assert result == "test:42"
class TestParseRequestFlattenedDecorator:
"""Test parse_request decorator when MCP_PARSE_REQUEST_ENABLED=False.
In this mode the decorator produces a "flattened" wrapper that exposes
individual Pydantic model fields as function parameters instead of a
single opaque ``request`` parameter.
"""
class SimpleRequest(BaseModel):
"""Request with required and optional fields."""
name: str = Field(description="The name")
count: int = Field(description="Item count")
tag: str | None = Field(default=None, description="Optional tag")
class EmptyRequest(BaseModel):
"""Request with no fields (like GetSupersetInstanceInfoRequest)."""
class ParentModel(BaseModel):
"""Base model with inherited fields."""
use_cache: bool = Field(default=True, description="Use cache")
class ChildRequest(ParentModel):
"""Child inherits use_cache from ParentModel."""
page: int = Field(default=1, description="Page number")
class ListDefaultRequest(BaseModel):
"""Request with default_factory fields."""
items: List[str] = Field(default_factory=list, description="List of items")
label: str = Field(default="default", description="A label")
@pytest.fixture(autouse=True)
def _disable_parse_request(self):
"""Ensure MCP_PARSE_REQUEST_ENABLED=False for flattened tests."""
from unittest.mock import patch
with patch(
"superset.mcp_service.utils.schema_utils._is_parse_request_enabled",
return_value=False,
):
yield
def test_flattened_signature_has_individual_params(self):
"""Decorated function should expose model fields, not 'request'."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
param_names = list(sig.parameters.keys())
assert "request" not in param_names
assert "ctx" not in param_names
assert "name" in param_names
assert "count" in param_names
assert "tag" in param_names
def test_flattened_signature_defaults(self):
"""Required fields have no default; optional fields keep their defaults."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
assert sig.parameters["name"].default is inspect.Parameter.empty
assert sig.parameters["count"].default is inspect.Parameter.empty
assert sig.parameters["tag"].default is None
def test_flattened_async_constructs_model(self):
"""Async flattened wrapper should construct model from kwargs."""
from unittest.mock import MagicMock, patch
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
return f"{request.name}:{request.count}:{request.tag}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(my_tool(name="hello", count=7, tag="x"))
assert result == "hello:7:x"
def test_flattened_sync_constructs_model(self):
"""Sync flattened wrapper should construct model from kwargs."""
from unittest.mock import MagicMock, patch
@parse_request(self.SimpleRequest)
def my_tool(request, ctx=None):
return f"{request.name}:{request.count}:{request.tag}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = my_tool(name="hello", count=7)
assert result == "hello:7:None"
def test_flattened_empty_model(self):
"""Empty request model should produce zero parameters."""
@parse_request(self.EmptyRequest)
async def my_tool(request, ctx=None):
return "ok"
sig = inspect.signature(my_tool)
assert len(sig.parameters) == 0
def test_flattened_empty_model_callable(self):
"""Tool with empty model should be callable with no args."""
from unittest.mock import MagicMock, patch
@parse_request(self.EmptyRequest)
async def my_tool(request, ctx=None):
return "ok"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(my_tool())
assert result == "ok"
def test_flattened_inherited_fields(self):
"""Inherited fields from base model should appear in the signature."""
@parse_request(self.ChildRequest)
async def my_tool(request, ctx=None):
return f"{request.use_cache}:{request.page}"
sig = inspect.signature(my_tool)
assert "use_cache" in sig.parameters
assert "page" in sig.parameters
assert sig.parameters["use_cache"].default is True
assert sig.parameters["page"].default == 1
def test_flattened_inherited_fields_callable(self):
"""Tool with inherited fields should be callable."""
from unittest.mock import MagicMock, patch
@parse_request(self.ChildRequest)
async def my_tool(request, ctx=None):
return f"{request.use_cache}:{request.page}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(my_tool(use_cache=False, page=3))
assert result == "False:3"
def test_flattened_default_factory_fields(self):
"""Fields with default_factory should get the factory result as default."""
@parse_request(self.ListDefaultRequest)
def my_tool(request, ctx=None):
return f"{request.items}:{request.label}"
sig = inspect.signature(my_tool)
assert sig.parameters["items"].default == []
assert sig.parameters["label"].default == "default"
def test_flattened_preserves_docstring(self):
"""Flattened wrapper should preserve the original function's docstring."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
"""My tool docstring."""
pass
assert my_tool.__doc__ == "My tool docstring."
def test_flattened_preserves_function_name(self):
"""Flattened wrapper should preserve the original function's name."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
assert my_tool.__name__ == "my_tool"
def test_flattened_params_are_keyword_only(self):
"""All parameters should be keyword-only."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
for param in sig.parameters.values():
assert param.kind == inspect.Parameter.KEYWORD_ONLY
def test_flattened_validation_error(self):
"""Pydantic validation errors should propagate when constructing model."""
from unittest.mock import MagicMock, patch
@parse_request(self.SimpleRequest)
def my_tool(request, ctx=None):
return "ok"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
with pytest.raises(ValidationError):
my_tool(name="test", count="not_a_number")
def test_flattened_annotations_have_descriptions(self):
"""Annotations should include Field descriptions for FastMCP schema."""
from typing import get_origin
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
name_param = sig.parameters["name"]
# Should be Annotated[str, Field(description="The name")]
assert get_origin(name_param.annotation) is not None or hasattr(
name_param.annotation, "__metadata__"
)
def test_flattened_annotations_forward_constraints(self):
"""Constraint metadata (ge, le, min_length, pattern, etc.) should be
forwarded into the Annotated type so FastMCP includes them in the
JSON schema."""
from typing import get_args, get_origin
from annotated_types import Ge, Le, MaxLen, MinLen
class ConstrainedRequest(BaseModel):
page: int = Field(1, description="Page number", ge=1, le=1000)
name: str = Field(..., description="Name", min_length=1, max_length=255)
@parse_request(ConstrainedRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
# Check page param has ge/le constraints in metadata
page_ann = sig.parameters["page"].annotation
assert get_origin(page_ann) is Annotated
page_metadata = get_args(page_ann)[1:] # skip base type
metadata_types = [type(m) for m in page_metadata]
assert Ge in metadata_types, f"Missing Ge in {page_metadata}"
assert Le in metadata_types, f"Missing Le in {page_metadata}"
# Check name param has min_length/max_length constraints
name_ann = sig.parameters["name"].annotation
assert get_origin(name_ann) is Annotated
name_metadata = get_args(name_ann)[1:]
metadata_types = [type(m) for m in name_metadata]
assert MinLen in metadata_types, f"Missing MinLen in {name_metadata}"
assert MaxLen in metadata_types, f"Missing MaxLen in {name_metadata}"
def test_flattened_uses_alias_as_param_name(self):
"""When a field has an alias, the flattened param name should be the alias,
not the Python field name. This ensures model_validate() works for models
without populate_by_name=True."""
class AliasedRequest(BaseModel):
schema_name: str | None = Field(None, description="Schema", alias="schema")
database_id: int = Field(..., description="DB ID")
@parse_request(AliasedRequest)
async def my_tool(request, ctx=None):
return {"schema": request.schema_name, "db": request.database_id}
sig = inspect.signature(my_tool)
param_names = list(sig.parameters.keys())
# alias "schema" should be used instead of field name "schema_name"
assert "schema" in param_names
assert "schema_name" not in param_names
# non-aliased field keeps its name
assert "database_id" in param_names
def test_flattened_alias_callable(self):
"""Flattened wrapper with aliased fields should construct model correctly."""
from unittest.mock import MagicMock, patch
class AliasedRequest(BaseModel):
schema_name: str | None = Field(None, description="Schema", alias="schema")
database_id: int = Field(..., description="DB ID")
@parse_request(AliasedRequest)
def my_tool(request, ctx=None):
return {"schema": request.schema_name, "db": request.database_id}
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = my_tool(schema="public", database_id=1)
assert result == {"schema": "public", "db": 1}