From b6f545e61e1fbbc943e9dd5cec491c69a37a13c7 Mon Sep 17 00:00:00 2001 From: Maxime Beauchemin Date: Tue, 26 May 2026 11:37:37 -0700 Subject: [PATCH] feat(mcp): resolve call_tool proxy name and capture error_type in logging (#38915) Co-authored-by: Amin Ghadersohi --- superset/mcp_service/middleware.py | 85 ++++++++-- .../mcp_service/test_middleware_logging.py | 153 +++++++++++++++--- 2 files changed, 202 insertions(+), 36 deletions(-) diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index 844b5ff59d0..685022c3a9f 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -188,10 +188,15 @@ def _sanitize_params(params: dict[str, Any]) -> dict[str, Any]: """Remove sensitive fields from params before logging.""" if not isinstance(params, dict): return params - return { - k: "[REDACTED]" if k.lower() in _SENSITIVE_PARAM_KEYS else v - for k, v in params.items() - } + result: dict[str, Any] = {} + for k, v in params.items(): + if k.lower() in _SENSITIVE_PARAM_KEYS: + result[k] = "[REDACTED]" + elif k == "arguments" and isinstance(v, dict): + result[k] = _sanitize_params(v) + else: + result[k] = v + return result class LoggingMiddleware(Middleware): @@ -204,8 +209,17 @@ class LoggingMiddleware(Middleware): Tool calls are handled in on_call_tool() which wraps execution to capture duration_ms. Non-tool messages (resource reads, prompts, etc.) are handled in on_message(). + + When tool search is enabled (progressive discovery), the MCP client calls + ``call_tool`` proxies instead of individual tools. This middleware resolves + the underlying tool name from ``call_tool`` arguments so that analytics + queries can filter by the actual tool (stored as ``mcp_tool`` in the curated + payload). """ + #: Proxy name used by FastMCP tool-search transforms. + _CALL_TOOL_PROXY = "call_tool" + def _is_error_response(self, result: ToolResult) -> bool: """Check if a tool result contains an error schema response. @@ -244,6 +258,28 @@ class LoggingMiddleware(Middleware): dataset_id = params.get("dataset_id") return agent_id, user_id, dashboard_id, slice_id, dataset_id, params + @staticmethod + def _resolve_tool_name(tool_name: str | None, params: Any) -> str | None: + """Resolve the underlying tool name from call_tool proxy arguments. + + When tool search is enabled, the MCP client uses the ``call_tool`` + proxy and passes the real tool name as the ``name`` argument. This + helper extracts that value so we can log which tool was actually + executed rather than just ``"call_tool"``. + + Returns: + The resolved tool name if *tool_name* is the call_tool proxy and + ``params["name"]`` is a non-empty string, otherwise ``None``. + """ + if ( + tool_name == LoggingMiddleware._CALL_TOOL_PROXY + and isinstance(params, dict) + and isinstance(params.get("name"), str) + and params["name"] + ): + return params["name"] + return None + async def on_call_tool( self, context: MiddlewareContext, @@ -254,11 +290,13 @@ class LoggingMiddleware(Middleware): self._extract_context_info(context) ) tool_name = getattr(context.message, "name", None) + mcp_tool = self._resolve_tool_name(tool_name, params) mcp_call_id = secrets.token_hex(16) _mcp_call_id_var.set(mcp_call_id) start_time = time.time() success = False + error_type: str | None = None try: result = await call_next(context) success = not self._is_error_response(result) @@ -270,11 +308,27 @@ class LoggingMiddleware(Middleware): structured_content=result.structured_content, ) return result - except Exception: + except Exception as exc: + error_type = type(exc).__name__ success = False raise finally: duration_ms = int((time.time() - start_time) * 1000) + payload: dict[str, Any] = { + "mcp_call_id": mcp_call_id, + "tool": tool_name, + "agent_id": agent_id, + "params": _sanitize_params(params), + "method": context.method, + "dashboard_id": dashboard_id, + "slice_id": slice_id, + "dataset_id": dataset_id, + "success": success, + } + if mcp_tool is not None: + payload["mcp_tool"] = mcp_tool + if error_type is not None: + payload["error_type"] = error_type if has_app_context(): event_logger.log( user_id=user_id, @@ -283,22 +337,18 @@ class LoggingMiddleware(Middleware): duration_ms=duration_ms, slice_id=slice_id, referrer=None, - curated_payload={ - "mcp_call_id": mcp_call_id, - "tool": tool_name, - "agent_id": agent_id, - "params": _sanitize_params(params), - "method": context.method, - "dashboard_id": dashboard_id, - "slice_id": slice_id, - "dataset_id": dataset_id, - "success": success, - }, + curated_payload=payload, ) + extra_parts = [] + if mcp_tool is not None: + extra_parts.append(f"mcp_tool={mcp_tool}") + if error_type is not None: + extra_parts.append(f"error_type={error_type}") + extra = (", " + ", ".join(extra_parts)) if extra_parts else "" logger.info( "MCP tool call: tool=%s, agent_id=%s, user_id=%s, method=%s, " "dashboard_id=%s, slice_id=%s, dataset_id=%s, duration_ms=%s, " - "success=%s, mcp_call_id=%s", + "success=%s, mcp_call_id=%s%s", tool_name, agent_id, user_id, @@ -309,6 +359,7 @@ class LoggingMiddleware(Middleware): duration_ms, success, mcp_call_id, + extra, ) async def on_message( diff --git a/tests/unit_tests/mcp_service/test_middleware_logging.py b/tests/unit_tests/mcp_service/test_middleware_logging.py index 3f81dc3ae9b..de8f246b599 100644 --- a/tests/unit_tests/mcp_service/test_middleware_logging.py +++ b/tests/unit_tests/mcp_service/test_middleware_logging.py @@ -20,6 +20,8 @@ Unit tests for LoggingMiddleware on_call_tool() and on_message() methods. Tests verify that: - on_call_tool() captures duration_ms and success status +- on_call_tool() resolves call_tool proxy to actual tool name (mcp_tool) +- on_call_tool() captures error_type on failure - on_message() logs non-tool messages without duration - _extract_context_info() extracts entity IDs from params """ @@ -65,7 +67,7 @@ class TestLoggingMiddlewareOnCallTool: @pytest.mark.asyncio async def test_on_call_tool_logs_duration_and_success( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """on_call_tool records duration_ms and success=True on normal return.""" middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") @@ -91,8 +93,8 @@ class TestLoggingMiddlewareOnCallTool: @pytest.mark.asyncio async def test_on_call_tool_logs_failure_on_exception( self, mock_get_user_id, mock_event_logger - ): - """on_call_tool records success=False when tool raises.""" + ) -> None: + """on_call_tool records success=False and error_type when tool raises.""" middleware = LoggingMiddleware() ctx = _make_context(name="execute_sql") call_next = AsyncMock(side_effect=ValueError("boom")) @@ -104,6 +106,7 @@ class TestLoggingMiddlewareOnCallTool: mock_event_logger.log.assert_called_once() call_kwargs = mock_event_logger.log.call_args[1] assert call_kwargs["curated_payload"]["success"] is False + assert call_kwargs["curated_payload"]["error_type"] == "ValueError" assert call_kwargs["duration_ms"] >= 0 @patch("superset.mcp_service.middleware.event_logger") @@ -111,7 +114,7 @@ class TestLoggingMiddlewareOnCallTool: @pytest.mark.asyncio async def test_on_call_tool_logs_failure_on_tool_error( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """on_call_tool records success=False when GlobalErrorHandler raises ToolError. This simulates the real middleware chain: GlobalErrorHandler catches @@ -137,7 +140,7 @@ class TestLoggingMiddlewareOnCallTool: @pytest.mark.asyncio async def test_on_call_tool_includes_mcp_call_id_in_curated_payload( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """on_call_tool adds mcp_call_id to curated_payload.""" middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") @@ -155,7 +158,7 @@ class TestLoggingMiddlewareOnCallTool: @pytest.mark.asyncio async def test_on_call_tool_injects_mcp_call_id_into_tool_result_meta( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """on_call_tool injects mcp_call_id into ToolResult.meta.""" middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") @@ -173,7 +176,7 @@ class TestLoggingMiddlewareOnCallTool: @pytest.mark.asyncio async def test_on_call_tool_preserves_existing_meta( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """on_call_tool merges mcp_call_id with existing ToolResult.meta.""" middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") @@ -193,7 +196,7 @@ class TestLoggingMiddlewareOnCallTool: @pytest.mark.asyncio async def test_on_call_tool_extracts_entity_ids( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """on_call_tool extracts dashboard_id, chart_id, dataset_id from params.""" middleware = LoggingMiddleware() ctx = _make_context( @@ -222,7 +225,7 @@ class TestLoggingMiddlewareOnMessage: @pytest.mark.asyncio async def test_on_message_logs_without_duration( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """on_message logs with action=mcp_message and duration_ms=None.""" middleware = LoggingMiddleware() ctx = _make_context(method="resources/read", name="instance/metadata") @@ -240,12 +243,124 @@ class TestLoggingMiddlewareOnMessage: # on_message should NOT have success field assert "success" not in call_kwargs["curated_payload"] + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_no_error_type_on_success( + self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock + ) -> None: + """on_call_tool omits error_type from payload on success.""" + middleware = LoggingMiddleware() + ctx = _make_context(name="list_charts") + call_next = AsyncMock(return_value="ok") + + await middleware.on_call_tool(ctx, call_next) + + payload = mock_event_logger.log.call_args[1]["curated_payload"] + assert "error_type" not in payload + + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_resolves_call_tool_proxy( + self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock + ) -> None: + """call_tool proxy is resolved to the actual tool name via mcp_tool.""" + middleware = LoggingMiddleware() + ctx = _make_context( + name="call_tool", + params={"name": "list_datasets", "arguments": {"page": 1}}, + ) + call_next = AsyncMock(return_value="datasets") + + await middleware.on_call_tool(ctx, call_next) + + payload = mock_event_logger.log.call_args[1]["curated_payload"] + assert payload["tool"] == "call_tool" + assert payload["mcp_tool"] == "list_datasets" + + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_no_mcp_tool_for_direct_calls( + self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock + ) -> None: + """Direct tool calls (not via proxy) omit mcp_tool from payload.""" + middleware = LoggingMiddleware() + ctx = _make_context(name="list_charts") + call_next = AsyncMock(return_value="charts") + + await middleware.on_call_tool(ctx, call_next) + + payload = mock_event_logger.log.call_args[1]["curated_payload"] + assert payload["tool"] == "list_charts" + assert "mcp_tool" not in payload + + @patch("superset.mcp_service.middleware.event_logger") + @patch("superset.mcp_service.middleware.get_user_id", return_value=42) + @pytest.mark.asyncio + async def test_on_call_tool_proxy_failure_captures_both_fields( + self, mock_get_user_id: MagicMock, mock_event_logger: MagicMock + ) -> None: + """call_tool proxy failure captures mcp_tool and error_type.""" + middleware = LoggingMiddleware() + ctx = _make_context( + name="call_tool", + params={"name": "get_chart_data", "arguments": {"chart_id": 1}}, + ) + call_next = AsyncMock(side_effect=PermissionError("access denied")) + + with pytest.raises(PermissionError): + await middleware.on_call_tool(ctx, call_next) + + payload = mock_event_logger.log.call_args[1]["curated_payload"] + assert payload["tool"] == "call_tool" + assert payload["mcp_tool"] == "get_chart_data" + assert payload["success"] is False + assert payload["error_type"] == "PermissionError" + + +class TestResolveToolName: + """Tests for LoggingMiddleware._resolve_tool_name().""" + + def test_resolves_call_tool_proxy(self) -> None: + """Returns the real tool name when call_tool proxy is used.""" + assert ( + LoggingMiddleware._resolve_tool_name( + "call_tool", {"name": "list_datasets", "arguments": {}} + ) + == "list_datasets" + ) + + def test_returns_none_for_direct_tool(self) -> None: + """Returns None for direct tool calls (not via proxy).""" + assert LoggingMiddleware._resolve_tool_name("list_charts", {"page": 1}) is None + + def test_returns_none_when_name_missing(self) -> None: + """Returns None when call_tool params lack 'name'.""" + assert LoggingMiddleware._resolve_tool_name("call_tool", {"foo": "bar"}) is None + + def test_returns_none_for_empty_name(self) -> None: + """Returns None when call_tool params have empty 'name'.""" + assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": ""}) is None + + def test_returns_none_for_non_string_name(self) -> None: + """Returns None when call_tool name param is not a string.""" + assert LoggingMiddleware._resolve_tool_name("call_tool", {"name": 123}) is None + + def test_returns_none_for_search_tools(self) -> None: + """search_tools proxy is not resolved (no underlying tool name).""" + assert ( + LoggingMiddleware._resolve_tool_name("search_tools", {"query": "datasets"}) + is None + ) + class TestExtractContextInfo: """Tests for LoggingMiddleware._extract_context_info().""" @patch("superset.mcp_service.middleware.get_user_id", return_value=99) - def test_extract_with_metadata_agent_id(self, mock_get_user_id): + def test_extract_with_metadata_agent_id(self, mock_get_user_id) -> None: """Extracts agent_id from context.metadata.""" middleware = LoggingMiddleware() ctx = _make_context(metadata={"agent_id": "agent-123"}) @@ -261,7 +376,7 @@ class TestExtractContextInfo: "superset.mcp_service.middleware.get_user_id", side_effect=RuntimeError("no Flask request context"), ) - def test_extract_handles_missing_user(self, mock_get_user_id): + def test_extract_handles_missing_user(self, mock_get_user_id) -> None: """Gracefully handles missing user context.""" middleware = LoggingMiddleware() ctx = _make_context() @@ -273,7 +388,7 @@ class TestExtractContextInfo: assert user_id is None @patch("superset.mcp_service.middleware.get_user_id", return_value=1) - def test_extract_slice_id_from_chart_id(self, mock_get_user_id): + def test_extract_slice_id_from_chart_id(self, mock_get_user_id) -> None: """Extracts slice_id from chart_id param (alias).""" middleware = LoggingMiddleware() ctx = _make_context(params={"chart_id": 55}) @@ -283,7 +398,7 @@ class TestExtractContextInfo: assert slice_id == 55 @patch("superset.mcp_service.middleware.get_user_id", return_value=1) - def test_extract_slice_id_from_slice_id(self, mock_get_user_id): + def test_extract_slice_id_from_slice_id(self, mock_get_user_id) -> None: """Extracts slice_id from slice_id param (fallback).""" middleware = LoggingMiddleware() ctx = _make_context(params={"slice_id": 66}) @@ -296,7 +411,7 @@ class TestExtractContextInfo: class TestIsErrorResponse: """Tests for LoggingMiddleware._is_error_response().""" - def test_detects_error_schema_response(self): + def test_detects_error_schema_response(self) -> None: """Detects ToolResult containing a serialized error schema (ChartError, DashboardError, etc.) via "error_type" field.""" middleware = LoggingMiddleware() @@ -308,7 +423,7 @@ class TestIsErrorResponse: result = ToolResult(content=[mt.TextContent(type="text", text=error_json)]) assert middleware._is_error_response(result) is True - def test_success_response_not_detected_as_error(self): + def test_success_response_not_detected_as_error(self) -> None: """Normal ToolResult is not detected as error.""" middleware = LoggingMiddleware() result = ToolResult( @@ -316,7 +431,7 @@ class TestIsErrorResponse: ) assert middleware._is_error_response(result) is False - def test_empty_content_not_detected_as_error(self): + def test_empty_content_not_detected_as_error(self) -> None: """ToolResult with empty content is not detected as error.""" middleware = LoggingMiddleware() assert middleware._is_error_response(ToolResult(content=[])) is False @@ -326,7 +441,7 @@ class TestIsErrorResponse: @pytest.mark.asyncio async def test_on_call_tool_logs_failure_for_error_schema( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """on_call_tool logs success=False when tool returns an error schema (e.g. ChartError).""" middleware = LoggingMiddleware() @@ -366,7 +481,7 @@ class TestMiddlewareChainOrder: @pytest.mark.asyncio async def test_real_middleware_chain_logs_exception_as_failure( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """Tool exception is logged as success=False through the real middleware chain from build_middleware_list().""" from superset.mcp_service.server import build_middleware_list @@ -413,7 +528,7 @@ class TestMiddlewareChainOrder: @pytest.mark.asyncio async def test_real_middleware_chain_error_result_has_mcp_call_id( self, mock_get_user_id, mock_event_logger - ): + ) -> None: """When a tool raises, the error ToolResult from StructuredContentStripper still carries mcp_call_id in meta.""" from superset.mcp_service.server import build_middleware_list