From 9b520312a1a111affbfc669cc707adf1270d64d8 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Thu, 7 May 2026 11:51:31 -0400 Subject: [PATCH] fix(mcp): use tiktoken for response-size-guard token estimation (#39912) --- pyproject.toml | 8 +- requirements/base.txt | 5 +- requirements/development.txt | 7 ++ superset/mcp_service/middleware.py | 19 ++-- superset/mcp_service/utils/token_utils.py | 93 +++++++++++++++++-- .../unit_tests/mcp_service/test_middleware.py | 17 +++- .../mcp_service/utils/test_token_utils.py | 64 ++++++++++--- 7 files changed, 171 insertions(+), 42 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index fbbd8d82b7a..933e5175971 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,13 @@ solr = ["sqlalchemy-solr >= 0.2.0"] elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"] exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"] excel = ["xlrd>=1.2.0, <1.3"] -fastmcp = ["fastmcp>=3.2.4,<4.0"] +fastmcp = [ + "fastmcp>=3.2.4,<4.0", + # tiktoken backs the response-size-guard token estimator. Without + # it, the middleware falls back to a coarser character-based + # heuristic that under-counts JSON-heavy MCP responses. + "tiktoken>=0.7.0,<1.0", +] firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"] firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"] gevent = ["gevent>=23.9.1"] diff --git a/requirements/base.txt b/requirements/base.txt index 3c0575c1f4a..9c99184f920 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -183,7 +183,9 @@ idna==3.10 # trio # url-normalize isodate==0.7.2 - # via apache-superset (pyproject.toml) + # via + # apache-superset (pyproject.toml) + # apache-superset-core itsdangerous==2.2.0 # via # flask @@ -296,6 +298,7 @@ pyarrow==20.0.0 # via # -r requirements/base.in # apache-superset (pyproject.toml) + # apache-superset-core pyasn1==0.6.3 # via # pyasn1-modules diff --git a/requirements/development.txt b/requirements/development.txt index 28219acfd9c..b7664ddae6b 100644 --- a/requirements/development.txt +++ b/requirements/development.txt @@ -442,6 +442,7 @@ isodate==0.7.2 # via # -c requirements/base-constraint.txt # apache-superset + # apache-superset-core isort==6.0.1 # via pylint itsdangerous==2.2.0 @@ -715,6 +716,7 @@ pyarrow==20.0.0 # via # -c requirements/base-constraint.txt # apache-superset + # apache-superset-core # db-dtypes # pandas-gbq pyasn1==0.6.3 @@ -866,6 +868,8 @@ referencing==0.36.2 # jsonschema # jsonschema-path # jsonschema-specifications +regex==2026.4.4 + # via tiktoken requests==2.33.0 # via # -c requirements/base-constraint.txt @@ -878,6 +882,7 @@ requests==2.33.0 # requests-cache # requests-oauthlib # shillelagh + # tiktoken # trino requests-cache==1.2.1 # via @@ -1003,6 +1008,8 @@ tabulate==0.9.0 # via # -c requirements/base-constraint.txt # apache-superset +tiktoken==0.12.0 + # via apache-superset tomli-w==1.2.0 # via apache-superset-extensions-cli tomlkit==0.13.3 diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index de592a3da02..64a5145d7fc 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -41,6 +41,12 @@ from superset.mcp_service.constants import ( DEFAULT_TOKEN_LIMIT, DEFAULT_WARN_THRESHOLD_PCT, ) +from superset.mcp_service.utils.token_utils import ( + estimate_response_tokens, + format_size_limit_error, + INFO_TOOLS, + truncate_oversized_response, +) from superset.utils.core import get_user_id logger = logging.getLogger(__name__) @@ -1104,11 +1110,6 @@ class ResponseSizeGuardMiddleware(Middleware): ``content[0].text`` as a JSON string. We parse that string, run the truncation phases on the resulting dict, then re-wrap the result. """ - from superset.mcp_service.utils.token_utils import ( - estimate_response_tokens, - truncate_oversized_response, - ) - # Unwrap ToolResult so truncation operates on the real payload extracted = self._extract_payload_from_tool_result(response) if extracted is not None: @@ -1191,12 +1192,6 @@ class ResponseSizeGuardMiddleware(Middleware): # Execute the tool response = await call_next(context) - # Estimate response token count (guard against huge responses causing OOM) - from superset.mcp_service.utils.token_utils import ( - estimate_response_tokens, - format_size_limit_error, - ) - # When the response is a ToolResult, estimate tokens on the actual # payload inside content[0].text rather than on the ToolResult # wrapper (which would double-serialize the JSON string). @@ -1233,8 +1228,6 @@ class ResponseSizeGuardMiddleware(Middleware): params = getattr(context.message, "params", {}) or {} # For info tools, try dynamic truncation before blocking - from superset.mcp_service.utils.token_utils import INFO_TOOLS - if tool_name in INFO_TOOLS: truncated = self._try_truncate_info_response( tool_name, response, estimated_tokens diff --git a/superset/mcp_service/utils/token_utils.py b/superset/mcp_service/utils/token_utils.py index 00e6664e729..b14c8f5a564 100644 --- a/superset/mcp_service/utils/token_utils.py +++ b/superset/mcp_service/utils/token_utils.py @@ -21,6 +21,26 @@ Token counting and response size utilities for MCP service. This module provides utilities to estimate token counts and generate smart suggestions when responses exceed configured limits. This prevents large responses from overwhelming LLM clients like Claude Desktop. + +Token counting strategy: + +1. ``tiktoken`` with the ``cl100k_base`` encoding when the package is + installed (it is shipped as part of the ``fastmcp`` extra). This is a + real BPE tokenizer trained on a similar vocabulary to Claude's; for + English and JSON-heavy MCP payloads it tracks Claude's tokenizer + within roughly ±10%, which is far more accurate than the legacy + character heuristic. +2. A character-based fallback (``CHARS_PER_TOKEN``) when tiktoken is not + importable. The fallback uses a slightly more conservative ratio than + before (3.0 chars/token instead of 3.5) so that JSON-heavy responses + are not under-counted, which previously let oversized payloads slip + past the response-size guard. + +The exact-Claude tokenizer is only available via Anthropic's network +``count_tokens`` API; calling it from a synchronous middleware on every +tool result is too slow and adds an external dependency on every +response. ``tiktoken`` is the closest approximation we can ship without +that risk. """ from __future__ import annotations @@ -36,18 +56,63 @@ logger = logging.getLogger(__name__) # Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes) ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes] -# Approximate characters per token for estimation -# Claude tokenizer averages ~4 chars per token for English text -# JSON tends to be more verbose, so we use a slightly lower ratio -CHARS_PER_TOKEN = 3.5 +# Fallback character-to-token ratio used when tiktoken is unavailable. +# 3.0 is conservative for JSON content (the previous 3.5 under-counted +# JSON-heavy payloads relative to Claude's actual tokenizer, which let +# oversized responses slip past the response-size guard). +CHARS_PER_TOKEN = 3.0 + +# Encoding used when tiktoken is available. cl100k_base is OpenAI's +# tokenizer for GPT-3.5/4; it is BPE-based with a vocabulary similar to +# Claude's and tracks Claude's token counts within roughly ±10% for +# English and JSON-heavy MCP responses. +_TIKTOKEN_ENCODING_NAME = "cl100k_base" + + +def _load_tiktoken_encoding() -> Any: + """Return a tiktoken encoding instance, or None if tiktoken is unavailable. + + Imported lazily so the module can be used in environments without + tiktoken installed. The encoding is small (~1 MB) so we cache it on + first use. + """ + try: + import tiktoken + except ImportError: + logger.info( + "tiktoken not installed; falling back to char-based token " + "estimation (CHARS_PER_TOKEN=%s). Install the 'fastmcp' extra " + "for accurate counts.", + CHARS_PER_TOKEN, + ) + return None + + try: + return tiktoken.get_encoding(_TIKTOKEN_ENCODING_NAME) + except (KeyError, ValueError) as exc: + # tiktoken installed but the requested encoding is missing — this + # only happens on partial installs. Treat as no tokenizer rather + # than crashing on every tool call. + logger.warning( + "tiktoken encoding '%s' unavailable: %s; falling back to " + "char-based token estimation", + _TIKTOKEN_ENCODING_NAME, + exc, + ) + return None + + +# Cached encoding instance (None if tiktoken not importable). +_ENCODING = _load_tiktoken_encoding() def estimate_token_count(text: str | bytes) -> int: """ Estimate the token count for a given text. - Uses a character-based heuristic since we don't have direct access to - the actual tokenizer. This is conservative to avoid underestimating. + Uses tiktoken's ``cl100k_base`` encoding when available for + Claude-aligned accuracy (within ~10%), falling back to a + character-based heuristic otherwise. Args: text: The text to estimate tokens for (string or bytes) @@ -58,11 +123,19 @@ def estimate_token_count(text: str | bytes) -> int: if isinstance(text, bytes): text = text.decode("utf-8", errors="replace") - # Simple heuristic: ~3.5 characters per token for JSON/code - text_length = len(text) - if text_length == 0: + if not text: return 0 - return max(1, int(text_length / CHARS_PER_TOKEN)) + + if _ENCODING is not None: + try: + return len(_ENCODING.encode(text)) + except (ValueError, UnicodeError) as exc: + # Defensive: if tiktoken chokes on a specific input, fall + # back to the char heuristic for this call rather than + # raising — the response size guard must never fail-open. + logger.warning("tiktoken encode failed (%s); using fallback", exc) + + return max(1, int(len(text) / CHARS_PER_TOKEN)) def estimate_response_tokens(response: ToolResponse) -> int: diff --git a/tests/unit_tests/mcp_service/test_middleware.py b/tests/unit_tests/mcp_service/test_middleware.py index e95c0d87220..948ba2547cb 100644 --- a/tests/unit_tests/mcp_service/test_middleware.py +++ b/tests/unit_tests/mcp_service/test_middleware.py @@ -146,7 +146,13 @@ class TestResponseSizeGuardMiddleware: @pytest.mark.asyncio async def test_logs_warning_at_threshold(self) -> None: - """Should log warning when approaching limit.""" + """Should log warning when approaching limit. + + Mocks the token estimator to return a specific value above the + warn threshold but below the hard limit, decoupling the test + from whichever tokenizer (tiktoken or char heuristic) happens + to be loaded. + """ middleware = ResponseSizeGuardMiddleware( token_limit=1000, warn_threshold_pct=80 ) @@ -155,18 +161,21 @@ class TestResponseSizeGuardMiddleware: context.message.name = "list_charts" context.message.params = {} - # Response at ~85% of limit (should trigger warning but not block) - response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token + response = {"data": "approaching the limit"} call_next = AsyncMock(return_value=response) with ( patch("superset.mcp_service.middleware.get_user_id", return_value=1), patch("superset.mcp_service.middleware.event_logger"), + patch( + "superset.mcp_service.middleware.estimate_response_tokens", + return_value=850, + ), patch("superset.mcp_service.middleware.logger") as mock_logger, ): result = await middleware.on_call_tool(context, call_next) - # Should return response (not blocked) + # Should return response (not blocked at 85% of limit) assert result == response # Should log warning mock_logger.warning.assert_called() diff --git a/tests/unit_tests/mcp_service/utils/test_token_utils.py b/tests/unit_tests/mcp_service/utils/test_token_utils.py index 9a49264bd93..4254bd9f539 100644 --- a/tests/unit_tests/mcp_service/utils/test_token_utils.py +++ b/tests/unit_tests/mcp_service/utils/test_token_utils.py @@ -20,9 +20,11 @@ Unit tests for MCP service token utilities. """ from typing import Any, List +from unittest.mock import patch from pydantic import BaseModel +from superset.mcp_service.utils import token_utils from superset.mcp_service.utils.token_utils import ( _replace_collections_with_summaries, _summarize_large_dicts, @@ -45,29 +47,65 @@ class TestEstimateTokenCount: """Test estimate_token_count function.""" def test_estimate_string(self) -> None: - """Should estimate tokens for a string.""" + """Should produce a positive non-zero estimate for a normal string. + + We don't assert on a specific number because the result depends on + which tokenizer is loaded (tiktoken when available, char heuristic + otherwise). + """ text = "Hello world" result = estimate_token_count(text) - expected = int(len(text) / CHARS_PER_TOKEN) - assert result == expected + assert result > 0 def test_estimate_bytes(self) -> None: - """Should estimate tokens for bytes.""" - text = b"Hello world" - result = estimate_token_count(text) - expected = int(len(text) / CHARS_PER_TOKEN) - assert result == expected + """Bytes input should be decoded and produce the same count as the + equivalent string.""" + text = "Hello world" + assert estimate_token_count(text.encode("utf-8")) == estimate_token_count(text) def test_empty_string(self) -> None: - """Should return 0 for empty string.""" + """Should return 0 for empty string and empty bytes.""" assert estimate_token_count("") == 0 + assert estimate_token_count(b"") == 0 def test_json_like_content(self) -> None: - """Should estimate tokens for JSON-like content.""" + """JSON content should produce a positive estimate.""" json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}' - result = estimate_token_count(json_str) - assert result > 0 - assert result == int(len(json_str) / CHARS_PER_TOKEN) + assert estimate_token_count(json_str) > 0 + + def test_long_text_roughly_scales_with_length(self) -> None: + """A doubled string should produce roughly double the token count + (within ±10%).""" + small = "the quick brown fox jumps over the lazy dog. " * 20 + large = small * 2 + small_n = estimate_token_count(small) + large_n = estimate_token_count(large) + # Within 10% of 2x — both tokenizers (tiktoken and the char + # fallback) preserve length monotonicity. + assert 1.8 * small_n <= large_n <= 2.2 * small_n + + def test_fallback_uses_chars_per_token_when_tiktoken_unavailable( + self, + ) -> None: + """When the tiktoken encoding is None (not installed), the + function falls back to len/CHARS_PER_TOKEN math.""" + text = "x" * 100 + with patch.object(token_utils, "_ENCODING", None): + result = estimate_token_count(text) + assert result == int(100 / CHARS_PER_TOKEN) + + def test_fallback_when_tiktoken_encode_raises(self) -> None: + """A misbehaving encoding should fall back to the char heuristic + rather than raise — the size guard must never fail-open.""" + + class BoomEncoding: + def encode(self, text: str) -> list[int]: + raise ValueError("simulated tiktoken failure") + + text = "abc" * 50 + with patch.object(token_utils, "_ENCODING", BoomEncoding()): + result = estimate_token_count(text) + assert result == int(len(text) / CHARS_PER_TOKEN) class TestEstimateResponseTokens: