diff --git a/superset/mcp_service/constants.py b/superset/mcp_service/constants.py new file mode 100644 index 00000000000..7abf91147a8 --- /dev/null +++ b/superset/mcp_service/constants.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Constants for the MCP service.""" + +# Response size guard defaults +DEFAULT_TOKEN_LIMIT = 25_000 # ~25k tokens prevents overwhelming LLM context windows +DEFAULT_WARN_THRESHOLD_PCT = 80 # Log warnings above 80% of limit diff --git a/superset/mcp_service/mcp_config.py b/superset/mcp_service/mcp_config.py index 23cbe7d7ee4..86a772aa483 100644 --- a/superset/mcp_service/mcp_config.py +++ b/superset/mcp_service/mcp_config.py @@ -22,6 +22,11 @@ from typing import Any, Dict, Optional from flask import Flask +from superset.mcp_service.constants import ( + DEFAULT_TOKEN_LIMIT, + DEFAULT_WARN_THRESHOLD_PCT, +) + logger = logging.getLogger(__name__) # MCP Service Configuration @@ -167,6 +172,50 @@ MCP_CACHE_CONFIG: Dict[str, Any] = { ], } +# ============================================================================= +# MCP Response Size Guard Configuration +# ============================================================================= +# +# Overview: +# --------- +# The Response Size Guard prevents oversized responses from overwhelming LLM +# clients (e.g., Claude Desktop). When a tool response exceeds the token limit, +# it returns a helpful error with suggestions for reducing the response size. +# +# How it works: +# ------------- +# 1. After a tool executes, the middleware estimates the response's token count +# 2. If the response exceeds the configured limit, it blocks the response +# 3. Instead, it returns an error message with smart suggestions: +# - Reduce page_size/limit +# - Use select_columns to exclude large fields +# - Add filters to narrow results +# - Tool-specific recommendations +# +# Configuration: +# -------------- +# - enabled: Toggle the guard on/off (default: True) +# - token_limit: Maximum estimated tokens per response (default: 25,000) +# - excluded_tools: Tools to skip checking (e.g., streaming tools) +# - warn_threshold_pct: Log warnings above this % of limit (default: 80%) +# +# Token Estimation: +# ----------------- +# Uses character-based heuristic (~3.5 chars per token for JSON). +# This is intentionally conservative to avoid underestimating. +# ============================================================================= +MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = { + "enabled": True, # Enabled by default to protect LLM clients + "token_limit": DEFAULT_TOKEN_LIMIT, + "warn_threshold_pct": DEFAULT_WARN_THRESHOLD_PCT, + "excluded_tools": [ # Tools to skip size checking + "health_check", # Always small + "get_chart_preview", # Returns URLs, not data + "generate_explore_link", # Returns URLs + "open_sql_lab_with_context", # Returns URLs + ], +} + def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]: """Default MCP auth factory using app.config values.""" diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index 7b3508c76a5..cb30b47e5b8 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -27,6 +27,10 @@ from sqlalchemy.exc import OperationalError, TimeoutError from starlette.exceptions import HTTPException from superset.extensions import event_logger +from superset.mcp_service.constants import ( + DEFAULT_TOKEN_LIMIT, + DEFAULT_WARN_THRESHOLD_PCT, +) from superset.utils.core import get_user_id logger = logging.getLogger(__name__) @@ -837,3 +841,169 @@ class FieldPermissionsMiddleware(Middleware): "Unknown response type for field filtering: %s", type(response) ) return response + + +class ResponseSizeGuardMiddleware(Middleware): + """ + Middleware that prevents oversized responses from overwhelming LLM clients. + + When a tool response exceeds the configured token limit, this middleware + intercepts it and returns a helpful error message with suggestions for + reducing the response size. + + This is critical for protecting LLM clients like Claude Desktop which can + crash or become unresponsive when receiving extremely large responses. + + Configuration via MCP_RESPONSE_SIZE_CONFIG in superset_config.py: + - enabled: Toggle the guard on/off (default: True) + - token_limit: Maximum estimated tokens per response (default: 25,000) + - warn_threshold_pct: Log warnings above this % of limit (default: 80%) + - excluded_tools: Tools to skip checking + """ + + def __init__( + self, + token_limit: int = DEFAULT_TOKEN_LIMIT, + warn_threshold_pct: int = DEFAULT_WARN_THRESHOLD_PCT, + excluded_tools: list[str] | str | None = None, + ) -> None: + self.token_limit = token_limit + self.warn_threshold_pct = warn_threshold_pct + self.warn_threshold = int(token_limit * warn_threshold_pct / 100) + if isinstance(excluded_tools, str): + excluded_tools = [excluded_tools] + self.excluded_tools = set(excluded_tools or []) + + async def on_call_tool( + self, + context: MiddlewareContext, + call_next: Callable[[MiddlewareContext], Awaitable[Any]], + ) -> Any: + """Check response size after tool execution.""" + tool_name = getattr(context.message, "name", "unknown") + + # Skip excluded tools + if tool_name in self.excluded_tools: + return await call_next(context) + + # Execute the tool + response = await call_next(context) + + # Estimate response token count (guard against huge responses causing OOM) + from superset.mcp_service.utils.token_utils import ( + estimate_response_tokens, + format_size_limit_error, + ) + + try: + estimated_tokens = estimate_response_tokens(response) + except MemoryError as me: + logger.warning( + "MemoryError while estimating tokens for %s: %s", tool_name, me + ) + # Treat as over limit to avoid further serialization + estimated_tokens = self.token_limit + 1 + except Exception as e: # noqa: BLE001 + logger.warning( + "Failed to estimate response tokens for %s: %s", tool_name, e + ) + # Conservative fallback: block rather than risk OOM + estimated_tokens = self.token_limit + 1 + + # Log warning if approaching limit + if estimated_tokens > self.warn_threshold: + logger.warning( + "Response size warning for %s: ~%d tokens (%.0f%% of %d limit)", + tool_name, + estimated_tokens, + (estimated_tokens / self.token_limit * 100) if self.token_limit else 0, + self.token_limit, + ) + + # Block if over limit + if estimated_tokens > self.token_limit: + # Extract params for smart suggestions + params = getattr(context.message, "params", {}) or {} + + # Log the blocked response + logger.error( + "Response blocked for %s: ~%d tokens exceeds limit of %d", + tool_name, + estimated_tokens, + self.token_limit, + ) + + # Log to event logger for monitoring + try: + user_id = get_user_id() + event_logger.log( + user_id=user_id, + action="mcp_response_size_exceeded", + curated_payload={ + "tool": tool_name, + "estimated_tokens": estimated_tokens, + "token_limit": self.token_limit, + "params": params, + }, + ) + except Exception as log_error: # noqa: BLE001 + logger.warning("Failed to log size exceeded event: %s", log_error) + + # Generate helpful error message with suggestions + # Avoid passing the full `response` (which may be huge) into the formatter + # to prevent large-memory operations during error formatting. + error_message = format_size_limit_error( + tool_name=tool_name, + params=params, + estimated_tokens=estimated_tokens, + token_limit=self.token_limit, + response=None, + ) + + raise ToolError(error_message) + + return response + + +def create_response_size_guard_middleware() -> ResponseSizeGuardMiddleware | None: + """ + Factory function to create ResponseSizeGuardMiddleware from config. + + Reads configuration from Flask app's MCP_RESPONSE_SIZE_CONFIG. + Returns None if the guard is disabled. + + Returns: + ResponseSizeGuardMiddleware instance or None if disabled + """ + try: + from superset.mcp_service.flask_singleton import get_flask_app + from superset.mcp_service.mcp_config import MCP_RESPONSE_SIZE_CONFIG + + flask_app = get_flask_app() + + # Get config from Flask app, falling back to defaults + config = flask_app.config.get( + "MCP_RESPONSE_SIZE_CONFIG", MCP_RESPONSE_SIZE_CONFIG + ) + + if not config.get("enabled", True): + logger.info("Response size guard is disabled") + return None + + middleware = ResponseSizeGuardMiddleware( + token_limit=int(config.get("token_limit", DEFAULT_TOKEN_LIMIT)), + warn_threshold_pct=int( + config.get("warn_threshold_pct", DEFAULT_WARN_THRESHOLD_PCT) + ), + excluded_tools=config.get("excluded_tools"), + ) + + logger.info( + "Created ResponseSizeGuardMiddleware with token_limit=%d", + middleware.token_limit, + ) + return middleware + + except (ImportError, AttributeError, KeyError) as e: + logger.error("Failed to create ResponseSizeGuardMiddleware: %s", e) + return None diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index c159c6cb3b2..a0d7e7f5d42 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -29,10 +29,8 @@ from typing import Any import uvicorn from superset.mcp_service.app import create_mcp_app, init_fastmcp_server -from superset.mcp_service.mcp_config import ( - get_mcp_factory_config, - MCP_STORE_CONFIG, -) +from superset.mcp_service.mcp_config import get_mcp_factory_config, MCP_STORE_CONFIG +from superset.mcp_service.middleware import create_response_size_guard_middleware from superset.mcp_service.storage import _create_redis_store @@ -176,6 +174,12 @@ def run_server( # Build middleware list middleware_list = [] + + # Add response size guard (protects LLM clients from huge responses) + if size_guard_middleware := create_response_size_guard_middleware(): + middleware_list.append(size_guard_middleware) + + # Add caching middleware caching_middleware = create_response_caching_middleware() if caching_middleware: middleware_list.append(caching_middleware) diff --git a/superset/mcp_service/utils/token_utils.py b/superset/mcp_service/utils/token_utils.py new file mode 100644 index 00000000000..b7f12a2c1cb --- /dev/null +++ b/superset/mcp_service/utils/token_utils.py @@ -0,0 +1,424 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Token counting and response size utilities for MCP service. + +This module provides utilities to estimate token counts and generate smart +suggestions when responses exceed configured limits. This prevents large +responses from overwhelming LLM clients like Claude Desktop. +""" + +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Union + +from pydantic import BaseModel +from typing_extensions import TypeAlias + +logger = logging.getLogger(__name__) + +# Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes) +ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes] + +# Approximate characters per token for estimation +# Claude tokenizer averages ~4 chars per token for English text +# JSON tends to be more verbose, so we use a slightly lower ratio +CHARS_PER_TOKEN = 3.5 + + +def estimate_token_count(text: str | bytes) -> int: + """ + Estimate the token count for a given text. + + Uses a character-based heuristic since we don't have direct access to + the actual tokenizer. This is conservative to avoid underestimating. + + Args: + text: The text to estimate tokens for (string or bytes) + + Returns: + Estimated number of tokens + """ + if isinstance(text, bytes): + text = text.decode("utf-8", errors="replace") + + # Simple heuristic: ~3.5 characters per token for JSON/code + text_length = len(text) + if text_length == 0: + return 0 + return max(1, int(text_length / CHARS_PER_TOKEN)) + + +def estimate_response_tokens(response: ToolResponse) -> int: + """ + Estimate token count for an MCP tool response. + + Handles various response types including Pydantic models, dicts, and strings. + + Args: + response: The response object to estimate + + Returns: + Estimated number of tokens + """ + try: + from superset.utils import json + + # Convert response to JSON string for accurate estimation + if hasattr(response, "model_dump"): + # Pydantic model + response_str = json.dumps(response.model_dump()) + elif isinstance(response, (dict, list)): + response_str = json.dumps(response) + elif isinstance(response, bytes): + # Delegate to estimate_token_count which handles decoding safely + return estimate_token_count(response) + elif isinstance(response, str): + response_str = response + else: + response_str = str(response) + + return estimate_token_count(response_str) + except Exception as e: # noqa: BLE001 + logger.warning("Failed to estimate response tokens: %s", e) + # Return a high estimate to be safe (conservative fallback) + return 100000 + + +def get_response_size_bytes(response: ToolResponse) -> int: + """ + Get the size of a response in bytes. + + Args: + response: The response object + + Returns: + Size in bytes + """ + try: + from superset.utils import json + + if hasattr(response, "model_dump"): + response_str = json.dumps(response.model_dump()) + elif isinstance(response, (dict, list)): + response_str = json.dumps(response) + elif isinstance(response, bytes): + return len(response) + elif isinstance(response, str): + return len(response.encode("utf-8")) + else: + response_str = str(response) + + return len(response_str.encode("utf-8")) + except Exception as e: # noqa: BLE001 + logger.warning("Failed to get response size: %s", e) + # Return a conservative large value to avoid allowing oversized responses + # to bypass size checks (returning 0 would underestimate) + return 1_000_000 # 1MB fallback + + +def extract_query_params(params: Dict[str, Any] | None) -> Dict[str, Any]: + """ + Extract relevant query parameters from tool params for suggestions. + + Args: + params: The tool parameters dict + + Returns: + Extracted parameters relevant for size reduction suggestions + """ + if not params: + return {} + + # Handle nested request object (common pattern in MCP tools) + if "request" in params and isinstance(params["request"], dict): + params = params["request"] + + # Keys to extract from params + extract_keys = [ + # Pagination + "page_size", + "limit", + # Column selection + "select_columns", + "columns", + # Filters + "filters", + # Search + "search", + ] + return {k: params[k] for k in extract_keys if k in params} + + +def generate_size_reduction_suggestions( + tool_name: str, + params: Dict[str, Any] | None, + estimated_tokens: int, + token_limit: int, + response: ToolResponse | None = None, +) -> List[str]: + """ + Generate smart suggestions for reducing response size. + + Analyzes the tool and parameters to provide actionable recommendations. + + Args: + tool_name: Name of the MCP tool + params: The tool parameters + estimated_tokens: Estimated token count of the response + token_limit: Configured token limit + response: Optional response object for additional analysis + + Returns: + List of suggestion strings + """ + suggestions = [] + query_params = extract_query_params(params) + reduction_needed = estimated_tokens - token_limit + reduction_pct = ( + int((reduction_needed / estimated_tokens) * 100) if estimated_tokens else 0 + ) + + # Suggestion 1: Reduce page_size or limit + raw_page_size = query_params.get("page_size") or query_params.get("limit") + try: + current_page_size = int(raw_page_size) if raw_page_size is not None else None + except (TypeError, ValueError): + current_page_size = None + if current_page_size and current_page_size > 0: + # Calculate suggested new limit based on reduction needed + suggested_limit = max( + 1, + int(current_page_size * (token_limit / estimated_tokens)) + if estimated_tokens + else 1, + ) + suggestions.append( + f"Reduce page_size/limit from {current_page_size} to {suggested_limit} " + f"(need ~{reduction_pct}% reduction)" + ) + else: + # No limit specified - suggest adding one + if "list_" in tool_name or tool_name.startswith("get_chart_data"): + suggestions.append( + f"Add a 'limit' or 'page_size' parameter (suggested: 10-25 items) " + f"to reduce response size by ~{reduction_pct}%" + ) + + # Suggestion 2: Use select_columns to reduce fields + current_columns = query_params.get("select_columns") or query_params.get("columns") + if current_columns and len(current_columns) > 5: + preview = ", ".join(str(c) for c in current_columns[:3]) + suffix = "..." if len(current_columns) > 3 else "" + suggestions.append( + f"Reduce select_columns from {len(current_columns)} columns to only " + f"essential fields (currently: {preview}{suffix})" + ) + elif not current_columns and "list_" in tool_name: + # Analyze response to suggest specific columns to exclude + large_fields = _identify_large_fields(response) + if large_fields: + fields_preview = ", ".join(large_fields[:3]) + suffix = "..." if len(large_fields) > 3 else "" + suggestions.append( + f"Use 'select_columns' to exclude large fields: " + f"{fields_preview}{suffix}. Only request the columns you need." + ) + else: + suggestions.append( + "Use 'select_columns' to request only the specific columns you need " + "instead of fetching all fields" + ) + + # Suggestion 3: Add filters to reduce result set + current_filters = query_params.get("filters") + if not current_filters and "list_" in tool_name: + suggestions.append( + "Add filters to narrow down results (e.g., filter by owner, " + "date range, or specific attributes)" + ) + + # Suggestion 4: Tool-specific suggestions + tool_suggestions = _get_tool_specific_suggestions(tool_name, query_params, response) + suggestions.extend(tool_suggestions) + + # Suggestion 5: Use search instead of listing all + if "list_" in tool_name and not query_params.get("search"): + suggestions.append( + "Use the 'search' parameter to find specific items instead of " + "listing all and filtering client-side" + ) + + return suggestions + + +def _identify_large_fields(response: ToolResponse) -> List[str]: + """ + Identify fields that contribute most to response size. + + Args: + response: The response object to analyze + + Returns: + List of field names that are particularly large + """ + large_fields: List[str] = [] + + try: + from superset.utils import json + + if hasattr(response, "model_dump"): + data = response.model_dump() + elif isinstance(response, dict): + data = response + else: + return large_fields + + # Check for list responses (e.g., charts, dashboards) + items_key = None + for key in ["charts", "dashboards", "datasets", "data", "items", "results"]: + if key in data and isinstance(data[key], list) and data[key]: + items_key = key + break + + if items_key and data[items_key]: + first_item = data[items_key][0] + if isinstance(first_item, dict): + # Analyze field sizes in first item + field_sizes = {} + for field, value in first_item.items(): + if value is not None: + field_sizes[field] = len(json.dumps(value)) + + # Sort by size and identify large fields (>500 chars) + sorted_fields = sorted( + field_sizes.items(), key=lambda x: x[1], reverse=True + ) + large_fields = [ + f + for f, size in sorted_fields + if size > 500 and f not in ("id", "uuid", "name", "slice_name") + ] + + except Exception as e: # noqa: BLE001 + logger.debug("Failed to identify large fields: %s", e) + + return large_fields + + +def _get_tool_specific_suggestions( + tool_name: str, + query_params: Dict[str, Any], + response: ToolResponse, +) -> List[str]: + """ + Generate tool-specific suggestions based on the tool being called. + + Args: + tool_name: Name of the MCP tool + query_params: Extracted query parameters + response: The response object + + Returns: + List of tool-specific suggestions + """ + suggestions = [] + + if tool_name == "get_chart_data": + suggestions.append( + "For get_chart_data, use 'limit' parameter to restrict rows returned, " + "or use 'format=csv' for more compact output" + ) + + elif tool_name == "execute_sql": + suggestions.append( + "Add LIMIT clause to your SQL query to restrict the number of rows " + "(e.g., SELECT * FROM table LIMIT 100)" + ) + + elif tool_name in ("get_chart_info", "get_dashboard_info", "get_dataset_info"): + suggestions.append( + f"For {tool_name}, use 'select_columns' to fetch only specific metadata " + "fields instead of the full object" + ) + + elif tool_name == "list_charts": + suggestions.append( + "For list_charts, exclude 'params' and 'query_context' columns which " + "contain large JSON blobs - use select_columns to pick specific fields" + ) + + elif tool_name == "list_datasets": + suggestions.append( + "For list_datasets, exclude 'columns' and 'metrics' from select_columns " + "if you only need basic dataset info" + ) + + return suggestions + + +def format_size_limit_error( + tool_name: str, + params: Dict[str, Any] | None, + estimated_tokens: int, + token_limit: int, + response: ToolResponse | None = None, +) -> str: + """ + Format a user-friendly error message when response exceeds token limit. + + Args: + tool_name: Name of the MCP tool + params: The tool parameters + estimated_tokens: Estimated token count + token_limit: Configured token limit + response: Optional response for analysis + + Returns: + Formatted error message with suggestions + """ + suggestions = generate_size_reduction_suggestions( + tool_name, params, estimated_tokens, token_limit, response + ) + + error_lines = [ + f"Response too large: ~{estimated_tokens:,} tokens (limit: {token_limit:,})", + "", + "This response would overwhelm the LLM context window.", + "Please modify your query to reduce the response size:", + "", + ] + + for i, suggestion in enumerate(suggestions[:5], 1): # Limit to top 5 suggestions + error_lines.append(f"{i}. {suggestion}") + + reduction_pct = ( + (estimated_tokens - token_limit) / estimated_tokens * 100 + if estimated_tokens + else 0 + ) + error_lines.extend( + [ + "", + f"Tool: {tool_name}", + f"Reduction needed: ~{reduction_pct:.0f}%", + ] + ) + + return "\n".join(error_lines) diff --git a/tests/unit_tests/mcp_service/test_middleware.py b/tests/unit_tests/mcp_service/test_middleware.py new file mode 100644 index 00000000000..d380320b790 --- /dev/null +++ b/tests/unit_tests/mcp_service/test_middleware.py @@ -0,0 +1,343 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP service middleware. +""" + +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastmcp.exceptions import ToolError + +from superset.mcp_service.middleware import ( + create_response_size_guard_middleware, + ResponseSizeGuardMiddleware, +) + + +class TestResponseSizeGuardMiddleware: + """Test ResponseSizeGuardMiddleware class.""" + + def test_init_default_values(self) -> None: + """Should initialize with default values.""" + middleware = ResponseSizeGuardMiddleware() + assert middleware.token_limit == 25_000 + assert middleware.warn_threshold_pct == 80 + assert middleware.warn_threshold == 20000 + assert middleware.excluded_tools == set() + + def test_init_custom_values(self) -> None: + """Should initialize with custom values.""" + middleware = ResponseSizeGuardMiddleware( + token_limit=10000, + warn_threshold_pct=70, + excluded_tools=["health_check", "get_chart_preview"], + ) + assert middleware.token_limit == 10000 + assert middleware.warn_threshold_pct == 70 + assert middleware.warn_threshold == 7000 + assert middleware.excluded_tools == {"health_check", "get_chart_preview"} + + def test_init_excluded_tools_as_string(self) -> None: + """Should handle excluded_tools as a single string.""" + middleware = ResponseSizeGuardMiddleware( + excluded_tools="health_check", + ) + assert middleware.excluded_tools == {"health_check"} + + @pytest.mark.asyncio + async def test_allows_small_response(self) -> None: + """Should allow responses under token limit.""" + middleware = ResponseSizeGuardMiddleware(token_limit=25000) + + # Create mock context + context = MagicMock() + context.message.name = "list_charts" + context.message.params = {} + + # Create mock call_next that returns small response + small_response = {"charts": [{"id": 1, "name": "test"}]} + call_next = AsyncMock(return_value=small_response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + ): + result = await middleware.on_call_tool(context, call_next) + + assert result == small_response + call_next.assert_called_once_with(context) + + @pytest.mark.asyncio + async def test_blocks_large_response(self) -> None: + """Should block responses over token limit.""" + middleware = ResponseSizeGuardMiddleware(token_limit=100) # Very low limit + + # Create mock context + context = MagicMock() + context.message.name = "list_charts" + context.message.params = {"page_size": 100} + + # Create large response + large_response = { + "charts": [{"id": i, "name": f"chart_{i}"} for i in range(1000)] + } + call_next = AsyncMock(return_value=large_response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + pytest.raises(ToolError) as exc_info, + ): + await middleware.on_call_tool(context, call_next) + + # Verify error contains helpful information + error_message = str(exc_info.value) + assert "Response too large" in error_message + assert "limit" in error_message.lower() + + @pytest.mark.asyncio + async def test_skips_excluded_tools(self) -> None: + """Should skip checking for excluded tools.""" + middleware = ResponseSizeGuardMiddleware( + token_limit=100, excluded_tools=["health_check"] + ) + + # Create mock context for excluded tool + context = MagicMock() + context.message.name = "health_check" + context.message.params = {} + + # Create response that would exceed limit + large_response = {"data": "x" * 10000} + call_next = AsyncMock(return_value=large_response) + + # Should not raise even though response exceeds limit + result = await middleware.on_call_tool(context, call_next) + assert result == large_response + + @pytest.mark.asyncio + async def test_logs_warning_at_threshold(self) -> None: + """Should log warning when approaching limit.""" + middleware = ResponseSizeGuardMiddleware( + token_limit=1000, warn_threshold_pct=80 + ) + + context = MagicMock() + context.message.name = "list_charts" + context.message.params = {} + + # Response at ~85% of limit (should trigger warning but not block) + response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token + call_next = AsyncMock(return_value=response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + patch("superset.mcp_service.middleware.logger") as mock_logger, + ): + result = await middleware.on_call_tool(context, call_next) + + # Should return response (not blocked) + assert result == response + # Should log warning + mock_logger.warning.assert_called() + + @pytest.mark.asyncio + async def test_error_includes_suggestions(self) -> None: + """Should include suggestions in error message.""" + middleware = ResponseSizeGuardMiddleware(token_limit=100) + + context = MagicMock() + context.message.name = "list_charts" + context.message.params = {"page_size": 100} + + large_response = {"charts": [{"id": i} for i in range(1000)]} + call_next = AsyncMock(return_value=large_response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + pytest.raises(ToolError) as exc_info, + ): + await middleware.on_call_tool(context, call_next) + + error_message = str(exc_info.value) + # Should have numbered suggestions + assert "1." in error_message + # Should suggest reducing page_size + assert "page_size" in error_message.lower() or "limit" in error_message.lower() + + @pytest.mark.asyncio + async def test_logs_size_exceeded_event(self) -> None: + """Should log to event logger when size exceeded.""" + middleware = ResponseSizeGuardMiddleware(token_limit=100) + + context = MagicMock() + context.message.name = "list_charts" + context.message.params = {} + + large_response = {"data": "x" * 10000} + call_next = AsyncMock(return_value=large_response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger") as mock_event_logger, + pytest.raises(ToolError), + ): + await middleware.on_call_tool(context, call_next) + + # Should log to event logger + mock_event_logger.log.assert_called() + call_args = mock_event_logger.log.call_args + assert call_args.kwargs["action"] == "mcp_response_size_exceeded" + + +class TestCreateResponseSizeGuardMiddleware: + """Test create_response_size_guard_middleware factory function.""" + + def test_creates_middleware_when_enabled(self) -> None: + """Should create middleware when enabled in config.""" + mock_config = { + "enabled": True, + "token_limit": 30000, + "warn_threshold_pct": 75, + "excluded_tools": ["health_check"], + } + + mock_flask_app = MagicMock() + mock_flask_app.config.get.return_value = mock_config + + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + return_value=mock_flask_app, + ): + middleware = create_response_size_guard_middleware() + + assert middleware is not None + assert isinstance(middleware, ResponseSizeGuardMiddleware) + assert middleware.token_limit == 30000 + assert middleware.warn_threshold_pct == 75 + assert "health_check" in middleware.excluded_tools + + def test_returns_none_when_disabled(self) -> None: + """Should return None when disabled in config.""" + mock_config = {"enabled": False} + + mock_flask_app = MagicMock() + mock_flask_app.config.get.return_value = mock_config + + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + return_value=mock_flask_app, + ): + middleware = create_response_size_guard_middleware() + + assert middleware is None + + def test_uses_defaults_when_config_missing(self) -> None: + """Should use defaults when config values are missing.""" + mock_config = {"enabled": True} # Only enabled, no other values + + mock_flask_app = MagicMock() + mock_flask_app.config.get.return_value = mock_config + + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + return_value=mock_flask_app, + ): + middleware = create_response_size_guard_middleware() + + assert middleware is not None + assert middleware.token_limit == 25_000 # Default + assert middleware.warn_threshold_pct == 80 # Default + + def test_handles_exception_gracefully(self) -> None: + """Should return None on expected configuration exceptions.""" + with patch( + "superset.mcp_service.flask_singleton.get_flask_app", + side_effect=ImportError("Config error"), + ): + middleware = create_response_size_guard_middleware() + + assert middleware is None + + +class TestMiddlewareIntegration: + """Integration tests for middleware behavior.""" + + @pytest.mark.asyncio + async def test_pydantic_model_response(self) -> None: + """Should handle Pydantic model responses.""" + from pydantic import BaseModel + + class ChartInfo(BaseModel): + id: int + name: str + + middleware = ResponseSizeGuardMiddleware(token_limit=25000) + + context = MagicMock() + context.message.name = "get_chart_info" + context.message.params = {} + + response = ChartInfo(id=1, name="Test Chart") + call_next = AsyncMock(return_value=response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + ): + result = await middleware.on_call_tool(context, call_next) + + assert result == response + + @pytest.mark.asyncio + async def test_list_response(self) -> None: + """Should handle list responses.""" + middleware = ResponseSizeGuardMiddleware(token_limit=25000) + + context = MagicMock() + context.message.name = "list_charts" + context.message.params = {} + + response = [{"id": 1}, {"id": 2}, {"id": 3}] + call_next = AsyncMock(return_value=response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + ): + result = await middleware.on_call_tool(context, call_next) + + assert result == response + + @pytest.mark.asyncio + async def test_string_response(self) -> None: + """Should handle string responses.""" + middleware = ResponseSizeGuardMiddleware(token_limit=25000) + + context = MagicMock() + context.message.name = "health_check" + context.message.params = {} + + response = "OK" + call_next = AsyncMock(return_value=response) + + result = await middleware.on_call_tool(context, call_next) + assert result == response diff --git a/tests/unit_tests/mcp_service/utils/test_token_utils.py b/tests/unit_tests/mcp_service/utils/test_token_utils.py new file mode 100644 index 00000000000..649536019bb --- /dev/null +++ b/tests/unit_tests/mcp_service/utils/test_token_utils.py @@ -0,0 +1,358 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP service token utilities. +""" + +from typing import Any, List + +from pydantic import BaseModel + +from superset.mcp_service.utils.token_utils import ( + CHARS_PER_TOKEN, + estimate_response_tokens, + estimate_token_count, + extract_query_params, + format_size_limit_error, + generate_size_reduction_suggestions, + get_response_size_bytes, +) + + +class TestEstimateTokenCount: + """Test estimate_token_count function.""" + + def test_estimate_string(self) -> None: + """Should estimate tokens for a string.""" + text = "Hello world" + result = estimate_token_count(text) + expected = int(len(text) / CHARS_PER_TOKEN) + assert result == expected + + def test_estimate_bytes(self) -> None: + """Should estimate tokens for bytes.""" + text = b"Hello world" + result = estimate_token_count(text) + expected = int(len(text) / CHARS_PER_TOKEN) + assert result == expected + + def test_empty_string(self) -> None: + """Should return 0 for empty string.""" + assert estimate_token_count("") == 0 + + def test_json_like_content(self) -> None: + """Should estimate tokens for JSON-like content.""" + json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}' + result = estimate_token_count(json_str) + assert result > 0 + assert result == int(len(json_str) / CHARS_PER_TOKEN) + + +class TestEstimateResponseTokens: + """Test estimate_response_tokens function.""" + + class MockResponse(BaseModel): + """Mock Pydantic response model.""" + + name: str + value: int + items: List[Any] + + def test_estimate_pydantic_model(self) -> None: + """Should estimate tokens for Pydantic model.""" + response = self.MockResponse(name="test", value=42, items=[1, 2, 3]) + result = estimate_response_tokens(response) + assert result > 0 + + def test_estimate_dict(self) -> None: + """Should estimate tokens for dict.""" + response = {"name": "test", "value": 42} + result = estimate_response_tokens(response) + assert result > 0 + + def test_estimate_list(self) -> None: + """Should estimate tokens for list.""" + response = [{"name": "item1"}, {"name": "item2"}] + result = estimate_response_tokens(response) + assert result > 0 + + def test_estimate_string(self) -> None: + """Should estimate tokens for string response.""" + response = "Hello world" + result = estimate_response_tokens(response) + assert result > 0 + + def test_estimate_large_response(self) -> None: + """Should estimate tokens for large response.""" + response = {"items": [{"name": f"item{i}"} for i in range(1000)]} + result = estimate_response_tokens(response) + assert result > 1000 # Large response should have many tokens + + +class TestGetResponseSizeBytes: + """Test get_response_size_bytes function.""" + + def test_size_dict(self) -> None: + """Should return size in bytes for dict.""" + response = {"name": "test"} + result = get_response_size_bytes(response) + assert result > 0 + + def test_size_string(self) -> None: + """Should return size in bytes for string.""" + response = "Hello world" + result = get_response_size_bytes(response) + assert result == len(response.encode("utf-8")) + + def test_size_bytes(self) -> None: + """Should return size for bytes.""" + response = b"Hello world" + result = get_response_size_bytes(response) + assert result == len(response) + + +class TestExtractQueryParams: + """Test extract_query_params function.""" + + def test_extract_pagination_params(self) -> None: + """Should extract pagination parameters.""" + params = {"page_size": 100, "limit": 50} + result = extract_query_params(params) + assert result["page_size"] == 100 + assert result["limit"] == 50 + + def test_extract_column_selection(self) -> None: + """Should extract column selection parameters.""" + params = {"select_columns": ["name", "id"]} + result = extract_query_params(params) + assert result["select_columns"] == ["name", "id"] + + def test_extract_from_nested_request(self) -> None: + """Should extract from nested request object.""" + params = {"request": {"page_size": 50, "filters": [{"col": "name"}]}} + result = extract_query_params(params) + assert result["page_size"] == 50 + assert result["filters"] == [{"col": "name"}] + + def test_empty_params(self) -> None: + """Should return empty dict for empty params.""" + assert extract_query_params(None) == {} + assert extract_query_params({}) == {} + + def test_extract_filters(self) -> None: + """Should extract filter parameters.""" + params = {"filters": [{"col": "name", "opr": "eq", "value": "test"}]} + result = extract_query_params(params) + assert "filters" in result + + +class TestGenerateSizeReductionSuggestions: + """Test generate_size_reduction_suggestions function.""" + + def test_suggest_reduce_page_size(self) -> None: + """Should suggest reducing page_size when present.""" + params = {"page_size": 100} + suggestions = generate_size_reduction_suggestions( + tool_name="list_charts", + params=params, + estimated_tokens=50000, + token_limit=25000, + ) + assert any( + "page_size" in s.lower() or "limit" in s.lower() for s in suggestions + ) + + def test_suggest_add_limit_for_list_tools(self) -> None: + """Should suggest adding limit for list tools.""" + params: dict[str, Any] = {} + suggestions = generate_size_reduction_suggestions( + tool_name="list_charts", + params=params, + estimated_tokens=50000, + token_limit=25000, + ) + assert any( + "limit" in s.lower() or "page_size" in s.lower() for s in suggestions + ) + + def test_suggest_select_columns(self) -> None: + """Should suggest using select_columns.""" + params: dict[str, Any] = {} + suggestions = generate_size_reduction_suggestions( + tool_name="list_charts", + params=params, + estimated_tokens=50000, + token_limit=25000, + ) + assert any( + "select_columns" in s.lower() or "columns" in s.lower() for s in suggestions + ) + + def test_suggest_filters(self) -> None: + """Should suggest adding filters.""" + params: dict[str, Any] = {} + suggestions = generate_size_reduction_suggestions( + tool_name="list_charts", + params=params, + estimated_tokens=50000, + token_limit=25000, + ) + assert any("filter" in s.lower() for s in suggestions) + + def test_tool_specific_suggestions_execute_sql(self) -> None: + """Should provide SQL-specific suggestions for execute_sql.""" + suggestions = generate_size_reduction_suggestions( + tool_name="execute_sql", + params={"sql": "SELECT * FROM table"}, + estimated_tokens=50000, + token_limit=25000, + ) + assert any("LIMIT" in s or "limit" in s.lower() for s in suggestions) + + def test_tool_specific_suggestions_list_charts(self) -> None: + """Should provide chart-specific suggestions for list_charts.""" + suggestions = generate_size_reduction_suggestions( + tool_name="list_charts", + params={}, + estimated_tokens=50000, + token_limit=25000, + ) + # Should suggest excluding params or query_context + assert any( + "params" in s.lower() or "query_context" in s.lower() for s in suggestions + ) + + def test_suggests_search_parameter(self) -> None: + """Should suggest using search parameter.""" + suggestions = generate_size_reduction_suggestions( + tool_name="list_dashboards", + params={}, + estimated_tokens=50000, + token_limit=25000, + ) + assert any("search" in s.lower() for s in suggestions) + + +class TestFormatSizeLimitError: + """Test format_size_limit_error function.""" + + def test_error_contains_token_counts(self) -> None: + """Should include token counts in error message.""" + error = format_size_limit_error( + tool_name="list_charts", + params={}, + estimated_tokens=50000, + token_limit=25000, + ) + assert "50,000" in error + assert "25,000" in error + + def test_error_contains_tool_name(self) -> None: + """Should include tool name in error message.""" + error = format_size_limit_error( + tool_name="list_charts", + params={}, + estimated_tokens=50000, + token_limit=25000, + ) + assert "list_charts" in error + + def test_error_contains_suggestions(self) -> None: + """Should include suggestions in error message.""" + error = format_size_limit_error( + tool_name="list_charts", + params={"page_size": 100}, + estimated_tokens=50000, + token_limit=25000, + ) + # Should have numbered suggestions + assert "1." in error + + def test_error_contains_reduction_percentage(self) -> None: + """Should include reduction percentage in error message.""" + error = format_size_limit_error( + tool_name="list_charts", + params={}, + estimated_tokens=50000, + token_limit=25000, + ) + # 50% reduction needed + assert "50%" in error or "Reduction" in error + + def test_error_limits_suggestions_to_five(self) -> None: + """Should limit suggestions to 5.""" + error = format_size_limit_error( + tool_name="list_charts", + params={}, + estimated_tokens=100000, + token_limit=10000, + ) + # Count numbered suggestions (1. through 5.) + suggestion_count = sum(1 for i in range(1, 10) if f"{i}." in error) + assert suggestion_count <= 5 + + def test_error_message_is_readable(self) -> None: + """Should produce human-readable error message.""" + error = format_size_limit_error( + tool_name="list_charts", + params={"page_size": 100}, + estimated_tokens=75000, + token_limit=25000, + ) + # Should be multi-line and contain key information + lines = error.split("\n") + assert len(lines) > 5 + assert "Response too large" in error + assert "Please modify your query" in error + + +class TestCalculatedSuggestions: + """Test that suggestions include calculated values.""" + + def test_suggested_limit_is_calculated(self) -> None: + """Should calculate suggested limit based on reduction needed.""" + params = {"page_size": 100} + suggestions = generate_size_reduction_suggestions( + tool_name="list_charts", + params=params, + estimated_tokens=50000, # 2x over limit + token_limit=25000, + ) + # Find the page_size suggestion + page_size_suggestion = next( + (s for s in suggestions if "page_size" in s.lower()), None + ) + assert page_size_suggestion is not None + # Should suggest reducing from 100 to approximately 50 + assert "100" in page_size_suggestion + assert ( + "50" in page_size_suggestion or "reduction" in page_size_suggestion.lower() + ) + + def test_reduction_percentage_in_suggestions(self) -> None: + """Should include reduction percentage in suggestions.""" + params = {"page_size": 100} + suggestions = generate_size_reduction_suggestions( + tool_name="list_charts", + params=params, + estimated_tokens=75000, # 3x over limit + token_limit=25000, + ) + # Should mention ~66% reduction needed (int truncation of 66.6%) + combined = " ".join(suggestions) + assert "66%" in combined