Compare commits

...

3 Commits

Author SHA1 Message Date
Elizabeth Thompson
c10f75b4ab fix(mcp): simplify _extract_payload_from_tool_result return type and add fallback logging
- Return dict | None instead of tuple[dict, bool] — the bool was always
  True and discarded at every call site
- Add logger.debug when extraction falls back to the raw response object,
  so operators can diagnose why an info tool was blocked instead of truncated

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-23 01:36:06 +00:00
Elizabeth Thompson
131c63a942 test(mcp): add unit tests for _extract_payload_from_tool_result and _rewrap_as_tool_result
Covers the four gaps identified in review:
- _extract_payload_from_tool_result: dict payload, non-ToolResult inputs,
  list payload (non-dict), invalid JSON, empty content, empty text
- _rewrap_as_tool_result: round-trip, meta preservation, non-ToolResult original
- Integration: info ToolResult truncated and returned as ToolResult (not dict),
  small ToolResult passes through unchanged, non-info ToolResult blocked,
  meta preserved through truncation

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-23 00:55:42 +00:00
Amin Ghadersohi
5348b92e3a fix(mcp): unwrap ToolResult payload before truncation in ResponseSizeGuardMiddleware
FastMCP converts tool return values into ToolResult objects before
middleware sees them. The actual data (e.g. DashboardInfo) is serialized
as a JSON string inside content[0].text. The ResponseSizeGuardMiddleware
was operating on the ToolResult wrapper instead of the actual payload,
causing two problems:

1. Token estimation was double-serializing (the JSON string inside text
   gets escaped again), producing inflated estimates
2. Truncation phases (truncate charts list, truncate strings, etc.)
   could not find the right keys because they were looking at the
   ToolResult structure, not the dashboard/chart/dataset data

This caused get_dashboard_info to produce broken truncated responses
for dashboards with many charts — the middleware would char-truncate
content[0].text mid-JSON instead of intelligently reducing the payload.

The fix extracts the payload from content[0].text, parses it back to a
dict, runs the 5-phase truncation on the actual data, then re-wraps the
result into a ToolResult.
2026-04-23 00:54:24 +00:00
2 changed files with 359 additions and 2 deletions

View File

