feat(mcp): add BM25 tool search transform to reduce initial context size (#38562)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-03-13 18:06:11 +01:00
committed by GitHub
parent b6c3b3ef46
commit 97a66f7a64
11 changed files with 552 additions and 476 deletions

View File

@@ -144,7 +144,7 @@ solr = ["sqlalchemy-solr >= 0.2.0"]
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
excel = ["xlrd>=1.2.0, <1.3"]
fastmcp = ["fastmcp==2.14.3"]
fastmcp = ["fastmcp>=3.1.0,<4.0"]
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
gevent = ["gevent>=23.9.1"]

View File

@@ -10,6 +10,8 @@
# via
# -r requirements/development.in
# apache-superset
aiofile==3.9.0
# via py-key-value-aio
alembic==1.15.2
# via
# -c requirements/base-constraint.txt
@@ -26,8 +28,10 @@ anyio==4.11.0
# via
# httpx
# mcp
# py-key-value-aio
# sse-starlette
# starlette
# watchfiles
apispec==6.6.1
# via
# -c requirements/base-constraint.txt
@@ -65,9 +69,7 @@ bcrypt==4.3.0
# -c requirements/base-constraint.txt
# paramiko
beartype==0.22.5
# via
# py-key-value-aio
# py-key-value-shared
# via py-key-value-aio
billiard==4.2.1
# via
# -c requirements/base-constraint.txt
@@ -100,6 +102,8 @@ cachetools==6.2.1
# -c requirements/base-constraint.txt
# google-auth
# py-key-value-aio
caio==0.9.25
# via aiofile
cattrs==25.1.1
# via
# -c requirements/base-constraint.txt
@@ -138,7 +142,6 @@ click==8.2.1
# click-repl
# flask
# flask-appbuilder
# typer
# uvicorn
click-didyoumean==0.3.1
# via
@@ -156,8 +159,6 @@ click-repl==0.3.0
# via
# -c requirements/base-constraint.txt
# celery
cloudpickle==3.1.2
# via pydocket
cmdstanpy==1.1.0
# via prophet
colorama==0.4.6
@@ -206,8 +207,6 @@ deprecation==2.1.0
# apache-superset
dill==0.4.0
# via pylint
diskcache==5.6.3
# via py-key-value-aio
distlib==0.3.8
# via virtualenv
dnspython==2.7.0
@@ -237,9 +236,7 @@ et-xmlfile==2.0.0
# openpyxl
exceptiongroup==1.3.0
# via fastmcp
fakeredis==2.32.1
# via pydocket
fastmcp==2.14.3
fastmcp==3.1.0
# via apache-superset
filelock==3.20.3
# via
@@ -474,6 +471,8 @@ jsonpath-ng==1.7.0
# via
# -c requirements/base-constraint.txt
# apache-superset
jsonref==1.1.0
# via fastmcp
jsonschema==4.23.0
# via
# -c requirements/base-constraint.txt
@@ -504,8 +503,6 @@ limits==5.1.0
# via
# -c requirements/base-constraint.txt
# flask-limiter
lupa==2.6
# via fakeredis
mako==1.3.10
# via
# -c requirements/base-constraint.txt
@@ -603,7 +600,7 @@ openpyxl==3.1.5
# -c requirements/base-constraint.txt
# pandas
opentelemetry-api==1.39.1
# via pydocket
# via fastmcp
ordered-set==4.1.0
# via
# -c requirements/base-constraint.txt
@@ -622,6 +619,7 @@ packaging==25.0
# deprecation
# docker
# duckdb-engine
# fastmcp
# google-cloud-bigquery
# gunicorn
# limits
@@ -653,8 +651,6 @@ parsedatetime==2.6
# apache-superset
pathable==0.4.3
# via jsonschema-path
pathvalidate==3.3.1
# via py-key-value-aio
pgsanity==0.2.9
# via
# -c requirements/base-constraint.txt
@@ -691,8 +687,6 @@ prison==0.2.1
# flask-appbuilder
progress==1.6
# via apache-superset
prometheus-client==0.23.1
# via pydocket
prompt-toolkit==3.0.51
# via
# -c requirements/base-constraint.txt
@@ -714,12 +708,8 @@ psutil==6.1.0
# via apache-superset
psycopg2-binary==2.9.9
# via apache-superset
py-key-value-aio==0.3.0
# via
# fastmcp
# pydocket
py-key-value-shared==0.3.0
# via py-key-value-aio
py-key-value-aio==0.4.4
# via fastmcp
pyarrow==16.1.0
# via
# -c requirements/base-constraint.txt
@@ -758,8 +748,6 @@ pydantic-settings==2.10.1
# via mcp
pydata-google-auth==1.9.0
# via pandas-gbq
pydocket==0.17.1
# via fastmcp
pydruid==0.6.9
# via apache-superset
pyfakefs==5.3.5
@@ -844,8 +832,6 @@ python-dotenv==1.1.0
# apache-superset
# fastmcp
# pydantic-settings
python-json-logger==4.0.0
# via pydocket
python-ldap==3.4.4
# via apache-superset
python-multipart==0.0.20
@@ -866,15 +852,13 @@ pyyaml==6.0.2
# -c requirements/base-constraint.txt
# apache-superset
# apispec
# fastmcp
# jsonschema-path
# pre-commit
redis==5.3.1
# via
# -c requirements/base-constraint.txt
# apache-superset
# fakeredis
# py-key-value-aio
# pydocket
referencing==0.36.2
# via
# -c requirements/base-constraint.txt
@@ -910,9 +894,7 @@ rich==13.9.4
# cyclopts
# fastmcp
# flask-limiter
# pydocket
# rich-rst
# typer
rich-rst==1.3.1
# via cyclopts
rpds-py==0.25.0
@@ -944,8 +926,6 @@ setuptools==80.9.0
# pydata-google-auth
# zope-event
# zope-interface
shellingham==1.5.4
# via typer
shillelagh==1.4.3
# via
# -c requirements/base-constraint.txt
@@ -973,7 +953,6 @@ sniffio==1.3.1
sortedcontainers==2.4.0
# via
# -c requirements/base-constraint.txt
# fakeredis
# trio
sqlalchemy==1.4.54
# via
@@ -1034,8 +1013,6 @@ trio-websocket==0.12.2
# via
# -c requirements/base-constraint.txt
# selenium
typer==0.20.0
# via pydocket
typing-extensions==4.15.0
# via
# -c requirements/base-constraint.txt
@@ -1048,16 +1025,14 @@ typing-extensions==4.15.0
# limits
# mcp
# opentelemetry-api
# py-key-value-shared
# py-key-value-aio
# pydantic
# pydantic-core
# pydocket
# pyopenssl
# referencing
# selenium
# shillelagh
# starlette
# typer
# typing-inspection
typing-inspection==0.4.1
# via
@@ -1072,6 +1047,8 @@ tzdata==2025.2
# pandas
tzlocal==5.2
# via trino
uncalled-for==0.2.0
# via fastmcp
url-normalize==2.2.1
# via
# -c requirements/base-constraint.txt
@@ -1101,6 +1078,8 @@ watchdog==6.0.0
# -c requirements/base-constraint.txt
# apache-superset
# apache-superset-extensions-cli
watchfiles==1.1.1
# via fastmcp
wcwidth==0.2.13
# via
# -c requirements/base-constraint.txt

View File

@@ -384,15 +384,12 @@ class ChartList(BaseModel):
class ColumnRef(BaseModel):
name: str = Field(
...,
description="Column name",
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
label: str | None = Field(
None, description="Display label for the column", max_length=500
)
dtype: str | None = Field(None, description="Data type hint")
label: str | None = Field(None, max_length=500)
dtype: str | None = None
aggregate: (
Literal[
"SUM",
@@ -407,11 +404,7 @@ class ColumnRef(BaseModel):
"PERCENTILE",
]
| None
) = Field(
None,
description="SQL aggregation function. Only these validated functions are "
"supported to prevent SQL errors.",
)
) = Field(None, description="SQL aggregate function")
@field_validator("name")
@classmethod
@@ -431,25 +424,22 @@ class ColumnRef(BaseModel):
class AxisConfig(BaseModel):
title: str | None = Field(None, description="Axis title", max_length=200)
scale: Literal["linear", "log"] | None = Field(
"linear", description="Axis scale type"
)
format: str | None = Field(
None, description="Format string (e.g. '$,.2f')", max_length=50
)
title: str | None = Field(None, max_length=200)
scale: Literal["linear", "log"] | None = "linear"
format: str | None = Field(None, description="e.g. '$,.2f'", max_length=50)
class LegendConfig(BaseModel):
show: bool = Field(True, description="Whether to show legend")
position: Literal["top", "bottom", "left", "right"] | None = Field(
"right", description="Legend position"
)
show: bool = True
position: Literal["top", "bottom", "left", "right"] | None = "right"
class FilterConfig(BaseModel):
column: str = Field(
..., description="Column to filter on", min_length=1, max_length=255
...,
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
op: Literal[
"=",
@@ -465,17 +455,11 @@ class FilterConfig(BaseModel):
"NOT IN",
] = Field(
...,
description=(
"Filter operator. Use LIKE/ILIKE for pattern matching with % wildcards "
"(e.g., '%mario%'). Use IN/NOT IN with a list of values."
),
description="LIKE/ILIKE use % wildcards. IN/NOT IN take a list.",
)
value: str | int | float | bool | list[str | int | float | bool] = Field(
...,
description=(
"Filter value. For IN/NOT IN operators, provide a list of values. "
"For LIKE/ILIKE, use % as wildcard (e.g., '%mario%')."
),
description="For IN/NOT IN, provide a list.",
)
@field_validator("column")
@@ -516,26 +500,13 @@ class FilterConfig(BaseModel):
class PieChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
chart_type: Literal["pie"] = Field(
...,
description=(
"Chart type discriminator - MUST be 'pie' for pie/donut charts. "
"This field is REQUIRED and tells Superset which chart "
"configuration schema to use."
),
)
dimension: ColumnRef = Field(
..., description="Category column that defines the pie slices"
)
chart_type: Literal["pie"] = "pie"
dimension: ColumnRef = Field(..., description="Category column for slices")
metric: ColumnRef = Field(
...,
description=(
"Value metric that determines slice sizes. "
"Must include an aggregate function (e.g., SUM, COUNT)."
),
..., description="Value metric (needs aggregate e.g. SUM, COUNT)"
)
donut: bool = Field(False, description="Render as a donut chart with a center hole")
show_labels: bool = Field(True, description="Display labels on slices")
donut: bool = False
show_labels: bool = True
label_type: Literal[
"key",
"value",
@@ -544,63 +515,32 @@ class PieChartConfig(BaseModel):
"key_percent",
"key_value_percent",
"value_percent",
] = Field("key_value_percent", description="Type of labels to show on slices")
sort_by_metric: bool = Field(True, description="Sort slices by metric value")
show_legend: bool = Field(True, description="Whether to show legend")
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
row_limit: int = Field(
100,
description="Maximum number of slices to display",
ge=1,
le=10000,
)
number_format: str = Field(
"SMART_NUMBER",
description="Number format string",
max_length=50,
)
show_total: bool = Field(False, description="Display aggregate count in center")
labels_outside: bool = Field(True, description="Place labels outside the pie")
outer_radius: int = Field(
70,
description="Outer edge radius as a percentage (1-100)",
ge=1,
le=100,
)
] = "key_value_percent"
sort_by_metric: bool = True
show_legend: bool = True
filters: List[FilterConfig] | None = None
row_limit: int = Field(100, description="Max slices", ge=1, le=10000)
number_format: str = Field("SMART_NUMBER", max_length=50)
show_total: bool = Field(False, description="Show total in center")
labels_outside: bool = True
outer_radius: int = Field(70, description="Outer radius % (1-100)", ge=1, le=100)
inner_radius: int = Field(
30,
description="Inner radius as a percentage for donut (1-100)",
ge=1,
le=100,
30, description="Donut inner radius % (1-100)", ge=1, le=100
)
class PivotTableChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
chart_type: Literal["pivot_table"] = Field(
...,
description=(
"Chart type discriminator - MUST be 'pivot_table' for interactive "
"pivot tables. This field is REQUIRED."
),
)
rows: List[ColumnRef] = Field(
...,
min_length=1,
description="Row grouping columns (at least one required)",
)
chart_type: Literal["pivot_table"] = "pivot_table"
rows: List[ColumnRef] = Field(..., min_length=1, description="Row grouping columns")
columns: List[ColumnRef] | None = Field(
None,
description="Column grouping columns (optional, for cross-tabulation)",
None, description="Column groups for cross-tabulation"
)
metrics: List[ColumnRef] = Field(
...,
min_length=1,
description=(
"Metrics to aggregate. Each must have an aggregate function "
"(e.g., SUM, COUNT, AVG)."
),
description="Metrics (need aggregate e.g. SUM, COUNT, AVG)",
)
aggregate_function: Literal[
"Sum",
@@ -614,108 +554,56 @@ class PivotTableChartConfig(BaseModel):
"Count Unique Values",
"First",
"Last",
] = Field("Sum", description="Default aggregation function for the pivot table")
show_row_totals: bool = Field(True, description="Show row totals")
show_column_totals: bool = Field(True, description="Show column totals")
transpose: bool = Field(False, description="Swap rows and columns")
combine_metric: bool = Field(
False,
description="Display metrics side by side within columns",
)
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
row_limit: int = Field(
10000,
description="Maximum number of cells",
ge=1,
le=50000,
)
value_format: str = Field(
"SMART_NUMBER",
description="Value format string",
max_length=50,
)
] = "Sum"
show_row_totals: bool = True
show_column_totals: bool = True
transpose: bool = False
combine_metric: bool = Field(False, description="Metrics side by side in columns")
filters: List[FilterConfig] | None = None
row_limit: int = Field(10000, description="Max cells", ge=1, le=50000)
value_format: str = Field("SMART_NUMBER", max_length=50)
class MixedTimeseriesChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
chart_type: Literal["mixed_timeseries"] = Field(
...,
description=(
"Chart type discriminator - MUST be 'mixed_timeseries' for charts "
"that combine two different series types (e.g., line + bar). "
"This field is REQUIRED."
),
)
x: ColumnRef = Field(..., description="X-axis temporal column (shared)")
time_grain: TimeGrain | None = Field(
None,
description=(
"Time granularity for the x-axis. "
"Common values: PT1H (hourly), P1D (daily), P1W (weekly), "
"P1M (monthly), P1Y (yearly)."
),
)
chart_type: Literal["mixed_timeseries"] = "mixed_timeseries"
x: ColumnRef = Field(..., description="Shared temporal X-axis column")
time_grain: TimeGrain | None = Field(None, description="PT1H, P1D, P1W, P1M, P1Y")
# Primary series (Query A)
y: List[ColumnRef] = Field(
...,
min_length=1,
description="Primary Y-axis metrics (Query A)",
)
primary_kind: Literal["line", "bar", "area", "scatter"] = Field(
"line", description="Primary series chart type"
)
group_by: ColumnRef | None = Field(
None, description="Group by column for primary series"
)
y: List[ColumnRef] = Field(..., min_length=1, description="Primary Y-axis metrics")
primary_kind: Literal["line", "bar", "area", "scatter"] = "line"
group_by: ColumnRef | None = Field(None, description="Primary series group by")
# Secondary series (Query B)
y_secondary: List[ColumnRef] = Field(
...,
min_length=1,
description="Secondary Y-axis metrics (Query B)",
)
secondary_kind: Literal["line", "bar", "area", "scatter"] = Field(
"bar", description="Secondary series chart type"
..., min_length=1, description="Secondary Y-axis metrics"
)
secondary_kind: Literal["line", "bar", "area", "scatter"] = "bar"
group_by_secondary: ColumnRef | None = Field(
None, description="Group by column for secondary series"
None, description="Secondary series group by"
)
# Display options
show_legend: bool = Field(True, description="Whether to show legend")
x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
y_axis: AxisConfig | None = Field(None, description="Primary Y-axis configuration")
y_axis_secondary: AxisConfig | None = Field(
None, description="Secondary Y-axis configuration"
)
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
show_legend: bool = True
x_axis: AxisConfig | None = None
y_axis: AxisConfig | None = None
y_axis_secondary: AxisConfig | None = None
filters: List[FilterConfig] | None = None
class TableChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
chart_type: Literal["table"] = Field(
..., description="Chart type (REQUIRED: must be 'table')"
)
chart_type: Literal["table"] = "table"
viz_type: Literal["table", "ag-grid-table"] = Field(
"table",
description=(
"Visualization type: 'table' for standard table, 'ag-grid-table' for "
"AG Grid Interactive Table with advanced features like column resizing, "
"sorting, filtering, and server-side pagination"
),
"table", description="'ag-grid-table' for interactive features"
)
columns: List[ColumnRef] = Field(
...,
min_length=1,
description=(
"Columns to display. Must have at least one column. Each column must have "
"a unique label "
"(either explicitly set via 'label' field or auto-generated "
"from name/aggregate)"
),
description="Columns with unique labels",
)
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
sort_by: List[str] | None = Field(None, description="Columns to sort by")
filters: List[FilterConfig] | None = None
sort_by: List[str] | None = None
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "TableChartConfig":
@@ -748,56 +636,26 @@ class TableChartConfig(BaseModel):
class XYChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
chart_type: Literal["xy"] = Field(
...,
description=(
"Chart type discriminator - MUST be 'xy' for XY charts "
"(line, bar, area, scatter). "
"This field is REQUIRED and tells Superset which chart "
"configuration schema to use."
),
)
chart_type: Literal["xy"] = "xy"
x: ColumnRef = Field(..., description="X-axis column")
y: List[ColumnRef] = Field(
...,
min_length=1,
description="Y-axis columns (metrics). Must have at least one Y-axis column. "
"Each column must have a unique label "
"that doesn't conflict with x-axis or group_by labels",
)
kind: Literal["line", "bar", "area", "scatter"] = Field(
"line", description="Chart visualization type"
..., min_length=1, description="Y-axis metrics (unique labels)"
)
kind: Literal["line", "bar", "area", "scatter"] = "line"
time_grain: TimeGrain | None = Field(
None,
description=(
"Time granularity for the x-axis when it's a temporal column. "
"Common values: PT1S (second), PT1M (minute), PT1H (hour), "
"P1D (day), P1W (week), P1M (month), P3M (quarter), P1Y (year). "
"If not specified, Superset will use its default behavior."
),
None, description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y"
)
orientation: Literal["vertical", "horizontal"] | None = Field(
None,
description=(
"Bar chart orientation. Only applies when kind='bar'. "
"'vertical' (default): bars extend upward. "
"'horizontal': bars extend rightward, useful for long category names."
),
)
stacked: bool = Field(
False,
description="Stack bars/areas on top of each other instead of side-by-side",
None, description="Bar orientation (only for kind='bar')"
)
stacked: bool = False
group_by: ColumnRef | None = Field(
None,
description="Column to group by (creates series/breakdown). "
"Use this field for series grouping — do NOT use 'series'.",
None, description="Series breakdown column (not 'series')"
)
x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
y_axis: AxisConfig | None = Field(None, description="Y-axis configuration")
legend: LegendConfig | None = Field(None, description="Legend configuration")
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
x_axis: AxisConfig | None = None
y_axis: AxisConfig | None = None
legend: LegendConfig | None = None
filters: List[FilterConfig] | None = None
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "XYChartConfig":
@@ -949,21 +807,12 @@ class GenerateChartRequest(QueryCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration")
chart_name: str | None = Field(
None,
description="Custom chart name (optional, auto-generates if not provided)",
max_length=255,
)
save_chart: bool = Field(
default=False,
description="Whether to permanently save the chart in Superset",
)
generate_preview: bool = Field(
default=True,
description="Whether to generate a preview image",
None, description="Auto-generates if omitted", max_length=255
)
save_chart: bool = Field(default=False, description="Save permanently in Superset")
generate_preview: bool = True
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
default_factory=lambda: ["url"],
description="List of preview formats to generate",
)
@field_validator("chart_name")
@@ -1002,20 +851,14 @@ class GenerateExploreLinkRequest(FormDataCacheControl):
class UpdateChartRequest(QueryCacheControl):
identifier: int | str = Field(..., description="Chart identifier (ID, UUID)")
config: ChartConfig = Field(..., description="New chart configuration")
identifier: int | str = Field(..., description="Chart ID or UUID")
config: ChartConfig
chart_name: str | None = Field(
None,
description="New chart name (optional, will auto-generate if not provided)",
max_length=255,
)
generate_preview: bool = Field(
default=True,
description="Whether to generate a preview after updating",
None, description="Auto-generates if omitted", max_length=255
)
generate_preview: bool = True
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
default_factory=lambda: ["url"],
description="List of preview formats to generate",
)
@field_validator("chart_name")
@@ -1027,15 +870,11 @@ class UpdateChartRequest(QueryCacheControl):
class UpdateChartPreviewRequest(FormDataCacheControl):
form_data_key: str = Field(..., description="Existing form_data_key to update")
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="New chart configuration")
generate_preview: bool = Field(
default=True,
description="Whether to generate a preview after updating",
)
dataset_id: int | str = Field(..., description="Dataset ID or UUID")
config: ChartConfig
generate_preview: bool = True
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
default_factory=lambda: ["url"],
description="List of preview formats to generate",
)

