mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
fix(mcp): unwrap ToolResult payload before truncation in ResponseSizeGuardMiddleware (#39578)
Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@
|
||||
Unit tests for MCP service middleware.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -378,6 +379,282 @@ class TestCreateResponseSizeGuardMiddleware:
|
||||
assert middleware is None
|
||||
|
||||
|
||||
class TestExtractPayloadFromToolResult:
|
||||
"""Tests for _extract_payload_from_tool_result static method."""
|
||||
|
||||
def _make_tool_result(self, text: str, meta: dict[str, Any] | None = None) -> Any:
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp.types import TextContent
|
||||
|
||||
return ToolResult(
|
||||
content=[TextContent(type="text", text=text)],
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def test_extracts_dict_payload(self) -> None:
|
||||
"""Should parse the JSON text inside content[0] and return the dict."""
|
||||
from superset.utils import json
|
||||
|
||||
payload = {"id": 1, "name": "test", "charts": [1, 2, 3]}
|
||||
tool_result = self._make_tool_result(json.dumps(payload))
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result == payload
|
||||
|
||||
def test_returns_none_for_plain_dict(self) -> None:
|
||||
"""Should return None when given a plain dict (not a ToolResult)."""
|
||||
assert (
|
||||
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
{"key": "val"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_string(self) -> None:
|
||||
"""Should return None when given a plain string."""
|
||||
assert (
|
||||
ResponseSizeGuardMiddleware._extract_payload_from_tool_result("string")
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_list(self) -> None:
|
||||
"""Should return None when given a plain list."""
|
||||
assert (
|
||||
ResponseSizeGuardMiddleware._extract_payload_from_tool_result([1, 2, 3])
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_none(self) -> None:
|
||||
"""Should return None when given None."""
|
||||
assert (
|
||||
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(None) is None
|
||||
)
|
||||
|
||||
def test_returns_none_when_payload_is_list(self) -> None:
|
||||
"""Should return None when JSON payload is a list, not a dict."""
|
||||
from superset.utils import json
|
||||
|
||||
tool_result = self._make_tool_result(json.dumps([{"id": 1}, {"id": 2}]))
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_invalid_json(self) -> None:
|
||||
"""Should return None when content text is not valid JSON."""
|
||||
tool_result = self._make_tool_result("not valid {{{json")
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_empty_text(self) -> None:
|
||||
"""Should return None when content[0].text is empty."""
|
||||
tool_result = self._make_tool_result("")
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_empty_content(self) -> None:
|
||||
"""Should return None when ToolResult has no content items."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
tool_result = ToolResult(content=[], meta=None)
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestRewrapAsToolResult:
|
||||
"""Tests for _rewrap_as_tool_result static method."""
|
||||
|
||||
def _make_tool_result(
|
||||
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp.types import TextContent
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
return ToolResult(
|
||||
content=[TextContent(type="text", text=json.dumps(payload))],
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def test_returns_tool_result_with_serialized_payload(self) -> None:
|
||||
"""Should return a ToolResult whose content[0].text is the JSON payload."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
original = self._make_tool_result({"old": "data"})
|
||||
new_payload = {"id": 1, "name": "truncated", "_response_truncated": True}
|
||||
|
||||
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
|
||||
new_payload, original
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.content[0].type == "text"
|
||||
reparsed = json.loads(result.content[0].text)
|
||||
assert reparsed == new_payload
|
||||
|
||||
def test_preserves_meta_from_original_tool_result(self) -> None:
|
||||
"""Should copy meta from the original ToolResult."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
meta = {"request_id": "abc-123", "trace": "xyz"}
|
||||
original = self._make_tool_result({"key": "val"}, meta=meta)
|
||||
|
||||
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
|
||||
{"key": "val"}, original
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.meta == meta
|
||||
|
||||
def test_sets_meta_none_for_non_tool_result_original(self) -> None:
|
||||
"""Should set meta=None when original is not a ToolResult."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
|
||||
{"key": "val"}, {"not": "a ToolResult"}
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.meta is None
|
||||
|
||||
|
||||
class TestToolResultWrapping:
|
||||
"""Integration tests for ToolResult unwrap/truncate/rewrap in on_call_tool."""
|
||||
|
||||
def _make_tool_result(
|
||||
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp.types import TextContent
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
return ToolResult(
|
||||
content=[TextContent(type="text", text=json.dumps(payload))],
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None:
|
||||
"""Truncate a ToolResult-wrapped info response and return a ToolResult."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
||||
context = MagicMock()
|
||||
context.message.name = "get_dataset_info"
|
||||
context.message.params = {}
|
||||
|
||||
large_payload = {"id": 1, "table_name": "test", "description": "x" * 50000}
|
||||
tool_result = self._make_tool_result(large_payload)
|
||||
call_next = AsyncMock(return_value=tool_result)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
# Must return a ToolResult, not a raw dict
|
||||
assert isinstance(result, ToolResult)
|
||||
reparsed = json.loads(result.content[0].text)
|
||||
assert reparsed["id"] == 1
|
||||
assert reparsed["_response_truncated"] is True
|
||||
assert "[truncated" in reparsed["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_tool_result_passes_through_unchanged(self) -> None:
|
||||
"""Should return the original ToolResult when within the token limit."""
|
||||
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
||||
context = MagicMock()
|
||||
context.message.name = "get_chart_info"
|
||||
context.message.params = {}
|
||||
|
||||
small_payload = {"id": 1, "name": "My Chart"}
|
||||
tool_result = self._make_tool_result(small_payload)
|
||||
call_next = AsyncMock(return_value=tool_result)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
assert result is tool_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_non_info_tool_result_is_blocked(self) -> None:
|
||||
"""Should raise ToolError for a non-info ToolResult that exceeds the limit."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
||||
context = MagicMock()
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {}
|
||||
|
||||
large_payload = {
|
||||
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(500)]
|
||||
}
|
||||
tool_result = self._make_tool_result(large_payload)
|
||||
call_next = AsyncMock(return_value=tool_result)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
pytest.raises(ToolError),
|
||||
):
|
||||
await middleware.on_call_tool(context, call_next)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_preserved_after_truncation(self) -> None:
|
||||
"""Should preserve the original ToolResult meta through truncation."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
||||
context = MagicMock()
|
||||
context.message.name = "get_dashboard_info"
|
||||
context.message.params = {}
|
||||
|
||||
meta = {"request_id": "abc-123"}
|
||||
large_payload = {"id": 1, "title": "My Dashboard", "description": "x" * 50000}
|
||||
tool_result = self._make_tool_result(large_payload, meta=meta)
|
||||
call_next = AsyncMock(return_value=tool_result)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.meta == meta
|
||||
reparsed = json.loads(result.content[0].text)
|
||||
assert reparsed["_response_truncated"] is True
|
||||
|
||||
|
||||
class TestMiddlewareIntegration:
|
||||
"""Integration tests for middleware behavior."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user