From b1b6a057d88bc298743f1ea716af2f22cbdd8c2c Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 23 Apr 2026 12:35:13 -0400 Subject: [PATCH] fix(mcp): unwrap ToolResult payload before truncation in ResponseSizeGuardMiddleware (#39578) Co-authored-by: Elizabeth Thompson --- superset/mcp_service/middleware.py | 84 +++++- .../unit_tests/mcp_service/test_middleware.py | 277 ++++++++++++++++++ 2 files changed, 359 insertions(+), 2 deletions(-) diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index 797bb70076c..af8c8c12bd7 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -951,6 +951,59 @@ class ResponseSizeGuardMiddleware(Middleware): excluded_tools = [excluded_tools] self.excluded_tools = set(excluded_tools or []) + @staticmethod + def _extract_payload_from_tool_result( + response: Any, + ) -> dict[str, Any] | None: + """Extract the JSON payload dict from a ToolResult's content[0].text. + + FastMCP converts tool return values into ToolResult before middleware + sees them. The actual data (e.g. DashboardInfo dict) is serialized + as a JSON string inside ``content[0].text``. Truncation must operate + on that parsed dict — not on the ToolResult wrapper — otherwise + phases like "truncate charts list" never find the right keys. + + Returns the payload dict when extraction succeeds, or ``None`` when + the response is not a ToolResult or cannot be parsed. + """ + from fastmcp.tools.tool import ToolResult + + from superset.utils.json import loads as json_loads + + if not isinstance(response, ToolResult): + return None + + if ( + not response.content + or not hasattr(response.content[0], "text") + or not response.content[0].text + ): + return None + + try: + payload = json_loads(response.content[0].text) + except (ValueError, TypeError): + return None + + if not isinstance(payload, dict): + return None + + return payload + + @staticmethod + def _rewrap_as_tool_result(payload: dict[str, Any], original: Any) -> Any: + """Re-serialize a truncated payload dict back into a ToolResult.""" + from fastmcp.tools.tool import ToolResult + from mcp.types import TextContent + + from superset.utils.json import dumps as json_dumps + + text = json_dumps(payload) + return ToolResult( + content=[TextContent(type="text", text=text)], + meta=original.meta if isinstance(original, ToolResult) else None, + ) + def _try_truncate_info_response( self, tool_name: str, @@ -960,15 +1013,32 @@ class ResponseSizeGuardMiddleware(Middleware): """Attempt to dynamically truncate an info tool response to fit the limit. Returns the truncated response if successful, None otherwise. + + When the response is a ToolResult (the normal case — FastMCP wraps + every tool return value), the actual data lives inside + ``content[0].text`` as a JSON string. We parse that string, run the + truncation phases on the resulting dict, then re-wrap the result. """ from superset.mcp_service.utils.token_utils import ( estimate_response_tokens, truncate_oversized_response, ) + # Unwrap ToolResult so truncation operates on the real payload + extracted = self._extract_payload_from_tool_result(response) + if extracted is not None: + truncation_target = extracted + else: + logger.debug( + "Could not extract dict payload from response for %s; " + "falling back to truncating the raw response object", + tool_name, + ) + truncation_target = response + try: truncated, was_truncated, notes = truncate_oversized_response( - response, self.token_limit + truncation_target, self.token_limit ) except (MemoryError, RecursionError) as trunc_error: logger.warning( @@ -1015,6 +1085,10 @@ class ResponseSizeGuardMiddleware(Middleware): truncated["_response_truncated"] = True truncated["_truncation_notes"] = notes + # Re-wrap into ToolResult if we unwrapped one + if extracted is not None and isinstance(truncated, dict): + return self._rewrap_as_tool_result(truncated, response) + return truncated async def on_call_tool( @@ -1038,8 +1112,14 @@ class ResponseSizeGuardMiddleware(Middleware): format_size_limit_error, ) + # When the response is a ToolResult, estimate tokens on the actual + # payload inside content[0].text rather than on the ToolResult + # wrapper (which would double-serialize the JSON string). + extracted = self._extract_payload_from_tool_result(response) + estimation_target = extracted if extracted is not None else response + try: - estimated_tokens = estimate_response_tokens(response) + estimated_tokens = estimate_response_tokens(estimation_target) except MemoryError as me: logger.warning( "MemoryError while estimating tokens for %s: %s", tool_name, me diff --git a/tests/unit_tests/mcp_service/test_middleware.py b/tests/unit_tests/mcp_service/test_middleware.py index bcc164e048f..deaa01efed0 100644 --- a/tests/unit_tests/mcp_service/test_middleware.py +++ b/tests/unit_tests/mcp_service/test_middleware.py @@ -19,6 +19,7 @@ Unit tests for MCP service middleware. """ +from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -378,6 +379,282 @@ class TestCreateResponseSizeGuardMiddleware: assert middleware is None +class TestExtractPayloadFromToolResult: + """Tests for _extract_payload_from_tool_result static method.""" + + def _make_tool_result(self, text: str, meta: dict[str, Any] | None = None) -> Any: + from fastmcp.tools.tool import ToolResult + from mcp.types import TextContent + + return ToolResult( + content=[TextContent(type="text", text=text)], + meta=meta, + ) + + def test_extracts_dict_payload(self) -> None: + """Should parse the JSON text inside content[0] and return the dict.""" + from superset.utils import json + + payload = {"id": 1, "name": "test", "charts": [1, 2, 3]} + tool_result = self._make_tool_result(json.dumps(payload)) + + result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result( + tool_result + ) + + assert result is not None + assert result == payload + + def test_returns_none_for_plain_dict(self) -> None: + """Should return None when given a plain dict (not a ToolResult).""" + assert ( + ResponseSizeGuardMiddleware._extract_payload_from_tool_result( + {"key": "val"} + ) + is None + ) + + def test_returns_none_for_string(self) -> None: + """Should return None when given a plain string.""" + assert ( + ResponseSizeGuardMiddleware._extract_payload_from_tool_result("string") + is None + ) + + def test_returns_none_for_list(self) -> None: + """Should return None when given a plain list.""" + assert ( + ResponseSizeGuardMiddleware._extract_payload_from_tool_result([1, 2, 3]) + is None + ) + + def test_returns_none_for_none(self) -> None: + """Should return None when given None.""" + assert ( + ResponseSizeGuardMiddleware._extract_payload_from_tool_result(None) is None + ) + + def test_returns_none_when_payload_is_list(self) -> None: + """Should return None when JSON payload is a list, not a dict.""" + from superset.utils import json + + tool_result = self._make_tool_result(json.dumps([{"id": 1}, {"id": 2}])) + + result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result( + tool_result + ) + + assert result is None + + def test_returns_none_for_invalid_json(self) -> None: + """Should return None when content text is not valid JSON.""" + tool_result = self._make_tool_result("not valid {{{json") + + result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result( + tool_result + ) + + assert result is None + + def test_returns_none_for_empty_text(self) -> None: + """Should return None when content[0].text is empty.""" + tool_result = self._make_tool_result("") + + result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result( + tool_result + ) + + assert result is None + + def test_returns_none_for_empty_content(self) -> None: + """Should return None when ToolResult has no content items.""" + from fastmcp.tools.tool import ToolResult + + tool_result = ToolResult(content=[], meta=None) + + result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result( + tool_result + ) + + assert result is None + + +class TestRewrapAsToolResult: + """Tests for _rewrap_as_tool_result static method.""" + + def _make_tool_result( + self, payload: dict[str, Any], meta: dict[str, Any] | None = None + ) -> Any: + from fastmcp.tools.tool import ToolResult + from mcp.types import TextContent + + from superset.utils import json + + return ToolResult( + content=[TextContent(type="text", text=json.dumps(payload))], + meta=meta, + ) + + def test_returns_tool_result_with_serialized_payload(self) -> None: + """Should return a ToolResult whose content[0].text is the JSON payload.""" + from fastmcp.tools.tool import ToolResult + + from superset.utils import json + + original = self._make_tool_result({"old": "data"}) + new_payload = {"id": 1, "name": "truncated", "_response_truncated": True} + + result = ResponseSizeGuardMiddleware._rewrap_as_tool_result( + new_payload, original + ) + + assert isinstance(result, ToolResult) + assert result.content[0].type == "text" + reparsed = json.loads(result.content[0].text) + assert reparsed == new_payload + + def test_preserves_meta_from_original_tool_result(self) -> None: + """Should copy meta from the original ToolResult.""" + from fastmcp.tools.tool import ToolResult + + meta = {"request_id": "abc-123", "trace": "xyz"} + original = self._make_tool_result({"key": "val"}, meta=meta) + + result = ResponseSizeGuardMiddleware._rewrap_as_tool_result( + {"key": "val"}, original + ) + + assert isinstance(result, ToolResult) + assert result.meta == meta + + def test_sets_meta_none_for_non_tool_result_original(self) -> None: + """Should set meta=None when original is not a ToolResult.""" + from fastmcp.tools.tool import ToolResult + + result = ResponseSizeGuardMiddleware._rewrap_as_tool_result( + {"key": "val"}, {"not": "a ToolResult"} + ) + + assert isinstance(result, ToolResult) + assert result.meta is None + + +class TestToolResultWrapping: + """Integration tests for ToolResult unwrap/truncate/rewrap in on_call_tool.""" + + def _make_tool_result( + self, payload: dict[str, Any], meta: dict[str, Any] | None = None + ) -> Any: + from fastmcp.tools.tool import ToolResult + from mcp.types import TextContent + + from superset.utils import json + + return ToolResult( + content=[TextContent(type="text", text=json.dumps(payload))], + meta=meta, + ) + + @pytest.mark.asyncio + async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None: + """Truncate a ToolResult-wrapped info response and return a ToolResult.""" + from fastmcp.tools.tool import ToolResult + + from superset.utils import json + + middleware = ResponseSizeGuardMiddleware(token_limit=500) + context = MagicMock() + context.message.name = "get_dataset_info" + context.message.params = {} + + large_payload = {"id": 1, "table_name": "test", "description": "x" * 50000} + tool_result = self._make_tool_result(large_payload) + call_next = AsyncMock(return_value=tool_result) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + ): + result = await middleware.on_call_tool(context, call_next) + + # Must return a ToolResult, not a raw dict + assert isinstance(result, ToolResult) + reparsed = json.loads(result.content[0].text) + assert reparsed["id"] == 1 + assert reparsed["_response_truncated"] is True + assert "[truncated" in reparsed["description"] + + @pytest.mark.asyncio + async def test_small_tool_result_passes_through_unchanged(self) -> None: + """Should return the original ToolResult when within the token limit.""" + + middleware = ResponseSizeGuardMiddleware(token_limit=25000) + context = MagicMock() + context.message.name = "get_chart_info" + context.message.params = {} + + small_payload = {"id": 1, "name": "My Chart"} + tool_result = self._make_tool_result(small_payload) + call_next = AsyncMock(return_value=tool_result) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + ): + result = await middleware.on_call_tool(context, call_next) + + assert result is tool_result + + @pytest.mark.asyncio + async def test_large_non_info_tool_result_is_blocked(self) -> None: + """Should raise ToolError for a non-info ToolResult that exceeds the limit.""" + middleware = ResponseSizeGuardMiddleware(token_limit=100) + context = MagicMock() + context.message.name = "list_charts" + context.message.params = {} + + large_payload = { + "charts": [{"id": i, "name": f"chart_{i}"} for i in range(500)] + } + tool_result = self._make_tool_result(large_payload) + call_next = AsyncMock(return_value=tool_result) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + pytest.raises(ToolError), + ): + await middleware.on_call_tool(context, call_next) + + @pytest.mark.asyncio + async def test_meta_preserved_after_truncation(self) -> None: + """Should preserve the original ToolResult meta through truncation.""" + from fastmcp.tools.tool import ToolResult + + from superset.utils import json + + middleware = ResponseSizeGuardMiddleware(token_limit=500) + context = MagicMock() + context.message.name = "get_dashboard_info" + context.message.params = {} + + meta = {"request_id": "abc-123"} + large_payload = {"id": 1, "title": "My Dashboard", "description": "x" * 50000} + tool_result = self._make_tool_result(large_payload, meta=meta) + call_next = AsyncMock(return_value=tool_result) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + ): + result = await middleware.on_call_tool(context, call_next) + + assert isinstance(result, ToolResult) + assert result.meta == meta + reparsed = json.loads(result.content[0].text) + assert reparsed["_response_truncated"] is True + + class TestMiddlewareIntegration: """Integration tests for middleware behavior."""