From 842df5ee774fb251a9bc91eaad164baae20e689d Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 20 May 2026 17:55:27 +0000 Subject: [PATCH] fix(mcp): fix 4 failing unit tests and ruff import error in RBAC tool visibility - Fix ruff error: consolidate contextlib imports into single from-import - Fix test patch targets: middleware tests must patch middleware module after imports were promoted to module level (not auth module) - Fix _tool_allowed_for_current_user: pass public tools through when user resolution fails (only hide tools with _class_permission_name) Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/auth.py | 5 +- superset/mcp_service/server.py | 5 +- .../unit_tests/mcp_service/test_middleware.py | 83 ++++++++++--------- 3 files changed, 49 insertions(+), 44 deletions(-) diff --git a/superset/mcp_service/auth.py b/superset/mcp_service/auth.py index e8cb675228a..69d8e625531 100644 --- a/superset/mcp_service/auth.py +++ b/superset/mcp_service/auth.py @@ -44,9 +44,8 @@ Configuration: - MCP_DEV_USERNAME: Fallback username for development """ -import contextlib import logging -from contextlib import AbstractContextManager +from contextlib import AbstractContextManager, nullcontext from typing import Any, Callable, TYPE_CHECKING, TypeVar from flask import current_app, g, has_app_context, has_request_context @@ -659,7 +658,7 @@ def _get_app_context_manager() -> AbstractContextManager[None]: ``RBACToolVisibilityMiddleware`` (tools/list filtering). """ if has_request_context(): - return contextlib.nullcontext() + return nullcontext() if has_app_context(): # Push a new context for the CURRENT app (not get_flask_app() # which may return a different instance in test environments). diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index 9d3c8d5f350..92166606e43 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -411,7 +411,10 @@ def _tool_allowed_for_current_user(tool: Any) -> bool: try: g.user = get_user_from_request() except (ValueError, PermissionError): - return False + # Can't resolve user; only hide protected tools. Public tools + # (no _class_permission_name) pass through regardless. + func = getattr(tool, "fn", tool) + return not getattr(func, "_class_permission_name", None) return is_tool_visible_to_current_user(tool) except (AttributeError, RuntimeError, ValueError): diff --git a/tests/unit_tests/mcp_service/test_middleware.py b/tests/unit_tests/mcp_service/test_middleware.py index 00e3f9457f0..4056a1f6b0b 100644 --- a/tests/unit_tests/mcp_service/test_middleware.py +++ b/tests/unit_tests/mcp_service/test_middleware.py @@ -73,7 +73,7 @@ class TestResponseSizeGuardMiddleware: ) assert middleware.excluded_tools == {"health_check"} - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_allows_small_response(self) -> None: """Should allow responses under token limit.""" middleware = ResponseSizeGuardMiddleware(token_limit=25000) @@ -96,7 +96,7 @@ class TestResponseSizeGuardMiddleware: assert result == small_response call_next.assert_called_once_with(context) - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_blocks_large_response(self) -> None: """Should block responses over token limit.""" middleware = ResponseSizeGuardMiddleware(token_limit=100) # Very low limit @@ -124,7 +124,7 @@ class TestResponseSizeGuardMiddleware: assert "Response too large" in error_message assert "limit" in error_message.lower() - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_skips_excluded_tools(self) -> None: """Should skip checking for excluded tools.""" middleware = ResponseSizeGuardMiddleware( @@ -144,7 +144,7 @@ class TestResponseSizeGuardMiddleware: result = await middleware.on_call_tool(context, call_next) assert result == large_response - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_logs_warning_at_threshold(self) -> None: """Should log warning when approaching limit. @@ -180,7 +180,7 @@ class TestResponseSizeGuardMiddleware: # Should log warning mock_logger.warning.assert_called() - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_error_includes_suggestions(self) -> None: """Should include suggestions in error message.""" middleware = ResponseSizeGuardMiddleware(token_limit=100) @@ -205,7 +205,7 @@ class TestResponseSizeGuardMiddleware: # Should suggest reducing page_size assert "page_size" in error_message.lower() or "limit" in error_message.lower() - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_logs_size_exceeded_event(self) -> None: """Should log to event logger when size exceeded.""" middleware = ResponseSizeGuardMiddleware(token_limit=100) @@ -229,7 +229,7 @@ class TestResponseSizeGuardMiddleware: call_args = mock_event_logger.log.call_args assert call_args.kwargs["action"] == "mcp_response_size_exceeded" - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_truncates_info_tool_instead_of_blocking(self) -> None: """Should truncate info tool responses instead of blocking them.""" middleware = ResponseSizeGuardMiddleware(token_limit=500) @@ -258,7 +258,7 @@ class TestResponseSizeGuardMiddleware: assert result["_response_truncated"] is True assert "[truncated" in result["description"] - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_truncates_chart_info_with_large_form_data(self) -> None: """Should truncate get_chart_info with large form_data.""" middleware = ResponseSizeGuardMiddleware(token_limit=500) @@ -284,7 +284,7 @@ class TestResponseSizeGuardMiddleware: assert result["id"] == 1 assert result["_response_truncated"] is True - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_still_blocks_non_info_tools(self) -> None: """Should still block non-info tools that exceed limit.""" middleware = ResponseSizeGuardMiddleware(token_limit=100) @@ -303,7 +303,7 @@ class TestResponseSizeGuardMiddleware: ): await middleware.on_call_tool(context, call_next) - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_logs_truncation_event(self) -> None: """Should log mcp_response_truncated event on successful truncation.""" middleware = ResponseSizeGuardMiddleware(token_limit=500) @@ -591,7 +591,7 @@ class TestToolResultWrapping: meta=meta, ) - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_info_tool_result_is_truncated_and_rewrapped(self) -> None: """Truncate a ToolResult-wrapped info response and return a ToolResult.""" from fastmcp.tools.tool import ToolResult @@ -620,7 +620,7 @@ class TestToolResultWrapping: assert reparsed["_response_truncated"] is True assert "[truncated" in reparsed["description"] - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_small_tool_result_passes_through_unchanged(self) -> None: """Should return the original ToolResult when within the token limit.""" @@ -641,7 +641,7 @@ class TestToolResultWrapping: assert result is tool_result - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_large_non_info_tool_result_is_blocked(self) -> None: """Should raise ToolError for a non-info ToolResult that exceeds the limit.""" middleware = ResponseSizeGuardMiddleware(token_limit=100) @@ -662,7 +662,7 @@ class TestToolResultWrapping: ): await middleware.on_call_tool(context, call_next) - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_meta_preserved_after_truncation(self) -> None: """Should preserve the original ToolResult meta through truncation.""" from fastmcp.tools.tool import ToolResult @@ -694,7 +694,7 @@ class TestToolResultWrapping: class TestMiddlewareIntegration: """Integration tests for middleware behavior.""" - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_pydantic_model_response(self) -> None: """Should handle Pydantic model responses.""" from pydantic import BaseModel @@ -720,7 +720,7 @@ class TestMiddlewareIntegration: assert result == response - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_list_response(self) -> None: """Should handle list responses.""" middleware = ResponseSizeGuardMiddleware(token_limit=25000) @@ -740,7 +740,7 @@ class TestMiddlewareIntegration: assert result == response - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_string_response(self) -> None: """Should handle string responses.""" middleware = ResponseSizeGuardMiddleware(token_limit=25000) @@ -859,7 +859,7 @@ class TestIsUserError: class TestGlobalErrorHandlerLogLevels: """Test that GlobalErrorHandlerMiddleware logs at correct levels.""" - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_user_error_logs_warning(self) -> None: """User errors (e.g. ValueError) should log at WARNING.""" middleware = GlobalErrorHandlerMiddleware() @@ -882,7 +882,7 @@ class TestGlobalErrorHandlerLogLevels: mock_logger.warning.assert_called() mock_logger.error.assert_not_called() - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_system_error_logs_error(self) -> None: """System errors (OperationalError, generic Exception) should log at ERROR.""" middleware = GlobalErrorHandlerMiddleware() @@ -904,7 +904,7 @@ class TestGlobalErrorHandlerLogLevels: # Should log at ERROR mock_logger.error.assert_called() - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_unexpected_error_logs_error(self) -> None: """Truly unexpected errors should log at ERROR with error_id.""" middleware = GlobalErrorHandlerMiddleware() @@ -926,7 +926,7 @@ class TestGlobalErrorHandlerLogLevels: # Should log at ERROR (both the classification log and the error_id log) assert mock_logger.error.call_count >= 1 - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_event_logger_includes_severity(self) -> None: """Event logger payload should include severity field.""" middleware = GlobalErrorHandlerMiddleware() @@ -949,7 +949,7 @@ class TestGlobalErrorHandlerLogLevels: payload = mock_event_logger.log.call_args.kwargs["curated_payload"] assert payload["severity"] == "warning" - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_permission_error_logs_warning(self) -> None: """PermissionError should log at WARNING — agents are expected to try tools they lack access to.""" @@ -972,7 +972,7 @@ class TestGlobalErrorHandlerLogLevels: mock_logger.warning.assert_called() mock_logger.error.assert_not_called() - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_connection_error_logs_error(self) -> None: """ConnectionError should log at ERROR — infrastructure issue.""" middleware = GlobalErrorHandlerMiddleware() @@ -993,7 +993,7 @@ class TestGlobalErrorHandlerLogLevels: mock_logger.error.assert_called() - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_superset_exception_4xx_logs_warning(self) -> None: """SupersetException with 4xx status should log at WARNING.""" middleware = GlobalErrorHandlerMiddleware() @@ -1017,7 +1017,7 @@ class TestGlobalErrorHandlerLogLevels: mock_logger.warning.assert_called() mock_logger.error.assert_not_called() - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_superset_exception_5xx_logs_error(self) -> None: """SupersetException with 5xx status should log at ERROR.""" middleware = GlobalErrorHandlerMiddleware() @@ -1041,7 +1041,7 @@ class TestGlobalErrorHandlerLogLevels: mock_logger.error.assert_called() - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_mcp_permission_denied_error_becomes_tool_error(self) -> None: """MCPPermissionDeniedError must convert to ToolError, not a generic error.""" from superset.mcp_service.auth import MCPPermissionDeniedError @@ -1070,7 +1070,7 @@ class TestGlobalErrorHandlerLogLevels: assert "can_write" in str(exc_info.value) assert "Dashboard" in str(exc_info.value) - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_mcp_permission_denied_error_is_user_error(self) -> None: """MCPPermissionDeniedError must be classified as a user error (WARNING).""" from superset.mcp_service.auth import MCPPermissionDeniedError @@ -1081,7 +1081,7 @@ class TestGlobalErrorHandlerLogLevels: ) assert _is_user_error(error) is True - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_mcp_permission_denied_error_logs_at_warning(self) -> None: """MCPPermissionDeniedError should log at WARNING, not ERROR.""" from superset.mcp_service.auth import MCPPermissionDeniedError @@ -1121,7 +1121,7 @@ class TestRBACToolVisibilityMiddleware: tool.name = name return tool - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_fails_open_on_exception(self) -> None: """Returns all tools when get_flask_app raises (fail open).""" from superset.mcp_service.middleware import RBACToolVisibilityMiddleware @@ -1138,7 +1138,7 @@ class TestRBACToolVisibilityMiddleware: assert result == tools - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_fails_open_when_user_is_none(self, app) -> None: """Returns all tools when get_user_from_request returns None.""" from superset.mcp_service.middleware import RBACToolVisibilityMiddleware @@ -1151,13 +1151,16 @@ class TestRBACToolVisibilityMiddleware: patch( "superset.mcp_service.flask_singleton.get_flask_app", return_value=app ), - patch("superset.mcp_service.auth.get_user_from_request", return_value=None), + patch( + "superset.mcp_service.middleware.get_user_from_request", + return_value=None, + ), ): result = await middleware.on_list_tools(MagicMock(), call_next) assert result == tools - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_filters_tools_by_rbac(self, app) -> None: """Tools denied by is_tool_visible_to_current_user are removed.""" from superset.mcp_service.middleware import RBACToolVisibilityMiddleware @@ -1178,11 +1181,11 @@ class TestRBACToolVisibilityMiddleware: "superset.mcp_service.flask_singleton.get_flask_app", return_value=app ), patch( - "superset.mcp_service.auth.get_user_from_request", + "superset.mcp_service.middleware.get_user_from_request", return_value=mock_user, ), patch( - "superset.mcp_service.auth.is_tool_visible_to_current_user", + "superset.mcp_service.middleware.is_tool_visible_to_current_user", side_effect=_visible, ), ): @@ -1191,7 +1194,7 @@ class TestRBACToolVisibilityMiddleware: assert read_tool in result assert write_tool not in result - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_fails_closed_on_permission_error(self, app) -> None: """Returns empty list when credentials are invalid (PermissionError).""" from superset.mcp_service.middleware import RBACToolVisibilityMiddleware @@ -1205,7 +1208,7 @@ class TestRBACToolVisibilityMiddleware: "superset.mcp_service.flask_singleton.get_flask_app", return_value=app ), patch( - "superset.mcp_service.auth.get_user_from_request", + "superset.mcp_service.middleware.get_user_from_request", side_effect=PermissionError("Invalid API key"), ), ): @@ -1213,7 +1216,7 @@ class TestRBACToolVisibilityMiddleware: assert result == [] - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_fails_closed_on_bad_credentials_value_error(self, app) -> None: """Returns empty list when auth was attempted but user not found.""" from superset.mcp_service.middleware import RBACToolVisibilityMiddleware @@ -1227,7 +1230,7 @@ class TestRBACToolVisibilityMiddleware: "superset.mcp_service.flask_singleton.get_flask_app", return_value=app ), patch( - "superset.mcp_service.auth.get_user_from_request", + "superset.mcp_service.middleware.get_user_from_request", side_effect=ValueError("User 'ghost' not found in database"), ), ): @@ -1235,7 +1238,7 @@ class TestRBACToolVisibilityMiddleware: assert result == [] - @pytest.mark.asyncio + @pytest.mark.asyncio() async def test_fails_open_when_no_auth_configured(self, app) -> None: """Returns all tools when no auth source is configured at all.""" from superset.mcp_service.middleware import RBACToolVisibilityMiddleware @@ -1249,7 +1252,7 @@ class TestRBACToolVisibilityMiddleware: "superset.mcp_service.flask_singleton.get_flask_app", return_value=app ), patch( - "superset.mcp_service.auth.get_user_from_request", + "superset.mcp_service.middleware.get_user_from_request", side_effect=ValueError("No authenticated user found"), ), ):