mirror of
https://github.com/apache/superset.git
synced 2026-05-17 13:55:15 +00:00
Compare commits
3 Commits
fix/mcp-ex
...
superset-p
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c10f75b4ab | ||
|
|
131c63a942 | ||
|
|
5348b92e3a |
@@ -951,6 +951,59 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
excluded_tools = [excluded_tools]
|
||||
self.excluded_tools = set(excluded_tools or [])
|
||||
|
||||
@staticmethod
|
||||
def _extract_payload_from_tool_result(
|
||||
response: Any,
|
||||
) -> dict[str, Any] | None:
|
||||
"""Extract the JSON payload dict from a ToolResult's content[0].text.
|
||||
|
||||
FastMCP converts tool return values into ToolResult before middleware
|
||||
sees them. The actual data (e.g. DashboardInfo dict) is serialized
|
||||
as a JSON string inside ``content[0].text``. Truncation must operate
|
||||
on that parsed dict — not on the ToolResult wrapper — otherwise
|
||||
phases like "truncate charts list" never find the right keys.
|
||||
|
||||
Returns the payload dict when extraction succeeds, or ``None`` when
|
||||
the response is not a ToolResult or cannot be parsed.
|
||||
"""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.utils.json import loads as json_loads
|
||||
|
||||
if not isinstance(response, ToolResult):
|
||||
return None
|
||||
|
||||
if (
|
||||
not response.content
|
||||
or not hasattr(response.content[0], "text")
|
||||
or not response.content[0].text
|
||||
):
|
||||
return None
|
||||
|
||||
try:
|
||||
payload = json_loads(response.content[0].text)
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
if not isinstance(payload, dict):
|
||||
return None
|
||||
|
||||
return payload
|
||||
|
||||
@staticmethod
|
||||
def _rewrap_as_tool_result(payload: dict[str, Any], original: Any) -> Any:
|
||||
"""Re-serialize a truncated payload dict back into a ToolResult."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp.types import TextContent
|
||||
|
||||
from superset.utils.json import dumps as json_dumps
|
||||
|
||||
text = json_dumps(payload)
|
||||
return ToolResult(
|
||||
content=[TextContent(type="text", text=text)],
|
||||
meta=original.meta if isinstance(original, ToolResult) else None,
|
||||
)
|
||||
|
||||
def _try_truncate_info_response(
|
||||
self,
|
||||
tool_name: str,
|
||||
@@ -960,15 +1013,32 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
"""Attempt to dynamically truncate an info tool response to fit the limit.
|
||||
|
||||
Returns the truncated response if successful, None otherwise.
|
||||
|
||||
When the response is a ToolResult (the normal case — FastMCP wraps
|
||||
every tool return value), the actual data lives inside
|
||||
``content[0].text`` as a JSON string. We parse that string, run the
|
||||
truncation phases on the resulting dict, then re-wrap the result.
|
||||
"""
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
truncate_oversized_response,
|
||||
)
|
||||
|
||||
# Unwrap ToolResult so truncation operates on the real payload
|
||||
extracted = self._extract_payload_from_tool_result(response)
|
||||
if extracted is not None:
|
||||
truncation_target = extracted
|
||||
else:
|
||||
logger.debug(
|
||||
"Could not extract dict payload from response for %s; "
|
||||
"falling back to truncating the raw response object",
|
||||
tool_name,
|
||||
)
|
||||
truncation_target = response
|
||||
|
||||
try:
|
||||
truncated, was_truncated, notes = truncate_oversized_response(
|
||||
response, self.token_limit
|
||||
truncation_target, self.token_limit
|
||||
)
|
||||
except (MemoryError, RecursionError) as trunc_error:
|
||||
logger.warning(
|
||||
@@ -1015,6 +1085,10 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
truncated["_response_truncated"] = True
|
||||
truncated["_truncation_notes"] = notes
|
||||
|
||||
# Re-wrap into ToolResult if we unwrapped one
|
||||
if extracted is not None and isinstance(truncated, dict):
|
||||
return self._rewrap_as_tool_result(truncated, response)
|
||||
|
||||
return truncated
|
||||
|
||||
async def on_call_tool(
|
||||
@@ -1038,8 +1112,14 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
format_size_limit_error,
|
||||
)
|
||||
|
||||
# When the response is a ToolResult, estimate tokens on the actual
|
||||
# payload inside content[0].text rather than on the ToolResult
|
||||
# wrapper (which would double-serialize the JSON string).
|
||||
extracted = self._extract_payload_from_tool_result(response)
|
||||
estimation_target = extracted if extracted is not None else response
|
||||
|
||||
try:
|
||||
estimated_tokens = estimate_response_tokens(response)
|
||||
estimated_tokens = estimate_response_tokens(estimation_target)
|
||||
except MemoryError as me:
|
||||
logger.warning(
|
||||
"MemoryError while estimating tokens for %s: %s", tool_name, me
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
Unit tests for MCP service middleware.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -378,6 +379,282 @@ class TestCreateResponseSizeGuardMiddleware:
|
||||
assert middleware is None
|
||||
|
||||
|
||||
class TestExtractPayloadFromToolResult:
|
||||
"""Tests for _extract_payload_from_tool_result static method."""
|
||||
|
||||
def _make_tool_result(self, text: str, meta: dict[str, Any] | None = None) -> Any:
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp.types import TextContent
|
||||
|
||||
return ToolResult(
|
||||
content=[TextContent(type="text", text=text)],
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def test_extracts_dict_payload(self) -> None:
|
||||
"""Should parse the JSON text inside content[0] and return the dict."""
|
||||
from superset.utils import json
|
||||
|
||||
payload = {"id": 1, "name": "test", "charts": [1, 2, 3]}
|
||||
tool_result = self._make_tool_result(json.dumps(payload))
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result == payload
|
||||
|
||||
def test_returns_none_for_plain_dict(self) -> None:
|
||||
"""Should return None when given a plain dict (not a ToolResult)."""
|
||||
assert (
|
||||
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
{"key": "val"}
|
||||
)
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_string(self) -> None:
|
||||
"""Should return None when given a plain string."""
|
||||
assert (
|
||||
ResponseSizeGuardMiddleware._extract_payload_from_tool_result("string")
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_list(self) -> None:
|
||||
"""Should return None when given a plain list."""
|
||||
assert (
|
||||
ResponseSizeGuardMiddleware._extract_payload_from_tool_result([1, 2, 3])
|
||||
is None
|
||||
)
|
||||
|
||||
def test_returns_none_for_none(self) -> None:
|
||||
"""Should return None when given None."""
|
||||
assert (
|
||||
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(None) is None
|
||||
)
|
||||
|
||||
def test_returns_none_when_payload_is_list(self) -> None:
|
||||
"""Should return None when JSON payload is a list, not a dict."""
|
||||
from superset.utils import json
|
||||
|
||||
tool_result = self._make_tool_result(json.dumps([{"id": 1}, {"id": 2}]))
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_invalid_json(self) -> None:
|
||||
"""Should return None when content text is not valid JSON."""
|
||||
tool_result = self._make_tool_result("not valid {{{json")
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_empty_text(self) -> None:
|
||||
"""Should return None when content[0].text is empty."""
|
||||
tool_result = self._make_tool_result("")
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_empty_content(self) -> None:
|
||||
"""Should return None when ToolResult has no content items."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
tool_result = ToolResult(content=[], meta=None)
|
||||
|
||||
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
||||
tool_result
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestRewrapAsToolResult:
|
||||
"""Tests for _rewrap_as_tool_result static method."""
|
||||
|
||||
def _make_tool_result(
|
||||
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp.types import TextContent
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
return ToolResult(
|
||||
content=[TextContent(type="text", text=json.dumps(payload))],
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
def test_returns_tool_result_with_serialized_payload(self) -> None:
|
||||
"""Should return a ToolResult whose content[0].text is the JSON payload."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
original = self._make_tool_result({"old": "data"})
|
||||
new_payload = {"id": 1, "name": "truncated", "_response_truncated": True}
|
||||
|
||||
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
|
||||
new_payload, original
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.content[0].type == "text"
|
||||
reparsed = json.loads(result.content[0].text)
|
||||
assert reparsed == new_payload
|
||||
|
||||
def test_preserves_meta_from_original_tool_result(self) -> None:
|
||||
"""Should copy meta from the original ToolResult."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
meta = {"request_id": "abc-123", "trace": "xyz"}
|
||||
original = self._make_tool_result({"key": "val"}, meta=meta)
|
||||
|
||||
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
|
||||
{"key": "val"}, original
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.meta == meta
|
||||
|
||||
def test_sets_meta_none_for_non_tool_result_original(self) -> None:
|
||||
"""Should set meta=None when original is not a ToolResult."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
|
||||
{"key": "val"}, {"not": "a ToolResult"}
|
||||
)
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.meta is None
|
||||
|
||||
|
||||
class TestToolResultWrapping:
|
||||
"""Integration tests for ToolResult unwrap/truncate/rewrap in on_call_tool."""
|
||||
|
||||
def _make_tool_result(
|
||||
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
|
||||
) -> Any:
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp.types import TextContent
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
return ToolResult(
|
||||
content=[TextContent(type="text", text=json.dumps(payload))],
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None:
|
||||
"""Truncate a ToolResult-wrapped info response and return a ToolResult."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
||||
context = MagicMock()
|
||||
context.message.name = "get_dataset_info"
|
||||
context.message.params = {}
|
||||
|
||||
large_payload = {"id": 1, "table_name": "test", "description": "x" * 50000}
|
||||
tool_result = self._make_tool_result(large_payload)
|
||||
call_next = AsyncMock(return_value=tool_result)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
# Must return a ToolResult, not a raw dict
|
||||
assert isinstance(result, ToolResult)
|
||||
reparsed = json.loads(result.content[0].text)
|
||||
assert reparsed["id"] == 1
|
||||
assert reparsed["_response_truncated"] is True
|
||||
assert "[truncated" in reparsed["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_small_tool_result_passes_through_unchanged(self) -> None:
|
||||
"""Should return the original ToolResult when within the token limit."""
|
||||
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
||||
context = MagicMock()
|
||||
context.message.name = "get_chart_info"
|
||||
context.message.params = {}
|
||||
|
||||
small_payload = {"id": 1, "name": "My Chart"}
|
||||
tool_result = self._make_tool_result(small_payload)
|
||||
call_next = AsyncMock(return_value=tool_result)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
assert result is tool_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_large_non_info_tool_result_is_blocked(self) -> None:
|
||||
"""Should raise ToolError for a non-info ToolResult that exceeds the limit."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
||||
context = MagicMock()
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {}
|
||||
|
||||
large_payload = {
|
||||
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(500)]
|
||||
}
|
||||
tool_result = self._make_tool_result(large_payload)
|
||||
call_next = AsyncMock(return_value=tool_result)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
pytest.raises(ToolError),
|
||||
):
|
||||
await middleware.on_call_tool(context, call_next)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_meta_preserved_after_truncation(self) -> None:
|
||||
"""Should preserve the original ToolResult meta through truncation."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.utils import json
|
||||
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
||||
context = MagicMock()
|
||||
context.message.name = "get_dashboard_info"
|
||||
context.message.params = {}
|
||||
|
||||
meta = {"request_id": "abc-123"}
|
||||
large_payload = {"id": 1, "title": "My Dashboard", "description": "x" * 50000}
|
||||
tool_result = self._make_tool_result(large_payload, meta=meta)
|
||||
call_next = AsyncMock(return_value=tool_result)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.meta == meta
|
||||
reparsed = json.loads(result.content[0].text)
|
||||
assert reparsed["_response_truncated"] is True
|
||||
|
||||
|
||||
class TestMiddlewareIntegration:
|
||||
"""Integration tests for middleware behavior."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user