Files
superset2/tests/unit_tests/mcp_service/test_middleware.py
2026-04-28 13:30:39 -03:00

1033 lines
38 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP service middleware.
"""
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastmcp.exceptions import ToolError
from pydantic import ValidationError
from sqlalchemy.exc import OperationalError
from superset.commands.exceptions import (
CommandInvalidError,
ForbiddenError,
ObjectNotFoundError,
)
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetException, SupersetSecurityException
from superset.mcp_service.mcp_config import MCP_RESPONSE_SIZE_CONFIG
from superset.mcp_service.middleware import (
_is_user_error,
create_response_size_guard_middleware,
GlobalErrorHandlerMiddleware,
ResponseSizeGuardMiddleware,
)
class TestResponseSizeGuardMiddleware:
"""Test ResponseSizeGuardMiddleware class."""
def test_init_default_values(self) -> None:
"""Should initialize with default values."""
middleware = ResponseSizeGuardMiddleware()
assert middleware.token_limit == 25_000
assert middleware.warn_threshold_pct == 80
assert middleware.warn_threshold == 20000
assert middleware.excluded_tools == set()
def test_init_custom_values(self) -> None:
"""Should initialize with custom values."""
middleware = ResponseSizeGuardMiddleware(
token_limit=10000,
warn_threshold_pct=70,
excluded_tools=["health_check", "get_chart_preview"],
)
assert middleware.token_limit == 10000
assert middleware.warn_threshold_pct == 70
assert middleware.warn_threshold == 7000
assert middleware.excluded_tools == {"health_check", "get_chart_preview"}
def test_init_excluded_tools_as_string(self) -> None:
"""Should handle excluded_tools as a single string."""
middleware = ResponseSizeGuardMiddleware(
excluded_tools="health_check",
)
assert middleware.excluded_tools == {"health_check"}
@pytest.mark.asyncio
async def test_allows_small_response(self) -> None:
"""Should allow responses under token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
# Create mock context
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
# Create mock call_next that returns small response
small_response = {"charts": [{"id": 1, "name": "test"}]}
call_next = AsyncMock(return_value=small_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result == small_response
call_next.assert_called_once_with(context)
@pytest.mark.asyncio
async def test_blocks_large_response(self) -> None:
"""Should block responses over token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100) # Very low limit
# Create mock context
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {"page_size": 100}
# Create large response
large_response = {
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(1000)]
}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError) as exc_info,
):
await middleware.on_call_tool(context, call_next)
# Verify error contains helpful information
error_message = str(exc_info.value)
assert "Response too large" in error_message
assert "limit" in error_message.lower()
@pytest.mark.asyncio
async def test_skips_excluded_tools(self) -> None:
"""Should skip checking for excluded tools."""
middleware = ResponseSizeGuardMiddleware(
token_limit=100, excluded_tools=["health_check"]
)
# Create mock context for excluded tool
context = MagicMock()
context.message.name = "health_check"
context.message.params = {}
# Create response that would exceed limit
large_response = {"data": "x" * 10000}
call_next = AsyncMock(return_value=large_response)
# Should not raise even though response exceeds limit
result = await middleware.on_call_tool(context, call_next)
assert result == large_response
@pytest.mark.asyncio
async def test_logs_warning_at_threshold(self) -> None:
"""Should log warning when approaching limit."""
middleware = ResponseSizeGuardMiddleware(
token_limit=1000, warn_threshold_pct=80
)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
# Response at ~85% of limit (should trigger warning but not block)
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
):
result = await middleware.on_call_tool(context, call_next)
# Should return response (not blocked)
assert result == response
# Should log warning
mock_logger.warning.assert_called()
@pytest.mark.asyncio
async def test_error_includes_suggestions(self) -> None:
"""Should include suggestions in error message."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {"page_size": 100}
large_response = {"charts": [{"id": i} for i in range(1000)]}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError) as exc_info,
):
await middleware.on_call_tool(context, call_next)
error_message = str(exc_info.value)
# Should have numbered suggestions
assert "1." in error_message
# Should suggest reducing page_size
assert "page_size" in error_message.lower() or "limit" in error_message.lower()
@pytest.mark.asyncio
async def test_logs_size_exceeded_event(self) -> None:
"""Should log to event logger when size exceeded."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
large_response = {"data": "x" * 10000}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger") as mock_event_logger,
pytest.raises(ToolError),
):
await middleware.on_call_tool(context, call_next)
# Should log to event logger
mock_event_logger.log.assert_called()
call_args = mock_event_logger.log.call_args
assert call_args.kwargs["action"] == "mcp_response_size_exceeded"
@pytest.mark.asyncio
async def test_truncates_info_tool_instead_of_blocking(self) -> None:
"""Should truncate info tool responses instead of blocking them."""
middleware = ResponseSizeGuardMiddleware(token_limit=500)
context = MagicMock()
context.message.name = "get_dataset_info"
context.message.params = {}
# Large info tool response with a big description
large_response = {
"id": 1,
"table_name": "test",
"description": "x" * 50000,
}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
# Should return truncated response, not raise ToolError
assert isinstance(result, dict)
assert result["id"] == 1
assert result["_response_truncated"] is True
assert "[truncated" in result["description"]
@pytest.mark.asyncio
async def test_truncates_chart_info_with_large_form_data(self) -> None:
"""Should truncate get_chart_info with large form_data."""
middleware = ResponseSizeGuardMiddleware(token_limit=500)
context = MagicMock()
context.message.name = "get_chart_info"
context.message.params = {}
large_response = {
"id": 1,
"slice_name": "My Chart",
"form_data": {f"key_{i}": f"value_{i}" for i in range(100)},
}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert isinstance(result, dict)
assert result["id"] == 1
assert result["_response_truncated"] is True
@pytest.mark.asyncio
async def test_still_blocks_non_info_tools(self) -> None:
"""Should still block non-info tools that exceed limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts" # Not an info tool
context.message.params = {}
large_response = {"data": "x" * 10000}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError),
):
await middleware.on_call_tool(context, call_next)
@pytest.mark.asyncio
async def test_logs_truncation_event(self) -> None:
"""Should log mcp_response_truncated event on successful truncation."""
middleware = ResponseSizeGuardMiddleware(token_limit=500)
context = MagicMock()
context.message.name = "get_dashboard_info"
context.message.params = {}
large_response = {
"id": 1,
"description": "x" * 50000,
}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger") as mock_event_logger,
):
await middleware.on_call_tool(context, call_next)
# Should log truncation event (not size_exceeded)
mock_event_logger.log.assert_called()
call_args = mock_event_logger.log.call_args
assert call_args.kwargs["action"] == "mcp_response_truncated"
class TestCreateResponseSizeGuardMiddleware:
"""Test create_response_size_guard_middleware factory function."""
def test_default_config_checks_chart_preview(self) -> None:
"""Should size-check chart preview responses by default."""
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = MCP_RESPONSE_SIZE_CONFIG
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is not None
assert "get_chart_preview" not in middleware.excluded_tools
assert "health_check" in middleware.excluded_tools
def test_creates_middleware_when_enabled(self) -> None:
"""Should create middleware when enabled in config."""
mock_config = {
"enabled": True,
"token_limit": 30000,
"warn_threshold_pct": 75,
"excluded_tools": ["health_check"],
}
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = mock_config
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is not None
assert isinstance(middleware, ResponseSizeGuardMiddleware)
assert middleware.token_limit == 30000
assert middleware.warn_threshold_pct == 75
assert "health_check" in middleware.excluded_tools
def test_returns_none_when_disabled(self) -> None:
"""Should return None when disabled in config."""
mock_config = {"enabled": False}
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = mock_config
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is None
def test_uses_defaults_when_config_missing(self) -> None:
"""Should use defaults when config values are missing."""
mock_config = {"enabled": True} # Only enabled, no other values
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = mock_config
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is not None
assert middleware.token_limit == 25_000 # Default
assert middleware.warn_threshold_pct == 80 # Default
def test_handles_exception_gracefully(self) -> None:
"""Should return None on expected configuration exceptions."""
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
side_effect=ImportError("Config error"),
):
middleware = create_response_size_guard_middleware()
assert middleware is None
class TestExtractPayloadFromToolResult:
"""Tests for _extract_payload_from_tool_result static method."""
def _make_tool_result(self, text: str, meta: dict[str, Any] | None = None) -> Any:
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
return ToolResult(
content=[TextContent(type="text", text=text)],
meta=meta,
)
def test_extracts_dict_payload(self) -> None:
"""Should parse the JSON text inside content[0] and return the dict."""
from superset.utils import json
payload = {"id": 1, "name": "test", "charts": [1, 2, 3]}
tool_result = self._make_tool_result(json.dumps(payload))
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is not None
assert result == payload
def test_returns_none_for_plain_dict(self) -> None:
"""Should return None when given a plain dict (not a ToolResult)."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
{"key": "val"}
)
is None
)
def test_returns_none_for_string(self) -> None:
"""Should return None when given a plain string."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result("string")
is None
)
def test_returns_none_for_list(self) -> None:
"""Should return None when given a plain list."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result([1, 2, 3])
is None
)
def test_returns_none_for_none(self) -> None:
"""Should return None when given None."""
assert (
ResponseSizeGuardMiddleware._extract_payload_from_tool_result(None) is None
)
def test_returns_none_when_payload_is_list(self) -> None:
"""Should return None when JSON payload is a list, not a dict."""
from superset.utils import json
tool_result = self._make_tool_result(json.dumps([{"id": 1}, {"id": 2}]))
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
def test_returns_none_for_invalid_json(self) -> None:
"""Should return None when content text is not valid JSON."""
tool_result = self._make_tool_result("not valid {{{json")
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
def test_returns_none_for_empty_text(self) -> None:
"""Should return None when content[0].text is empty."""
tool_result = self._make_tool_result("")
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
def test_returns_none_for_empty_content(self) -> None:
"""Should return None when ToolResult has no content items."""
from fastmcp.tools.tool import ToolResult
tool_result = ToolResult(content=[], meta=None)
result = ResponseSizeGuardMiddleware._extract_payload_from_tool_result(
tool_result
)
assert result is None
class TestRewrapAsToolResult:
"""Tests for _rewrap_as_tool_result static method."""
def _make_tool_result(
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
) -> Any:
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
from superset.utils import json
return ToolResult(
content=[TextContent(type="text", text=json.dumps(payload))],
meta=meta,
)
def test_returns_tool_result_with_serialized_payload(self) -> None:
"""Should return a ToolResult whose content[0].text is the JSON payload."""
from fastmcp.tools.tool import ToolResult
from superset.utils import json
original = self._make_tool_result({"old": "data"})
new_payload = {"id": 1, "name": "truncated", "_response_truncated": True}
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
new_payload, original
)
assert isinstance(result, ToolResult)
assert result.content[0].type == "text"
reparsed = json.loads(result.content[0].text)
assert reparsed == new_payload
def test_preserves_meta_from_original_tool_result(self) -> None:
"""Should copy meta from the original ToolResult."""
from fastmcp.tools.tool import ToolResult
meta = {"request_id": "abc-123", "trace": "xyz"}
original = self._make_tool_result({"key": "val"}, meta=meta)
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
{"key": "val"}, original
)
assert isinstance(result, ToolResult)
assert result.meta == meta
def test_sets_meta_none_for_non_tool_result_original(self) -> None:
"""Should set meta=None when original is not a ToolResult."""
from fastmcp.tools.tool import ToolResult
result = ResponseSizeGuardMiddleware._rewrap_as_tool_result(
{"key": "val"}, {"not": "a ToolResult"}
)
assert isinstance(result, ToolResult)
assert result.meta is None
class TestToolResultWrapping:
"""Integration tests for ToolResult unwrap/truncate/rewrap in on_call_tool."""
def _make_tool_result(
self, payload: dict[str, Any], meta: dict[str, Any] | None = None
) -> Any:
from fastmcp.tools.tool import ToolResult
from mcp.types import TextContent
from superset.utils import json
return ToolResult(
content=[TextContent(type="text", text=json.dumps(payload))],
meta=meta,
)
@pytest.mark.asyncio
async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None:
"""Truncate a ToolResult-wrapped info response and return a ToolResult."""
from fastmcp.tools.tool import ToolResult
from superset.utils import json
middleware = ResponseSizeGuardMiddleware(token_limit=500)
context = MagicMock()
context.message.name = "get_dataset_info"
context.message.params = {}
large_payload = {"id": 1, "table_name": "test", "description": "x" * 50000}
tool_result = self._make_tool_result(large_payload)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
# Must return a ToolResult, not a raw dict
assert isinstance(result, ToolResult)
reparsed = json.loads(result.content[0].text)
assert reparsed["id"] == 1
assert reparsed["_response_truncated"] is True
assert "[truncated" in reparsed["description"]
@pytest.mark.asyncio
async def test_small_tool_result_passes_through_unchanged(self) -> None:
"""Should return the original ToolResult when within the token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "get_chart_info"
context.message.params = {}
small_payload = {"id": 1, "name": "My Chart"}
tool_result = self._make_tool_result(small_payload)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result is tool_result
@pytest.mark.asyncio
async def test_large_non_info_tool_result_is_blocked(self) -> None:
"""Should raise ToolError for a non-info ToolResult that exceeds the limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
large_payload = {
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(500)]
}
tool_result = self._make_tool_result(large_payload)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError),
):
await middleware.on_call_tool(context, call_next)
@pytest.mark.asyncio
async def test_meta_preserved_after_truncation(self) -> None:
"""Should preserve the original ToolResult meta through truncation."""
from fastmcp.tools.tool import ToolResult
from superset.utils import json
middleware = ResponseSizeGuardMiddleware(token_limit=500)
context = MagicMock()
context.message.name = "get_dashboard_info"
context.message.params = {}
meta = {"request_id": "abc-123"}
large_payload = {"id": 1, "title": "My Dashboard", "description": "x" * 50000}
tool_result = self._make_tool_result(large_payload, meta=meta)
call_next = AsyncMock(return_value=tool_result)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert isinstance(result, ToolResult)
assert result.meta == meta
reparsed = json.loads(result.content[0].text)
assert reparsed["_response_truncated"] is True
class TestMiddlewareIntegration:
"""Integration tests for middleware behavior."""
@pytest.mark.asyncio
async def test_pydantic_model_response(self) -> None:
"""Should handle Pydantic model responses."""
from pydantic import BaseModel
class ChartInfo(BaseModel):
id: int
name: str
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "get_chart_info"
context.message.params = {}
response = ChartInfo(id=1, name="Test Chart")
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result == response
@pytest.mark.asyncio
async def test_list_response(self) -> None:
"""Should handle list responses."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
response = [{"id": 1}, {"id": 2}, {"id": 3}]
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result == response
@pytest.mark.asyncio
async def test_string_response(self) -> None:
"""Should handle string responses."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "health_check"
context.message.params = {}
response = "OK"
call_next = AsyncMock(return_value=response)
result = await middleware.on_call_tool(context, call_next)
assert result == response
def _make_security_exception(msg: str = "access denied") -> SupersetSecurityException:
"""Helper to construct SupersetSecurityException with a proper SupersetError."""
return SupersetSecurityException(
SupersetError(
message=msg,
error_type=SupersetErrorType.GENERIC_BACKEND_ERROR,
level=ErrorLevel.ERROR,
)
)
class TestIsUserError:
"""Test _is_user_error classification helper."""
@pytest.mark.parametrize(
("error", "expected"),
[
# User errors (WARNING) — expected in normal MCP operation
(ToolError("bad request"), True),
(PermissionError("access denied"), True),
(ObjectNotFoundError("Chart", "123"), True),
(ForbiddenError(), True),
(_make_security_exception(), True),
(ValueError("invalid param"), True),
(FileNotFoundError("not found"), True),
# System errors (ERROR) — unexpected failures
(RuntimeError("unexpected"), False),
(ConnectionError("connection refused"), False),
(TypeError("type mismatch"), False),
(KeyError("missing key"), False),
(Exception("generic"), False),
],
ids=[
"ToolError",
"PermissionError",
"ObjectNotFoundError",
"ForbiddenError",
"SupersetSecurityException",
"ValueError", # user error — bad param from LLM
"FileNotFoundError",
"RuntimeError",
"ConnectionError",
"TypeError",
"KeyError",
"Exception",
],
)
def test_error_classification(self, error: Exception, expected: bool) -> None:
"""Test that _is_user_error correctly classifies error types."""
assert _is_user_error(error) == expected
def test_validation_error(self) -> None:
"""Test ValidationError is classified as user error."""
from pydantic import BaseModel
class TestModel(BaseModel):
"""Test model for validation error testing."""
name: str
with pytest.raises(ValidationError) as exc_info:
TestModel.model_validate({})
assert _is_user_error(exc_info.value) is True
def test_command_invalid_error(self) -> None:
"""Test CommandInvalidError is classified as user error."""
error = CommandInvalidError()
assert _is_user_error(error) is True
assert error.status == 422
def test_operational_error(self) -> None:
"""Test OperationalError is classified as system error."""
error = OperationalError("db error", {}, Exception())
assert _is_user_error(error) is False
def test_superset_exception_status_based(self) -> None:
"""Test SupersetException classification is based on .status attribute."""
# 4xx status → user error
error_400 = SupersetException("bad request")
error_400.status = 400
assert _is_user_error(error_400) is True
error_408 = SupersetException("timeout")
error_408.status = 408
assert _is_user_error(error_408) is True
error_422 = SupersetException("unprocessable")
error_422.status = 422
assert _is_user_error(error_422) is True
# 5xx status → system error
error_500 = SupersetException("internal error")
error_500.status = 500
assert _is_user_error(error_500) is False
error_503 = SupersetException("unavailable")
error_503.status = 503
assert _is_user_error(error_503) is False
class TestGlobalErrorHandlerLogLevels:
"""Test that GlobalErrorHandlerMiddleware logs at correct levels."""
@pytest.mark.asyncio
async def test_user_error_logs_warning(self) -> None:
"""User errors (e.g. ValueError) should log at WARNING."""
middleware = GlobalErrorHandlerMiddleware()
context = MagicMock()
context.message.name = "list_charts"
context.method = "tools/call"
call_next = AsyncMock(side_effect=ValueError("invalid page"))
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
pytest.raises(ToolError),
):
await middleware.on_message(context, call_next)
# Should log at WARNING, not ERROR
mock_logger.warning.assert_called()
mock_logger.error.assert_not_called()
@pytest.mark.asyncio
async def test_system_error_logs_error(self) -> None:
"""System errors (OperationalError, generic Exception) should log at ERROR."""
middleware = GlobalErrorHandlerMiddleware()
context = MagicMock()
context.message.name = "execute_sql"
context.method = "tools/call"
call_next = AsyncMock(side_effect=OperationalError("db error", {}, Exception()))
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
pytest.raises(ToolError),
):
await middleware.on_message(context, call_next)
# Should log at ERROR
mock_logger.error.assert_called()
@pytest.mark.asyncio
async def test_unexpected_error_logs_error(self) -> None:
"""Truly unexpected errors should log at ERROR with error_id."""
middleware = GlobalErrorHandlerMiddleware()
context = MagicMock()
context.message.name = "list_charts"
context.method = "tools/call"
call_next = AsyncMock(side_effect=RuntimeError("something broke"))
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
pytest.raises(ToolError, match="Internal error"),
):
await middleware.on_message(context, call_next)
# Should log at ERROR (both the classification log and the error_id log)
assert mock_logger.error.call_count >= 1
@pytest.mark.asyncio
async def test_event_logger_includes_severity(self) -> None:
"""Event logger payload should include severity field."""
middleware = GlobalErrorHandlerMiddleware()
context = MagicMock()
context.message.name = "list_charts"
context.method = "tools/call"
call_next = AsyncMock(side_effect=ValueError("bad param"))
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger") as mock_event_logger,
patch("superset.mcp_service.middleware.logger"),
pytest.raises(ToolError),
):
await middleware.on_message(context, call_next)
mock_event_logger.log.assert_called_once()
payload = mock_event_logger.log.call_args.kwargs["curated_payload"]
assert payload["severity"] == "warning"
@pytest.mark.asyncio
async def test_permission_error_logs_warning(self) -> None:
"""PermissionError should log at WARNING — agents are expected to
try tools they lack access to."""
middleware = GlobalErrorHandlerMiddleware()
context = MagicMock()
context.message.name = "generate_chart"
context.method = "tools/call"
call_next = AsyncMock(side_effect=PermissionError("not allowed"))
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
pytest.raises(ToolError, match="Permission denied"),
):
await middleware.on_message(context, call_next)
mock_logger.warning.assert_called()
mock_logger.error.assert_not_called()
@pytest.mark.asyncio
async def test_connection_error_logs_error(self) -> None:
"""ConnectionError should log at ERROR — infrastructure issue."""
middleware = GlobalErrorHandlerMiddleware()
context = MagicMock()
context.message.name = "list_charts"
context.method = "tools/call"
call_next = AsyncMock(side_effect=ConnectionError("connection refused"))
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
pytest.raises(ToolError, match="Connection error"),
):
await middleware.on_message(context, call_next)
mock_logger.error.assert_called()
@pytest.mark.asyncio
async def test_superset_exception_4xx_logs_warning(self) -> None:
"""SupersetException with 4xx status should log at WARNING."""
middleware = GlobalErrorHandlerMiddleware()
context = MagicMock()
context.message.name = "list_charts"
context.method = "tools/call"
error = SupersetException("bad request")
error.status = 400
call_next = AsyncMock(side_effect=error)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
pytest.raises(ToolError, match="Invalid request"),
):
await middleware.on_message(context, call_next)
mock_logger.warning.assert_called()
mock_logger.error.assert_not_called()
@pytest.mark.asyncio
async def test_superset_exception_5xx_logs_error(self) -> None:
"""SupersetException with 5xx status should log at ERROR."""
middleware = GlobalErrorHandlerMiddleware()
context = MagicMock()
context.message.name = "list_charts"
context.method = "tools/call"
error = SupersetException("internal failure")
error.status = 500
call_next = AsyncMock(side_effect=error)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
pytest.raises(ToolError, match="Internal error"),
):
await middleware.on_message(context, call_next)
mock_logger.error.assert_called()