cover edge case

This commit is contained in:
alexandrusoare
2026-04-30 15:24:45 +03:00
parent 178f24a308
commit 5e3ca00bb2
2 changed files with 35 additions and 0 deletions

View File

@@ -241,6 +241,7 @@ class LoggingMiddleware(Middleware):
tool_name = getattr(context.message, "name", None)
mcp_call_id = uuid.uuid4().hex[:12]
context.mcp_call_id = mcp_call_id
start_time = time.time()
success = False
try:
@@ -396,8 +397,10 @@ class StructuredContentStripperMiddleware(Middleware):
# unhandled exception — including ToolError from
# GlobalErrorHandlerMiddleware, ValueError, TypeError, etc. —
# will cause encoding failures on the wire.
mcp_call_id = getattr(context, "mcp_call_id", None)
return ToolResult(
content=[mt.TextContent(type="text", text=f"Error: {e}")],
meta={"mcp_call_id": mcp_call_id} if mcp_call_id else None,
)
if isinstance(result, ToolResult) and result.structured_content is not None:
result = ToolResult(content=result.content, meta=result.meta)

View File

@@ -426,3 +426,35 @@ class TestMiddlewareChainOrder:
"them. Ensure StructuredContentStripper is outermost "
"(first added) in build_middleware_list()."
)
@patch("superset.mcp_service.middleware.event_logger")
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
@pytest.mark.asyncio
async def test_real_middleware_chain_error_result_has_mcp_call_id(
self, mock_get_user_id, mock_event_logger
):
"""When a tool raises, the error ToolResult from
StructuredContentStripper still carries mcp_call_id in meta."""
from functools import partial
from fastmcp.tools.tool import ToolResult
from superset.mcp_service.server import build_middleware_list
middleware_list = build_middleware_list()
async def failing_tool(context: Any) -> Any:
raise ValueError("chart not found")
chain = failing_tool
for mw in reversed(middleware_list):
chain = partial(mw, call_next=chain)
ctx = _make_context(name="get_chart_info")
result = await chain(ctx)
assert isinstance(result, ToolResult)
assert result.content[0].text.startswith("Error:")
assert result.meta is not None
assert "mcp_call_id" in result.meta
assert len(result.meta["mcp_call_id"]) == 12