diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index a90c8448321..ba9ad417a73 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -16,8 +16,8 @@ # under the License. import logging +import secrets import time -import uuid from collections import defaultdict from typing import Any, Awaitable, Callable, Dict, Protocol, Sequence @@ -240,7 +240,7 @@ class LoggingMiddleware(Middleware): ) tool_name = getattr(context.message, "name", None) - mcp_call_id = uuid.uuid4().hex[:12] + mcp_call_id = secrets.token_hex(16) context.mcp_call_id = mcp_call_id start_time = time.time() success = False diff --git a/tests/unit_tests/mcp_service/test_middleware_logging.py b/tests/unit_tests/mcp_service/test_middleware_logging.py index ca68689557b..19487810bbf 100644 --- a/tests/unit_tests/mcp_service/test_middleware_logging.py +++ b/tests/unit_tests/mcp_service/test_middleware_logging.py @@ -24,10 +24,14 @@ Tests verify that: - _extract_context_info() extracts entity IDs from params """ +from functools import partial from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest +from fastmcp.exceptions import ToolError +from fastmcp.tools.tool import ToolResult +from mcp import types as mt from superset.mcp_service.middleware import LoggingMiddleware @@ -115,8 +119,6 @@ class TestLoggingMiddlewareOnCallTool: sits between GlobalErrorHandler and StructuredContentStripper, it catches the ToolError directly. """ - from fastmcp.exceptions import ToolError - middleware = LoggingMiddleware() ctx = _make_context(name="get_chart_info") call_next = AsyncMock(side_effect=ToolError("Chart 999999 not found")) @@ -146,7 +148,7 @@ class TestLoggingMiddlewareOnCallTool: call_kwargs = mock_event_logger.log.call_args[1] mcp_call_id = call_kwargs["curated_payload"]["mcp_call_id"] assert isinstance(mcp_call_id, str) - assert len(mcp_call_id) == 12 + assert len(mcp_call_id) == 32 @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @@ -155,9 +157,6 @@ class TestLoggingMiddlewareOnCallTool: self, mock_get_user_id, mock_event_logger ): """on_call_tool injects mcp_call_id into ToolResult.meta.""" - from fastmcp.tools.tool import ToolResult - from mcp import types as mt - middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") original_result = ToolResult(content=[mt.TextContent(type="text", text="ok")]) @@ -167,7 +166,7 @@ class TestLoggingMiddlewareOnCallTool: assert isinstance(result, ToolResult) assert "mcp_call_id" in result.meta - assert len(result.meta["mcp_call_id"]) == 12 + assert len(result.meta["mcp_call_id"]) == 32 @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @@ -176,9 +175,6 @@ class TestLoggingMiddlewareOnCallTool: self, mock_get_user_id, mock_event_logger ): """on_call_tool merges mcp_call_id with existing ToolResult.meta.""" - from fastmcp.tools.tool import ToolResult - from mcp import types as mt - middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") original_result = ToolResult( @@ -303,9 +299,6 @@ class TestIsErrorResponse: def test_detects_error_schema_response(self): """Detects ToolResult containing a serialized error schema (ChartError, DashboardError, etc.) via "error_type" field.""" - from fastmcp.tools.tool import ToolResult - from mcp import types as mt - middleware = LoggingMiddleware() error_json = ( '{"error": "Chart 999 not found",' @@ -317,9 +310,6 @@ class TestIsErrorResponse: def test_success_response_not_detected_as_error(self): """Normal ToolResult is not detected as error.""" - from fastmcp.tools.tool import ToolResult - from mcp import types as mt - middleware = LoggingMiddleware() result = ToolResult( content=[mt.TextContent(type="text", text="Successfully retrieved data")] @@ -328,8 +318,6 @@ class TestIsErrorResponse: def test_empty_content_not_detected_as_error(self): """ToolResult with empty content is not detected as error.""" - from fastmcp.tools.tool import ToolResult - middleware = LoggingMiddleware() assert middleware._is_error_response(ToolResult(content=[])) is False @@ -341,9 +329,6 @@ class TestIsErrorResponse: ): """on_call_tool logs success=False when tool returns an error schema (e.g. ChartError).""" - from fastmcp.tools.tool import ToolResult - from mcp import types as mt - middleware = LoggingMiddleware() ctx = _make_context(name="get_chart_info") @@ -384,10 +369,6 @@ class TestMiddlewareChainOrder: ): """Tool exception is logged as success=False through the real middleware chain from build_middleware_list().""" - from functools import partial - - from fastmcp.tools.tool import ToolResult - from superset.mcp_service.server import build_middleware_list middleware_list = build_middleware_list() @@ -435,10 +416,6 @@ class TestMiddlewareChainOrder: ): """When a tool raises, the error ToolResult from StructuredContentStripper still carries mcp_call_id in meta.""" - from functools import partial - - from fastmcp.tools.tool import ToolResult - from superset.mcp_service.server import build_middleware_list middleware_list = build_middleware_list() @@ -457,4 +434,4 @@ class TestMiddlewareChainOrder: assert result.content[0].text.startswith("Error:") assert result.meta is not None assert "mcp_call_id" in result.meta - assert len(result.meta["mcp_call_id"]) == 12 + assert len(result.meta["mcp_call_id"]) == 32