@@ -951,6 +951,59 @@ class ResponseSizeGuardMiddleware(Middleware):
excluded_tools = [excluded_tools]
self.excluded_tools = set(excluded_tools or [])
@staticmethod
def _extract_payload_from_tool_result(
response: Any,
) -> dict[str, Any] | None:
"""Extract the JSON payload dict from a ToolResult's content[0].text.
FastMCP converts tool return values into ToolResult before middleware
sees them. The actual data (e.g. DashboardInfo dict) is serialized
as a JSON string inside ``content[0].text``. Truncation must operate
on that parsed dict — not on the ToolResult wrapper — otherwise
phases like "truncate charts list" never find the right keys.
Returns the payload dict when extraction succeeds, or ``None`` when
the response is not a ToolResult or cannot be parsed.
"""
from fastmcp.tools.tool import ToolResult
from superset.utils.json import loads as json_loads
if not isinstance(response, ToolResult):
return None
if (
not response.content
or not hasattr(response.content[0], "text")
or not response.content[0].text
):
return None
try:
payload = json_loads(response.content[0].text)
except (ValueError, TypeError):
return None
if not isinstance(payload, dict):
return None
return payload
@staticmethod
def _rewrap_as_tool_result(payload: dict[str, Any], original: Any) -> Any:
"""Re-serialize a truncated payload dict back into a ToolResult."""
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
from superset.utils.json import dumps as json_dumps
text = json_dumps(payload)
return ToolResult(
content=[TextContent(type="text", text=text)],
meta=original.meta if isinstance(original, ToolResult) else None,
)
def _try_truncate_info_response(
self,
tool_name: str,
@@ -960,15 +1013,32 @@ class ResponseSizeGuardMiddleware(Middleware):
"""Attempt to dynamically truncate an info tool response to fit the limit.
Returns the truncated response if successful, None otherwise.
When the response is a ToolResult (the normal case — FastMCP wraps
every tool return value), the actual data lives inside
``content[0].text`` as a JSON string. We parse that string, run the
truncation phases on the resulting dict, then re-wrap the result.
"""
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
truncate_oversized_response,
)
# Unwrap ToolResult so truncation operates on the real payload
extracted = self._extract_payload_from_tool_result(response)
if extracted is not None:
truncation_target = extracted
else:
logger.debug(
"Could not extract dict payload from response for %s; "
"falling back to truncating the raw response object",
tool_name,
)
truncation_target = response
try:
truncated, was_truncated, notes = truncate_oversized_response(
response, self.token_limit
truncation_target, self.token_limit
)
except (MemoryError, RecursionError) as trunc_error:
logger.warning(
@@ -1015,6 +1085,10 @@ class ResponseSizeGuardMiddleware(Middleware):
truncated["_response_truncated"] = True
truncated["_truncation_notes"] = notes
# Re-wrap into ToolResult if we unwrapped one
if extracted is not None and isinstance(truncated, dict):
return self._rewrap_as_tool_result(truncated, response)
return truncated
async def on_call_tool(
@@ -1038,8 +1112,14 @@ class ResponseSizeGuardMiddleware(Middleware):
format_size_limit_error,
)
# When the response is a ToolResult, estimate tokens on the actual
# payload inside content[0].text rather than on the ToolResult
# wrapper (which would double-serialize the JSON string).
extracted = self._extract_payload_from_tool_result(response)
estimation_target = extracted if extracted is not None else response
try:
estimated_tokens = estimate_response_tokens(response)
estimated_tokens = estimate_response_tokens(estimation_target)
except MemoryError as me:
logger.warning(
"MemoryError while estimating tokens for %s: %s", tool_name, me

View File

@@ -19,6 +19,7 @@
Unit tests for MCP service middleware.
"""
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@@ -378,6 +379,282 @@ class TestCreateResponseSizeGuardMiddleware:
assert middleware is None
class TestExtractPayloadFromToolResult:
"""Tests for _extract_payload_from_tool_result static method."""
def _make_tool_result(self, text: str, meta: dict[str, Any] | None = None) -> Any:
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
return ToolResult(
content=[TextContent(type="text", text=text)],
meta=meta,
)
def test_extracts_dict_payload(self) -> None:
"""Should parse the JSON text inside content[0] and return the dict."""
from superset.utils import json
payload = {"id": 1, "name": "test", "charts": [1, 2, 3]}
tool_result = self._make_tool_result(json.dumps(payload))
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is not None
assert result == payload
def test_returns_none_for_plain_dict(self) -> None:
"""Should return None when given a plain dict (not a ToolResult)."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
{"key": "val"}
)
is None
)
def test_returns_none_for_string(self) -> None:
"""Should return None when given a plain string."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result("string")
is None
)
def test_returns_none_for_list(self) -> None:
"""Should return None when given a plain list."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result([1, 2, 3])
is None
)
def test_returns_none_for_none(self) -> None:
"""Should return None when given None."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(None) is None
)
def test_returns_none_when_payload_is_list(self) -> None:
"""Should return None when JSON payload is a list, not a dict."""
from superset.utils import json
tool_result = self._make_tool_result(json.dumps([{"id": 1}, {"id": 2}]))
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
def test_returns_none_for_invalid_json(self) -> None:
"""Should return None when content text is not valid JSON."""
tool_result = self._make_tool_result("not valid {{{json")
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
def test_returns_none_for_empty_text(self) -> None:
"""Should return None when content[0].text is empty."""
tool_result = self._make_tool_result("")
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
def test_returns_none_for_empty_content(self) -> None:
"""Should return None when ToolResult has no content items."""
from fastmcp.tools.tool import ToolResult
tool_result = ToolResult(content=[], meta=None)
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
class TestRewrapAsToolResult:
"""Tests for _rewrap_as_tool_result static method."""
def _make_tool_result(
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
) -> Any:
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
from superset.utils import json
return ToolResult(
content=[TextContent(type="text", text=json.dumps(payload))],
meta=meta,
)
def test_returns_tool_result_with_serialized_payload(self) -> None:
"""Should return a ToolResult whose content[0].text is the JSON payload."""
from fastmcp.tools.tool import ToolResult
from superset.utils import json
original = self._make_tool_result({"old": "data"})
new_payload = {"id": 1, "name": "truncated", "_response_truncated": True}
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
new_payload, original
)
assert isinstance(result, ToolResult)
assert result.content[0].type == "text"
reparsed = json.loads(result.content[0].text)
assert reparsed == new_payload
def test_preserves_meta_from_original_tool_result(self) -> None:
"""Should copy meta from the original ToolResult."""
from fastmcp.tools.tool import ToolResult
meta = {"request_id": "abc-123", "trace": "xyz"}
original = self._make_tool_result({"key": "val"}, meta=meta)
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
{"key": "val"}, original
)
assert isinstance(result, ToolResult)
assert result.meta == meta
def test_sets_meta_none_for_non_tool_result_original(self) -> None:
"""Should set meta=None when original is not a ToolResult."""
from fastmcp.tools.tool import ToolResult
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
{"key": "val"}, {"not": "a ToolResult"}
)
assert isinstance(result, ToolResult)
assert result.meta is None
class TestToolResultWrapping:
"""Integration tests for ToolResult unwrap/truncate/rewrap in on_call_tool."""
def _make_tool_result(
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
) -> Any:
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
from superset.utils import json
return ToolResult(
content=[TextContent(type="text", text=json.dumps(payload))],
meta=meta,
)
@pytest.mark.asyncio
async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None:
"""Truncate a ToolResult-wrapped info response and return a ToolResult."""
from fastmcp.tools.tool import ToolResult
from superset.utils import json
middleware = ResponseSizeGuardMiddleware(token_limit=500)
context = MagicMock()
context.message.name = "get_dataset_info"
context.message.params = {}
large_payload = {"id": 1, "table_name": "test", "description": "x" * 50000}
tool_result = self._make_tool_result(large_payload)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
# Must return a ToolResult, not a raw dict
assert isinstance(result, ToolResult)
reparsed = json.loads(result.content[0].text)
assert reparsed["id"] == 1
assert reparsed["_response_truncated"] is True
assert "[truncated" in reparsed["description"]
@pytest.mark.asyncio
async def test_small_tool_result_passes_through_unchanged(self) -> None:
"""Should return the original ToolResult when within the token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "get_chart_info"
context.message.params = {}
small_payload = {"id": 1, "name": "My Chart"}
tool_result = self._make_tool_result(small_payload)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result is tool_result
@pytest.mark.asyncio
async def test_large_non_info_tool_result_is_blocked(self) -> None:
"""Should raise ToolError for a non-info ToolResult that exceeds the limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
large_payload = {
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(500)]
}
tool_result = self._make_tool_result(large_payload)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError),
):
await middleware.on_call_tool(context, call_next)
@pytest.mark.asyncio
async def test_meta_preserved_after_truncation(self) -> None:
"""Should preserve the original ToolResult meta through truncation."""
from fastmcp.tools.tool import ToolResult
from superset.utils import json
middleware = ResponseSizeGuardMiddleware(token_limit=500)
context = MagicMock()
context.message.name = "get_dashboard_info"
context.message.params = {}
meta = {"request_id": "abc-123"}
large_payload = {"id": 1, "title": "My Dashboard", "description": "x" * 50000}
tool_result = self._make_tool_result(large_payload, meta=meta)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert isinstance(result, ToolResult)
assert result.meta == meta
reparsed = json.loads(result.content[0].text)
assert reparsed["_response_truncated"] is True
class TestMiddlewareIntegration:
"""Integration tests for middleware behavior."""