fix(mcp): unwrap ToolResult payload before truncation in ResponseSizeGuardMiddleware (#39578)

Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
This commit is contained in:
Amin Ghadersohi
2026-04-23 12:35:13 -04:00
committed by GitHub
parent 9b52110ab1
commit b1b6a057d8
2 changed files with 359 additions and 2 deletions

View File

@@ -19,6 +19,7 @@
Unit tests for MCP service middleware.
"""
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -378,6 +379,282 @@ class TestCreateResponseSizeGuardMiddleware:
assert middleware is None
class TestExtractPayloadFromToolResult:
"""Tests for _extract_payload_from_tool_result static method."""
def _make_tool_result(self, text: str, meta: dict[str, Any] | None = None) -> Any:
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
return ToolResult(
content=[TextContent(type="text", text=text)],
meta=meta,
)
def test_extracts_dict_payload(self) -> None:
"""Should parse the JSON text inside content[0] and return the dict."""
from superset.utils import json
payload = {"id": 1, "name": "test", "charts": [1, 2, 3]}
tool_result = self._make_tool_result(json.dumps(payload))
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is not None
assert result == payload
def test_returns_none_for_plain_dict(self) -> None:
"""Should return None when given a plain dict (not a ToolResult)."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
{"key": "val"}
)
is None
)
def test_returns_none_for_string(self) -> None:
"""Should return None when given a plain string."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result("string")
is None
)
def test_returns_none_for_list(self) -> None:
"""Should return None when given a plain list."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result([1, 2, 3])
is None
)
def test_returns_none_for_none(self) -> None:
"""Should return None when given None."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(None) is None
)
def test_returns_none_when_payload_is_list(self) -> None:
"""Should return None when JSON payload is a list, not a dict."""
from superset.utils import json
tool_result = self._make_tool_result(json.dumps([{"id": 1}, {"id": 2}]))
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
def test_returns_none_for_invalid_json(self) -> None:
"""Should return None when content text is not valid JSON."""
tool_result = self._make_tool_result("not valid {{{json")
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
def test_returns_none_for_empty_text(self) -> None:
"""Should return None when content[0].text is empty."""
tool_result = self._make_tool_result("")
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
def test_returns_none_for_empty_content(self) -> None:
"""Should return None when ToolResult has no content items."""
from fastmcp.tools.tool import ToolResult
tool_result = ToolResult(content=[], meta=None)
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
class TestRewrapAsToolResult:
"""Tests for _rewrap_as_tool_result static method."""
def _make_tool_result(
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
) -> Any:
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
from superset.utils import json
return ToolResult(
content=[TextContent(type="text", text=json.dumps(payload))],
meta=meta,
)
def test_returns_tool_result_with_serialized_payload(self) -> None:
"""Should return a ToolResult whose content[0].text is the JSON payload."""
from fastmcp.tools.tool import ToolResult
from superset.utils import json
original = self._make_tool_result({"old": "data"})
new_payload = {"id": 1, "name": "truncated", "_response_truncated": True}
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
new_payload, original
)
assert isinstance(result, ToolResult)
assert result.content[0].type == "text"
reparsed = json.loads(result.content[0].text)
assert reparsed == new_payload
def test_preserves_meta_from_original_tool_result(self) -> None:
"""Should copy meta from the original ToolResult."""
from fastmcp.tools.tool import ToolResult
meta = {"request_id": "abc-123", "trace": "xyz"}
original = self._make_tool_result({"key": "val"}, meta=meta)
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
{"key": "val"}, original
)
assert isinstance(result, ToolResult)
assert result.meta == meta
def test_sets_meta_none_for_non_tool_result_original(self) -> None:
"""Should set meta=None when original is not a ToolResult."""
from fastmcp.tools.tool import ToolResult
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
{"key": "val"}, {"not": "a ToolResult"}
)
assert isinstance(result, ToolResult)
assert result.meta is None
class TestToolResultWrapping:
"""Integration tests for ToolResult unwrap/truncate/rewrap in on_call_tool."""
def _make_tool_result(
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
) -> Any:
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
from superset.utils import json
return ToolResult(
content=[TextContent(type="text", text=json.dumps(payload))],
meta=meta,
)
@pytest.mark.asyncio
async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None:
"""Truncate a ToolResult-wrapped info response and return a ToolResult."""
from fastmcp.tools.tool import ToolResult
from superset.utils import json
middleware = ResponseSizeGuardMiddleware(token_limit=500)
context = MagicMock()
context.message.name = "get_dataset_info"
context.message.params = {}
large_payload = {"id": 1, "table_name": "test", "description": "x" * 50000}
tool_result = self._make_tool_result(large_payload)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
# Must return a ToolResult, not a raw dict
assert isinstance(result, ToolResult)
reparsed = json.loads(result.content[0].text)
assert reparsed["id"] == 1
assert reparsed["_response_truncated"] is True
assert "[truncated" in reparsed["description"]
@pytest.mark.asyncio
async def test_small_tool_result_passes_through_unchanged(self) -> None:
"""Should return the original ToolResult when within the token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "get_chart_info"
context.message.params = {}
small_payload = {"id": 1, "name": "My Chart"}
tool_result = self._make_tool_result(small_payload)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result is tool_result
@pytest.mark.asyncio
async def test_large_non_info_tool_result_is_blocked(self) -> None:
"""Should raise ToolError for a non-info ToolResult that exceeds the limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
large_payload = {
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(500)]
}
tool_result = self._make_tool_result(large_payload)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError),
):
await middleware.on_call_tool(context, call_next)
@pytest.mark.asyncio
async def test_meta_preserved_after_truncation(self) -> None:
"""Should preserve the original ToolResult meta through truncation."""
from fastmcp.tools.tool import ToolResult
from superset.utils import json
middleware = ResponseSizeGuardMiddleware(token_limit=500)
context = MagicMock()
context.message.name = "get_dashboard_info"
context.message.params = {}
meta = {"request_id": "abc-123"}
large_payload = {"id": 1, "title": "My Dashboard", "description": "x" * 50000}
tool_result = self._make_tool_result(large_payload, meta=meta)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert isinstance(result, ToolResult)
assert result.meta == meta
reparsed = json.loads(result.content[0].text)
assert reparsed["_response_truncated"] is True
class TestMiddlewareIntegration:
"""Integration tests for middleware behavior."""