mirror of
https://github.com/apache/superset.git
synced 2026-04-21 17:14:57 +00:00
365 lines
14 KiB
Python
365 lines
14 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""
|
|
Unit tests for LoggingMiddleware on_call_tool() and on_message() methods.
|
|
|
|
Tests verify that:
|
|
- on_call_tool() captures duration_ms and success status
|
|
- on_message() logs non-tool messages without duration
|
|
- _extract_context_info() extracts entity IDs from params
|
|
"""
|
|
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from superset.mcp_service.middleware import LoggingMiddleware
|
|
|
|
|
|
def _make_context(
|
|
method: str = "tools/call",
|
|
name: str = "list_charts",
|
|
params: dict[str, Any] | None = None,
|
|
metadata: dict[str, Any] | None = None,
|
|
):
|
|
"""Create a mock MiddlewareContext."""
|
|
ctx = MagicMock()
|
|
ctx.method = method
|
|
message = MagicMock()
|
|
message.name = name
|
|
message.params = params or {}
|
|
ctx.message = message
|
|
if metadata is not None:
|
|
ctx.metadata = metadata
|
|
else:
|
|
ctx.metadata = None
|
|
ctx.session = None
|
|
return ctx
|
|
|
|
|
|
class TestLoggingMiddlewareOnCallTool:
|
|
"""Tests for LoggingMiddleware.on_call_tool()."""
|
|
|
|
@patch("superset.mcp_service.middleware.event_logger")
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
|
@pytest.mark.asyncio
|
|
async def test_on_call_tool_logs_duration_and_success(
|
|
self, mock_get_user_id, mock_event_logger
|
|
):
|
|
"""on_call_tool records duration_ms and success=True on normal return."""
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context(name="list_charts")
|
|
call_next = AsyncMock(return_value="tool_result")
|
|
|
|
result = await middleware.on_call_tool(ctx, call_next)
|
|
|
|
assert result == "tool_result"
|
|
call_next.assert_awaited_once_with(ctx)
|
|
|
|
# Verify event_logger.log was called with duration_ms and success
|
|
mock_event_logger.log.assert_called_once()
|
|
call_kwargs = mock_event_logger.log.call_args[1]
|
|
assert call_kwargs["action"] == "mcp_tool_call"
|
|
assert call_kwargs["user_id"] == 42
|
|
assert isinstance(call_kwargs["duration_ms"], int)
|
|
assert call_kwargs["duration_ms"] >= 0
|
|
assert call_kwargs["curated_payload"]["success"] is True
|
|
assert call_kwargs["curated_payload"]["tool"] == "list_charts"
|
|
|
|
@patch("superset.mcp_service.middleware.event_logger")
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
|
@pytest.mark.asyncio
|
|
async def test_on_call_tool_logs_failure_on_exception(
|
|
self, mock_get_user_id, mock_event_logger
|
|
):
|
|
"""on_call_tool records success=False when tool raises."""
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context(name="execute_sql")
|
|
call_next = AsyncMock(side_effect=ValueError("boom"))
|
|
|
|
with pytest.raises(ValueError, match="boom"):
|
|
await middleware.on_call_tool(ctx, call_next)
|
|
|
|
# Verify event_logger.log was still called (in the finally block)
|
|
mock_event_logger.log.assert_called_once()
|
|
call_kwargs = mock_event_logger.log.call_args[1]
|
|
assert call_kwargs["curated_payload"]["success"] is False
|
|
assert call_kwargs["duration_ms"] >= 0
|
|
|
|
@patch("superset.mcp_service.middleware.event_logger")
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
|
@pytest.mark.asyncio
|
|
async def test_on_call_tool_logs_failure_on_tool_error(
|
|
self, mock_get_user_id, mock_event_logger
|
|
):
|
|
"""on_call_tool records success=False when GlobalErrorHandler raises ToolError.
|
|
|
|
This simulates the real middleware chain: GlobalErrorHandler catches
|
|
tool exceptions and re-raises them as ToolError. Since LoggingMiddleware
|
|
sits between GlobalErrorHandler and StructuredContentStripper, it
|
|
catches the ToolError directly.
|
|
"""
|
|
from fastmcp.exceptions import ToolError
|
|
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context(name="get_chart_info")
|
|
call_next = AsyncMock(side_effect=ToolError("Chart 999999 not found"))
|
|
|
|
with pytest.raises(ToolError, match="Chart 999999 not found"):
|
|
await middleware.on_call_tool(ctx, call_next)
|
|
|
|
mock_event_logger.log.assert_called_once()
|
|
call_kwargs = mock_event_logger.log.call_args[1]
|
|
assert call_kwargs["curated_payload"]["success"] is False
|
|
assert call_kwargs["curated_payload"]["tool"] == "get_chart_info"
|
|
assert call_kwargs["duration_ms"] >= 0
|
|
|
|
@patch("superset.mcp_service.middleware.event_logger")
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
|
@pytest.mark.asyncio
|
|
async def test_on_call_tool_extracts_entity_ids(
|
|
self, mock_get_user_id, mock_event_logger
|
|
):
|
|
"""on_call_tool extracts dashboard_id, chart_id, dataset_id from params."""
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context(
|
|
name="get_chart_info",
|
|
params={
|
|
"dashboard_id": 10,
|
|
"chart_id": 20,
|
|
"dataset_id": 30,
|
|
},
|
|
)
|
|
call_next = AsyncMock(return_value="ok")
|
|
|
|
await middleware.on_call_tool(ctx, call_next)
|
|
|
|
call_kwargs = mock_event_logger.log.call_args[1]
|
|
assert call_kwargs["dashboard_id"] == 10
|
|
assert call_kwargs["slice_id"] == 20
|
|
assert call_kwargs["curated_payload"]["dataset_id"] == 30
|
|
|
|
|
|
class TestLoggingMiddlewareOnMessage:
|
|
"""Tests for LoggingMiddleware.on_message()."""
|
|
|
|
@patch("superset.mcp_service.middleware.event_logger")
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=1)
|
|
@pytest.mark.asyncio
|
|
async def test_on_message_logs_without_duration(
|
|
self, mock_get_user_id, mock_event_logger
|
|
):
|
|
"""on_message logs with action=mcp_message and duration_ms=None."""
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context(method="resources/read", name="instance/metadata")
|
|
call_next = AsyncMock(return_value="resource_data")
|
|
|
|
result = await middleware.on_message(ctx, call_next)
|
|
|
|
assert result == "resource_data"
|
|
call_next.assert_awaited_once_with(ctx)
|
|
|
|
mock_event_logger.log.assert_called_once()
|
|
call_kwargs = mock_event_logger.log.call_args[1]
|
|
assert call_kwargs["action"] == "mcp_message"
|
|
assert call_kwargs["duration_ms"] is None
|
|
# on_message should NOT have success field
|
|
assert "success" not in call_kwargs["curated_payload"]
|
|
|
|
|
|
class TestExtractContextInfo:
|
|
"""Tests for LoggingMiddleware._extract_context_info()."""
|
|
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=99)
|
|
def test_extract_with_metadata_agent_id(self, mock_get_user_id):
|
|
"""Extracts agent_id from context.metadata."""
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context(metadata={"agent_id": "agent-123"})
|
|
|
|
agent_id, user_id, dashboard_id, slice_id, dataset_id, params = (
|
|
middleware._extract_context_info(ctx)
|
|
)
|
|
|
|
assert agent_id == "agent-123"
|
|
assert user_id == 99
|
|
|
|
@patch(
|
|
"superset.mcp_service.middleware.get_user_id",
|
|
side_effect=RuntimeError("no Flask request context"),
|
|
)
|
|
def test_extract_handles_missing_user(self, mock_get_user_id):
|
|
"""Gracefully handles missing user context."""
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context()
|
|
|
|
agent_id, user_id, dashboard_id, slice_id, dataset_id, params = (
|
|
middleware._extract_context_info(ctx)
|
|
)
|
|
|
|
assert user_id is None
|
|
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=1)
|
|
def test_extract_slice_id_from_chart_id(self, mock_get_user_id):
|
|
"""Extracts slice_id from chart_id param (alias)."""
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context(params={"chart_id": 55})
|
|
|
|
_, _, _, slice_id, _, _ = middleware._extract_context_info(ctx)
|
|
|
|
assert slice_id == 55
|
|
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=1)
|
|
def test_extract_slice_id_from_slice_id(self, mock_get_user_id):
|
|
"""Extracts slice_id from slice_id param (fallback)."""
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context(params={"slice_id": 66})
|
|
|
|
_, _, _, slice_id, _, _ = middleware._extract_context_info(ctx)
|
|
|
|
assert slice_id == 66
|
|
|
|
|
|
class TestIsErrorResponse:
|
|
"""Tests for LoggingMiddleware._is_error_response()."""
|
|
|
|
def test_detects_error_schema_response(self):
|
|
"""Detects ToolResult containing a serialized error schema
|
|
(ChartError, DashboardError, etc.) via "error_type" field."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
from mcp import types as mt
|
|
|
|
middleware = LoggingMiddleware()
|
|
error_json = (
|
|
'{"error": "Chart 999 not found",'
|
|
' "error_type": "not_found",'
|
|
' "timestamp": "2026-04-09T00:00:00Z"}'
|
|
)
|
|
result = ToolResult(content=[mt.TextContent(type="text", text=error_json)])
|
|
assert middleware._is_error_response(result) is True
|
|
|
|
def test_success_response_not_detected_as_error(self):
|
|
"""Normal ToolResult is not detected as error."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
from mcp import types as mt
|
|
|
|
middleware = LoggingMiddleware()
|
|
result = ToolResult(
|
|
content=[mt.TextContent(type="text", text="Successfully retrieved data")]
|
|
)
|
|
assert middleware._is_error_response(result) is False
|
|
|
|
def test_empty_content_not_detected_as_error(self):
|
|
"""ToolResult with empty content is not detected as error."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
|
|
middleware = LoggingMiddleware()
|
|
assert middleware._is_error_response(ToolResult(content=[])) is False
|
|
|
|
@patch("superset.mcp_service.middleware.event_logger")
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
|
@pytest.mark.asyncio
|
|
async def test_on_call_tool_logs_failure_for_error_schema(
|
|
self, mock_get_user_id, mock_event_logger
|
|
):
|
|
"""on_call_tool logs success=False when tool returns an
|
|
error schema (e.g. ChartError)."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
from mcp import types as mt
|
|
|
|
middleware = LoggingMiddleware()
|
|
ctx = _make_context(name="get_chart_info")
|
|
|
|
error_json = (
|
|
'{"error": "Chart 999999 not found",'
|
|
' "error_type": "not_found",'
|
|
' "timestamp": "2026-04-09T00:00:00Z"}'
|
|
)
|
|
error_result = ToolResult(
|
|
content=[mt.TextContent(type="text", text=error_json)]
|
|
)
|
|
call_next = AsyncMock(return_value=error_result)
|
|
|
|
result = await middleware.on_call_tool(ctx, call_next)
|
|
|
|
assert result == error_result
|
|
mock_event_logger.log.assert_called_once()
|
|
call_kwargs = mock_event_logger.log.call_args[1]
|
|
assert call_kwargs["curated_payload"]["success"] is False
|
|
assert call_kwargs["curated_payload"]["tool"] == "get_chart_info"
|
|
|
|
|
|
class TestMiddlewareChainOrder:
|
|
"""Test that the middleware order from server.py logs failures correctly.
|
|
|
|
If the order is wrong (StructuredContentStripper innermost),
|
|
it swallows exceptions before LoggingMiddleware can see them,
|
|
causing success=True for failures.
|
|
"""
|
|
|
|
@patch("superset.mcp_service.middleware.event_logger")
|
|
@patch("superset.mcp_service.middleware.get_user_id", return_value=42)
|
|
@pytest.mark.asyncio
|
|
async def test_real_middleware_chain_logs_exception_as_failure(
|
|
self, mock_get_user_id, mock_event_logger
|
|
):
|
|
"""Tool exception is logged as success=False through the
|
|
real middleware chain from build_middleware_list()."""
|
|
from functools import partial
|
|
|
|
from fastmcp.tools.tool import ToolResult
|
|
|
|
from superset.mcp_service.server import build_middleware_list
|
|
|
|
middleware_list = build_middleware_list()
|
|
|
|
async def failing_tool(context: Any) -> Any:
|
|
raise ValueError("chart not found")
|
|
|
|
# Build chain same way FastMCP does
|
|
chain = failing_tool
|
|
for mw in reversed(middleware_list):
|
|
chain = partial(mw, call_next=chain)
|
|
|
|
ctx = _make_context(name="get_chart_info")
|
|
result = await chain(ctx)
|
|
|
|
# StructuredContentStripper (outermost) must catch the re-raised
|
|
# exception and convert it to a safe ToolResult with "Error:" text.
|
|
# If it's not outermost, the exception would leak to the MCP SDK.
|
|
assert isinstance(result, ToolResult)
|
|
assert result.content[0].text.startswith("Error:")
|
|
|
|
# LoggingMiddleware must log
|
|
# success=False. If the middleware order is wrong
|
|
# (StructuredContentStripper innermost), this would be
|
|
# success=True because the exception gets swallowed
|
|
# before LoggingMiddleware sees it.
|
|
log_calls = [
|
|
c
|
|
for c in mock_event_logger.log.call_args_list
|
|
if c[1].get("action") == "mcp_tool_call"
|
|
]
|
|
assert len(log_calls) == 1
|
|
assert log_calls[0][1]["curated_payload"]["success"] is False, (
|
|
"Middleware order is wrong: StructuredContentStripper is "
|
|
"swallowing exceptions before LoggingMiddleware can detect "
|
|
"them. Ensure StructuredContentStripper is outermost "
|
|
"(first added) in build_middleware_list()."
|
|
)
|