feat(mcp): add response size guard to prevent oversized responses (#37200)

This commit is contained in:
Amin Ghadersohi
2026-02-25 12:43:14 -05:00
committed by GitHub
parent c54b21ef98
commit cc1128a404
7 changed files with 1373 additions and 4 deletions

View File

@@ -0,0 +1,343 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP service middleware.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastmcp.exceptions import ToolError
from superset.mcp_service.middleware import (
create_response_size_guard_middleware,
ResponseSizeGuardMiddleware,
)
class TestResponseSizeGuardMiddleware:
"""Test ResponseSizeGuardMiddleware class."""
def test_init_default_values(self) -> None:
"""Should initialize with default values."""
middleware = ResponseSizeGuardMiddleware()
assert middleware.token_limit == 25_000
assert middleware.warn_threshold_pct == 80
assert middleware.warn_threshold == 20000
assert middleware.excluded_tools == set()
def test_init_custom_values(self) -> None:
"""Should initialize with custom values."""
middleware = ResponseSizeGuardMiddleware(
token_limit=10000,
warn_threshold_pct=70,
excluded_tools=["health_check", "get_chart_preview"],
)
assert middleware.token_limit == 10000
assert middleware.warn_threshold_pct == 70
assert middleware.warn_threshold == 7000
assert middleware.excluded_tools == {"health_check", "get_chart_preview"}
def test_init_excluded_tools_as_string(self) -> None:
"""Should handle excluded_tools as a single string."""
middleware = ResponseSizeGuardMiddleware(
excluded_tools="health_check",
)
assert middleware.excluded_tools == {"health_check"}
@pytest.mark.asyncio
async def test_allows_small_response(self) -> None:
"""Should allow responses under token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
# Create mock context
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
# Create mock call_next that returns small response
small_response = {"charts": [{"id": 1, "name": "test"}]}
call_next = AsyncMock(return_value=small_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result == small_response
call_next.assert_called_once_with(context)
@pytest.mark.asyncio
async def test_blocks_large_response(self) -> None:
"""Should block responses over token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100) # Very low limit
# Create mock context
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {"page_size": 100}
# Create large response
large_response = {
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(1000)]
}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError) as exc_info,
):
await middleware.on_call_tool(context, call_next)
# Verify error contains helpful information
error_message = str(exc_info.value)
assert "Response too large" in error_message
assert "limit" in error_message.lower()
@pytest.mark.asyncio
async def test_skips_excluded_tools(self) -> None:
"""Should skip checking for excluded tools."""
middleware = ResponseSizeGuardMiddleware(
token_limit=100, excluded_tools=["health_check"]
)
# Create mock context for excluded tool
context = MagicMock()
context.message.name = "health_check"
context.message.params = {}
# Create response that would exceed limit
large_response = {"data": "x" * 10000}
call_next = AsyncMock(return_value=large_response)
# Should not raise even though response exceeds limit
result = await middleware.on_call_tool(context, call_next)
assert result == large_response
@pytest.mark.asyncio
async def test_logs_warning_at_threshold(self) -> None:
"""Should log warning when approaching limit."""
middleware = ResponseSizeGuardMiddleware(
token_limit=1000, warn_threshold_pct=80
)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
# Response at ~85% of limit (should trigger warning but not block)
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
):
result = await middleware.on_call_tool(context, call_next)
# Should return response (not blocked)
assert result == response
# Should log warning
mock_logger.warning.assert_called()
@pytest.mark.asyncio
async def test_error_includes_suggestions(self) -> None:
"""Should include suggestions in error message."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {"page_size": 100}
large_response = {"charts": [{"id": i} for i in range(1000)]}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError) as exc_info,
):
await middleware.on_call_tool(context, call_next)
error_message = str(exc_info.value)
# Should have numbered suggestions
assert "1." in error_message
# Should suggest reducing page_size
assert "page_size" in error_message.lower() or "limit" in error_message.lower()
@pytest.mark.asyncio
async def test_logs_size_exceeded_event(self) -> None:
"""Should log to event logger when size exceeded."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
large_response = {"data": "x" * 10000}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger") as mock_event_logger,
pytest.raises(ToolError),
):
await middleware.on_call_tool(context, call_next)
# Should log to event logger
mock_event_logger.log.assert_called()
call_args = mock_event_logger.log.call_args
assert call_args.kwargs["action"] == "mcp_response_size_exceeded"
class TestCreateResponseSizeGuardMiddleware:
"""Test create_response_size_guard_middleware factory function."""
def test_creates_middleware_when_enabled(self) -> None:
"""Should create middleware when enabled in config."""
mock_config = {
"enabled": True,
"token_limit": 30000,
"warn_threshold_pct": 75,
"excluded_tools": ["health_check"],
}
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = mock_config
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is not None
assert isinstance(middleware, ResponseSizeGuardMiddleware)
assert middleware.token_limit == 30000
assert middleware.warn_threshold_pct == 75
assert "health_check" in middleware.excluded_tools
def test_returns_none_when_disabled(self) -> None:
"""Should return None when disabled in config."""
mock_config = {"enabled": False}
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = mock_config
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is None
def test_uses_defaults_when_config_missing(self) -> None:
"""Should use defaults when config values are missing."""
mock_config = {"enabled": True} # Only enabled, no other values
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = mock_config
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is not None
assert middleware.token_limit == 25_000 # Default
assert middleware.warn_threshold_pct == 80 # Default
def test_handles_exception_gracefully(self) -> None:
"""Should return None on expected configuration exceptions."""
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
side_effect=ImportError("Config error"),
):
middleware = create_response_size_guard_middleware()
assert middleware is None
class TestMiddlewareIntegration:
"""Integration tests for middleware behavior."""
@pytest.mark.asyncio
async def test_pydantic_model_response(self) -> None:
"""Should handle Pydantic model responses."""
from pydantic import BaseModel
class ChartInfo(BaseModel):
id: int
name: str
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "get_chart_info"
context.message.params = {}
response = ChartInfo(id=1, name="Test Chart")
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result == response
@pytest.mark.asyncio
async def test_list_response(self) -> None:
"""Should handle list responses."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
response = [{"id": 1}, {"id": 2}, {"id": 3}]
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result == response
@pytest.mark.asyncio
async def test_string_response(self) -> None:
"""Should handle string responses."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "health_check"
context.message.params = {}
response = "OK"
call_next = AsyncMock(return_value=response)
result = await middleware.on_call_tool(context, call_next)
assert result == response

