fix(mcp): fix 4 failing unit tests and ruff import error in RBAC tool visibility

- Fix ruff error: consolidate contextlib imports into single from-import
- Fix test patch targets: middleware tests must patch middleware module
  after imports were promoted to module level (not auth module)
- Fix _tool_allowed_for_current_user: pass public tools through when
  user resolution fails (only hide tools with _class_permission_name)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-05-20 17:55:27 +00:00
parent f97e70ccdb
commit 842df5ee77
3 changed files with 49 additions and 44 deletions

View File

@@ -44,9 +44,8 @@ Configuration:
- MCP_DEV_USERNAME: Fallback username for development
"""
import contextlib
import logging
from contextlib import AbstractContextManager
from contextlib import AbstractContextManager, nullcontext
from typing import Any, Callable, TYPE_CHECKING, TypeVar
from flask import current_app, g, has_app_context, has_request_context
@@ -659,7 +658,7 @@ def _get_app_context_manager() -> AbstractContextManager[None]:
``RBACToolVisibilityMiddleware`` (tools/list filtering).
"""
if has_request_context():
return contextlib.nullcontext()
return nullcontext()
if has_app_context():
# Push a new context for the CURRENT app (not get_flask_app()
# which may return a different instance in test environments).

View File

@@ -411,7 +411,10 @@ def _tool_allowed_for_current_user(tool: Any) -> bool:
try:
g.user = get_user_from_request()
except (ValueError, PermissionError):
return False
# Can't resolve user; only hide protected tools. Public tools
# (no _class_permission_name) pass through regardless.
func = getattr(tool, "fn", tool)
return not getattr(func, "_class_permission_name", None)
return is_tool_visible_to_current_user(tool)
except (AttributeError, RuntimeError, ValueError):

View File

@@ -73,7 +73,7 @@ class TestResponseSizeGuardMiddleware:
)
assert middleware.excluded_tools == {"health_check"}
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_allows_small_response(self) -> None:
"""Should allow responses under token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
@@ -96,7 +96,7 @@ class TestResponseSizeGuardMiddleware:
assert result == small_response
call_next.assert_called_once_with(context)
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_blocks_large_response(self) -> None:
"""Should block responses over token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100) # Very low limit
@@ -124,7 +124,7 @@ class TestResponseSizeGuardMiddleware:
assert "Response too large" in error_message
assert "limit" in error_message.lower()
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_skips_excluded_tools(self) -> None:
"""Should skip checking for excluded tools."""
middleware = ResponseSizeGuardMiddleware(
@@ -144,7 +144,7 @@ class TestResponseSizeGuardMiddleware:
result = await middleware.on_call_tool(context, call_next)
assert result == large_response
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_logs_warning_at_threshold(self) -> None:
"""Should log warning when approaching limit.
@@ -180,7 +180,7 @@ class TestResponseSizeGuardMiddleware:
# Should log warning
mock_logger.warning.assert_called()
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_error_includes_suggestions(self) -> None:
"""Should include suggestions in error message."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
@@ -205,7 +205,7 @@ class TestResponseSizeGuardMiddleware:
# Should suggest reducing page_size
assert "page_size" in error_message.lower() or "limit" in error_message.lower()
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_logs_size_exceeded_event(self) -> None:
"""Should log to event logger when size exceeded."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
@@ -229,7 +229,7 @@ class TestResponseSizeGuardMiddleware:
call_args = mock_event_logger.log.call_args
assert call_args.kwargs["action"] == "mcp_response_size_exceeded"
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_truncates_info_tool_instead_of_blocking(self) -> None:
"""Should truncate info tool responses instead of blocking them."""
middleware = ResponseSizeGuardMiddleware(token_limit=500)
@@ -258,7 +258,7 @@ class TestResponseSizeGuardMiddleware:
assert result["_response_truncated"] is True
assert "[truncated" in result["description"]
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_truncates_chart_info_with_large_form_data(self) -> None:
"""Should truncate get_chart_info with large form_data."""
middleware = ResponseSizeGuardMiddleware(token_limit=500)
@@ -284,7 +284,7 @@ class TestResponseSizeGuardMiddleware:
assert result["id"] == 1
assert result["_response_truncated"] is True
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_still_blocks_non_info_tools(self) -> None:
"""Should still block non-info tools that exceed limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
@@ -303,7 +303,7 @@ class TestResponseSizeGuardMiddleware:
):
await middleware.on_call_tool(context, call_next)
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_logs_truncation_event(self) -> None:
"""Should log mcp_response_truncated event on successful truncation."""
middleware = ResponseSizeGuardMiddleware(token_limit=500)
@@ -591,7 +591,7 @@ class TestToolResultWrapping:
meta=meta,
)
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None:
"""Truncate a ToolResult-wrapped info response and return a ToolResult."""
from fastmcp.tools.tool import ToolResult
@@ -620,7 +620,7 @@ class TestToolResultWrapping:
assert reparsed["_response_truncated"] is True
assert "[truncated" in reparsed["description"]
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_small_tool_result_passes_through_unchanged(self) -> None:
"""Should return the original ToolResult when within the token limit."""
@@ -641,7 +641,7 @@ class TestToolResultWrapping:
assert result is tool_result
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_large_non_info_tool_result_is_blocked(self) -> None:
"""Should raise ToolError for a non-info ToolResult that exceeds the limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
@@ -662,7 +662,7 @@ class TestToolResultWrapping:
):
await middleware.on_call_tool(context, call_next)
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_meta_preserved_after_truncation(self) -> None:
"""Should preserve the original ToolResult meta through truncation."""
from fastmcp.tools.tool import ToolResult
@@ -694,7 +694,7 @@ class TestToolResultWrapping:
class TestMiddlewareIntegration:
"""Integration tests for middleware behavior."""
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_pydantic_model_response(self) -> None:
"""Should handle Pydantic model responses."""
from pydantic import BaseModel
@@ -720,7 +720,7 @@ class TestMiddlewareIntegration:
assert result == response
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_list_response(self) -> None:
"""Should handle list responses."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
@@ -740,7 +740,7 @@ class TestMiddlewareIntegration:
assert result == response
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_string_response(self) -> None:
"""Should handle string responses."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
@@ -859,7 +859,7 @@ class TestIsUserError:
class TestGlobalErrorHandlerLogLevels:
"""Test that GlobalErrorHandlerMiddleware logs at correct levels."""
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_user_error_logs_warning(self) -> None:
"""User errors (e.g. ValueError) should log at WARNING."""
middleware = GlobalErrorHandlerMiddleware()
@@ -882,7 +882,7 @@ class TestGlobalErrorHandlerLogLevels:
mock_logger.warning.assert_called()
mock_logger.error.assert_not_called()
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_system_error_logs_error(self) -> None:
"""System errors (OperationalError, generic Exception) should log at ERROR."""
middleware = GlobalErrorHandlerMiddleware()
@@ -904,7 +904,7 @@ class TestGlobalErrorHandlerLogLevels:
# Should log at ERROR
mock_logger.error.assert_called()
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_unexpected_error_logs_error(self) -> None:
"""Truly unexpected errors should log at ERROR with error_id."""
middleware = GlobalErrorHandlerMiddleware()
@@ -926,7 +926,7 @@ class TestGlobalErrorHandlerLogLevels:
# Should log at ERROR (both the classification log and the error_id log)
assert mock_logger.error.call_count >= 1
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_event_logger_includes_severity(self) -> None:
"""Event logger payload should include severity field."""
middleware = GlobalErrorHandlerMiddleware()
@@ -949,7 +949,7 @@ class TestGlobalErrorHandlerLogLevels:
payload = mock_event_logger.log.call_args.kwargs["curated_payload"]
assert payload["severity"] == "warning"
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_permission_error_logs_warning(self) -> None:
"""PermissionError should log at WARNING — agents are expected to
try tools they lack access to."""
@@ -972,7 +972,7 @@ class TestGlobalErrorHandlerLogLevels:
mock_logger.warning.assert_called()
mock_logger.error.assert_not_called()
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_connection_error_logs_error(self) -> None:
"""ConnectionError should log at ERROR — infrastructure issue."""
middleware = GlobalErrorHandlerMiddleware()
@@ -993,7 +993,7 @@ class TestGlobalErrorHandlerLogLevels:
mock_logger.error.assert_called()
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_superset_exception_4xx_logs_warning(self) -> None:
"""SupersetException with 4xx status should log at WARNING."""
middleware = GlobalErrorHandlerMiddleware()
@@ -1017,7 +1017,7 @@ class TestGlobalErrorHandlerLogLevels:
mock_logger.warning.assert_called()
mock_logger.error.assert_not_called()
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_superset_exception_5xx_logs_error(self) -> None:
"""SupersetException with 5xx status should log at ERROR."""
middleware = GlobalErrorHandlerMiddleware()
@@ -1041,7 +1041,7 @@ class TestGlobalErrorHandlerLogLevels:
mock_logger.error.assert_called()
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_mcp_permission_denied_error_becomes_tool_error(self) -> None:
"""MCPPermissionDeniedError must convert to ToolError, not a generic error."""
from superset.mcp_service.auth import MCPPermissionDeniedError
@@ -1070,7 +1070,7 @@ class TestGlobalErrorHandlerLogLevels:
assert "can_write" in str(exc_info.value)
assert "Dashboard" in str(exc_info.value)
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_mcp_permission_denied_error_is_user_error(self) -> None:
"""MCPPermissionDeniedError must be classified as a user error (WARNING)."""
from superset.mcp_service.auth import MCPPermissionDeniedError
@@ -1081,7 +1081,7 @@ class TestGlobalErrorHandlerLogLevels:
)
assert _is_user_error(error) is True
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_mcp_permission_denied_error_logs_at_warning(self) -> None:
"""MCPPermissionDeniedError should log at WARNING, not ERROR."""
from superset.mcp_service.auth import MCPPermissionDeniedError
@@ -1121,7 +1121,7 @@ class TestRBACToolVisibilityMiddleware:
tool.name = name
return tool
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_fails_open_on_exception(self) -> None:
"""Returns all tools when get_flask_app raises (fail open)."""
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
@@ -1138,7 +1138,7 @@ class TestRBACToolVisibilityMiddleware:
assert result == tools
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_fails_open_when_user_is_none(self, app) -> None:
"""Returns all tools when get_user_from_request returns None."""
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
@@ -1151,13 +1151,16 @@ class TestRBACToolVisibilityMiddleware:
patch(
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
),
patch("superset.mcp_service.auth.get_user_from_request", return_value=None),
patch(
"superset.mcp_service.middleware.get_user_from_request",
return_value=None,
),
):
result = await middleware.on_list_tools(MagicMock(), call_next)
assert result == tools
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_filters_tools_by_rbac(self, app) -> None:
"""Tools denied by is_tool_visible_to_current_user are removed."""
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
@@ -1178,11 +1181,11 @@ class TestRBACToolVisibilityMiddleware:
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
),
patch(
"superset.mcp_service.auth.get_user_from_request",
"superset.mcp_service.middleware.get_user_from_request",
return_value=mock_user,
),
patch(
"superset.mcp_service.auth.is_tool_visible_to_current_user",
"superset.mcp_service.middleware.is_tool_visible_to_current_user",
side_effect=_visible,
),
):
@@ -1191,7 +1194,7 @@ class TestRBACToolVisibilityMiddleware:
assert read_tool in result
assert write_tool not in result
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_fails_closed_on_permission_error(self, app) -> None:
"""Returns empty list when credentials are invalid (PermissionError)."""
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
@@ -1205,7 +1208,7 @@ class TestRBACToolVisibilityMiddleware:
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
),
patch(
"superset.mcp_service.auth.get_user_from_request",
"superset.mcp_service.middleware.get_user_from_request",
side_effect=PermissionError("Invalid API key"),
),
):
@@ -1213,7 +1216,7 @@ class TestRBACToolVisibilityMiddleware:
assert result == []
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_fails_closed_on_bad_credentials_value_error(self, app) -> None:
"""Returns empty list when auth was attempted but user not found."""
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
@@ -1227,7 +1230,7 @@ class TestRBACToolVisibilityMiddleware:
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
),
patch(
"superset.mcp_service.auth.get_user_from_request",
"superset.mcp_service.middleware.get_user_from_request",
side_effect=ValueError("User 'ghost' not found in database"),
),
):
@@ -1235,7 +1238,7 @@ class TestRBACToolVisibilityMiddleware:
assert result == []
@pytest.mark.asyncio
@pytest.mark.asyncio()
async def test_fails_open_when_no_auth_configured(self, app) -> None:
"""Returns all tools when no auth source is configured at all."""
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
@@ -1249,7 +1252,7 @@ class TestRBACToolVisibilityMiddleware:
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
),
patch(
"superset.mcp_service.auth.get_user_from_request",
"superset.mcp_service.middleware.get_user_from_request",
side_effect=ValueError("No authenticated user found"),
),
):