mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat(mcp): add BM25 tool search transform to reduce initial context size (#38562)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -144,7 +144,7 @@ solr = ["sqlalchemy-solr >= 0.2.0"]
|
||||
elasticsearch = ["elasticsearch-dbapi>=0.2.12, <0.3.0"]
|
||||
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
|
||||
excel = ["xlrd>=1.2.0, <1.3"]
|
||||
fastmcp = ["fastmcp==2.14.3"]
|
||||
fastmcp = ["fastmcp>=3.1.0,<4.0"]
|
||||
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
|
||||
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
|
||||
gevent = ["gevent>=23.9.1"]
|
||||
|
||||
@@ -10,6 +10,8 @@
|
||||
# via
|
||||
# -r requirements/development.in
|
||||
# apache-superset
|
||||
aiofile==3.9.0
|
||||
# via py-key-value-aio
|
||||
alembic==1.15.2
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -26,8 +28,10 @@ anyio==4.11.0
|
||||
# via
|
||||
# httpx
|
||||
# mcp
|
||||
# py-key-value-aio
|
||||
# sse-starlette
|
||||
# starlette
|
||||
# watchfiles
|
||||
apispec==6.6.1
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -65,9 +69,7 @@ bcrypt==4.3.0
|
||||
# -c requirements/base-constraint.txt
|
||||
# paramiko
|
||||
beartype==0.22.5
|
||||
# via
|
||||
# py-key-value-aio
|
||||
# py-key-value-shared
|
||||
# via py-key-value-aio
|
||||
billiard==4.2.1
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -100,6 +102,8 @@ cachetools==6.2.1
|
||||
# -c requirements/base-constraint.txt
|
||||
# google-auth
|
||||
# py-key-value-aio
|
||||
caio==0.9.25
|
||||
# via aiofile
|
||||
cattrs==25.1.1
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -138,7 +142,6 @@ click==8.2.1
|
||||
# click-repl
|
||||
# flask
|
||||
# flask-appbuilder
|
||||
# typer
|
||||
# uvicorn
|
||||
click-didyoumean==0.3.1
|
||||
# via
|
||||
@@ -156,8 +159,6 @@ click-repl==0.3.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# celery
|
||||
cloudpickle==3.1.2
|
||||
# via pydocket
|
||||
cmdstanpy==1.1.0
|
||||
# via prophet
|
||||
colorama==0.4.6
|
||||
@@ -206,8 +207,6 @@ deprecation==2.1.0
|
||||
# apache-superset
|
||||
dill==0.4.0
|
||||
# via pylint
|
||||
diskcache==5.6.3
|
||||
# via py-key-value-aio
|
||||
distlib==0.3.8
|
||||
# via virtualenv
|
||||
dnspython==2.7.0
|
||||
@@ -237,9 +236,7 @@ et-xmlfile==2.0.0
|
||||
# openpyxl
|
||||
exceptiongroup==1.3.0
|
||||
# via fastmcp
|
||||
fakeredis==2.32.1
|
||||
# via pydocket
|
||||
fastmcp==2.14.3
|
||||
fastmcp==3.1.0
|
||||
# via apache-superset
|
||||
filelock==3.20.3
|
||||
# via
|
||||
@@ -474,6 +471,8 @@ jsonpath-ng==1.7.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
jsonref==1.1.0
|
||||
# via fastmcp
|
||||
jsonschema==4.23.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -504,8 +503,6 @@ limits==5.1.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# flask-limiter
|
||||
lupa==2.6
|
||||
# via fakeredis
|
||||
mako==1.3.10
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -603,7 +600,7 @@ openpyxl==3.1.5
|
||||
# -c requirements/base-constraint.txt
|
||||
# pandas
|
||||
opentelemetry-api==1.39.1
|
||||
# via pydocket
|
||||
# via fastmcp
|
||||
ordered-set==4.1.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -622,6 +619,7 @@ packaging==25.0
|
||||
# deprecation
|
||||
# docker
|
||||
# duckdb-engine
|
||||
# fastmcp
|
||||
# google-cloud-bigquery
|
||||
# gunicorn
|
||||
# limits
|
||||
@@ -653,8 +651,6 @@ parsedatetime==2.6
|
||||
# apache-superset
|
||||
pathable==0.4.3
|
||||
# via jsonschema-path
|
||||
pathvalidate==3.3.1
|
||||
# via py-key-value-aio
|
||||
pgsanity==0.2.9
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -691,8 +687,6 @@ prison==0.2.1
|
||||
# flask-appbuilder
|
||||
progress==1.6
|
||||
# via apache-superset
|
||||
prometheus-client==0.23.1
|
||||
# via pydocket
|
||||
prompt-toolkit==3.0.51
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -714,12 +708,8 @@ psutil==6.1.0
|
||||
# via apache-superset
|
||||
psycopg2-binary==2.9.9
|
||||
# via apache-superset
|
||||
py-key-value-aio==0.3.0
|
||||
# via
|
||||
# fastmcp
|
||||
# pydocket
|
||||
py-key-value-shared==0.3.0
|
||||
# via py-key-value-aio
|
||||
py-key-value-aio==0.4.4
|
||||
# via fastmcp
|
||||
pyarrow==16.1.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -758,8 +748,6 @@ pydantic-settings==2.10.1
|
||||
# via mcp
|
||||
pydata-google-auth==1.9.0
|
||||
# via pandas-gbq
|
||||
pydocket==0.17.1
|
||||
# via fastmcp
|
||||
pydruid==0.6.9
|
||||
# via apache-superset
|
||||
pyfakefs==5.3.5
|
||||
@@ -844,8 +832,6 @@ python-dotenv==1.1.0
|
||||
# apache-superset
|
||||
# fastmcp
|
||||
# pydantic-settings
|
||||
python-json-logger==4.0.0
|
||||
# via pydocket
|
||||
python-ldap==3.4.4
|
||||
# via apache-superset
|
||||
python-multipart==0.0.20
|
||||
@@ -866,15 +852,13 @@ pyyaml==6.0.2
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# apispec
|
||||
# fastmcp
|
||||
# jsonschema-path
|
||||
# pre-commit
|
||||
redis==5.3.1
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# fakeredis
|
||||
# py-key-value-aio
|
||||
# pydocket
|
||||
referencing==0.36.2
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -910,9 +894,7 @@ rich==13.9.4
|
||||
# cyclopts
|
||||
# fastmcp
|
||||
# flask-limiter
|
||||
# pydocket
|
||||
# rich-rst
|
||||
# typer
|
||||
rich-rst==1.3.1
|
||||
# via cyclopts
|
||||
rpds-py==0.25.0
|
||||
@@ -944,8 +926,6 @@ setuptools==80.9.0
|
||||
# pydata-google-auth
|
||||
# zope-event
|
||||
# zope-interface
|
||||
shellingham==1.5.4
|
||||
# via typer
|
||||
shillelagh==1.4.3
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -973,7 +953,6 @@ sniffio==1.3.1
|
||||
sortedcontainers==2.4.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# fakeredis
|
||||
# trio
|
||||
sqlalchemy==1.4.54
|
||||
# via
|
||||
@@ -1034,8 +1013,6 @@ trio-websocket==0.12.2
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
# selenium
|
||||
typer==0.20.0
|
||||
# via pydocket
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -1048,16 +1025,14 @@ typing-extensions==4.15.0
|
||||
# limits
|
||||
# mcp
|
||||
# opentelemetry-api
|
||||
# py-key-value-shared
|
||||
# py-key-value-aio
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pydocket
|
||||
# pyopenssl
|
||||
# referencing
|
||||
# selenium
|
||||
# shillelagh
|
||||
# starlette
|
||||
# typer
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via
|
||||
@@ -1072,6 +1047,8 @@ tzdata==2025.2
|
||||
# pandas
|
||||
tzlocal==5.2
|
||||
# via trino
|
||||
uncalled-for==0.2.0
|
||||
# via fastmcp
|
||||
url-normalize==2.2.1
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
@@ -1101,6 +1078,8 @@ watchdog==6.0.0
|
||||
# -c requirements/base-constraint.txt
|
||||
# apache-superset
|
||||
# apache-superset-extensions-cli
|
||||
watchfiles==1.1.1
|
||||
# via fastmcp
|
||||
wcwidth==0.2.13
|
||||
# via
|
||||
# -c requirements/base-constraint.txt
|
||||
|
||||
@@ -384,15 +384,12 @@ class ChartList(BaseModel):
|
||||
class ColumnRef(BaseModel):
|
||||
name: str = Field(
|
||||
...,
|
||||
description="Column name",
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
|
||||
)
|
||||
label: str | None = Field(
|
||||
None, description="Display label for the column", max_length=500
|
||||
)
|
||||
dtype: str | None = Field(None, description="Data type hint")
|
||||
label: str | None = Field(None, max_length=500)
|
||||
dtype: str | None = None
|
||||
aggregate: (
|
||||
Literal[
|
||||
"SUM",
|
||||
@@ -407,11 +404,7 @@ class ColumnRef(BaseModel):
|
||||
"PERCENTILE",
|
||||
]
|
||||
| None
|
||||
) = Field(
|
||||
None,
|
||||
description="SQL aggregation function. Only these validated functions are "
|
||||
"supported to prevent SQL errors.",
|
||||
)
|
||||
) = Field(None, description="SQL aggregate function")
|
||||
|
||||
@field_validator("name")
|
||||
@classmethod
|
||||
@@ -431,25 +424,22 @@ class ColumnRef(BaseModel):
|
||||
|
||||
|
||||
class AxisConfig(BaseModel):
|
||||
title: str | None = Field(None, description="Axis title", max_length=200)
|
||||
scale: Literal["linear", "log"] | None = Field(
|
||||
"linear", description="Axis scale type"
|
||||
)
|
||||
format: str | None = Field(
|
||||
None, description="Format string (e.g. '$,.2f')", max_length=50
|
||||
)
|
||||
title: str | None = Field(None, max_length=200)
|
||||
scale: Literal["linear", "log"] | None = "linear"
|
||||
format: str | None = Field(None, description="e.g. '$,.2f'", max_length=50)
|
||||
|
||||
|
||||
class LegendConfig(BaseModel):
|
||||
show: bool = Field(True, description="Whether to show legend")
|
||||
position: Literal["top", "bottom", "left", "right"] | None = Field(
|
||||
"right", description="Legend position"
|
||||
)
|
||||
show: bool = True
|
||||
position: Literal["top", "bottom", "left", "right"] | None = "right"
|
||||
|
||||
|
||||
class FilterConfig(BaseModel):
|
||||
column: str = Field(
|
||||
..., description="Column to filter on", min_length=1, max_length=255
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=255,
|
||||
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
|
||||
)
|
||||
op: Literal[
|
||||
"=",
|
||||
@@ -465,17 +455,11 @@ class FilterConfig(BaseModel):
|
||||
"NOT IN",
|
||||
] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Filter operator. Use LIKE/ILIKE for pattern matching with % wildcards "
|
||||
"(e.g., '%mario%'). Use IN/NOT IN with a list of values."
|
||||
),
|
||||
description="LIKE/ILIKE use % wildcards. IN/NOT IN take a list.",
|
||||
)
|
||||
value: str | int | float | bool | list[str | int | float | bool] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Filter value. For IN/NOT IN operators, provide a list of values. "
|
||||
"For LIKE/ILIKE, use % as wildcard (e.g., '%mario%')."
|
||||
),
|
||||
description="For IN/NOT IN, provide a list.",
|
||||
)
|
||||
|
||||
@field_validator("column")
|
||||
@@ -516,26 +500,13 @@ class FilterConfig(BaseModel):
|
||||
class PieChartConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
chart_type: Literal["pie"] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Chart type discriminator - MUST be 'pie' for pie/donut charts. "
|
||||
"This field is REQUIRED and tells Superset which chart "
|
||||
"configuration schema to use."
|
||||
),
|
||||
)
|
||||
dimension: ColumnRef = Field(
|
||||
..., description="Category column that defines the pie slices"
|
||||
)
|
||||
chart_type: Literal["pie"] = "pie"
|
||||
dimension: ColumnRef = Field(..., description="Category column for slices")
|
||||
metric: ColumnRef = Field(
|
||||
...,
|
||||
description=(
|
||||
"Value metric that determines slice sizes. "
|
||||
"Must include an aggregate function (e.g., SUM, COUNT)."
|
||||
),
|
||||
..., description="Value metric (needs aggregate e.g. SUM, COUNT)"
|
||||
)
|
||||
donut: bool = Field(False, description="Render as a donut chart with a center hole")
|
||||
show_labels: bool = Field(True, description="Display labels on slices")
|
||||
donut: bool = False
|
||||
show_labels: bool = True
|
||||
label_type: Literal[
|
||||
"key",
|
||||
"value",
|
||||
@@ -544,63 +515,32 @@ class PieChartConfig(BaseModel):
|
||||
"key_percent",
|
||||
"key_value_percent",
|
||||
"value_percent",
|
||||
] = Field("key_value_percent", description="Type of labels to show on slices")
|
||||
sort_by_metric: bool = Field(True, description="Sort slices by metric value")
|
||||
show_legend: bool = Field(True, description="Whether to show legend")
|
||||
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
|
||||
row_limit: int = Field(
|
||||
100,
|
||||
description="Maximum number of slices to display",
|
||||
ge=1,
|
||||
le=10000,
|
||||
)
|
||||
number_format: str = Field(
|
||||
"SMART_NUMBER",
|
||||
description="Number format string",
|
||||
max_length=50,
|
||||
)
|
||||
show_total: bool = Field(False, description="Display aggregate count in center")
|
||||
labels_outside: bool = Field(True, description="Place labels outside the pie")
|
||||
outer_radius: int = Field(
|
||||
70,
|
||||
description="Outer edge radius as a percentage (1-100)",
|
||||
ge=1,
|
||||
le=100,
|
||||
)
|
||||
] = "key_value_percent"
|
||||
sort_by_metric: bool = True
|
||||
show_legend: bool = True
|
||||
filters: List[FilterConfig] | None = None
|
||||
row_limit: int = Field(100, description="Max slices", ge=1, le=10000)
|
||||
number_format: str = Field("SMART_NUMBER", max_length=50)
|
||||
show_total: bool = Field(False, description="Show total in center")
|
||||
labels_outside: bool = True
|
||||
outer_radius: int = Field(70, description="Outer radius % (1-100)", ge=1, le=100)
|
||||
inner_radius: int = Field(
|
||||
30,
|
||||
description="Inner radius as a percentage for donut (1-100)",
|
||||
ge=1,
|
||||
le=100,
|
||||
30, description="Donut inner radius % (1-100)", ge=1, le=100
|
||||
)
|
||||
|
||||
|
||||
class PivotTableChartConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
chart_type: Literal["pivot_table"] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Chart type discriminator - MUST be 'pivot_table' for interactive "
|
||||
"pivot tables. This field is REQUIRED."
|
||||
),
|
||||
)
|
||||
rows: List[ColumnRef] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Row grouping columns (at least one required)",
|
||||
)
|
||||
chart_type: Literal["pivot_table"] = "pivot_table"
|
||||
rows: List[ColumnRef] = Field(..., min_length=1, description="Row grouping columns")
|
||||
columns: List[ColumnRef] | None = Field(
|
||||
None,
|
||||
description="Column grouping columns (optional, for cross-tabulation)",
|
||||
None, description="Column groups for cross-tabulation"
|
||||
)
|
||||
metrics: List[ColumnRef] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description=(
|
||||
"Metrics to aggregate. Each must have an aggregate function "
|
||||
"(e.g., SUM, COUNT, AVG)."
|
||||
),
|
||||
description="Metrics (need aggregate e.g. SUM, COUNT, AVG)",
|
||||
)
|
||||
aggregate_function: Literal[
|
||||
"Sum",
|
||||
@@ -614,108 +554,56 @@ class PivotTableChartConfig(BaseModel):
|
||||
"Count Unique Values",
|
||||
"First",
|
||||
"Last",
|
||||
] = Field("Sum", description="Default aggregation function for the pivot table")
|
||||
show_row_totals: bool = Field(True, description="Show row totals")
|
||||
show_column_totals: bool = Field(True, description="Show column totals")
|
||||
transpose: bool = Field(False, description="Swap rows and columns")
|
||||
combine_metric: bool = Field(
|
||||
False,
|
||||
description="Display metrics side by side within columns",
|
||||
)
|
||||
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
|
||||
row_limit: int = Field(
|
||||
10000,
|
||||
description="Maximum number of cells",
|
||||
ge=1,
|
||||
le=50000,
|
||||
)
|
||||
value_format: str = Field(
|
||||
"SMART_NUMBER",
|
||||
description="Value format string",
|
||||
max_length=50,
|
||||
)
|
||||
] = "Sum"
|
||||
show_row_totals: bool = True
|
||||
show_column_totals: bool = True
|
||||
transpose: bool = False
|
||||
combine_metric: bool = Field(False, description="Metrics side by side in columns")
|
||||
filters: List[FilterConfig] | None = None
|
||||
row_limit: int = Field(10000, description="Max cells", ge=1, le=50000)
|
||||
value_format: str = Field("SMART_NUMBER", max_length=50)
|
||||
|
||||
|
||||
class MixedTimeseriesChartConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
chart_type: Literal["mixed_timeseries"] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Chart type discriminator - MUST be 'mixed_timeseries' for charts "
|
||||
"that combine two different series types (e.g., line + bar). "
|
||||
"This field is REQUIRED."
|
||||
),
|
||||
)
|
||||
x: ColumnRef = Field(..., description="X-axis temporal column (shared)")
|
||||
time_grain: TimeGrain | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Time granularity for the x-axis. "
|
||||
"Common values: PT1H (hourly), P1D (daily), P1W (weekly), "
|
||||
"P1M (monthly), P1Y (yearly)."
|
||||
),
|
||||
)
|
||||
chart_type: Literal["mixed_timeseries"] = "mixed_timeseries"
|
||||
x: ColumnRef = Field(..., description="Shared temporal X-axis column")
|
||||
time_grain: TimeGrain | None = Field(None, description="PT1H, P1D, P1W, P1M, P1Y")
|
||||
# Primary series (Query A)
|
||||
y: List[ColumnRef] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Primary Y-axis metrics (Query A)",
|
||||
)
|
||||
primary_kind: Literal["line", "bar", "area", "scatter"] = Field(
|
||||
"line", description="Primary series chart type"
|
||||
)
|
||||
group_by: ColumnRef | None = Field(
|
||||
None, description="Group by column for primary series"
|
||||
)
|
||||
y: List[ColumnRef] = Field(..., min_length=1, description="Primary Y-axis metrics")
|
||||
primary_kind: Literal["line", "bar", "area", "scatter"] = "line"
|
||||
group_by: ColumnRef | None = Field(None, description="Primary series group by")
|
||||
# Secondary series (Query B)
|
||||
y_secondary: List[ColumnRef] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Secondary Y-axis metrics (Query B)",
|
||||
)
|
||||
secondary_kind: Literal["line", "bar", "area", "scatter"] = Field(
|
||||
"bar", description="Secondary series chart type"
|
||||
..., min_length=1, description="Secondary Y-axis metrics"
|
||||
)
|
||||
secondary_kind: Literal["line", "bar", "area", "scatter"] = "bar"
|
||||
group_by_secondary: ColumnRef | None = Field(
|
||||
None, description="Group by column for secondary series"
|
||||
None, description="Secondary series group by"
|
||||
)
|
||||
# Display options
|
||||
show_legend: bool = Field(True, description="Whether to show legend")
|
||||
x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
|
||||
y_axis: AxisConfig | None = Field(None, description="Primary Y-axis configuration")
|
||||
y_axis_secondary: AxisConfig | None = Field(
|
||||
None, description="Secondary Y-axis configuration"
|
||||
)
|
||||
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
|
||||
show_legend: bool = True
|
||||
x_axis: AxisConfig | None = None
|
||||
y_axis: AxisConfig | None = None
|
||||
y_axis_secondary: AxisConfig | None = None
|
||||
filters: List[FilterConfig] | None = None
|
||||
|
||||
|
||||
class TableChartConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
chart_type: Literal["table"] = Field(
|
||||
..., description="Chart type (REQUIRED: must be 'table')"
|
||||
)
|
||||
chart_type: Literal["table"] = "table"
|
||||
viz_type: Literal["table", "ag-grid-table"] = Field(
|
||||
"table",
|
||||
description=(
|
||||
"Visualization type: 'table' for standard table, 'ag-grid-table' for "
|
||||
"AG Grid Interactive Table with advanced features like column resizing, "
|
||||
"sorting, filtering, and server-side pagination"
|
||||
),
|
||||
"table", description="'ag-grid-table' for interactive features"
|
||||
)
|
||||
columns: List[ColumnRef] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description=(
|
||||
"Columns to display. Must have at least one column. Each column must have "
|
||||
"a unique label "
|
||||
"(either explicitly set via 'label' field or auto-generated "
|
||||
"from name/aggregate)"
|
||||
),
|
||||
description="Columns with unique labels",
|
||||
)
|
||||
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
|
||||
sort_by: List[str] | None = Field(None, description="Columns to sort by")
|
||||
filters: List[FilterConfig] | None = None
|
||||
sort_by: List[str] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_unique_column_labels(self) -> "TableChartConfig":
|
||||
@@ -748,56 +636,26 @@ class TableChartConfig(BaseModel):
|
||||
class XYChartConfig(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
chart_type: Literal["xy"] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Chart type discriminator - MUST be 'xy' for XY charts "
|
||||
"(line, bar, area, scatter). "
|
||||
"This field is REQUIRED and tells Superset which chart "
|
||||
"configuration schema to use."
|
||||
),
|
||||
)
|
||||
chart_type: Literal["xy"] = "xy"
|
||||
x: ColumnRef = Field(..., description="X-axis column")
|
||||
y: List[ColumnRef] = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
description="Y-axis columns (metrics). Must have at least one Y-axis column. "
|
||||
"Each column must have a unique label "
|
||||
"that doesn't conflict with x-axis or group_by labels",
|
||||
)
|
||||
kind: Literal["line", "bar", "area", "scatter"] = Field(
|
||||
"line", description="Chart visualization type"
|
||||
..., min_length=1, description="Y-axis metrics (unique labels)"
|
||||
)
|
||||
kind: Literal["line", "bar", "area", "scatter"] = "line"
|
||||
time_grain: TimeGrain | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Time granularity for the x-axis when it's a temporal column. "
|
||||
"Common values: PT1S (second), PT1M (minute), PT1H (hour), "
|
||||
"P1D (day), P1W (week), P1M (month), P3M (quarter), P1Y (year). "
|
||||
"If not specified, Superset will use its default behavior."
|
||||
),
|
||||
None, description="PT1S, PT1M, PT1H, P1D, P1W, P1M, P3M, P1Y"
|
||||
)
|
||||
orientation: Literal["vertical", "horizontal"] | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Bar chart orientation. Only applies when kind='bar'. "
|
||||
"'vertical' (default): bars extend upward. "
|
||||
"'horizontal': bars extend rightward, useful for long category names."
|
||||
),
|
||||
)
|
||||
stacked: bool = Field(
|
||||
False,
|
||||
description="Stack bars/areas on top of each other instead of side-by-side",
|
||||
None, description="Bar orientation (only for kind='bar')"
|
||||
)
|
||||
stacked: bool = False
|
||||
group_by: ColumnRef | None = Field(
|
||||
None,
|
||||
description="Column to group by (creates series/breakdown). "
|
||||
"Use this field for series grouping — do NOT use 'series'.",
|
||||
None, description="Series breakdown column (not 'series')"
|
||||
)
|
||||
x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
|
||||
y_axis: AxisConfig | None = Field(None, description="Y-axis configuration")
|
||||
legend: LegendConfig | None = Field(None, description="Legend configuration")
|
||||
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
|
||||
x_axis: AxisConfig | None = None
|
||||
y_axis: AxisConfig | None = None
|
||||
legend: LegendConfig | None = None
|
||||
filters: List[FilterConfig] | None = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_unique_column_labels(self) -> "XYChartConfig":
|
||||
@@ -949,21 +807,12 @@ class GenerateChartRequest(QueryCacheControl):
|
||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="Chart configuration")
|
||||
chart_name: str | None = Field(
|
||||
None,
|
||||
description="Custom chart name (optional, auto-generates if not provided)",
|
||||
max_length=255,
|
||||
)
|
||||
save_chart: bool = Field(
|
||||
default=False,
|
||||
description="Whether to permanently save the chart in Superset",
|
||||
)
|
||||
generate_preview: bool = Field(
|
||||
default=True,
|
||||
description="Whether to generate a preview image",
|
||||
None, description="Auto-generates if omitted", max_length=255
|
||||
)
|
||||
save_chart: bool = Field(default=False, description="Save permanently in Superset")
|
||||
generate_preview: bool = True
|
||||
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
|
||||
default_factory=lambda: ["url"],
|
||||
description="List of preview formats to generate",
|
||||
)
|
||||
|
||||
@field_validator("chart_name")
|
||||
@@ -1002,20 +851,14 @@ class GenerateExploreLinkRequest(FormDataCacheControl):
|
||||
|
||||
|
||||
class UpdateChartRequest(QueryCacheControl):
|
||||
identifier: int | str = Field(..., description="Chart identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="New chart configuration")
|
||||
identifier: int | str = Field(..., description="Chart ID or UUID")
|
||||
config: ChartConfig
|
||||
chart_name: str | None = Field(
|
||||
None,
|
||||
description="New chart name (optional, will auto-generate if not provided)",
|
||||
max_length=255,
|
||||
)
|
||||
generate_preview: bool = Field(
|
||||
default=True,
|
||||
description="Whether to generate a preview after updating",
|
||||
None, description="Auto-generates if omitted", max_length=255
|
||||
)
|
||||
generate_preview: bool = True
|
||||
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
|
||||
default_factory=lambda: ["url"],
|
||||
description="List of preview formats to generate",
|
||||
)
|
||||
|
||||
@field_validator("chart_name")
|
||||
@@ -1027,15 +870,11 @@ class UpdateChartRequest(QueryCacheControl):
|
||||
|
||||
class UpdateChartPreviewRequest(FormDataCacheControl):
|
||||
form_data_key: str = Field(..., description="Existing form_data_key to update")
|
||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="New chart configuration")
|
||||
generate_preview: bool = Field(
|
||||
default=True,
|
||||
description="Whether to generate a preview after updating",
|
||||
)
|
||||
dataset_id: int | str = Field(..., description="Dataset ID or UUID")
|
||||
config: ChartConfig
|
||||
generate_preview: bool = True
|
||||
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
|
||||
default_factory=lambda: ["url"],
|
||||
description="List of preview formats to generate",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -37,22 +37,10 @@ class CacheControlMixin(BaseModel):
|
||||
- Dashboard Cache: Caches rendered dashboard components
|
||||
"""
|
||||
|
||||
use_cache: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether to use Superset's cache layers. When True, will serve from "
|
||||
"cache if available (query results, metadata, form data). When False, "
|
||||
"will bypass cache and fetch fresh data."
|
||||
),
|
||||
)
|
||||
use_cache: bool = Field(default=True, description="Use cache if available")
|
||||
|
||||
force_refresh: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to force refresh cached data. When True, will invalidate "
|
||||
"existing cache entries and fetch fresh data, then update the cache. "
|
||||
"Overrides use_cache=True if both are specified."
|
||||
),
|
||||
default=False, description="Invalidate cache and fetch fresh data"
|
||||
)
|
||||
|
||||
|
||||
@@ -65,12 +53,7 @@ class QueryCacheControl(CacheControlMixin):
|
||||
"""
|
||||
|
||||
cache_timeout: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Override the default cache timeout in seconds for this query. "
|
||||
"If not specified, uses dataset-level or global cache settings. "
|
||||
"Set to 0 to disable caching for this specific query."
|
||||
),
|
||||
default=None, description="Cache timeout override in seconds (0 to disable)"
|
||||
)
|
||||
|
||||
|
||||
@@ -83,11 +66,7 @@ class MetadataCacheControl(CacheControlMixin):
|
||||
"""
|
||||
|
||||
refresh_metadata: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to refresh metadata cache for datasets, tables, and columns. "
|
||||
"Useful when database schema has changed and you need fresh metadata."
|
||||
),
|
||||
default=False, description="Refresh metadata cache for schema changes"
|
||||
)
|
||||
|
||||
|
||||
@@ -100,11 +79,7 @@ class FormDataCacheControl(CacheControlMixin):
|
||||
"""
|
||||
|
||||
cache_form_data: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether to cache the form data configuration for future use. "
|
||||
"When False, generates temporary configurations that are not cached."
|
||||
),
|
||||
default=True, description="Cache form data for future use"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -287,45 +287,31 @@ class GetDashboardInfoRequest(MetadataCacheControl):
|
||||
|
||||
|
||||
class DashboardInfo(BaseModel):
|
||||
id: int | None = Field(None, description="Dashboard ID")
|
||||
dashboard_title: str | None = Field(None, description="Dashboard title")
|
||||
slug: str | None = Field(None, description="Dashboard slug")
|
||||
description: str | None = Field(None, description="Dashboard description")
|
||||
css: str | None = Field(None, description="Custom CSS for the dashboard")
|
||||
certified_by: str | None = Field(None, description="Who certified the dashboard")
|
||||
certification_details: str | None = Field(None, description="Certification details")
|
||||
json_metadata: str | None = Field(
|
||||
None, description="Dashboard metadata (JSON string)"
|
||||
)
|
||||
position_json: str | None = Field(None, description="Chart positions (JSON string)")
|
||||
published: bool | None = Field(
|
||||
None, description="Whether the dashboard is published"
|
||||
)
|
||||
is_managed_externally: bool | None = Field(
|
||||
None, description="Whether managed externally"
|
||||
)
|
||||
external_url: str | None = Field(None, description="External URL")
|
||||
created_on: str | datetime | None = Field(None, description="Creation timestamp")
|
||||
changed_on: str | datetime | None = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
created_by: str | None = Field(None, description="Dashboard creator (username)")
|
||||
changed_by: str | None = Field(None, description="Last modifier (username)")
|
||||
uuid: str | None = Field(None, description="Dashboard UUID (converted to string)")
|
||||
url: str | None = Field(None, description="Dashboard URL")
|
||||
created_on_humanized: str | None = Field(
|
||||
None, description="Humanized creation time"
|
||||
)
|
||||
changed_on_humanized: str | None = Field(
|
||||
None, description="Humanized modification time"
|
||||
)
|
||||
chart_count: int = Field(0, description="Number of charts in the dashboard")
|
||||
owners: List[UserInfo] = Field(default_factory=list, description="Dashboard owners")
|
||||
tags: List[TagInfo] = Field(default_factory=list, description="Dashboard tags")
|
||||
roles: List[RoleInfo] = Field(default_factory=list, description="Dashboard roles")
|
||||
charts: List[ChartInfo] = Field(
|
||||
default_factory=list, description="Dashboard charts"
|
||||
)
|
||||
id: int | None = None
|
||||
dashboard_title: str | None = None
|
||||
slug: str | None = None
|
||||
description: str | None = None
|
||||
css: str | None = None
|
||||
certified_by: str | None = None
|
||||
certification_details: str | None = None
|
||||
json_metadata: str | None = None
|
||||
position_json: str | None = None
|
||||
published: bool | None = None
|
||||
is_managed_externally: bool | None = None
|
||||
external_url: str | None = None
|
||||
created_on: str | datetime | None = None
|
||||
changed_on: str | datetime | None = None
|
||||
created_by: str | None = None
|
||||
changed_by: str | None = None
|
||||
uuid: str | None = None
|
||||
url: str | None = None
|
||||
created_on_humanized: str | None = None
|
||||
changed_on_humanized: str | None = None
|
||||
chart_count: int = 0
|
||||
owners: List[UserInfo] = Field(default_factory=list)
|
||||
tags: List[TagInfo] = Field(default_factory=list)
|
||||
roles: List[RoleInfo] = Field(default_factory=list)
|
||||
charts: List[ChartInfo] = Field(default_factory=list)
|
||||
|
||||
# Fields for permalink/filter state support
|
||||
permalink_key: str | None = Field(
|
||||
|
||||
@@ -25,7 +25,6 @@ Following the Stack Overflow recommendation:
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from flask import current_app, Flask, has_app_context
|
||||
|
||||
@@ -52,62 +51,45 @@ try:
|
||||
logger.info("Reusing existing Flask app from app context for MCP service")
|
||||
# Use _get_current_object() to get the actual Flask app, not the LocalProxy
|
||||
app = current_app._get_current_object()
|
||||
elif appbuilder_initialized:
|
||||
# appbuilder is initialized but we have no app context. Calling
|
||||
# create_app() here would invoke appbuilder.init_app() a second
|
||||
# time with a *different* Flask app, overwriting shared internal
|
||||
# state (views, security manager, etc.). Fail loudly instead of
|
||||
# silently corrupting the singleton.
|
||||
raise RuntimeError(
|
||||
"appbuilder is already initialized but no Flask app context is "
|
||||
"available. Cannot call create_app() as it would re-initialize "
|
||||
"appbuilder with a different Flask app instance."
|
||||
)
|
||||
else:
|
||||
# Either appbuilder is not initialized (standalone MCP server),
|
||||
# or appbuilder is initialized but we're not in an app context
|
||||
# (edge case - should rarely happen). In both cases, create a minimal app.
|
||||
# Standalone MCP server — Superset models are deeply coupled to
|
||||
# appbuilder, security_manager, event_logger, encrypted_field_factory,
|
||||
# etc. so we use create_app() for full initialization rather than
|
||||
# trying to init a minimal subset (which leads to cascading failures).
|
||||
#
|
||||
# We avoid calling create_app() which would run full FAB initialization
|
||||
# and could corrupt the shared appbuilder singleton if main app starts.
|
||||
from superset.app import SupersetApp
|
||||
# create_app() is safe here because in standalone mode the main
|
||||
# Superset web server is not running in-process.
|
||||
from superset.app import create_app
|
||||
from superset.mcp_service.mcp_config import get_mcp_config
|
||||
|
||||
if appbuilder_initialized:
|
||||
logger.warning(
|
||||
"Appbuilder initialized but not in app context - "
|
||||
"creating separate MCP Flask app"
|
||||
)
|
||||
else:
|
||||
logger.info("Creating minimal Flask app for standalone MCP service")
|
||||
|
||||
# Disable debug mode to avoid side-effects like file watchers
|
||||
_mcp_app = SupersetApp(__name__)
|
||||
logger.info("Creating fully initialized Flask app for standalone MCP service")
|
||||
_mcp_app = create_app()
|
||||
_mcp_app.debug = False
|
||||
|
||||
# Load configuration
|
||||
config_module = os.environ.get("SUPERSET_CONFIG", "superset.config")
|
||||
_mcp_app.config.from_object(config_module)
|
||||
|
||||
# Apply MCP-specific configuration
|
||||
# Apply MCP-specific configuration on top
|
||||
mcp_config = get_mcp_config(_mcp_app.config)
|
||||
_mcp_app.config.update(mcp_config)
|
||||
|
||||
# Initialize only the minimal dependencies needed for MCP service
|
||||
with _mcp_app.app_context():
|
||||
try:
|
||||
from superset.extensions import db
|
||||
from superset.core.mcp.core_mcp_injection import (
|
||||
initialize_core_mcp_dependencies,
|
||||
)
|
||||
|
||||
db.init_app(_mcp_app)
|
||||
|
||||
# Initialize only MCP-specific dependencies
|
||||
# MCP tools import directly from superset.daos/models, so we only need
|
||||
# the MCP decorator injection, not the full superset_core abstraction
|
||||
from superset.core.mcp.core_mcp_injection import (
|
||||
initialize_core_mcp_dependencies,
|
||||
)
|
||||
|
||||
initialize_core_mcp_dependencies()
|
||||
|
||||
logger.info(
|
||||
"Minimal MCP dependencies initialized for standalone MCP service"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to initialize dependencies for MCP service: %s", e
|
||||
)
|
||||
initialize_core_mcp_dependencies()
|
||||
|
||||
app = _mcp_app
|
||||
logger.info("Minimal Flask app instance created successfully for MCP service")
|
||||
logger.info("Flask app fully initialized for standalone MCP service")
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Failed to create Flask app: %s", e)
|
||||
|
||||
@@ -227,10 +227,44 @@ MCP_RESPONSE_SIZE_CONFIG: Dict[str, Any] = {
|
||||
"get_chart_preview", # Returns URLs, not data
|
||||
"generate_explore_link", # Returns URLs
|
||||
"open_sql_lab_with_context", # Returns URLs
|
||||
"search_tools", # Returns tool schemas for discovery (intentionally large)
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# MCP Tool Search Transform Configuration
|
||||
# =============================================================================
|
||||
#
|
||||
# Overview:
|
||||
# ---------
|
||||
# When enabled, replaces the full tool catalog with a search interface.
|
||||
# LLMs see only 2 synthetic tools (search_tools + call_tool) plus any
|
||||
# pinned tools, and discover other tools on-demand via natural language search.
|
||||
# This reduces initial context by ~70% (from ~40k tokens to ~5-8k tokens).
|
||||
#
|
||||
# Strategies:
|
||||
# -----------
|
||||
# - "bm25": Natural language search using BM25 ranking (recommended)
|
||||
# - "regex": Pattern-based search using regular expressions
|
||||
#
|
||||
# Rollback:
|
||||
# ---------
|
||||
# Set enabled=False in superset_config.py for instant rollback.
|
||||
# =============================================================================
|
||||
MCP_TOOL_SEARCH_CONFIG: Dict[str, Any] = {
|
||||
"enabled": True, # Enabled by default — reduces initial context by ~70%
|
||||
"strategy": "bm25", # "bm25" (natural language) or "regex" (pattern matching)
|
||||
"max_results": 5, # Max tools returned per search
|
||||
"always_visible": [ # Tools always shown in list_tools (pinned)
|
||||
"health_check",
|
||||
"get_instance_info",
|
||||
],
|
||||
"search_tool_name": "search_tools", # Name of the search tool
|
||||
"call_tool_name": "call_tool", # Name of the call proxy tool
|
||||
}
|
||||
|
||||
|
||||
def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]:
|
||||
"""Default MCP auth factory using app.config values."""
|
||||
if not app.config.get("MCP_AUTH_ENABLED", False):
|
||||
|
||||
@@ -24,12 +24,17 @@ For multi-pod deployments, configure MCP_EVENT_STORE_CONFIG with Redis URL.
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
import uvicorn
|
||||
|
||||
from superset.mcp_service.app import create_mcp_app, init_fastmcp_server
|
||||
from superset.mcp_service.mcp_config import get_mcp_factory_config, MCP_STORE_CONFIG
|
||||
from superset.mcp_service.mcp_config import (
|
||||
get_mcp_factory_config,
|
||||
MCP_STORE_CONFIG,
|
||||
MCP_TOOL_SEARCH_CONFIG,
|
||||
)
|
||||
from superset.mcp_service.middleware import (
|
||||
create_response_size_guard_middleware,
|
||||
GlobalErrorHandlerMiddleware,
|
||||
@@ -111,8 +116,7 @@ def create_event_store(config: dict[str, Any] | None = None) -> Any | None:
|
||||
if config is None:
|
||||
config = MCP_STORE_CONFIG
|
||||
|
||||
redis_url = config.get("CACHE_REDIS_URL")
|
||||
if not redis_url:
|
||||
if not config.get("CACHE_REDIS_URL"):
|
||||
logging.info("EventStore: Using in-memory storage (single-pod mode)")
|
||||
return None
|
||||
|
||||
@@ -151,6 +155,117 @@ def create_event_store(config: dict[str, Any] | None = None) -> Any | None:
|
||||
return None
|
||||
|
||||
|
||||
def _strip_titles(obj: Any, in_properties_map: bool = False) -> Any:
|
||||
"""Recursively strip schema metadata ``title`` keys.
|
||||
|
||||
Keeps real field names inside ``properties`` (e.g. a property literally
|
||||
named ``title``), while removing auto-generated schema title metadata.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in obj.items():
|
||||
if key == "title" and not in_properties_map:
|
||||
continue
|
||||
result[key] = _strip_titles(value, in_properties_map=(key == "properties"))
|
||||
return result
|
||||
if isinstance(obj, list):
|
||||
return [_strip_titles(item, in_properties_map=False) for item in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _serialize_tools_without_output_schema(
|
||||
tools: Sequence[Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Serialize tools to JSON, stripping outputSchema and titles to reduce tokens.
|
||||
|
||||
LLMs only need inputSchema to call tools. outputSchema accounts for
|
||||
50-80% of the per-tool schema size, and auto-generated 'title' fields
|
||||
add ~12% bloat. Stripping both cuts search result tokens significantly.
|
||||
"""
|
||||
results = []
|
||||
for tool in tools:
|
||||
data = tool.to_mcp_tool().model_dump(mode="json", exclude_none=True)
|
||||
data.pop("outputSchema", None)
|
||||
if input_schema := data.get("inputSchema"):
|
||||
data["inputSchema"] = _strip_titles(input_schema)
|
||||
results.append(data)
|
||||
return results
|
||||
|
||||
|
||||
def _fix_call_tool_arguments(tool: Any) -> Any:
|
||||
"""Fix anyOf schema in call_tool ``arguments`` for MCP bridge compatibility.
|
||||
|
||||
FastMCP's BaseSearchTransform defines ``arguments`` as
|
||||
``dict[str, Any] | None`` which emits an ``anyOf`` JSON Schema.
|
||||
Some MCP bridges (mcp-remote, Claude Desktop) don't handle ``anyOf``
|
||||
and strip it, leaving the field without a ``type`` — causing all
|
||||
call_tool invocations to fail with "Input should be a valid dictionary".
|
||||
|
||||
Replaces the ``anyOf`` with a flat ``type: object``.
|
||||
"""
|
||||
if "arguments" in (props := (tool.parameters or {}).get("properties", {})):
|
||||
props["arguments"] = {
|
||||
"additionalProperties": True,
|
||||
"default": None,
|
||||
"description": "Arguments to pass to the tool",
|
||||
"type": "object",
|
||||
}
|
||||
return tool
|
||||
|
||||
|
||||
def _apply_tool_search_transform(mcp_instance: Any, config: dict[str, Any]) -> None:
|
||||
"""Apply tool search transform to reduce initial context size.
|
||||
|
||||
When enabled, replaces the full tool catalog with a search interface.
|
||||
LLMs see only synthetic search/call tools plus pinned tools, and
|
||||
discover other tools on-demand via natural language search.
|
||||
|
||||
Uses subclassing (not monkey-patching) to override ``_make_call_tool``
|
||||
and fix the ``arguments`` schema for MCP bridge compatibility.
|
||||
|
||||
NOTE: ``_make_call_tool`` is a private API in FastMCP 3.x
|
||||
(fastmcp>=3.1.0,<4.0). If FastMCP changes or removes this method
|
||||
in a future major version, these subclasses will need to be updated.
|
||||
"""
|
||||
strategy = config.get("strategy", "bm25")
|
||||
kwargs: dict[str, Any] = {
|
||||
"max_results": config.get("max_results", 5),
|
||||
"always_visible": config.get("always_visible", []),
|
||||
"search_tool_name": config.get("search_tool_name", "search_tools"),
|
||||
"call_tool_name": config.get("call_tool_name", "call_tool"),
|
||||
"search_result_serializer": _serialize_tools_without_output_schema,
|
||||
}
|
||||
|
||||
if strategy == "regex":
|
||||
from fastmcp.server.transforms.search import RegexSearchTransform
|
||||
|
||||
class _FixedRegexSearchTransform(RegexSearchTransform):
|
||||
"""Regex search with fixed call_tool arguments schema."""
|
||||
|
||||
def _make_call_tool(self) -> Any:
|
||||
return _fix_call_tool_arguments(super()._make_call_tool())
|
||||
|
||||
transform = _FixedRegexSearchTransform(**kwargs)
|
||||
else:
|
||||
from fastmcp.server.transforms.search import BM25SearchTransform
|
||||
|
||||
class _FixedBM25SearchTransform(BM25SearchTransform):
|
||||
"""BM25 search with fixed call_tool arguments schema."""
|
||||
|
||||
def _make_call_tool(self) -> Any:
|
||||
return _fix_call_tool_arguments(super()._make_call_tool())
|
||||
|
||||
transform = _FixedBM25SearchTransform(**kwargs)
|
||||
|
||||
mcp_instance.add_transform(transform)
|
||||
logger.info(
|
||||
"Tool search transform enabled (strategy=%s, max_results=%d, pinned=%s)",
|
||||
strategy,
|
||||
kwargs["max_results"],
|
||||
kwargs["always_visible"],
|
||||
)
|
||||
|
||||
|
||||
def _create_auth_provider(flask_app: Any) -> Any | None:
|
||||
"""Create an auth provider from Flask app config.
|
||||
|
||||
@@ -218,6 +333,11 @@ def run_server(
|
||||
logging.info("Creating MCP app from factory configuration...")
|
||||
factory_config = get_mcp_factory_config()
|
||||
mcp_instance = create_mcp_app(**factory_config)
|
||||
|
||||
# Apply tool search transform if configured
|
||||
tool_search_config = MCP_TOOL_SEARCH_CONFIG
|
||||
if tool_search_config.get("enabled", False):
|
||||
_apply_tool_search_transform(mcp_instance, tool_search_config)
|
||||
else:
|
||||
# Use default initialization with auth from Flask config
|
||||
logging.info("Creating MCP app with default configuration...")
|
||||
@@ -233,8 +353,7 @@ def run_server(
|
||||
middleware_list = []
|
||||
|
||||
# Add caching middleware (innermost – runs closest to the tool)
|
||||
caching_middleware = create_response_caching_middleware()
|
||||
if caching_middleware:
|
||||
if caching_middleware := create_response_caching_middleware():
|
||||
middleware_list.append(caching_middleware)
|
||||
|
||||
# Add response size guard (protects LLM clients from huge responses)
|
||||
@@ -252,6 +371,18 @@ def run_server(
|
||||
middleware=middleware_list or None,
|
||||
)
|
||||
|
||||
# Apply tool search transform if configured
|
||||
tool_search_config = flask_app.config.get(
|
||||
"MCP_TOOL_SEARCH_CONFIG", MCP_TOOL_SEARCH_CONFIG
|
||||
)
|
||||
if tool_search_config.get("enabled", False):
|
||||
_apply_tool_search_transform(mcp_instance, tool_search_config)
|
||||
# Ensure the configured search tool name is excluded from the
|
||||
# response size guard (search results are intentionally large)
|
||||
if size_guard_middleware:
|
||||
search_name = tool_search_config.get("search_tool_name", "search_tools")
|
||||
size_guard_middleware.excluded_tools.add(search_name)
|
||||
|
||||
# Create EventStore for session management (Redis for multi-pod, None for in-memory)
|
||||
event_store = create_event_store(event_store_config)
|
||||
|
||||
|
||||
@@ -58,54 +58,40 @@ class GetSupersetInstanceInfoRequest(BaseModel):
|
||||
|
||||
|
||||
class InstanceSummary(BaseModel):
|
||||
total_dashboards: int = Field(..., description="Total number of dashboards")
|
||||
total_charts: int = Field(..., description="Total number of charts")
|
||||
total_datasets: int = Field(..., description="Total number of datasets")
|
||||
total_databases: int = Field(..., description="Total number of databases")
|
||||
total_users: int = Field(..., description="Total number of users")
|
||||
total_roles: int = Field(..., description="Total number of roles")
|
||||
total_tags: int = Field(..., description="Total number of tags")
|
||||
avg_charts_per_dashboard: float = Field(
|
||||
..., description="Average number of charts per dashboard"
|
||||
)
|
||||
total_dashboards: int
|
||||
total_charts: int
|
||||
total_datasets: int
|
||||
total_databases: int
|
||||
total_users: int
|
||||
total_roles: int
|
||||
total_tags: int
|
||||
avg_charts_per_dashboard: float
|
||||
|
||||
|
||||
class RecentActivity(BaseModel):
|
||||
dashboards_created_last_30_days: int = Field(
|
||||
..., description="Dashboards created in the last 30 days"
|
||||
)
|
||||
charts_created_last_30_days: int = Field(
|
||||
..., description="Charts created in the last 30 days"
|
||||
)
|
||||
datasets_created_last_30_days: int = Field(
|
||||
..., description="Datasets created in the last 30 days"
|
||||
)
|
||||
dashboards_modified_last_7_days: int = Field(
|
||||
..., description="Dashboards modified in the last 7 days"
|
||||
)
|
||||
charts_modified_last_7_days: int = Field(
|
||||
..., description="Charts modified in the last 7 days"
|
||||
)
|
||||
datasets_modified_last_7_days: int = Field(
|
||||
..., description="Datasets modified in the last 7 days"
|
||||
)
|
||||
dashboards_created_last_30_days: int
|
||||
charts_created_last_30_days: int
|
||||
datasets_created_last_30_days: int
|
||||
dashboards_modified_last_7_days: int
|
||||
charts_modified_last_7_days: int
|
||||
datasets_modified_last_7_days: int
|
||||
|
||||
|
||||
class DashboardBreakdown(BaseModel):
|
||||
published: int = Field(..., description="Number of published dashboards")
|
||||
unpublished: int = Field(..., description="Number of unpublished dashboards")
|
||||
certified: int = Field(..., description="Number of certified dashboards")
|
||||
with_charts: int = Field(..., description="Number of dashboards with charts")
|
||||
without_charts: int = Field(..., description="Number of dashboards without charts")
|
||||
published: int
|
||||
unpublished: int
|
||||
certified: int
|
||||
with_charts: int
|
||||
without_charts: int
|
||||
|
||||
|
||||
class DatabaseBreakdown(BaseModel):
|
||||
by_type: Dict[str, int] = Field(..., description="Breakdown of databases by type")
|
||||
by_type: Dict[str, int]
|
||||
|
||||
|
||||
class PopularContent(BaseModel):
|
||||
top_tags: List[str] = Field(..., description="Most popular tags")
|
||||
top_creators: List[str] = Field(..., description="Most active creators")
|
||||
top_tags: List[str] = Field(default_factory=list)
|
||||
top_creators: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class FeatureAvailability(BaseModel):
|
||||
@@ -125,33 +111,19 @@ class FeatureAvailability(BaseModel):
|
||||
|
||||
|
||||
class InstanceInfo(BaseModel):
|
||||
instance_summary: InstanceSummary = Field(
|
||||
..., description="Instance summary information"
|
||||
)
|
||||
recent_activity: RecentActivity = Field(
|
||||
..., description="Recent activity information"
|
||||
)
|
||||
dashboard_breakdown: DashboardBreakdown = Field(
|
||||
..., description="Dashboard breakdown information"
|
||||
)
|
||||
database_breakdown: DatabaseBreakdown = Field(
|
||||
..., description="Database breakdown by type"
|
||||
)
|
||||
popular_content: PopularContent = Field(
|
||||
..., description="Popular content information"
|
||||
)
|
||||
instance_summary: InstanceSummary
|
||||
recent_activity: RecentActivity
|
||||
dashboard_breakdown: DashboardBreakdown
|
||||
database_breakdown: DatabaseBreakdown
|
||||
popular_content: PopularContent
|
||||
current_user: UserInfo | None = Field(
|
||||
None,
|
||||
description="The authenticated user making the request. "
|
||||
"Use current_user.id with created_by_fk filter to find your own assets.",
|
||||
)
|
||||
feature_availability: FeatureAvailability = Field(
|
||||
...,
|
||||
description=(
|
||||
"Dynamic feature availability for the current user and deployment"
|
||||
"Use current_user.id with created_by_fk filter to find your own assets."
|
||||
),
|
||||
)
|
||||
timestamp: datetime = Field(..., description="Response timestamp")
|
||||
feature_availability: FeatureAvailability
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
@@ -207,10 +179,10 @@ class RoleInfo(BaseModel):
|
||||
|
||||
|
||||
class PaginationInfo(BaseModel):
|
||||
page: int = Field(..., description="Current page number")
|
||||
page_size: int = Field(..., description="Number of items per page")
|
||||
total_count: int = Field(..., description="Total number of items")
|
||||
total_pages: int = Field(..., description="Total number of pages")
|
||||
has_next: bool = Field(..., description="Whether there is a next page")
|
||||
has_previous: bool = Field(..., description="Whether there is a previous page")
|
||||
page: int
|
||||
page_size: int
|
||||
total_count: int
|
||||
total_pages: int
|
||||
has_next: bool
|
||||
has_previous: bool
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
@@ -17,25 +17,32 @@
|
||||
|
||||
"""Test MCP app imports and tool/prompt registration."""
|
||||
|
||||
import asyncio
|
||||
|
||||
|
||||
def _run(coro):
|
||||
"""Run an async coroutine synchronously."""
|
||||
return asyncio.run(coro)
|
||||
|
||||
|
||||
def test_mcp_app_imports_successfully():
|
||||
"""Test that the MCP app can be imported without errors."""
|
||||
from superset.mcp_service.app import mcp
|
||||
|
||||
assert mcp is not None
|
||||
assert hasattr(mcp, "_tool_manager")
|
||||
|
||||
tools = mcp._tool_manager._tools
|
||||
assert len(tools) > 0
|
||||
assert "health_check" in tools
|
||||
assert "list_charts" in tools
|
||||
tools = _run(mcp.list_tools())
|
||||
tool_names = [t.name for t in tools]
|
||||
assert len(tool_names) > 0
|
||||
assert "health_check" in tool_names
|
||||
assert "list_charts" in tool_names
|
||||
|
||||
|
||||
def test_mcp_prompts_registered():
|
||||
"""Test that MCP prompts are registered."""
|
||||
from superset.mcp_service.app import mcp
|
||||
|
||||
prompts = mcp._prompt_manager._prompts
|
||||
prompts = _run(mcp.list_prompts())
|
||||
assert len(prompts) > 0
|
||||
|
||||
|
||||
@@ -48,12 +55,10 @@ def test_mcp_resources_registered():
|
||||
"""
|
||||
from superset.mcp_service.app import mcp
|
||||
|
||||
resource_manager = mcp._resource_manager
|
||||
resources = resource_manager._resources
|
||||
resources = _run(mcp.list_resources())
|
||||
assert len(resources) > 0, "No MCP resources registered"
|
||||
|
||||
# Verify the two documented resources are registered
|
||||
resource_uris = set(resources.keys())
|
||||
resource_uris = {str(r.uri) for r in resources}
|
||||
assert "chart://configs" in resource_uris, (
|
||||
"chart://configs resource not registered - "
|
||||
"check superset/mcp_service/chart/__init__.py exists"
|
||||
|
||||
173
tests/unit_tests/mcp_service/test_tool_search_transform.py
Normal file
173
tests/unit_tests/mcp_service/test_tool_search_transform.py
Normal file
@@ -0,0 +1,173 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for MCP tool search transform configuration and application."""
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from fastmcp.server.transforms.search import BM25SearchTransform, RegexSearchTransform
|
||||
|
||||
from superset.mcp_service.mcp_config import MCP_TOOL_SEARCH_CONFIG
|
||||
from superset.mcp_service.server import (
|
||||
_apply_tool_search_transform,
|
||||
_fix_call_tool_arguments,
|
||||
_serialize_tools_without_output_schema,
|
||||
)
|
||||
|
||||
|
||||
def test_tool_search_config_defaults():
|
||||
"""Default config has expected keys and values."""
|
||||
assert MCP_TOOL_SEARCH_CONFIG["enabled"] is True
|
||||
assert MCP_TOOL_SEARCH_CONFIG["strategy"] == "bm25"
|
||||
assert MCP_TOOL_SEARCH_CONFIG["max_results"] == 5
|
||||
assert "health_check" in MCP_TOOL_SEARCH_CONFIG["always_visible"]
|
||||
assert "get_instance_info" in MCP_TOOL_SEARCH_CONFIG["always_visible"]
|
||||
assert MCP_TOOL_SEARCH_CONFIG["search_tool_name"] == "search_tools"
|
||||
assert MCP_TOOL_SEARCH_CONFIG["call_tool_name"] == "call_tool"
|
||||
|
||||
|
||||
def test_apply_bm25_transform():
|
||||
"""BM25 subclass is created and added when strategy is 'bm25'."""
|
||||
mock_mcp = MagicMock()
|
||||
config = {
|
||||
"strategy": "bm25",
|
||||
"max_results": 5,
|
||||
"always_visible": ["health_check"],
|
||||
"search_tool_name": "search_tools",
|
||||
"call_tool_name": "call_tool",
|
||||
}
|
||||
|
||||
_apply_tool_search_transform(mock_mcp, config)
|
||||
|
||||
mock_mcp.add_transform.assert_called_once()
|
||||
transform = mock_mcp.add_transform.call_args[0][0]
|
||||
assert isinstance(transform, BM25SearchTransform)
|
||||
|
||||
|
||||
def test_apply_regex_transform():
|
||||
"""Regex subclass is created and added when strategy is 'regex'."""
|
||||
mock_mcp = MagicMock()
|
||||
config = {
|
||||
"strategy": "regex",
|
||||
"max_results": 10,
|
||||
"always_visible": ["health_check", "get_instance_info"],
|
||||
"search_tool_name": "find_tools",
|
||||
"call_tool_name": "invoke_tool",
|
||||
}
|
||||
|
||||
_apply_tool_search_transform(mock_mcp, config)
|
||||
|
||||
mock_mcp.add_transform.assert_called_once()
|
||||
transform = mock_mcp.add_transform.call_args[0][0]
|
||||
assert isinstance(transform, RegexSearchTransform)
|
||||
|
||||
|
||||
def test_apply_transform_uses_defaults_for_missing_keys():
|
||||
"""Missing config keys fall back to sensible defaults (BM25)."""
|
||||
mock_mcp = MagicMock()
|
||||
config = {} # All keys missing — should use defaults
|
||||
|
||||
_apply_tool_search_transform(mock_mcp, config)
|
||||
|
||||
mock_mcp.add_transform.assert_called_once()
|
||||
transform = mock_mcp.add_transform.call_args[0][0]
|
||||
assert isinstance(transform, BM25SearchTransform)
|
||||
|
||||
|
||||
def test_fix_call_tool_arguments_replaces_anyof():
|
||||
"""_fix_call_tool_arguments replaces anyOf with flat type: object."""
|
||||
tool = SimpleNamespace(
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"arguments": {
|
||||
"anyOf": [
|
||||
{"type": "object", "additionalProperties": True},
|
||||
{"type": "null"},
|
||||
],
|
||||
"default": None,
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
result = _fix_call_tool_arguments(tool)
|
||||
|
||||
assert result.parameters["properties"]["arguments"] == {
|
||||
"additionalProperties": True,
|
||||
"default": None,
|
||||
"description": "Arguments to pass to the tool",
|
||||
"type": "object",
|
||||
}
|
||||
# Other properties untouched
|
||||
assert result.parameters["properties"]["name"] == {"type": "string"}
|
||||
|
||||
|
||||
def test_fix_call_tool_arguments_no_arguments_field():
|
||||
"""_fix_call_tool_arguments is a no-op when arguments field is absent."""
|
||||
tool = SimpleNamespace(
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
}
|
||||
)
|
||||
|
||||
result = _fix_call_tool_arguments(tool)
|
||||
|
||||
assert "arguments" not in result.parameters["properties"]
|
||||
|
||||
|
||||
def test_serialize_tools_strips_output_schema():
|
||||
"""Custom serializer removes outputSchema from tool definitions."""
|
||||
mock_tool = MagicMock()
|
||||
mock_mcp_tool = MagicMock()
|
||||
mock_mcp_tool.model_dump.return_value = {
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"inputSchema": {"type": "object", "properties": {"x": {"type": "integer"}}},
|
||||
"outputSchema": {
|
||||
"type": "object",
|
||||
"properties": {"result": {"type": "string"}},
|
||||
},
|
||||
}
|
||||
mock_tool.to_mcp_tool.return_value = mock_mcp_tool
|
||||
|
||||
result = _serialize_tools_without_output_schema([mock_tool])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "test_tool"
|
||||
assert "inputSchema" in result[0]
|
||||
assert "outputSchema" not in result[0]
|
||||
|
||||
|
||||
def test_serialize_tools_handles_no_output_schema():
|
||||
"""Custom serializer works when tool has no outputSchema."""
|
||||
mock_tool = MagicMock()
|
||||
mock_mcp_tool = MagicMock()
|
||||
mock_mcp_tool.model_dump.return_value = {
|
||||
"name": "simple_tool",
|
||||
"inputSchema": {"type": "object"},
|
||||
}
|
||||
mock_tool.to_mcp_tool.return_value = mock_mcp_tool
|
||||
|
||||
result = _serialize_tools_without_output_schema([mock_tool])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0]["name"] == "simple_tool"
|
||||
assert "outputSchema" not in result[0]
|
||||
Reference in New Issue
Block a user