mirror of
https://github.com/apache/superset.git
synced 2026-06-02 14:19:21 +00:00
fix(MCP): fix MCP logs (#39159)
This commit is contained in:
@@ -130,6 +130,18 @@ class LoggingMiddleware(Middleware):
|
||||
in on_message().
|
||||
"""
|
||||
|
||||
def _is_error_response(self, result: ToolResult) -> bool:
|
||||
"""Check if a tool result contains an error schema response.
|
||||
|
||||
MCP tools return error schemas (ChartError, DashboardError, etc.)
|
||||
instead of raising exceptions. These serialize to JSON containing
|
||||
an "error_type" field.
|
||||
"""
|
||||
try:
|
||||
return '"error_type"' in result.content[0].text
|
||||
except (AttributeError, IndexError):
|
||||
return False
|
||||
|
||||
def _extract_context_info(
|
||||
self, context: MiddlewareContext
|
||||
) -> tuple[
|
||||
@@ -171,8 +183,11 @@ class LoggingMiddleware(Middleware):
|
||||
success = False
|
||||
try:
|
||||
result = await call_next(context)
|
||||
success = True
|
||||
success = not self._is_error_response(result)
|
||||
return result
|
||||
except Exception:
|
||||
success = False
|
||||
raise
|
||||
finally:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
if has_app_context():
|
||||
|
||||
@@ -28,6 +28,7 @@ from collections.abc import Sequence
|
||||
from typing import Annotated, Any
|
||||
|
||||
import uvicorn
|
||||
from fastmcp.server.middleware import Middleware
|
||||
|
||||
from superset.mcp_service.app import create_mcp_app, init_fastmcp_server
|
||||
from superset.mcp_service.mcp_config import (
|
||||
@@ -492,6 +493,24 @@ def _create_auth_provider(flask_app: Any) -> Any | None:
|
||||
return auth_provider
|
||||
|
||||
|
||||
def build_middleware_list() -> list[Middleware]:
|
||||
"""Build the core MCP middleware list in the correct order.
|
||||
|
||||
FastMCP wraps handlers so that the FIRST-added middleware is
|
||||
outermost. Order here is outermost → innermost:
|
||||
|
||||
1. StructuredContentStripper — safety net, converts exceptions
|
||||
to safe ToolResult text for transports that can't encode errors
|
||||
2. LoggingMiddleware — logs tool calls with success/failure status
|
||||
3. GlobalErrorHandler — catches tool exceptions, raises ToolError
|
||||
"""
|
||||
return [
|
||||
StructuredContentStripperMiddleware(),
|
||||
LoggingMiddleware(),
|
||||
GlobalErrorHandlerMiddleware(),
|
||||
]
|
||||
|
||||
|
||||
def run_server(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 5008,
|
||||
@@ -539,30 +558,15 @@ def run_server(
|
||||
flask_app = get_flask_app()
|
||||
auth_provider = _create_auth_provider(flask_app)
|
||||
|
||||
# Build middleware list
|
||||
# FastMCP wraps handlers so that the LAST-added middleware is
|
||||
# outermost. Order here is innermost → outermost.
|
||||
middleware_list = []
|
||||
middleware_list = build_middleware_list()
|
||||
|
||||
# Add caching middleware (innermost – runs closest to the tool)
|
||||
if caching_middleware := create_response_caching_middleware():
|
||||
middleware_list.append(caching_middleware)
|
||||
|
||||
# Add response size guard (protects LLM clients from huge responses)
|
||||
if size_guard_middleware := create_response_size_guard_middleware():
|
||||
# Add optional middleware (innermost, closest to tool)
|
||||
size_guard_middleware = create_response_size_guard_middleware()
|
||||
if size_guard_middleware:
|
||||
middleware_list.append(size_guard_middleware)
|
||||
|
||||
# Add logging middleware (logs all tool calls with duration tracking)
|
||||
middleware_list.append(LoggingMiddleware())
|
||||
|
||||
# Add global error handler (catches all exceptions, raises ToolError)
|
||||
middleware_list.append(GlobalErrorHandlerMiddleware())
|
||||
|
||||
# Strip outputSchema from tool definitions and structuredContent from
|
||||
# tool responses to prevent encoding errors on Claude.ai's MCP bridge.
|
||||
# MUST be outermost so it catches ToolError from GlobalErrorHandler
|
||||
# and converts to plain text before the MCP SDK tries to encode it.
|
||||
middleware_list.append(StructuredContentStripperMiddleware())
|
||||
if caching_middleware := create_response_caching_middleware():
|
||||
middleware_list.append(caching_middleware)
|
||||
|
||||
mcp_instance = init_fastmcp_server(
|
||||
auth=auth_provider,
|
||||
|
||||
@@ -384,6 +384,13 @@ class DBEventLogger(AbstractEventLogger):
|
||||
from superset.models.core import Log
|
||||
|
||||
records = kwargs.get("records", [])
|
||||
curated_payload = kwargs.get("curated_payload")
|
||||
|
||||
# If no records but curated_payload exists, use it as a single record
|
||||
# This enables MCP middleware logging which passes curated_payload
|
||||
if not records and curated_payload:
|
||||
records = [curated_payload]
|
||||
|
||||
logs = []
|
||||
for record in records:
|
||||
json_string: str | None
|
||||
|
||||
@@ -254,3 +254,64 @@ class TestEventLogger(unittest.TestCase):
|
||||
}
|
||||
]
|
||||
assert payload["duration_ms"] >= 100
|
||||
|
||||
@patch("superset.db")
|
||||
def test_curated_payload_used_when_records_empty(self, mock_db):
|
||||
"""Test that curated_payload is used when records is empty (MCP pattern).
|
||||
|
||||
MCP middleware passes curated_payload instead of records. This test verifies
|
||||
that DBEventLogger.log() creates a Log entry from curated_payload when
|
||||
records is empty.
|
||||
"""
|
||||
logger = DBEventLogger()
|
||||
|
||||
with app.test_request_context():
|
||||
logger.log(
|
||||
user_id=1,
|
||||
action="mcp_tool_call",
|
||||
dashboard_id=None,
|
||||
duration_ms=100,
|
||||
slice_id=None,
|
||||
referrer=None,
|
||||
curated_payload={"tool": "list_charts", "success": True},
|
||||
)
|
||||
|
||||
# Verify bulk_save_objects was called with one Log object
|
||||
mock_db.session.bulk_save_objects.assert_called_once()
|
||||
logs = mock_db.session.bulk_save_objects.call_args[0][0]
|
||||
assert len(logs) == 1
|
||||
assert logs[0].action == "mcp_tool_call"
|
||||
assert logs[0].duration_ms == 100
|
||||
# Verify JSON contains the curated_payload data
|
||||
from superset.utils import json as json_utils
|
||||
|
||||
payload = json_utils.loads(logs[0].json)
|
||||
assert payload["tool"] == "list_charts"
|
||||
assert payload["success"] is True
|
||||
|
||||
@patch("superset.db")
|
||||
def test_records_takes_precedence_over_curated_payload(self, mock_db):
|
||||
"""Test that records takes precedence over curated_payload."""
|
||||
logger = DBEventLogger()
|
||||
|
||||
with app.test_request_context():
|
||||
logger.log(
|
||||
user_id=1,
|
||||
action="test_action",
|
||||
dashboard_id=None,
|
||||
duration_ms=50,
|
||||
slice_id=None,
|
||||
referrer=None,
|
||||
records=[{"from_records": True}],
|
||||
curated_payload={"from_curated": True},
|
||||
)
|
||||
|
||||
# Verify only records data is used, not curated_payload
|
||||
mock_db.session.bulk_save_objects.assert_called_once()
|
||||
logs = mock_db.session.bulk_save_objects.call_args[0][0]
|
||||
assert len(logs) == 1
|
||||
from superset.utils import json as json_utils
|
||||
|
||||
payload = json_utils.loads(logs[0].json)
|
||||
assert payload.get("from_records") is True
|
||||
assert "from_curated" not in payload
|
||||
|
||||
@@ -102,6 +102,34 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
assert call_kwargs["curated_payload"]["success"] is False
|
||||
assert call_kwargs["duration_ms"] >= 0
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_logs_failure_on_tool_error(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""on_call_tool records success=False when GlobalErrorHandler raises ToolError.
|
||||
|
||||
This simulates the real middleware chain: GlobalErrorHandler catches
|
||||
tool exceptions and re-raises them as ToolError. Since LoggingMiddleware
|
||||
sits between GlobalErrorHandler and StructuredContentStripper, it
|
||||
catches the ToolError directly.
|
||||
"""
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="get_chart_info")
|
||||
call_next = AsyncMock(side_effect=ToolError("Chart 999999 not found"))
|
||||
|
||||
with pytest.raises(ToolError, match="Chart 999999 not found"):
|
||||
await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
mock_event_logger.log.assert_called_once()
|
||||
call_kwargs = mock_event_logger.log.call_args[1]
|
||||
assert call_kwargs["curated_payload"]["success"] is False
|
||||
assert call_kwargs["curated_payload"]["tool"] == "get_chart_info"
|
||||
assert call_kwargs["duration_ms"] >= 0
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
@@ -205,3 +233,132 @@ class TestExtractContextInfo:
|
||||
_, _, _, slice_id, _, _ = middleware._extract_context_info(ctx)
|
||||
|
||||
assert slice_id == 66
|
||||
|
||||
|
||||
class TestIsErrorResponse:
|
||||
"""Tests for LoggingMiddleware._is_error_response()."""
|
||||
|
||||
def test_detects_error_schema_response(self):
|
||||
"""Detects ToolResult containing a serialized error schema
|
||||
(ChartError, DashboardError, etc.) via "error_type" field."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp import types as mt
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
error_json = (
|
||||
'{"error": "Chart 999 not found",'
|
||||
' "error_type": "not_found",'
|
||||
' "timestamp": "2026-04-09T00:00:00Z"}'
|
||||
)
|
||||
result = ToolResult(content=[mt.TextContent(type="text", text=error_json)])
|
||||
assert middleware._is_error_response(result) is True
|
||||
|
||||
def test_success_response_not_detected_as_error(self):
|
||||
"""Normal ToolResult is not detected as error."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp import types as mt
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
result = ToolResult(
|
||||
content=[mt.TextContent(type="text", text="Successfully retrieved data")]
|
||||
)
|
||||
assert middleware._is_error_response(result) is False
|
||||
|
||||
def test_empty_content_not_detected_as_error(self):
|
||||
"""ToolResult with empty content is not detected as error."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
assert middleware._is_error_response(ToolResult(content=[])) is False
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_on_call_tool_logs_failure_for_error_schema(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""on_call_tool logs success=False when tool returns an
|
||||
error schema (e.g. ChartError)."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp import types as mt
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="get_chart_info")
|
||||
|
||||
error_json = (
|
||||
'{"error": "Chart 999999 not found",'
|
||||
' "error_type": "not_found",'
|
||||
' "timestamp": "2026-04-09T00:00:00Z"}'
|
||||
)
|
||||
error_result = ToolResult(
|
||||
content=[mt.TextContent(type="text", text=error_json)]
|
||||
)
|
||||
call_next = AsyncMock(return_value=error_result)
|
||||
|
||||
result = await middleware.on_call_tool(ctx, call_next)
|
||||
|
||||
assert result == error_result
|
||||
mock_event_logger.log.assert_called_once()
|
||||
call_kwargs = mock_event_logger.log.call_args[1]
|
||||
assert call_kwargs["curated_payload"]["success"] is False
|
||||
assert call_kwargs["curated_payload"]["tool"] == "get_chart_info"
|
||||
|
||||
|
||||
class TestMiddlewareChainOrder:
|
||||
"""Test that the middleware order from server.py logs failures correctly.
|
||||
|
||||
If the order is wrong (StructuredContentStripper innermost),
|
||||
it swallows exceptions before LoggingMiddleware can see them,
|
||||
causing success=True for failures.
|
||||
"""
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@pytest.mark.asyncio
|
||||
async def test_real_middleware_chain_logs_exception_as_failure(
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""Tool exception is logged as success=False through the
|
||||
real middleware chain from build_middleware_list()."""
|
||||
from functools import partial
|
||||
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.mcp_service.server import build_middleware_list
|
||||
|
||||
middleware_list = build_middleware_list()
|
||||
|
||||
async def failing_tool(context: Any) -> Any:
|
||||
raise ValueError("chart not found")
|
||||
|
||||
# Build chain same way FastMCP does
|
||||
chain = failing_tool
|
||||
for mw in reversed(middleware_list):
|
||||
chain = partial(mw, call_next=chain)
|
||||
|
||||
ctx = _make_context(name="get_chart_info")
|
||||
result = await chain(ctx)
|
||||
|
||||
# StructuredContentStripper (outermost) must catch the re-raised
|
||||
# exception and convert it to a safe ToolResult with "Error:" text.
|
||||
# If it's not outermost, the exception would leak to the MCP SDK.
|
||||
assert isinstance(result, ToolResult)
|
||||
assert result.content[0].text.startswith("Error:")
|
||||
|
||||
# LoggingMiddleware must log
|
||||
# success=False. If the middleware order is wrong
|
||||
# (StructuredContentStripper innermost), this would be
|
||||
# success=True because the exception gets swallowed
|
||||
# before LoggingMiddleware sees it.
|
||||
log_calls = [
|
||||
c
|
||||
for c in mock_event_logger.log.call_args_list
|
||||
if c[1].get("action") == "mcp_tool_call"
|
||||
]
|
||||
assert len(log_calls) == 1
|
||||
assert log_calls[0][1]["curated_payload"]["success"] is False, (
|
||||
"Middleware order is wrong: StructuredContentStripper is "
|
||||
"swallowing exceptions before LoggingMiddleware can detect "
|
||||
"them. Ensure StructuredContentStripper is outermost "
|
||||
"(first added) in build_middleware_list()."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user