mirror of
https://github.com/apache/superset.git
synced 2026-05-08 01:15:46 +00:00
1033 lines
38 KiB
Python
1033 lines
38 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""
|
|
Unit tests for MCP service middleware.
|
|
"""
|
|
|
|
from typing import Any
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
from fastmcp.exceptions import ToolError
|
|
from pydantic import ValidationError
|
|
from sqlalchemy.exc import OperationalError
|
|
|
|
from superset.commands.exceptions import (
|
|
CommandInvalidError,
|
|
ForbiddenError,
|
|
ObjectNotFoundError,
|
|
)
|
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
|
from superset.exceptions import SupersetException, SupersetSecurityException
|
|
from superset.mcp_service.mcp_config import MCP_RESPONSE_SIZE_CONFIG
|
|
from superset.mcp_service.middleware import (
|
|
_is_user_error,
|
|
create_response_size_guard_middleware,
|
|
GlobalErrorHandlerMiddleware,
|
|
ResponseSizeGuardMiddleware,
|
|
)
|
|
|
|
|
|
class TestResponseSizeGuardMiddleware:
|
|
"""Test ResponseSizeGuardMiddleware class."""
|
|
|
|
def test_init_default_values(self) -> None:
|
|
"""Should initialize with default values."""
|
|
middleware = ResponseSizeGuardMiddleware()
|
|
assert middleware.token_limit == 25_000
|
|
assert middleware.warn_threshold_pct == 80
|
|
assert middleware.warn_threshold == 20000
|
|
assert middleware.excluded_tools == set()
|
|
|
|
def test_init_custom_values(self) -> None:
|
|
"""Should initialize with custom values."""
|
|
middleware = ResponseSizeGuardMiddleware(
|
|
token_limit=10000,
|
|
warn_threshold_pct=70,
|
|
excluded_tools=["health_check", "get_chart_preview"],
|
|
)
|
|
assert middleware.token_limit == 10000
|
|
assert middleware.warn_threshold_pct == 70
|
|
assert middleware.warn_threshold == 7000
|
|
assert middleware.excluded_tools == {"health_check", "get_chart_preview"}
|
|
|
|
def test_init_excluded_tools_as_string(self) -> None:
|
|
"""Should handle excluded_tools as a single string."""
|
|
middleware = ResponseSizeGuardMiddleware(
|
|
excluded_tools="health_check",
|
|
)
|
|
assert middleware.excluded_tools == {"health_check"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_allows_small_response(self) -> None:
|
|
"""Should allow responses under token limit."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
|
|
|
# Create mock context
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.message.params = {}
|
|
|
|
# Create mock call_next that returns small response
|
|
small_response = {"charts": [{"id": 1, "name": "test"}]}
|
|
call_next = AsyncMock(return_value=small_response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
):
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
|
|
assert result == small_response
|
|
call_next.assert_called_once_with(context)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_blocks_large_response(self) -> None:
|
|
"""Should block responses over token limit."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=100) # Very low limit
|
|
|
|
# Create mock context
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.message.params = {"page_size": 100}
|
|
|
|
# Create large response
|
|
large_response = {
|
|
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(1000)]
|
|
}
|
|
call_next = AsyncMock(return_value=large_response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
pytest.raises(ToolError) as exc_info,
|
|
):
|
|
await middleware.on_call_tool(context, call_next)
|
|
|
|
# Verify error contains helpful information
|
|
error_message = str(exc_info.value)
|
|
assert "Response too large" in error_message
|
|
assert "limit" in error_message.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_skips_excluded_tools(self) -> None:
|
|
"""Should skip checking for excluded tools."""
|
|
middleware = ResponseSizeGuardMiddleware(
|
|
token_limit=100, excluded_tools=["health_check"]
|
|
)
|
|
|
|
# Create mock context for excluded tool
|
|
context = MagicMock()
|
|
context.message.name = "health_check"
|
|
context.message.params = {}
|
|
|
|
# Create response that would exceed limit
|
|
large_response = {"data": "x" * 10000}
|
|
call_next = AsyncMock(return_value=large_response)
|
|
|
|
# Should not raise even though response exceeds limit
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
assert result == large_response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_logs_warning_at_threshold(self) -> None:
|
|
"""Should log warning when approaching limit."""
|
|
middleware = ResponseSizeGuardMiddleware(
|
|
token_limit=1000, warn_threshold_pct=80
|
|
)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.message.params = {}
|
|
|
|
# Response at ~85% of limit (should trigger warning but not block)
|
|
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
|
|
call_next = AsyncMock(return_value=response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
|
):
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
|
|
# Should return response (not blocked)
|
|
assert result == response
|
|
# Should log warning
|
|
mock_logger.warning.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_error_includes_suggestions(self) -> None:
|
|
"""Should include suggestions in error message."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.message.params = {"page_size": 100}
|
|
|
|
large_response = {"charts": [{"id": i} for i in range(1000)]}
|
|
call_next = AsyncMock(return_value=large_response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
pytest.raises(ToolError) as exc_info,
|
|
):
|
|
await middleware.on_call_tool(context, call_next)
|
|
|
|
error_message = str(exc_info.value)
|
|
# Should have numbered suggestions
|
|
assert "1." in error_message
|
|
# Should suggest reducing page_size
|
|
assert "page_size" in error_message.lower() or "limit" in error_message.lower()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_logs_size_exceeded_event(self) -> None:
|
|
"""Should log to event logger when size exceeded."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.message.params = {}
|
|
|
|
large_response = {"data": "x" * 10000}
|
|
call_next = AsyncMock(return_value=large_response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger") as mock_event_logger,
|
|
pytest.raises(ToolError),
|
|
):
|
|
await middleware.on_call_tool(context, call_next)
|
|
|
|
# Should log to event logger
|
|
mock_event_logger.log.assert_called()
|
|
call_args = mock_event_logger.log.call_args
|
|
assert call_args.kwargs["action"] == "mcp_response_size_exceeded"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncates_info_tool_instead_of_blocking(self) -> None:
|
|
"""Should truncate info tool responses instead of blocking them."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "get_dataset_info"
|
|
context.message.params = {}
|
|
|
|
# Large info tool response with a big description
|
|
large_response = {
|
|
"id": 1,
|
|
"table_name": "test",
|
|
"description": "x" * 50000,
|
|
}
|
|
call_next = AsyncMock(return_value=large_response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
):
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
|
|
# Should return truncated response, not raise ToolError
|
|
assert isinstance(result, dict)
|
|
assert result["id"] == 1
|
|
assert result["_response_truncated"] is True
|
|
assert "[truncated" in result["description"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_truncates_chart_info_with_large_form_data(self) -> None:
|
|
"""Should truncate get_chart_info with large form_data."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "get_chart_info"
|
|
context.message.params = {}
|
|
|
|
large_response = {
|
|
"id": 1,
|
|
"slice_name": "My Chart",
|
|
"form_data": {f"key_{i}": f"value_{i}" for i in range(100)},
|
|
}
|
|
call_next = AsyncMock(return_value=large_response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
):
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
|
|
assert isinstance(result, dict)
|
|
assert result["id"] == 1
|
|
assert result["_response_truncated"] is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_still_blocks_non_info_tools(self) -> None:
|
|
"""Should still block non-info tools that exceed limit."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts" # Not an info tool
|
|
context.message.params = {}
|
|
|
|
large_response = {"data": "x" * 10000}
|
|
call_next = AsyncMock(return_value=large_response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
pytest.raises(ToolError),
|
|
):
|
|
await middleware.on_call_tool(context, call_next)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_logs_truncation_event(self) -> None:
|
|
"""Should log mcp_response_truncated event on successful truncation."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "get_dashboard_info"
|
|
context.message.params = {}
|
|
|
|
large_response = {
|
|
"id": 1,
|
|
"description": "x" * 50000,
|
|
}
|
|
call_next = AsyncMock(return_value=large_response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger") as mock_event_logger,
|
|
):
|
|
await middleware.on_call_tool(context, call_next)
|
|
|
|
# Should log truncation event (not size_exceeded)
|
|
mock_event_logger.log.assert_called()
|
|
call_args = mock_event_logger.log.call_args
|
|
assert call_args.kwargs["action"] == "mcp_response_truncated"
|
|
|
|
|
|
class TestCreateResponseSizeGuardMiddleware:
|
|
"""Test create_response_size_guard_middleware factory function."""
|
|
|
|
def test_default_config_checks_chart_preview(self) -> None:
|
|
"""Should size-check chart preview responses by default."""
|
|
mock_flask_app = MagicMock()
|
|
mock_flask_app.config.get.return_value = MCP_RESPONSE_SIZE_CONFIG
|
|
|
|
with patch(
|
|
"superset.mcp_service.flask_singleton.get_flask_app",
|
|
return_value=mock_flask_app,
|
|
):
|
|
middleware = create_response_size_guard_middleware()
|
|
|
|
assert middleware is not None
|
|
assert "get_chart_preview" not in middleware.excluded_tools
|
|
assert "health_check" in middleware.excluded_tools
|
|
|
|
def test_creates_middleware_when_enabled(self) -> None:
|
|
"""Should create middleware when enabled in config."""
|
|
mock_config = {
|
|
"enabled": True,
|
|
"token_limit": 30000,
|
|
"warn_threshold_pct": 75,
|
|
"excluded_tools": ["health_check"],
|
|
}
|
|
|
|
mock_flask_app = MagicMock()
|
|
mock_flask_app.config.get.return_value = mock_config
|
|
|
|
with patch(
|
|
"superset.mcp_service.flask_singleton.get_flask_app",
|
|
return_value=mock_flask_app,
|
|
):
|
|
middleware = create_response_size_guard_middleware()
|
|
|
|
assert middleware is not None
|
|
assert isinstance(middleware, ResponseSizeGuardMiddleware)
|
|
assert middleware.token_limit == 30000
|
|
assert middleware.warn_threshold_pct == 75
|
|
assert "health_check" in middleware.excluded_tools
|
|
|
|
def test_returns_none_when_disabled(self) -> None:
|
|
"""Should return None when disabled in config."""
|
|
mock_config = {"enabled": False}
|
|
|
|
mock_flask_app = MagicMock()
|
|
mock_flask_app.config.get.return_value = mock_config
|
|
|
|
with patch(
|
|
"superset.mcp_service.flask_singleton.get_flask_app",
|
|
return_value=mock_flask_app,
|
|
):
|
|
middleware = create_response_size_guard_middleware()
|
|
|
|
assert middleware is None
|
|
|
|
def test_uses_defaults_when_config_missing(self) -> None:
|
|
"""Should use defaults when config values are missing."""
|
|
mock_config = {"enabled": True} # Only enabled, no other values
|
|
|
|
mock_flask_app = MagicMock()
|
|
mock_flask_app.config.get.return_value = mock_config
|
|
|
|
with patch(
|
|
"superset.mcp_service.flask_singleton.get_flask_app",
|
|
return_value=mock_flask_app,
|
|
):
|
|
middleware = create_response_size_guard_middleware()
|
|
|
|
assert middleware is not None
|
|
assert middleware.token_limit == 25_000 # Default
|
|
assert middleware.warn_threshold_pct == 80 # Default
|
|
|
|
def test_handles_exception_gracefully(self) -> None:
|
|
"""Should return None on expected configuration exceptions."""
|
|
with patch(
|
|
"superset.mcp_service.flask_singleton.get_flask_app",
|
|
side_effect=ImportError("Config error"),
|
|
):
|
|
middleware = create_response_size_guard_middleware()
|
|
|
|
assert middleware is None
|
|
|
|
|
|
class TestExtractPayloadFromToolResult:
|
|
"""Tests for _extract_payload_from_tool_result static method."""
|
|
|
|
def _make_tool_result(self, text: str, meta: dict[str, Any] | None = None) -> Any:
|
|
from fastmcp.tools.tool import ToolResult
|
|
from mcp.types import TextContent
|
|
|
|
return ToolResult(
|
|
content=[TextContent(type="text", text=text)],
|
|
meta=meta,
|
|
)
|
|
|
|
def test_extracts_dict_payload(self) -> None:
|
|
"""Should parse the JSON text inside content[0] and return the dict."""
|
|
from superset.utils import json
|
|
|
|
payload = {"id": 1, "name": "test", "charts": [1, 2, 3]}
|
|
tool_result = self._make_tool_result(json.dumps(payload))
|
|
|
|
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
|
tool_result
|
|
)
|
|
|
|
assert result is not None
|
|
assert result == payload
|
|
|
|
def test_returns_none_for_plain_dict(self) -> None:
|
|
"""Should return None when given a plain dict (not a ToolResult)."""
|
|
assert (
|
|
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
|
{"key": "val"}
|
|
)
|
|
is None
|
|
)
|
|
|
|
def test_returns_none_for_string(self) -> None:
|
|
"""Should return None when given a plain string."""
|
|
assert (
|
|
ResponseSizeGuardMiddleware._extract_payload_from_tool_result("string")
|
|
is None
|
|
)
|
|
|
|
def test_returns_none_for_list(self) -> None:
|
|
"""Should return None when given a plain list."""
|
|
assert (
|
|
ResponseSizeGuardMiddleware._extract_payload_from_tool_result([1, 2, 3])
|
|
is None
|
|
)
|
|
|
|
def test_returns_none_for_none(self) -> None:
|
|
"""Should return None when given None."""
|
|
assert (
|
|
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(None) is None
|
|
)
|
|
|
|
def test_returns_none_when_payload_is_list(self) -> None:
|
|
"""Should return None when JSON payload is a list, not a dict."""
|
|
from superset.utils import json
|
|
|
|
tool_result = self._make_tool_result(json.dumps([{"id": 1}, {"id": 2}]))
|
|
|
|
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
|
tool_result
|
|
)
|
|
|
|
assert result is None
|
|
|
|
def test_returns_none_for_invalid_json(self) -> None:
|
|
"""Should return None when content text is not valid JSON."""
|
|
tool_result = self._make_tool_result("not valid {{{json")
|
|
|
|
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
|
tool_result
|
|
)
|
|
|
|
assert result is None
|
|
|
|
def test_returns_none_for_empty_text(self) -> None:
|
|
"""Should return None when content[0].text is empty."""
|
|
tool_result = self._make_tool_result("")
|
|
|
|
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
|
tool_result
|
|
)
|
|
|
|
assert result is None
|
|
|
|
def test_returns_none_for_empty_content(self) -> None:
|
|
"""Should return None when ToolResult has no content items."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
|
|
tool_result = ToolResult(content=[], meta=None)
|
|
|
|
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
|
|
tool_result
|
|
)
|
|
|
|
assert result is None
|
|
|
|
|
|
class TestRewrapAsToolResult:
|
|
"""Tests for _rewrap_as_tool_result static method."""
|
|
|
|
def _make_tool_result(
|
|
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
|
|
) -> Any:
|
|
from fastmcp.tools.tool import ToolResult
|
|
from mcp.types import TextContent
|
|
|
|
from superset.utils import json
|
|
|
|
return ToolResult(
|
|
content=[TextContent(type="text", text=json.dumps(payload))],
|
|
meta=meta,
|
|
)
|
|
|
|
def test_returns_tool_result_with_serialized_payload(self) -> None:
|
|
"""Should return a ToolResult whose content[0].text is the JSON payload."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
|
|
from superset.utils import json
|
|
|
|
original = self._make_tool_result({"old": "data"})
|
|
new_payload = {"id": 1, "name": "truncated", "_response_truncated": True}
|
|
|
|
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
|
|
new_payload, original
|
|
)
|
|
|
|
assert isinstance(result, ToolResult)
|
|
assert result.content[0].type == "text"
|
|
reparsed = json.loads(result.content[0].text)
|
|
assert reparsed == new_payload
|
|
|
|
def test_preserves_meta_from_original_tool_result(self) -> None:
|
|
"""Should copy meta from the original ToolResult."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
|
|
meta = {"request_id": "abc-123", "trace": "xyz"}
|
|
original = self._make_tool_result({"key": "val"}, meta=meta)
|
|
|
|
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
|
|
{"key": "val"}, original
|
|
)
|
|
|
|
assert isinstance(result, ToolResult)
|
|
assert result.meta == meta
|
|
|
|
def test_sets_meta_none_for_non_tool_result_original(self) -> None:
|
|
"""Should set meta=None when original is not a ToolResult."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
|
|
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
|
|
{"key": "val"}, {"not": "a ToolResult"}
|
|
)
|
|
|
|
assert isinstance(result, ToolResult)
|
|
assert result.meta is None
|
|
|
|
|
|
class TestToolResultWrapping:
|
|
"""Integration tests for ToolResult unwrap/truncate/rewrap in on_call_tool."""
|
|
|
|
def _make_tool_result(
|
|
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
|
|
) -> Any:
|
|
from fastmcp.tools.tool import ToolResult
|
|
from mcp.types import TextContent
|
|
|
|
from superset.utils import json
|
|
|
|
return ToolResult(
|
|
content=[TextContent(type="text", text=json.dumps(payload))],
|
|
meta=meta,
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None:
|
|
"""Truncate a ToolResult-wrapped info response and return a ToolResult."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
|
|
from superset.utils import json
|
|
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
|
context = MagicMock()
|
|
context.message.name = "get_dataset_info"
|
|
context.message.params = {}
|
|
|
|
large_payload = {"id": 1, "table_name": "test", "description": "x" * 50000}
|
|
tool_result = self._make_tool_result(large_payload)
|
|
call_next = AsyncMock(return_value=tool_result)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
):
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
|
|
# Must return a ToolResult, not a raw dict
|
|
assert isinstance(result, ToolResult)
|
|
reparsed = json.loads(result.content[0].text)
|
|
assert reparsed["id"] == 1
|
|
assert reparsed["_response_truncated"] is True
|
|
assert "[truncated" in reparsed["description"]
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_small_tool_result_passes_through_unchanged(self) -> None:
|
|
"""Should return the original ToolResult when within the token limit."""
|
|
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
|
context = MagicMock()
|
|
context.message.name = "get_chart_info"
|
|
context.message.params = {}
|
|
|
|
small_payload = {"id": 1, "name": "My Chart"}
|
|
tool_result = self._make_tool_result(small_payload)
|
|
call_next = AsyncMock(return_value=tool_result)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
):
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
|
|
assert result is tool_result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_large_non_info_tool_result_is_blocked(self) -> None:
|
|
"""Should raise ToolError for a non-info ToolResult that exceeds the limit."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.message.params = {}
|
|
|
|
large_payload = {
|
|
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(500)]
|
|
}
|
|
tool_result = self._make_tool_result(large_payload)
|
|
call_next = AsyncMock(return_value=tool_result)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
pytest.raises(ToolError),
|
|
):
|
|
await middleware.on_call_tool(context, call_next)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_meta_preserved_after_truncation(self) -> None:
|
|
"""Should preserve the original ToolResult meta through truncation."""
|
|
from fastmcp.tools.tool import ToolResult
|
|
|
|
from superset.utils import json
|
|
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
|
context = MagicMock()
|
|
context.message.name = "get_dashboard_info"
|
|
context.message.params = {}
|
|
|
|
meta = {"request_id": "abc-123"}
|
|
large_payload = {"id": 1, "title": "My Dashboard", "description": "x" * 50000}
|
|
tool_result = self._make_tool_result(large_payload, meta=meta)
|
|
call_next = AsyncMock(return_value=tool_result)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
):
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
|
|
assert isinstance(result, ToolResult)
|
|
assert result.meta == meta
|
|
reparsed = json.loads(result.content[0].text)
|
|
assert reparsed["_response_truncated"] is True
|
|
|
|
|
|
class TestMiddlewareIntegration:
|
|
"""Integration tests for middleware behavior."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_pydantic_model_response(self) -> None:
|
|
"""Should handle Pydantic model responses."""
|
|
from pydantic import BaseModel
|
|
|
|
class ChartInfo(BaseModel):
|
|
id: int
|
|
name: str
|
|
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "get_chart_info"
|
|
context.message.params = {}
|
|
|
|
response = ChartInfo(id=1, name="Test Chart")
|
|
call_next = AsyncMock(return_value=response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
):
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
|
|
assert result == response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_list_response(self) -> None:
|
|
"""Should handle list responses."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.message.params = {}
|
|
|
|
response = [{"id": 1}, {"id": 2}, {"id": 3}]
|
|
call_next = AsyncMock(return_value=response)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
):
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
|
|
assert result == response
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_string_response(self) -> None:
|
|
"""Should handle string responses."""
|
|
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
|
|
|
context = MagicMock()
|
|
context.message.name = "health_check"
|
|
context.message.params = {}
|
|
|
|
response = "OK"
|
|
call_next = AsyncMock(return_value=response)
|
|
|
|
result = await middleware.on_call_tool(context, call_next)
|
|
assert result == response
|
|
|
|
|
|
def _make_security_exception(msg: str = "access denied") -> SupersetSecurityException:
|
|
"""Helper to construct SupersetSecurityException with a proper SupersetError."""
|
|
return SupersetSecurityException(
|
|
SupersetError(
|
|
message=msg,
|
|
error_type=SupersetErrorType.GENERIC_BACKEND_ERROR,
|
|
level=ErrorLevel.ERROR,
|
|
)
|
|
)
|
|
|
|
|
|
class TestIsUserError:
|
|
"""Test _is_user_error classification helper."""
|
|
|
|
@pytest.mark.parametrize(
|
|
("error", "expected"),
|
|
[
|
|
# User errors (WARNING) — expected in normal MCP operation
|
|
(ToolError("bad request"), True),
|
|
(PermissionError("access denied"), True),
|
|
(ObjectNotFoundError("Chart", "123"), True),
|
|
(ForbiddenError(), True),
|
|
(_make_security_exception(), True),
|
|
(ValueError("invalid param"), True),
|
|
(FileNotFoundError("not found"), True),
|
|
# System errors (ERROR) — unexpected failures
|
|
(RuntimeError("unexpected"), False),
|
|
(ConnectionError("connection refused"), False),
|
|
(TypeError("type mismatch"), False),
|
|
(KeyError("missing key"), False),
|
|
(Exception("generic"), False),
|
|
],
|
|
ids=[
|
|
"ToolError",
|
|
"PermissionError",
|
|
"ObjectNotFoundError",
|
|
"ForbiddenError",
|
|
"SupersetSecurityException",
|
|
"ValueError", # user error — bad param from LLM
|
|
"FileNotFoundError",
|
|
"RuntimeError",
|
|
"ConnectionError",
|
|
"TypeError",
|
|
"KeyError",
|
|
"Exception",
|
|
],
|
|
)
|
|
def test_error_classification(self, error: Exception, expected: bool) -> None:
|
|
"""Test that _is_user_error correctly classifies error types."""
|
|
assert _is_user_error(error) == expected
|
|
|
|
def test_validation_error(self) -> None:
|
|
"""Test ValidationError is classified as user error."""
|
|
from pydantic import BaseModel
|
|
|
|
class TestModel(BaseModel):
|
|
"""Test model for validation error testing."""
|
|
|
|
name: str
|
|
|
|
with pytest.raises(ValidationError) as exc_info:
|
|
TestModel.model_validate({})
|
|
assert _is_user_error(exc_info.value) is True
|
|
|
|
def test_command_invalid_error(self) -> None:
|
|
"""Test CommandInvalidError is classified as user error."""
|
|
error = CommandInvalidError()
|
|
assert _is_user_error(error) is True
|
|
assert error.status == 422
|
|
|
|
def test_operational_error(self) -> None:
|
|
"""Test OperationalError is classified as system error."""
|
|
error = OperationalError("db error", {}, Exception())
|
|
assert _is_user_error(error) is False
|
|
|
|
def test_superset_exception_status_based(self) -> None:
|
|
"""Test SupersetException classification is based on .status attribute."""
|
|
# 4xx status → user error
|
|
error_400 = SupersetException("bad request")
|
|
error_400.status = 400
|
|
assert _is_user_error(error_400) is True
|
|
|
|
error_408 = SupersetException("timeout")
|
|
error_408.status = 408
|
|
assert _is_user_error(error_408) is True
|
|
|
|
error_422 = SupersetException("unprocessable")
|
|
error_422.status = 422
|
|
assert _is_user_error(error_422) is True
|
|
|
|
# 5xx status → system error
|
|
error_500 = SupersetException("internal error")
|
|
error_500.status = 500
|
|
assert _is_user_error(error_500) is False
|
|
|
|
error_503 = SupersetException("unavailable")
|
|
error_503.status = 503
|
|
assert _is_user_error(error_503) is False
|
|
|
|
|
|
class TestGlobalErrorHandlerLogLevels:
|
|
"""Test that GlobalErrorHandlerMiddleware logs at correct levels."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_user_error_logs_warning(self) -> None:
|
|
"""User errors (e.g. ValueError) should log at WARNING."""
|
|
middleware = GlobalErrorHandlerMiddleware()
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.method = "tools/call"
|
|
|
|
call_next = AsyncMock(side_effect=ValueError("invalid page"))
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
|
pytest.raises(ToolError),
|
|
):
|
|
await middleware.on_message(context, call_next)
|
|
|
|
# Should log at WARNING, not ERROR
|
|
mock_logger.warning.assert_called()
|
|
mock_logger.error.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_system_error_logs_error(self) -> None:
|
|
"""System errors (OperationalError, generic Exception) should log at ERROR."""
|
|
middleware = GlobalErrorHandlerMiddleware()
|
|
|
|
context = MagicMock()
|
|
context.message.name = "execute_sql"
|
|
context.method = "tools/call"
|
|
|
|
call_next = AsyncMock(side_effect=OperationalError("db error", {}, Exception()))
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
|
pytest.raises(ToolError),
|
|
):
|
|
await middleware.on_message(context, call_next)
|
|
|
|
# Should log at ERROR
|
|
mock_logger.error.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unexpected_error_logs_error(self) -> None:
|
|
"""Truly unexpected errors should log at ERROR with error_id."""
|
|
middleware = GlobalErrorHandlerMiddleware()
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.method = "tools/call"
|
|
|
|
call_next = AsyncMock(side_effect=RuntimeError("something broke"))
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
|
pytest.raises(ToolError, match="Internal error"),
|
|
):
|
|
await middleware.on_message(context, call_next)
|
|
|
|
# Should log at ERROR (both the classification log and the error_id log)
|
|
assert mock_logger.error.call_count >= 1
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_event_logger_includes_severity(self) -> None:
|
|
"""Event logger payload should include severity field."""
|
|
middleware = GlobalErrorHandlerMiddleware()
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.method = "tools/call"
|
|
|
|
call_next = AsyncMock(side_effect=ValueError("bad param"))
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger") as mock_event_logger,
|
|
patch("superset.mcp_service.middleware.logger"),
|
|
pytest.raises(ToolError),
|
|
):
|
|
await middleware.on_message(context, call_next)
|
|
|
|
mock_event_logger.log.assert_called_once()
|
|
payload = mock_event_logger.log.call_args.kwargs["curated_payload"]
|
|
assert payload["severity"] == "warning"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_permission_error_logs_warning(self) -> None:
|
|
"""PermissionError should log at WARNING — agents are expected to
|
|
try tools they lack access to."""
|
|
middleware = GlobalErrorHandlerMiddleware()
|
|
|
|
context = MagicMock()
|
|
context.message.name = "generate_chart"
|
|
context.method = "tools/call"
|
|
|
|
call_next = AsyncMock(side_effect=PermissionError("not allowed"))
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
|
pytest.raises(ToolError, match="Permission denied"),
|
|
):
|
|
await middleware.on_message(context, call_next)
|
|
|
|
mock_logger.warning.assert_called()
|
|
mock_logger.error.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_connection_error_logs_error(self) -> None:
|
|
"""ConnectionError should log at ERROR — infrastructure issue."""
|
|
middleware = GlobalErrorHandlerMiddleware()
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.method = "tools/call"
|
|
|
|
call_next = AsyncMock(side_effect=ConnectionError("connection refused"))
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
|
pytest.raises(ToolError, match="Connection error"),
|
|
):
|
|
await middleware.on_message(context, call_next)
|
|
|
|
mock_logger.error.assert_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_superset_exception_4xx_logs_warning(self) -> None:
|
|
"""SupersetException with 4xx status should log at WARNING."""
|
|
middleware = GlobalErrorHandlerMiddleware()
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.method = "tools/call"
|
|
|
|
error = SupersetException("bad request")
|
|
error.status = 400
|
|
call_next = AsyncMock(side_effect=error)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
|
pytest.raises(ToolError, match="Invalid request"),
|
|
):
|
|
await middleware.on_message(context, call_next)
|
|
|
|
mock_logger.warning.assert_called()
|
|
mock_logger.error.assert_not_called()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_superset_exception_5xx_logs_error(self) -> None:
|
|
"""SupersetException with 5xx status should log at ERROR."""
|
|
middleware = GlobalErrorHandlerMiddleware()
|
|
|
|
context = MagicMock()
|
|
context.message.name = "list_charts"
|
|
context.method = "tools/call"
|
|
|
|
error = SupersetException("internal failure")
|
|
error.status = 500
|
|
call_next = AsyncMock(side_effect=error)
|
|
|
|
with (
|
|
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
|
patch("superset.mcp_service.middleware.event_logger"),
|
|
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
|
pytest.raises(ToolError, match="Internal error"),
|
|
):
|
|
await middleware.on_message(context, call_next)
|
|
|
|
mock_logger.error.assert_called()
|