mirror of
https://github.com/apache/superset.git
synced 2026-04-07 10:31:50 +00:00
feat(mcp): add response size guard to prevent oversized responses (#37200)
This commit is contained in:
21
superset/mcp_service/constants.py
Normal file
21
superset/mcp_service/constants.py
Normal file
@@ -0,0 +1,21 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""Constants for the MCP service."""
|
||||
|
||||
# Response size guard defaults
|
||||
DEFAULT_TOKEN_LIMIT = 25_000 # ~25k tokens prevents overwhelming LLM context windows
|
||||
DEFAULT_WARN_THRESHOLD_PCT = 80 # Log warnings above 80% of limit
|
||||
@@ -22,6 +22,11 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from flask import Flask
|
||||
|
||||
from superset.mcp_service.constants import (
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_WARN_THRESHOLD_PCT,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# MCP Service Configuration
|
||||
@@ -167,6 +172,50 @@ MCP_CACHE_CONFIG: Dict[str, Any] = {
|
||||
],
|
||||
}
|
||||
|
||||
# =============================================================================
|
||||
# MCP Response Size Guard Configuration
|
||||
# =============================================================================
|
||||
#
|
||||
# Overview:
|
||||
# ---------
|
||||
# The Response Size Guard prevents oversized responses from overwhelming LLM
|
||||
# clients (e.g., Claude Desktop). When a tool response exceeds the token limit,
|
||||
# it returns a helpful error with suggestions for reducing the response size.
|
||||
#
|
||||
# How it works:
|
||||
# -------------
|
||||
# 1. After a tool executes, the middleware estimates the response's token count
|
||||
# 2. If the response exceeds the configured limit, it blocks the response
|
||||
# 3. Instead, it returns an error message with smart suggestions:
|
||||
# - Reduce page_size/limit
|
||||
# - Use select_columns to exclude large fields
|
||||
# - Add filters to narrow results
|
||||
# - Tool-specific recommendations
|
||||
#
|
||||
# Configuration:
|
||||
# --------------
|
||||
# - enabled: Toggle the guard on/off (default: True)
|
||||
# - token_limit: Maximum estimated tokens per response (default: 25,000)
|
||||
# - excluded_tools: Tools to skip checking (e.g., streaming tools)
|
||||
# - warn_threshold_pct: Log warnings above this % of limit (default: 80%)
|
||||
#
|
||||
# Token Estimation:
|
||||
# -----------------
|
||||
# Uses character-based heuristic (~3.5 chars per token for JSON).
|
||||
# This is intentionally conservative to avoid underestimating.
|
||||
# =============================================================================
|
||||
MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = {
|
||||
"enabled": True, # Enabled by default to protect LLM clients
|
||||
"token_limit": DEFAULT_TOKEN_LIMIT,
|
||||
"warn_threshold_pct": DEFAULT_WARN_THRESHOLD_PCT,
|
||||
"excluded_tools": [ # Tools to skip size checking
|
||||
"health_check", # Always small
|
||||
"get_chart_preview", # Returns URLs, not data
|
||||
"generate_explore_link", # Returns URLs
|
||||
"open_sql_lab_with_context", # Returns URLs
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]:
|
||||
"""Default MCP auth factory using app.config values."""
|
||||
|
||||
@@ -27,6 +27,10 @@ from sqlalchemy.exc import OperationalError, TimeoutError
|
||||
from starlette.exceptions import HTTPException
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.constants import (
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_WARN_THRESHOLD_PCT,
|
||||
)
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -837,3 +841,169 @@ class FieldPermissionsMiddleware(Middleware):
|
||||
"Unknown response type for field filtering: %s", type(response)
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class ResponseSizeGuardMiddleware(Middleware):
|
||||
"""
|
||||
Middleware that prevents oversized responses from overwhelming LLM clients.
|
||||
|
||||
When a tool response exceeds the configured token limit, this middleware
|
||||
intercepts it and returns a helpful error message with suggestions for
|
||||
reducing the response size.
|
||||
|
||||
This is critical for protecting LLM clients like Claude Desktop which can
|
||||
crash or become unresponsive when receiving extremely large responses.
|
||||
|
||||
Configuration via MCP_RESPONSE_SIZE_CONFIG in superset_config.py:
|
||||
- enabled: Toggle the guard on/off (default: True)
|
||||
- token_limit: Maximum estimated tokens per response (default: 25,000)
|
||||
- warn_threshold_pct: Log warnings above this % of limit (default: 80%)
|
||||
- excluded_tools: Tools to skip checking
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_limit: int = DEFAULT_TOKEN_LIMIT,
|
||||
warn_threshold_pct: int = DEFAULT_WARN_THRESHOLD_PCT,
|
||||
excluded_tools: list[str] | str | None = None,
|
||||
) -> None:
|
||||
self.token_limit = token_limit
|
||||
self.warn_threshold_pct = warn_threshold_pct
|
||||
self.warn_threshold = int(token_limit * warn_threshold_pct / 100)
|
||||
if isinstance(excluded_tools, str):
|
||||
excluded_tools = [excluded_tools]
|
||||
self.excluded_tools = set(excluded_tools or [])
|
||||
|
||||
async def on_call_tool(
|
||||
self,
|
||||
context: MiddlewareContext,
|
||||
call_next: Callable[[MiddlewareContext], Awaitable[Any]],
|
||||
) -> Any:
|
||||
"""Check response size after tool execution."""
|
||||
tool_name = getattr(context.message, "name", "unknown")
|
||||
|
||||
# Skip excluded tools
|
||||
if tool_name in self.excluded_tools:
|
||||
return await call_next(context)
|
||||
|
||||
# Execute the tool
|
||||
response = await call_next(context)
|
||||
|
||||
# Estimate response token count (guard against huge responses causing OOM)
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
format_size_limit_error,
|
||||
)
|
||||
|
||||
try:
|
||||
estimated_tokens = estimate_response_tokens(response)
|
||||
except MemoryError as me:
|
||||
logger.warning(
|
||||
"MemoryError while estimating tokens for %s: %s", tool_name, me
|
||||
)
|
||||
# Treat as over limit to avoid further serialization
|
||||
estimated_tokens = self.token_limit + 1
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Failed to estimate response tokens for %s: %s", tool_name, e
|
||||
)
|
||||
# Conservative fallback: block rather than risk OOM
|
||||
estimated_tokens = self.token_limit + 1
|
||||
|
||||
# Log warning if approaching limit
|
||||
if estimated_tokens > self.warn_threshold:
|
||||
logger.warning(
|
||||
"Response size warning for %s: ~%d tokens (%.0f%% of %d limit)",
|
||||
tool_name,
|
||||
estimated_tokens,
|
||||
(estimated_tokens / self.token_limit * 100) if self.token_limit else 0,
|
||||
self.token_limit,
|
||||
)
|
||||
|
||||
# Block if over limit
|
||||
if estimated_tokens > self.token_limit:
|
||||
# Extract params for smart suggestions
|
||||
params = getattr(context.message, "params", {}) or {}
|
||||
|
||||
# Log the blocked response
|
||||
logger.error(
|
||||
"Response blocked for %s: ~%d tokens exceeds limit of %d",
|
||||
tool_name,
|
||||
estimated_tokens,
|
||||
self.token_limit,
|
||||
)
|
||||
|
||||
# Log to event logger for monitoring
|
||||
try:
|
||||
user_id = get_user_id()
|
||||
event_logger.log(
|
||||
user_id=user_id,
|
||||
action="mcp_response_size_exceeded",
|
||||
curated_payload={
|
||||
"tool": tool_name,
|
||||
"estimated_tokens": estimated_tokens,
|
||||
"token_limit": self.token_limit,
|
||||
"params": params,
|
||||
},
|
||||
)
|
||||
except Exception as log_error: # noqa: BLE001
|
||||
logger.warning("Failed to log size exceeded event: %s", log_error)
|
||||
|
||||
# Generate helpful error message with suggestions
|
||||
# Avoid passing the full `response` (which may be huge) into the formatter
|
||||
# to prevent large-memory operations during error formatting.
|
||||
error_message = format_size_limit_error(
|
||||
tool_name=tool_name,
|
||||
params=params,
|
||||
estimated_tokens=estimated_tokens,
|
||||
token_limit=self.token_limit,
|
||||
response=None,
|
||||
)
|
||||
|
||||
raise ToolError(error_message)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
def create_response_size_guard_middleware() -> ResponseSizeGuardMiddleware | None:
|
||||
"""
|
||||
Factory function to create ResponseSizeGuardMiddleware from config.
|
||||
|
||||
Reads configuration from Flask app's MCP_RESPONSE_SIZE_CONFIG.
|
||||
Returns None if the guard is disabled.
|
||||
|
||||
Returns:
|
||||
ResponseSizeGuardMiddleware instance or None if disabled
|
||||
"""
|
||||
try:
|
||||
from superset.mcp_service.flask_singleton import get_flask_app
|
||||
from superset.mcp_service.mcp_config import MCP_RESPONSE_SIZE_CONFIG
|
||||
|
||||
flask_app = get_flask_app()
|
||||
|
||||
# Get config from Flask app, falling back to defaults
|
||||
config = flask_app.config.get(
|
||||
"MCP_RESPONSE_SIZE_CONFIG", MCP_RESPONSE_SIZE_CONFIG
|
||||
)
|
||||
|
||||
if not config.get("enabled", True):
|
||||
logger.info("Response size guard is disabled")
|
||||
return None
|
||||
|
||||
middleware = ResponseSizeGuardMiddleware(
|
||||
token_limit=int(config.get("token_limit", DEFAULT_TOKEN_LIMIT)),
|
||||
warn_threshold_pct=int(
|
||||
config.get("warn_threshold_pct", DEFAULT_WARN_THRESHOLD_PCT)
|
||||
),
|
||||
excluded_tools=config.get("excluded_tools"),
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Created ResponseSizeGuardMiddleware with token_limit=%d",
|
||||
middleware.token_limit,
|
||||
)
|
||||
return middleware
|
||||
|
||||
except (ImportError, AttributeError, KeyError) as e:
|
||||
logger.error("Failed to create ResponseSizeGuardMiddleware: %s", e)
|
||||
return None
|
||||
|
||||
@@ -29,10 +29,8 @@ from typing import Any
|
||||
import uvicorn
|
||||
|
||||
from superset.mcp_service.app import create_mcp_app, init_fastmcp_server
|
||||
from superset.mcp_service.mcp_config import (
|
||||
get_mcp_factory_config,
|
||||
MCP_STORE_CONFIG,
|
||||
)
|
||||
from superset.mcp_service.mcp_config import get_mcp_factory_config, MCP_STORE_CONFIG
|
||||
from superset.mcp_service.middleware import create_response_size_guard_middleware
|
||||
from superset.mcp_service.storage import _create_redis_store
|
||||
|
||||
|
||||
@@ -176,6 +174,12 @@ def run_server(
|
||||
|
||||
# Build middleware list
|
||||
middleware_list = []
|
||||
|
||||
# Add response size guard (protects LLM clients from huge responses)
|
||||
if size_guard_middleware := create_response_size_guard_middleware():
|
||||
middleware_list.append(size_guard_middleware)
|
||||
|
||||
# Add caching middleware
|
||||
caching_middleware = create_response_caching_middleware()
|
||||
if caching_middleware:
|
||||
middleware_list.append(caching_middleware)
|
||||
|
||||
424
superset/mcp_service/utils/token_utils.py
Normal file
424
superset/mcp_service/utils/token_utils.py
Normal file
@@ -0,0 +1,424 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Token counting and response size utilities for MCP service.
|
||||
|
||||
This module provides utilities to estimate token counts and generate smart
|
||||
suggestions when responses exceed configured limits. This prevents large
|
||||
responses from overwhelming LLM clients like Claude Desktop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes)
|
||||
ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes]
|
||||
|
||||
# Approximate characters per token for estimation
|
||||
# Claude tokenizer averages ~4 chars per token for English text
|
||||
# JSON tends to be more verbose, so we use a slightly lower ratio
|
||||
CHARS_PER_TOKEN = 3.5
|
||||
|
||||
|
||||
def estimate_token_count(text: str | bytes) -> int:
|
||||
"""
|
||||
Estimate the token count for a given text.
|
||||
|
||||
Uses a character-based heuristic since we don't have direct access to
|
||||
the actual tokenizer. This is conservative to avoid underestimating.
|
||||
|
||||
Args:
|
||||
text: The text to estimate tokens for (string or bytes)
|
||||
|
||||
Returns:
|
||||
Estimated number of tokens
|
||||
"""
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode("utf-8", errors="replace")
|
||||
|
||||
# Simple heuristic: ~3.5 characters per token for JSON/code
|
||||
text_length = len(text)
|
||||
if text_length == 0:
|
||||
return 0
|
||||
return max(1, int(text_length / CHARS_PER_TOKEN))
|
||||
|
||||
|
||||
def estimate_response_tokens(response: ToolResponse) -> int:
|
||||
"""
|
||||
Estimate token count for an MCP tool response.
|
||||
|
||||
Handles various response types including Pydantic models, dicts, and strings.
|
||||
|
||||
Args:
|
||||
response: The response object to estimate
|
||||
|
||||
Returns:
|
||||
Estimated number of tokens
|
||||
"""
|
||||
try:
|
||||
from superset.utils import json
|
||||
|
||||
# Convert response to JSON string for accurate estimation
|
||||
if hasattr(response, "model_dump"):
|
||||
# Pydantic model
|
||||
response_str = json.dumps(response.model_dump())
|
||||
elif isinstance(response, (dict, list)):
|
||||
response_str = json.dumps(response)
|
||||
elif isinstance(response, bytes):
|
||||
# Delegate to estimate_token_count which handles decoding safely
|
||||
return estimate_token_count(response)
|
||||
elif isinstance(response, str):
|
||||
response_str = response
|
||||
else:
|
||||
response_str = str(response)
|
||||
|
||||
return estimate_token_count(response_str)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Failed to estimate response tokens: %s", e)
|
||||
# Return a high estimate to be safe (conservative fallback)
|
||||
return 100000
|
||||
|
||||
|
||||
def get_response_size_bytes(response: ToolResponse) -> int:
|
||||
"""
|
||||
Get the size of a response in bytes.
|
||||
|
||||
Args:
|
||||
response: The response object
|
||||
|
||||
Returns:
|
||||
Size in bytes
|
||||
"""
|
||||
try:
|
||||
from superset.utils import json
|
||||
|
||||
if hasattr(response, "model_dump"):
|
||||
response_str = json.dumps(response.model_dump())
|
||||
elif isinstance(response, (dict, list)):
|
||||
response_str = json.dumps(response)
|
||||
elif isinstance(response, bytes):
|
||||
return len(response)
|
||||
elif isinstance(response, str):
|
||||
return len(response.encode("utf-8"))
|
||||
else:
|
||||
response_str = str(response)
|
||||
|
||||
return len(response_str.encode("utf-8"))
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Failed to get response size: %s", e)
|
||||
# Return a conservative large value to avoid allowing oversized responses
|
||||
# to bypass size checks (returning 0 would underestimate)
|
||||
return 1_000_000 # 1MB fallback
|
||||
|
||||
|
||||
def extract_query_params(params: Dict[str, Any] | None) -> Dict[str, Any]:
|
||||
"""
|
||||
Extract relevant query parameters from tool params for suggestions.
|
||||
|
||||
Args:
|
||||
params: The tool parameters dict
|
||||
|
||||
Returns:
|
||||
Extracted parameters relevant for size reduction suggestions
|
||||
"""
|
||||
if not params:
|
||||
return {}
|
||||
|
||||
# Handle nested request object (common pattern in MCP tools)
|
||||
if "request" in params and isinstance(params["request"], dict):
|
||||
params = params["request"]
|
||||
|
||||
# Keys to extract from params
|
||||
extract_keys = [
|
||||
# Pagination
|
||||
"page_size",
|
||||
"limit",
|
||||
# Column selection
|
||||
"select_columns",
|
||||
"columns",
|
||||
# Filters
|
||||
"filters",
|
||||
# Search
|
||||
"search",
|
||||
]
|
||||
return {k: params[k] for k in extract_keys if k in params}
|
||||
|
||||
|
||||
def generate_size_reduction_suggestions(
|
||||
tool_name: str,
|
||||
params: Dict[str, Any] | None,
|
||||
estimated_tokens: int,
|
||||
token_limit: int,
|
||||
response: ToolResponse | None = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate smart suggestions for reducing response size.
|
||||
|
||||
Analyzes the tool and parameters to provide actionable recommendations.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the MCP tool
|
||||
params: The tool parameters
|
||||
estimated_tokens: Estimated token count of the response
|
||||
token_limit: Configured token limit
|
||||
response: Optional response object for additional analysis
|
||||
|
||||
Returns:
|
||||
List of suggestion strings
|
||||
"""
|
||||
suggestions = []
|
||||
query_params = extract_query_params(params)
|
||||
reduction_needed = estimated_tokens - token_limit
|
||||
reduction_pct = (
|
||||
int((reduction_needed / estimated_tokens) * 100) if estimated_tokens else 0
|
||||
)
|
||||
|
||||
# Suggestion 1: Reduce page_size or limit
|
||||
raw_page_size = query_params.get("page_size") or query_params.get("limit")
|
||||
try:
|
||||
current_page_size = int(raw_page_size) if raw_page_size is not None else None
|
||||
except (TypeError, ValueError):
|
||||
current_page_size = None
|
||||
if current_page_size and current_page_size > 0:
|
||||
# Calculate suggested new limit based on reduction needed
|
||||
suggested_limit = max(
|
||||
1,
|
||||
int(current_page_size * (token_limit / estimated_tokens))
|
||||
if estimated_tokens
|
||||
else 1,
|
||||
)
|
||||
suggestions.append(
|
||||
f"Reduce page_size/limit from {current_page_size} to {suggested_limit} "
|
||||
f"(need ~{reduction_pct}% reduction)"
|
||||
)
|
||||
else:
|
||||
# No limit specified - suggest adding one
|
||||
if "list_" in tool_name or tool_name.startswith("get_chart_data"):
|
||||
suggestions.append(
|
||||
f"Add a 'limit' or 'page_size' parameter (suggested: 10-25 items) "
|
||||
f"to reduce response size by ~{reduction_pct}%"
|
||||
)
|
||||
|
||||
# Suggestion 2: Use select_columns to reduce fields
|
||||
current_columns = query_params.get("select_columns") or query_params.get("columns")
|
||||
if current_columns and len(current_columns) > 5:
|
||||
preview = ", ".join(str(c) for c in current_columns[:3])
|
||||
suffix = "..." if len(current_columns) > 3 else ""
|
||||
suggestions.append(
|
||||
f"Reduce select_columns from {len(current_columns)} columns to only "
|
||||
f"essential fields (currently: {preview}{suffix})"
|
||||
)
|
||||
elif not current_columns and "list_" in tool_name:
|
||||
# Analyze response to suggest specific columns to exclude
|
||||
large_fields = _identify_large_fields(response)
|
||||
if large_fields:
|
||||
fields_preview = ", ".join(large_fields[:3])
|
||||
suffix = "..." if len(large_fields) > 3 else ""
|
||||
suggestions.append(
|
||||
f"Use 'select_columns' to exclude large fields: "
|
||||
f"{fields_preview}{suffix}. Only request the columns you need."
|
||||
)
|
||||
else:
|
||||
suggestions.append(
|
||||
"Use 'select_columns' to request only the specific columns you need "
|
||||
"instead of fetching all fields"
|
||||
)
|
||||
|
||||
# Suggestion 3: Add filters to reduce result set
|
||||
current_filters = query_params.get("filters")
|
||||
if not current_filters and "list_" in tool_name:
|
||||
suggestions.append(
|
||||
"Add filters to narrow down results (e.g., filter by owner, "
|
||||
"date range, or specific attributes)"
|
||||
)
|
||||
|
||||
# Suggestion 4: Tool-specific suggestions
|
||||
tool_suggestions = _get_tool_specific_suggestions(tool_name, query_params, response)
|
||||
suggestions.extend(tool_suggestions)
|
||||
|
||||
# Suggestion 5: Use search instead of listing all
|
||||
if "list_" in tool_name and not query_params.get("search"):
|
||||
suggestions.append(
|
||||
"Use the 'search' parameter to find specific items instead of "
|
||||
"listing all and filtering client-side"
|
||||
)
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
def _identify_large_fields(response: ToolResponse) -> List[str]:
|
||||
"""
|
||||
Identify fields that contribute most to response size.
|
||||
|
||||
Args:
|
||||
response: The response object to analyze
|
||||
|
||||
Returns:
|
||||
List of field names that are particularly large
|
||||
"""
|
||||
large_fields: List[str] = []
|
||||
|
||||
try:
|
||||
from superset.utils import json
|
||||
|
||||
if hasattr(response, "model_dump"):
|
||||
data = response.model_dump()
|
||||
elif isinstance(response, dict):
|
||||
data = response
|
||||
else:
|
||||
return large_fields
|
||||
|
||||
# Check for list responses (e.g., charts, dashboards)
|
||||
items_key = None
|
||||
for key in ["charts", "dashboards", "datasets", "data", "items", "results"]:
|
||||
if key in data and isinstance(data[key], list) and data[key]:
|
||||
items_key = key
|
||||
break
|
||||
|
||||
if items_key and data[items_key]:
|
||||
first_item = data[items_key][0]
|
||||
if isinstance(first_item, dict):
|
||||
# Analyze field sizes in first item
|
||||
field_sizes = {}
|
||||
for field, value in first_item.items():
|
||||
if value is not None:
|
||||
field_sizes[field] = len(json.dumps(value))
|
||||
|
||||
# Sort by size and identify large fields (>500 chars)
|
||||
sorted_fields = sorted(
|
||||
field_sizes.items(), key=lambda x: x[1], reverse=True
|
||||
)
|
||||
large_fields = [
|
||||
f
|
||||
for f, size in sorted_fields
|
||||
if size > 500 and f not in ("id", "uuid", "name", "slice_name")
|
||||
]
|
||||
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.debug("Failed to identify large fields: %s", e)
|
||||
|
||||
return large_fields
|
||||
|
||||
|
||||
def _get_tool_specific_suggestions(
|
||||
tool_name: str,
|
||||
query_params: Dict[str, Any],
|
||||
response: ToolResponse,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Generate tool-specific suggestions based on the tool being called.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the MCP tool
|
||||
query_params: Extracted query parameters
|
||||
response: The response object
|
||||
|
||||
Returns:
|
||||
List of tool-specific suggestions
|
||||
"""
|
||||
suggestions = []
|
||||
|
||||
if tool_name == "get_chart_data":
|
||||
suggestions.append(
|
||||
"For get_chart_data, use 'limit' parameter to restrict rows returned, "
|
||||
"or use 'format=csv' for more compact output"
|
||||
)
|
||||
|
||||
elif tool_name == "execute_sql":
|
||||
suggestions.append(
|
||||
"Add LIMIT clause to your SQL query to restrict the number of rows "
|
||||
"(e.g., SELECT * FROM table LIMIT 100)"
|
||||
)
|
||||
|
||||
elif tool_name in ("get_chart_info", "get_dashboard_info", "get_dataset_info"):
|
||||
suggestions.append(
|
||||
f"For {tool_name}, use 'select_columns' to fetch only specific metadata "
|
||||
"fields instead of the full object"
|
||||
)
|
||||
|
||||
elif tool_name == "list_charts":
|
||||
suggestions.append(
|
||||
"For list_charts, exclude 'params' and 'query_context' columns which "
|
||||
"contain large JSON blobs - use select_columns to pick specific fields"
|
||||
)
|
||||
|
||||
elif tool_name == "list_datasets":
|
||||
suggestions.append(
|
||||
"For list_datasets, exclude 'columns' and 'metrics' from select_columns "
|
||||
"if you only need basic dataset info"
|
||||
)
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
def format_size_limit_error(
|
||||
tool_name: str,
|
||||
params: Dict[str, Any] | None,
|
||||
estimated_tokens: int,
|
||||
token_limit: int,
|
||||
response: ToolResponse | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a user-friendly error message when response exceeds token limit.
|
||||
|
||||
Args:
|
||||
tool_name: Name of the MCP tool
|
||||
params: The tool parameters
|
||||
estimated_tokens: Estimated token count
|
||||
token_limit: Configured token limit
|
||||
response: Optional response for analysis
|
||||
|
||||
Returns:
|
||||
Formatted error message with suggestions
|
||||
"""
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name, params, estimated_tokens, token_limit, response
|
||||
)
|
||||
|
||||
error_lines = [
|
||||
f"Response too large: ~{estimated_tokens:,} tokens (limit: {token_limit:,})",
|
||||
"",
|
||||
"This response would overwhelm the LLM context window.",
|
||||
"Please modify your query to reduce the response size:",
|
||||
"",
|
||||
]
|
||||
|
||||
for i, suggestion in enumerate(suggestions[:5], 1): # Limit to top 5 suggestions
|
||||
error_lines.append(f"{i}. {suggestion}")
|
||||
|
||||
reduction_pct = (
|
||||
(estimated_tokens - token_limit) / estimated_tokens * 100
|
||||
if estimated_tokens
|
||||
else 0
|
||||
)
|
||||
error_lines.extend(
|
||||
[
|
||||
"",
|
||||
f"Tool: {tool_name}",
|
||||
f"Reduction needed: ~{reduction_pct:.0f}%",
|
||||
]
|
||||
)
|
||||
|
||||
return "\n".join(error_lines)
|
||||
343
tests/unit_tests/mcp_service/test_middleware.py
Normal file
343
tests/unit_tests/mcp_service/test_middleware.py
Normal file
@@ -0,0 +1,343 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP service middleware.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
from superset.mcp_service.middleware import (
|
||||
create_response_size_guard_middleware,
|
||||
ResponseSizeGuardMiddleware,
|
||||
)
|
||||
|
||||
|
||||
class TestResponseSizeGuardMiddleware:
|
||||
"""Test ResponseSizeGuardMiddleware class."""
|
||||
|
||||
def test_init_default_values(self) -> None:
|
||||
"""Should initialize with default values."""
|
||||
middleware = ResponseSizeGuardMiddleware()
|
||||
assert middleware.token_limit == 25_000
|
||||
assert middleware.warn_threshold_pct == 80
|
||||
assert middleware.warn_threshold == 20000
|
||||
assert middleware.excluded_tools == set()
|
||||
|
||||
def test_init_custom_values(self) -> None:
|
||||
"""Should initialize with custom values."""
|
||||
middleware = ResponseSizeGuardMiddleware(
|
||||
token_limit=10000,
|
||||
warn_threshold_pct=70,
|
||||
excluded_tools=["health_check", "get_chart_preview"],
|
||||
)
|
||||
assert middleware.token_limit == 10000
|
||||
assert middleware.warn_threshold_pct == 70
|
||||
assert middleware.warn_threshold == 7000
|
||||
assert middleware.excluded_tools == {"health_check", "get_chart_preview"}
|
||||
|
||||
def test_init_excluded_tools_as_string(self) -> None:
|
||||
"""Should handle excluded_tools as a single string."""
|
||||
middleware = ResponseSizeGuardMiddleware(
|
||||
excluded_tools="health_check",
|
||||
)
|
||||
assert middleware.excluded_tools == {"health_check"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_allows_small_response(self) -> None:
|
||||
"""Should allow responses under token limit."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
||||
|
||||
# Create mock context
|
||||
context = MagicMock()
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {}
|
||||
|
||||
# Create mock call_next that returns small response
|
||||
small_response = {"charts": [{"id": 1, "name": "test"}]}
|
||||
call_next = AsyncMock(return_value=small_response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
assert result == small_response
|
||||
call_next.assert_called_once_with(context)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_blocks_large_response(self) -> None:
|
||||
"""Should block responses over token limit."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100) # Very low limit
|
||||
|
||||
# Create mock context
|
||||
context = MagicMock()
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {"page_size": 100}
|
||||
|
||||
# Create large response
|
||||
large_response = {
|
||||
"charts": [{"id": i, "name": f"chart_{i}"} for i in range(1000)]
|
||||
}
|
||||
call_next = AsyncMock(return_value=large_response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
pytest.raises(ToolError) as exc_info,
|
||||
):
|
||||
await middleware.on_call_tool(context, call_next)
|
||||
|
||||
# Verify error contains helpful information
|
||||
error_message = str(exc_info.value)
|
||||
assert "Response too large" in error_message
|
||||
assert "limit" in error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_excluded_tools(self) -> None:
|
||||
"""Should skip checking for excluded tools."""
|
||||
middleware = ResponseSizeGuardMiddleware(
|
||||
token_limit=100, excluded_tools=["health_check"]
|
||||
)
|
||||
|
||||
# Create mock context for excluded tool
|
||||
context = MagicMock()
|
||||
context.message.name = "health_check"
|
||||
context.message.params = {}
|
||||
|
||||
# Create response that would exceed limit
|
||||
large_response = {"data": "x" * 10000}
|
||||
call_next = AsyncMock(return_value=large_response)
|
||||
|
||||
# Should not raise even though response exceeds limit
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
assert result == large_response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_warning_at_threshold(self) -> None:
|
||||
"""Should log warning when approaching limit."""
|
||||
middleware = ResponseSizeGuardMiddleware(
|
||||
token_limit=1000, warn_threshold_pct=80
|
||||
)
|
||||
|
||||
context = MagicMock()
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {}
|
||||
|
||||
# Response at ~85% of limit (should trigger warning but not block)
|
||||
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
# Should return response (not blocked)
|
||||
assert result == response
|
||||
# Should log warning
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_includes_suggestions(self) -> None:
|
||||
"""Should include suggestions in error message."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
||||
|
||||
context = MagicMock()
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {"page_size": 100}
|
||||
|
||||
large_response = {"charts": [{"id": i} for i in range(1000)]}
|
||||
call_next = AsyncMock(return_value=large_response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
pytest.raises(ToolError) as exc_info,
|
||||
):
|
||||
await middleware.on_call_tool(context, call_next)
|
||||
|
||||
error_message = str(exc_info.value)
|
||||
# Should have numbered suggestions
|
||||
assert "1." in error_message
|
||||
# Should suggest reducing page_size
|
||||
assert "page_size" in error_message.lower() or "limit" in error_message.lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_size_exceeded_event(self) -> None:
|
||||
"""Should log to event logger when size exceeded."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=100)
|
||||
|
||||
context = MagicMock()
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {}
|
||||
|
||||
large_response = {"data": "x" * 10000}
|
||||
call_next = AsyncMock(return_value=large_response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger") as mock_event_logger,
|
||||
pytest.raises(ToolError),
|
||||
):
|
||||
await middleware.on_call_tool(context, call_next)
|
||||
|
||||
# Should log to event logger
|
||||
mock_event_logger.log.assert_called()
|
||||
call_args = mock_event_logger.log.call_args
|
||||
assert call_args.kwargs["action"] == "mcp_response_size_exceeded"
|
||||
|
||||
|
||||
class TestCreateResponseSizeGuardMiddleware:
|
||||
"""Test create_response_size_guard_middleware factory function."""
|
||||
|
||||
def test_creates_middleware_when_enabled(self) -> None:
|
||||
"""Should create middleware when enabled in config."""
|
||||
mock_config = {
|
||||
"enabled": True,
|
||||
"token_limit": 30000,
|
||||
"warn_threshold_pct": 75,
|
||||
"excluded_tools": ["health_check"],
|
||||
}
|
||||
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = mock_config
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
return_value=mock_flask_app,
|
||||
):
|
||||
middleware = create_response_size_guard_middleware()
|
||||
|
||||
assert middleware is not None
|
||||
assert isinstance(middleware, ResponseSizeGuardMiddleware)
|
||||
assert middleware.token_limit == 30000
|
||||
assert middleware.warn_threshold_pct == 75
|
||||
assert "health_check" in middleware.excluded_tools
|
||||
|
||||
def test_returns_none_when_disabled(self) -> None:
|
||||
"""Should return None when disabled in config."""
|
||||
mock_config = {"enabled": False}
|
||||
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = mock_config
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
return_value=mock_flask_app,
|
||||
):
|
||||
middleware = create_response_size_guard_middleware()
|
||||
|
||||
assert middleware is None
|
||||
|
||||
def test_uses_defaults_when_config_missing(self) -> None:
|
||||
"""Should use defaults when config values are missing."""
|
||||
mock_config = {"enabled": True} # Only enabled, no other values
|
||||
|
||||
mock_flask_app = MagicMock()
|
||||
mock_flask_app.config.get.return_value = mock_config
|
||||
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
return_value=mock_flask_app,
|
||||
):
|
||||
middleware = create_response_size_guard_middleware()
|
||||
|
||||
assert middleware is not None
|
||||
assert middleware.token_limit == 25_000 # Default
|
||||
assert middleware.warn_threshold_pct == 80 # Default
|
||||
|
||||
def test_handles_exception_gracefully(self) -> None:
|
||||
"""Should return None on expected configuration exceptions."""
|
||||
with patch(
|
||||
"superset.mcp_service.flask_singleton.get_flask_app",
|
||||
side_effect=ImportError("Config error"),
|
||||
):
|
||||
middleware = create_response_size_guard_middleware()
|
||||
|
||||
assert middleware is None
|
||||
|
||||
|
||||
class TestMiddlewareIntegration:
|
||||
"""Integration tests for middleware behavior."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pydantic_model_response(self) -> None:
|
||||
"""Should handle Pydantic model responses."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ChartInfo(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
||||
|
||||
context = MagicMock()
|
||||
context.message.name = "get_chart_info"
|
||||
context.message.params = {}
|
||||
|
||||
response = ChartInfo(id=1, name="Test Chart")
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
assert result == response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_response(self) -> None:
|
||||
"""Should handle list responses."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
||||
|
||||
context = MagicMock()
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {}
|
||||
|
||||
response = [{"id": 1}, {"id": 2}, {"id": 3}]
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
assert result == response
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_string_response(self) -> None:
|
||||
"""Should handle string responses."""
|
||||
middleware = ResponseSizeGuardMiddleware(token_limit=25000)
|
||||
|
||||
context = MagicMock()
|
||||
context.message.name = "health_check"
|
||||
context.message.params = {}
|
||||
|
||||
response = "OK"
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
assert result == response
|
||||
358
tests/unit_tests/mcp_service/utils/test_token_utils.py
Normal file
358
tests/unit_tests/mcp_service/utils/test_token_utils.py
Normal file
@@ -0,0 +1,358 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP service token utilities.
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
CHARS_PER_TOKEN,
|
||||
estimate_response_tokens,
|
||||
estimate_token_count,
|
||||
extract_query_params,
|
||||
format_size_limit_error,
|
||||
generate_size_reduction_suggestions,
|
||||
get_response_size_bytes,
|
||||
)
|
||||
|
||||
|
||||
class TestEstimateTokenCount:
|
||||
"""Test estimate_token_count function."""
|
||||
|
||||
def test_estimate_string(self) -> None:
|
||||
"""Should estimate tokens for a string."""
|
||||
text = "Hello world"
|
||||
result = estimate_token_count(text)
|
||||
expected = int(len(text) / CHARS_PER_TOKEN)
|
||||
assert result == expected
|
||||
|
||||
def test_estimate_bytes(self) -> None:
|
||||
"""Should estimate tokens for bytes."""
|
||||
text = b"Hello world"
|
||||
result = estimate_token_count(text)
|
||||
expected = int(len(text) / CHARS_PER_TOKEN)
|
||||
assert result == expected
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
"""Should return 0 for empty string."""
|
||||
assert estimate_token_count("") == 0
|
||||
|
||||
def test_json_like_content(self) -> None:
|
||||
"""Should estimate tokens for JSON-like content."""
|
||||
json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}'
|
||||
result = estimate_token_count(json_str)
|
||||
assert result > 0
|
||||
assert result == int(len(json_str) / CHARS_PER_TOKEN)
|
||||
|
||||
|
||||
class TestEstimateResponseTokens:
|
||||
"""Test estimate_response_tokens function."""
|
||||
|
||||
class MockResponse(BaseModel):
|
||||
"""Mock Pydantic response model."""
|
||||
|
||||
name: str
|
||||
value: int
|
||||
items: List[Any]
|
||||
|
||||
def test_estimate_pydantic_model(self) -> None:
|
||||
"""Should estimate tokens for Pydantic model."""
|
||||
response = self.MockResponse(name="test", value=42, items=[1, 2, 3])
|
||||
result = estimate_response_tokens(response)
|
||||
assert result > 0
|
||||
|
||||
def test_estimate_dict(self) -> None:
|
||||
"""Should estimate tokens for dict."""
|
||||
response = {"name": "test", "value": 42}
|
||||
result = estimate_response_tokens(response)
|
||||
assert result > 0
|
||||
|
||||
def test_estimate_list(self) -> None:
|
||||
"""Should estimate tokens for list."""
|
||||
response = [{"name": "item1"}, {"name": "item2"}]
|
||||
result = estimate_response_tokens(response)
|
||||
assert result > 0
|
||||
|
||||
def test_estimate_string(self) -> None:
|
||||
"""Should estimate tokens for string response."""
|
||||
response = "Hello world"
|
||||
result = estimate_response_tokens(response)
|
||||
assert result > 0
|
||||
|
||||
def test_estimate_large_response(self) -> None:
|
||||
"""Should estimate tokens for large response."""
|
||||
response = {"items": [{"name": f"item{i}"} for i in range(1000)]}
|
||||
result = estimate_response_tokens(response)
|
||||
assert result > 1000 # Large response should have many tokens
|
||||
|
||||
|
||||
class TestGetResponseSizeBytes:
|
||||
"""Test get_response_size_bytes function."""
|
||||
|
||||
def test_size_dict(self) -> None:
|
||||
"""Should return size in bytes for dict."""
|
||||
response = {"name": "test"}
|
||||
result = get_response_size_bytes(response)
|
||||
assert result > 0
|
||||
|
||||
def test_size_string(self) -> None:
|
||||
"""Should return size in bytes for string."""
|
||||
response = "Hello world"
|
||||
result = get_response_size_bytes(response)
|
||||
assert result == len(response.encode("utf-8"))
|
||||
|
||||
def test_size_bytes(self) -> None:
|
||||
"""Should return size for bytes."""
|
||||
response = b"Hello world"
|
||||
result = get_response_size_bytes(response)
|
||||
assert result == len(response)
|
||||
|
||||
|
||||
class TestExtractQueryParams:
|
||||
"""Test extract_query_params function."""
|
||||
|
||||
def test_extract_pagination_params(self) -> None:
|
||||
"""Should extract pagination parameters."""
|
||||
params = {"page_size": 100, "limit": 50}
|
||||
result = extract_query_params(params)
|
||||
assert result["page_size"] == 100
|
||||
assert result["limit"] == 50
|
||||
|
||||
def test_extract_column_selection(self) -> None:
|
||||
"""Should extract column selection parameters."""
|
||||
params = {"select_columns": ["name", "id"]}
|
||||
result = extract_query_params(params)
|
||||
assert result["select_columns"] == ["name", "id"]
|
||||
|
||||
def test_extract_from_nested_request(self) -> None:
|
||||
"""Should extract from nested request object."""
|
||||
params = {"request": {"page_size": 50, "filters": [{"col": "name"}]}}
|
||||
result = extract_query_params(params)
|
||||
assert result["page_size"] == 50
|
||||
assert result["filters"] == [{"col": "name"}]
|
||||
|
||||
def test_empty_params(self) -> None:
|
||||
"""Should return empty dict for empty params."""
|
||||
assert extract_query_params(None) == {}
|
||||
assert extract_query_params({}) == {}
|
||||
|
||||
def test_extract_filters(self) -> None:
|
||||
"""Should extract filter parameters."""
|
||||
params = {"filters": [{"col": "name", "opr": "eq", "value": "test"}]}
|
||||
result = extract_query_params(params)
|
||||
assert "filters" in result
|
||||
|
||||
|
||||
class TestGenerateSizeReductionSuggestions:
|
||||
"""Test generate_size_reduction_suggestions function."""
|
||||
|
||||
def test_suggest_reduce_page_size(self) -> None:
|
||||
"""Should suggest reducing page_size when present."""
|
||||
params = {"page_size": 100}
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name="list_charts",
|
||||
params=params,
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
assert any(
|
||||
"page_size" in s.lower() or "limit" in s.lower() for s in suggestions
|
||||
)
|
||||
|
||||
def test_suggest_add_limit_for_list_tools(self) -> None:
|
||||
"""Should suggest adding limit for list tools."""
|
||||
params: dict[str, Any] = {}
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name="list_charts",
|
||||
params=params,
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
assert any(
|
||||
"limit" in s.lower() or "page_size" in s.lower() for s in suggestions
|
||||
)
|
||||
|
||||
def test_suggest_select_columns(self) -> None:
|
||||
"""Should suggest using select_columns."""
|
||||
params: dict[str, Any] = {}
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name="list_charts",
|
||||
params=params,
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
assert any(
|
||||
"select_columns" in s.lower() or "columns" in s.lower() for s in suggestions
|
||||
)
|
||||
|
||||
def test_suggest_filters(self) -> None:
|
||||
"""Should suggest adding filters."""
|
||||
params: dict[str, Any] = {}
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name="list_charts",
|
||||
params=params,
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
assert any("filter" in s.lower() for s in suggestions)
|
||||
|
||||
def test_tool_specific_suggestions_execute_sql(self) -> None:
|
||||
"""Should provide SQL-specific suggestions for execute_sql."""
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name="execute_sql",
|
||||
params={"sql": "SELECT * FROM table"},
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
assert any("LIMIT" in s or "limit" in s.lower() for s in suggestions)
|
||||
|
||||
def test_tool_specific_suggestions_list_charts(self) -> None:
|
||||
"""Should provide chart-specific suggestions for list_charts."""
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name="list_charts",
|
||||
params={},
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
# Should suggest excluding params or query_context
|
||||
assert any(
|
||||
"params" in s.lower() or "query_context" in s.lower() for s in suggestions
|
||||
)
|
||||
|
||||
def test_suggests_search_parameter(self) -> None:
|
||||
"""Should suggest using search parameter."""
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name="list_dashboards",
|
||||
params={},
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
assert any("search" in s.lower() for s in suggestions)
|
||||
|
||||
|
||||
class TestFormatSizeLimitError:
|
||||
"""Test format_size_limit_error function."""
|
||||
|
||||
def test_error_contains_token_counts(self) -> None:
|
||||
"""Should include token counts in error message."""
|
||||
error = format_size_limit_error(
|
||||
tool_name="list_charts",
|
||||
params={},
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
assert "50,000" in error
|
||||
assert "25,000" in error
|
||||
|
||||
def test_error_contains_tool_name(self) -> None:
|
||||
"""Should include tool name in error message."""
|
||||
error = format_size_limit_error(
|
||||
tool_name="list_charts",
|
||||
params={},
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
assert "list_charts" in error
|
||||
|
||||
def test_error_contains_suggestions(self) -> None:
|
||||
"""Should include suggestions in error message."""
|
||||
error = format_size_limit_error(
|
||||
tool_name="list_charts",
|
||||
params={"page_size": 100},
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
# Should have numbered suggestions
|
||||
assert "1." in error
|
||||
|
||||
def test_error_contains_reduction_percentage(self) -> None:
|
||||
"""Should include reduction percentage in error message."""
|
||||
error = format_size_limit_error(
|
||||
tool_name="list_charts",
|
||||
params={},
|
||||
estimated_tokens=50000,
|
||||
token_limit=25000,
|
||||
)
|
||||
# 50% reduction needed
|
||||
assert "50%" in error or "Reduction" in error
|
||||
|
||||
def test_error_limits_suggestions_to_five(self) -> None:
|
||||
"""Should limit suggestions to 5."""
|
||||
error = format_size_limit_error(
|
||||
tool_name="list_charts",
|
||||
params={},
|
||||
estimated_tokens=100000,
|
||||
token_limit=10000,
|
||||
)
|
||||
# Count numbered suggestions (1. through 5.)
|
||||
suggestion_count = sum(1 for i in range(1, 10) if f"{i}." in error)
|
||||
assert suggestion_count <= 5
|
||||
|
||||
def test_error_message_is_readable(self) -> None:
|
||||
"""Should produce human-readable error message."""
|
||||
error = format_size_limit_error(
|
||||
tool_name="list_charts",
|
||||
params={"page_size": 100},
|
||||
estimated_tokens=75000,
|
||||
token_limit=25000,
|
||||
)
|
||||
# Should be multi-line and contain key information
|
||||
lines = error.split("\n")
|
||||
assert len(lines) > 5
|
||||
assert "Response too large" in error
|
||||
assert "Please modify your query" in error
|
||||
|
||||
|
||||
class TestCalculatedSuggestions:
|
||||
"""Test that suggestions include calculated values."""
|
||||
|
||||
def test_suggested_limit_is_calculated(self) -> None:
|
||||
"""Should calculate suggested limit based on reduction needed."""
|
||||
params = {"page_size": 100}
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name="list_charts",
|
||||
params=params,
|
||||
estimated_tokens=50000, # 2x over limit
|
||||
token_limit=25000,
|
||||
)
|
||||
# Find the page_size suggestion
|
||||
page_size_suggestion = next(
|
||||
(s for s in suggestions if "page_size" in s.lower()), None
|
||||
)
|
||||
assert page_size_suggestion is not None
|
||||
# Should suggest reducing from 100 to approximately 50
|
||||
assert "100" in page_size_suggestion
|
||||
assert (
|
||||
"50" in page_size_suggestion or "reduction" in page_size_suggestion.lower()
|
||||
)
|
||||
|
||||
def test_reduction_percentage_in_suggestions(self) -> None:
|
||||
"""Should include reduction percentage in suggestions."""
|
||||
params = {"page_size": 100}
|
||||
suggestions = generate_size_reduction_suggestions(
|
||||
tool_name="list_charts",
|
||||
params=params,
|
||||
estimated_tokens=75000, # 3x over limit
|
||||
token_limit=25000,
|
||||
)
|
||||
# Should mention ~66% reduction needed (int truncation of 66.6%)
|
||||
combined = " ".join(suggestions)
|
||||
assert "66%" in combined
|
||||
Reference in New Issue
Block a user