fix(mcp): normalize column names to fix time series filter prompt issue (#37187)

Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-02-25 09:27:53 -05:00
committed by GitHub
parent 3084907931
commit a1312a86e8
5 changed files with 1073 additions and 11 deletions

View File

@@ -35,6 +35,7 @@ from superset.mcp_service.chart.schemas import (
TableChartConfig,
XYChartConfig,
)
from superset.mcp_service.common.error_schemas import DatasetContext
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@@ -706,3 +707,151 @@ class TestGenerateExploreLink:
assert result.data["form_data"].get("x_axis") == "date"
# Verify datasource field format: "{dataset_id}__table"
assert result.data["form_data"].get("datasource") == "1__table"
class TestGenerateExploreLinkColumnNormalization:
"""Tests that generate_explore_link normalizes column names.
This verifies the fix where user-provided column names in wrong case
(e.g., 'order_date') are normalized to the canonical dataset name
(e.g., 'OrderDate') before being used in form_data.
"""
@patch(
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
)
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_xy_chart_x_axis_normalized_in_form_data(
self,
mock_create_form_data,
mock_find_dataset,
mock_get_context,
mcp_server,
):
"""x-axis column name in wrong case is normalized in form_data."""
mock_create_form_data.return_value = "norm_test_key_1"
mock_find_dataset.return_value = _mock_dataset(id=18)
mock_get_context.return_value = DatasetContext(
id=18,
table_name="Vehicle Sales",
schema="public",
database_name="examples",
available_columns=[
{"name": "OrderDate", "type": "DATE", "is_temporal": True},
{"name": "Sales", "type": "FLOAT", "is_numeric": True},
],
available_metrics=[],
)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="orderdate"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="line",
)
request = GenerateExploreLinkRequest(dataset_id="18", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
# x-axis should be normalized from 'orderdate' to 'OrderDate'
assert result.data["form_data"]["x_axis"] == "OrderDate"
@patch(
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
)
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_filter_column_normalized_in_form_data(
self,
mock_create_form_data,
mock_find_dataset,
mock_get_context,
mcp_server,
):
"""Filter column name in wrong case is normalized in adhoc_filters."""
mock_create_form_data.return_value = "norm_test_key_2"
mock_find_dataset.return_value = _mock_dataset(id=18)
mock_get_context.return_value = DatasetContext(
id=18,
table_name="Vehicle Sales",
schema="public",
database_name="examples",
available_columns=[
{"name": "OrderDate", "type": "DATE", "is_temporal": True},
{"name": "Sales", "type": "FLOAT", "is_numeric": True},
],
available_metrics=[],
)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="orderdate"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="line",
filters=[
FilterConfig(column="orderdate", op=">", value="2023-01-01"),
],
)
request = GenerateExploreLinkRequest(dataset_id="18", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
form_data = result.data["form_data"]
# x-axis normalized
assert form_data["x_axis"] == "OrderDate"
# filter subject normalized to match x-axis
adhoc_filters = form_data.get("adhoc_filters", [])
assert len(adhoc_filters) == 1
assert adhoc_filters[0]["subject"] == "OrderDate"
@patch(
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
)
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_normalization_fallback_when_dataset_not_found(
self,
mock_create_form_data,
mock_find_dataset,
mock_get_context,
mcp_server,
):
"""When dataset context is unavailable, original names pass through."""
mock_create_form_data.return_value = "norm_test_key_3"
mock_find_dataset.return_value = _mock_dataset(id=99)
mock_get_context.return_value = None
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="orderdate"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="line",
)
request = GenerateExploreLinkRequest(dataset_id="99", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
# original names should pass through unchanged
assert result.data["form_data"]["x_axis"] == "orderdate"