# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Unit tests for LoggingMiddleware on_call_tool() and on_message() methods. Tests verify that: - on_call_tool() captures duration_ms and success status - on_message() logs non-tool messages without duration - _extract_context_info() extracts entity IDs from params """ from functools import partial from typing import Any from unittest.mock import AsyncMock, MagicMock, patch import pytest from fastmcp.exceptions import ToolError from fastmcp.tools.tool import ToolResult from mcp import types as mt from superset.mcp_service.middleware import LoggingMiddleware def _make_context( method: str = "tools/call", name: str = "list_charts", params: dict[str, Any] | None = None, metadata: dict[str, Any] | None = None, ): """Create a mock MiddlewareContext.""" ctx = MagicMock() ctx.method = method message = MagicMock() message.name = name message.params = params or {} ctx.message = message if metadata is not None: ctx.metadata = metadata else: ctx.metadata = None ctx.session = None return ctx class TestLoggingMiddlewareOnCallTool: """Tests for LoggingMiddleware.on_call_tool().""" @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_on_call_tool_logs_duration_and_success( self, mock_get_user_id, mock_event_logger ): """on_call_tool records duration_ms and success=True on normal return.""" middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") call_next = AsyncMock(return_value="tool_result") result = await middleware.on_call_tool(ctx, call_next) assert result == "tool_result" call_next.assert_awaited_once_with(ctx) # Verify event_logger.log was called with duration_ms and success mock_event_logger.log.assert_called_once() call_kwargs = mock_event_logger.log.call_args[1] assert call_kwargs["action"] == "mcp_tool_call" assert call_kwargs["user_id"] == 42 assert isinstance(call_kwargs["duration_ms"], int) assert call_kwargs["duration_ms"] >= 0 assert call_kwargs["curated_payload"]["success"] is True assert call_kwargs["curated_payload"]["tool"] == "list_charts" @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_on_call_tool_logs_failure_on_exception( self, mock_get_user_id, mock_event_logger ): """on_call_tool records success=False when tool raises.""" middleware = LoggingMiddleware() ctx = _make_context(name="execute_sql") call_next = AsyncMock(side_effect=ValueError("boom")) with pytest.raises(ValueError, match="boom"): await middleware.on_call_tool(ctx, call_next) # Verify event_logger.log was still called (in the finally block) mock_event_logger.log.assert_called_once() call_kwargs = mock_event_logger.log.call_args[1] assert call_kwargs["curated_payload"]["success"] is False assert call_kwargs["duration_ms"] >= 0 @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_on_call_tool_logs_failure_on_tool_error( self, mock_get_user_id, mock_event_logger ): """on_call_tool records success=False when GlobalErrorHandler raises ToolError. This simulates the real middleware chain: GlobalErrorHandler catches tool exceptions and re-raises them as ToolError. Since LoggingMiddleware sits between GlobalErrorHandler and StructuredContentStripper, it catches the ToolError directly. """ middleware = LoggingMiddleware() ctx = _make_context(name="get_chart_info") call_next = AsyncMock(side_effect=ToolError("Chart 999999 not found")) with pytest.raises(ToolError, match="Chart 999999 not found"): await middleware.on_call_tool(ctx, call_next) mock_event_logger.log.assert_called_once() call_kwargs = mock_event_logger.log.call_args[1] assert call_kwargs["curated_payload"]["success"] is False assert call_kwargs["curated_payload"]["tool"] == "get_chart_info" assert call_kwargs["duration_ms"] >= 0 @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_on_call_tool_includes_mcp_call_id_in_curated_payload( self, mock_get_user_id, mock_event_logger ): """on_call_tool adds mcp_call_id to curated_payload.""" middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") call_next = AsyncMock(return_value="tool_result") await middleware.on_call_tool(ctx, call_next) call_kwargs = mock_event_logger.log.call_args[1] mcp_call_id = call_kwargs["curated_payload"]["mcp_call_id"] assert isinstance(mcp_call_id, str) assert len(mcp_call_id) == 32 @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_on_call_tool_injects_mcp_call_id_into_tool_result_meta( self, mock_get_user_id, mock_event_logger ): """on_call_tool injects mcp_call_id into ToolResult.meta.""" middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") original_result = ToolResult(content=[mt.TextContent(type="text", text="ok")]) call_next = AsyncMock(return_value=original_result) result = await middleware.on_call_tool(ctx, call_next) assert isinstance(result, ToolResult) assert "mcp_call_id" in result.meta assert len(result.meta["mcp_call_id"]) == 32 @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_on_call_tool_preserves_existing_meta( self, mock_get_user_id, mock_event_logger ): """on_call_tool merges mcp_call_id with existing ToolResult.meta.""" middleware = LoggingMiddleware() ctx = _make_context(name="list_charts") original_result = ToolResult( content=[mt.TextContent(type="text", text="ok")], meta={"existing_key": "existing_value"}, ) call_next = AsyncMock(return_value=original_result) result = await middleware.on_call_tool(ctx, call_next) assert result.meta["existing_key"] == "existing_value" assert "mcp_call_id" in result.meta @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_on_call_tool_extracts_entity_ids( self, mock_get_user_id, mock_event_logger ): """on_call_tool extracts dashboard_id, chart_id, dataset_id from params.""" middleware = LoggingMiddleware() ctx = _make_context( name="get_chart_info", params={ "dashboard_id": 10, "chart_id": 20, "dataset_id": 30, }, ) call_next = AsyncMock(return_value="ok") await middleware.on_call_tool(ctx, call_next) call_kwargs = mock_event_logger.log.call_args[1] assert call_kwargs["dashboard_id"] == 10 assert call_kwargs["slice_id"] == 20 assert call_kwargs["curated_payload"]["dataset_id"] == 30 class TestLoggingMiddlewareOnMessage: """Tests for LoggingMiddleware.on_message().""" @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=1) @pytest.mark.asyncio async def test_on_message_logs_without_duration( self, mock_get_user_id, mock_event_logger ): """on_message logs with action=mcp_message and duration_ms=None.""" middleware = LoggingMiddleware() ctx = _make_context(method="resources/read", name="instance/metadata") call_next = AsyncMock(return_value="resource_data") result = await middleware.on_message(ctx, call_next) assert result == "resource_data" call_next.assert_awaited_once_with(ctx) mock_event_logger.log.assert_called_once() call_kwargs = mock_event_logger.log.call_args[1] assert call_kwargs["action"] == "mcp_message" assert call_kwargs["duration_ms"] is None # on_message should NOT have success field assert "success" not in call_kwargs["curated_payload"] class TestExtractContextInfo: """Tests for LoggingMiddleware._extract_context_info().""" @patch("superset.mcp_service.middleware.get_user_id", return_value=99) def test_extract_with_metadata_agent_id(self, mock_get_user_id): """Extracts agent_id from context.metadata.""" middleware = LoggingMiddleware() ctx = _make_context(metadata={"agent_id": "agent-123"}) agent_id, user_id, dashboard_id, slice_id, dataset_id, params = ( middleware._extract_context_info(ctx) ) assert agent_id == "agent-123" assert user_id == 99 @patch( "superset.mcp_service.middleware.get_user_id", side_effect=RuntimeError("no Flask request context"), ) def test_extract_handles_missing_user(self, mock_get_user_id): """Gracefully handles missing user context.""" middleware = LoggingMiddleware() ctx = _make_context() agent_id, user_id, dashboard_id, slice_id, dataset_id, params = ( middleware._extract_context_info(ctx) ) assert user_id is None @patch("superset.mcp_service.middleware.get_user_id", return_value=1) def test_extract_slice_id_from_chart_id(self, mock_get_user_id): """Extracts slice_id from chart_id param (alias).""" middleware = LoggingMiddleware() ctx = _make_context(params={"chart_id": 55}) _, _, _, slice_id, _, _ = middleware._extract_context_info(ctx) assert slice_id == 55 @patch("superset.mcp_service.middleware.get_user_id", return_value=1) def test_extract_slice_id_from_slice_id(self, mock_get_user_id): """Extracts slice_id from slice_id param (fallback).""" middleware = LoggingMiddleware() ctx = _make_context(params={"slice_id": 66}) _, _, _, slice_id, _, _ = middleware._extract_context_info(ctx) assert slice_id == 66 class TestIsErrorResponse: """Tests for LoggingMiddleware._is_error_response().""" def test_detects_error_schema_response(self): """Detects ToolResult containing a serialized error schema (ChartError, DashboardError, etc.) via "error_type" field.""" middleware = LoggingMiddleware() error_json = ( '{"error": "Chart 999 not found",' ' "error_type": "not_found",' ' "timestamp": "2026-04-09T00:00:00Z"}' ) result = ToolResult(content=[mt.TextContent(type="text", text=error_json)]) assert middleware._is_error_response(result) is True def test_success_response_not_detected_as_error(self): """Normal ToolResult is not detected as error.""" middleware = LoggingMiddleware() result = ToolResult( content=[mt.TextContent(type="text", text="Successfully retrieved data")] ) assert middleware._is_error_response(result) is False def test_empty_content_not_detected_as_error(self): """ToolResult with empty content is not detected as error.""" middleware = LoggingMiddleware() assert middleware._is_error_response(ToolResult(content=[])) is False @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_on_call_tool_logs_failure_for_error_schema( self, mock_get_user_id, mock_event_logger ): """on_call_tool logs success=False when tool returns an error schema (e.g. ChartError).""" middleware = LoggingMiddleware() ctx = _make_context(name="get_chart_info") error_json = ( '{"error": "Chart 999999 not found",' ' "error_type": "not_found",' ' "timestamp": "2026-04-09T00:00:00Z"}' ) error_result = ToolResult( content=[mt.TextContent(type="text", text=error_json)] ) call_next = AsyncMock(return_value=error_result) result = await middleware.on_call_tool(ctx, call_next) assert isinstance(result, ToolResult) assert result.content == error_result.content assert "mcp_call_id" in result.meta mock_event_logger.log.assert_called_once() call_kwargs = mock_event_logger.log.call_args[1] assert call_kwargs["curated_payload"]["success"] is False assert call_kwargs["curated_payload"]["tool"] == "get_chart_info" class TestMiddlewareChainOrder: """Test that the middleware order from server.py logs failures correctly. If the order is wrong (StructuredContentStripper innermost), it swallows exceptions before LoggingMiddleware can see them, causing success=True for failures. """ @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_real_middleware_chain_logs_exception_as_failure( self, mock_get_user_id, mock_event_logger ): """Tool exception is logged as success=False through the real middleware chain from build_middleware_list().""" from superset.mcp_service.server import build_middleware_list middleware_list = build_middleware_list() async def failing_tool(context: Any) -> Any: raise ValueError("chart not found") # Build chain same way FastMCP does chain = failing_tool for mw in reversed(middleware_list): chain = partial(mw, call_next=chain) ctx = _make_context(name="get_chart_info") result = await chain(ctx) # StructuredContentStripper (outermost) must catch the re-raised # exception and convert it to a safe ToolResult with "Error:" text. # If it's not outermost, the exception would leak to the MCP SDK. assert isinstance(result, ToolResult) assert result.content[0].text.startswith("Error:") # LoggingMiddleware must log # success=False. If the middleware order is wrong # (StructuredContentStripper innermost), this would be # success=True because the exception gets swallowed # before LoggingMiddleware sees it. log_calls = [ c for c in mock_event_logger.log.call_args_list if c[1].get("action") == "mcp_tool_call" ] assert len(log_calls) == 1 assert log_calls[0][1]["curated_payload"]["success"] is False, ( "Middleware order is wrong: StructuredContentStripper is " "swallowing exceptions before LoggingMiddleware can detect " "them. Ensure StructuredContentStripper is outermost " "(first added) in build_middleware_list()." ) @patch("superset.mcp_service.middleware.event_logger") @patch("superset.mcp_service.middleware.get_user_id", return_value=42) @pytest.mark.asyncio async def test_real_middleware_chain_error_result_has_mcp_call_id( self, mock_get_user_id, mock_event_logger ): """When a tool raises, the error ToolResult from StructuredContentStripper still carries mcp_call_id in meta.""" from superset.mcp_service.server import build_middleware_list middleware_list = build_middleware_list() async def failing_tool(context: Any) -> Any: raise ValueError("chart not found") chain = failing_tool for mw in reversed(middleware_list): chain = partial(mw, call_next=chain) ctx = _make_context(name="get_chart_info") result = await chain(ctx) assert isinstance(result, ToolResult) assert result.content[0].text.startswith("Error:") assert result.meta is not None assert "mcp_call_id" in result.meta assert len(result.meta["mcp_call_id"]) == 32