feat(mcp): add response size guard to prevent oversized responses (#37200)

This commit is contained in:
Amin Ghadersohi
2026-02-25 12:43:14 -05:00
committed by GitHub
parent c54b21ef98
commit cc1128a404
7 changed files with 1373 additions and 4 deletions

View File

@@ -0,0 +1,21 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Constants for the MCP service."""
# Response size guard defaults
DEFAULT_TOKEN_LIMIT = 25_000 # ~25k tokens prevents overwhelming LLM context windows
DEFAULT_WARN_THRESHOLD_PCT = 80 # Log warnings above 80% of limit

View File

@@ -22,6 +22,11 @@ from typing import Any, Dict, Optional
from flask import Flask
from superset.mcp_service.constants import (
DEFAULT_TOKEN_LIMIT,
DEFAULT_WARN_THRESHOLD_PCT,
)
logger = logging.getLogger(__name__)
# MCP Service Configuration
@@ -167,6 +172,50 @@ MCP_CACHE_CONFIG: Dict[str, Any] = {
],
}
# =============================================================================
# MCP Response Size Guard Configuration
# =============================================================================
#
# Overview:
# ---------
# The Response Size Guard prevents oversized responses from overwhelming LLM
# clients (e.g., Claude Desktop). When a tool response exceeds the token limit,
# it returns a helpful error with suggestions for reducing the response size.
#
# How it works:
# -------------
# 1. After a tool executes, the middleware estimates the response's token count
# 2. If the response exceeds the configured limit, it blocks the response
# 3. Instead, it returns an error message with smart suggestions:
# - Reduce page_size/limit
# - Use select_columns to exclude large fields
# - Add filters to narrow results
# - Tool-specific recommendations
#
# Configuration:
# --------------
# - enabled: Toggle the guard on/off (default: True)
# - token_limit: Maximum estimated tokens per response (default: 25,000)
# - excluded_tools: Tools to skip checking (e.g., streaming tools)
# - warn_threshold_pct: Log warnings above this % of limit (default: 80%)
#
# Token Estimation:
# -----------------
# Uses character-based heuristic (~3.5 chars per token for JSON).
# This is intentionally conservative to avoid underestimating.
# =============================================================================
MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = {
"enabled": True, # Enabled by default to protect LLM clients
"token_limit": DEFAULT_TOKEN_LIMIT,
"warn_threshold_pct": DEFAULT_WARN_THRESHOLD_PCT,
"excluded_tools": [ # Tools to skip size checking
"health_check", # Always small
"get_chart_preview", # Returns URLs, not data
"generate_explore_link", # Returns URLs
"open_sql_lab_with_context", # Returns URLs
],
}
def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]:
"""Default MCP auth factory using app.config values."""

View File

@@ -27,6 +27,10 @@ from sqlalchemy.exc import OperationalError, TimeoutError
from starlette.exceptions import HTTPException
from superset.extensions import event_logger
from superset.mcp_service.constants import (
DEFAULT_TOKEN_LIMIT,
DEFAULT_WARN_THRESHOLD_PCT,
)
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
@@ -837,3 +841,169 @@ class FieldPermissionsMiddleware(Middleware):
"Unknown response type for field filtering: %s", type(response)
)
return response
class ResponseSizeGuardMiddleware(Middleware):
"""
Middleware that prevents oversized responses from overwhelming LLM clients.
When a tool response exceeds the configured token limit, this middleware
intercepts it and returns a helpful error message with suggestions for
reducing the response size.
This is critical for protecting LLM clients like Claude Desktop which can
crash or become unresponsive when receiving extremely large responses.
Configuration via MCP_RESPONSE_SIZE_CONFIG in superset_config.py:
- enabled: Toggle the guard on/off (default: True)
- token_limit: Maximum estimated tokens per response (default: 25,000)
- warn_threshold_pct: Log warnings above this % of limit (default: 80%)
- excluded_tools: Tools to skip checking
"""
def __init__(
self,
token_limit: int = DEFAULT_TOKEN_LIMIT,
warn_threshold_pct: int = DEFAULT_WARN_THRESHOLD_PCT,
excluded_tools: list[str] | str | None = None,
) -> None:
self.token_limit = token_limit
self.warn_threshold_pct = warn_threshold_pct
self.warn_threshold = int(token_limit * warn_threshold_pct / 100)
if isinstance(excluded_tools, str):
excluded_tools = [excluded_tools]
self.excluded_tools = set(excluded_tools or [])
async def on_call_tool(
self,
context: MiddlewareContext,
call_next: Callable[[MiddlewareContext], Awaitable[Any]],
) -> Any:
"""Check response size after tool execution."""
tool_name = getattr(context.message, "name", "unknown")
# Skip excluded tools
if tool_name in self.excluded_tools:
return await call_next(context)
# Execute the tool
response = await call_next(context)
# Estimate response token count (guard against huge responses causing OOM)
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
format_size_limit_error,
)
try:
estimated_tokens = estimate_response_tokens(response)
except MemoryError as me:
logger.warning(
"MemoryError while estimating tokens for %s: %s", tool_name, me
)
# Treat as over limit to avoid further serialization
estimated_tokens = self.token_limit + 1
except Exception as e: # noqa: BLE001
logger.warning(
"Failed to estimate response tokens for %s: %s", tool_name, e
)
# Conservative fallback: block rather than risk OOM
estimated_tokens = self.token_limit + 1
# Log warning if approaching limit
if estimated_tokens > self.warn_threshold:
logger.warning(
"Response size warning for %s: ~%d tokens (%.0f%% of %d limit)",
tool_name,
estimated_tokens,
(estimated_tokens / self.token_limit * 100) if self.token_limit else 0,
self.token_limit,
)
# Block if over limit
if estimated_tokens > self.token_limit:
# Extract params for smart suggestions
params = getattr(context.message, "params", {}) or {}
# Log the blocked response
logger.error(
"Response blocked for %s: ~%d tokens exceeds limit of %d",
tool_name,
estimated_tokens,
self.token_limit,
)
# Log to event logger for monitoring
try:
user_id = get_user_id()
event_logger.log(
user_id=user_id,
action="mcp_response_size_exceeded",
curated_payload={
"tool": tool_name,
"estimated_tokens": estimated_tokens,
"token_limit": self.token_limit,
"params": params,
},
)
except Exception as log_error: # noqa: BLE001
logger.warning("Failed to log size exceeded event: %s", log_error)
# Generate helpful error message with suggestions
# Avoid passing the full `response` (which may be huge) into the formatter
# to prevent large-memory operations during error formatting.
error_message = format_size_limit_error(
tool_name=tool_name,
params=params,
estimated_tokens=estimated_tokens,
token_limit=self.token_limit,
response=None,
)
raise ToolError(error_message)
return response
def create_response_size_guard_middleware() -> ResponseSizeGuardMiddleware | None:
"""
Factory function to create ResponseSizeGuardMiddleware from config.
Reads configuration from Flask app's MCP_RESPONSE_SIZE_CONFIG.
Returns None if the guard is disabled.
Returns:
ResponseSizeGuardMiddleware instance or None if disabled
"""
try:
from superset.mcp_service.flask_singleton import get_flask_app
from superset.mcp_service.mcp_config import MCP_RESPONSE_SIZE_CONFIG
flask_app = get_flask_app()
# Get config from Flask app, falling back to defaults
config = flask_app.config.get(
"MCP_RESPONSE_SIZE_CONFIG", MCP_RESPONSE_SIZE_CONFIG
)
if not config.get("enabled", True):
logger.info("Response size guard is disabled")
return None
middleware = ResponseSizeGuardMiddleware(
token_limit=int(config.get("token_limit", DEFAULT_TOKEN_LIMIT)),
warn_threshold_pct=int(
config.get("warn_threshold_pct", DEFAULT_WARN_THRESHOLD_PCT)
),
excluded_tools=config.get("excluded_tools"),
)
logger.info(
"Created ResponseSizeGuardMiddleware with token_limit=%d",
middleware.token_limit,
)
return middleware
except (ImportError, AttributeError, KeyError) as e:
logger.error("Failed to create ResponseSizeGuardMiddleware: %s", e)
return None

View File

@@ -29,10 +29,8 @@ from typing import Any
import uvicorn
from superset.mcp_service.app import create_mcp_app, init_fastmcp_server
from superset.mcp_service.mcp_config import (
get_mcp_factory_config,
MCP_STORE_CONFIG,
)
from superset.mcp_service.mcp_config import get_mcp_factory_config, MCP_STORE_CONFIG
from superset.mcp_service.middleware import create_response_size_guard_middleware
from superset.mcp_service.storage import _create_redis_store
@@ -176,6 +174,12 @@ def run_server(
# Build middleware list
middleware_list = []
# Add response size guard (protects LLM clients from huge responses)
if size_guard_middleware := create_response_size_guard_middleware():
middleware_list.append(size_guard_middleware)
# Add caching middleware
caching_middleware = create_response_caching_middleware()
if caching_middleware:
middleware_list.append(caching_middleware)

View File

@@ -0,0 +1,424 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Token counting and response size utilities for MCP service.
This module provides utilities to estimate token counts and generate smart
suggestions when responses exceed configured limits. This prevents large
responses from overwhelming LLM clients like Claude Desktop.
"""
from __future__ import annotations
import logging
from typing import Any, Dict, List, Union
from pydantic import BaseModel
from typing_extensions import TypeAlias
logger = logging.getLogger(__name__)
# Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes)
ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes]
# Approximate characters per token for estimation
# Claude tokenizer averages ~4 chars per token for English text
# JSON tends to be more verbose, so we use a slightly lower ratio
CHARS_PER_TOKEN = 3.5
def estimate_token_count(text: str | bytes) -> int:
"""
Estimate the token count for a given text.
Uses a character-based heuristic since we don't have direct access to
the actual tokenizer. This is conservative to avoid underestimating.
Args:
text: The text to estimate tokens for (string or bytes)
Returns:
Estimated number of tokens
"""
if isinstance(text, bytes):
text = text.decode("utf-8", errors="replace")
# Simple heuristic: ~3.5 characters per token for JSON/code
text_length = len(text)
if text_length == 0:
return 0
return max(1, int(text_length / CHARS_PER_TOKEN))
def estimate_response_tokens(response: ToolResponse) -> int:
"""
Estimate token count for an MCP tool response.
Handles various response types including Pydantic models, dicts, and strings.
Args:
response: The response object to estimate
Returns:
Estimated number of tokens
"""
try:
from superset.utils import json
# Convert response to JSON string for accurate estimation
if hasattr(response, "model_dump"):
# Pydantic model
response_str = json.dumps(response.model_dump())
elif isinstance(response, (dict, list)):
response_str = json.dumps(response)
elif isinstance(response, bytes):
# Delegate to estimate_token_count which handles decoding safely
return estimate_token_count(response)
elif isinstance(response, str):
response_str = response
else:
response_str = str(response)
return estimate_token_count(response_str)
except Exception as e: # noqa: BLE001
logger.warning("Failed to estimate response tokens: %s", e)
# Return a high estimate to be safe (conservative fallback)
return 100000
def get_response_size_bytes(response: ToolResponse) -> int:
"""
Get the size of a response in bytes.
Args:
response: The response object
Returns:
Size in bytes
"""
try:
from superset.utils import json
if hasattr(response, "model_dump"):
response_str = json.dumps(response.model_dump())
elif isinstance(response, (dict, list)):
response_str = json.dumps(response)
elif isinstance(response, bytes):
return len(response)
elif isinstance(response, str):
return len(response.encode("utf-8"))
else:
response_str = str(response)
return len(response_str.encode("utf-8"))
except Exception as e: # noqa: BLE001
logger.warning("Failed to get response size: %s", e)
# Return a conservative large value to avoid allowing oversized responses
# to bypass size checks (returning 0 would underestimate)
return 1_000_000 # 1MB fallback
def extract_query_params(params: Dict[str, Any] | None) -> Dict[str, Any]:
"""
Extract relevant query parameters from tool params for suggestions.
Args:
params: The tool parameters dict
Returns:
Extracted parameters relevant for size reduction suggestions
"""
if not params:
return {}
# Handle nested request object (common pattern in MCP tools)
if "request" in params and isinstance(params["request"], dict):
params = params["request"]
# Keys to extract from params
extract_keys = [
# Pagination
"page_size",
"limit",
# Column selection
"select_columns",
"columns",
# Filters
"filters",
# Search
"search",
]
return {k: params[k] for k in extract_keys if k in params}
def generate_size_reduction_suggestions(
tool_name: str,
params: Dict[str, Any] | None,
estimated_tokens: int,
token_limit: int,
response: ToolResponse | None = None,
) -> List[str]:
"""
Generate smart suggestions for reducing response size.
Analyzes the tool and parameters to provide actionable recommendations.
Args:
tool_name: Name of the MCP tool
params: The tool parameters
estimated_tokens: Estimated token count of the response
token_limit: Configured token limit
response: Optional response object for additional analysis
Returns:
List of suggestion strings
"""
suggestions = []
query_params = extract_query_params(params)
reduction_needed = estimated_tokens - token_limit
reduction_pct = (
int((reduction_needed / estimated_tokens) * 100) if estimated_tokens else 0
)
# Suggestion 1: Reduce page_size or limit
raw_page_size = query_params.get("page_size") or query_params.get("limit")
try:
current_page_size = int(raw_page_size) if raw_page_size is not None else None
except (TypeError, ValueError):
current_page_size = None
if current_page_size and current_page_size > 0:
# Calculate suggested new limit based on reduction needed
suggested_limit = max(
1,
int(current_page_size * (token_limit / estimated_tokens))
if estimated_tokens
else 1,
)
suggestions.append(
f"Reduce page_size/limit from {current_page_size} to {suggested_limit} "
f"(need ~{reduction_pct}% reduction)"
)
else:
# No limit specified - suggest adding one
if "list_" in tool_name or tool_name.startswith("get_chart_data"):
suggestions.append(
f"Add a 'limit' or 'page_size' parameter (suggested: 10-25 items) "
f"to reduce response size by ~{reduction_pct}%"
)
# Suggestion 2: Use select_columns to reduce fields
current_columns = query_params.get("select_columns") or query_params.get("columns")
if current_columns and len(current_columns) > 5:
preview = ", ".join(str(c) for c in current_columns[:3])
suffix = "..." if len(current_columns) > 3 else ""
suggestions.append(
f"Reduce select_columns from {len(current_columns)} columns to only "
f"essential fields (currently: {preview}{suffix})"
)
elif not current_columns and "list_" in tool_name:
# Analyze response to suggest specific columns to exclude
large_fields = _identify_large_fields(response)
if large_fields:
fields_preview = ", ".join(large_fields[:3])
suffix = "..." if len(large_fields) > 3 else ""
suggestions.append(
f"Use 'select_columns' to exclude large fields: "
f"{fields_preview}{suffix}. Only request the columns you need."
)
else:
suggestions.append(
"Use 'select_columns' to request only the specific columns you need "
"instead of fetching all fields"
)
# Suggestion 3: Add filters to reduce result set
current_filters = query_params.get("filters")
if not current_filters and "list_" in tool_name:
suggestions.append(
"Add filters to narrow down results (e.g., filter by owner, "
"date range, or specific attributes)"
)
# Suggestion 4: Tool-specific suggestions
tool_suggestions = _get_tool_specific_suggestions(tool_name, query_params, response)
suggestions.extend(tool_suggestions)
# Suggestion 5: Use search instead of listing all
if "list_" in tool_name and not query_params.get("search"):
suggestions.append(
"Use the 'search' parameter to find specific items instead of "
"listing all and filtering client-side"
)
return suggestions
def _identify_large_fields(response: ToolResponse) -> List[str]:
"""
Identify fields that contribute most to response size.
Args:
response: The response object to analyze
Returns:
List of field names that are particularly large
"""
large_fields: List[str] = []
try:
from superset.utils import json
if hasattr(response, "model_dump"):
data = response.model_dump()
elif isinstance(response, dict):
data = response
else:
return large_fields
# Check for list responses (e.g., charts, dashboards)
items_key = None
for key in ["charts", "dashboards", "datasets", "data", "items", "results"]:
if key in data and isinstance(data[key], list) and data[key]:
items_key = key
break
if items_key and data[items_key]:
first_item = data[items_key][0]
if isinstance(first_item, dict):
# Analyze field sizes in first item
field_sizes = {}
for field, value in first_item.items():
if value is not None:
field_sizes[field] = len(json.dumps(value))
# Sort by size and identify large fields (>500 chars)
sorted_fields = sorted(
field_sizes.items(), key=lambda x: x[1], reverse=True
)
large_fields = [
f
for f, size in sorted_fields
if size > 500 and f not in ("id", "uuid", "name", "slice_name")
]
except Exception as e: # noqa: BLE001
logger.debug("Failed to identify large fields: %s", e)
return large_fields
def _get_tool_specific_suggestions(
tool_name: str,
query_params: Dict[str, Any],
response: ToolResponse,
) -> List[str]:
"""
Generate tool-specific suggestions based on the tool being called.
Args:
tool_name: Name of the MCP tool
query_params: Extracted query parameters
response: The response object
Returns:
List of tool-specific suggestions
"""
suggestions = []
if tool_name == "get_chart_data":
suggestions.append(
"For get_chart_data, use 'limit' parameter to restrict rows returned, "
"or use 'format=csv' for more compact output"
)
elif tool_name == "execute_sql":
suggestions.append(
"Add LIMIT clause to your SQL query to restrict the number of rows "
"(e.g., SELECT * FROM table LIMIT 100)"
)
elif tool_name in ("get_chart_info", "get_dashboard_info", "get_dataset_info"):
suggestions.append(
f"For {tool_name}, use 'select_columns' to fetch only specific metadata "
"fields instead of the full object"
)
elif tool_name == "list_charts":
suggestions.append(
"For list_charts, exclude 'params' and 'query_context' columns which "
"contain large JSON blobs - use select_columns to pick specific fields"
)
elif tool_name == "list_datasets":
suggestions.append(
"For list_datasets, exclude 'columns' and 'metrics' from select_columns "
"if you only need basic dataset info"
)
return suggestions
def format_size_limit_error(
tool_name: str,
params: Dict[str, Any] | None,
estimated_tokens: int,
token_limit: int,
response: ToolResponse | None = None,
) -> str:
"""
Format a user-friendly error message when response exceeds token limit.
Args:
tool_name: Name of the MCP tool
params: The tool parameters
estimated_tokens: Estimated token count
token_limit: Configured token limit
response: Optional response for analysis
Returns:
Formatted error message with suggestions
"""
suggestions = generate_size_reduction_suggestions(
tool_name, params, estimated_tokens, token_limit, response
)
error_lines = [
f"Response too large: ~{estimated_tokens:,} tokens (limit: {token_limit:,})",
"",
"This response would overwhelm the LLM context window.",
"Please modify your query to reduce the response size:",
"",
]
for i, suggestion in enumerate(suggestions[:5], 1): # Limit to top 5 suggestions
error_lines.append(f"{i}. {suggestion}")
reduction_pct = (
(estimated_tokens - token_limit) / estimated_tokens * 100
if estimated_tokens
else 0
)
error_lines.extend(
[
"",
f"Tool: {tool_name}",
f"Reduction needed: ~{reduction_pct:.0f}%",
]
)
return "\n".join(error_lines)

