mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
cover edge case
This commit is contained in:
@@ -241,6 +241,7 @@ class LoggingMiddleware(Middleware):
|
||||
tool_name = getattr(context.message, "name", None)
|
||||
|
||||
mcp_call_id = uuid.uuid4().hex[:12]
|
||||
context.mcp_call_id = mcp_call_id
|
||||
start_time = time.time()
|
||||
success = False
|
||||
try:
|
||||
@@ -396,8 +397,10 @@ class StructuredContentStripperMiddleware(Middleware):
|
||||
# unhandled exception — including ToolError from
|
||||
# GlobalErrorHandlerMiddleware, ValueError, TypeError, etc. —
|
||||
# will cause encoding failures on the wire.
|
||||
mcp_call_id = getattr(context, "mcp_call_id", None)
|
||||
return ToolResult(
|
||||
content=[mt.TextContent(type="text", text=f"Error: {e}")],
|
||||
meta={"mcp_call_id": mcp_call_id} if mcp_call_id else None,
|
||||
)
|
||||
if isinstance(result, ToolResult) and result.structured_content is not None:
|
||||
result = ToolResult(content=result.content, meta=result.meta)
|
||||
|
||||
@@ -426,3 +426,35 @@ class TestMiddlewareChainOrder:
|
||||
"them. Ensure StructuredContentStripper is outermost "
|
||||
"(first added) in build_middleware_list()."
|
||||
)
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_middleware_chain_error_result_has_mcp_call_id(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""When a tool raises, the error ToolResult from
|
||||
StructuredContentStripper still carries mcp_call_id in meta."""
|
||||
from functools import partial
|
||||
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.mcp_service.server import build_middleware_list
|
||||
|
||||
middleware_list = build_middleware_list()
|
||||
|
||||
async def failing_tool(context: Any) -> Any:
|
||||
raise ValueError("chart not found")
|
||||
|
||||
chain = failing_tool
|
||||
for mw in reversed(middleware_list):
|
||||
chain = partial(mw, call_next=chain)
|
||||
|
||||
ctx = _make_context(name="get_chart_info")
|
||||
result = await chain(ctx)
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.content[0].text.startswith("Error:")
|
||||
assert result.meta is not None
|
||||
assert "mcp_call_id" in result.meta
|
||||
assert len(result.meta["mcp_call_id"]) == 12
|
||||
|
||||
Reference in New Issue
Block a user