View File

@@ -37,22 +37,10 @@ class CacheControlMixin(BaseModel):
- Dashboard Cache: Caches rendered dashboard components
"""
use_cache: bool = Field(
default=True,
description=(
"Whether to use Superset's cache layers. When True, will serve from "
"cache if available (query results, metadata, form data). When False, "
"will bypass cache and fetch fresh data."
),
)
use_cache: bool = Field(default=True, description="Use cache if available")
force_refresh: bool = Field(
default=False,
description=(
"Whether to force refresh cached data. When True, will invalidate "
"existing cache entries and fetch fresh data, then update the cache. "
"Overrides use_cache=True if both are specified."
),
default=False, description="Invalidate cache and fetch fresh data"
)
@@ -65,12 +53,7 @@ class QueryCacheControl(CacheControlMixin):
"""
cache_timeout: int | None = Field(
default=None,
description=(
"Override the default cache timeout in seconds for this query. "
"If not specified, uses dataset-level or global cache settings. "
"Set to 0 to disable caching for this specific query."
),
default=None, description="Cache timeout override in seconds (0 to disable)"
)
@@ -83,11 +66,7 @@ class MetadataCacheControl(CacheControlMixin):
"""
refresh_metadata: bool = Field(
default=False,
description=(
"Whether to refresh metadata cache for datasets, tables, and columns. "
"Useful when database schema has changed and you need fresh metadata."
),
default=False, description="Refresh metadata cache for schema changes"
)
@@ -100,11 +79,7 @@ class FormDataCacheControl(CacheControlMixin):
"""
cache_form_data: bool = Field(
default=True,
description=(
"Whether to cache the form data configuration for future use. "
"When False, generates temporary configurations that are not cached."
),
default=True, description="Cache form data for future use"
)

