mirror of
https://github.com/apache/superset.git
synced 2026-05-07 17:04:58 +00:00
taking care of comments
This commit is contained in:
@@ -16,8 +16,8 @@
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from collections import defaultdict
|
||||
from typing import Any, Awaitable, Callable, Dict, Protocol, Sequence
|
||||
|
||||
@@ -240,7 +240,7 @@ class LoggingMiddleware(Middleware):
|
||||
)
|
||||
tool_name = getattr(context.message, "name", None)
|
||||
|
||||
mcp_call_id = uuid.uuid4().hex[:12]
|
||||
mcp_call_id = secrets.token_hex(16)
|
||||
context.mcp_call_id = mcp_call_id
|
||||
start_time = time.time()
|
||||
success = False
|
||||
|
||||
@@ -24,10 +24,14 @@ Tests verify that:
|
||||
- _extract_context_info() extracts entity IDs from params
|
||||
"""
|
||||
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp.exceptions import ToolError
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp import types as mt
|
||||
|
||||
from superset.mcp_service.middleware import LoggingMiddleware
|
||||
|
||||
@@ -115,8 +119,6 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
sits between GlobalErrorHandler and StructuredContentStripper, it
|
||||
catches the ToolError directly.
|
||||
"""
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="get_chart_info")
|
||||
call_next = AsyncMock(side_effect=ToolError("Chart 999999 not found"))
|
||||
@@ -146,7 +148,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
call_kwargs = mock_event_logger.log.call_args[1]
|
||||
mcp_call_id = call_kwargs["curated_payload"]["mcp_call_id"]
|
||||
assert isinstance(mcp_call_id, str)
|
||||
assert len(mcp_call_id) == 12
|
||||
assert len(mcp_call_id) == 32
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@@ -155,9 +157,6 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""on_call_tool injects mcp_call_id into ToolResult.meta."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp import types as mt
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
original_result = ToolResult(content=[mt.TextContent(type="text", text="ok")])
|
||||
@@ -167,7 +166,7 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
|
||||
assert isinstance(result, ToolResult)
|
||||
assert "mcp_call_id" in result.meta
|
||||
assert len(result.meta["mcp_call_id"]) == 12
|
||||
assert len(result.meta["mcp_call_id"]) == 32
|
||||
|
||||
@patch("superset.mcp_service.middleware.event_logger")
|
||||
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
||||
@@ -176,9 +175,6 @@ class TestLoggingMiddlewareOnCallTool:
|
||||
self, mock_get_user_id, mock_event_logger
|
||||
):
|
||||
"""on_call_tool merges mcp_call_id with existing ToolResult.meta."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp import types as mt
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="list_charts")
|
||||
original_result = ToolResult(
|
||||
@@ -303,9 +299,6 @@ class TestIsErrorResponse:
|
||||
def test_detects_error_schema_response(self):
|
||||
"""Detects ToolResult containing a serialized error schema
|
||||
(ChartError, DashboardError, etc.) via "error_type" field."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp import types as mt
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
error_json = (
|
||||
'{"error": "Chart 999 not found",'
|
||||
@@ -317,9 +310,6 @@ class TestIsErrorResponse:
|
||||
|
||||
def test_success_response_not_detected_as_error(self):
|
||||
"""Normal ToolResult is not detected as error."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp import types as mt
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
result = ToolResult(
|
||||
content=[mt.TextContent(type="text", text="Successfully retrieved data")]
|
||||
@@ -328,8 +318,6 @@ class TestIsErrorResponse:
|
||||
|
||||
def test_empty_content_not_detected_as_error(self):
|
||||
"""ToolResult with empty content is not detected as error."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
assert middleware._is_error_response(ToolResult(content=[])) is False
|
||||
|
||||
@@ -341,9 +329,6 @@ class TestIsErrorResponse:
|
||||
):
|
||||
"""on_call_tool logs success=False when tool returns an
|
||||
error schema (e.g. ChartError)."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
from mcp import types as mt
|
||||
|
||||
middleware = LoggingMiddleware()
|
||||
ctx = _make_context(name="get_chart_info")
|
||||
|
||||
@@ -384,10 +369,6 @@ class TestMiddlewareChainOrder:
|
||||
):
|
||||
"""Tool exception is logged as success=False through the
|
||||
real middleware chain from build_middleware_list()."""
|
||||
from functools import partial
|
||||
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.mcp_service.server import build_middleware_list
|
||||
|
||||
middleware_list = build_middleware_list()
|
||||
@@ -435,10 +416,6 @@ class TestMiddlewareChainOrder:
|
||||
):
|
||||
"""When a tool raises, the error ToolResult from
|
||||
StructuredContentStripper still carries mcp_call_id in meta."""
|
||||
from functools import partial
|
||||
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
|
||||
from superset.mcp_service.server import build_middleware_list
|
||||
|
||||
middleware_list = build_middleware_list()
|
||||
@@ -457,4 +434,4 @@ class TestMiddlewareChainOrder:
|
||||
assert result.content[0].text.startswith("Error:")
|
||||
assert result.meta is not None
|
||||
assert "mcp_call_id" in result.meta
|
||||
assert len(result.meta["mcp_call_id"]) == 12
|
||||
assert len(result.meta["mcp_call_id"]) == 32
|
||||
|
||||
Reference in New Issue
Block a user