fix(mcp): expose individual tool parameters when MCP_PARSE_REQUEST_ENABLED=False (#38714)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-03-18 11:38:22 +01:00
committed by GitHub
parent 834d2abe70
commit e02ca8871d
2 changed files with 525 additions and 83 deletions

View File

@@ -19,8 +19,12 @@
Unit tests for MCP service schema utilities.
"""
import asyncio
import inspect
from typing import Annotated, List
import pytest
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel, Field, ValidationError
from superset.mcp_service.utils.schema_utils import (
JSONParseError,
@@ -564,3 +568,292 @@ class TestParseRequestDecorator:
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = sync_tool(json_str)
assert result == "test:42"
class TestParseRequestFlattenedDecorator:
"""Test parse_request decorator when MCP_PARSE_REQUEST_ENABLED=False.
In this mode the decorator produces a "flattened" wrapper that exposes
individual Pydantic model fields as function parameters instead of a
single opaque ``request`` parameter.
"""
class SimpleRequest(BaseModel):
"""Request with required and optional fields."""
name: str = Field(description="The name")
count: int = Field(description="Item count")
tag: str | None = Field(default=None, description="Optional tag")
class EmptyRequest(BaseModel):
"""Request with no fields (like GetSupersetInstanceInfoRequest)."""
class ParentModel(BaseModel):
"""Base model with inherited fields."""
use_cache: bool = Field(default=True, description="Use cache")
class ChildRequest(ParentModel):
"""Child inherits use_cache from ParentModel."""
page: int = Field(default=1, description="Page number")
class ListDefaultRequest(BaseModel):
"""Request with default_factory fields."""
items: List[str] = Field(default_factory=list, description="List of items")
label: str = Field(default="default", description="A label")
@pytest.fixture(autouse=True)
def _disable_parse_request(self):
"""Ensure MCP_PARSE_REQUEST_ENABLED=False for flattened tests."""
from unittest.mock import patch
with patch(
"superset.mcp_service.utils.schema_utils._is_parse_request_enabled",
return_value=False,
):
yield
def test_flattened_signature_has_individual_params(self):
"""Decorated function should expose model fields, not 'request'."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
param_names = list(sig.parameters.keys())
assert "request" not in param_names
assert "ctx" not in param_names
assert "name" in param_names
assert "count" in param_names
assert "tag" in param_names
def test_flattened_signature_defaults(self):
"""Required fields have no default; optional fields keep their defaults."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
assert sig.parameters["name"].default is inspect.Parameter.empty
assert sig.parameters["count"].default is inspect.Parameter.empty
assert sig.parameters["tag"].default is None
def test_flattened_async_constructs_model(self):
"""Async flattened wrapper should construct model from kwargs."""
from unittest.mock import MagicMock, patch
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
return f"{request.name}:{request.count}:{request.tag}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(my_tool(name="hello", count=7, tag="x"))
assert result == "hello:7:x"
def test_flattened_sync_constructs_model(self):
"""Sync flattened wrapper should construct model from kwargs."""
from unittest.mock import MagicMock, patch
@parse_request(self.SimpleRequest)
def my_tool(request, ctx=None):
return f"{request.name}:{request.count}:{request.tag}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = my_tool(name="hello", count=7)
assert result == "hello:7:None"
def test_flattened_empty_model(self):
"""Empty request model should produce zero parameters."""
@parse_request(self.EmptyRequest)
async def my_tool(request, ctx=None):
return "ok"
sig = inspect.signature(my_tool)
assert len(sig.parameters) == 0
def test_flattened_empty_model_callable(self):
"""Tool with empty model should be callable with no args."""
from unittest.mock import MagicMock, patch
@parse_request(self.EmptyRequest)
async def my_tool(request, ctx=None):
return "ok"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(my_tool())
assert result == "ok"
def test_flattened_inherited_fields(self):
"""Inherited fields from base model should appear in the signature."""
@parse_request(self.ChildRequest)
async def my_tool(request, ctx=None):
return f"{request.use_cache}:{request.page}"
sig = inspect.signature(my_tool)
assert "use_cache" in sig.parameters
assert "page" in sig.parameters
assert sig.parameters["use_cache"].default is True
assert sig.parameters["page"].default == 1
def test_flattened_inherited_fields_callable(self):
"""Tool with inherited fields should be callable."""
from unittest.mock import MagicMock, patch
@parse_request(self.ChildRequest)
async def my_tool(request, ctx=None):
return f"{request.use_cache}:{request.page}"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = asyncio.run(my_tool(use_cache=False, page=3))
assert result == "False:3"
def test_flattened_default_factory_fields(self):
"""Fields with default_factory should get the factory result as default."""
@parse_request(self.ListDefaultRequest)
def my_tool(request, ctx=None):
return f"{request.items}:{request.label}"
sig = inspect.signature(my_tool)
assert sig.parameters["items"].default == []
assert sig.parameters["label"].default == "default"
def test_flattened_preserves_docstring(self):
"""Flattened wrapper should preserve the original function's docstring."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
"""My tool docstring."""
pass
assert my_tool.__doc__ == "My tool docstring."
def test_flattened_preserves_function_name(self):
"""Flattened wrapper should preserve the original function's name."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
assert my_tool.__name__ == "my_tool"
def test_flattened_params_are_keyword_only(self):
"""All parameters should be keyword-only."""
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
for param in sig.parameters.values():
assert param.kind == inspect.Parameter.KEYWORD_ONLY
def test_flattened_validation_error(self):
"""Pydantic validation errors should propagate when constructing model."""
from unittest.mock import MagicMock, patch
@parse_request(self.SimpleRequest)
def my_tool(request, ctx=None):
return "ok"
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
with pytest.raises(ValidationError):
my_tool(name="test", count="not_a_number")
def test_flattened_annotations_have_descriptions(self):
"""Annotations should include Field descriptions for FastMCP schema."""
from typing import get_origin
@parse_request(self.SimpleRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
name_param = sig.parameters["name"]
# Should be Annotated[str, Field(description="The name")]
assert get_origin(name_param.annotation) is not None or hasattr(
name_param.annotation, "__metadata__"
)
def test_flattened_annotations_forward_constraints(self):
"""Constraint metadata (ge, le, min_length, pattern, etc.) should be
forwarded into the Annotated type so FastMCP includes them in the
JSON schema."""
from typing import get_args, get_origin
from annotated_types import Ge, Le, MaxLen, MinLen
class ConstrainedRequest(BaseModel):
page: int = Field(1, description="Page number", ge=1, le=1000)
name: str = Field(..., description="Name", min_length=1, max_length=255)
@parse_request(ConstrainedRequest)
async def my_tool(request, ctx=None):
pass
sig = inspect.signature(my_tool)
# Check page param has ge/le constraints in metadata
page_ann = sig.parameters["page"].annotation
assert get_origin(page_ann) is Annotated
page_metadata = get_args(page_ann)[1:] # skip base type
metadata_types = [type(m) for m in page_metadata]
assert Ge in metadata_types, f"Missing Ge in {page_metadata}"
assert Le in metadata_types, f"Missing Le in {page_metadata}"
# Check name param has min_length/max_length constraints
name_ann = sig.parameters["name"].annotation
assert get_origin(name_ann) is Annotated
name_metadata = get_args(name_ann)[1:]
metadata_types = [type(m) for m in name_metadata]
assert MinLen in metadata_types, f"Missing MinLen in {name_metadata}"
assert MaxLen in metadata_types, f"Missing MaxLen in {name_metadata}"
def test_flattened_uses_alias_as_param_name(self):
"""When a field has an alias, the flattened param name should be the alias,
not the Python field name. This ensures model_validate() works for models
without populate_by_name=True."""
class AliasedRequest(BaseModel):
schema_name: str | None = Field(None, description="Schema", alias="schema")
database_id: int = Field(..., description="DB ID")
@parse_request(AliasedRequest)
async def my_tool(request, ctx=None):
return {"schema": request.schema_name, "db": request.database_id}
sig = inspect.signature(my_tool)
param_names = list(sig.parameters.keys())
# alias "schema" should be used instead of field name "schema_name"
assert "schema" in param_names
assert "schema_name" not in param_names
# non-aliased field keeps its name
assert "database_id" in param_names
def test_flattened_alias_callable(self):
"""Flattened wrapper with aliased fields should construct model correctly."""
from unittest.mock import MagicMock, patch
class AliasedRequest(BaseModel):
schema_name: str | None = Field(None, description="Schema", alias="schema")
database_id: int = Field(..., description="DB ID")
@parse_request(AliasedRequest)
def my_tool(request, ctx=None):
return {"schema": request.schema_name, "db": request.database_id}
mock_ctx = MagicMock()
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
result = my_tool(schema="public", database_id=1)
assert result == {"schema": "public", "db": 1}