mirror of
https://github.com/apache/superset.git
synced 2026-04-20 00:24:38 +00:00
fix(mcp): expose individual tool parameters when MCP_PARSE_REQUEST_ENABLED=False (#38714)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -19,8 +19,12 @@
|
||||
Unit tests for MCP service schema utilities.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
from typing import Annotated, List
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
JSONParseError,
|
||||
@@ -564,3 +568,292 @@ class TestParseRequestDecorator:
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = sync_tool(json_str)
|
||||
assert result == "test:42"
|
||||
|
||||
|
||||
class TestParseRequestFlattenedDecorator:
|
||||
"""Test parse_request decorator when MCP_PARSE_REQUEST_ENABLED=False.
|
||||
|
||||
In this mode the decorator produces a "flattened" wrapper that exposes
|
||||
individual Pydantic model fields as function parameters instead of a
|
||||
single opaque ``request`` parameter.
|
||||
"""
|
||||
|
||||
class SimpleRequest(BaseModel):
|
||||
"""Request with required and optional fields."""
|
||||
|
||||
name: str = Field(description="The name")
|
||||
count: int = Field(description="Item count")
|
||||
tag: str | None = Field(default=None, description="Optional tag")
|
||||
|
||||
class EmptyRequest(BaseModel):
|
||||
"""Request with no fields (like GetSupersetInstanceInfoRequest)."""
|
||||
|
||||
class ParentModel(BaseModel):
|
||||
"""Base model with inherited fields."""
|
||||
|
||||
use_cache: bool = Field(default=True, description="Use cache")
|
||||
|
||||
class ChildRequest(ParentModel):
|
||||
"""Child inherits use_cache from ParentModel."""
|
||||
|
||||
page: int = Field(default=1, description="Page number")
|
||||
|
||||
class ListDefaultRequest(BaseModel):
|
||||
"""Request with default_factory fields."""
|
||||
|
||||
items: List[str] = Field(default_factory=list, description="List of items")
|
||||
label: str = Field(default="default", description="A label")
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_parse_request(self):
|
||||
"""Ensure MCP_PARSE_REQUEST_ENABLED=False for flattened tests."""
|
||||
from unittest.mock import patch
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.utils.schema_utils._is_parse_request_enabled",
|
||||
return_value=False,
|
||||
):
|
||||
yield
|
||||
|
||||
def test_flattened_signature_has_individual_params(self):
|
||||
"""Decorated function should expose model fields, not 'request'."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
param_names = list(sig.parameters.keys())
|
||||
assert "request" not in param_names
|
||||
assert "ctx" not in param_names
|
||||
assert "name" in param_names
|
||||
assert "count" in param_names
|
||||
assert "tag" in param_names
|
||||
|
||||
def test_flattened_signature_defaults(self):
|
||||
"""Required fields have no default; optional fields keep their defaults."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
assert sig.parameters["name"].default is inspect.Parameter.empty
|
||||
assert sig.parameters["count"].default is inspect.Parameter.empty
|
||||
assert sig.parameters["tag"].default is None
|
||||
|
||||
def test_flattened_async_constructs_model(self):
|
||||
"""Async flattened wrapper should construct model from kwargs."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}:{request.tag}"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(my_tool(name="hello", count=7, tag="x"))
|
||||
assert result == "hello:7:x"
|
||||
|
||||
def test_flattened_sync_constructs_model(self):
|
||||
"""Sync flattened wrapper should construct model from kwargs."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
def my_tool(request, ctx=None):
|
||||
return f"{request.name}:{request.count}:{request.tag}"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = my_tool(name="hello", count=7)
|
||||
assert result == "hello:7:None"
|
||||
|
||||
def test_flattened_empty_model(self):
|
||||
"""Empty request model should produce zero parameters."""
|
||||
|
||||
@parse_request(self.EmptyRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return "ok"
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
assert len(sig.parameters) == 0
|
||||
|
||||
def test_flattened_empty_model_callable(self):
|
||||
"""Tool with empty model should be callable with no args."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.EmptyRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return "ok"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(my_tool())
|
||||
assert result == "ok"
|
||||
|
||||
def test_flattened_inherited_fields(self):
|
||||
"""Inherited fields from base model should appear in the signature."""
|
||||
|
||||
@parse_request(self.ChildRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return f"{request.use_cache}:{request.page}"
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
assert "use_cache" in sig.parameters
|
||||
assert "page" in sig.parameters
|
||||
assert sig.parameters["use_cache"].default is True
|
||||
assert sig.parameters["page"].default == 1
|
||||
|
||||
def test_flattened_inherited_fields_callable(self):
|
||||
"""Tool with inherited fields should be callable."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.ChildRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return f"{request.use_cache}:{request.page}"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = asyncio.run(my_tool(use_cache=False, page=3))
|
||||
assert result == "False:3"
|
||||
|
||||
def test_flattened_default_factory_fields(self):
|
||||
"""Fields with default_factory should get the factory result as default."""
|
||||
|
||||
@parse_request(self.ListDefaultRequest)
|
||||
def my_tool(request, ctx=None):
|
||||
return f"{request.items}:{request.label}"
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
assert sig.parameters["items"].default == []
|
||||
assert sig.parameters["label"].default == "default"
|
||||
|
||||
def test_flattened_preserves_docstring(self):
|
||||
"""Flattened wrapper should preserve the original function's docstring."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
"""My tool docstring."""
|
||||
pass
|
||||
|
||||
assert my_tool.__doc__ == "My tool docstring."
|
||||
|
||||
def test_flattened_preserves_function_name(self):
|
||||
"""Flattened wrapper should preserve the original function's name."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
assert my_tool.__name__ == "my_tool"
|
||||
|
||||
def test_flattened_params_are_keyword_only(self):
|
||||
"""All parameters should be keyword-only."""
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
for param in sig.parameters.values():
|
||||
assert param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
|
||||
def test_flattened_validation_error(self):
|
||||
"""Pydantic validation errors should propagate when constructing model."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
def my_tool(request, ctx=None):
|
||||
return "ok"
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
with pytest.raises(ValidationError):
|
||||
my_tool(name="test", count="not_a_number")
|
||||
|
||||
def test_flattened_annotations_have_descriptions(self):
|
||||
"""Annotations should include Field descriptions for FastMCP schema."""
|
||||
from typing import get_origin
|
||||
|
||||
@parse_request(self.SimpleRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
name_param = sig.parameters["name"]
|
||||
# Should be Annotated[str, Field(description="The name")]
|
||||
assert get_origin(name_param.annotation) is not None or hasattr(
|
||||
name_param.annotation, "__metadata__"
|
||||
)
|
||||
|
||||
def test_flattened_annotations_forward_constraints(self):
|
||||
"""Constraint metadata (ge, le, min_length, pattern, etc.) should be
|
||||
forwarded into the Annotated type so FastMCP includes them in the
|
||||
JSON schema."""
|
||||
from typing import get_args, get_origin
|
||||
|
||||
from annotated_types import Ge, Le, MaxLen, MinLen
|
||||
|
||||
class ConstrainedRequest(BaseModel):
|
||||
page: int = Field(1, description="Page number", ge=1, le=1000)
|
||||
name: str = Field(..., description="Name", min_length=1, max_length=255)
|
||||
|
||||
@parse_request(ConstrainedRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
pass
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
|
||||
# Check page param has ge/le constraints in metadata
|
||||
page_ann = sig.parameters["page"].annotation
|
||||
assert get_origin(page_ann) is Annotated
|
||||
page_metadata = get_args(page_ann)[1:] # skip base type
|
||||
metadata_types = [type(m) for m in page_metadata]
|
||||
assert Ge in metadata_types, f"Missing Ge in {page_metadata}"
|
||||
assert Le in metadata_types, f"Missing Le in {page_metadata}"
|
||||
|
||||
# Check name param has min_length/max_length constraints
|
||||
name_ann = sig.parameters["name"].annotation
|
||||
assert get_origin(name_ann) is Annotated
|
||||
name_metadata = get_args(name_ann)[1:]
|
||||
metadata_types = [type(m) for m in name_metadata]
|
||||
assert MinLen in metadata_types, f"Missing MinLen in {name_metadata}"
|
||||
assert MaxLen in metadata_types, f"Missing MaxLen in {name_metadata}"
|
||||
|
||||
def test_flattened_uses_alias_as_param_name(self):
|
||||
"""When a field has an alias, the flattened param name should be the alias,
|
||||
not the Python field name. This ensures model_validate() works for models
|
||||
without populate_by_name=True."""
|
||||
|
||||
class AliasedRequest(BaseModel):
|
||||
schema_name: str | None = Field(None, description="Schema", alias="schema")
|
||||
database_id: int = Field(..., description="DB ID")
|
||||
|
||||
@parse_request(AliasedRequest)
|
||||
async def my_tool(request, ctx=None):
|
||||
return {"schema": request.schema_name, "db": request.database_id}
|
||||
|
||||
sig = inspect.signature(my_tool)
|
||||
param_names = list(sig.parameters.keys())
|
||||
# alias "schema" should be used instead of field name "schema_name"
|
||||
assert "schema" in param_names
|
||||
assert "schema_name" not in param_names
|
||||
# non-aliased field keeps its name
|
||||
assert "database_id" in param_names
|
||||
|
||||
def test_flattened_alias_callable(self):
|
||||
"""Flattened wrapper with aliased fields should construct model correctly."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
class AliasedRequest(BaseModel):
|
||||
schema_name: str | None = Field(None, description="Schema", alias="schema")
|
||||
database_id: int = Field(..., description="DB ID")
|
||||
|
||||
@parse_request(AliasedRequest)
|
||||
def my_tool(request, ctx=None):
|
||||
return {"schema": request.schema_name, "db": request.database_id}
|
||||
|
||||
mock_ctx = MagicMock()
|
||||
with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx):
|
||||
result = my_tool(schema="public", database_id=1)
|
||||
assert result == {"schema": "public", "db": 1}
|
||||
|
||||
Reference in New Issue
Block a user