View File

@@ -0,0 +1,343 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP service middleware.
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastmcp.exceptions import ToolError
from superset.mcp_service.middleware import (
create_response_size_guard_middleware,
ResponseSizeGuardMiddleware,
)
class TestResponseSizeGuardMiddleware:
"""Test ResponseSizeGuardMiddleware class."""
def test_init_default_values(self) -> None:
"""Should initialize with default values."""
middleware = ResponseSizeGuardMiddleware()
assert middleware.token_limit == 25_000
assert middleware.warn_threshold_pct == 80
assert middleware.warn_threshold == 20000
assert middleware.excluded_tools == set()
def test_init_custom_values(self) -> None:
"""Should initialize with custom values."""
middleware = ResponseSizeGuardMiddleware(
token_limit=10000,
warn_threshold_pct=70,
excluded_tools=["health_check", "get_chart_preview"],
)
assert middleware.token_limit == 10000
assert middleware.warn_threshold_pct == 70
assert middleware.warn_threshold == 7000
assert middleware.excluded_tools == {"health_check", "get_chart_preview"}
def test_init_excluded_tools_as_string(self) -> None:
"""Should handle excluded_tools as a single string."""
middleware = ResponseSizeGuardMiddleware(
excluded_tools="health_check",
)
assert middleware.excluded_tools == {"health_check"}
@pytest.mark.asyncio
async def test_allows_small_response(self) -> None:
"""Should allow responses under token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
# Create mock context
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
# Create mock call_next that returns small response
small_response = {"charts": [{"id": 1, "name": "test"}]}
call_next = AsyncMock(return_value=small_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result == small_response
call_next.assert_called_once_with(context)
@pytest.mark.asyncio
async def test_blocks_large_response(self) -> None:
"""Should block responses over token limit."""
middleware = ResponseSizeGuardMiddleware(token_limit=100) # Very low limit
# Create mock context
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {"page_size": 100}
# Create large response
large_response = {
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(1000)]
}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError) as exc_info,
):
await middleware.on_call_tool(context, call_next)
# Verify error contains helpful information
error_message = str(exc_info.value)
assert "Response too large" in error_message
assert "limit" in error_message.lower()
@pytest.mark.asyncio
async def test_skips_excluded_tools(self) -> None:
"""Should skip checking for excluded tools."""
middleware = ResponseSizeGuardMiddleware(
token_limit=100, excluded_tools=["health_check"]
)
# Create mock context for excluded tool
context = MagicMock()
context.message.name = "health_check"
context.message.params = {}
# Create response that would exceed limit
large_response = {"data": "x" * 10000}
call_next = AsyncMock(return_value=large_response)
# Should not raise even though response exceeds limit
result = await middleware.on_call_tool(context, call_next)
assert result == large_response
@pytest.mark.asyncio
async def test_logs_warning_at_threshold(self) -> None:
"""Should log warning when approaching limit."""
middleware = ResponseSizeGuardMiddleware(
token_limit=1000, warn_threshold_pct=80
)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
# Response at ~85% of limit (should trigger warning but not block)
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch("superset.mcp_service.middleware.logger") as mock_logger,
):
result = await middleware.on_call_tool(context, call_next)
# Should return response (not blocked)
assert result == response
# Should log warning
mock_logger.warning.assert_called()
@pytest.mark.asyncio
async def test_error_includes_suggestions(self) -> None:
"""Should include suggestions in error message."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {"page_size": 100}
large_response = {"charts": [{"id": i} for i in range(1000)]}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
pytest.raises(ToolError) as exc_info,
):
await middleware.on_call_tool(context, call_next)
error_message = str(exc_info.value)
# Should have numbered suggestions
assert "1." in error_message
# Should suggest reducing page_size
assert "page_size" in error_message.lower() or "limit" in error_message.lower()
@pytest.mark.asyncio
async def test_logs_size_exceeded_event(self) -> None:
"""Should log to event logger when size exceeded."""
middleware = ResponseSizeGuardMiddleware(token_limit=100)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
large_response = {"data": "x" * 10000}
call_next = AsyncMock(return_value=large_response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger") as mock_event_logger,
pytest.raises(ToolError),
):
await middleware.on_call_tool(context, call_next)
# Should log to event logger
mock_event_logger.log.assert_called()
call_args = mock_event_logger.log.call_args
assert call_args.kwargs["action"] == "mcp_response_size_exceeded"
class TestCreateResponseSizeGuardMiddleware:
"""Test create_response_size_guard_middleware factory function."""
def test_creates_middleware_when_enabled(self) -> None:
"""Should create middleware when enabled in config."""
mock_config = {
"enabled": True,
"token_limit": 30000,
"warn_threshold_pct": 75,
"excluded_tools": ["health_check"],
}
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = mock_config
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is not None
assert isinstance(middleware, ResponseSizeGuardMiddleware)
assert middleware.token_limit == 30000
assert middleware.warn_threshold_pct == 75
assert "health_check" in middleware.excluded_tools
def test_returns_none_when_disabled(self) -> None:
"""Should return None when disabled in config."""
mock_config = {"enabled": False}
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = mock_config
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is None
def test_uses_defaults_when_config_missing(self) -> None:
"""Should use defaults when config values are missing."""
mock_config = {"enabled": True} # Only enabled, no other values
mock_flask_app = MagicMock()
mock_flask_app.config.get.return_value = mock_config
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
return_value=mock_flask_app,
):
middleware = create_response_size_guard_middleware()
assert middleware is not None
assert middleware.token_limit == 25_000 # Default
assert middleware.warn_threshold_pct == 80 # Default
def test_handles_exception_gracefully(self) -> None:
"""Should return None on expected configuration exceptions."""
with patch(
"superset.mcp_service.flask_singleton.get_flask_app",
side_effect=ImportError("Config error"),
):
middleware = create_response_size_guard_middleware()
assert middleware is None
class TestMiddlewareIntegration:
"""Integration tests for middleware behavior."""
@pytest.mark.asyncio
async def test_pydantic_model_response(self) -> None:
"""Should handle Pydantic model responses."""
from pydantic import BaseModel
class ChartInfo(BaseModel):
id: int
name: str
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "get_chart_info"
context.message.params = {}
response = ChartInfo(id=1, name="Test Chart")
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result == response
@pytest.mark.asyncio
async def test_list_response(self) -> None:
"""Should handle list responses."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "list_charts"
context.message.params = {}
response = [{"id": 1}, {"id": 2}, {"id": 3}]
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
):
result = await middleware.on_call_tool(context, call_next)
assert result == response
@pytest.mark.asyncio
async def test_string_response(self) -> None:
"""Should handle string responses."""
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
context = MagicMock()
context.message.name = "health_check"
context.message.params = {}
response = "OK"
call_next = AsyncMock(return_value=response)
result = await middleware.on_call_tool(context, call_next)
assert result == response

View File

@@ -0,0 +1,358 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP service token utilities.
"""
from typing import Any, List
from pydantic import BaseModel
from superset.mcp_service.utils.token_utils import (
CHARS_PER_TOKEN,
estimate_response_tokens,
estimate_token_count,
extract_query_params,
format_size_limit_error,
generate_size_reduction_suggestions,
get_response_size_bytes,
)
class TestEstimateTokenCount:
"""Test estimate_token_count function."""
def test_estimate_string(self) -> None:
"""Should estimate tokens for a string."""
text = "Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
def test_estimate_bytes(self) -> None:
"""Should estimate tokens for bytes."""
text = b"Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
def test_empty_string(self) -> None:
"""Should return 0 for empty string."""
assert estimate_token_count("") == 0
def test_json_like_content(self) -> None:
"""Should estimate tokens for JSON-like content."""
json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}'
result = estimate_token_count(json_str)
assert result > 0
assert result == int(len(json_str) / CHARS_PER_TOKEN)
class TestEstimateResponseTokens:
"""Test estimate_response_tokens function."""
class MockResponse(BaseModel):
"""Mock Pydantic response model."""
name: str
value: int
items: List[Any]
def test_estimate_pydantic_model(self) -> None:
"""Should estimate tokens for Pydantic model."""
response = self.MockResponse(name="test", value=42, items=[1, 2, 3])
result = estimate_response_tokens(response)
assert result > 0
def test_estimate_dict(self) -> None:
"""Should estimate tokens for dict."""
response = {"name": "test", "value": 42}
result = estimate_response_tokens(response)
assert result > 0
def test_estimate_list(self) -> None:
"""Should estimate tokens for list."""
response = [{"name": "item1"}, {"name": "item2"}]
result = estimate_response_tokens(response)
assert result > 0
def test_estimate_string(self) -> None:
"""Should estimate tokens for string response."""
response = "Hello world"
result = estimate_response_tokens(response)
assert result > 0
def test_estimate_large_response(self) -> None:
"""Should estimate tokens for large response."""
response = {"items": [{"name": f"item{i}"} for i in range(1000)]}
result = estimate_response_tokens(response)
assert result > 1000 # Large response should have many tokens
class TestGetResponseSizeBytes:
"""Test get_response_size_bytes function."""
def test_size_dict(self) -> None:
"""Should return size in bytes for dict."""
response = {"name": "test"}
result = get_response_size_bytes(response)
assert result > 0
def test_size_string(self) -> None:
"""Should return size in bytes for string."""
response = "Hello world"
result = get_response_size_bytes(response)
assert result == len(response.encode("utf-8"))
def test_size_bytes(self) -> None:
"""Should return size for bytes."""
response = b"Hello world"
result = get_response_size_bytes(response)
assert result == len(response)
class TestExtractQueryParams:
"""Test extract_query_params function."""
def test_extract_pagination_params(self) -> None:
"""Should extract pagination parameters."""
params = {"page_size": 100, "limit": 50}
result = extract_query_params(params)
assert result["page_size"] == 100
assert result["limit"] == 50
def test_extract_column_selection(self) -> None:
"""Should extract column selection parameters."""
params = {"select_columns": ["name", "id"]}
result = extract_query_params(params)
assert result["select_columns"] == ["name", "id"]
def test_extract_from_nested_request(self) -> None:
"""Should extract from nested request object."""
params = {"request": {"page_size": 50, "filters": [{"col": "name"}]}}
result = extract_query_params(params)
assert result["page_size"] == 50
assert result["filters"] == [{"col": "name"}]
def test_empty_params(self) -> None:
"""Should return empty dict for empty params."""
assert extract_query_params(None) == {}
assert extract_query_params({}) == {}
def test_extract_filters(self) -> None:
"""Should extract filter parameters."""
params = {"filters": [{"col": "name", "opr": "eq", "value": "test"}]}
result = extract_query_params(params)
assert "filters" in result
class TestGenerateSizeReductionSuggestions:
"""Test generate_size_reduction_suggestions function."""
def test_suggest_reduce_page_size(self) -> None:
"""Should suggest reducing page_size when present."""
params = {"page_size": 100}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000,
token_limit=25000,
)
assert any(
"page_size" in s.lower() or "limit" in s.lower() for s in suggestions
)
def test_suggest_add_limit_for_list_tools(self) -> None:
"""Should suggest adding limit for list tools."""
params: dict[str, Any] = {}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000,
token_limit=25000,
)
assert any(
"limit" in s.lower() or "page_size" in s.lower() for s in suggestions
)
def test_suggest_select_columns(self) -> None:
"""Should suggest using select_columns."""
params: dict[str, Any] = {}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000,
token_limit=25000,
)
assert any(
"select_columns" in s.lower() or "columns" in s.lower() for s in suggestions
)
def test_suggest_filters(self) -> None:
"""Should suggest adding filters."""
params: dict[str, Any] = {}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000,
token_limit=25000,
)
assert any("filter" in s.lower() for s in suggestions)
def test_tool_specific_suggestions_execute_sql(self) -> None:
"""Should provide SQL-specific suggestions for execute_sql."""
suggestions = generate_size_reduction_suggestions(
tool_name="execute_sql",
params={"sql": "SELECT * FROM table"},
estimated_tokens=50000,
token_limit=25000,
)
assert any("LIMIT" in s or "limit" in s.lower() for s in suggestions)
def test_tool_specific_suggestions_list_charts(self) -> None:
"""Should provide chart-specific suggestions for list_charts."""
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params={},
estimated_tokens=50000,
token_limit=25000,
)
# Should suggest excluding params or query_context
assert any(
"params" in s.lower() or "query_context" in s.lower() for s in suggestions
)
def test_suggests_search_parameter(self) -> None:
"""Should suggest using search parameter."""
suggestions = generate_size_reduction_suggestions(
tool_name="list_dashboards",
params={},
estimated_tokens=50000,
token_limit=25000,
)
assert any("search" in s.lower() for s in suggestions)
class TestFormatSizeLimitError:
"""Test format_size_limit_error function."""
def test_error_contains_token_counts(self) -> None:
"""Should include token counts in error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={},
estimated_tokens=50000,
token_limit=25000,
)
assert "50,000" in error
assert "25,000" in error
def test_error_contains_tool_name(self) -> None:
"""Should include tool name in error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={},
estimated_tokens=50000,
token_limit=25000,
)
assert "list_charts" in error
def test_error_contains_suggestions(self) -> None:
"""Should include suggestions in error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={"page_size": 100},
estimated_tokens=50000,
token_limit=25000,
)
# Should have numbered suggestions
assert "1." in error
def test_error_contains_reduction_percentage(self) -> None:
"""Should include reduction percentage in error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={},
estimated_tokens=50000,
token_limit=25000,
)
# 50% reduction needed
assert "50%" in error or "Reduction" in error
def test_error_limits_suggestions_to_five(self) -> None:
"""Should limit suggestions to 5."""
error = format_size_limit_error(
tool_name="list_charts",
params={},
estimated_tokens=100000,
token_limit=10000,
)
# Count numbered suggestions (1. through 5.)
suggestion_count = sum(1 for i in range(1, 10) if f"{i}." in error)
assert suggestion_count <= 5
def test_error_message_is_readable(self) -> None:
"""Should produce human-readable error message."""
error = format_size_limit_error(
tool_name="list_charts",
params={"page_size": 100},
estimated_tokens=75000,
token_limit=25000,
)
# Should be multi-line and contain key information
lines = error.split("\n")
assert len(lines) > 5
assert "Response too large" in error
assert "Please modify your query" in error
class TestCalculatedSuggestions:
"""Test that suggestions include calculated values."""
def test_suggested_limit_is_calculated(self) -> None:
"""Should calculate suggested limit based on reduction needed."""
params = {"page_size": 100}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=50000, # 2x over limit
token_limit=25000,
)
# Find the page_size suggestion
page_size_suggestion = next(
(s for s in suggestions if "page_size" in s.lower()), None
)
assert page_size_suggestion is not None
# Should suggest reducing from 100 to approximately 50
assert "100" in page_size_suggestion
assert (
"50" in page_size_suggestion or "reduction" in page_size_suggestion.lower()
)
def test_reduction_percentage_in_suggestions(self) -> None:
"""Should include reduction percentage in suggestions."""
params = {"page_size": 100}
suggestions = generate_size_reduction_suggestions(
tool_name="list_charts",
params=params,
estimated_tokens=75000, # 3x over limit
token_limit=25000,
)
# Should mention ~66% reduction needed (int truncation of 66.6%)
combined = " ".join(suggestions)
assert "66%" in combined