mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
fix(mcp): normalize column names to fix time series filter prompt issue (#37187)
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
@@ -35,6 +35,7 @@ from superset.mcp_service.chart.schemas import (
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
from superset.mcp_service.common.error_schemas import DatasetContext
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -706,3 +707,151 @@ class TestGenerateExploreLink:
|
||||
assert result.data["form_data"].get("x_axis") == "date"
|
||||
# Verify datasource field format: "{dataset_id}__table"
|
||||
assert result.data["form_data"].get("datasource") == "1__table"
|
||||
|
||||
|
||||
class TestGenerateExploreLinkColumnNormalization:
|
||||
"""Tests that generate_explore_link normalizes column names.
|
||||
|
||||
This verifies the fix where user-provided column names in wrong case
|
||||
(e.g., 'order_date') are normalized to the canonical dataset name
|
||||
(e.g., 'OrderDate') before being used in form_data.
|
||||
"""
|
||||
|
||||
@patch(
|
||||
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
|
||||
)
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_xy_chart_x_axis_normalized_in_form_data(
|
||||
self,
|
||||
mock_create_form_data,
|
||||
mock_find_dataset,
|
||||
mock_get_context,
|
||||
mcp_server,
|
||||
):
|
||||
"""x-axis column name in wrong case is normalized in form_data."""
|
||||
mock_create_form_data.return_value = "norm_test_key_1"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=18)
|
||||
mock_get_context.return_value = DatasetContext(
|
||||
id=18,
|
||||
table_name="Vehicle Sales",
|
||||
schema="public",
|
||||
database_name="examples",
|
||||
available_columns=[
|
||||
{"name": "OrderDate", "type": "DATE", "is_temporal": True},
|
||||
{"name": "Sales", "type": "FLOAT", "is_numeric": True},
|
||||
],
|
||||
available_metrics=[],
|
||||
)
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="orderdate"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="line",
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="18", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
# x-axis should be normalized from 'orderdate' to 'OrderDate'
|
||||
assert result.data["form_data"]["x_axis"] == "OrderDate"
|
||||
|
||||
@patch(
|
||||
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
|
||||
)
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_filter_column_normalized_in_form_data(
|
||||
self,
|
||||
mock_create_form_data,
|
||||
mock_find_dataset,
|
||||
mock_get_context,
|
||||
mcp_server,
|
||||
):
|
||||
"""Filter column name in wrong case is normalized in adhoc_filters."""
|
||||
mock_create_form_data.return_value = "norm_test_key_2"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=18)
|
||||
mock_get_context.return_value = DatasetContext(
|
||||
id=18,
|
||||
table_name="Vehicle Sales",
|
||||
schema="public",
|
||||
database_name="examples",
|
||||
available_columns=[
|
||||
{"name": "OrderDate", "type": "DATE", "is_temporal": True},
|
||||
{"name": "Sales", "type": "FLOAT", "is_numeric": True},
|
||||
],
|
||||
available_metrics=[],
|
||||
)
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="orderdate"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="line",
|
||||
filters=[
|
||||
FilterConfig(column="orderdate", op=">", value="2023-01-01"),
|
||||
],
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="18", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
form_data = result.data["form_data"]
|
||||
# x-axis normalized
|
||||
assert form_data["x_axis"] == "OrderDate"
|
||||
# filter subject normalized to match x-axis
|
||||
adhoc_filters = form_data.get("adhoc_filters", [])
|
||||
assert len(adhoc_filters) == 1
|
||||
assert adhoc_filters[0]["subject"] == "OrderDate"
|
||||
|
||||
@patch(
|
||||
"superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context"
|
||||
)
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_normalization_fallback_when_dataset_not_found(
|
||||
self,
|
||||
mock_create_form_data,
|
||||
mock_find_dataset,
|
||||
mock_get_context,
|
||||
mcp_server,
|
||||
):
|
||||
"""When dataset context is unavailable, original names pass through."""
|
||||
mock_create_form_data.return_value = "norm_test_key_3"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=99)
|
||||
mock_get_context.return_value = None
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="orderdate"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="line",
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="99", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
# original names should pass through unchanged
|
||||
assert result.data["form_data"]["x_axis"] == "orderdate"
|
||||
|
||||
Reference in New Issue
Block a user