View File

@@ -287,45 +287,31 @@ class GetDashboardInfoRequest(MetadataCacheControl):
class DashboardInfo(BaseModel):
id: int | None = Field(None, description="Dashboard ID")
dashboard_title: str | None = Field(None, description="Dashboard title")
slug: str | None = Field(None, description="Dashboard slug")
description: str | None = Field(None, description="Dashboard description")
css: str | None = Field(None, description="Custom CSS for the dashboard")
certified_by: str | None = Field(None, description="Who certified the dashboard")
certification_details: str | None = Field(None, description="Certification details")
json_metadata: str | None = Field(
None, description="Dashboard metadata (JSON string)"
)
position_json: str | None = Field(None, description="Chart positions (JSON string)")
published: bool | None = Field(
None, description="Whether the dashboard is published"
)
is_managed_externally: bool | None = Field(
None, description="Whether managed externally"
)
external_url: str | None = Field(None, description="External URL")
created_on: str | datetime | None = Field(None, description="Creation timestamp")
changed_on: str | datetime | None = Field(
None, description="Last modification timestamp"
)
created_by: str | None = Field(None, description="Dashboard creator (username)")
changed_by: str | None = Field(None, description="Last modifier (username)")
uuid: str | None = Field(None, description="Dashboard UUID (converted to string)")
url: str | None = Field(None, description="Dashboard URL")
created_on_humanized: str | None = Field(
None, description="Humanized creation time"
)
changed_on_humanized: str | None = Field(
None, description="Humanized modification time"
)
chart_count: int = Field(0, description="Number of charts in the dashboard")
owners: List[UserInfo] = Field(default_factory=list, description="Dashboard owners")
tags: List[TagInfo] = Field(default_factory=list, description="Dashboard tags")
roles: List[RoleInfo] = Field(default_factory=list, description="Dashboard roles")
charts: List[ChartInfo] = Field(
default_factory=list, description="Dashboard charts"
)
id: int | None = None
dashboard_title: str | None = None
slug: str | None = None
description: str | None = None
css: str | None = None
certified_by: str | None = None
certification_details: str | None = None
json_metadata: str | None = None
position_json: str | None = None
published: bool | None = None
is_managed_externally: bool | None = None
external_url: str | None = None
created_on: str | datetime | None = None
changed_on: str | datetime | None = None
created_by: str | None = None
changed_by: str | None = None
uuid: str | None = None
url: str | None = None
created_on_humanized: str | None = None
changed_on_humanized: str | None = None
chart_count: int = 0
owners: List[UserInfo] = Field(default_factory=list)
tags: List[TagInfo] = Field(default_factory=list)
roles: List[RoleInfo] = Field(default_factory=list)
charts: List[ChartInfo] = Field(default_factory=list)
# Fields for permalink/filter state support
permalink_key: str | None = Field(

View File

@@ -25,7 +25,6 @@ Following the Stack Overflow recommendation:
"""
import logging
import os
from flask import current_app, Flask, has_app_context
@@ -52,62 +51,45 @@ try:
logger.info("Reusing existing Flask app from app context for MCP service")
# Use _get_current_object() to get the actual Flask app, not the LocalProxy
app = current_app._get_current_object()
elif appbuilder_initialized:
# appbuilder is initialized but we have no app context. Calling
# create_app() here would invoke appbuilder.init_app() a second
# time with a *different* Flask app, overwriting shared internal
# state (views, security manager, etc.). Fail loudly instead of
# silently corrupting the singleton.
raise RuntimeError(
"appbuilder is already initialized but no Flask app context is "
"available. Cannot call create_app() as it would re-initialize "
"appbuilder with a different Flask app instance."
)
else:
# Either appbuilder is not initialized (standalone MCP server),
# or appbuilder is initialized but we're not in an app context
# (edge case - should rarely happen). In both cases, create a minimal app.
# Standalone MCP server — Superset models are deeply coupled to
# appbuilder, security_manager, event_logger, encrypted_field_factory,
# etc. so we use create_app() for full initialization rather than
# trying to init a minimal subset (which leads to cascading failures).
#
# We avoid calling create_app() which would run full FAB initialization
# and could corrupt the shared appbuilder singleton if main app starts.
from superset.app import SupersetApp
# create_app() is safe here because in standalone mode the main
# Superset web server is not running in-process.
from superset.app import create_app
from superset.mcp_service.mcp_config import get_mcp_config
if appbuilder_initialized:
logger.warning(
"Appbuilder initialized but not in app context - "
"creating separate MCP Flask app"
)
else:
logger.info("Creating minimal Flask app for standalone MCP service")
# Disable debug mode to avoid side-effects like file watchers
_mcp_app = SupersetApp(__name__)
logger.info("Creating fully initialized Flask app for standalone MCP service")
_mcp_app = create_app()
_mcp_app.debug = False
# Load configuration
config_module = os.environ.get("SUPERSET_CONFIG", "superset.config")
_mcp_app.config.from_object(config_module)
# Apply MCP-specific configuration
# Apply MCP-specific configuration on top
mcp_config = get_mcp_config(_mcp_app.config)
_mcp_app.config.update(mcp_config)
# Initialize only the minimal dependencies needed for MCP service
with _mcp_app.app_context():
try:
from superset.extensions import db
from superset.core.mcp.core_mcp_injection import (
initialize_core_mcp_dependencies,
)
db.init_app(_mcp_app)
# Initialize only MCP-specific dependencies
# MCP tools import directly from superset.daos/models, so we only need
# the MCP decorator injection, not the full superset_core abstraction
from superset.core.mcp.core_mcp_injection import (
initialize_core_mcp_dependencies,
)
initialize_core_mcp_dependencies()
logger.info(
"Minimal MCP dependencies initialized for standalone MCP service"
)
except Exception as e:
logger.warning(
"Failed to initialize dependencies for MCP service: %s", e
)
initialize_core_mcp_dependencies()
app = _mcp_app
logger.info("Minimal Flask app instance created successfully for MCP service")
logger.info("Flask app fully initialized for standalone MCP service")
except Exception as e:
logger.error("Failed to create Flask app: %s", e)

View File

@@ -227,10 +227,44 @@ MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = {
"get_chart_preview", # Returns URLs, not data
"generate_explore_link", # Returns URLs
"open_sql_lab_with_context", # Returns URLs
"search_tools", # Returns tool schemas for discovery (intentionally large)
],
}
# =============================================================================
# MCP Tool Search Transform Configuration
# =============================================================================
#
# Overview:
# ---------
# When enabled, replaces the full tool catalog with a search interface.
# LLMs see only 2 synthetic tools (search_tools + call_tool) plus any
# pinned tools, and discover other tools on-demand via natural language search.
# This reduces initial context by ~70% (from ~40k tokens to ~5-8k tokens).
#
# Strategies:
# -----------
# - "bm25": Natural language search using BM25 ranking (recommended)
# - "regex": Pattern-based search using regular expressions
#
# Rollback:
# ---------
# Set enabled=False in superset_config.py for instant rollback.
# =============================================================================
MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = {
"enabled": True, # Enabled by default — reduces initial context by ~70%
"strategy": "bm25", # "bm25" (natural language) or "regex" (pattern matching)
"max_results": 5, # Max tools returned per search
"always_visible": [ # Tools always shown in list_tools (pinned)
"health_check",
"get_instance_info",
],
"search_tool_name": "search_tools", # Name of the search tool
"call_tool_name": "call_tool", # Name of the call proxy tool
}
def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]:
"""Default MCP auth factory using app.config values."""
if not app.config.get("MCP_AUTH_ENABLED", False):

View File

@@ -24,12 +24,17 @@ For multi-pod deployments, configure MCP_EVENT_STORE_CONFIG with Redis URL.
import logging
import os
from collections.abc import Sequence
from typing import Any
import uvicorn
from superset.mcp_service.app import create_mcp_app, init_fastmcp_server
from superset.mcp_service.mcp_config import get_mcp_factory_config, MCP_STORE_CONFIG
from superset.mcp_service.mcp_config import (
get_mcp_factory_config,
MCP_STORE_CONFIG,
MCP_TOOL_SEARCH_CONFIG,
)
from superset.mcp_service.middleware import (
create_response_size_guard_middleware,
GlobalErrorHandlerMiddleware,
@@ -111,8 +116,7 @@ def create_event_store(config: dict[str, Any] | None = None) -> Any | None:
if config is None:
config = MCP_STORE_CONFIG
redis_url = config.get("CACHE_REDIS_URL")
if not redis_url:
if not config.get("CACHE_REDIS_URL"):
logging.info("EventStore: Using in-memory storage (single-pod mode)")
return None
@@ -151,6 +155,117 @@ def create_event_store(config: dict[str, Any] | None = None) -> Any | None:
return None
def _strip_titles(obj: Any, in_properties_map: bool = False) -> Any:
"""Recursively strip schema metadata ``title`` keys.
Keeps real field names inside ``properties`` (e.g. a property literally
named ``title``), while removing auto-generated schema title metadata.
"""
if isinstance(obj, dict):
result: dict[str, Any] = {}
for key, value in obj.items():
if key == "title" and not in_properties_map:
continue
result[key] = _strip_titles(value, in_properties_map=(key == "properties"))
return result
if isinstance(obj, list):
return [_strip_titles(item, in_properties_map=False) for item in obj]
return obj
def _serialize_tools_without_output_schema(
tools: Sequence[Any],
) -> list[dict[str, Any]]:
"""Serialize tools to JSON, stripping outputSchema and titles to reduce tokens.
LLMs only need inputSchema to call tools. outputSchema accounts for
50-80% of the per-tool schema size, and auto-generated 'title' fields
add ~12% bloat. Stripping both cuts search result tokens significantly.
"""
results = []
for tool in tools:
data = tool.to_mcp_tool().model_dump(mode="json", exclude_none=True)
data.pop("outputSchema", None)
if input_schema := data.get("inputSchema"):
data["inputSchema"] = _strip_titles(input_schema)
results.append(data)
return results
def _fix_call_tool_arguments(tool: Any) -> Any:
"""Fix anyOf schema in call_tool ``arguments`` for MCP bridge compatibility.
FastMCP's BaseSearchTransform defines ``arguments`` as
``dict[str, Any] | None`` which emits an ``anyOf`` JSON Schema.
Some MCP bridges (mcp-remote, Claude Desktop) don't handle ``anyOf``
and strip it, leaving the field without a ``type`` — causing all
call_tool invocations to fail with "Input should be a valid dictionary".
Replaces the ``anyOf`` with a flat ``type: object``.
"""
if "arguments" in (props := (tool.parameters or {}).get("properties", {})):
props["arguments"] = {
"additionalProperties": True,
"default": None,
"description": "Arguments to pass to the tool",
"type": "object",
}
return tool
def _apply_tool_search_transform(mcp_instance: Any, config: dict[str, Any]) -> None:
"""Apply tool search transform to reduce initial context size.
When enabled, replaces the full tool catalog with a search interface.
LLMs see only synthetic search/call tools plus pinned tools, and
discover other tools on-demand via natural language search.
Uses subclassing (not monkey-patching) to override ``_make_call_tool``
and fix the ``arguments`` schema for MCP bridge compatibility.
NOTE: ``_make_call_tool`` is a private API in FastMCP 3.x
(fastmcp>=3.1.0,<4.0). If FastMCP changes or removes this method
in a future major version, these subclasses will need to be updated.
"""
strategy = config.get("strategy", "bm25")
kwargs: dict[str, Any] = {
"max_results": config.get("max_results", 5),
"always_visible": config.get("always_visible", []),
"search_tool_name": config.get("search_tool_name", "search_tools"),
"call_tool_name": config.get("call_tool_name", "call_tool"),
"search_result_serializer": _serialize_tools_without_output_schema,
}
if strategy == "regex":
from fastmcp.server.transforms.search import RegexSearchTransform
class _FixedRegexSearchTransform(RegexSearchTransform):
"""Regex search with fixed call_tool arguments schema."""
def _make_call_tool(self) -> Any:
return _fix_call_tool_arguments(super()._make_call_tool())
transform = _FixedRegexSearchTransform(**kwargs)
else:
from fastmcp.server.transforms.search import BM25SearchTransform
class _FixedBM25SearchTransform(BM25SearchTransform):
"""BM25 search with fixed call_tool arguments schema."""
def _make_call_tool(self) -> Any:
return _fix_call_tool_arguments(super()._make_call_tool())
transform = _FixedBM25SearchTransform(**kwargs)
mcp_instance.add_transform(transform)
logger.info(
"Tool search transform enabled (strategy=%s, max_results=%d, pinned=%s)",
strategy,
kwargs["max_results"],
kwargs["always_visible"],
)
def _create_auth_provider(flask_app: Any) -> Any | None:
"""Create an auth provider from Flask app config.
@@ -218,6 +333,11 @@ def run_server(
logging.info("Creating MCP app from factory configuration...")
factory_config = get_mcp_factory_config()
mcp_instance = create_mcp_app(**factory_config)
# Apply tool search transform if configured
tool_search_config = MCP_TOOL_SEARCH_CONFIG
if tool_search_config.get("enabled", False):
_apply_tool_search_transform(mcp_instance, tool_search_config)
else:
# Use default initialization with auth from Flask config
logging.info("Creating MCP app with default configuration...")
@@ -233,8 +353,7 @@ def run_server(
middleware_list = []
# Add caching middleware (innermost runs closest to the tool)
caching_middleware = create_response_caching_middleware()
if caching_middleware:
if caching_middleware := create_response_caching_middleware():
middleware_list.append(caching_middleware)
# Add response size guard (protects LLM clients from huge responses)
@@ -252,6 +371,18 @@ def run_server(
middleware=middleware_list or None,
)
# Apply tool search transform if configured
tool_search_config = flask_app.config.get(
"MCP_TOOL_SEARCH_CONFIG", MCP_TOOL_SEARCH_CONFIG
)
if tool_search_config.get("enabled", False):
_apply_tool_search_transform(mcp_instance, tool_search_config)
# Ensure the configured search tool name is excluded from the
# response size guard (search results are intentionally large)
if size_guard_middleware:
search_name = tool_search_config.get("search_tool_name", "search_tools")
size_guard_middleware.excluded_tools.add(search_name)
# Create EventStore for session management (Redis for multi-pod, None for in-memory)
event_store = create_event_store(event_store_config)

View File

@@ -58,54 +58,40 @@ class GetSupersetInstanceInfoRequest(BaseModel):
class InstanceSummary(BaseModel):
total_dashboards: int = Field(..., description="Total number of dashboards")
total_charts: int = Field(..., description="Total number of charts")
total_datasets: int = Field(..., description="Total number of datasets")
total_databases: int = Field(..., description="Total number of databases")
total_users: int = Field(..., description="Total number of users")
total_roles: int = Field(..., description="Total number of roles")
total_tags: int = Field(..., description="Total number of tags")
avg_charts_per_dashboard: float = Field(
..., description="Average number of charts per dashboard"
)
total_dashboards: int
total_charts: int
total_datasets: int
total_databases: int
total_users: int
total_roles: int
total_tags: int
avg_charts_per_dashboard: float
class RecentActivity(BaseModel):
dashboards_created_last_30_days: int = Field(
..., description="Dashboards created in the last 30 days"
)
charts_created_last_30_days: int = Field(
..., description="Charts created in the last 30 days"
)
datasets_created_last_30_days: int = Field(
..., description="Datasets created in the last 30 days"
)
dashboards_modified_last_7_days: int = Field(
..., description="Dashboards modified in the last 7 days"
)
charts_modified_last_7_days: int = Field(
..., description="Charts modified in the last 7 days"
)
datasets_modified_last_7_days: int = Field(
..., description="Datasets modified in the last 7 days"
)
dashboards_created_last_30_days: int
charts_created_last_30_days: int
datasets_created_last_30_days: int
dashboards_modified_last_7_days: int
charts_modified_last_7_days: int
datasets_modified_last_7_days: int
class DashboardBreakdown(BaseModel):
published: int = Field(..., description="Number of published dashboards")
unpublished: int = Field(..., description="Number of unpublished dashboards")
certified: int = Field(..., description="Number of certified dashboards")
with_charts: int = Field(..., description="Number of dashboards with charts")
without_charts: int = Field(..., description="Number of dashboards without charts")
published: int
unpublished: int
certified: int
with_charts: int
without_charts: int
class DatabaseBreakdown(BaseModel):
by_type: Dict[str, int] = Field(..., description="Breakdown of databases by type")
by_type: Dict[str, int]
class PopularContent(BaseModel):
top_tags: List[str] = Field(..., description="Most popular tags")
top_creators: List[str] = Field(..., description="Most active creators")
top_tags: List[str] = Field(default_factory=list)
top_creators: List[str] = Field(default_factory=list)
class FeatureAvailability(BaseModel):
@@ -125,33 +111,19 @@ class FeatureAvailability(BaseModel):
class InstanceInfo(BaseModel):
instance_summary: InstanceSummary = Field(
..., description="Instance summary information"
)
recent_activity: RecentActivity = Field(
..., description="Recent activity information"
)
dashboard_breakdown: DashboardBreakdown = Field(
..., description="Dashboard breakdown information"
)
database_breakdown: DatabaseBreakdown = Field(
..., description="Database breakdown by type"
)
popular_content: PopularContent = Field(
..., description="Popular content information"
)
instance_summary: InstanceSummary
recent_activity: RecentActivity
dashboard_breakdown: DashboardBreakdown
database_breakdown: DatabaseBreakdown
popular_content: PopularContent
current_user: UserInfo | None = Field(
None,
description="The authenticated user making the request. "
"Use current_user.id with created_by_fk filter to find your own assets.",
)
feature_availability: FeatureAvailability = Field(
...,
description=(
"Dynamic feature availability for the current user and deployment"
"Use current_user.id with created_by_fk filter to find your own assets."
),
)
timestamp: datetime = Field(..., description="Response timestamp")
feature_availability: FeatureAvailability
timestamp: datetime
class UserInfo(BaseModel):
@@ -207,10 +179,10 @@ class RoleInfo(BaseModel):
class PaginationInfo(BaseModel):
page: int = Field(..., description="Current page number")
page_size: int = Field(..., description="Number of items per page")
total_count: int = Field(..., description="Total number of items")
total_pages: int = Field(..., description="Total number of pages")
has_next: bool = Field(..., description="Whether there is a next page")
has_previous: bool = Field(..., description="Whether there is a previous page")
page: int
page_size: int
total_count: int
total_pages: int
has_next: bool
has_previous: bool
model_config = ConfigDict(ser_json_timedelta="iso8601")

View File

@@ -17,25 +17,32 @@
"""Test MCP app imports and tool/prompt registration."""
import asyncio
def _run(coro):
"""Run an async coroutine synchronously."""
return asyncio.run(coro)
def test_mcp_app_imports_successfully():
"""Test that the MCP app can be imported without errors."""
from superset.mcp_service.app import mcp
assert mcp is not None
assert hasattr(mcp, "_tool_manager")
tools = mcp._tool_manager._tools
assert len(tools) > 0
assert "health_check" in tools
assert "list_charts" in tools
tools = _run(mcp.list_tools())
tool_names = [t.name for t in tools]
assert len(tool_names) > 0
assert "health_check" in tool_names
assert "list_charts" in tool_names
def test_mcp_prompts_registered():
"""Test that MCP prompts are registered."""
from superset.mcp_service.app import mcp
prompts = mcp._prompt_manager._prompts
prompts = _run(mcp.list_prompts())
assert len(prompts) > 0
@@ -48,12 +55,10 @@ def test_mcp_resources_registered():
"""
from superset.mcp_service.app import mcp
resource_manager = mcp._resource_manager
resources = resource_manager._resources
resources = _run(mcp.list_resources())
assert len(resources) > 0, "No MCP resources registered"
# Verify the two documented resources are registered
resource_uris = set(resources.keys())
resource_uris = {str(r.uri) for r in resources}
assert "chart://configs" in resource_uris, (
"chart://configs resource not registered - "
"check superset/mcp_service/chart/__init__.py exists"

View File

@@ -0,0 +1,173 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for MCP tool search transform configuration and application."""
from types import SimpleNamespace
from unittest.mock import MagicMock
from fastmcp.server.transforms.search import BM25SearchTransform, RegexSearchTransform
from superset.mcp_service.mcp_config import MCP_TOOL_SEARCH_CONFIG
from superset.mcp_service.server import (
_apply_tool_search_transform,
_fix_call_tool_arguments,
_serialize_tools_without_output_schema,
)
def test_tool_search_config_defaults():
"""Default config has expected keys and values."""
assert MCP_TOOL_SEARCH_CONFIG["enabled"] is True
assert MCP_TOOL_SEARCH_CONFIG["strategy"] == "bm25"
assert MCP_TOOL_SEARCH_CONFIG["max_results"] == 5
assert "health_check" in MCP_TOOL_SEARCH_CONFIG["always_visible"]
assert "get_instance_info" in MCP_TOOL_SEARCH_CONFIG["always_visible"]
assert MCP_TOOL_SEARCH_CONFIG["search_tool_name"] == "search_tools"
assert MCP_TOOL_SEARCH_CONFIG["call_tool_name"] == "call_tool"
def test_apply_bm25_transform():
"""BM25 subclass is created and added when strategy is 'bm25'."""
mock_mcp = MagicMock()
config = {
"strategy": "bm25",
"max_results": 5,
"always_visible": ["health_check"],
"search_tool_name": "search_tools",
"call_tool_name": "call_tool",
}
_apply_tool_search_transform(mock_mcp, config)
mock_mcp.add_transform.assert_called_once()
transform = mock_mcp.add_transform.call_args[0][0]
assert isinstance(transform, BM25SearchTransform)
def test_apply_regex_transform():
"""Regex subclass is created and added when strategy is 'regex'."""
mock_mcp = MagicMock()
config = {
"strategy": "regex",
"max_results": 10,
"always_visible": ["health_check", "get_instance_info"],
"search_tool_name": "find_tools",
"call_tool_name": "invoke_tool",
}
_apply_tool_search_transform(mock_mcp, config)
mock_mcp.add_transform.assert_called_once()
transform = mock_mcp.add_transform.call_args[0][0]
assert isinstance(transform, RegexSearchTransform)
def test_apply_transform_uses_defaults_for_missing_keys():
"""Missing config keys fall back to sensible defaults (BM25)."""
mock_mcp = MagicMock()
config = {} # All keys missing — should use defaults
_apply_tool_search_transform(mock_mcp, config)
mock_mcp.add_transform.assert_called_once()
transform = mock_mcp.add_transform.call_args[0][0]
assert isinstance(transform, BM25SearchTransform)
def test_fix_call_tool_arguments_replaces_anyof():
"""_fix_call_tool_arguments replaces anyOf with flat type: object."""
tool = SimpleNamespace(
parameters={
"type": "object",
"properties": {
"name": {"type": "string"},
"arguments": {
"anyOf": [
{"type": "object", "additionalProperties": True},
{"type": "null"},
],
"default": None,
},
},
}
)
result = _fix_call_tool_arguments(tool)
assert result.parameters["properties"]["arguments"] == {
"additionalProperties": True,
"default": None,
"description": "Arguments to pass to the tool",
"type": "object",
}
# Other properties untouched
assert result.parameters["properties"]["name"] == {"type": "string"}
def test_fix_call_tool_arguments_no_arguments_field():
"""_fix_call_tool_arguments is a no-op when arguments field is absent."""
tool = SimpleNamespace(
parameters={
"type": "object",
"properties": {"name": {"type": "string"}},
}
)
result = _fix_call_tool_arguments(tool)
assert "arguments" not in result.parameters["properties"]
def test_serialize_tools_strips_output_schema():
"""Custom serializer removes outputSchema from tool definitions."""
mock_tool = MagicMock()
mock_mcp_tool = MagicMock()
mock_mcp_tool.model_dump.return_value = {
"name": "test_tool",
"description": "A test tool",
"inputSchema": {"type": "object", "properties": {"x": {"type": "integer"}}},
"outputSchema": {
"type": "object",
"properties": {"result": {"type": "string"}},
},
}
mock_tool.to_mcp_tool.return_value = mock_mcp_tool
result = _serialize_tools_without_output_schema([mock_tool])
assert len(result) == 1
assert result[0]["name"] == "test_tool"
assert "inputSchema" in result[0]
assert "outputSchema" not in result[0]
def test_serialize_tools_handles_no_output_schema():
"""Custom serializer works when tool has no outputSchema."""
mock_tool = MagicMock()
mock_mcp_tool = MagicMock()
mock_mcp_tool.model_dump.return_value = {
"name": "simple_tool",
"inputSchema": {"type": "object"},
}
mock_tool.to_mcp_tool.return_value = mock_mcp_tool
result = _serialize_tools_without_output_schema([mock_tool])
assert len(result) == 1
assert result[0]["name"] == "simple_tool"
assert "outputSchema" not in result[0]