From 9e313ba82f18d47795a2ee43e2472f6205537f10 Mon Sep 17 00:00:00 2001 From: Alexandru Soare <37236580+alexandrusoare@users.noreply.github.com> Date: Wed, 15 Apr 2026 15:57:04 +0300 Subject: [PATCH] fix(MCP): fix MCP logs (#39159) (cherry picked from commit ffcc6e8b635fb66437a7ce855d98a2fe5ec3ce80) --- superset/mcp_service/middleware.py | 17 +- superset/mcp_service/server.py | 46 ++--- superset/utils/log.py | 7 + tests/integration_tests/event_logger_tests.py | 61 +++++++ .../mcp_service/test_middleware_logging.py | 157 ++++++++++++++++++ 5 files changed, 266 insertions(+), 22 deletions(-) diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index 2254e8cc597..797bb70076c 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -130,6 +130,18 @@ class LoggingMiddleware(Middleware): in on_message(). """ + def _is_error_response(self, result: ToolResult) -> bool: + """Check if a tool result contains an error schema response. + + MCP tools return error schemas (ChartError, DashboardError, etc.) + instead of raising exceptions. These serialize to JSON containing + an "error_type" field. + """ + try: + return '"error_type"' in result.content[0].text + except (AttributeError, IndexError): + return False + def _extract_context_info( self, context: MiddlewareContext ) -> tuple[ @@ -171,8 +183,11 @@ class LoggingMiddleware(Middleware): success = False try: result = await call_next(context) - success = True + success = not self._is_error_response(result) return result + except Exception: + success = False + raise finally: duration_ms = int((time.time() - start_time) * 1000) if has_app_context(): diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index 8f0d6feccaf..074914de4e7 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -28,6 +28,7 @@ from collections.abc import Sequence from typing import Annotated, Any import uvicorn +from fastmcp.server.middleware import Middleware from superset.mcp_service.app import create_mcp_app, init_fastmcp_server from superset.mcp_service.mcp_config import ( @@ -492,6 +493,24 @@ def _create_auth_provider(flask_app: Any) -> Any | None: return auth_provider +def build_middleware_list() -> list[Middleware]: + """Build the core MCP middleware list in the correct order. + + FastMCP wraps handlers so that the FIRST-added middleware is + outermost. Order here is outermost → innermost: + + 1. StructuredContentStripper — safety net, converts exceptions + to safe ToolResult text for transports that can't encode errors + 2. LoggingMiddleware — logs tool calls with success/failure status + 3. GlobalErrorHandler — catches tool exceptions, raises ToolError + """ + return [ + StructuredContentStripperMiddleware(), + LoggingMiddleware(), + GlobalErrorHandlerMiddleware(), + ] + + def run_server( host: str = "127.0.0.1", port: int = 5008, @@ -539,30 +558,15 @@ def run_server( flask_app = get_flask_app() auth_provider = _create_auth_provider(flask_app) - # Build middleware list - # FastMCP wraps handlers so that the LAST-added middleware is - # outermost. Order here is innermost → outermost. - middleware_list = [] + middleware_list = build_middleware_list() - # Add caching middleware (innermost – runs closest to the tool) - if caching_middleware := create_response_caching_middleware(): - middleware_list.append(caching_middleware) - - # Add response size guard (protects LLM clients from huge responses) - if size_guard_middleware := create_response_size_guard_middleware(): + # Add optional middleware (innermost, closest to tool) + size_guard_middleware = create_response_size_guard_middleware() + if size_guard_middleware: middleware_list.append(size_guard_middleware) - # Add logging middleware (logs all tool calls with duration tracking) - middleware_list.append(LoggingMiddleware()) - - # Add global error handler (catches all exceptions, raises ToolError) - middleware_list.append(GlobalErrorHandlerMiddleware()) - - # Strip outputSchema from tool definitions and structuredContent from - # tool responses to prevent encoding errors on Claude.ai's MCP bridge. - # MUST be outermost so it catches ToolError from GlobalErrorHandler - # and converts to plain text before the MCP SDK tries to encode it. - middleware_list.append(StructuredContentStripperMiddleware()) + if caching_middleware := create_response_caching_middleware(): + middleware_list.append(caching_middleware) mcp_instance = init_fastmcp_server( auth=auth_provider, diff --git a/superset/utils/log.py b/superset/utils/log.py index 6685010b4d2..e840c61f037 100644 --- a/superset/utils/log.py +++ b/superset/utils/log.py @@ -384,6 +384,13 @@ class DBEventLogger(AbstractEventLogger): from superset.models.core import Log records = kwargs.get("records", []) + curated_payload = kwargs.get("curated_payload") + + # If no records but curated_payload exists, use it as a single record + # This enables MCP middleware logging which passes curated_payload + if not records and curated_payload: + records = [curated_payload] + logs = [] for record in records: json_string: str | None diff --git a/tests/integration_tests/event_logger_tests.py b/tests/integration_tests/event_logger_tests.py index f87b5a782e3..6bd8b96a945 100644 --- a/tests/integration_tests/event_logger_tests.py +++ b/tests/integration_tests/event_logger_tests.py @@ -254,3 +254,64 @@ class TestEventLogger(unittest.TestCase): } ] assert payload["duration_ms"] >= 100 + + @patch("superset.db") + def test_curated_payload_used_when_records_empty(self, mock_db): + """Test that curated_payload is used when records is empty (MCP pattern). + + MCP middleware passes curated_payload instead of records. This test verifies + that DBEventLogger.log() creates a Log entry from curated_payload when + records is empty. + """ + logger = DBEventLogger() + + with app.test_request_context(): + logger.log( + user_id=1, + action="mcp_tool_call", + dashboard_id=None, + duration_ms=100, + slice_id=None, + referrer=None, + curated_payload={"tool": "list_charts", "success": True}, + ) + + # Verify bulk_save_objects was called with one Log object + mock_db.session.bulk_save_objects.assert_called_once() + logs = mock_db.session.bulk_save_objects.call_args[0][0] + assert len(logs) == 1 + assert logs[0].action == "mcp_tool_call" + assert logs[0].duration_ms == 100 + # Verify JSON contains the curated_payload data + from superset.utils import json as json_utils + + payload = json_utils.loads(logs[0].json) + assert payload["tool"] == "list_charts" + assert payload["success"] is True + + @patch("superset.db") + def test_records_takes_precedence_over_curated_payload(self, mock_db): + """Test that records takes precedence over curated_payload.""" + logger = DBEventLogger() + + with app.test_request_context(): + logger.log( + user_id=1, + action="test_action", + dashboard_id=None, + duration_ms=50, + slice_id=None, + referrer=None, + records=[{"from_records": True}], + curated_payload={"from_curated": True}, + ) + + # Verify only records data is used, not curated_payload + mock_db.session.bulk_save_objects.assert_called_once() + logs = mock_db.session.bulk_save_objects.call_args[0][0] + assert len(logs) == 1 + from superset.utils import json as json_utils + + payload = json_utils.loads(logs[0].json) + assert payload.get("from_records") is True + assert "from_curated" not in payload diff --git a/tests/unit_tests/mcp_service/test_middleware_logging.py b/tests/unit_tests/mcp_service/test_middleware_logging.py index 1f48d531ccc..50d23449707 100644 --- a/tests/unit_tests/mcp_service/test_middleware_logging.py +++ b/tests/unit_tests/mcp_service/test_middleware_logging.py @@ -102,6 +102,34 @@ class TestLoggingMiddlewareOnCallTool: assert call_kwargs["curated_payload"]["success"] is False assert call_kwargs["duration_ms"] >= 0 + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_logs_failure_on_tool_error( + self, mock_get_user_id, mock_event_logger + ): + """on_call_tool records success=False when GlobalErrorHandler raises ToolError. + + This simulates the real middleware chain: GlobalErrorHandler catches + tool exceptions and re-raises them as ToolError. Since LoggingMiddleware + sits between GlobalErrorHandler and StructuredContentStripper, it + catches the ToolError directly. + """ + from fastmcp.exceptions import ToolError + + middleware = LoggingMiddleware() + ctx = _make_context(name="get_chart_info") + call_next = AsyncMock(side_effect=ToolError("Chart 999999 not found")) + + with pytest.raises(ToolError, match="Chart 999999 not found"): + await middleware.on_call_tool(ctx, call_next) + + mock_event_logger.log.assert_called_once() + call_kwargs = mock_event_logger.log.call_args[1] + assert call_kwargs["curated_payload"]["success"] is False + assert call_kwargs["curated_payload"]["tool"] == "get_chart_info" + assert call_kwargs["duration_ms"] >= 0 + @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio @@ -205,3 +233,132 @@ class TestExtractContextInfo: _, _, _, slice_id, _, _ = middleware._extract_context_info(ctx) assert slice_id == 66 + + +class TestIsErrorResponse: + """Tests for LoggingMiddleware._is_error_response().""" + + def test_detects_error_schema_response(self): + """Detects ToolResult containing a serialized error schema + (ChartError, DashboardError, etc.) via "error_type" field.""" + from fastmcp.tools.tool import ToolResult + from mcp import types as mt + + middleware = LoggingMiddleware() + error_json = ( + '{"error": "Chart 999 not found",' + ' "error_type": "not_found",' + ' "timestamp": "2026-04-09T00:00:00Z"}' + ) + result = ToolResult(content=[mt.TextContent(type="text", text=error_json)]) + assert middleware._is_error_response(result) is True + + def test_success_response_not_detected_as_error(self): + """Normal ToolResult is not detected as error.""" + from fastmcp.tools.tool import ToolResult + from mcp import types as mt + + middleware = LoggingMiddleware() + result = ToolResult( + content=[mt.TextContent(type="text", text="Successfully retrieved data")] + ) + assert middleware._is_error_response(result) is False + + def test_empty_content_not_detected_as_error(self): + """ToolResult with empty content is not detected as error.""" + from fastmcp.tools.tool import ToolResult + + middleware = LoggingMiddleware() + assert middleware._is_error_response(ToolResult(content=[])) is False + + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_logs_failure_for_error_schema( + self, mock_get_user_id, mock_event_logger + ): + """on_call_tool logs success=False when tool returns an + error schema (e.g. ChartError).""" + from fastmcp.tools.tool import ToolResult + from mcp import types as mt + + middleware = LoggingMiddleware() + ctx = _make_context(name="get_chart_info") + + error_json = ( + '{"error": "Chart 999999 not found",' + ' "error_type": "not_found",' + ' "timestamp": "2026-04-09T00:00:00Z"}' + ) + error_result = ToolResult( + content=[mt.TextContent(type="text", text=error_json)] + ) + call_next = AsyncMock(return_value=error_result) + + result = await middleware.on_call_tool(ctx, call_next) + + assert result == error_result + mock_event_logger.log.assert_called_once() + call_kwargs = mock_event_logger.log.call_args[1] + assert call_kwargs["curated_payload"]["success"] is False + assert call_kwargs["curated_payload"]["tool"] == "get_chart_info" + + +class TestMiddlewareChainOrder: + """Test that the middleware order from server.py logs failures correctly. + + If the order is wrong (StructuredContentStripper innermost), + it swallows exceptions before LoggingMiddleware can see them, + causing success=True for failures. + """ + + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_real_middleware_chain_logs_exception_as_failure( + self, mock_get_user_id, mock_event_logger + ): + """Tool exception is logged as success=False through the + real middleware chain from build_middleware_list().""" + from functools import partial + + from fastmcp.tools.tool import ToolResult + + from superset.mcp_service.server import build_middleware_list + + middleware_list = build_middleware_list() + + async def failing_tool(context: Any) -> Any: + raise ValueError("chart not found") + + # Build chain same way FastMCP does + chain = failing_tool + for mw in reversed(middleware_list): + chain = partial(mw, call_next=chain) + + ctx = _make_context(name="get_chart_info") + result = await chain(ctx) + + # StructuredContentStripper (outermost) must catch the re-raised + # exception and convert it to a safe ToolResult with "Error:" text. + # If it's not outermost, the exception would leak to the MCP SDK. + assert isinstance(result, ToolResult) + assert result.content[0].text.startswith("Error:") + + # LoggingMiddleware must log + # success=False. If the middleware order is wrong + # (StructuredContentStripper innermost), this would be + # success=True because the exception gets swallowed + # before LoggingMiddleware sees it. + log_calls = [ + c + for c in mock_event_logger.log.call_args_list + if c[1].get("action") == "mcp_tool_call" + ] + assert len(log_calls) == 1 + assert log_calls[0][1]["curated_payload"]["success"] is False, ( + "Middleware order is wrong: StructuredContentStripper is " + "swallowing exceptions before LoggingMiddleware can detect " + "them. Ensure StructuredContentStripper is outermost " + "(first added) in build_middleware_list()." + )