View File

@@ -0,0 +1,358 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP service token utilities.
"""
from typing import Any, List
from pydantic import BaseModel
from superset.mcp_service.utils.token_utils import (
CHARS_PER_TOKEN,
estimate_response_tokens,
estimate_token_count,
extract_query_params,
format_size_limit_error,
generate_size_reduction_suggestions,
get_response_size_bytes,
)
class TestEstimateTokenCount:
"""Test estimate_token_count function."""
def test_estimate_string(self) -> None:
"""Should estimate tokens for a string."""
text = "Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
def test_estimate_bytes(self) -> None:
"""Should estimate tokens for bytes."""
text = b"Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
def test_empty_string(self) -> None:
"""Should return 0 for empty string."""
assert estimate_token_count("") == 0
def test_json_like_content(self) -> None:
"""Should estimate tokens for JSON-like content."""
json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}'
result = estimate_token_count(json_str)
assert result > 0
assert result == int(len(json_str) / CHARS_PER_TOKEN)
class TestEstimateResponseTokens:
"""Test estimate_response_tokens function."""
class MockResponse(BaseModel):
"""Mock Pydantic response model."""
name: str
value: int
items: List[Any]
def test_estimate_pydantic_model(self) -> None:
"""Should estimate tokens for Pydantic model."""
response = self.MockResponse(name="test", value=42, items=[1, 2, 3])
result = estimate_response_tokens(response)
assert result > 0
def test_estimate_dict(self) -> None:
"""Should estimate tokens for dict."""
response = {"name": "test", "value": 42}
result = estimate_response_tokens(response)
assert result > 0
def test_estimate_list(self) -> None:
"""Should estimate tokens for list."""
response = [{"name": "item1"}, {"name": "item2"}]
result = estimate_response_tokens(response)
assert result > 0
def test_estimate_string(self) -> None:
"""Should estimate tokens for string response."""
response = "Hello world"
result = estimate_response_tokens(response)
assert result > 0
def test_estimate_large_response(self) -> None:
"""Should estimate tokens for large response."""
response = {"items": [{"name": f"item{i}"} for i in range(1000)]}
result = estimate_response_tokens(response)
assert result > 1000 # Large response should have many tokens
class TestGetResponseSizeBytes:
"""Test get_response_size_bytes function."""
def test_size_dict(self) -> None:
"""Should return size in bytes for dict."""
response = {"name": "test"}
result = get_response_size_bytes(response)
assert result > 0
def test_size_string(self) -> None:
"""Should return size in bytes for string."""
response = "Hello world"
result = get_response_size_bytes(response)
assert result == len(response.encode("utf-8"))
def test_size_bytes(self) -> None:
"""Should return size for bytes."""
response = b"Hello world"
result = get_response_size_bytes(response)
assert result == len(response)
class TestExtractQueryParams:
"""Test extract_query_params function."""
def test_extract_pagination_params(self) -> None:
"""Should extract pagination parameters."""
params = {"page_size": 100, "limit": 50}
result = extract_query_params(params)
assert result["page_size"] == 100
assert result["limit"] == 50
def test_extract_column_selection(self) -> None:
"""Should extract column selection parameters."""
params = {"select_columns": ["name", "id"]}
result = extract_query_params(params)
assert result["select_columns"] == ["name", "id"]
def test_extract_from_nested_request(self) -> None:
"""Should extract from nested request object."""
params = {"request": {"page_size": 50, "filters": [{"col": "name"}]}}
result = extract_query_params(params)
assert result["page_size"] == 50
assert result["filters"] == [{"col": "name"}]
def test_empty_params(self) -> None:
"""Should return empty dict for empty params."""
assert extract_query_params(None) == {}
assert extract_query_params({}) == {}
def test_extract_filters(self) -> None:
"""Should extract filter parameters."""
params = {"filters": [{"col": "name", "opr": "eq", "value": "test"}]}
result = extract_query_params(params)
assert "filters" in result
class TestGenerateSizeReductionSuggestions:
"""Test generate_size_reduction_suggestions function."""
def test_suggest_reduce_page_size(self) -> None:
"""Should suggest reducing page_size when present."""
params = {"page_size": 100}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000,
token_limit=25000,
)
assert any(
"page_size" in s.lower() or "limit" in s.lower() for s in suggestions
)
def test_suggest_add_limit_for_list_tools(self) -> None:
"""Should suggest adding limit for list tools."""
params: dict[str, Any] = {}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000,
token_limit=25000,
)
assert any(
"limit" in s.lower() or "page_size" in s.lower() for s in suggestions
)
def test_suggest_select_columns(self) -> None:
"""Should suggest using select_columns."""
params: dict[str, Any] = {}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000,
token_limit=25000,
)
assert any(
"select_columns" in s.lower() or "columns" in s.lower() for s in suggestions
)
def test_suggest_filters(self) -> None:
"""Should suggest adding filters."""
params: dict[str, Any] = {}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000,
token_limit=25000,
)
assert any("filter" in s.lower() for s in suggestions)
def test_tool_specific_suggestions_execute_sql(self) -> None:
"""Should provide SQL-specific suggestions for execute_sql."""
suggestions = generate_size_reduction_suggestions(
tool_name="execute_sql",
params={"sql": "SELECT * FROM table"},
estimated_tokens=50000,
token_limit=25000,
)
assert any("LIMIT" in s or "limit" in s.lower() for s in suggestions)
def test_tool_specific_suggestions_list_charts(self) -> None:
"""Should provide chart-specific suggestions for list_charts."""
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params={},
estimated_tokens=50000,
token_limit=25000,
)
# Should suggest excluding params or query_context
assert any(
"params" in s.lower() or "query_context" in s.lower() for s in suggestions
)
def test_suggests_search_parameter(self) -> None:
"""Should suggest using search parameter."""
suggestions = generate_size_reduction_suggestions(
tool_name="list_dashboards",
params={},
estimated_tokens=50000,
token_limit=25000,
)
assert any("search" in s.lower() for s in suggestions)
class TestFormatSizeLimitError:
"""Test format_size_limit_error function."""
def test_error_contains_token_counts(self) -> None:
"""Should include token counts in error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={},
estimated_tokens=50000,
token_limit=25000,
)
assert "50,000" in error
assert "25,000" in error
def test_error_contains_tool_name(self) -> None:
"""Should include tool name in error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={},
estimated_tokens=50000,
token_limit=25000,
)
assert "list_charts" in error
def test_error_contains_suggestions(self) -> None:
"""Should include suggestions in error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={"page_size": 100},
estimated_tokens=50000,
token_limit=25000,
)
# Should have numbered suggestions
assert "1." in error
def test_error_contains_reduction_percentage(self) -> None:
"""Should include reduction percentage in error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={},
estimated_tokens=50000,
token_limit=25000,
)
# 50% reduction needed
assert "50%" in error or "Reduction" in error
def test_error_limits_suggestions_to_five(self) -> None:
"""Should limit suggestions to 5."""
error = format_size_limit_error(
tool_name="list_charts",
params={},
estimated_tokens=100000,
token_limit=10000,
)
# Count numbered suggestions (1. through 5.)
suggestion_count = sum(1 for i in range(1, 10) if f"{i}." in error)
assert suggestion_count <= 5
def test_error_message_is_readable(self) -> None:
"""Should produce human-readable error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={"page_size": 100},
estimated_tokens=75000,
token_limit=25000,
)
# Should be multi-line and contain key information
lines = error.split("\n")
assert len(lines) > 5
assert "Response too large" in error
assert "Please modify your query" in error
class TestCalculatedSuggestions:
"""Test that suggestions include calculated values."""
def test_suggested_limit_is_calculated(self) -> None:
"""Should calculate suggested limit based on reduction needed."""
params = {"page_size": 100}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000, # 2x over limit
token_limit=25000,
)
# Find the page_size suggestion
page_size_suggestion = next(
(s for s in suggestions if "page_size" in s.lower()), None
)
assert page_size_suggestion is not None
# Should suggest reducing from 100 to approximately 50
assert "100" in page_size_suggestion
assert (
"50" in page_size_suggestion or "reduction" in page_size_suggestion.lower()
)
def test_reduction_percentage_in_suggestions(self) -> None:
"""Should include reduction percentage in suggestions."""
params = {"page_size": 100}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=75000, # 3x over limit
token_limit=25000,
)
# Should mention ~66% reduction needed (int truncation of 66.6%)
combined = " ".join(suggestions)
assert "66%" in combined