Compare commits

...

12 Commits

Author SHA1 Message Date
Amin Ghadersohi
18e8217498 test(mcp): add -> None return types and remove redundant ToolError import 2026-05-07 16:55:35 +00:00
Amin Ghadersohi
32234ea0bb test(mcp): add -> None return types and remove redundant ToolError import
All 6 new test methods in TestListDatasetsRequestWrapper were missing
the required -> None return type annotation (Rule 12). Also remove the
inline `from fastmcp.exceptions import ToolError` in test_flat_kwargs_rejected
since ToolError is already imported at module top-level (line 27).
2026-05-07 16:55:35 +00:00
Amin Ghadersohi
a2b0f64176 fix(mcp): use valid opr value in docstring filter examples; fix dashboard docstring test
- Replace opr:"ct" with opr:"sw" in filter examples in list_datasets,
  list_charts, and list_dashboards docstrings — "ct" is not a valid
  ColumnOperatorEnum value, so examples using it produce validation errors
- Fix TestDashboardSortableColumns.test_sortable_columns_in_docstring:
  assertion checked for "Sortable columns for order_column:" (plain text)
  but docstring uses RST backtick format; split into two substring checks
  matching "Sortable columns for" and "order_column" separately
2026-05-07 16:55:35 +00:00
Amin Ghadersohi
4d66dc0774 test(mcp): fix and strengthen TestListDatasetsRequestWrapper tests
- Fix test_sortable_columns_in_docstring: assertion used plain-text match
  but docstring now uses RST backtick format; use substring check for both
  'Sortable columns for' and 'order_column' separately
- Fix test_dataset_filter_valid_col: opr='ct' is not a valid
  ColumnOperatorEnum value; use opr='sw' (starts-with) instead
- Add test_request_wrapper_enforced_by_tool: exercises the wrapper pattern
  through the actual FastMCP client call (not just schema instantiation),
  verifying the MCP tool layer accepts request={...} correctly
- Strengthen test_flat_kwargs_rejected: assert that the ToolError message
  references the unexpected arguments ('search', 'Unexpected', or 'request')
  so the test cannot pass on unrelated failures
2026-05-07 16:55:35 +00:00
Amin Ghadersohi
5542a2f3b1 fix(mcp): clarify request wrapper pattern in list_datasets, list_charts, list_dashboards
LLMs consistently passed flat kwargs (search, page, page_size) to list_*
tools instead of wrapping them in the required `request` object, causing
pydantic validation errors.

- Add docstring usage examples to list_datasets, list_charts, and
  list_dashboards showing the correct `request={...}` call shape and
  explicitly warning against flat kwargs
- Enumerate valid filter columns directly in DatasetFilter, ChartFilter,
  and DashboardFilter field descriptions (e.g. `created_by_fk` is not a
  valid dataset filter col)
- Add TestListDatasetsRequestWrapper tests covering: correct request
  wrapper usage, default values, valid/invalid filter col validation,
  and the flat-kwargs rejection scenario from story #105712
- Allow E501 in list_*.py tool files (docstring examples need full request
  shapes to be instructive)
