mirror of
https://github.com/apache/superset.git
synced 2026-04-18 15:44:57 +00:00
fix(mcp): expose individual tool parameters when MCP_PARSE_REQUEST_ENABLED=False (#38714)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -25,11 +25,12 @@ for input parameters, making MCP tools more flexible for different clients.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import logging
|
||||
from functools import wraps
|
||||
from typing import Any, Callable, List, Type, TypeVar
|
||||
from typing import Annotated, Any, Callable, List, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -411,9 +412,12 @@ def parse_request(
|
||||
the tool function. Also modifies the function's type annotations to accept
|
||||
str | RequestModel to pass FastMCP validation.
|
||||
|
||||
Can be disabled by setting MCP_PARSE_REQUEST_ENABLED = False in config.
|
||||
When disabled, string-to-model parsing is skipped but ctx injection and
|
||||
signature stripping still apply.
|
||||
Behavior depends on MCP_PARSE_REQUEST_ENABLED config (checked at decoration time):
|
||||
|
||||
- **When True (default):** Single `request: str | Model` param with string parsing
|
||||
at runtime. Preserves the Claude Code double-serialization workaround.
|
||||
- **When False (production):** Flattened wrapper exposing all Pydantic model fields
|
||||
as individual function parameters. FastMCP generates proper per-field schemas.
|
||||
|
||||
See: https://github.com/anthropics/claude-code/issues/5504
|
||||
|
||||
@@ -444,87 +448,232 @@ def parse_request(
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
import types
|
||||
|
||||
def _maybe_parse(request: Any) -> Any:
|
||||
if _is_parse_request_enabled():
|
||||
try:
|
||||
return parse_json_or_model(request, request_class, "request")
|
||||
except ValidationError as e:
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
details = []
|
||||
for err in e.errors():
|
||||
field = " -> ".join(str(loc) for loc in err["loc"])
|
||||
details.append(f"{field}: {err['msg']}")
|
||||
required_fields = [
|
||||
f.alias or name
|
||||
for name, f in request_class.model_fields.items()
|
||||
if f.is_required()
|
||||
]
|
||||
raise ToolError(
|
||||
f"Invalid request parameters: {'; '.join(details)}. "
|
||||
f"Required fields for {request_class.__name__}: "
|
||||
f"{', '.join(required_fields)}"
|
||||
) from None
|
||||
return request
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
|
||||
ctx = get_context()
|
||||
return await func(_maybe_parse(request), ctx, *args, **kwargs)
|
||||
|
||||
wrapper = async_wrapper
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
|
||||
ctx = get_context()
|
||||
return func(_maybe_parse(request), ctx, *args, **kwargs)
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Merge original function's __globals__ into wrapper's __globals__
|
||||
# This allows get_type_hints() to resolve type annotations from the
|
||||
# original module (e.g., Context from fastmcp)
|
||||
# FastMCP 2.13.2+ uses get_type_hints() which needs access to these types
|
||||
merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined]
|
||||
new_wrapper = types.FunctionType(
|
||||
wrapper.__code__, # type: ignore[attr-defined]
|
||||
merged_globals,
|
||||
wrapper.__name__,
|
||||
wrapper.__defaults__, # type: ignore[attr-defined]
|
||||
wrapper.__closure__, # type: ignore[attr-defined]
|
||||
)
|
||||
# Copy __dict__ but exclude __wrapped__
|
||||
# NOTE: We intentionally do NOT preserve __wrapped__ here.
|
||||
# Setting __wrapped__ causes inspect.signature() to follow the chain
|
||||
# and find 'ctx' in the original function's signature, even after
|
||||
# FastMCP's create_function_without_params removes it from annotations.
|
||||
# This breaks Pydantic's TypeAdapter which expects signature params
|
||||
# to match type_hints.
|
||||
new_wrapper.__dict__.update(
|
||||
{k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"}
|
||||
)
|
||||
new_wrapper.__module__ = wrapper.__module__
|
||||
new_wrapper.__qualname__ = wrapper.__qualname__
|
||||
# Copy docstring from original function (not wrapper, which has no docstring)
|
||||
new_wrapper.__doc__ = func.__doc__
|
||||
|
||||
request_annotation = str | request_class
|
||||
_apply_signature_for_fastmcp(new_wrapper, func, request_annotation)
|
||||
|
||||
return new_wrapper
|
||||
if _is_parse_request_enabled():
|
||||
return _create_string_parsing_wrapper(func, request_class)
|
||||
return _create_flattened_wrapper(func, request_class)
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
def _create_string_parsing_wrapper(
|
||||
func: Callable[..., Any],
|
||||
request_class: Type[BaseModel],
|
||||
) -> Callable[..., Any]:
|
||||
"""Create a wrapper that accepts a single `request: str | Model` parameter.
|
||||
|
||||
This is the original parse_request behavior: at runtime, if the request
|
||||
is a JSON string it gets parsed into the Pydantic model. Used when
|
||||
MCP_PARSE_REQUEST_ENABLED is True.
|
||||
"""
|
||||
import types
|
||||
|
||||
def _maybe_parse(request: Any) -> Any:
|
||||
if _is_parse_request_enabled():
|
||||
try:
|
||||
return parse_json_or_model(request, request_class, "request")
|
||||
except ValidationError as e:
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
details = []
|
||||
for err in e.errors():
|
||||
field = " -> ".join(str(loc) for loc in err["loc"])
|
||||
details.append(f"{field}: {err['msg']}")
|
||||
required_fields = [
|
||||
f.alias or name
|
||||
for name, f in request_class.model_fields.items()
|
||||
if f.is_required()
|
||||
]
|
||||
raise ToolError(
|
||||
f"Invalid request parameters: {'; '.join(details)}. "
|
||||
f"Required fields for {request_class.__name__}: "
|
||||
f"{', '.join(required_fields)}"
|
||||
) from None
|
||||
return request
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
|
||||
ctx = get_context()
|
||||
return await func(_maybe_parse(request), ctx, *args, **kwargs)
|
||||
|
||||
wrapper = async_wrapper
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
|
||||
ctx = get_context()
|
||||
return func(_maybe_parse(request), ctx, *args, **kwargs)
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Merge original function's __globals__ into wrapper's __globals__
|
||||
# This allows get_type_hints() to resolve type annotations from the
|
||||
# original module (e.g., Context from fastmcp)
|
||||
# FastMCP 2.13.2+ uses get_type_hints() which needs access to these types
|
||||
merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined]
|
||||
new_wrapper = types.FunctionType(
|
||||
wrapper.__code__, # type: ignore[attr-defined]
|
||||
merged_globals,
|
||||
wrapper.__name__,
|
||||
wrapper.__defaults__, # type: ignore[attr-defined]
|
||||
wrapper.__closure__, # type: ignore[attr-defined]
|
||||
)
|
||||
# Copy __dict__ but exclude __wrapped__
|
||||
# NOTE: We intentionally do NOT preserve __wrapped__ here.
|
||||
# Setting __wrapped__ causes inspect.signature() to follow the chain
|
||||
# and find 'ctx' in the original function's signature, even after
|
||||
# FastMCP's create_function_without_params removes it from annotations.
|
||||
# This breaks Pydantic's TypeAdapter which expects signature params
|
||||
# to match type_hints.
|
||||
new_wrapper.__dict__.update(
|
||||
{k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"}
|
||||
)
|
||||
new_wrapper.__module__ = wrapper.__module__
|
||||
new_wrapper.__qualname__ = wrapper.__qualname__
|
||||
# Copy docstring from original function (not wrapper, which has no docstring)
|
||||
new_wrapper.__doc__ = func.__doc__
|
||||
|
||||
request_annotation = str | request_class
|
||||
_apply_signature_for_fastmcp(new_wrapper, func, request_annotation)
|
||||
|
||||
return new_wrapper
|
||||
|
||||
|
||||
def _create_flattened_wrapper(
|
||||
func: Callable[..., Any],
|
||||
request_class: Type[BaseModel],
|
||||
) -> Callable[..., Any]:
|
||||
"""Create a wrapper that exposes individual Pydantic model fields as parameters.
|
||||
|
||||
Used when MCP_PARSE_REQUEST_ENABLED is False. Instead of a single opaque
|
||||
``request`` parameter, the wrapper accepts each field from ``request_class``
|
||||
as a keyword argument. FastMCP then generates a proper JSON schema with
|
||||
individual properties for each field.
|
||||
"""
|
||||
import types
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(**kwargs: Any) -> Any:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
|
||||
ctx = get_context()
|
||||
model = request_class.model_validate(kwargs)
|
||||
return await func(model, ctx)
|
||||
|
||||
wrapper = async_wrapper
|
||||
else:
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(**kwargs: Any) -> Any:
|
||||
from fastmcp.server.dependencies import get_context
|
||||
|
||||
ctx = get_context()
|
||||
model = request_class.model_validate(kwargs)
|
||||
return func(model, ctx)
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Merge globals so get_type_hints() can resolve annotations from
|
||||
# the original function's module.
|
||||
merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined]
|
||||
new_wrapper = types.FunctionType(
|
||||
wrapper.__code__, # type: ignore[attr-defined]
|
||||
merged_globals,
|
||||
wrapper.__name__,
|
||||
wrapper.__defaults__, # type: ignore[attr-defined]
|
||||
wrapper.__closure__, # type: ignore[attr-defined]
|
||||
)
|
||||
new_wrapper.__dict__.update(
|
||||
{k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"}
|
||||
)
|
||||
new_wrapper.__module__ = wrapper.__module__
|
||||
new_wrapper.__qualname__ = wrapper.__qualname__
|
||||
new_wrapper.__doc__ = func.__doc__
|
||||
|
||||
_apply_flattened_signature(new_wrapper, func, request_class)
|
||||
|
||||
return new_wrapper
|
||||
|
||||
|
||||
def _apply_flattened_signature(
|
||||
wrapper: Any,
|
||||
original_func: Callable[..., Any],
|
||||
request_class: Type[BaseModel],
|
||||
) -> None:
|
||||
"""Build a signature from the Pydantic model's fields and apply it to *wrapper*.
|
||||
|
||||
Each field in ``request_class.model_fields`` becomes a keyword-only parameter.
|
||||
Field descriptions and constraints (ge, le, min_length, max_length, pattern,
|
||||
etc.) are preserved via ``Annotated[type, Field(description=...), ...]``
|
||||
so that FastMCP propagates them into the generated JSON schema.
|
||||
"""
|
||||
params: list[inspect.Parameter] = []
|
||||
annotations: dict[str, Any] = {}
|
||||
|
||||
for name, field_info in request_class.model_fields.items():
|
||||
# Use alias as param name when set, otherwise field name.
|
||||
# This ensures model_validate() always works: aliases are always
|
||||
# accepted, while field names require populate_by_name=True.
|
||||
param_name = field_info.alias if isinstance(field_info.alias, str) else name
|
||||
|
||||
# Determine the default value for this parameter.
|
||||
# Check default_factory before default because Pydantic sets
|
||||
# default to PydanticUndefined when default_factory is used.
|
||||
if field_info.is_required():
|
||||
default = inspect.Parameter.empty
|
||||
elif field_info.default_factory is not None:
|
||||
default = field_info.default_factory()
|
||||
else:
|
||||
default = field_info.default # covers None and explicit defaults
|
||||
|
||||
# Build Annotated type with description and all constraint metadata
|
||||
# (ge, le, min_length, max_length, pattern, etc.) so FastMCP
|
||||
# propagates them into the generated JSON schema.
|
||||
base_annotation = field_info.annotation
|
||||
metadata_markers: list[Any] = []
|
||||
if field_info.description:
|
||||
metadata_markers.append(Field(description=field_info.description))
|
||||
# Forward constraint metadata from the original Pydantic field.
|
||||
# Pydantic v2 stores constraints (Ge, Le, MinLen, MaxLen, etc.)
|
||||
# as annotation objects in field_info.metadata.
|
||||
if field_info.metadata:
|
||||
metadata_markers.extend(field_info.metadata)
|
||||
if metadata_markers:
|
||||
# Dynamically construct Annotated[base_type, meta1, meta2, ...]
|
||||
annotation: Any = Annotated.__class_getitem__( # type: ignore[attr-defined]
|
||||
(base_annotation, *metadata_markers)
|
||||
)
|
||||
else:
|
||||
annotation = base_annotation
|
||||
|
||||
params.append(
|
||||
inspect.Parameter(
|
||||
param_name,
|
||||
kind=inspect.Parameter.KEYWORD_ONLY,
|
||||
default=default,
|
||||
annotation=annotation,
|
||||
)
|
||||
)
|
||||
annotations[param_name] = annotation
|
||||
|
||||
# Preserve return annotation from the original function if present
|
||||
orig_sig = inspect.signature(original_func)
|
||||
wrapper.__signature__ = orig_sig.replace(
|
||||
parameters=params,
|
||||
return_annotation=orig_sig.return_annotation,
|
||||
)
|
||||
wrapper.__annotations__ = annotations
|
||||
if orig_sig.return_annotation is not inspect.Parameter.empty:
|
||||
wrapper.__annotations__["return"] = orig_sig.return_annotation
|
||||
|
||||
|
||||
def _apply_signature_for_fastmcp(
|
||||
wrapper: Any,
|
||||
original_func: Callable[..., Any],
|
||||
|
||||
@@ -19,8 +19,12 @@
|
||||
Unit tests for MCP service schema utilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Annotated, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
JSONParseError,
|
||||
@@ -564,3 +568,292 @@ class TestParseRequestDecorator:
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = sync_tool(json_str)
|
||||
assert result == "test:42"
|
||||
|
||||
|
||||
class TestParseRequestFlattenedDecorator:
|
||||
"""Test parse_request decorator when MCP_PARSE_REQUEST_ENABLED=False.
|
||||
|
||||
In this mode the decorator produces a "flattened" wrapper that exposes
|
||||
individual Pydantic model fields as function parameters instead of a
|
||||
single opaque ``request`` parameter.
|
||||
"""
|
||||
|
||||
class SimpleRequest(BaseModel):
|
||||
"""Request with required and optional fields."""
|
||||
|
||||
name: str = Field(description="The name")
|
||||
count: int = Field(description="Item count")
|
||||
tag: str | None = Field(default=None, description="Optional tag")
|
||||
|
||||
class EmptyRequest(BaseModel):
|
||||
"""Request with no fields (like GetSupersetInstanceInfoRequest)."""
|
||||
|
||||
class ParentModel(BaseModel):
|
||||
"""Base model with inherited fields."""
|
||||
|
||||
use_cache: bool = Field(default=True, description="Use cache")
|
||||
|
||||
class ChildRequest(ParentModel):
|
||||
"""Child inherits use_cache from ParentModel."""
|
||||
|
||||
page: int = Field(default=1, description="Page number")
|
||||
|
||||
class ListDefaultRequest(BaseModel):
|
||||
"""Request with default_factory fields."""
|
||||
|
||||
items: List[str] = Field(default_factory=list, description="List of items")
|
||||
label: str = Field(default="default", description="A label")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_parse_request(self):
|
||||
"""Ensure MCP_PARSE_REQUEST_ENABLED=False for flattened tests."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.utils.schema_utils._is_parse_request_enabled",
|
||||
return_value=False,
|
||||
):
|
||||
yield
|
||||
|
||||
def test_flattened_signature_has_individual_params(self):
|
||||
"""Decorated function should expose model fields, not 'request'."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
param_names = list(sig.parameters.keys())
|
||||
assert "request" not in param_names
|
||||
assert "ctx" not in param_names
|
||||
assert "name" in param_names
|
||||
assert "count" in param_names
|
||||
assert "tag" in param_names
|
||||
|
||||
def test_flattened_signature_defaults(self):
|
||||
"""Required fields have no default; optional fields keep their defaults."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
assert sig.parameters["name"].default is inspect.Parameter.empty
|
||||
assert sig.parameters["count"].default is inspect.Parameter.empty
|
||||
assert sig.parameters["tag"].default is None
|
||||
|
||||
def test_flattened_async_constructs_model(self):
|
||||
"""Async flattened wrapper should construct model from kwargs."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}:{request.tag}"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(my_tool(name="hello", count=7, tag="x"))
|
||||
assert result == "hello:7:x"
|
||||
|
||||
def test_flattened_sync_constructs_model(self):
|
||||
"""Sync flattened wrapper should construct model from kwargs."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
def my_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}:{request.tag}"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = my_tool(name="hello", count=7)
|
||||
assert result == "hello:7:None"
|
||||
|
||||
def test_flattened_empty_model(self):
|
||||
"""Empty request model should produce zero parameters."""
|
||||
|
||||
@parse_request(self.EmptyRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return "ok"
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
assert len(sig.parameters) == 0
|
||||
|
||||
def test_flattened_empty_model_callable(self):
|
||||
"""Tool with empty model should be callable with no args."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.EmptyRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return "ok"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(my_tool())
|
||||
assert result == "ok"
|
||||
|
||||
def test_flattened_inherited_fields(self):
|
||||
"""Inherited fields from base model should appear in the signature."""
|
||||
|
||||
@parse_request(self.ChildRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return f"{request.use_cache}:{request.page}"
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
assert "use_cache" in sig.parameters
|
||||
assert "page" in sig.parameters
|
||||
assert sig.parameters["use_cache"].default is True
|
||||
assert sig.parameters["page"].default == 1
|
||||
|
||||
def test_flattened_inherited_fields_callable(self):
|
||||
"""Tool with inherited fields should be callable."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.ChildRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return f"{request.use_cache}:{request.page}"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(my_tool(use_cache=False, page=3))
|
||||
assert result == "False:3"
|
||||
|
||||
def test_flattened_default_factory_fields(self):
|
||||
"""Fields with default_factory should get the factory result as default."""
|
||||
|
||||
@parse_request(self.ListDefaultRequest)
|
||||
def my_tool(request, ctx=None):
|
||||
return f"{request.items}:{request.label}"
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
assert sig.parameters["items"].default == []
|
||||
assert sig.parameters["label"].default == "default"
|
||||
|
||||
def test_flattened_preserves_docstring(self):
|
||||
"""Flattened wrapper should preserve the original function's docstring."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
"""My tool docstring."""
|
||||
pass
|
||||
|
||||
assert my_tool.__doc__ == "My tool docstring."
|
||||
|
||||
def test_flattened_preserves_function_name(self):
|
||||
"""Flattened wrapper should preserve the original function's name."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
assert my_tool.__name__ == "my_tool"
|
||||
|
||||
def test_flattened_params_are_keyword_only(self):
|
||||
"""All parameters should be keyword-only."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
for param in sig.parameters.values():
|
||||
assert param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
|
||||
def test_flattened_validation_error(self):
|
||||
"""Pydantic validation errors should propagate when constructing model."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
def my_tool(request, ctx=None):
|
||||
return "ok"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
with pytest.raises(ValidationError):
|
||||
my_tool(name="test", count="not_a_number")
|
||||
|
||||
def test_flattened_annotations_have_descriptions(self):
|
||||
"""Annotations should include Field descriptions for FastMCP schema."""
|
||||
from typing import get_origin
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
name_param = sig.parameters["name"]
|
||||
# Should be Annotated[str, Field(description="The name")]
|
||||
assert get_origin(name_param.annotation) is not None or hasattr(
|
||||
name_param.annotation, "__metadata__"
|
||||
)
|
||||
|
||||
def test_flattened_annotations_forward_constraints(self):
|
||||
"""Constraint metadata (ge, le, min_length, pattern, etc.) should be
|
||||
forwarded into the Annotated type so FastMCP includes them in the
|
||||
JSON schema."""
|
||||
from typing import get_args, get_origin
|
||||
|
||||
from annotated_types import Ge, Le, MaxLen, MinLen
|
||||
|
||||
class ConstrainedRequest(BaseModel):
|
||||
page: int = Field(1, description="Page number", ge=1, le=1000)
|
||||
name: str = Field(..., description="Name", min_length=1, max_length=255)
|
||||
|
||||
@parse_request(ConstrainedRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
|
||||
# Check page param has ge/le constraints in metadata
|
||||
page_ann = sig.parameters["page"].annotation
|
||||
assert get_origin(page_ann) is Annotated
|
||||
page_metadata = get_args(page_ann)[1:] # skip base type
|
||||
metadata_types = [type(m) for m in page_metadata]
|
||||
assert Ge in metadata_types, f"Missing Ge in {page_metadata}"
|
||||
assert Le in metadata_types, f"Missing Le in {page_metadata}"
|
||||
|
||||
# Check name param has min_length/max_length constraints
|
||||
name_ann = sig.parameters["name"].annotation
|
||||
assert get_origin(name_ann) is Annotated
|
||||
name_metadata = get_args(name_ann)[1:]
|
||||
metadata_types = [type(m) for m in name_metadata]
|
||||
assert MinLen in metadata_types, f"Missing MinLen in {name_metadata}"
|
||||
assert MaxLen in metadata_types, f"Missing MaxLen in {name_metadata}"
|
||||
|
||||
def test_flattened_uses_alias_as_param_name(self):
|
||||
"""When a field has an alias, the flattened param name should be the alias,
|
||||
not the Python field name. This ensures model_validate() works for models
|
||||
without populate_by_name=True."""
|
||||
|
||||
class AliasedRequest(BaseModel):
|
||||
schema_name: str | None = Field(None, description="Schema", alias="schema")
|
||||
database_id: int = Field(..., description="DB ID")
|
||||
|
||||
@parse_request(AliasedRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return {"schema": request.schema_name, "db": request.database_id}
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
param_names = list(sig.parameters.keys())
|
||||
# alias "schema" should be used instead of field name "schema_name"
|
||||
assert "schema" in param_names
|
||||
assert "schema_name" not in param_names
|
||||
# non-aliased field keeps its name
|
||||
assert "database_id" in param_names
|
||||
|
||||
def test_flattened_alias_callable(self):
|
||||
"""Flattened wrapper with aliased fields should construct model correctly."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
class AliasedRequest(BaseModel):
|
||||
schema_name: str | None = Field(None, description="Schema", alias="schema")
|
||||
database_id: int = Field(..., description="DB ID")
|
||||
|
||||
@parse_request(AliasedRequest)
|
||||
def my_tool(request, ctx=None):
|
||||
return {"schema": request.schema_name, "db": request.database_id}
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = my_tool(schema="public", database_id=1)
|
||||
assert result == {"schema": "public", "db": 1}
|
||||
|
||||
Reference in New Issue
Block a user