From dface257b8b088dc330ae31f0657563a2f74487b Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Fri, 27 Mar 2026 16:51:07 +0000 Subject: [PATCH] feat(mcp): resolve call_tool proxy name and capture error_type in logging When tool search (progressive discovery) is enabled, MCP clients use the call_tool proxy instead of individual tools. The logging middleware now resolves the actual tool name from the proxy arguments and stores it as mcp_tool in the curated payload, restoring analytics granularity. Additionally, on failure the error_type (exception class name) is captured in the curated payload so analysts can distinguish failure modes without cross-referencing the separate mcp_tool_error log entries. Co-Authored-By: Claude Opus 4.6 --- superset/mcp_service/middleware.py | 70 +++++++++-- .../mcp_service/test_middleware_logging.py | 117 +++++++++++++++++- 2 files changed, 173 insertions(+), 14 deletions(-) diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index b0ffc4f5f7c..15202060d1c 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -128,8 +128,18 @@ class LoggingMiddleware(Middleware): Tool calls are handled in on_call_tool() which wraps execution to capture duration_ms. Non-tool messages (resource reads, prompts, etc.) are handled in on_message(). + + When tool search is enabled (progressive discovery), the MCP client calls + ``call_tool`` / ``search_tools`` proxies instead of individual tools. This + middleware resolves the underlying tool name from ``call_tool`` arguments so + that analytics queries can filter by the actual tool (stored as ``mcp_tool`` + in the curated payload). """ + #: Default proxy names used by FastMCP tool-search transforms. + _CALL_TOOL_PROXY = "call_tool" + _SEARCH_TOOLS_PROXY = "search_tools" + def _extract_context_info( self, context: MiddlewareContext ) -> tuple[ @@ -156,6 +166,28 @@ class LoggingMiddleware(Middleware): dataset_id = params.get("dataset_id") return agent_id, user_id, dashboard_id, slice_id, dataset_id, params + @staticmethod + def _resolve_tool_name(tool_name: str | None, params: dict[str, Any]) -> str | None: + """Resolve the underlying tool name from call_tool proxy arguments. + + When tool search is enabled, the MCP client uses the ``call_tool`` + proxy and passes the real tool name as the ``name`` argument. This + helper extracts that value so we can log which tool was actually + executed rather than just ``"call_tool"``. + + Returns: + The resolved tool name if *tool_name* is the call_tool proxy and + ``params["name"]`` is a non-empty string, otherwise ``None``. + """ + if ( + tool_name == LoggingMiddleware._CALL_TOOL_PROXY + and isinstance(params, dict) + and isinstance(params.get("name"), str) + and params["name"] + ): + return params["name"] + return None + async def on_call_tool( self, context: MiddlewareContext, @@ -166,15 +198,34 @@ class LoggingMiddleware(Middleware): self._extract_context_info(context) ) tool_name = getattr(context.message, "name", None) + mcp_tool = self._resolve_tool_name(tool_name, params) start_time = time.time() success = False + error_type: str | None = None try: result = await call_next(context) success = True return result + except Exception as exc: + error_type = type(exc).__name__ + raise finally: duration_ms = int((time.time() - start_time) * 1000) + payload: dict[str, Any] = { + "tool": tool_name, + "agent_id": agent_id, + "params": _sanitize_params(params), + "method": context.method, + "dashboard_id": dashboard_id, + "slice_id": slice_id, + "dataset_id": dataset_id, + "success": success, + } + if mcp_tool is not None: + payload["mcp_tool"] = mcp_tool + if error_type is not None: + payload["error_type"] = error_type if has_app_context(): event_logger.log( user_id=user_id, @@ -183,22 +234,14 @@ class LoggingMiddleware(Middleware): duration_ms=duration_ms, slice_id=slice_id, referrer=None, - curated_payload={ - "tool": tool_name, - "agent_id": agent_id, - "params": _sanitize_params(params), - "method": context.method, - "dashboard_id": dashboard_id, - "slice_id": slice_id, - "dataset_id": dataset_id, - "success": success, - }, + curated_payload=payload, ) logger.info( - "MCP tool call: tool=%s, agent_id=%s, user_id=%s, method=%s, " - "dashboard_id=%s, slice_id=%s, dataset_id=%s, duration_ms=%s, " - "success=%s", + "MCP tool call: tool=%s, mcp_tool=%s, agent_id=%s, user_id=%s, " + "method=%s, dashboard_id=%s, slice_id=%s, dataset_id=%s, " + "duration_ms=%s, success=%s, error_type=%s", tool_name, + mcp_tool, agent_id, user_id, context.method, @@ -207,6 +250,7 @@ class LoggingMiddleware(Middleware): dataset_id, duration_ms, success, + error_type, ) async def on_message( diff --git a/tests/unit_tests/mcp_service/test_middleware_logging.py b/tests/unit_tests/mcp_service/test_middleware_logging.py index 1f48d531ccc..67f82e1feb4 100644 --- a/tests/unit_tests/mcp_service/test_middleware_logging.py +++ b/tests/unit_tests/mcp_service/test_middleware_logging.py @@ -20,6 +20,8 @@ Unit tests for LoggingMiddleware on_call_tool() and on_message() methods. Tests verify that: - on_call_tool() captures duration_ms and success status +- on_call_tool() resolves call_tool proxy to actual tool name (mcp_tool) +- on_call_tool() captures error_type on failure - on_message() logs non-tool messages without duration - _extract_context_info() extracts entity IDs from params """ @@ -88,7 +90,7 @@ class TestLoggingMiddlewareOnCallTool: async def test_on_call_tool_logs_failure_on_exception( self, mock_get_user_id, mock_event_logger ): - """on_call_tool records success=False when tool raises.""" + """on_call_tool records success=False and error_type when tool raises.""" middleware = LoggingMiddleware() ctx = _make_context(name="execute_sql") call_next = AsyncMock(side_effect=ValueError("boom")) @@ -100,8 +102,25 @@ class TestLoggingMiddlewareOnCallTool: mock_event_logger.log.assert_called_once() call_kwargs = mock_event_logger.log.call_args[1] assert call_kwargs["curated_payload"]["success"] is False + assert call_kwargs["curated_payload"]["error_type"] == "ValueError" assert call_kwargs["duration_ms"] >= 0 + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_no_error_type_on_success( + self, mock_get_user_id, mock_event_logger + ): + """on_call_tool omits error_type from payload on success.""" + middleware = LoggingMiddleware() + ctx = _make_context(name="list_charts") + call_next = AsyncMock(return_value="ok") + + await middleware.on_call_tool(ctx, call_next) + + payload = mock_event_logger.log.call_args[1]["curated_payload"] + assert "error_type" not in payload + @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio @@ -127,6 +146,66 @@ class TestLoggingMiddlewareOnCallTool: assert call_kwargs["slice_id"] == 20 assert call_kwargs["curated_payload"]["dataset_id"] == 30 + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_resolves_call_tool_proxy( + self, mock_get_user_id, mock_event_logger + ): + """call_tool proxy is resolved to the actual tool name via mcp_tool.""" + middleware = LoggingMiddleware() + ctx = _make_context( + name="call_tool", + params={"name": "list_datasets", "arguments": {"page": 1}}, + ) + call_next = AsyncMock(return_value="datasets") + + await middleware.on_call_tool(ctx, call_next) + + payload = mock_event_logger.log.call_args[1]["curated_payload"] + assert payload["tool"] == "call_tool" + assert payload["mcp_tool"] == "list_datasets" + + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_no_mcp_tool_for_direct_calls( + self, mock_get_user_id, mock_event_logger + ): + """Direct tool calls (not via proxy) omit mcp_tool from payload.""" + middleware = LoggingMiddleware() + ctx = _make_context(name="list_charts") + call_next = AsyncMock(return_value="charts") + + await middleware.on_call_tool(ctx, call_next) + + payload = mock_event_logger.log.call_args[1]["curated_payload"] + assert payload["tool"] == "list_charts" + assert "mcp_tool" not in payload + + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_proxy_failure_captures_both_fields( + self, mock_get_user_id, mock_event_logger + ): + """call_tool proxy failure captures mcp_tool and error_type.""" + middleware = LoggingMiddleware() + ctx = _make_context( + name="call_tool", + params={"name": "get_chart_data", "arguments": {"chart_id": 1}}, + ) + call_next = AsyncMock(side_effect=PermissionError("access denied")) + + with pytest.raises(PermissionError): + await middleware.on_call_tool(ctx, call_next) + + payload = mock_event_logger.log.call_args[1]["curated_payload"] + assert payload["tool"] == "call_tool" + assert payload["mcp_tool"] == "get_chart_data" + assert payload["success"] is False + assert payload["error_type"] == "PermissionError" + class TestLoggingMiddlewareOnMessage: """Tests for LoggingMiddleware.on_message().""" @@ -155,6 +234,42 @@ class TestLoggingMiddlewareOnMessage: assert "success" not in call_kwargs["curated_payload"] +class TestResolveToolName: + """Tests for LoggingMiddleware._resolve_tool_name().""" + + def test_resolves_call_tool_proxy(self): + """Returns the real tool name when call_tool proxy is used.""" + assert ( + LoggingMiddleware._resolve_tool_name( + "call_tool", {"name": "list_datasets", "arguments": {}} + ) + == "list_datasets" + ) + + def test_returns_none_for_direct_tool(self): + """Returns None for direct tool calls (not via proxy).""" + assert LoggingMiddleware._resolve_tool_name("list_charts", {"page": 1}) is None + + def test_returns_none_when_name_missing(self): + """Returns None when call_tool params lack 'name'.""" + assert LoggingMiddleware._resolve_tool_name("call_tool", {"foo": "bar"}) is None + + def test_returns_none_for_empty_name(self): + """Returns None when call_tool params have empty 'name'.""" + assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": ""}) is None + + def test_returns_none_for_non_string_name(self): + """Returns None when call_tool name param is not a string.""" + assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": 123}) is None + + def test_returns_none_for_search_tools(self): + """search_tools proxy is not resolved (no underlying tool name).""" + assert ( + LoggingMiddleware._resolve_tool_name("search_tools", {"query": "datasets"}) + is None + ) + + class TestExtractContextInfo: """Tests for LoggingMiddleware._extract_context_info()."""