2026-05-07 16:55:35 +00:00
Richard Fogaca Nienkotter
8c80caefa3 fix(explore): preserve preview chart name on save (#39908) 2026-05-07 13:08:28 -03:00
Richard Fogaca Nienkotter
8088c5d1de fix(dashboard): match auto-refresh paused-dot outline to icon color (#39909) 2026-05-07 13:07:52 -03:00
Amin Ghadersohi
9b520312a1 fix(mcp): use tiktoken for response-size-guard token estimation (#39912) 2026-05-07 11:51:31 -04:00
Amin Ghadersohi
9ac4711ac8 fix(mcp): prevent DetachedInstanceError in get_chart_preview (#39921) 2026-05-07 11:44:11 -04:00
dependabot[bot]
7593d2a164 chore(deps): bump caniuse-lite from 1.0.30001791 to 1.0.30001792 in /docs (#39933)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 21:57:29 +07:00
dependabot[bot]
d3c44e311e chore(deps): bump aws-actions/amazon-ecr-login from 2.1.4 to 2.1.5 (#39931)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-05-07 21:54:59 +07:00
Enzo Martellucci
b5186d1c65 fix(reports): keep body sized so standalone screenshots don't time out (#39944) 2026-05-07 12:26:50 +02:00
28 changed files with 738 additions and 70 deletions

View File

@@ -58,7 +58,7 @@ jobs:
- name: Login to Amazon ECR
if: steps.describe-services.outputs.active == 'true'
id: login-ecr
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
- name: Delete ECR image tag
if: steps.describe-services.outputs.active == 'true'

View File

@@ -199,7 +199,7 @@ jobs:
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
- name: Load, tag and push image to ECR
id: push-image
@@ -235,7 +235,7 @@ jobs:
- name: Login to Amazon ECR
id: login-ecr
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
- name: Check target image exists in ECR
id: check-image

View File

@@ -70,7 +70,7 @@
"@swc/core": "^1.15.33",
"antd": "^6.3.7",
"baseline-browser-mapping": "^2.10.27",
"caniuse-lite": "^1.0.30001791",
"caniuse-lite": "^1.0.30001792",
"docusaurus-plugin-openapi-docs": "^5.0.2",
"docusaurus-theme-openapi-docs": "^5.0.2",
"js-yaml": "^4.1.1",

View File

@@ -6035,10 +6035,10 @@ caniuse-api@^3.0.0:
lodash.memoize "^4.1.2"
lodash.uniq "^4.5.0"
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001791:
version "1.0.30001791"
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001791.tgz#dfb93d85c40ad380c57123e72e10f3c575786b51"
integrity sha512-yk0l/YSrOnFZk3UROpDLQD9+kC1l4meK/wed583AXrzoarMGJcbRi2Q4RaUYbKxYAsZ8sWmaSa/DsLmdBeI1vQ==
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001792:
version "1.0.30001792"
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001792.tgz#ca8bb9be244835a335e2018272ce7223691873c5"
integrity sha512-hVLMUZFgR4JJ6ACt1uEESvQN1/dBVqPAKY0hgrV70eN3391K6juAfTjKZLKvOMsx8PxA7gsY1/tLMMTcfFLLpw==
ccount@^2.0.0:
version "2.0.1"

View File

@@ -145,7 +145,13 @@ solr = ["sqlalchemy-solr >= 0.2.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
excel = ["xlrd>=1.2.0, <1.3"]
fastmcp = ["fastmcp>=3.2.4,<4.0"]
fastmcp = [
"fastmcp>=3.2.4,<4.0",
# tiktoken backs the response-size-guard token estimator. Without
# it, the middleware falls back to a coarser character-based
# heuristic that under-counts JSON-heavy MCP responses.
"tiktoken>=0.7.0,<1.0",
]
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
gevent = ["gevent>=23.9.1"]
@@ -377,6 +383,7 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
[tool.ruff.lint.per-file-ignores]
"superset/mcp_service/app.py" = ["S608", "E501"] # LLM instruction text: SQL examples (S608) and long lines in multiline string (E501)
"superset/mcp_service/*/tool/list_*.py" = ["E501"] # LLM docstring examples show full request shapes which exceed line length
"scripts/*" = ["TID251"]
"setup.py" = ["TID251"]
"superset/config.py" = ["TID251"]

View File

@@ -183,7 +183,9 @@ idna==3.10
# trio
# url-normalize
isodate==0.7.2
# via apache-superset (pyproject.toml)
# via
# apache-superset (pyproject.toml)
# apache-superset-core
itsdangerous==2.2.0
# via
# flask
@@ -296,6 +298,7 @@ pyarrow==20.0.0
# via
# -r requirements/base.in
# apache-superset (pyproject.toml)
# apache-superset-core
pyasn1==0.6.3
# via
# pyasn1-modules

View File

@@ -442,6 +442,7 @@ isodate==0.7.2
# via
# -c requirements/base-constraint.txt
# apache-superset
# apache-superset-core
isort==6.0.1
# via pylint
itsdangerous==2.2.0
@@ -715,6 +716,7 @@ pyarrow==20.0.0
# via
# -c requirements/base-constraint.txt
# apache-superset
# apache-superset-core
# db-dtypes
# pandas-gbq
pyasn1==0.6.3
@@ -866,6 +868,8 @@ referencing==0.36.2
# jsonschema
# jsonschema-path
# jsonschema-specifications
regex==2026.4.4
# via tiktoken
requests==2.33.0
# via
# -c requirements/base-constraint.txt
@@ -878,6 +882,7 @@ requests==2.33.0
# requests-cache
# requests-oauthlib
# shillelagh
# tiktoken
# trino
requests-cache==1.2.1
# via
@@ -1003,6 +1008,8 @@ tabulate==0.9.0
# via
# -c requirements/base-constraint.txt
# apache-superset
tiktoken==0.12.0
# via apache-superset
tomli-w==1.2.0
# via apache-superset-extensions-cli
tomlkit==0.13.3

View File

@@ -17,7 +17,8 @@
* under the License.
*/
import { render, screen, act } from 'spec/helpers/testing-library';
import { StatusIndicatorDot } from './StatusIndicatorDot';
import { supersetTheme } from '@apache-superset/core/theme';
import { getStatusConfig, StatusIndicatorDot } from './StatusIndicatorDot';
import { AutoRefreshStatus } from '../../types/autoRefresh';
afterEach(() => {
@@ -62,6 +63,15 @@ test('renders with paused status', () => {
expect(dot).toHaveAttribute('data-status', AutoRefreshStatus.Paused);
});
test('uses the icon color for the paused status outline', () => {
expect(
getStatusConfig(supersetTheme, AutoRefreshStatus.Paused),
).toMatchObject({
needsBorder: true,
outlineColor: 'currentColor',
});
});
test('has correct accessibility attributes', () => {
render(<StatusIndicatorDot status={AutoRefreshStatus.Success} />);
const dot = screen.getByTestId('status-indicator-dot');

View File

@@ -39,9 +39,10 @@ export interface StatusIndicatorDotProps {
interface StatusConfig {
color: string;
needsBorder: boolean;
outlineColor?: string;
}
const getStatusConfig = (
export const getStatusConfig = (
theme: ReturnType<typeof useTheme>,
status: AutoRefreshStatus,
): StatusConfig => {
@@ -75,6 +76,7 @@ const getStatusConfig = (
return {
color: theme.colorBgContainer,
needsBorder: true,
outlineColor: 'currentColor',
};
default:
return {
@@ -136,13 +138,15 @@ export const StatusIndicatorDot: FC<StatusIndicatorDotProps> = ({
width: ${size}px;
height: ${size}px;
border-radius: 50%;
color: ${theme.colorTextSecondary};
background-color: ${statusConfig.color};
transition:
background-color ${theme.motionDurationMid} ease-in-out,
border-color ${theme.motionDurationMid} ease-in-out;
border: ${statusConfig.needsBorder
? `1px solid ${theme.colorBorder}`
: 'none'};
border: ${statusConfig.needsBorder ? '1px solid' : 'none'};
border-color: ${statusConfig.needsBorder
? statusConfig.outlineColor
: 'transparent'};
box-shadow: ${statusConfig.needsBorder
? 'none'
: `0 0 0 2px ${theme.colorBgContainer}`};

View File

@@ -21,6 +21,10 @@ import { VizType } from '@superset-ui/core';
import { hydrateExplore, HYDRATE_EXPLORE } from './hydrateExplore';
import { exploreInitialData } from '../fixtures';
afterEach(() => {
window.history.pushState({}, '', '/');
});
test('creates hydrate action from initial data', () => {
const dispatch = jest.fn();
const getState = jest.fn(() => ({
@@ -168,6 +172,84 @@ test('creates hydrate action with existing state', () => {
);
});
test('hydrates sliceName from preview form data before saved slice name', () => {
window.history.pushState({}, '', '/explore/?form_data_key=preview-key');
const dispatch = jest.fn();
const getState = jest.fn(() => ({
user: {},
charts: {},
datasources: {},
common: {},
explore: {},
}));
const previewSliceName = 'RENAMED - Bug Evidence';
const savedSliceName = 'Most Populated Countries';
const previewInitialData = {
...exploreInitialData,
form_data: {
...exploreInitialData.form_data,
slice_name: previewSliceName,
},
slice: {
...exploreInitialData.slice!,
slice_name: savedSliceName,
},
};
// @ts-expect-error we only need the fields consumed by hydrateExplore
hydrateExplore(previewInitialData)(dispatch, getState);
expect(dispatch).toHaveBeenCalledWith(
expect.objectContaining({
type: HYDRATE_EXPLORE,
data: expect.objectContaining({
explore: expect.objectContaining({
sliceName: previewSliceName,
}),
}),
}),
);
});
test('hydrates sliceName from saved slice when regular form data has stale name', () => {
const dispatch = jest.fn();
const getState = jest.fn(() => ({
user: {},
charts: {},
datasources: {},
common: {},
explore: {},
}));
const staleFormDataSliceName = 'Stale Params Name';
const savedSliceName = 'Current Saved Name';
const savedChartInitialData = {
...exploreInitialData,
form_data: {
...exploreInitialData.form_data,
slice_name: staleFormDataSliceName,
},
slice: {
...exploreInitialData.slice!,
slice_name: savedSliceName,
},
};
// @ts-expect-error we only need the fields consumed by hydrateExplore
hydrateExplore(savedChartInitialData)(dispatch, getState);
expect(dispatch).toHaveBeenCalledWith(
expect.objectContaining({
type: HYDRATE_EXPLORE,
data: expect.objectContaining({
explore: expect.objectContaining({
sliceName: savedSliceName,
}),
}),
}),
);
});
test('uses configured default time range if not set', () => {
const dispatch = jest.fn();
const getState = jest.fn(() => ({

View File

@@ -77,6 +77,12 @@ export const hydrateExplore =
const fallbackSlice = sliceId ? sliceEntities?.slices?.[sliceId] : null;
const initialSlice = slice ?? fallbackSlice;
const initialFormData = form_data ?? initialSlice?.form_data;
const isCachedFormData = getUrlParam(URL_PARAMS.formDataKey) !== null;
const [primarySliceNameSource, fallbackSliceNameSource] = isCachedFormData
? [initialFormData, initialSlice]
: [initialSlice, initialFormData];
const initialSliceName =
primarySliceNameSource?.slice_name ?? fallbackSliceNameSource?.slice_name;
if (!initialFormData.viz_type) {
const defaultVizType = common?.conf.DEFAULT_VIZ_TYPE || VizType.Table;
initialFormData.viz_type =
@@ -183,6 +189,7 @@ export const hydrateExplore =
// because `bootstrapData.controls` is undefined.
controls: initialControls,
form_data: initialFormData,
sliceName: initialSliceName,
slice: initialSlice,
controlsTransferred: explore.controlsTransferred,
standalone: getUrlParam(URL_PARAMS.standalone),

View File

@@ -179,6 +179,33 @@ test('renders the right footer buttons', () => {
).toBeInTheDocument();
});
test('initializes chart name from current Explore slice name', () => {
const previewSliceName = 'RENAMED - Bug Evidence';
const savedSliceName = 'Most Populated Countries';
const { getByTestId } = setup(
{
...defaultProps,
form_data: {
...defaultProps.form_data,
slice_name: previewSliceName,
},
sliceName: previewSliceName,
},
mockStore({
...initialState,
explore: {
...initialState.explore,
slice: {
...initialState.explore.slice,
slice_name: savedSliceName,
},
},
}),
);
expect(getByTestId('new-chart-name')).toHaveValue(previewSliceName);
});
test('does not render a message when overriding', () => {
const { getByRole, queryByRole } = setup();

View File

@@ -537,8 +537,11 @@ class ChartFilter(ColumnOperator):
"datasource_name",
] = Field(
...,
description="Column to filter on. Use get_schema(model_type='chart') for "
"available filter columns.",
description=(
"Column to filter on. Valid values: 'slice_name', 'viz_type', "
"'datasource_name'. Other column names are not valid filter columns "
"and will cause a validation error."
),
)
opr: ColumnOperatorEnum = Field(
...,

View File

@@ -28,7 +28,7 @@ from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.extensions import db, event_logger
from superset.mcp_service.chart.ascii_charts import (
generate_ascii_chart,
generate_ascii_table,
@@ -1140,6 +1140,15 @@ async def _get_chart_preview_internal( # noqa: C901
)
chart = find_chart_by_identifier(request.identifier)
# Eagerly refresh all attributes while the session is still
# active. SQLAlchemy expires object attributes after any
# commit; if a downstream operation commits before the strategy
# classes access chart attributes, a DetachedInstanceError will
# be raised. Calling refresh() here ensures all column values
# are loaded into the object's __dict__ upfront.
if chart is not None:
db.session.refresh(chart)
# If not found and looks like a form_data_key, try transient
if (
not chart
@@ -1371,6 +1380,20 @@ async def _get_chart_preview_internal( # noqa: C901
return _sanitize_chart_preview_for_llm_context(result)
except SQLAlchemyError as e:
# Catch DetachedInstanceError and other SQLAlchemy errors that can
# surface when the ORM session expires or commits mid-request.
await ctx.error(
"Chart preview failed due to database session error: "
"identifier=%s, error_type=%s, error=%s"
% (request.identifier, type(e).__name__, str(e))
)
logger.exception("SQLAlchemy error in get_chart_preview: %s", e)
return ChartError(
error="Database session error while generating chart preview. "
"Please retry the request.",
error_type="InternalError",
)
except (
CommandException,
SupersetException,

View File

@@ -91,8 +91,24 @@ async def list_charts(
Returns chart metadata including id, name, viz_type, URL, and last
modified time.
Sortable columns for order_column: id, slice_name, viz_type, description,
changed_on, created_on
**IMPORTANT**: All parameters must be wrapped in a ``request`` object.
Do NOT pass ``search``, ``page``, ``page_size``, etc. as top-level
keyword arguments — they will be rejected. Use the ``request`` wrapper::
# Correct usage
list_charts(request={"search": "revenue", "page": 1, "page_size": 10})
list_charts(request={"filters": [{"col": "slice_name", "opr": "sw", "value": "sales"}]})
list_charts() # no arguments returns first page with defaults
# Wrong — causes pydantic validation errors
list_charts(search="revenue", page=1) # DO NOT DO THIS
Valid filter columns for ``filters[].col``:
``slice_name``, ``viz_type``, ``datasource_name``
Sortable columns for ``order_column``:
``id``, ``slice_name``, ``viz_type``, ``description``,
``changed_on``, ``created_on``
"""
request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True)
await ctx.info(

View File

@@ -176,8 +176,9 @@ class DashboardFilter(ColumnOperator):
] = Field(
...,
description=(
"Column to filter on. Use "
"get_schema(model_type='dashboard') for available filter columns."
"Column to filter on. Valid values: 'dashboard_title', 'published', "
"'favorite'. Other column names are not valid filter columns and will "
"cause a validation error."
),
)
opr: ColumnOperatorEnum = Field(

View File

@@ -85,8 +85,24 @@ async def list_dashboards(
including title, slug, URL, and last modified time. Use select_columns to
request additional fields.
Sortable columns for order_column: id, dashboard_title, slug, published,
changed_on, created_on
**IMPORTANT**: All parameters must be wrapped in a ``request`` object.
Do NOT pass ``search``, ``page``, ``page_size``, etc. as top-level
keyword arguments — they will be rejected. Use the ``request`` wrapper::
# Correct usage
list_dashboards(request={"search": "sales", "page": 1, "page_size": 10})
list_dashboards(request={"filters": [{"col": "dashboard_title", "opr": "sw", "value": "exec"}]})
list_dashboards() # no arguments returns first page with defaults
# Wrong — causes pydantic validation errors
list_dashboards(search="sales", page=1) # DO NOT DO THIS
Valid filter columns for ``filters[].col``:
``dashboard_title``, ``published``, ``favorite``
Sortable columns for ``order_column``:
``id``, ``dashboard_title``, ``slug``, ``published``,
``changed_on``, ``created_on``
"""
request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True)
await ctx.info(

View File

@@ -71,8 +71,11 @@ class DatasetFilter(ColumnOperator):
"database_name",
] = Field(
...,
description="Column to filter on. Use get_schema(model_type='dataset') for "
"available filter columns.",
description=(
"Column to filter on. Valid values: 'table_name', 'schema', "
"'database_name'. Other column names (e.g. 'created_by_fk', 'id') "
"are not valid filter columns and will cause a validation error."
),
)
opr: ColumnOperatorEnum = Field(
...,

View File

@@ -96,8 +96,23 @@ async def list_datasets(
Returns dataset metadata including table name, schema, and last modified
time.
Sortable columns for order_column: id, table_name, schema, changed_on,
created_on
**IMPORTANT**: All parameters must be wrapped in a ``request`` object.
Do NOT pass ``search``, ``page``, ``page_size``, etc. as top-level
keyword arguments — they will be rejected. Use the ``request`` wrapper::
# Correct usage
list_datasets(request={"search": "sales", "page": 1, "page_size": 10})
list_datasets(request={"filters": [{"col": "table_name", "opr": "sw", "value": "orders"}]})
list_datasets() # no arguments returns first page with defaults
# Wrong — causes pydantic validation errors
list_datasets(search="sales", page=1) # DO NOT DO THIS
Valid filter columns for ``filters[].col``:
``table_name``, ``schema``, ``database_name``
Sortable columns for ``order_column``:
``id``, ``table_name``, ``schema``, ``changed_on``, ``created_on``
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_datasets")

View File

@@ -41,6 +41,12 @@ from superset.mcp_service.constants import (
DEFAULT_TOKEN_LIMIT,
DEFAULT_WARN_THRESHOLD_PCT,
)
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
format_size_limit_error,
INFO_TOOLS,
truncate_oversized_response,
)
from superset.utils.core import get_user_id
logger = logging.getLogger(__name__)
@@ -1104,11 +1110,6 @@ class ResponseSizeGuardMiddleware(Middleware):
``content[0].text`` as a JSON string. We parse that string, run the
truncation phases on the resulting dict, then re-wrap the result.
"""
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
truncate_oversized_response,
)
# Unwrap ToolResult so truncation operates on the real payload
extracted = self._extract_payload_from_tool_result(response)
if extracted is not None:
@@ -1191,12 +1192,6 @@ class ResponseSizeGuardMiddleware(Middleware):
# Execute the tool
response = await call_next(context)
# Estimate response token count (guard against huge responses causing OOM)
from superset.mcp_service.utils.token_utils import (
estimate_response_tokens,
format_size_limit_error,
)
# When the response is a ToolResult, estimate tokens on the actual
# payload inside content[0].text rather than on the ToolResult
# wrapper (which would double-serialize the JSON string).
@@ -1233,8 +1228,6 @@ class ResponseSizeGuardMiddleware(Middleware):
params = getattr(context.message, "params", {}) or {}
# For info tools, try dynamic truncation before blocking
from superset.mcp_service.utils.token_utils import INFO_TOOLS
if tool_name in INFO_TOOLS:
truncated = self._try_truncate_info_response(
tool_name, response, estimated_tokens

View File

@@ -21,6 +21,26 @@ Token counting and response size utilities for MCP service.
This module provides utilities to estimate token counts and generate smart
suggestions when responses exceed configured limits. This prevents large
responses from overwhelming LLM clients like Claude Desktop.
Token counting strategy:
1. ``tiktoken`` with the ``cl100k_base`` encoding when the package is
installed (it is shipped as part of the ``fastmcp`` extra). This is a
real BPE tokenizer trained on a similar vocabulary to Claude's; for
English and JSON-heavy MCP payloads it tracks Claude's tokenizer
within roughly ±10%, which is far more accurate than the legacy
character heuristic.
2. A character-based fallback (``CHARS_PER_TOKEN``) when tiktoken is not
importable. The fallback uses a slightly more conservative ratio than
before (3.0 chars/token instead of 3.5) so that JSON-heavy responses
are not under-counted, which previously let oversized payloads slip
past the response-size guard.
The exact-Claude tokenizer is only available via Anthropic's network
``count_tokens`` API; calling it from a synchronous middleware on every
tool result is too slow and adds an external dependency on every
response. ``tiktoken`` is the closest approximation we can ship without
that risk.
"""
from __future__ import annotations
@@ -36,18 +56,63 @@ logger = logging.getLogger(__name__)
# Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes)
ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes]
# Approximate characters per token for estimation
# Claude tokenizer averages ~4 chars per token for English text
# JSON tends to be more verbose, so we use a slightly lower ratio
CHARS_PER_TOKEN = 3.5
# Fallback character-to-token ratio used when tiktoken is unavailable.
# 3.0 is conservative for JSON content (the previous 3.5 under-counted
# JSON-heavy payloads relative to Claude's actual tokenizer, which let
# oversized responses slip past the response-size guard).
CHARS_PER_TOKEN = 3.0
# Encoding used when tiktoken is available. cl100k_base is OpenAI's
# tokenizer for GPT-3.5/4; it is BPE-based with a vocabulary similar to
# Claude's and tracks Claude's token counts within roughly ±10% for
# English and JSON-heavy MCP responses.
_TIKTOKEN_ENCODING_NAME = "cl100k_base"
def _load_tiktoken_encoding() -> Any:
"""Return a tiktoken encoding instance, or None if tiktoken is unavailable.
Imported lazily so the module can be used in environments without
tiktoken installed. The encoding is small (~1 MB) so we cache it on
first use.
"""
try:
import tiktoken
except ImportError:
logger.info(
"tiktoken not installed; falling back to char-based token "
"estimation (CHARS_PER_TOKEN=%s). Install the 'fastmcp' extra "
"for accurate counts.",
CHARS_PER_TOKEN,
)
return None
try:
return tiktoken.get_encoding(_TIKTOKEN_ENCODING_NAME)
except (KeyError, ValueError) as exc:
# tiktoken installed but the requested encoding is missing — this
# only happens on partial installs. Treat as no tokenizer rather
# than crashing on every tool call.
logger.warning(
"tiktoken encoding '%s' unavailable: %s; falling back to "
"char-based token estimation",
_TIKTOKEN_ENCODING_NAME,
exc,
)
return None
# Cached encoding instance (None if tiktoken not importable).
_ENCODING = _load_tiktoken_encoding()
def estimate_token_count(text: str | bytes) -> int:
"""
Estimate the token count for a given text.
Uses a character-based heuristic since we don't have direct access to
the actual tokenizer. This is conservative to avoid underestimating.
Uses tiktoken's ``cl100k_base`` encoding when available for
Claude-aligned accuracy (within ~10%), falling back to a
character-based heuristic otherwise.
Args:
text: The text to estimate tokens for (string or bytes)
@@ -58,11 +123,19 @@ def estimate_token_count(text: str | bytes) -> int:
if isinstance(text, bytes):
text = text.decode("utf-8", errors="replace")
# Simple heuristic: ~3.5 characters per token for JSON/code
text_length = len(text)
if text_length == 0:
if not text:
return 0
return max(1, int(text_length / CHARS_PER_TOKEN))
if _ENCODING is not None:
try:
return len(_ENCODING.encode(text))
except (ValueError, UnicodeError) as exc:
# Defensive: if tiktoken chokes on a specific input, fall
# back to the char heuristic for this call rather than
# raising — the response size guard must never fail-open.
logger.warning("tiktoken encode failed (%s); using fallback", exc)
return max(1, int(len(text) / CHARS_PER_TOKEN))
def estimate_response_tokens(response: ToolResponse) -> int:

View File

@@ -45,6 +45,13 @@
color: #000;
}
{% endif %}
{% if standalone_mode %}
/* Keep body sized so screenshot waits don't see it as hidden before React mounts. */
html, body.standalone {
min-height: 100vh;
margin: 0;
}
{% endif %}
</style>
{% if dark_theme_bg and entry != 'embedded' %}

View File

@@ -68,3 +68,50 @@ def test_spa_template_includes_css_bundles():
"spa.html must call css_bundle for the page entry to load "
"entry-specific extracted CSS in production builds"
)
def test_spa_template_standalone_body_has_min_height():
"""Standalone body must be measurable so screenshot waits don't time out."""
from jinja2 import DictLoader, Environment
template_path = join(SUPERSET_DIR, "templates", "superset", "spa.html")
with open(template_path) as f:
template_content = f.read()
env = Environment( # noqa: S701
loader=DictLoader(
{
"spa.html": template_content,
# Stub out includes/imports that are not relevant for this test.
"appbuilder/general/lib.html": "",
"superset/partials/asset_bundle.html": (
"{% macro css_bundle(prefix, entry) %}{% endmacro %}"
"{% macro js_bundle(prefix, entry) %}{% endmacro %}"
),
"superset/macros.html": ("{% macro get_nonce() %}{% endmacro %}"),
"tail_js_custom_extra.html": "",
"head_custom_extra.html": "",
}
)
)
appbuilder = Mock()
appbuilder.app.config = {"FAVICONS": []}
def render(standalone_mode: bool) -> str:
return env.get_template("spa.html").render(
appbuilder=appbuilder,
assets_prefix="",
bootstrap_data="{}",
entry="spa",
standalone_mode=standalone_mode,
theme_tokens={},
spinner_svg=None,
)
standalone_html = render(standalone_mode=True)
assert "body.standalone" in standalone_html
assert "min-height: 100vh" in standalone_html
non_standalone_html = render(standalone_mode=False)
assert "body.standalone" not in non_standalone_html

View File

@@ -595,3 +595,191 @@ Market Share
"""
# These demonstrate the expected ASCII formats for different chart types
class TestDetachedInstanceError:
"""Tests that DetachedInstanceError is handled gracefully.
When the SQLAlchemy session commits mid-request, ORM objects expire and
become detached. Accessing lazy attributes on a detached Slice raises
DetachedInstanceError. The tool must:
1. Call db.session.refresh() immediately after loading the chart so all
column values are loaded upfront before any downstream operation.
2. Catch SQLAlchemyError (the base class) and return a ChartError
instead of propagating the exception.
"""
@pytest.mark.asyncio
async def test_session_refresh_called_after_chart_load(self):
"""db.session.refresh() is invoked right after find_chart_by_identifier."""
import importlib
from contextlib import nullcontext
from unittest.mock import MagicMock, patch
from superset.mcp_service.chart.schemas import URLPreview
from superset.utils import json
get_chart_preview_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_preview"
)
mock_chart = MagicMock()
mock_chart.id = 42
mock_chart.slice_name = "Sales Chart"
mock_chart.viz_type = "table"
mock_chart.datasource_id = 1
mock_chart.datasource_type = "table"
mock_chart.params = "{}"
refresh_calls: list[object] = []
def _fake_refresh(obj: object) -> None:
refresh_calls.append(obj)
url_preview = URLPreview(
preview_url="http://localhost/explore/?slice_id=42",
width=800,
height=600,
)
with (
patch.object(
get_chart_preview_module,
"find_chart_by_identifier",
return_value=mock_chart,
),
patch.object(
get_chart_preview_module.db,
"session",
**{"refresh.side_effect": _fake_refresh},
),
patch.object(
get_chart_preview_module,
"validate_chart_dataset",
return_value=MagicMock(is_valid=True, warnings=[]),
),
patch.object(
get_chart_preview_module.event_logger,
"log_context",
return_value=nullcontext(),
),
# Return a real URLPreview so Pydantic model validation succeeds
patch.object(
get_chart_preview_module.PreviewFormatGenerator,
"generate",
return_value=url_preview,
),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost",
),
):
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
mu.return_value = MagicMock(id=1, username="admin")
with patch(
"superset.mcp_service.auth.check_tool_permission", return_value=True
):
async with Client(mcp) as client:
response = await client.call_tool(
"get_chart_preview",
{
"request": GetChartPreviewRequest(
identifier=42, format="url"
).model_dump()
},
)
data = json.loads(response.content[0].text)
# The tool should succeed — not return a ChartError
assert "error_type" not in data, (
f"Expected ChartPreview but got ChartError: {data.get('error')}"
)
assert data.get("chart_id") == 42
assert len(refresh_calls) == 1, (
"db.session.refresh() should be called once after loading the chart"
)
assert refresh_calls[0] is mock_chart
@pytest.mark.asyncio
async def test_detached_instance_error_returns_chart_error(self):
"""DetachedInstanceError during preview generation returns ChartError."""
import importlib
from contextlib import nullcontext
from unittest.mock import MagicMock, patch
from sqlalchemy.orm.exc import DetachedInstanceError
get_chart_preview_module = importlib.import_module(
"superset.mcp_service.chart.tool.get_chart_preview"
)
mock_chart = MagicMock()
mock_chart.id = 7
mock_chart.slice_name = "Broken Chart"
mock_chart.viz_type = "bar"
mock_chart.datasource_id = 3
mock_chart.datasource_type = "table"
mock_chart.params = "{}"
with (
patch.object(
get_chart_preview_module,
"find_chart_by_identifier",
return_value=mock_chart,
),
patch.object(
get_chart_preview_module.db,
"session",
**{"refresh.return_value": None},
),
patch.object(
get_chart_preview_module,
"validate_chart_dataset",
return_value=MagicMock(is_valid=True, warnings=[]),
),
patch.object(
get_chart_preview_module.event_logger,
"log_context",
return_value=nullcontext(),
),
# Simulate the session expiring inside the strategy
patch.object(
get_chart_preview_module.PreviewFormatGenerator,
"generate",
side_effect=DetachedInstanceError(),
),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost",
),
):
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
from superset.utils import json
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
mu.return_value = MagicMock(id=1, username="admin")
with patch(
"superset.mcp_service.auth.check_tool_permission", return_value=True
):
async with Client(mcp) as client:
response = await client.call_tool(
"get_chart_preview",
{
"request": GetChartPreviewRequest(
identifier=7, format="ascii"
).model_dump()
},
)
data = json.loads(response.content[0].text)
assert data["error_type"] == "InternalError"
assert "session" in data["error"].lower() or "retry" in data["error"].lower()

View File

@@ -1349,7 +1349,8 @@ class TestDashboardSortableColumns:
# Check list_dashboards docstring for sortable columns documentation
assert list_dashboards.__doc__ is not None
assert "Sortable columns for order_column:" in list_dashboards.__doc__
assert "Sortable columns for" in list_dashboards.__doc__
assert "order_column" in list_dashboards.__doc__
for col in SORTABLE_DASHBOARD_COLUMNS:
assert col in list_dashboards.__doc__

View File

@@ -1700,7 +1700,8 @@ class TestDatasetSortableColumns:
# Check list_datasets docstring for sortable columns documentation
assert list_datasets.__doc__ is not None
assert "Sortable columns for order_column:" in list_datasets.__doc__
assert "Sortable columns for" in list_datasets.__doc__
assert "order_column" in list_datasets.__doc__
for col in SORTABLE_DATASET_COLUMNS:
assert col in list_datasets.__doc__
@@ -2080,3 +2081,90 @@ class TestListDatasetsOwnedByMe:
request = ListDatasetsRequest(owned_by_me=True, created_by_me=True)
assert request.owned_by_me is True
assert request.created_by_me is True
class TestListDatasetsRequestWrapper:
"""
Tests verifying that list_datasets requires a ``request`` wrapper object.
LLMs sometimes pass parameters like ``search``, ``page``, or ``page_size``
as flat top-level kwargs instead of nesting them inside a ``request``
object. These tests confirm the correct call shape through both the Pydantic
schema and the actual MCP tool layer, and verify that invalid filter column
names (e.g. ``created_by_fk``) are rejected.
"""
def test_request_wrapper_with_search(self) -> None:
"""Parameters passed inside request= are accepted by the schema."""
request = ListDatasetsRequest(search="sales", page=1, page_size=10)
assert request.search == "sales"
assert request.page == 1
assert request.page_size == 10
def test_request_wrapper_defaults(self) -> None:
"""No-arg constructor produces valid schema defaults."""
request = ListDatasetsRequest()
assert request.search is None
assert request.page == 1
assert request.filters == []
def test_dataset_filter_valid_col(self) -> None:
"""Valid col values are accepted by DatasetFilter."""
for col in ("table_name", "schema", "database_name"):
f = DatasetFilter(col=col, opr="sw", value="test")
assert f.col == col
def test_dataset_filter_invalid_col_raises(self) -> None:
"""Column names not in the Literal are rejected with a validation error.
This guards against LLMs passing ``created_by_fk`` or similar
internal column names that are not exposed as filter fields.
"""
from pydantic import ValidationError
for bad_col in ("created_by_fk", "id", "database_id", "owner"):
with pytest.raises(ValidationError):
DatasetFilter(col=bad_col, opr="eq", value="1")
@patch("superset.daos.dataset.DatasetDAO.list")
@pytest.mark.asyncio
async def test_request_wrapper_enforced_by_tool(
self, mock_list, mcp_server
) -> None:
"""The MCP tool layer accepts the request wrapper and returns results.
Verifies end-to-end that wrapping params in ``request={}`` works through
the actual FastMCP tool call, not just schema validation.
"""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool(
"list_datasets",
{"request": {"search": "sales", "page": 1, "page_size": 5}},
)
data = json.loads(result.content[0].text)
assert data["count"] == 0
assert data["datasets"] == []
@pytest.mark.asyncio
async def test_flat_kwargs_rejected(self, mcp_server) -> None:
"""Passing search/page/page_size as top-level kwargs raises a ToolError
that specifically mentions the unexpected arguments.
This is the exact failure pattern from story #105712: LLMs call
``list_datasets(search=..., page=..., page_size=...)`` instead of
``list_datasets(request={...})``.
"""
with pytest.raises(ToolError) as exc_info:
async with Client(mcp_server) as client:
await client.call_tool(
"list_datasets",
{"search": "sales", "page": 1, "page_size": 10},
)
error_text = str(exc_info.value)
# The error must call out the unexpected arguments, not some unrelated failure
assert (
"search" in error_text
or "Unexpected" in error_text
or "request" in error_text
)

View File

@@ -146,7 +146,13 @@ class TestResponseSizeGuardMiddleware:
@pytest.mark.asyncio
async def test_logs_warning_at_threshold(self) -> None:
"""Should log warning when approaching limit."""
"""Should log warning when approaching limit.
Mocks the token estimator to return a specific value above the
warn threshold but below the hard limit, decoupling the test
from whichever tokenizer (tiktoken or char heuristic) happens
to be loaded.
"""
middleware = ResponseSizeGuardMiddleware(
token_limit=1000, warn_threshold_pct=80
)
@@ -155,18 +161,21 @@ class TestResponseSizeGuardMiddleware:
context.message.name = "list_charts"
context.message.params = {}
# Response at ~85% of limit (should trigger warning but not block)
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
response = {"data": "approaching the limit"}
call_next = AsyncMock(return_value=response)
with (
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
patch("superset.mcp_service.middleware.event_logger"),
patch(
"superset.mcp_service.middleware.estimate_response_tokens",
return_value=850,
),
patch("superset.mcp_service.middleware.logger") as mock_logger,
):
result = await middleware.on_call_tool(context, call_next)
# Should return response (not blocked)
# Should return response (not blocked at 85% of limit)
assert result == response
# Should log warning
mock_logger.warning.assert_called()

View File

@@ -20,9 +20,11 @@ Unit tests for MCP service token utilities.
"""
from typing import Any, List
from unittest.mock import patch
from pydantic import BaseModel
from superset.mcp_service.utils import token_utils
from superset.mcp_service.utils.token_utils import (
_replace_collections_with_summaries,
_summarize_large_dicts,
@@ -45,29 +47,65 @@ class TestEstimateTokenCount:
"""Test estimate_token_count function."""
def test_estimate_string(self) -> None:
"""Should estimate tokens for a string."""
"""Should produce a positive non-zero estimate for a normal string.
We don't assert on a specific number because the result depends on
which tokenizer is loaded (tiktoken when available, char heuristic
otherwise).
"""
text = "Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
assert result > 0
def test_estimate_bytes(self) -> None:
"""Should estimate tokens for bytes."""
text = b"Hello world"
result = estimate_token_count(text)
expected = int(len(text) / CHARS_PER_TOKEN)
assert result == expected
"""Bytes input should be decoded and produce the same count as the
equivalent string."""
text = "Hello world"
assert estimate_token_count(text.encode("utf-8")) == estimate_token_count(text)
def test_empty_string(self) -> None:
"""Should return 0 for empty string."""
"""Should return 0 for empty string and empty bytes."""
assert estimate_token_count("") == 0
assert estimate_token_count(b"") == 0
def test_json_like_content(self) -> None:
"""Should estimate tokens for JSON-like content."""
"""JSON content should produce a positive estimate."""
json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}'
result = estimate_token_count(json_str)
assert result > 0
assert result == int(len(json_str) / CHARS_PER_TOKEN)
assert estimate_token_count(json_str) > 0
def test_long_text_roughly_scales_with_length(self) -> None:
"""A doubled string should produce roughly double the token count
(within ±10%)."""
small = "the quick brown fox jumps over the lazy dog. " * 20
large = small * 2
small_n = estimate_token_count(small)
large_n = estimate_token_count(large)
# Within 10% of 2x — both tokenizers (tiktoken and the char
# fallback) preserve length monotonicity.
assert 1.8 * small_n <= large_n <= 2.2 * small_n
def test_fallback_uses_chars_per_token_when_tiktoken_unavailable(
self,
) -> None:
"""When the tiktoken encoding is None (not installed), the
function falls back to len/CHARS_PER_TOKEN math."""
text = "x" * 100
with patch.object(token_utils, "_ENCODING", None):
result = estimate_token_count(text)
assert result == int(100 / CHARS_PER_TOKEN)
def test_fallback_when_tiktoken_encode_raises(self) -> None:
"""A misbehaving encoding should fall back to the char heuristic
rather than raise — the size guard must never fail-open."""
class BoomEncoding:
def encode(self, text: str) -> list[int]:
raise ValueError("simulated tiktoken failure")
text = "abc" * 50
with patch.object(token_utils, "_ENCODING", BoomEncoding()):
result = estimate_token_count(text)
assert result == int(len(text) / CHARS_PER_TOKEN)
class TestEstimateResponseTokens: