fix(MCP): fix MCP logs (#39159)

This commit is contained in:
Alexandru Soare
2026-04-15 15:57:04 +03:00
committed by GitHub
parent 86575e129b
commit ffcc6e8b63
5 changed files with 266 additions and 22 deletions

View File

@@ -130,6 +130,18 @@ class LoggingMiddleware(Middleware):
in on_message().
"""
def _is_error_response(self, result: ToolResult) -> bool:
"""Check if a tool result contains an error schema response.
MCP tools return error schemas (ChartError, DashboardError, etc.)
instead of raising exceptions. These serialize to JSON containing
an "error_type" field.
"""
try:
return '"error_type"' in result.content[0].text
except (AttributeError, IndexError):
return False
def _extract_context_info(
self, context: MiddlewareContext
) -> tuple[
@@ -171,8 +183,11 @@ class LoggingMiddleware(Middleware):
success = False
try:
result = await call_next(context)
success = True
success = not self._is_error_response(result)
return result
except Exception:
success = False
raise
finally:
duration_ms = int((time.time() - start_time) * 1000)
if has_app_context():

View File

@@ -28,6 +28,7 @@ from collections.abc import Sequence
from typing import Annotated, Any
import uvicorn
from fastmcp.server.middleware import Middleware
from superset.mcp_service.app import create_mcp_app, init_fastmcp_server
from superset.mcp_service.mcp_config import (
@@ -492,6 +493,24 @@ def _create_auth_provider(flask_app: Any) -> Any | None:
return auth_provider
def build_middleware_list() -> list[Middleware]:
"""Build the core MCP middleware list in the correct order.
FastMCP wraps handlers so that the FIRST-added middleware is
outermost. Order here is outermost → innermost:
1. StructuredContentStripper — safety net, converts exceptions
to safe ToolResult text for transports that can't encode errors
2. LoggingMiddleware — logs tool calls with success/failure status
3. GlobalErrorHandler — catches tool exceptions, raises ToolError
"""
return [
StructuredContentStripperMiddleware(),
LoggingMiddleware(),
GlobalErrorHandlerMiddleware(),
]
def run_server(
host: str = "127.0.0.1",
port: int = 5008,
@@ -539,30 +558,15 @@ def run_server(
flask_app = get_flask_app()
auth_provider = _create_auth_provider(flask_app)
# Build middleware list
# FastMCP wraps handlers so that the LAST-added middleware is
# outermost. Order here is innermost → outermost.
middleware_list = []
middleware_list = build_middleware_list()
# Add caching middleware (innermost runs closest to the tool)
if caching_middleware := create_response_caching_middleware():
middleware_list.append(caching_middleware)
# Add response size guard (protects LLM clients from huge responses)
if size_guard_middleware := create_response_size_guard_middleware():
# Add optional middleware (innermost, closest to tool)
size_guard_middleware = create_response_size_guard_middleware()
if size_guard_middleware:
middleware_list.append(size_guard_middleware)
# Add logging middleware (logs all tool calls with duration tracking)
middleware_list.append(LoggingMiddleware())
# Add global error handler (catches all exceptions, raises ToolError)
middleware_list.append(GlobalErrorHandlerMiddleware())
# Strip outputSchema from tool definitions and structuredContent from
# tool responses to prevent encoding errors on Claude.ai's MCP bridge.
# MUST be outermost so it catches ToolError from GlobalErrorHandler
# and converts to plain text before the MCP SDK tries to encode it.
middleware_list.append(StructuredContentStripperMiddleware())
if caching_middleware := create_response_caching_middleware():
middleware_list.append(caching_middleware)
mcp_instance = init_fastmcp_server(
auth=auth_provider,

View File

@@ -384,6 +384,13 @@ class DBEventLogger(AbstractEventLogger):
from superset.models.core import Log
records = kwargs.get("records", [])
curated_payload = kwargs.get("curated_payload")
# If no records but curated_payload exists, use it as a single record
# This enables MCP middleware logging which passes curated_payload
if not records and curated_payload:
records = [curated_payload]
logs = []
for record in records:
json_string: str | None

View File

@@ -254,3 +254,64 @@ class TestEventLogger(unittest.TestCase):
}
]
assert payload["duration_ms"] >= 100
@patch("superset.db")
def test_curated_payload_used_when_records_empty(self, mock_db):
"""Test that curated_payload is used when records is empty (MCP pattern).
MCP middleware passes curated_payload instead of records. This test verifies
that DBEventLogger.log() creates a Log entry from curated_payload when
records is empty.
"""
logger = DBEventLogger()
with app.test_request_context():
logger.log(
user_id=1,
action="mcp_tool_call",
dashboard_id=None,
duration_ms=100,
slice_id=None,
referrer=None,
curated_payload={"tool": "list_charts", "success": True},
)
# Verify bulk_save_objects was called with one Log object
mock_db.session.bulk_save_objects.assert_called_once()
logs = mock_db.session.bulk_save_objects.call_args[0][0]
assert len(logs) == 1
assert logs[0].action == "mcp_tool_call"
assert logs[0].duration_ms == 100
# Verify JSON contains the curated_payload data
from superset.utils import json as json_utils
payload = json_utils.loads(logs[0].json)
assert payload["tool"] == "list_charts"
assert payload["success"] is True
@patch("superset.db")
def test_records_takes_precedence_over_curated_payload(self, mock_db):
"""Test that records takes precedence over curated_payload."""
logger = DBEventLogger()
with app.test_request_context():
logger.log(
user_id=1,
action="test_action",
dashboard_id=None,
duration_ms=50,
slice_id=None,
referrer=None,
records=[{"from_records": True}],
curated_payload={"from_curated": True},
)
# Verify only records data is used, not curated_payload
mock_db.session.bulk_save_objects.assert_called_once()
logs = mock_db.session.bulk_save_objects.call_args[0][0]
assert len(logs) == 1
from superset.utils import json as json_utils
payload = json_utils.loads(logs[0].json)
assert payload.get("from_records") is True
assert "from_curated" not in payload

View File

@@ -102,6 +102,34 @@ class TestLoggingMiddlewareOnCallTool:
assert call_kwargs["curated_payload"]["success"] is False
assert call_kwargs["duration_ms"] >= 0
@patch("superset.mcp_service.middleware.event_logger")
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
@pytest.mark.asyncio
async def test_on_call_tool_logs_failure_on_tool_error(
self, mock_get_user_id, mock_event_logger
):
"""on_call_tool records success=False when GlobalErrorHandler raises ToolError.
This simulates the real middleware chain: GlobalErrorHandler catches
tool exceptions and re-raises them as ToolError. Since LoggingMiddleware
sits between GlobalErrorHandler and StructuredContentStripper, it
catches the ToolError directly.
"""
from fastmcp.exceptions import ToolError
middleware = LoggingMiddleware()
ctx = _make_context(name="get_chart_info")
call_next = AsyncMock(side_effect=ToolError("Chart 999999 not found"))
with pytest.raises(ToolError, match="Chart 999999 not found"):
await middleware.on_call_tool(ctx, call_next)
mock_event_logger.log.assert_called_once()
call_kwargs = mock_event_logger.log.call_args[1]
assert call_kwargs["curated_payload"]["success"] is False
assert call_kwargs["curated_payload"]["tool"] == "get_chart_info"
assert call_kwargs["duration_ms"] >= 0
@patch("superset.mcp_service.middleware.event_logger")
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
@pytest.mark.asyncio
@@ -205,3 +233,132 @@ class TestExtractContextInfo:
_, _, _, slice_id, _, _ = middleware._extract_context_info(ctx)
assert slice_id == 66
class TestIsErrorResponse:
"""Tests for LoggingMiddleware._is_error_response()."""
def test_detects_error_schema_response(self):
"""Detects ToolResult containing a serialized error schema
(ChartError, DashboardError, etc.) via "error_type" field."""
from fastmcp.tools.tool import ToolResult
from mcp import types as mt
middleware = LoggingMiddleware()
error_json = (
'{"error": "Chart 999 not found",'
' "error_type": "not_found",'
' "timestamp": "2026-04-09T00:00:00Z"}'
)
result = ToolResult(content=[mt.TextContent(type="text", text=error_json)])
assert middleware._is_error_response(result) is True
def test_success_response_not_detected_as_error(self):
"""Normal ToolResult is not detected as error."""
from fastmcp.tools.tool import ToolResult
from mcp import types as mt
middleware = LoggingMiddleware()
result = ToolResult(
content=[mt.TextContent(type="text", text="Successfully retrieved data")]
)
assert middleware._is_error_response(result) is False
def test_empty_content_not_detected_as_error(self):
"""ToolResult with empty content is not detected as error."""
from fastmcp.tools.tool import ToolResult
middleware = LoggingMiddleware()
assert middleware._is_error_response(ToolResult(content=[])) is False
@patch("superset.mcp_service.middleware.event_logger")
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
@pytest.mark.asyncio
async def test_on_call_tool_logs_failure_for_error_schema(
self, mock_get_user_id, mock_event_logger
):
"""on_call_tool logs success=False when tool returns an
error schema (e.g. ChartError)."""
from fastmcp.tools.tool import ToolResult
from mcp import types as mt
middleware = LoggingMiddleware()
ctx = _make_context(name="get_chart_info")
error_json = (
'{"error": "Chart 999999 not found",'
' "error_type": "not_found",'
' "timestamp": "2026-04-09T00:00:00Z"}'
)
error_result = ToolResult(
content=[mt.TextContent(type="text", text=error_json)]
)
call_next = AsyncMock(return_value=error_result)
result = await middleware.on_call_tool(ctx, call_next)
assert result == error_result
mock_event_logger.log.assert_called_once()
call_kwargs = mock_event_logger.log.call_args[1]
assert call_kwargs["curated_payload"]["success"] is False
assert call_kwargs["curated_payload"]["tool"] == "get_chart_info"
class TestMiddlewareChainOrder:
"""Test that the middleware order from server.py logs failures correctly.
If the order is wrong (StructuredContentStripper innermost),
it swallows exceptions before LoggingMiddleware can see them,
causing success=True for failures.
"""
@patch("superset.mcp_service.middleware.event_logger")
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
@pytest.mark.asyncio
async def test_real_middleware_chain_logs_exception_as_failure(
self, mock_get_user_id, mock_event_logger
):
"""Tool exception is logged as success=False through the
real middleware chain from build_middleware_list()."""
from functools import partial
from fastmcp.tools.tool import ToolResult
from superset.mcp_service.server import build_middleware_list
middleware_list = build_middleware_list()
async def failing_tool(context: Any) -> Any:
raise ValueError("chart not found")
# Build chain same way FastMCP does
chain = failing_tool
for mw in reversed(middleware_list):
chain = partial(mw, call_next=chain)
ctx = _make_context(name="get_chart_info")
result = await chain(ctx)
# StructuredContentStripper (outermost) must catch the re-raised
# exception and convert it to a safe ToolResult with "Error:" text.
# If it's not outermost, the exception would leak to the MCP SDK.
assert isinstance(result, ToolResult)
assert result.content[0].text.startswith("Error:")
# LoggingMiddleware must log
# success=False. If the middleware order is wrong
# (StructuredContentStripper innermost), this would be
# success=True because the exception gets swallowed
# before LoggingMiddleware sees it.
log_calls = [
c
for c in mock_event_logger.log.call_args_list
if c[1].get("action") == "mcp_tool_call"
]
assert len(log_calls) == 1
assert log_calls[0][1]["curated_payload"]["success"] is False, (
"Middleware order is wrong: StructuredContentStripper is "
"swallowing exceptions before LoggingMiddleware can detect "
"them. Ensure StructuredContentStripper is outermost "
"(first added) in build_middleware_list()."
)