taking care of comments

This commit is contained in:
alexandrusoare
2026-05-04 12:36:03 +03:00
parent 5e3ca00bb2
commit ede52bd6b2
2 changed files with 9 additions and 32 deletions

View File

@@ -16,8 +16,8 @@
# under the License.
import logging
import secrets
import time
import uuid
from collections import defaultdict
from typing import Any, Awaitable, Callable, Dict, Protocol, Sequence
@@ -240,7 +240,7 @@ class LoggingMiddleware(Middleware):
)
tool_name = getattr(context.message, "name", None)
mcp_call_id = uuid.uuid4().hex[:12]
mcp_call_id = secrets.token_hex(16)
context.mcp_call_id = mcp_call_id
start_time = time.time()
success = False

View File

@@ -24,10 +24,14 @@ Tests verify that:
- _extract_context_info() extracts entity IDs from params
"""
from functools import partial
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastmcp.exceptions import ToolError
from fastmcp.tools.tool import ToolResult
from mcp import types as mt
from superset.mcp_service.middleware import LoggingMiddleware
@@ -115,8 +119,6 @@ class TestLoggingMiddlewareOnCallTool:
sits between GlobalErrorHandler and StructuredContentStripper, it
catches the ToolError directly.
"""
from fastmcp.exceptions import ToolError
middleware = LoggingMiddleware()
ctx = _make_context(name="get_chart_info")
call_next = AsyncMock(side_effect=ToolError("Chart 999999 not found"))
@@ -146,7 +148,7 @@ class TestLoggingMiddlewareOnCallTool:
call_kwargs = mock_event_logger.log.call_args[1]
mcp_call_id = call_kwargs["curated_payload"]["mcp_call_id"]
assert isinstance(mcp_call_id, str)
assert len(mcp_call_id) == 12
assert len(mcp_call_id) == 32
@patch("superset.mcp_service.middleware.event_logger")
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
@@ -155,9 +157,6 @@ class TestLoggingMiddlewareOnCallTool:
self, mock_get_user_id, mock_event_logger
):
"""on_call_tool injects mcp_call_id into ToolResult.meta."""
from fastmcp.tools.tool import ToolResult
from mcp import types as mt
middleware = LoggingMiddleware()
ctx = _make_context(name="list_charts")
original_result = ToolResult(content=[mt.TextContent(type="text", text="ok")])
@@ -167,7 +166,7 @@ class TestLoggingMiddlewareOnCallTool:
assert isinstance(result, ToolResult)
assert "mcp_call_id" in result.meta
assert len(result.meta["mcp_call_id"]) == 12
assert len(result.meta["mcp_call_id"]) == 32
@patch("superset.mcp_service.middleware.event_logger")
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
@@ -176,9 +175,6 @@ class TestLoggingMiddlewareOnCallTool:
self, mock_get_user_id, mock_event_logger
):
"""on_call_tool merges mcp_call_id with existing ToolResult.meta."""
from fastmcp.tools.tool import ToolResult
from mcp import types as mt
middleware = LoggingMiddleware()
ctx = _make_context(name="list_charts")
original_result = ToolResult(
@@ -303,9 +299,6 @@ class TestIsErrorResponse:
def test_detects_error_schema_response(self):
"""Detects ToolResult containing a serialized error schema
(ChartError, DashboardError, etc.) via "error_type" field."""
from fastmcp.tools.tool import ToolResult
from mcp import types as mt
middleware = LoggingMiddleware()
error_json = (
'{"error": "Chart 999 not found",'
@@ -317,9 +310,6 @@ class TestIsErrorResponse:
def test_success_response_not_detected_as_error(self):
"""Normal ToolResult is not detected as error."""
from fastmcp.tools.tool import ToolResult
from mcp import types as mt
middleware = LoggingMiddleware()
result = ToolResult(
content=[mt.TextContent(type="text", text="Successfully retrieved data")]
@@ -328,8 +318,6 @@ class TestIsErrorResponse:
def test_empty_content_not_detected_as_error(self):
"""ToolResult with empty content is not detected as error."""
from fastmcp.tools.tool import ToolResult
middleware = LoggingMiddleware()
assert middleware._is_error_response(ToolResult(content=[])) is False
@@ -341,9 +329,6 @@ class TestIsErrorResponse:
):
"""on_call_tool logs success=False when tool returns an
error schema (e.g. ChartError)."""
from fastmcp.tools.tool import ToolResult
from mcp import types as mt
middleware = LoggingMiddleware()
ctx = _make_context(name="get_chart_info")
@@ -384,10 +369,6 @@ class TestMiddlewareChainOrder:
):
"""Tool exception is logged as success=False through the
real middleware chain from build_middleware_list()."""
from functools import partial
from fastmcp.tools.tool import ToolResult
from superset.mcp_service.server import build_middleware_list
middleware_list = build_middleware_list()
@@ -435,10 +416,6 @@ class TestMiddlewareChainOrder:
):
"""When a tool raises, the error ToolResult from
StructuredContentStripper still carries mcp_call_id in meta."""
from functools import partial
from fastmcp.tools.tool import ToolResult
from superset.mcp_service.server import build_middleware_list
middleware_list = build_middleware_list()
@@ -457,4 +434,4 @@ class TestMiddlewareChainOrder:
assert result.content[0].text.startswith("Error:")
assert result.meta is not None
assert "mcp_call_id" in result.meta
assert len(result.meta["mcp_call_id"]) == 12
assert len(result.meta["mcp_call_id"]) == 32