mirror of
https://github.com/apache/superset.git
synced 2026-05-13 20:05:20 +00:00
Compare commits
12 Commits
fix/mcp-ex
...
fix/mcp-li
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
18e8217498 | ||
|
|
32234ea0bb | ||
|
|
a2b0f64176 | ||
|
|
4d66dc0774 | ||
|
|
5542a2f3b1 | ||
|
|
8c80caefa3 | ||
|
|
8088c5d1de | ||
|
|
9b520312a1 | ||
|
|
9ac4711ac8 | ||
|
|
7593d2a164 | ||
|
|
d3c44e311e | ||
|
|
b5186d1c65 |
2
.github/workflows/ephemeral-env-pr-close.yml
vendored
2
.github/workflows/ephemeral-env-pr-close.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
||||
- name: Login to Amazon ECR
|
||||
if: steps.describe-services.outputs.active == 'true'
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
|
||||
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
|
||||
|
||||
- name: Delete ECR image tag
|
||||
if: steps.describe-services.outputs.active == 'true'
|
||||
|
||||
4
.github/workflows/ephemeral-env.yml
vendored
4
.github/workflows/ephemeral-env.yml
vendored
@@ -199,7 +199,7 @@ jobs:
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
|
||||
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
|
||||
|
||||
- name: Load, tag and push image to ECR
|
||||
id: push-image
|
||||
@@ -235,7 +235,7 @@ jobs:
|
||||
|
||||
- name: Login to Amazon ECR
|
||||
id: login-ecr
|
||||
uses: aws-actions/amazon-ecr-login@19d944daaa35f0fa1d3f7f8af1d3f2e5de25c5b7 # v2
|
||||
uses: aws-actions/amazon-ecr-login@fa648b43de3d4d023bcb3f89ed6940096949c419 # v2
|
||||
|
||||
- name: Check target image exists in ECR
|
||||
id: check-image
|
||||
|
||||
@@ -70,7 +70,7 @@
|
||||
"@swc/core": "^1.15.33",
|
||||
"antd": "^6.3.7",
|
||||
"baseline-browser-mapping": "^2.10.27",
|
||||
"caniuse-lite": "^1.0.30001791",
|
||||
"caniuse-lite": "^1.0.30001792",
|
||||
"docusaurus-plugin-openapi-docs": "^5.0.2",
|
||||
"docusaurus-theme-openapi-docs": "^5.0.2",
|
||||
"js-yaml": "^4.1.1",
|
||||
|
||||
@@ -6035,10 +6035,10 @@ caniuse-api@^3.0.0:
|
||||
lodash.memoize "^4.1.2"
|
||||
lodash.uniq "^4.5.0"
|
||||
|
||||
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001791:
|
||||
version "1.0.30001791"
|
||||
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001791.tgz#dfb93d85c40ad380c57123e72e10f3c575786b51"
|
||||
integrity sha512-yk0l/YSrOnFZk3UROpDLQD9+kC1l4meK/wed583AXrzoarMGJcbRi2Q4RaUYbKxYAsZ8sWmaSa/DsLmdBeI1vQ==
|
||||
caniuse-lite@^1.0.0, caniuse-lite@^1.0.30001702, caniuse-lite@^1.0.30001759, caniuse-lite@^1.0.30001792:
|
||||
version "1.0.30001792"
|
||||
resolved "https://registry.yarnpkg.com/caniuse-lite/-/caniuse-lite-1.0.30001792.tgz#ca8bb9be244835a335e2018272ce7223691873c5"
|
||||
integrity sha512-hVLMUZFgR4JJ6ACt1uEESvQN1/dBVqPAKY0hgrV70eN3391K6juAfTjKZLKvOMsx8PxA7gsY1/tLMMTcfFLLpw==
|
||||
|
||||
ccount@^2.0.0:
|
||||
version "2.0.1"
|
||||
|
||||
@@ -145,7 +145,13 @@ solr = ["sqlalchemy-solr >= 0.2.0"]
|
||||
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
|
||||
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
|
||||
excel = ["xlrd>=1.2.0, <1.3"]
|
||||
fastmcp = ["fastmcp>=3.2.4,<4.0"]
|
||||
fastmcp = [
|
||||
"fastmcp>=3.2.4,<4.0",
|
||||
# tiktoken backs the response-size-guard token estimator. Without
|
||||
# it, the middleware falls back to a coarser character-based
|
||||
# heuristic that under-counts JSON-heavy MCP responses.
|
||||
"tiktoken>=0.7.0,<1.0",
|
||||
]
|
||||
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
|
||||
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
|
||||
gevent = ["gevent>=23.9.1"]
|
||||
@@ -377,6 +383,7 @@ dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$"
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"superset/mcp_service/app.py" = ["S608", "E501"] # LLM instruction text: SQL examples (S608) and long lines in multiline string (E501)
|
||||
"superset/mcp_service/*/tool/list_*.py" = ["E501"] # LLM docstring examples show full request shapes which exceed line length
|
||||
"scripts/*" = ["TID251"]
|
||||
"setup.py" = ["TID251"]
|
||||
"superset/config.py" = ["TID251"]
|
||||
|
||||
@@ -183,7 +183,9 @@ idna==3.10
|
||||
# trio
|
||||
# url-normalize
|
||||
isodate==0.7.2
|
||||
# via apache-superset (pyproject.toml)
|
||||
# via
|
||||
# apache-superset (pyproject.toml)
|
||||
# apache-superset-core
|
||||
itsdangerous==2.2.0
|
||||
# via
|
||||
# flask
|
||||
@@ -296,6 +298,7 @@ pyarrow==20.0.0
|
||||
# via
|
||||
# -r requirements/base.in
|
||||
# apache-superset (pyproject.toml)
|
||||
# apache-superset-core
|
||||
pyasn1==0.6.3
|
||||
# via
|
||||
# pyasn1-modules
|
||||
|
||||
@@ -442,6 +442,7 @@ isodate==0.7.2
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# apache-superset-core
|
||||
isort==6.0.1
|
||||
# via pylint
|
||||
itsdangerous==2.2.0
|
||||
@@ -715,6 +716,7 @@ pyarrow==20.0.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# apache-superset-core
|
||||
# db-dtypes
|
||||
# pandas-gbq
|
||||
pyasn1==0.6.3
|
||||
@@ -866,6 +868,8 @@ referencing==0.36.2
|
||||
# jsonschema
|
||||
# jsonschema-path
|
||||
# jsonschema-specifications
|
||||
regex==2026.4.4
|
||||
# via tiktoken
|
||||
requests==2.33.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -878,6 +882,7 @@ requests==2.33.0
|
||||
# requests-cache
|
||||
# requests-oauthlib
|
||||
# shillelagh
|
||||
# tiktoken
|
||||
# trino
|
||||
requests-cache==1.2.1
|
||||
# via
|
||||
@@ -1003,6 +1008,8 @@ tabulate==0.9.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
tiktoken==0.12.0
|
||||
# via apache-superset
|
||||
tomli-w==1.2.0
|
||||
# via apache-superset-extensions-cli
|
||||
tomlkit==0.13.3
|
||||
|
||||
@@ -17,7 +17,8 @@
|
||||
* under the License.
|
||||
*/
|
||||
import { render, screen, act } from 'spec/helpers/testing-library';
|
||||
import { StatusIndicatorDot } from './StatusIndicatorDot';
|
||||
import { supersetTheme } from '@apache-superset/core/theme';
|
||||
import { getStatusConfig, StatusIndicatorDot } from './StatusIndicatorDot';
|
||||
import { AutoRefreshStatus } from '../../types/autoRefresh';
|
||||
|
||||
afterEach(() => {
|
||||
@@ -62,6 +63,15 @@ test('renders with paused status', () => {
|
||||
expect(dot).toHaveAttribute('data-status', AutoRefreshStatus.Paused);
|
||||
});
|
||||
|
||||
test('uses the icon color for the paused status outline', () => {
|
||||
expect(
|
||||
getStatusConfig(supersetTheme, AutoRefreshStatus.Paused),
|
||||
).toMatchObject({
|
||||
needsBorder: true,
|
||||
outlineColor: 'currentColor',
|
||||
});
|
||||
});
|
||||
|
||||
test('has correct accessibility attributes', () => {
|
||||
render(<StatusIndicatorDot status={AutoRefreshStatus.Success} />);
|
||||
const dot = screen.getByTestId('status-indicator-dot');
|
||||
|
||||
@@ -39,9 +39,10 @@ export interface StatusIndicatorDotProps {
|
||||
interface StatusConfig {
|
||||
color: string;
|
||||
needsBorder: boolean;
|
||||
outlineColor?: string;
|
||||
}
|
||||
|
||||
const getStatusConfig = (
|
||||
export const getStatusConfig = (
|
||||
theme: ReturnType<typeof useTheme>,
|
||||
status: AutoRefreshStatus,
|
||||
): StatusConfig => {
|
||||
@@ -75,6 +76,7 @@ const getStatusConfig = (
|
||||
return {
|
||||
color: theme.colorBgContainer,
|
||||
needsBorder: true,
|
||||
outlineColor: 'currentColor',
|
||||
};
|
||||
default:
|
||||
return {
|
||||
@@ -136,13 +138,15 @@ export const StatusIndicatorDot: FC<StatusIndicatorDotProps> = ({
|
||||
width: ${size}px;
|
||||
height: ${size}px;
|
||||
border-radius: 50%;
|
||||
color: ${theme.colorTextSecondary};
|
||||
background-color: ${statusConfig.color};
|
||||
transition:
|
||||
background-color ${theme.motionDurationMid} ease-in-out,
|
||||
border-color ${theme.motionDurationMid} ease-in-out;
|
||||
border: ${statusConfig.needsBorder
|
||||
? `1px solid ${theme.colorBorder}`
|
||||
: 'none'};
|
||||
border: ${statusConfig.needsBorder ? '1px solid' : 'none'};
|
||||
border-color: ${statusConfig.needsBorder
|
||||
? statusConfig.outlineColor
|
||||
: 'transparent'};
|
||||
box-shadow: ${statusConfig.needsBorder
|
||||
? 'none'
|
||||
: `0 0 0 2px ${theme.colorBgContainer}`};
|
||||
|
||||
@@ -21,6 +21,10 @@ import { VizType } from '@superset-ui/core';
|
||||
import { hydrateExplore, HYDRATE_EXPLORE } from './hydrateExplore';
|
||||
import { exploreInitialData } from '../fixtures';
|
||||
|
||||
afterEach(() => {
|
||||
window.history.pushState({}, '', '/');
|
||||
});
|
||||
|
||||
test('creates hydrate action from initial data', () => {
|
||||
const dispatch = jest.fn();
|
||||
const getState = jest.fn(() => ({
|
||||
@@ -168,6 +172,84 @@ test('creates hydrate action with existing state', () => {
|
||||
);
|
||||
});
|
||||
|
||||
test('hydrates sliceName from preview form data before saved slice name', () => {
|
||||
window.history.pushState({}, '', '/explore/?form_data_key=preview-key');
|
||||
|
||||
const dispatch = jest.fn();
|
||||
const getState = jest.fn(() => ({
|
||||
user: {},
|
||||
charts: {},
|
||||
datasources: {},
|
||||
common: {},
|
||||
explore: {},
|
||||
}));
|
||||
const previewSliceName = 'RENAMED - Bug Evidence';
|
||||
const savedSliceName = 'Most Populated Countries';
|
||||
const previewInitialData = {
|
||||
...exploreInitialData,
|
||||
form_data: {
|
||||
...exploreInitialData.form_data,
|
||||
slice_name: previewSliceName,
|
||||
},
|
||||
slice: {
|
||||
...exploreInitialData.slice!,
|
||||
slice_name: savedSliceName,
|
||||
},
|
||||
};
|
||||
|
||||
// @ts-expect-error we only need the fields consumed by hydrateExplore
|
||||
hydrateExplore(previewInitialData)(dispatch, getState);
|
||||
|
||||
expect(dispatch).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: HYDRATE_EXPLORE,
|
||||
data: expect.objectContaining({
|
||||
explore: expect.objectContaining({
|
||||
sliceName: previewSliceName,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('hydrates sliceName from saved slice when regular form data has stale name', () => {
|
||||
const dispatch = jest.fn();
|
||||
const getState = jest.fn(() => ({
|
||||
user: {},
|
||||
charts: {},
|
||||
datasources: {},
|
||||
common: {},
|
||||
explore: {},
|
||||
}));
|
||||
const staleFormDataSliceName = 'Stale Params Name';
|
||||
const savedSliceName = 'Current Saved Name';
|
||||
const savedChartInitialData = {
|
||||
...exploreInitialData,
|
||||
form_data: {
|
||||
...exploreInitialData.form_data,
|
||||
slice_name: staleFormDataSliceName,
|
||||
},
|
||||
slice: {
|
||||
...exploreInitialData.slice!,
|
||||
slice_name: savedSliceName,
|
||||
},
|
||||
};
|
||||
|
||||
// @ts-expect-error we only need the fields consumed by hydrateExplore
|
||||
hydrateExplore(savedChartInitialData)(dispatch, getState);
|
||||
|
||||
expect(dispatch).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
type: HYDRATE_EXPLORE,
|
||||
data: expect.objectContaining({
|
||||
explore: expect.objectContaining({
|
||||
sliceName: savedSliceName,
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
);
|
||||
});
|
||||
|
||||
test('uses configured default time range if not set', () => {
|
||||
const dispatch = jest.fn();
|
||||
const getState = jest.fn(() => ({
|
||||
|
||||
@@ -77,6 +77,12 @@ export const hydrateExplore =
|
||||
const fallbackSlice = sliceId ? sliceEntities?.slices?.[sliceId] : null;
|
||||
const initialSlice = slice ?? fallbackSlice;
|
||||
const initialFormData = form_data ?? initialSlice?.form_data;
|
||||
const isCachedFormData = getUrlParam(URL_PARAMS.formDataKey) !== null;
|
||||
const [primarySliceNameSource, fallbackSliceNameSource] = isCachedFormData
|
||||
? [initialFormData, initialSlice]
|
||||
: [initialSlice, initialFormData];
|
||||
const initialSliceName =
|
||||
primarySliceNameSource?.slice_name ?? fallbackSliceNameSource?.slice_name;
|
||||
if (!initialFormData.viz_type) {
|
||||
const defaultVizType = common?.conf.DEFAULT_VIZ_TYPE || VizType.Table;
|
||||
initialFormData.viz_type =
|
||||
@@ -183,6 +189,7 @@ export const hydrateExplore =
|
||||
// because `bootstrapData.controls` is undefined.
|
||||
controls: initialControls,
|
||||
form_data: initialFormData,
|
||||
sliceName: initialSliceName,
|
||||
slice: initialSlice,
|
||||
controlsTransferred: explore.controlsTransferred,
|
||||
standalone: getUrlParam(URL_PARAMS.standalone),
|
||||
|
||||
@@ -179,6 +179,33 @@ test('renders the right footer buttons', () => {
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test('initializes chart name from current Explore slice name', () => {
|
||||
const previewSliceName = 'RENAMED - Bug Evidence';
|
||||
const savedSliceName = 'Most Populated Countries';
|
||||
const { getByTestId } = setup(
|
||||
{
|
||||
...defaultProps,
|
||||
form_data: {
|
||||
...defaultProps.form_data,
|
||||
slice_name: previewSliceName,
|
||||
},
|
||||
sliceName: previewSliceName,
|
||||
},
|
||||
mockStore({
|
||||
...initialState,
|
||||
explore: {
|
||||
...initialState.explore,
|
||||
slice: {
|
||||
...initialState.explore.slice,
|
||||
slice_name: savedSliceName,
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
|
||||
expect(getByTestId('new-chart-name')).toHaveValue(previewSliceName);
|
||||
});
|
||||
|
||||
test('does not render a message when overriding', () => {
|
||||
const { getByRole, queryByRole } = setup();
|
||||
|
||||
|
||||
@@ -537,8 +537,11 @@ class ChartFilter(ColumnOperator):
|
||||
"datasource_name",
|
||||
] = Field(
|
||||
...,
|
||||
description="Column to filter on. Use get_schema(model_type='chart') for "
|
||||
"available filter columns.",
|
||||
description=(
|
||||
"Column to filter on. Valid values: 'slice_name', 'viz_type', "
|
||||
"'datasource_name'. Other column names are not valid filter columns "
|
||||
"and will cause a validation error."
|
||||
),
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
|
||||
@@ -28,7 +28,7 @@ from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.extensions import db, event_logger
|
||||
from superset.mcp_service.chart.ascii_charts import (
|
||||
generate_ascii_chart,
|
||||
generate_ascii_table,
|
||||
@@ -1140,6 +1140,15 @@ async def _get_chart_preview_internal( # noqa: C901
|
||||
)
|
||||
chart = find_chart_by_identifier(request.identifier)
|
||||
|
||||
# Eagerly refresh all attributes while the session is still
|
||||
# active. SQLAlchemy expires object attributes after any
|
||||
# commit; if a downstream operation commits before the strategy
|
||||
# classes access chart attributes, a DetachedInstanceError will
|
||||
# be raised. Calling refresh() here ensures all column values
|
||||
# are loaded into the object's __dict__ upfront.
|
||||
if chart is not None:
|
||||
db.session.refresh(chart)
|
||||
|
||||
# If not found and looks like a form_data_key, try transient
|
||||
if (
|
||||
not chart
|
||||
@@ -1371,6 +1380,20 @@ async def _get_chart_preview_internal( # noqa: C901
|
||||
|
||||
return _sanitize_chart_preview_for_llm_context(result)
|
||||
|
||||
except SQLAlchemyError as e:
|
||||
# Catch DetachedInstanceError and other SQLAlchemy errors that can
|
||||
# surface when the ORM session expires or commits mid-request.
|
||||
await ctx.error(
|
||||
"Chart preview failed due to database session error: "
|
||||
"identifier=%s, error_type=%s, error=%s"
|
||||
% (request.identifier, type(e).__name__, str(e))
|
||||
)
|
||||
logger.exception("SQLAlchemy error in get_chart_preview: %s", e)
|
||||
return ChartError(
|
||||
error="Database session error while generating chart preview. "
|
||||
"Please retry the request.",
|
||||
error_type="InternalError",
|
||||
)
|
||||
except (
|
||||
CommandException,
|
||||
SupersetException,
|
||||
|
||||
@@ -91,8 +91,24 @@ async def list_charts(
|
||||
Returns chart metadata including id, name, viz_type, URL, and last
|
||||
modified time.
|
||||
|
||||
Sortable columns for order_column: id, slice_name, viz_type, description,
|
||||
changed_on, created_on
|
||||
**IMPORTANT**: All parameters must be wrapped in a ``request`` object.
|
||||
Do NOT pass ``search``, ``page``, ``page_size``, etc. as top-level
|
||||
keyword arguments — they will be rejected. Use the ``request`` wrapper::
|
||||
|
||||
# Correct usage
|
||||
list_charts(request={"search": "revenue", "page": 1, "page_size": 10})
|
||||
list_charts(request={"filters": [{"col": "slice_name", "opr": "sw", "value": "sales"}]})
|
||||
list_charts() # no arguments returns first page with defaults
|
||||
|
||||
# Wrong — causes pydantic validation errors
|
||||
list_charts(search="revenue", page=1) # DO NOT DO THIS
|
||||
|
||||
Valid filter columns for ``filters[].col``:
|
||||
``slice_name``, ``viz_type``, ``datasource_name``
|
||||
|
||||
Sortable columns for ``order_column``:
|
||||
``id``, ``slice_name``, ``viz_type``, ``description``,
|
||||
``changed_on``, ``created_on``
|
||||
"""
|
||||
request = request or _DEFAULT_LIST_CHARTS_REQUEST.model_copy(deep=True)
|
||||
await ctx.info(
|
||||
|
||||
@@ -176,8 +176,9 @@ class DashboardFilter(ColumnOperator):
|
||||
] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Column to filter on. Use "
|
||||
"get_schema(model_type='dashboard') for available filter columns."
|
||||
"Column to filter on. Valid values: 'dashboard_title', 'published', "
|
||||
"'favorite'. Other column names are not valid filter columns and will "
|
||||
"cause a validation error."
|
||||
),
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
|
||||
@@ -85,8 +85,24 @@ async def list_dashboards(
|
||||
including title, slug, URL, and last modified time. Use select_columns to
|
||||
request additional fields.
|
||||
|
||||
Sortable columns for order_column: id, dashboard_title, slug, published,
|
||||
changed_on, created_on
|
||||
**IMPORTANT**: All parameters must be wrapped in a ``request`` object.
|
||||
Do NOT pass ``search``, ``page``, ``page_size``, etc. as top-level
|
||||
keyword arguments — they will be rejected. Use the ``request`` wrapper::
|
||||
|
||||
# Correct usage
|
||||
list_dashboards(request={"search": "sales", "page": 1, "page_size": 10})
|
||||
list_dashboards(request={"filters": [{"col": "dashboard_title", "opr": "sw", "value": "exec"}]})
|
||||
list_dashboards() # no arguments returns first page with defaults
|
||||
|
||||
# Wrong — causes pydantic validation errors
|
||||
list_dashboards(search="sales", page=1) # DO NOT DO THIS
|
||||
|
||||
Valid filter columns for ``filters[].col``:
|
||||
``dashboard_title``, ``published``, ``favorite``
|
||||
|
||||
Sortable columns for ``order_column``:
|
||||
``id``, ``dashboard_title``, ``slug``, ``published``,
|
||||
``changed_on``, ``created_on``
|
||||
"""
|
||||
request = request or _DEFAULT_LIST_DASHBOARDS_REQUEST.model_copy(deep=True)
|
||||
await ctx.info(
|
||||
|
||||
@@ -71,8 +71,11 @@ class DatasetFilter(ColumnOperator):
|
||||
"database_name",
|
||||
] = Field(
|
||||
...,
|
||||
description="Column to filter on. Use get_schema(model_type='dataset') for "
|
||||
"available filter columns.",
|
||||
description=(
|
||||
"Column to filter on. Valid values: 'table_name', 'schema', "
|
||||
"'database_name'. Other column names (e.g. 'created_by_fk', 'id') "
|
||||
"are not valid filter columns and will cause a validation error."
|
||||
),
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
|
||||
@@ -96,8 +96,23 @@ async def list_datasets(
|
||||
Returns dataset metadata including table name, schema, and last modified
|
||||
time.
|
||||
|
||||
Sortable columns for order_column: id, table_name, schema, changed_on,
|
||||
created_on
|
||||
**IMPORTANT**: All parameters must be wrapped in a ``request`` object.
|
||||
Do NOT pass ``search``, ``page``, ``page_size``, etc. as top-level
|
||||
keyword arguments — they will be rejected. Use the ``request`` wrapper::
|
||||
|
||||
# Correct usage
|
||||
list_datasets(request={"search": "sales", "page": 1, "page_size": 10})
|
||||
list_datasets(request={"filters": [{"col": "table_name", "opr": "sw", "value": "orders"}]})
|
||||
list_datasets() # no arguments returns first page with defaults
|
||||
|
||||
# Wrong — causes pydantic validation errors
|
||||
list_datasets(search="sales", page=1) # DO NOT DO THIS
|
||||
|
||||
Valid filter columns for ``filters[].col``:
|
||||
``table_name``, ``schema``, ``database_name``
|
||||
|
||||
Sortable columns for ``order_column``:
|
||||
``id``, ``table_name``, ``schema``, ``changed_on``, ``created_on``
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_datasets")
|
||||
|
||||
@@ -41,6 +41,12 @@ from superset.mcp_service.constants import (
|
||||
DEFAULT_TOKEN_LIMIT,
|
||||
DEFAULT_WARN_THRESHOLD_PCT,
|
||||
)
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
format_size_limit_error,
|
||||
INFO_TOOLS,
|
||||
truncate_oversized_response,
|
||||
)
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -1104,11 +1110,6 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
``content[0].text`` as a JSON string. We parse that string, run the
|
||||
truncation phases on the resulting dict, then re-wrap the result.
|
||||
"""
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
truncate_oversized_response,
|
||||
)
|
||||
|
||||
# Unwrap ToolResult so truncation operates on the real payload
|
||||
extracted = self._extract_payload_from_tool_result(response)
|
||||
if extracted is not None:
|
||||
@@ -1191,12 +1192,6 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
# Execute the tool
|
||||
response = await call_next(context)
|
||||
|
||||
# Estimate response token count (guard against huge responses causing OOM)
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
estimate_response_tokens,
|
||||
format_size_limit_error,
|
||||
)
|
||||
|
||||
# When the response is a ToolResult, estimate tokens on the actual
|
||||
# payload inside content[0].text rather than on the ToolResult
|
||||
# wrapper (which would double-serialize the JSON string).
|
||||
@@ -1233,8 +1228,6 @@ class ResponseSizeGuardMiddleware(Middleware):
|
||||
params = getattr(context.message, "params", {}) or {}
|
||||
|
||||
# For info tools, try dynamic truncation before blocking
|
||||
from superset.mcp_service.utils.token_utils import INFO_TOOLS
|
||||
|
||||
if tool_name in INFO_TOOLS:
|
||||
truncated = self._try_truncate_info_response(
|
||||
tool_name, response, estimated_tokens
|
||||
|
||||
@@ -21,6 +21,26 @@ Token counting and response size utilities for MCP service.
|
||||
This module provides utilities to estimate token counts and generate smart
|
||||
suggestions when responses exceed configured limits. This prevents large
|
||||
responses from overwhelming LLM clients like Claude Desktop.
|
||||
|
||||
Token counting strategy:
|
||||
|
||||
1. ``tiktoken`` with the ``cl100k_base`` encoding when the package is
|
||||
installed (it is shipped as part of the ``fastmcp`` extra). This is a
|
||||
real BPE tokenizer trained on a similar vocabulary to Claude's; for
|
||||
English and JSON-heavy MCP payloads it tracks Claude's tokenizer
|
||||
within roughly ±10%, which is far more accurate than the legacy
|
||||
character heuristic.
|
||||
2. A character-based fallback (``CHARS_PER_TOKEN``) when tiktoken is not
|
||||
importable. The fallback uses a slightly more conservative ratio than
|
||||
before (3.0 chars/token instead of 3.5) so that JSON-heavy responses
|
||||
are not under-counted, which previously let oversized payloads slip
|
||||
past the response-size guard.
|
||||
|
||||
The exact-Claude tokenizer is only available via Anthropic's network
|
||||
``count_tokens`` API; calling it from a synchronous middleware on every
|
||||
tool result is too slow and adds an external dependency on every
|
||||
response. ``tiktoken`` is the closest approximation we can ship without
|
||||
that risk.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -36,18 +56,63 @@ logger = logging.getLogger(__name__)
|
||||
# Type alias for MCP tool responses (Pydantic models, dicts, lists, strings, bytes)
|
||||
ToolResponse: TypeAlias = Union[BaseModel, Dict[str, Any], List[Any], str, bytes]
|
||||
|
||||
# Approximate characters per token for estimation
|
||||
# Claude tokenizer averages ~4 chars per token for English text
|
||||
# JSON tends to be more verbose, so we use a slightly lower ratio
|
||||
CHARS_PER_TOKEN = 3.5
|
||||
# Fallback character-to-token ratio used when tiktoken is unavailable.
|
||||
# 3.0 is conservative for JSON content (the previous 3.5 under-counted
|
||||
# JSON-heavy payloads relative to Claude's actual tokenizer, which let
|
||||
# oversized responses slip past the response-size guard).
|
||||
CHARS_PER_TOKEN = 3.0
|
||||
|
||||
# Encoding used when tiktoken is available. cl100k_base is OpenAI's
|
||||
# tokenizer for GPT-3.5/4; it is BPE-based with a vocabulary similar to
|
||||
# Claude's and tracks Claude's token counts within roughly ±10% for
|
||||
# English and JSON-heavy MCP responses.
|
||||
_TIKTOKEN_ENCODING_NAME = "cl100k_base"
|
||||
|
||||
|
||||
def _load_tiktoken_encoding() -> Any:
|
||||
"""Return a tiktoken encoding instance, or None if tiktoken is unavailable.
|
||||
|
||||
Imported lazily so the module can be used in environments without
|
||||
tiktoken installed. The encoding is small (~1 MB) so we cache it on
|
||||
first use.
|
||||
"""
|
||||
try:
|
||||
import tiktoken
|
||||
except ImportError:
|
||||
logger.info(
|
||||
"tiktoken not installed; falling back to char-based token "
|
||||
"estimation (CHARS_PER_TOKEN=%s). Install the 'fastmcp' extra "
|
||||
"for accurate counts.",
|
||||
CHARS_PER_TOKEN,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
return tiktoken.get_encoding(_TIKTOKEN_ENCODING_NAME)
|
||||
except (KeyError, ValueError) as exc:
|
||||
# tiktoken installed but the requested encoding is missing — this
|
||||
# only happens on partial installs. Treat as no tokenizer rather
|
||||
# than crashing on every tool call.
|
||||
logger.warning(
|
||||
"tiktoken encoding '%s' unavailable: %s; falling back to "
|
||||
"char-based token estimation",
|
||||
_TIKTOKEN_ENCODING_NAME,
|
||||
exc,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
# Cached encoding instance (None if tiktoken not importable).
|
||||
_ENCODING = _load_tiktoken_encoding()
|
||||
|
||||
|
||||
def estimate_token_count(text: str | bytes) -> int:
|
||||
"""
|
||||
Estimate the token count for a given text.
|
||||
|
||||
Uses a character-based heuristic since we don't have direct access to
|
||||
the actual tokenizer. This is conservative to avoid underestimating.
|
||||
Uses tiktoken's ``cl100k_base`` encoding when available for
|
||||
Claude-aligned accuracy (within ~10%), falling back to a
|
||||
character-based heuristic otherwise.
|
||||
|
||||
Args:
|
||||
text: The text to estimate tokens for (string or bytes)
|
||||
@@ -58,11 +123,19 @@ def estimate_token_count(text: str | bytes) -> int:
|
||||
if isinstance(text, bytes):
|
||||
text = text.decode("utf-8", errors="replace")
|
||||
|
||||
# Simple heuristic: ~3.5 characters per token for JSON/code
|
||||
text_length = len(text)
|
||||
if text_length == 0:
|
||||
if not text:
|
||||
return 0
|
||||
return max(1, int(text_length / CHARS_PER_TOKEN))
|
||||
|
||||
if _ENCODING is not None:
|
||||
try:
|
||||
return len(_ENCODING.encode(text))
|
||||
except (ValueError, UnicodeError) as exc:
|
||||
# Defensive: if tiktoken chokes on a specific input, fall
|
||||
# back to the char heuristic for this call rather than
|
||||
# raising — the response size guard must never fail-open.
|
||||
logger.warning("tiktoken encode failed (%s); using fallback", exc)
|
||||
|
||||
return max(1, int(len(text) / CHARS_PER_TOKEN))
|
||||
|
||||
|
||||
def estimate_response_tokens(response: ToolResponse) -> int:
|
||||
|
||||
@@ -45,6 +45,13 @@
|
||||
color: #000;
|
||||
}
|
||||
{% endif %}
|
||||
{% if standalone_mode %}
|
||||
/* Keep body sized so screenshot waits don't see it as hidden before React mounts. */
|
||||
html, body.standalone {
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
}
|
||||
{% endif %}
|
||||
</style>
|
||||
|
||||
{% if dark_theme_bg and entry != 'embedded' %}
|
||||
|
||||
@@ -68,3 +68,50 @@ def test_spa_template_includes_css_bundles():
|
||||
"spa.html must call css_bundle for the page entry to load "
|
||||
"entry-specific extracted CSS in production builds"
|
||||
)
|
||||
|
||||
|
||||
def test_spa_template_standalone_body_has_min_height():
|
||||
"""Standalone body must be measurable so screenshot waits don't time out."""
|
||||
from jinja2 import DictLoader, Environment
|
||||
|
||||
template_path = join(SUPERSET_DIR, "templates", "superset", "spa.html")
|
||||
with open(template_path) as f:
|
||||
template_content = f.read()
|
||||
|
||||
env = Environment( # noqa: S701
|
||||
loader=DictLoader(
|
||||
{
|
||||
"spa.html": template_content,
|
||||
# Stub out includes/imports that are not relevant for this test.
|
||||
"appbuilder/general/lib.html": "",
|
||||
"superset/partials/asset_bundle.html": (
|
||||
"{% macro css_bundle(prefix, entry) %}{% endmacro %}"
|
||||
"{% macro js_bundle(prefix, entry) %}{% endmacro %}"
|
||||
),
|
||||
"superset/macros.html": ("{% macro get_nonce() %}{% endmacro %}"),
|
||||
"tail_js_custom_extra.html": "",
|
||||
"head_custom_extra.html": "",
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
appbuilder = Mock()
|
||||
appbuilder.app.config = {"FAVICONS": []}
|
||||
|
||||
def render(standalone_mode: bool) -> str:
|
||||
return env.get_template("spa.html").render(
|
||||
appbuilder=appbuilder,
|
||||
assets_prefix="",
|
||||
bootstrap_data="{}",
|
||||
entry="spa",
|
||||
standalone_mode=standalone_mode,
|
||||
theme_tokens={},
|
||||
spinner_svg=None,
|
||||
)
|
||||
|
||||
standalone_html = render(standalone_mode=True)
|
||||
assert "body.standalone" in standalone_html
|
||||
assert "min-height: 100vh" in standalone_html
|
||||
|
||||
non_standalone_html = render(standalone_mode=False)
|
||||
assert "body.standalone" not in non_standalone_html
|
||||
|
||||
@@ -595,3 +595,191 @@ Market Share
|
||||
"""
|
||||
|
||||
# These demonstrate the expected ASCII formats for different chart types
|
||||
|
||||
|
||||
class TestDetachedInstanceError:
|
||||
"""Tests that DetachedInstanceError is handled gracefully.
|
||||
|
||||
When the SQLAlchemy session commits mid-request, ORM objects expire and
|
||||
become detached. Accessing lazy attributes on a detached Slice raises
|
||||
DetachedInstanceError. The tool must:
|
||||
1. Call db.session.refresh() immediately after loading the chart so all
|
||||
column values are loaded upfront before any downstream operation.
|
||||
2. Catch SQLAlchemyError (the base class) and return a ChartError
|
||||
instead of propagating the exception.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_session_refresh_called_after_chart_load(self):
|
||||
"""db.session.refresh() is invoked right after find_chart_by_identifier."""
|
||||
import importlib
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.mcp_service.chart.schemas import URLPreview
|
||||
from superset.utils import json
|
||||
|
||||
get_chart_preview_module = importlib.import_module(
|
||||
"superset.mcp_service.chart.tool.get_chart_preview"
|
||||
)
|
||||
|
||||
mock_chart = MagicMock()
|
||||
mock_chart.id = 42
|
||||
mock_chart.slice_name = "Sales Chart"
|
||||
mock_chart.viz_type = "table"
|
||||
mock_chart.datasource_id = 1
|
||||
mock_chart.datasource_type = "table"
|
||||
mock_chart.params = "{}"
|
||||
|
||||
refresh_calls: list[object] = []
|
||||
|
||||
def _fake_refresh(obj: object) -> None:
|
||||
refresh_calls.append(obj)
|
||||
|
||||
url_preview = URLPreview(
|
||||
preview_url="http://localhost/explore/?slice_id=42",
|
||||
width=800,
|
||||
height=600,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
get_chart_preview_module,
|
||||
"find_chart_by_identifier",
|
||||
return_value=mock_chart,
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module.db,
|
||||
"session",
|
||||
**{"refresh.side_effect": _fake_refresh},
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module,
|
||||
"validate_chart_dataset",
|
||||
return_value=MagicMock(is_valid=True, warnings=[]),
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module.event_logger,
|
||||
"log_context",
|
||||
return_value=nullcontext(),
|
||||
),
|
||||
# Return a real URLPreview so Pydantic model validation succeeds
|
||||
patch.object(
|
||||
get_chart_preview_module.PreviewFormatGenerator,
|
||||
"generate",
|
||||
return_value=url_preview,
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.utils.url_utils.get_superset_base_url",
|
||||
return_value="http://localhost",
|
||||
),
|
||||
):
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
|
||||
mu.return_value = MagicMock(id=1, username="admin")
|
||||
with patch(
|
||||
"superset.mcp_service.auth.check_tool_permission", return_value=True
|
||||
):
|
||||
async with Client(mcp) as client:
|
||||
response = await client.call_tool(
|
||||
"get_chart_preview",
|
||||
{
|
||||
"request": GetChartPreviewRequest(
|
||||
identifier=42, format="url"
|
||||
).model_dump()
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(response.content[0].text)
|
||||
# The tool should succeed — not return a ChartError
|
||||
assert "error_type" not in data, (
|
||||
f"Expected ChartPreview but got ChartError: {data.get('error')}"
|
||||
)
|
||||
assert data.get("chart_id") == 42
|
||||
|
||||
assert len(refresh_calls) == 1, (
|
||||
"db.session.refresh() should be called once after loading the chart"
|
||||
)
|
||||
assert refresh_calls[0] is mock_chart
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_detached_instance_error_returns_chart_error(self):
|
||||
"""DetachedInstanceError during preview generation returns ChartError."""
|
||||
import importlib
|
||||
from contextlib import nullcontext
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from sqlalchemy.orm.exc import DetachedInstanceError
|
||||
|
||||
get_chart_preview_module = importlib.import_module(
|
||||
"superset.mcp_service.chart.tool.get_chart_preview"
|
||||
)
|
||||
|
||||
mock_chart = MagicMock()
|
||||
mock_chart.id = 7
|
||||
mock_chart.slice_name = "Broken Chart"
|
||||
mock_chart.viz_type = "bar"
|
||||
mock_chart.datasource_id = 3
|
||||
mock_chart.datasource_type = "table"
|
||||
mock_chart.params = "{}"
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
get_chart_preview_module,
|
||||
"find_chart_by_identifier",
|
||||
return_value=mock_chart,
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module.db,
|
||||
"session",
|
||||
**{"refresh.return_value": None},
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module,
|
||||
"validate_chart_dataset",
|
||||
return_value=MagicMock(is_valid=True, warnings=[]),
|
||||
),
|
||||
patch.object(
|
||||
get_chart_preview_module.event_logger,
|
||||
"log_context",
|
||||
return_value=nullcontext(),
|
||||
),
|
||||
# Simulate the session expiring inside the strategy
|
||||
patch.object(
|
||||
get_chart_preview_module.PreviewFormatGenerator,
|
||||
"generate",
|
||||
side_effect=DetachedInstanceError(),
|
||||
),
|
||||
patch(
|
||||
"superset.mcp_service.utils.url_utils.get_superset_base_url",
|
||||
return_value="http://localhost",
|
||||
),
|
||||
):
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.chart.schemas import GetChartPreviewRequest
|
||||
from superset.utils import json
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mu:
|
||||
mu.return_value = MagicMock(id=1, username="admin")
|
||||
with patch(
|
||||
"superset.mcp_service.auth.check_tool_permission", return_value=True
|
||||
):
|
||||
async with Client(mcp) as client:
|
||||
response = await client.call_tool(
|
||||
"get_chart_preview",
|
||||
{
|
||||
"request": GetChartPreviewRequest(
|
||||
identifier=7, format="ascii"
|
||||
).model_dump()
|
||||
},
|
||||
)
|
||||
|
||||
data = json.loads(response.content[0].text)
|
||||
assert data["error_type"] == "InternalError"
|
||||
assert "session" in data["error"].lower() or "retry" in data["error"].lower()
|
||||
|
||||
@@ -1349,7 +1349,8 @@ class TestDashboardSortableColumns:
|
||||
|
||||
# Check list_dashboards docstring for sortable columns documentation
|
||||
assert list_dashboards.__doc__ is not None
|
||||
assert "Sortable columns for order_column:" in list_dashboards.__doc__
|
||||
assert "Sortable columns for" in list_dashboards.__doc__
|
||||
assert "order_column" in list_dashboards.__doc__
|
||||
for col in SORTABLE_DASHBOARD_COLUMNS:
|
||||
assert col in list_dashboards.__doc__
|
||||
|
||||
|
||||
@@ -1700,7 +1700,8 @@ class TestDatasetSortableColumns:
|
||||
|
||||
# Check list_datasets docstring for sortable columns documentation
|
||||
assert list_datasets.__doc__ is not None
|
||||
assert "Sortable columns for order_column:" in list_datasets.__doc__
|
||||
assert "Sortable columns for" in list_datasets.__doc__
|
||||
assert "order_column" in list_datasets.__doc__
|
||||
for col in SORTABLE_DATASET_COLUMNS:
|
||||
assert col in list_datasets.__doc__
|
||||
|
||||
@@ -2080,3 +2081,90 @@ class TestListDatasetsOwnedByMe:
|
||||
request = ListDatasetsRequest(owned_by_me=True, created_by_me=True)
|
||||
assert request.owned_by_me is True
|
||||
assert request.created_by_me is True
|
||||
|
||||
|
||||
class TestListDatasetsRequestWrapper:
|
||||
"""
|
||||
Tests verifying that list_datasets requires a ``request`` wrapper object.
|
||||
|
||||
LLMs sometimes pass parameters like ``search``, ``page``, or ``page_size``
|
||||
as flat top-level kwargs instead of nesting them inside a ``request``
|
||||
object. These tests confirm the correct call shape through both the Pydantic
|
||||
schema and the actual MCP tool layer, and verify that invalid filter column
|
||||
names (e.g. ``created_by_fk``) are rejected.
|
||||
"""
|
||||
|
||||
def test_request_wrapper_with_search(self) -> None:
|
||||
"""Parameters passed inside request= are accepted by the schema."""
|
||||
request = ListDatasetsRequest(search="sales", page=1, page_size=10)
|
||||
assert request.search == "sales"
|
||||
assert request.page == 1
|
||||
assert request.page_size == 10
|
||||
|
||||
def test_request_wrapper_defaults(self) -> None:
|
||||
"""No-arg constructor produces valid schema defaults."""
|
||||
request = ListDatasetsRequest()
|
||||
assert request.search is None
|
||||
assert request.page == 1
|
||||
assert request.filters == []
|
||||
|
||||
def test_dataset_filter_valid_col(self) -> None:
|
||||
"""Valid col values are accepted by DatasetFilter."""
|
||||
for col in ("table_name", "schema", "database_name"):
|
||||
f = DatasetFilter(col=col, opr="sw", value="test")
|
||||
assert f.col == col
|
||||
|
||||
def test_dataset_filter_invalid_col_raises(self) -> None:
|
||||
"""Column names not in the Literal are rejected with a validation error.
|
||||
|
||||
This guards against LLMs passing ``created_by_fk`` or similar
|
||||
internal column names that are not exposed as filter fields.
|
||||
"""
|
||||
from pydantic import ValidationError
|
||||
|
||||
for bad_col in ("created_by_fk", "id", "database_id", "owner"):
|
||||
with pytest.raises(ValidationError):
|
||||
DatasetFilter(col=bad_col, opr="eq", value="1")
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_wrapper_enforced_by_tool(
|
||||
self, mock_list, mcp_server
|
||||
) -> None:
|
||||
"""The MCP tool layer accepts the request wrapper and returns results.
|
||||
|
||||
Verifies end-to-end that wrapping params in ``request={}`` works through
|
||||
the actual FastMCP tool call, not just schema validation.
|
||||
"""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"list_datasets",
|
||||
{"request": {"search": "sales", "page": 1, "page_size": 5}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 0
|
||||
assert data["datasets"] == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_flat_kwargs_rejected(self, mcp_server) -> None:
|
||||
"""Passing search/page/page_size as top-level kwargs raises a ToolError
|
||||
that specifically mentions the unexpected arguments.
|
||||
|
||||
This is the exact failure pattern from story #105712: LLMs call
|
||||
``list_datasets(search=..., page=..., page_size=...)`` instead of
|
||||
``list_datasets(request={...})``.
|
||||
"""
|
||||
with pytest.raises(ToolError) as exc_info:
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"list_datasets",
|
||||
{"search": "sales", "page": 1, "page_size": 10},
|
||||
)
|
||||
error_text = str(exc_info.value)
|
||||
# The error must call out the unexpected arguments, not some unrelated failure
|
||||
assert (
|
||||
"search" in error_text
|
||||
or "Unexpected" in error_text
|
||||
or "request" in error_text
|
||||
)
|
||||
|
||||
@@ -146,7 +146,13 @@ class TestResponseSizeGuardMiddleware:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_logs_warning_at_threshold(self) -> None:
|
||||
"""Should log warning when approaching limit."""
|
||||
"""Should log warning when approaching limit.
|
||||
|
||||
Mocks the token estimator to return a specific value above the
|
||||
warn threshold but below the hard limit, decoupling the test
|
||||
from whichever tokenizer (tiktoken or char heuristic) happens
|
||||
to be loaded.
|
||||
"""
|
||||
middleware = ResponseSizeGuardMiddleware(
|
||||
token_limit=1000, warn_threshold_pct=80
|
||||
)
|
||||
@@ -155,18 +161,21 @@ class TestResponseSizeGuardMiddleware:
|
||||
context.message.name = "list_charts"
|
||||
context.message.params = {}
|
||||
|
||||
# Response at ~85% of limit (should trigger warning but not block)
|
||||
response = {"data": "x" * 2900} # ~828 tokens at 3.5 chars/token
|
||||
response = {"data": "approaching the limit"}
|
||||
call_next = AsyncMock(return_value=response)
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.middleware.get_user_id", return_value=1),
|
||||
patch("superset.mcp_service.middleware.event_logger"),
|
||||
patch(
|
||||
"superset.mcp_service.middleware.estimate_response_tokens",
|
||||
return_value=850,
|
||||
),
|
||||
patch("superset.mcp_service.middleware.logger") as mock_logger,
|
||||
):
|
||||
result = await middleware.on_call_tool(context, call_next)
|
||||
|
||||
# Should return response (not blocked)
|
||||
# Should return response (not blocked at 85% of limit)
|
||||
assert result == response
|
||||
# Should log warning
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
@@ -20,9 +20,11 @@ Unit tests for MCP service token utilities.
|
||||
"""
|
||||
|
||||
from typing import Any, List
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from superset.mcp_service.utils import token_utils
|
||||
from superset.mcp_service.utils.token_utils import (
|
||||
_replace_collections_with_summaries,
|
||||
_summarize_large_dicts,
|
||||
@@ -45,29 +47,65 @@ class TestEstimateTokenCount:
|
||||
"""Test estimate_token_count function."""
|
||||
|
||||
def test_estimate_string(self) -> None:
|
||||
"""Should estimate tokens for a string."""
|
||||
"""Should produce a positive non-zero estimate for a normal string.
|
||||
|
||||
We don't assert on a specific number because the result depends on
|
||||
which tokenizer is loaded (tiktoken when available, char heuristic
|
||||
otherwise).
|
||||
"""
|
||||
text = "Hello world"
|
||||
result = estimate_token_count(text)
|
||||
expected = int(len(text) / CHARS_PER_TOKEN)
|
||||
assert result == expected
|
||||
assert result > 0
|
||||
|
||||
def test_estimate_bytes(self) -> None:
|
||||
"""Should estimate tokens for bytes."""
|
||||
text = b"Hello world"
|
||||
result = estimate_token_count(text)
|
||||
expected = int(len(text) / CHARS_PER_TOKEN)
|
||||
assert result == expected
|
||||
"""Bytes input should be decoded and produce the same count as the
|
||||
equivalent string."""
|
||||
text = "Hello world"
|
||||
assert estimate_token_count(text.encode("utf-8")) == estimate_token_count(text)
|
||||
|
||||
def test_empty_string(self) -> None:
|
||||
"""Should return 0 for empty string."""
|
||||
"""Should return 0 for empty string and empty bytes."""
|
||||
assert estimate_token_count("") == 0
|
||||
assert estimate_token_count(b"") == 0
|
||||
|
||||
def test_json_like_content(self) -> None:
|
||||
"""Should estimate tokens for JSON-like content."""
|
||||
"""JSON content should produce a positive estimate."""
|
||||
json_str = '{"name": "test", "value": 123, "items": [1, 2, 3]}'
|
||||
result = estimate_token_count(json_str)
|
||||
assert result > 0
|
||||
assert result == int(len(json_str) / CHARS_PER_TOKEN)
|
||||
assert estimate_token_count(json_str) > 0
|
||||
|
||||
def test_long_text_roughly_scales_with_length(self) -> None:
|
||||
"""A doubled string should produce roughly double the token count
|
||||
(within ±10%)."""
|
||||
small = "the quick brown fox jumps over the lazy dog. " * 20
|
||||
large = small * 2
|
||||
small_n = estimate_token_count(small)
|
||||
large_n = estimate_token_count(large)
|
||||
# Within 10% of 2x — both tokenizers (tiktoken and the char
|
||||
# fallback) preserve length monotonicity.
|
||||
assert 1.8 * small_n <= large_n <= 2.2 * small_n
|
||||
|
||||
def test_fallback_uses_chars_per_token_when_tiktoken_unavailable(
|
||||
self,
|
||||
) -> None:
|
||||
"""When the tiktoken encoding is None (not installed), the
|
||||
function falls back to len/CHARS_PER_TOKEN math."""
|
||||
text = "x" * 100
|
||||
with patch.object(token_utils, "_ENCODING", None):
|
||||
result = estimate_token_count(text)
|
||||
assert result == int(100 / CHARS_PER_TOKEN)
|
||||
|
||||
def test_fallback_when_tiktoken_encode_raises(self) -> None:
|
||||
"""A misbehaving encoding should fall back to the char heuristic
|
||||
rather than raise — the size guard must never fail-open."""
|
||||
|
||||
class BoomEncoding:
|
||||
def encode(self, text: str) -> list[int]:
|
||||
raise ValueError("simulated tiktoken failure")
|
||||
|
||||
text = "abc" * 50
|
||||
with patch.object(token_utils, "_ENCODING", BoomEncoding()):
|
||||
result = estimate_token_count(text)
|
||||
assert result == int(len(text) / CHARS_PER_TOKEN)
|
||||
|
||||
|
||||
class TestEstimateResponseTokens:
|
||||
|
||||
Reference in New Issue
Block a user