diff --git a/superset/mcp_service/utils/schema_utils.py b/superset/mcp_service/utils/schema_utils.py index f75e122b6ff..5a57ac63d13 100644 --- a/superset/mcp_service/utils/schema_utils.py +++ b/superset/mcp_service/utils/schema_utils.py @@ -25,11 +25,12 @@ for input parameters, making MCP tools more flexible for different clients. from __future__ import annotations import asyncio +import inspect import logging from functools import wraps -from typing import Any, Callable, List, Type, TypeVar +from typing import Annotated, Any, Callable, List, Type, TypeVar -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, Field, ValidationError logger = logging.getLogger(__name__) @@ -411,9 +412,12 @@ def parse_request( the tool function. Also modifies the function's type annotations to accept str | RequestModel to pass FastMCP validation. - Can be disabled by setting MCP_PARSE_REQUEST_ENABLED = False in config. - When disabled, string-to-model parsing is skipped but ctx injection and - signature stripping still apply. + Behavior depends on MCP_PARSE_REQUEST_ENABLED config (checked at decoration time): + + - **When True (default):** Single `request: str | Model` param with string parsing + at runtime. Preserves the Claude Code double-serialization workaround. + - **When False (production):** Flattened wrapper exposing all Pydantic model fields + as individual function parameters. FastMCP generates proper per-field schemas. See: https://github.com/anthropics/claude-code/issues/5504 @@ -444,87 +448,232 @@ def parse_request( """ def decorator(func: Callable[..., Any]) -> Callable[..., Any]: - import types - - def _maybe_parse(request: Any) -> Any: - if _is_parse_request_enabled(): - try: - return parse_json_or_model(request, request_class, "request") - except ValidationError as e: - from fastmcp.exceptions import ToolError - - details = [] - for err in e.errors(): - field = " -> ".join(str(loc) for loc in err["loc"]) - details.append(f"{field}: {err['msg']}") - required_fields = [ - f.alias or name - for name, f in request_class.model_fields.items() - if f.is_required() - ] - raise ToolError( - f"Invalid request parameters: {'; '.join(details)}. " - f"Required fields for {request_class.__name__}: " - f"{', '.join(required_fields)}" - ) from None - return request - - if asyncio.iscoroutinefunction(func): - - @wraps(func) - async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any: - from fastmcp.server.dependencies import get_context - - ctx = get_context() - return await func(_maybe_parse(request), ctx, *args, **kwargs) - - wrapper = async_wrapper - else: - - @wraps(func) - def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any: - from fastmcp.server.dependencies import get_context - - ctx = get_context() - return func(_maybe_parse(request), ctx, *args, **kwargs) - - wrapper = sync_wrapper - - # Merge original function's __globals__ into wrapper's __globals__ - # This allows get_type_hints() to resolve type annotations from the - # original module (e.g., Context from fastmcp) - # FastMCP 2.13.2+ uses get_type_hints() which needs access to these types - merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined] - new_wrapper = types.FunctionType( - wrapper.__code__, # type: ignore[attr-defined] - merged_globals, - wrapper.__name__, - wrapper.__defaults__, # type: ignore[attr-defined] - wrapper.__closure__, # type: ignore[attr-defined] - ) - # Copy __dict__ but exclude __wrapped__ - # NOTE: We intentionally do NOT preserve __wrapped__ here. - # Setting __wrapped__ causes inspect.signature() to follow the chain - # and find 'ctx' in the original function's signature, even after - # FastMCP's create_function_without_params removes it from annotations. - # This breaks Pydantic's TypeAdapter which expects signature params - # to match type_hints. - new_wrapper.__dict__.update( - {k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"} - ) - new_wrapper.__module__ = wrapper.__module__ - new_wrapper.__qualname__ = wrapper.__qualname__ - # Copy docstring from original function (not wrapper, which has no docstring) - new_wrapper.__doc__ = func.__doc__ - - request_annotation = str | request_class - _apply_signature_for_fastmcp(new_wrapper, func, request_annotation) - - return new_wrapper + if _is_parse_request_enabled(): + return _create_string_parsing_wrapper(func, request_class) + return _create_flattened_wrapper(func, request_class) return decorator +def _create_string_parsing_wrapper( + func: Callable[..., Any], + request_class: Type[BaseModel], +) -> Callable[..., Any]: + """Create a wrapper that accepts a single `request: str | Model` parameter. + + This is the original parse_request behavior: at runtime, if the request + is a JSON string it gets parsed into the Pydantic model. Used when + MCP_PARSE_REQUEST_ENABLED is True. + """ + import types + + def _maybe_parse(request: Any) -> Any: + if _is_parse_request_enabled(): + try: + return parse_json_or_model(request, request_class, "request") + except ValidationError as e: + from fastmcp.exceptions import ToolError + + details = [] + for err in e.errors(): + field = " -> ".join(str(loc) for loc in err["loc"]) + details.append(f"{field}: {err['msg']}") + required_fields = [ + f.alias or name + for name, f in request_class.model_fields.items() + if f.is_required() + ] + raise ToolError( + f"Invalid request parameters: {'; '.join(details)}. " + f"Required fields for {request_class.__name__}: " + f"{', '.join(required_fields)}" + ) from None + return request + + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any: + from fastmcp.server.dependencies import get_context + + ctx = get_context() + return await func(_maybe_parse(request), ctx, *args, **kwargs) + + wrapper = async_wrapper + else: + + @wraps(func) + def sync_wrapper(request: Any, *args: Any, **kwargs: Any) -> Any: + from fastmcp.server.dependencies import get_context + + ctx = get_context() + return func(_maybe_parse(request), ctx, *args, **kwargs) + + wrapper = sync_wrapper + + # Merge original function's __globals__ into wrapper's __globals__ + # This allows get_type_hints() to resolve type annotations from the + # original module (e.g., Context from fastmcp) + # FastMCP 2.13.2+ uses get_type_hints() which needs access to these types + merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined] + new_wrapper = types.FunctionType( + wrapper.__code__, # type: ignore[attr-defined] + merged_globals, + wrapper.__name__, + wrapper.__defaults__, # type: ignore[attr-defined] + wrapper.__closure__, # type: ignore[attr-defined] + ) + # Copy __dict__ but exclude __wrapped__ + # NOTE: We intentionally do NOT preserve __wrapped__ here. + # Setting __wrapped__ causes inspect.signature() to follow the chain + # and find 'ctx' in the original function's signature, even after + # FastMCP's create_function_without_params removes it from annotations. + # This breaks Pydantic's TypeAdapter which expects signature params + # to match type_hints. + new_wrapper.__dict__.update( + {k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"} + ) + new_wrapper.__module__ = wrapper.__module__ + new_wrapper.__qualname__ = wrapper.__qualname__ + # Copy docstring from original function (not wrapper, which has no docstring) + new_wrapper.__doc__ = func.__doc__ + + request_annotation = str | request_class + _apply_signature_for_fastmcp(new_wrapper, func, request_annotation) + + return new_wrapper + + +def _create_flattened_wrapper( + func: Callable[..., Any], + request_class: Type[BaseModel], +) -> Callable[..., Any]: + """Create a wrapper that exposes individual Pydantic model fields as parameters. + + Used when MCP_PARSE_REQUEST_ENABLED is False. Instead of a single opaque + ``request`` parameter, the wrapper accepts each field from ``request_class`` + as a keyword argument. FastMCP then generates a proper JSON schema with + individual properties for each field. + """ + import types + + if asyncio.iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(**kwargs: Any) -> Any: + from fastmcp.server.dependencies import get_context + + ctx = get_context() + model = request_class.model_validate(kwargs) + return await func(model, ctx) + + wrapper = async_wrapper + else: + + @wraps(func) + def sync_wrapper(**kwargs: Any) -> Any: + from fastmcp.server.dependencies import get_context + + ctx = get_context() + model = request_class.model_validate(kwargs) + return func(model, ctx) + + wrapper = sync_wrapper + + # Merge globals so get_type_hints() can resolve annotations from + # the original function's module. + merged_globals = {**wrapper.__globals__, **func.__globals__} # type: ignore[attr-defined] + new_wrapper = types.FunctionType( + wrapper.__code__, # type: ignore[attr-defined] + merged_globals, + wrapper.__name__, + wrapper.__defaults__, # type: ignore[attr-defined] + wrapper.__closure__, # type: ignore[attr-defined] + ) + new_wrapper.__dict__.update( + {k: v for k, v in wrapper.__dict__.items() if k != "__wrapped__"} + ) + new_wrapper.__module__ = wrapper.__module__ + new_wrapper.__qualname__ = wrapper.__qualname__ + new_wrapper.__doc__ = func.__doc__ + + _apply_flattened_signature(new_wrapper, func, request_class) + + return new_wrapper + + +def _apply_flattened_signature( + wrapper: Any, + original_func: Callable[..., Any], + request_class: Type[BaseModel], +) -> None: + """Build a signature from the Pydantic model's fields and apply it to *wrapper*. + + Each field in ``request_class.model_fields`` becomes a keyword-only parameter. + Field descriptions and constraints (ge, le, min_length, max_length, pattern, + etc.) are preserved via ``Annotated[type, Field(description=...), ...]`` + so that FastMCP propagates them into the generated JSON schema. + """ + params: list[inspect.Parameter] = [] + annotations: dict[str, Any] = {} + + for name, field_info in request_class.model_fields.items(): + # Use alias as param name when set, otherwise field name. + # This ensures model_validate() always works: aliases are always + # accepted, while field names require populate_by_name=True. + param_name = field_info.alias if isinstance(field_info.alias, str) else name + + # Determine the default value for this parameter. + # Check default_factory before default because Pydantic sets + # default to PydanticUndefined when default_factory is used. + if field_info.is_required(): + default = inspect.Parameter.empty + elif field_info.default_factory is not None: + default = field_info.default_factory() + else: + default = field_info.default # covers None and explicit defaults + + # Build Annotated type with description and all constraint metadata + # (ge, le, min_length, max_length, pattern, etc.) so FastMCP + # propagates them into the generated JSON schema. + base_annotation = field_info.annotation + metadata_markers: list[Any] = [] + if field_info.description: + metadata_markers.append(Field(description=field_info.description)) + # Forward constraint metadata from the original Pydantic field. + # Pydantic v2 stores constraints (Ge, Le, MinLen, MaxLen, etc.) + # as annotation objects in field_info.metadata. + if field_info.metadata: + metadata_markers.extend(field_info.metadata) + if metadata_markers: + # Dynamically construct Annotated[base_type, meta1, meta2, ...] + annotation: Any = Annotated.__class_getitem__( # type: ignore[attr-defined] + (base_annotation, *metadata_markers) + ) + else: + annotation = base_annotation + + params.append( + inspect.Parameter( + param_name, + kind=inspect.Parameter.KEYWORD_ONLY, + default=default, + annotation=annotation, + ) + ) + annotations[param_name] = annotation + + # Preserve return annotation from the original function if present + orig_sig = inspect.signature(original_func) + wrapper.__signature__ = orig_sig.replace( + parameters=params, + return_annotation=orig_sig.return_annotation, + ) + wrapper.__annotations__ = annotations + if orig_sig.return_annotation is not inspect.Parameter.empty: + wrapper.__annotations__["return"] = orig_sig.return_annotation + + def _apply_signature_for_fastmcp( wrapper: Any, original_func: Callable[..., Any], diff --git a/tests/unit_tests/mcp_service/utils/test_schema_utils.py b/tests/unit_tests/mcp_service/utils/test_schema_utils.py index ab3450e6182..419ff04471a 100644 --- a/tests/unit_tests/mcp_service/utils/test_schema_utils.py +++ b/tests/unit_tests/mcp_service/utils/test_schema_utils.py @@ -19,8 +19,12 @@ Unit tests for MCP service schema utilities. """ +import asyncio +import inspect +from typing import Annotated, List + import pytest -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, Field, ValidationError from superset.mcp_service.utils.schema_utils import ( JSONParseError, @@ -564,3 +568,292 @@ class TestParseRequestDecorator: with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): result = sync_tool(json_str) assert result == "test:42" + + +class TestParseRequestFlattenedDecorator: + """Test parse_request decorator when MCP_PARSE_REQUEST_ENABLED=False. + + In this mode the decorator produces a "flattened" wrapper that exposes + individual Pydantic model fields as function parameters instead of a + single opaque ``request`` parameter. + """ + + class SimpleRequest(BaseModel): + """Request with required and optional fields.""" + + name: str = Field(description="The name") + count: int = Field(description="Item count") + tag: str | None = Field(default=None, description="Optional tag") + + class EmptyRequest(BaseModel): + """Request with no fields (like GetSupersetInstanceInfoRequest).""" + + class ParentModel(BaseModel): + """Base model with inherited fields.""" + + use_cache: bool = Field(default=True, description="Use cache") + + class ChildRequest(ParentModel): + """Child inherits use_cache from ParentModel.""" + + page: int = Field(default=1, description="Page number") + + class ListDefaultRequest(BaseModel): + """Request with default_factory fields.""" + + items: List[str] = Field(default_factory=list, description="List of items") + label: str = Field(default="default", description="A label") + + @pytest.fixture(autouse=True) + def _disable_parse_request(self): + """Ensure MCP_PARSE_REQUEST_ENABLED=False for flattened tests.""" + from unittest.mock import patch + + with patch( + "superset.mcp_service.utils.schema_utils._is_parse_request_enabled", + return_value=False, + ): + yield + + def test_flattened_signature_has_individual_params(self): + """Decorated function should expose model fields, not 'request'.""" + + @parse_request(self.SimpleRequest) + async def my_tool(request, ctx=None): + pass + + sig = inspect.signature(my_tool) + param_names = list(sig.parameters.keys()) + assert "request" not in param_names + assert "ctx" not in param_names + assert "name" in param_names + assert "count" in param_names + assert "tag" in param_names + + def test_flattened_signature_defaults(self): + """Required fields have no default; optional fields keep their defaults.""" + + @parse_request(self.SimpleRequest) + async def my_tool(request, ctx=None): + pass + + sig = inspect.signature(my_tool) + assert sig.parameters["name"].default is inspect.Parameter.empty + assert sig.parameters["count"].default is inspect.Parameter.empty + assert sig.parameters["tag"].default is None + + def test_flattened_async_constructs_model(self): + """Async flattened wrapper should construct model from kwargs.""" + from unittest.mock import MagicMock, patch + + @parse_request(self.SimpleRequest) + async def my_tool(request, ctx=None): + return f"{request.name}:{request.count}:{request.tag}" + + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = asyncio.run(my_tool(name="hello", count=7, tag="x")) + assert result == "hello:7:x" + + def test_flattened_sync_constructs_model(self): + """Sync flattened wrapper should construct model from kwargs.""" + from unittest.mock import MagicMock, patch + + @parse_request(self.SimpleRequest) + def my_tool(request, ctx=None): + return f"{request.name}:{request.count}:{request.tag}" + + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = my_tool(name="hello", count=7) + assert result == "hello:7:None" + + def test_flattened_empty_model(self): + """Empty request model should produce zero parameters.""" + + @parse_request(self.EmptyRequest) + async def my_tool(request, ctx=None): + return "ok" + + sig = inspect.signature(my_tool) + assert len(sig.parameters) == 0 + + def test_flattened_empty_model_callable(self): + """Tool with empty model should be callable with no args.""" + from unittest.mock import MagicMock, patch + + @parse_request(self.EmptyRequest) + async def my_tool(request, ctx=None): + return "ok" + + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = asyncio.run(my_tool()) + assert result == "ok" + + def test_flattened_inherited_fields(self): + """Inherited fields from base model should appear in the signature.""" + + @parse_request(self.ChildRequest) + async def my_tool(request, ctx=None): + return f"{request.use_cache}:{request.page}" + + sig = inspect.signature(my_tool) + assert "use_cache" in sig.parameters + assert "page" in sig.parameters + assert sig.parameters["use_cache"].default is True + assert sig.parameters["page"].default == 1 + + def test_flattened_inherited_fields_callable(self): + """Tool with inherited fields should be callable.""" + from unittest.mock import MagicMock, patch + + @parse_request(self.ChildRequest) + async def my_tool(request, ctx=None): + return f"{request.use_cache}:{request.page}" + + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = asyncio.run(my_tool(use_cache=False, page=3)) + assert result == "False:3" + + def test_flattened_default_factory_fields(self): + """Fields with default_factory should get the factory result as default.""" + + @parse_request(self.ListDefaultRequest) + def my_tool(request, ctx=None): + return f"{request.items}:{request.label}" + + sig = inspect.signature(my_tool) + assert sig.parameters["items"].default == [] + assert sig.parameters["label"].default == "default" + + def test_flattened_preserves_docstring(self): + """Flattened wrapper should preserve the original function's docstring.""" + + @parse_request(self.SimpleRequest) + async def my_tool(request, ctx=None): + """My tool docstring.""" + pass + + assert my_tool.__doc__ == "My tool docstring." + + def test_flattened_preserves_function_name(self): + """Flattened wrapper should preserve the original function's name.""" + + @parse_request(self.SimpleRequest) + async def my_tool(request, ctx=None): + pass + + assert my_tool.__name__ == "my_tool" + + def test_flattened_params_are_keyword_only(self): + """All parameters should be keyword-only.""" + + @parse_request(self.SimpleRequest) + async def my_tool(request, ctx=None): + pass + + sig = inspect.signature(my_tool) + for param in sig.parameters.values(): + assert param.kind == inspect.Parameter.KEYWORD_ONLY + + def test_flattened_validation_error(self): + """Pydantic validation errors should propagate when constructing model.""" + from unittest.mock import MagicMock, patch + + @parse_request(self.SimpleRequest) + def my_tool(request, ctx=None): + return "ok" + + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + with pytest.raises(ValidationError): + my_tool(name="test", count="not_a_number") + + def test_flattened_annotations_have_descriptions(self): + """Annotations should include Field descriptions for FastMCP schema.""" + from typing import get_origin + + @parse_request(self.SimpleRequest) + async def my_tool(request, ctx=None): + pass + + sig = inspect.signature(my_tool) + name_param = sig.parameters["name"] + # Should be Annotated[str, Field(description="The name")] + assert get_origin(name_param.annotation) is not None or hasattr( + name_param.annotation, "__metadata__" + ) + + def test_flattened_annotations_forward_constraints(self): + """Constraint metadata (ge, le, min_length, pattern, etc.) should be + forwarded into the Annotated type so FastMCP includes them in the + JSON schema.""" + from typing import get_args, get_origin + + from annotated_types import Ge, Le, MaxLen, MinLen + + class ConstrainedRequest(BaseModel): + page: int = Field(1, description="Page number", ge=1, le=1000) + name: str = Field(..., description="Name", min_length=1, max_length=255) + + @parse_request(ConstrainedRequest) + async def my_tool(request, ctx=None): + pass + + sig = inspect.signature(my_tool) + + # Check page param has ge/le constraints in metadata + page_ann = sig.parameters["page"].annotation + assert get_origin(page_ann) is Annotated + page_metadata = get_args(page_ann)[1:] # skip base type + metadata_types = [type(m) for m in page_metadata] + assert Ge in metadata_types, f"Missing Ge in {page_metadata}" + assert Le in metadata_types, f"Missing Le in {page_metadata}" + + # Check name param has min_length/max_length constraints + name_ann = sig.parameters["name"].annotation + assert get_origin(name_ann) is Annotated + name_metadata = get_args(name_ann)[1:] + metadata_types = [type(m) for m in name_metadata] + assert MinLen in metadata_types, f"Missing MinLen in {name_metadata}" + assert MaxLen in metadata_types, f"Missing MaxLen in {name_metadata}" + + def test_flattened_uses_alias_as_param_name(self): + """When a field has an alias, the flattened param name should be the alias, + not the Python field name. This ensures model_validate() works for models + without populate_by_name=True.""" + + class AliasedRequest(BaseModel): + schema_name: str | None = Field(None, description="Schema", alias="schema") + database_id: int = Field(..., description="DB ID") + + @parse_request(AliasedRequest) + async def my_tool(request, ctx=None): + return {"schema": request.schema_name, "db": request.database_id} + + sig = inspect.signature(my_tool) + param_names = list(sig.parameters.keys()) + # alias "schema" should be used instead of field name "schema_name" + assert "schema" in param_names + assert "schema_name" not in param_names + # non-aliased field keeps its name + assert "database_id" in param_names + + def test_flattened_alias_callable(self): + """Flattened wrapper with aliased fields should construct model correctly.""" + from unittest.mock import MagicMock, patch + + class AliasedRequest(BaseModel): + schema_name: str | None = Field(None, description="Schema", alias="schema") + database_id: int = Field(..., description="DB ID") + + @parse_request(AliasedRequest) + def my_tool(request, ctx=None): + return {"schema": request.schema_name, "db": request.database_id} + + mock_ctx = MagicMock() + with patch("fastmcp.server.dependencies.get_context", return_value=mock_ctx): + result = my_tool(schema="public", database_id=1) + assert result == {"schema": "public", "db": 1}