diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index b0ffc4f5f7c..2254e8cc597 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -936,6 +936,72 @@ class ResponseSizeGuardMiddleware(Middleware): excluded_tools = [excluded_tools] self.excluded_tools = set(excluded_tools or []) + def _try_truncate_info_response( + self, + tool_name: str, + response: Any, + estimated_tokens: int, + ) -> Any | None: + """Attempt to dynamically truncate an info tool response to fit the limit. + + Returns the truncated response if successful, None otherwise. + """ + from superset.mcp_service.utils.token_utils import ( + estimate_response_tokens, + truncate_oversized_response, + ) + + try: + truncated, was_truncated, notes = truncate_oversized_response( + response, self.token_limit + ) + except (MemoryError, RecursionError) as trunc_error: + logger.warning( + "Truncation failed for %s due to %s: %s", + tool_name, + type(trunc_error).__name__, + trunc_error, + ) + return None + + if not was_truncated: + return None + + truncated_tokens = estimate_response_tokens(truncated) + if truncated_tokens > self.token_limit: + return None + + logger.warning( + "Response for %s truncated from ~%d to ~%d tokens (limit: %d). Fields: %s", + tool_name, + estimated_tokens, + truncated_tokens, + self.token_limit, + "; ".join(notes), + ) + + try: + user_id = get_user_id() + event_logger.log( + user_id=user_id, + action="mcp_response_truncated", + curated_payload={ + "tool": tool_name, + "original_tokens": estimated_tokens, + "truncated_tokens": truncated_tokens, + "token_limit": self.token_limit, + "truncation_notes": notes, + }, + ) + except Exception as log_error: # noqa: BLE001 + logger.warning("Failed to log truncation event: %s", log_error) + + if isinstance(truncated, dict): + truncated["_response_truncated"] = True + truncated["_truncation_notes"] = notes + + return truncated + async def on_call_tool( self, context: MiddlewareContext, @@ -984,9 +1050,18 @@ class ResponseSizeGuardMiddleware(Middleware): # Block if over limit if estimated_tokens > self.token_limit: - # Extract params for smart suggestions params = getattr(context.message, "params", {}) or {} + # For info tools, try dynamic truncation before blocking + from superset.mcp_service.utils.token_utils import INFO_TOOLS + + if tool_name in INFO_TOOLS: + truncated = self._try_truncate_info_response( + tool_name, response, estimated_tokens + ) + if truncated is not None: + return truncated + # Log the blocked response logger.error( "Response blocked for %s: ~%d tokens exceeds limit of %d", @@ -1011,9 +1086,6 @@ class ResponseSizeGuardMiddleware(Middleware): except Exception as log_error: # noqa: BLE001 logger.warning("Failed to log size exceeded event: %s", log_error) - # Generate helpful error message with suggestions - # Avoid passing the full `response` (which may be huge) into the formatter - # to prevent large-memory operations during error formatting. error_message = format_size_limit_error( tool_name=tool_name, params=params, diff --git a/superset/mcp_service/utils/token_utils.py b/superset/mcp_service/utils/token_utils.py index d4891634a8e..423bbc6a224 100644 --- a/superset/mcp_service/utils/token_utils.py +++ b/superset/mcp_service/utils/token_utils.py @@ -382,6 +382,210 @@ def _get_tool_specific_suggestions( return suggestions +# Tools eligible for dynamic response truncation instead of hard blocking. +# These tools return single objects (not paginated lists) where truncation +# is preferable to returning an error. +INFO_TOOLS = frozenset( + { + "get_chart_info", + "get_dataset_info", + "get_dashboard_info", + "get_instance_info", + } +) + +# Maximum character length for string fields before truncation +_MAX_STRING_CHARS = 500 +# Maximum items to keep in list fields before truncation +_MAX_LIST_ITEMS = 30 +# Maximum keys to keep when summarizing large dict fields +_MAX_DICT_KEYS = 20 + + +def _truncate_strings( + data: Dict[str, Any], notes: List[str], max_chars: int = _MAX_STRING_CHARS +) -> bool: + """Truncate string fields exceeding max_chars at the top level only.""" + changed = False + for key, value in data.items(): + if isinstance(value, str) and len(value) > max_chars: + original_len = len(value) + data[key] = value[:max_chars] + f"... [truncated from {original_len} chars]" + notes.append(f"Field '{key}' truncated from {original_len} chars") + changed = True + return changed + + +def _truncate_strings_recursive( + data: Any, + notes: List[str], + max_chars: int = _MAX_STRING_CHARS, + path: str = "", + _depth: int = 0, +) -> bool: + """Recursively truncate strings throughout the entire data tree. + + Walks nested dicts and list items to catch strings like + ``charts[0].description`` that top-level truncation misses. + Depth is capped at 10 to avoid runaway recursion. + """ + if _depth > 10: + return False + changed = False + if isinstance(data, dict): + for key, value in data.items(): + field_path = f"{path}.{key}" if path else key + if isinstance(value, str) and len(value) > max_chars: + original_len = len(value) + data[key] = ( + value[:max_chars] + f"... [truncated from {original_len} chars]" + ) + notes.append( + f"Field '{field_path}' truncated from {original_len} chars" + ) + changed = True + elif isinstance(value, (dict, list)): + changed |= _truncate_strings_recursive( + value, notes, max_chars, field_path, _depth + 1 + ) + elif isinstance(data, list): + for i, item in enumerate(data): + if isinstance(item, (dict, list)): + changed |= _truncate_strings_recursive( + item, notes, max_chars, f"{path}[{i}]", _depth + 1 + ) + return changed + + +def _truncate_lists(data: Dict[str, Any], notes: List[str], max_items: int) -> bool: + """Truncate list fields exceeding max_items. Returns True if any truncated. + + Does NOT append marker objects into the list to preserve the element type + contract (e.g. ``List[TableColumnInfo]`` stays homogeneous). Truncation + metadata is communicated through the *notes* list and top-level response + fields ``_response_truncated`` / ``_truncation_notes``. + """ + changed = False + for key, value in data.items(): + if isinstance(value, list) and len(value) > max_items: + original_len = len(value) + data[key] = value[:max_items] + notes.append( + f"Field '{key}' truncated from {original_len} to {max_items} items" + ) + changed = True + return changed + + +def _summarize_large_dicts( + data: Dict[str, Any], notes: List[str], max_keys: int = _MAX_DICT_KEYS +) -> bool: + """Replace large dict fields with key summaries. Returns True if any changed.""" + changed = False + for key, value in data.items(): + if isinstance(value, dict) and len(value) > max_keys: + keys_list = list(value.keys())[:max_keys] + data[key] = { + "_truncated": True, + "_message": ( + f"Dict with {len(value)} keys truncated. " + f"Keys: {', '.join(str(k) for k in keys_list)}..." + ), + } + notes.append(f"Field '{key}' dict summarized ({len(value)} keys)") + changed = True + return changed + + +def _replace_collections_with_summaries(data: Dict[str, Any], notes: List[str]) -> bool: + """Replace all non-empty list/dict fields with empty/minimal values. + + Lists are emptied (preserving the list type) rather than replaced with + marker objects to avoid breaking typed list contracts. + """ + changed = False + for key, value in list(data.items()): + if not isinstance(value, (list, dict)) or not value: + continue + count = len(value) + if isinstance(value, list): + data[key] = [] + notes.append(f"Field '{key}' list ({count} items) cleared to fit limit") + else: + data[key] = {} + notes.append(f"Field '{key}' dict ({count} keys) cleared to fit limit") + changed = True + return changed + + +def _is_under_limit(data: Dict[str, Any], token_limit: int) -> bool: + """Check if the serialized data fits within the token limit.""" + from superset.utils import json as utils_json + + return estimate_token_count(utils_json.dumps(data)) <= token_limit + + +def truncate_oversized_response( + response: ToolResponse, + token_limit: int, +) -> tuple[ToolResponse, bool, list[str]]: + """ + Dynamically truncate large fields in a response to fit within the token limit. + + Applies five progressive phases of truncation: + 1. Truncate long top-level string fields + 2. Truncate large list fields to _MAX_LIST_ITEMS + 3. Recursively truncate strings in nested structures (list items, nested dicts) + 4. Aggressively reduce lists to 10 items and summarize large dicts + 5. Replace all collections with empty values + + Args: + response: The tool response (Pydantic model, dict, or other). + token_limit: Maximum estimated tokens allowed. + + Returns: + A tuple of (possibly-truncated response, was_truncated, list of notes). + """ + notes: list[str] = [] + + # Convert to a mutable dict for manipulation + if hasattr(response, "model_dump"): + data = response.model_dump() + elif isinstance(response, dict): + data = dict(response) + else: + return response, False, notes + + was_truncated = False + + # Phase 1: Truncate long string fields + was_truncated |= _truncate_strings(data, notes) + if _is_under_limit(data, token_limit): + return data, was_truncated, notes + + # Phase 2: Truncate large list fields + was_truncated |= _truncate_lists(data, notes, _MAX_LIST_ITEMS) + if _is_under_limit(data, token_limit): + return data, was_truncated, notes + + # Phase 3: Recursively truncate strings inside nested structures + # (e.g. charts[i].description, native_filters[i].config, etc.) + was_truncated |= _truncate_strings_recursive(data, notes) + if _is_under_limit(data, token_limit): + return data, was_truncated, notes + + # Phase 4: Aggressively reduce lists and summarize large dicts + was_truncated |= _truncate_lists(data, notes, max_items=10) + was_truncated |= _summarize_large_dicts(data, notes) + if _is_under_limit(data, token_limit): + return data, was_truncated, notes + + # Phase 5: Nuclear — replace all collections with empty values + was_truncated |= _replace_collections_with_summaries(data, notes) + + return data, was_truncated, notes + + def format_size_limit_error( tool_name: str, params: Dict[str, Any] | None, diff --git a/tests/unit_tests/mcp_service/test_middleware.py b/tests/unit_tests/mcp_service/test_middleware.py index d380320b790..bcc164e048f 100644 --- a/tests/unit_tests/mcp_service/test_middleware.py +++ b/tests/unit_tests/mcp_service/test_middleware.py @@ -207,6 +207,106 @@ class TestResponseSizeGuardMiddleware: call_args = mock_event_logger.log.call_args assert call_args.kwargs["action"] == "mcp_response_size_exceeded" + @pytest.mark.asyncio + async def test_truncates_info_tool_instead_of_blocking(self) -> None: + """Should truncate info tool responses instead of blocking them.""" + middleware = ResponseSizeGuardMiddleware(token_limit=500) + + context = MagicMock() + context.message.name = "get_dataset_info" + context.message.params = {} + + # Large info tool response with a big description + large_response = { + "id": 1, + "table_name": "test", + "description": "x" * 50000, + } + call_next = AsyncMock(return_value=large_response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + ): + result = await middleware.on_call_tool(context, call_next) + + # Should return truncated response, not raise ToolError + assert isinstance(result, dict) + assert result["id"] == 1 + assert result["_response_truncated"] is True + assert "[truncated" in result["description"] + + @pytest.mark.asyncio + async def test_truncates_chart_info_with_large_form_data(self) -> None: + """Should truncate get_chart_info with large form_data.""" + middleware = ResponseSizeGuardMiddleware(token_limit=500) + + context = MagicMock() + context.message.name = "get_chart_info" + context.message.params = {} + + large_response = { + "id": 1, + "slice_name": "My Chart", + "form_data": {f"key_{i}": f"value_{i}" for i in range(100)}, + } + call_next = AsyncMock(return_value=large_response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + ): + result = await middleware.on_call_tool(context, call_next) + + assert isinstance(result, dict) + assert result["id"] == 1 + assert result["_response_truncated"] is True + + @pytest.mark.asyncio + async def test_still_blocks_non_info_tools(self) -> None: + """Should still block non-info tools that exceed limit.""" + middleware = ResponseSizeGuardMiddleware(token_limit=100) + + context = MagicMock() + context.message.name = "list_charts" # Not an info tool + context.message.params = {} + + large_response = {"data": "x" * 10000} + call_next = AsyncMock(return_value=large_response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger"), + pytest.raises(ToolError), + ): + await middleware.on_call_tool(context, call_next) + + @pytest.mark.asyncio + async def test_logs_truncation_event(self) -> None: + """Should log mcp_response_truncated event on successful truncation.""" + middleware = ResponseSizeGuardMiddleware(token_limit=500) + + context = MagicMock() + context.message.name = "get_dashboard_info" + context.message.params = {} + + large_response = { + "id": 1, + "description": "x" * 50000, + } + call_next = AsyncMock(return_value=large_response) + + with ( + patch("superset.mcp_service.middleware.get_user_id", return_value=1), + patch("superset.mcp_service.middleware.event_logger") as mock_event_logger, + ): + await middleware.on_call_tool(context, call_next) + + # Should log truncation event (not size_exceeded) + mock_event_logger.log.assert_called() + call_args = mock_event_logger.log.call_args + assert call_args.kwargs["action"] == "mcp_response_truncated" + class TestCreateResponseSizeGuardMiddleware: """Test create_response_size_guard_middleware factory function.""" diff --git a/tests/unit_tests/mcp_service/utils/test_token_utils.py b/tests/unit_tests/mcp_service/utils/test_token_utils.py index a451b086b37..2e540902c93 100644 --- a/tests/unit_tests/mcp_service/utils/test_token_utils.py +++ b/tests/unit_tests/mcp_service/utils/test_token_utils.py @@ -24,6 +24,11 @@ from typing import Any, List from pydantic import BaseModel from superset.mcp_service.utils.token_utils import ( + _replace_collections_with_summaries, + _summarize_large_dicts, + _truncate_lists, + _truncate_strings, + _truncate_strings_recursive, CHARS_PER_TOKEN, estimate_response_tokens, estimate_token_count, @@ -31,6 +36,8 @@ from superset.mcp_service.utils.token_utils import ( format_size_limit_error, generate_size_reduction_suggestions, get_response_size_bytes, + INFO_TOOLS, + truncate_oversized_response, ) @@ -374,3 +381,259 @@ class TestCalculatedSuggestions: # Should mention ~66% reduction needed (int truncation of 66.6%) combined = " ".join(suggestions) assert "66%" in combined + + +class TestInfoToolsSet: + """Test the INFO_TOOLS constant.""" + + def test_info_tools_contains_expected_tools(self) -> None: + """Should contain all info tools.""" + assert "get_chart_info" in INFO_TOOLS + assert "get_dataset_info" in INFO_TOOLS + assert "get_dashboard_info" in INFO_TOOLS + assert "get_instance_info" in INFO_TOOLS + + def test_info_tools_does_not_contain_list_tools(self) -> None: + """Should not contain list or write tools.""" + assert "list_charts" not in INFO_TOOLS + assert "execute_sql" not in INFO_TOOLS + assert "generate_chart" not in INFO_TOOLS + + +class TestTruncateStrings: + """Test _truncate_strings helper.""" + + def test_truncates_long_strings(self) -> None: + """Should truncate strings exceeding max_chars.""" + data: dict[str, Any] = {"description": "x" * 1000, "name": "short"} + notes: list[str] = [] + changed = _truncate_strings(data, notes, max_chars=500) + assert changed is True + assert len(data["description"]) < 1000 + assert "[truncated from 1000 chars]" in data["description"] + assert data["name"] == "short" + assert len(notes) == 1 + + def test_does_not_truncate_short_strings(self) -> None: + """Should not truncate strings within limit.""" + data: dict[str, Any] = {"name": "hello", "id": 123} + notes: list[str] = [] + changed = _truncate_strings(data, notes, max_chars=500) + assert changed is False + assert data["name"] == "hello" + assert len(notes) == 0 + + +class TestTruncateStringsRecursive: + """Test _truncate_strings_recursive helper.""" + + def test_truncates_nested_strings_in_list_items(self) -> None: + """Should truncate strings inside list items (e.g. charts[i].description).""" + data: dict[str, Any] = { + "id": 1, + "charts": [ + {"id": 1, "description": "x" * 1000}, + {"id": 2, "description": "short"}, + ], + } + notes: list[str] = [] + changed = _truncate_strings_recursive(data, notes, max_chars=500) + assert changed is True + assert "[truncated" in data["charts"][0]["description"] + assert data["charts"][1]["description"] == "short" + assert len(notes) == 1 + assert "charts[0].description" in notes[0] + + def test_truncates_nested_strings_in_dicts(self) -> None: + """Should truncate strings inside nested dicts.""" + data: dict[str, Any] = { + "filter_state": { + "dataMask": {"some_filter": "y" * 2000}, + }, + } + notes: list[str] = [] + changed = _truncate_strings_recursive(data, notes, max_chars=500) + assert changed is True + assert "[truncated" in data["filter_state"]["dataMask"]["some_filter"] + + def test_respects_depth_limit(self) -> None: + """Should stop recursing at depth 10.""" + # Build a deeply nested structure (15 levels) + data: dict[str, Any] = {"level": "x" * 1000} + current = data + for _ in range(15): + current["nested"] = {"level": "x" * 1000} + current = current["nested"] + notes: list[str] = [] + _truncate_strings_recursive(data, notes, max_chars=500) + # Should truncate levels 0-10 but stop before 15 + assert len(notes) <= 11 + + def test_handles_empty_structures(self) -> None: + """Should handle empty dicts and lists gracefully.""" + data: dict[str, Any] = {"items": [], "meta": {}, "name": "ok"} + notes: list[str] = [] + changed = _truncate_strings_recursive(data, notes, max_chars=500) + assert changed is False + + def test_dashboard_with_many_charts_edge_case(self) -> None: + """Simulate a dashboard with 30 charts each having long descriptions.""" + data: dict[str, Any] = { + "id": 1, + "dashboard_title": "Big Dashboard", + "charts": [ + {"id": i, "slice_name": f"Chart {i}", "description": "d" * 2000} + for i in range(30) + ], + } + notes: list[str] = [] + changed = _truncate_strings_recursive(data, notes, max_chars=500) + assert changed is True + # All 30 chart descriptions should be truncated + assert len(notes) == 30 + for chart in data["charts"]: + assert len(chart["description"]) < 2000 + assert "[truncated" in chart["description"] + + +class TestTruncateLists: + """Test _truncate_lists helper.""" + + def test_truncates_long_lists(self) -> None: + """Should truncate lists exceeding max_items without inline markers.""" + data: dict[str, Any] = { + "columns": [{"name": f"col_{i}"} for i in range(50)], + "tags": [1, 2], + } + notes: list[str] = [] + changed = _truncate_lists(data, notes, max_items=10) + assert changed is True + # Exactly 10 items — no marker appended (preserves type contract) + assert len(data["columns"]) == 10 + assert all(isinstance(c, dict) and "name" in c for c in data["columns"]) + assert data["tags"] == [1, 2] # Not truncated + assert len(notes) == 1 + assert "50" in notes[0] + + def test_does_not_truncate_short_lists(self) -> None: + """Should not truncate lists within limit.""" + data: dict[str, Any] = {"items": [1, 2, 3]} + notes: list[str] = [] + changed = _truncate_lists(data, notes, max_items=10) + assert changed is False + + +class TestSummarizeLargeDicts: + """Test _summarize_large_dicts helper.""" + + def test_summarizes_large_dicts(self) -> None: + """Should replace large dicts with key summaries.""" + big_dict = {f"key_{i}": f"value_{i}" for i in range(30)} + data: dict[str, Any] = {"form_data": big_dict, "id": 1} + notes: list[str] = [] + changed = _summarize_large_dicts(data, notes, max_keys=20) + assert changed is True + assert data["form_data"]["_truncated"] is True + assert "30 keys" in data["form_data"]["_message"] + assert data["id"] == 1 + + def test_does_not_summarize_small_dicts(self) -> None: + """Should not summarize dicts within limit.""" + data: dict[str, Any] = {"params": {"a": 1, "b": 2}} + notes: list[str] = [] + changed = _summarize_large_dicts(data, notes, max_keys=20) + assert changed is False + + +class TestReplaceCollectionsWithSummaries: + """Test _replace_collections_with_summaries helper.""" + + def test_replaces_lists_and_dicts(self) -> None: + """Should clear non-empty collections to reduce size.""" + data: dict[str, Any] = { + "columns": [1, 2, 3], + "params": {"a": 1}, + "name": "test", + "empty": [], + } + notes: list[str] = [] + changed = _replace_collections_with_summaries(data, notes) + assert changed is True + # Lists become empty lists (preserves type) + assert data["columns"] == [] + # Dicts become empty dicts (preserves type) + assert data["params"] == {} + # Scalars unchanged + assert data["name"] == "test" + # Empty collections unchanged + assert data["empty"] == [] + assert len(notes) == 2 + + +class TestTruncateOversizedResponse: + """Test truncate_oversized_response function.""" + + def test_no_truncation_needed(self) -> None: + """Should return original data when under limit.""" + response = {"id": 1, "name": "test"} + result, was_truncated, notes = truncate_oversized_response(response, 10000) + assert was_truncated is False + assert notes == [] + + def test_truncates_large_string_fields(self) -> None: + """Should truncate long strings to fit.""" + response = { + "id": 1, + "description": "x" * 50000, # Very large description + } + result, was_truncated, notes = truncate_oversized_response(response, 500) + assert was_truncated is True + assert isinstance(result, dict) + assert "[truncated" in result["description"] + assert any("description" in n for n in notes) + + def test_truncates_large_lists(self) -> None: + """Should truncate lists when strings alone are not enough.""" + response = { + "id": 1, + "columns": [{"name": f"col_{i}", "type": "VARCHAR"} for i in range(200)], + } + result, was_truncated, notes = truncate_oversized_response(response, 500) + assert was_truncated is True + assert isinstance(result, dict) + # Should have been truncated + assert len(result["columns"]) < 200 + + def test_handles_pydantic_model(self) -> None: + """Should handle Pydantic model input.""" + + class FakeInfo(BaseModel): + id: int = 1 + description: str = "x" * 5000 + + response = FakeInfo() + result, was_truncated, notes = truncate_oversized_response(response, 200) + assert was_truncated is True + assert isinstance(result, dict) + + def test_returns_non_dict_unchanged(self) -> None: + """Should return non-dict/model responses unchanged.""" + result, was_truncated, notes = truncate_oversized_response("just a string", 100) + assert was_truncated is False + assert result == "just a string" + + def test_progressive_truncation(self) -> None: + """Should progressively apply truncation phases.""" + # Build a response that's quite large + response = { + "id": 1, + "description": "x" * 2000, + "css": "y" * 2000, + "columns": [{"name": f"col_{i}"} for i in range(100)], + "form_data": {f"key_{i}": f"val_{i}" for i in range(50)}, + } + result, was_truncated, notes = truncate_oversized_response(response, 300) + assert was_truncated is True + assert isinstance(result, dict) + assert result["id"] == 1 # Scalar fields preserved + assert len(notes) > 0