mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
fix(mcp): use tiktoken for response-size-guard token estimation (#39912)
This commit is contained in:
@@ -145,7 +145,13 @@ solr = ["sqlalchemy-solr >= 0.2.0"]
|
||||
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
|
||||
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
|
||||
excel = ["xlrd>=1.2.0, <1.3"]
|
||||
fastmcp = ["fastmcp>=3.2.4,<4.0"]
|
||||
fastmcp = [
|
||||
"fastmcp>=3.2.4,<4.0",
|
||||
# tiktoken backs the response-size-guard token estimator. Without
|
||||
# it, the middleware falls back to a coarser character-based
|
||||
# heuristic that under-counts JSON-heavy MCP responses.
|
||||
"tiktoken>=0.7.0,<1.0",
|
||||
]
|
||||
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
|
||||
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
|
||||
gevent = ["gevent>=23.9.1"]
|
||||
|
||||
@@ -183,7 +183,9 @@ idna==3.10
|
||||
# trio
|
||||
# url-normalize
|
||||
isodate==0.7.2
|
||||
# via apache-superset (pyproject.toml)
|
||||
# via
|
||||
# apache-superset (pyproject.toml)
|
||||
# apache-superset-core
|
||||
itsdangerous==2.2.0
|
||||
# via
|
||||
# flask
|
||||
@@ -296,6 +298,7 @@ pyarrow==20.0.0
|
||||
# via
|
||||
# -r requirements/base.in
|
||||
# apache-superset (pyproject.toml)
|
||||
# apache-superset-core
|
||||
pyasn1==0.6.3
|
||||
# via
|
||||
# pyasn1-modules
|
||||
|
||||
@@ -442,6 +442,7 @@ isodate==0.7.2
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# apache-superset-core
|
||||
isort==6.0.1
|
||||
# via pylint
|
||||
itsdangerous==2.2.0
|
||||
@@ -715,6 +716,7 @@ pyarrow==20.0.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# apache-superset-core
|
||||
# db-dtypes
|
||||
# pandas-gbq
|
||||
pyasn1==0.6.3
|
||||
@@ -866,6 +868,8 @@ referencing==0.36.2
|
||||
# jsonschema
|
||||
# jsonschema-path
|
||||
# jsonschema-specifications
|
||||
regex==2026.4.4
|
||||
# via tiktoken
|
||||
requests==2.33.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -878,6 +882,7 @@ requests==2.33.0
|
||||
# requests-cache
|
||||
# requests-oauthlib
|
||||
# shillelagh
|
||||
# tiktoken
|
||||
# trino
|
||||
requests-cache==1.2.1
|
||||
# via
|
||||
@@ -1003,6 +1008,8 @@ tabulate==0.9.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
tiktoken==0.12.0
|
||||
# via apache-superset
|
||||
tomli-w==1.2.0
|
||||
# via apache-superset-extensions-cli
|
||||
tomlkit==0.13.3
|
||||
|
||||
@@ -41,6 +41,12 @@ from superset.mcp_service.constants import (
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_WARN_THRESHOLD_PCT,
|
||||
)
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
format_size_limit_error,
|
||||
INFO_TOOLS,
|
||||
truncate_oversized_response,
|
||||
)
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1104,11 +1110,6 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
``content[0].text`` as a JSON string. We parse that string, run the
|
||||
truncation phases on the resulting dict, then re-wrap the result.
|
||||
"""
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
truncate_oversized_response,
|
||||
)
|
||||
|
||||
# Unwrap ToolResult so truncation operates on the real payload
|
||||
extracted = self._extract_payload_from_tool_result(response)
|
||||
if extracted is not None:
|
||||
@@ -1191,12 +1192,6 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
# Execute the tool
|
||||
response = await call_next(context)
|
||||
|
||||
# Estimate response token count (guard against huge responses causing OOM)
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
format_size_limit_error,
|
||||
)
|
||||
|
||||
# When the response is a ToolResult, estimate tokens on the actual
|
||||
# payload inside content[0].text rather than on the ToolResult
|
||||
# wrapper (which would double-serialize the JSON string).
|
||||
@@ -1233,8 +1228,6 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
params = getattr(context.message, "params", {}) or {}
|
||||
|
||||
# For info tools, try dynamic truncation before blocking
|
||||
from superset.mcp_service.utils.token_utils import INFO_TOOLS
|
||||
|
||||
if tool_name in INFO_TOOLS:
|
||||
truncated = self._try_truncate_info_response(
|
||||
tool_name, response, estimated_tokens
|
||||
|
||||
@@ -21,6 +21,26 @@ Token counting and response size utilities for MCP service.
|
||||
This module provides utilities to estimate token counts and generate smart
|
||||
suggestions when responses exceed configured limits. This prevents large
|
||||
responses from overwhelming LLM clients like Claude Desktop.
|
||||
|
||||
Token counting strategy:
|
||||
|
||||
1. ``tiktoken`` with the ``cl100k_base`` encoding when the package is
|
||||
installed (it is shipped as part of the ``fastmcp`` extra). This is a
|
||||
real BPE tokenizer trained on a similar vocabulary to Claude's; for
|
||||
English and JSON-heavy MCP payloads it tracks Claude's tokenizer
|
||||
within roughly ±10%, which is far more accurate than the legacy
|
||||
character heuristic.
|
||||
2. A character-based fallback (``CHARS_PER_TOKEN``) when tiktoken is not
|
||||
importable. The fallback uses a slightly more conservative ratio than
|
||||
before (3.0 chars/token instead of 3.5) so that JSON-heavy responses
|
||||
are not under-counted, which previously let oversized payloads slip
|
||||
past the response-size guard.
|
||||
|
||||
The exact-Claude tokenizer is only available via Anthropic's network
|
||||
``count_tokens`` API; calling it from a synchronous middleware on every
|
||||
tool result is too slow and adds an external dependency on every
|
||||
response. ``tiktoken`` is the closest approximation we can ship without
|
||||
that risk.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -36,18 +56,63 @@ logger = logging.getLogger(__name__)
|
||||
# Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes)
|
||||
ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes]
|
||||
|
||||
# Approximate characters per token for estimation
|
||||
# Claude tokenizer averages ~4 chars per token for English text
|
||||
# JSON tends to be more verbose, so we use a slightly lower ratio
|
||||
CHARS_PER_TOKEN = 3.5
|
||||
# Fallback character-to-token ratio used when tiktoken is unavailable.
|
||||
# 3.0 is conservative for JSON content (the previous 3.5 under-counted
|
||||
# JSON-heavy payloads relative to Claude's actual tokenizer, which let
|
||||
# oversized responses slip past the response-size guard).
|
||||
CHARS_PER_TOKEN = 3.0
|
||||
|
||||
# Encoding used when tiktoken is available. cl100k_base is OpenAI's
|
||||
# tokenizer for GPT-3.5/4; it is BPE-based with a vocabulary similar to
|
||||
# Claude's and tracks Claude's token counts within roughly ±10% for
|
||||
# English and JSON-heavy MCP responses.
|
||||
_TIKTOKEN_ENCODING_NAME = "cl100k_base"
|
||||
|
||||
|
||||
def _load_tiktoken_encoding() -> Any:
|
||||
"""Return a tiktoken encoding instance, or None if tiktoken is unavailable.
|
||||
|
||||
Imported lazily so the module can be used in environments without
|
||||
tiktoken installed. The encoding is small (~1 MB) so we cache it on
|
||||
first use.
|
||||
"""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"tiktoken not installed; falling back to char-based token "
|
||||
"estimation (CHARS_PER_TOKEN=%s). Install the 'fastmcp' extra "
|
||||
"for accurate counts.",
|
||||
CHARS_PER_TOKEN,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return tiktoken.get_encoding(_TIKTOKEN_ENCODING_NAME)
|
||||
except (KeyError, ValueError) as exc:
|
||||
# tiktoken installed but the requested encoding is missing — this
|
||||
# only happens on partial installs. Treat as no tokenizer rather
|
||||
# than crashing on every tool call.
|
||||
logger.warning(
|
||||
"tiktoken encoding '%s' unavailable: %s; falling back to "
|
||||
"char-based token estimation",
|
||||
_TIKTOKEN_ENCODING_NAME,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Cached encoding instance (None if tiktoken not importable).
|
||||
_ENCODING = _load_tiktoken_encoding()
|
||||
|
||||
|
||||
def estimate_token_count(text: str | bytes) -> int:
|
||||
"""
|
||||
Estimate the token count for a given text.
|
||||
|
||||
Uses a character-based heuristic since we don't have direct access to
|
||||
the actual tokenizer. This is conservative to avoid underestimating.
|
||||
Uses tiktoken's ``cl100k_base`` encoding when available for
|
||||
Claude-aligned accuracy (within ~10%), falling back to a
|
||||
character-based heuristic otherwise.
|
||||
|
||||
Args:
|
||||
text: The text to estimate tokens for (string or bytes)
|
||||
@@ -58,11 +123,19 @@ def estimate_token_count(text: str | bytes) -> int:
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode("utf-8", errors="replace")
|
||||
|
||||
# Simple heuristic: ~3.5 characters per token for JSON/code
|
||||
text_length = len(text)
|
||||
if text_length == 0:
|
||||
if not text:
|
||||
return 0
|
||||
return max(1, int(text_length / CHARS_PER_TOKEN))
|
||||
|
||||
if _ENCODING is not None:
|
||||
try:
|
||||
return len(_ENCODING.encode(text))
|
||||
except (ValueError, UnicodeError) as exc:
|
||||
# Defensive: if tiktoken chokes on a specific input, fall
|
||||
# back to the char heuristic for this call rather than
|
||||
# raising — the response size guard must never fail-open.
|
||||
logger.warning("tiktoken encode failed (%s); using fallback", exc)
|
||||
|
||||
return max(1, int(len(text) / CHARS_PER_TOKEN))
|
||||
|
||||
|
||||
def estimate_response_tokens(response: ToolResponse) -> int:
|
||||
|
||||
@@ -146,7 +146,13 @@ class TestResponseSizeGuardMiddleware:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_warning_at_threshold(self) -> None:
|
||||
"""Should log warning when approaching limit."""
|
||||
"""Should log warning when approaching limit.
|
||||
|
||||
Mocks the token estimator to return a specific value above the
|
||||
warn threshold but below the hard limit, decoupling the test
|
||||
from whichever tokenizer (tiktoken or char heuristic) happens
|
||||
to be loaded.
|
||||
"""
|
||||
middleware = ResponseSizeGuardMiddleware(
|
||||
token_limit=1000, warn_threshold_pct=80
|
||||
)
|
||||
@@ -155,18 +161,21 @@ class TestResponseSizeGuardMiddleware:
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {}
|
||||
|
||||
# Response at ~85% of limit (should trigger warning but not block)
|
||||
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
|
||||
response = {"data": "approaching the limit"}
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
patch(
|
||||
"superset.mcp_service.middleware.estimate_response_tokens",
|
||||
return_value=850,
|
||||
),
|
||||
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
# Should return response (not blocked)
|
||||
# Should return response (not blocked at 85% of limit)
|
||||
assert result == response
|
||||
# Should log warning
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@@ -20,9 +20,11 @@ Unit tests for MCP service token utilities.
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from superset.mcp_service.utils import token_utils
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
_replace_collections_with_summaries,
|
||||
_summarize_large_dicts,
|
||||
@@ -45,29 +47,65 @@ class TestEstimateTokenCount:
|
||||
"""Test estimate_token_count function."""
|
||||
|
||||
def test_estimate_string(self) -> None:
|
||||
"""Should estimate tokens for a string."""
|
||||
"""Should produce a positive non-zero estimate for a normal string.
|
||||
|
||||
We don't assert on a specific number because the result depends on
|
||||
which tokenizer is loaded (tiktoken when available, char heuristic
|
||||
otherwise).
|
||||
"""
|
||||
text = "Hello world"
|
||||
result = estimate_token_count(text)
|
||||
expected = int(len(text) / CHARS_PER_TOKEN)
|
||||
assert result == expected
|
||||
assert result > 0
|
||||
|
||||
def test_estimate_bytes(self) -> None:
|
||||
"""Should estimate tokens for bytes."""
|
||||
text = b"Hello world"
|
||||
result = estimate_token_count(text)
|
||||
expected = int(len(text) / CHARS_PER_TOKEN)
|
||||
assert result == expected
|
||||
"""Bytes input should be decoded and produce the same count as the
|
||||
equivalent string."""
|
||||
text = "Hello world"
|
||||
assert estimate_token_count(text.encode("utf-8")) == estimate_token_count(text)
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
"""Should return 0 for empty string."""
|
||||
"""Should return 0 for empty string and empty bytes."""
|
||||
assert estimate_token_count("") == 0
|
||||
assert estimate_token_count(b"") == 0
|
||||
|
||||
def test_json_like_content(self) -> None:
|
||||
"""Should estimate tokens for JSON-like content."""
|
||||
"""JSON content should produce a positive estimate."""
|
||||
json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}'
|
||||
result = estimate_token_count(json_str)
|
||||
assert result > 0
|
||||
assert result == int(len(json_str) / CHARS_PER_TOKEN)
|
||||
assert estimate_token_count(json_str) > 0
|
||||
|
||||
def test_long_text_roughly_scales_with_length(self) -> None:
|
||||
"""A doubled string should produce roughly double the token count
|
||||
(within ±10%)."""
|
||||
small = "the quick brown fox jumps over the lazy dog. " * 20
|
||||
large = small * 2
|
||||
small_n = estimate_token_count(small)
|
||||
large_n = estimate_token_count(large)
|
||||
# Within 10% of 2x — both tokenizers (tiktoken and the char
|
||||
# fallback) preserve length monotonicity.
|
||||
assert 1.8 * small_n <= large_n <= 2.2 * small_n
|
||||
|
||||
def test_fallback_uses_chars_per_token_when_tiktoken_unavailable(
|
||||
self,
|
||||
) -> None:
|
||||
"""When the tiktoken encoding is None (not installed), the
|
||||
function falls back to len/CHARS_PER_TOKEN math."""
|
||||
text = "x" * 100
|
||||
with patch.object(token_utils, "_ENCODING", None):
|
||||
result = estimate_token_count(text)
|
||||
assert result == int(100 / CHARS_PER_TOKEN)
|
||||
|
||||
def test_fallback_when_tiktoken_encode_raises(self) -> None:
|
||||
"""A misbehaving encoding should fall back to the char heuristic
|
||||
rather than raise — the size guard must never fail-open."""
|
||||
|
||||
class BoomEncoding:
|
||||
def encode(self, text: str) -> list[int]:
|
||||
raise ValueError("simulated tiktoken failure")
|
||||
|
||||
text = "abc" * 50
|
||||
with patch.object(token_utils, "_ENCODING", BoomEncoding()):
|
||||
result = estimate_token_count(text)
|
||||
assert result == int(len(text) / CHARS_PER_TOKEN)
|
||||
|
||||
|
||||
class TestEstimateResponseTokens:
|
||||
|
||||
Reference in New Issue
Block a user