mirror of
https://github.com/apache/superset.git
synced 2026-05-21 15:55:10 +00:00
fix(mcp): fix 4 failing unit tests and ruff import error in RBAC tool visibility
- Fix ruff error: consolidate contextlib imports into single from-import - Fix test patch targets: middleware tests must patch middleware module after imports were promoted to module level (not auth module) - Fix _tool_allowed_for_current_user: pass public tools through when user resolution fails (only hide tools with _class_permission_name) Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -44,9 +44,8 @@ Configuration:
|
||||
- MCP_DEV_USERNAME: Fallback username for development
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
from contextlib import AbstractContextManager
|
||||
from contextlib import AbstractContextManager, nullcontext
|
||||
from typing import Any, Callable, TYPE_CHECKING, TypeVar
|
||||
|
||||
from flask import current_app, g, has_app_context, has_request_context
|
||||
@@ -659,7 +658,7 @@ def _get_app_context_manager() -> AbstractContextManager[None]:
|
||||
``RBACToolVisibilityMiddleware`` (tools/list filtering).
|
||||
"""
|
||||
if has_request_context():
|
||||
return contextlib.nullcontext()
|
||||
return nullcontext()
|
||||
if has_app_context():
|
||||
# Push a new context for the CURRENT app (not get_flask_app()
|
||||
# which may return a different instance in test environments).
|
||||
|
||||
@@ -411,7 +411,10 @@ def _tool_allowed_for_current_user(tool: Any) -> bool:
|
||||
try:
|
||||
g.user = get_user_from_request()
|
||||
except (ValueError, PermissionError):
|
||||
return False
|
||||
# Can't resolve user; only hide protected tools. Public tools
|
||||
# (no _class_permission_name) pass through regardless.
|
||||
func = getattr(tool, "fn", tool)
|
||||
return not getattr(func, "_class_permission_name", None)
|
||||
|
||||
return is_tool_visible_to_current_user(tool)
|
||||
except (AttributeError, RuntimeError, ValueError):
|
||||
|
||||
@@ -73,7 +73,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
)
|
||||
assert middleware.excluded_tools == {"health_check"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_allows_small_response(self) -> None:
|
||||
"""Should allow responses under token limit."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
||||
@@ -96,7 +96,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
assert result == small_response
|
||||
call_next.assert_called_once_with(context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_blocks_large_response(self) -> None:
|
||||
"""Should block responses over token limit."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100) # Very low limit
|
||||
@@ -124,7 +124,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
assert "Response too large" in error_message
|
||||
assert "limit" in error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_skips_excluded_tools(self) -> None:
|
||||
"""Should skip checking for excluded tools."""
|
||||
middleware = ResponseSizeGuardMiddleware(
|
||||
@@ -144,7 +144,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
assert result == large_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_logs_warning_at_threshold(self) -> None:
|
||||
"""Should log warning when approaching limit.
|
||||
|
||||
@@ -180,7 +180,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
# Should log warning
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_error_includes_suggestions(self) -> None:
|
||||
"""Should include suggestions in error message."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
||||
@@ -205,7 +205,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
# Should suggest reducing page_size
|
||||
assert "page_size" in error_message.lower() or "limit" in error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_logs_size_exceeded_event(self) -> None:
|
||||
"""Should log to event logger when size exceeded."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
||||
@@ -229,7 +229,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
call_args = mock_event_logger.log.call_args
|
||||
assert call_args.kwargs["action"] == "mcp_response_size_exceeded"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_truncates_info_tool_instead_of_blocking(self) -> None:
|
||||
"""Should truncate info tool responses instead of blocking them."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
||||
@@ -258,7 +258,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
assert result["_response_truncated"] is True
|
||||
assert "[truncated" in result["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_truncates_chart_info_with_large_form_data(self) -> None:
|
||||
"""Should truncate get_chart_info with large form_data."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
||||
@@ -284,7 +284,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
assert result["id"] == 1
|
||||
assert result["_response_truncated"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_still_blocks_non_info_tools(self) -> None:
|
||||
"""Should still block non-info tools that exceed limit."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
||||
@@ -303,7 +303,7 @@ class TestResponseSizeGuardMiddleware:
|
||||
):
|
||||
await middleware.on_call_tool(context, call_next)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_logs_truncation_event(self) -> None:
|
||||
"""Should log mcp_response_truncated event on successful truncation."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=500)
|
||||
@@ -591,7 +591,7 @@ class TestToolResultWrapping:
|
||||
meta=meta,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None:
|
||||
"""Truncate a ToolResult-wrapped info response and return a ToolResult."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
@@ -620,7 +620,7 @@ class TestToolResultWrapping:
|
||||
assert reparsed["_response_truncated"] is True
|
||||
assert "[truncated" in reparsed["description"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_small_tool_result_passes_through_unchanged(self) -> None:
|
||||
"""Should return the original ToolResult when within the token limit."""
|
||||
|
||||
@@ -641,7 +641,7 @@ class TestToolResultWrapping:
|
||||
|
||||
assert result is tool_result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_large_non_info_tool_result_is_blocked(self) -> None:
|
||||
"""Should raise ToolError for a non-info ToolResult that exceeds the limit."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
||||
@@ -662,7 +662,7 @@ class TestToolResultWrapping:
|
||||
):
|
||||
await middleware.on_call_tool(context, call_next)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_meta_preserved_after_truncation(self) -> None:
|
||||
"""Should preserve the original ToolResult meta through truncation."""
|
||||
from fastmcp.tools.tool import ToolResult
|
||||
@@ -694,7 +694,7 @@ class TestToolResultWrapping:
|
||||
class TestMiddlewareIntegration:
|
||||
"""Integration tests for middleware behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_pydantic_model_response(self) -> None:
|
||||
"""Should handle Pydantic model responses."""
|
||||
from pydantic import BaseModel
|
||||
@@ -720,7 +720,7 @@ class TestMiddlewareIntegration:
|
||||
|
||||
assert result == response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_list_response(self) -> None:
|
||||
"""Should handle list responses."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
||||
@@ -740,7 +740,7 @@ class TestMiddlewareIntegration:
|
||||
|
||||
assert result == response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_string_response(self) -> None:
|
||||
"""Should handle string responses."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
||||
@@ -859,7 +859,7 @@ class TestIsUserError:
|
||||
class TestGlobalErrorHandlerLogLevels:
|
||||
"""Test that GlobalErrorHandlerMiddleware logs at correct levels."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_user_error_logs_warning(self) -> None:
|
||||
"""User errors (e.g. ValueError) should log at WARNING."""
|
||||
middleware = GlobalErrorHandlerMiddleware()
|
||||
@@ -882,7 +882,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
mock_logger.warning.assert_called()
|
||||
mock_logger.error.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_system_error_logs_error(self) -> None:
|
||||
"""System errors (OperationalError, generic Exception) should log at ERROR."""
|
||||
middleware = GlobalErrorHandlerMiddleware()
|
||||
@@ -904,7 +904,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
# Should log at ERROR
|
||||
mock_logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_unexpected_error_logs_error(self) -> None:
|
||||
"""Truly unexpected errors should log at ERROR with error_id."""
|
||||
middleware = GlobalErrorHandlerMiddleware()
|
||||
@@ -926,7 +926,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
# Should log at ERROR (both the classification log and the error_id log)
|
||||
assert mock_logger.error.call_count >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_event_logger_includes_severity(self) -> None:
|
||||
"""Event logger payload should include severity field."""
|
||||
middleware = GlobalErrorHandlerMiddleware()
|
||||
@@ -949,7 +949,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
payload = mock_event_logger.log.call_args.kwargs["curated_payload"]
|
||||
assert payload["severity"] == "warning"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_permission_error_logs_warning(self) -> None:
|
||||
"""PermissionError should log at WARNING — agents are expected to
|
||||
try tools they lack access to."""
|
||||
@@ -972,7 +972,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
mock_logger.warning.assert_called()
|
||||
mock_logger.error.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_connection_error_logs_error(self) -> None:
|
||||
"""ConnectionError should log at ERROR — infrastructure issue."""
|
||||
middleware = GlobalErrorHandlerMiddleware()
|
||||
@@ -993,7 +993,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
|
||||
mock_logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_superset_exception_4xx_logs_warning(self) -> None:
|
||||
"""SupersetException with 4xx status should log at WARNING."""
|
||||
middleware = GlobalErrorHandlerMiddleware()
|
||||
@@ -1017,7 +1017,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
mock_logger.warning.assert_called()
|
||||
mock_logger.error.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_superset_exception_5xx_logs_error(self) -> None:
|
||||
"""SupersetException with 5xx status should log at ERROR."""
|
||||
middleware = GlobalErrorHandlerMiddleware()
|
||||
@@ -1041,7 +1041,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
|
||||
mock_logger.error.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_mcp_permission_denied_error_becomes_tool_error(self) -> None:
|
||||
"""MCPPermissionDeniedError must convert to ToolError, not a generic error."""
|
||||
from superset.mcp_service.auth import MCPPermissionDeniedError
|
||||
@@ -1070,7 +1070,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
assert "can_write" in str(exc_info.value)
|
||||
assert "Dashboard" in str(exc_info.value)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_mcp_permission_denied_error_is_user_error(self) -> None:
|
||||
"""MCPPermissionDeniedError must be classified as a user error (WARNING)."""
|
||||
from superset.mcp_service.auth import MCPPermissionDeniedError
|
||||
@@ -1081,7 +1081,7 @@ class TestGlobalErrorHandlerLogLevels:
|
||||
)
|
||||
assert _is_user_error(error) is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_mcp_permission_denied_error_logs_at_warning(self) -> None:
|
||||
"""MCPPermissionDeniedError should log at WARNING, not ERROR."""
|
||||
from superset.mcp_service.auth import MCPPermissionDeniedError
|
||||
@@ -1121,7 +1121,7 @@ class TestRBACToolVisibilityMiddleware:
|
||||
tool.name = name
|
||||
return tool
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_fails_open_on_exception(self) -> None:
|
||||
"""Returns all tools when get_flask_app raises (fail open)."""
|
||||
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
|
||||
@@ -1138,7 +1138,7 @@ class TestRBACToolVisibilityMiddleware:
|
||||
|
||||
assert result == tools
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_fails_open_when_user_is_none(self, app) -> None:
|
||||
"""Returns all tools when get_user_from_request returns None."""
|
||||
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
|
||||
@@ -1151,13 +1151,16 @@ class TestRBACToolVisibilityMiddleware:
|
||||
patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
|
||||
),
|
||||
patch("superset.mcp_service.auth.get_user_from_request", return_value=None),
|
||||
patch(
|
||||
"superset.mcp_service.middleware.get_user_from_request",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
result = await middleware.on_list_tools(MagicMock(), call_next)
|
||||
|
||||
assert result == tools
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_filters_tools_by_rbac(self, app) -> None:
|
||||
"""Tools denied by is_tool_visible_to_current_user are removed."""
|
||||
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
|
||||
@@ -1178,11 +1181,11 @@ class TestRBACToolVisibilityMiddleware:
|
||||
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
"superset.mcp_service.middleware.get_user_from_request",
|
||||
return_value=mock_user,
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.auth.is_tool_visible_to_current_user",
|
||||
"superset.mcp_service.middleware.is_tool_visible_to_current_user",
|
||||
side_effect=_visible,
|
||||
),
|
||||
):
|
||||
@@ -1191,7 +1194,7 @@ class TestRBACToolVisibilityMiddleware:
|
||||
assert read_tool in result
|
||||
assert write_tool not in result
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_fails_closed_on_permission_error(self, app) -> None:
|
||||
"""Returns empty list when credentials are invalid (PermissionError)."""
|
||||
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
|
||||
@@ -1205,7 +1208,7 @@ class TestRBACToolVisibilityMiddleware:
|
||||
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
"superset.mcp_service.middleware.get_user_from_request",
|
||||
side_effect=PermissionError("Invalid API key"),
|
||||
),
|
||||
):
|
||||
@@ -1213,7 +1216,7 @@ class TestRBACToolVisibilityMiddleware:
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_fails_closed_on_bad_credentials_value_error(self, app) -> None:
|
||||
"""Returns empty list when auth was attempted but user not found."""
|
||||
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
|
||||
@@ -1227,7 +1230,7 @@ class TestRBACToolVisibilityMiddleware:
|
||||
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
"superset.mcp_service.middleware.get_user_from_request",
|
||||
side_effect=ValueError("User 'ghost' not found in database"),
|
||||
),
|
||||
):
|
||||
@@ -1235,7 +1238,7 @@ class TestRBACToolVisibilityMiddleware:
|
||||
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.asyncio()
|
||||
async def test_fails_open_when_no_auth_configured(self, app) -> None:
|
||||
"""Returns all tools when no auth source is configured at all."""
|
||||
from superset.mcp_service.middleware import RBACToolVisibilityMiddleware
|
||||
@@ -1249,7 +1252,7 @@ class TestRBACToolVisibilityMiddleware:
|
||||
"superset.mcp_service.flask_singleton.get_flask_app", return_value=app
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.auth.get_user_from_request",
|
||||
"superset.mcp_service.middleware.get_user_from_request",
|
||||
side_effect=ValueError("No authenticated user found"),
|
||||
),
|
||||
):
|
||||
|
||||
Reference in New Issue
Block a user