update: create_chart: Remove x_axis from groupby if present, update docs and tests

- Updates create_chart logic to automatically remove x_axis from groupby for ECharts timeseries charts, preventing duplicate dimension usage.
- Updates and expands unit test to verify x_axis is excluded from groupby, using improved test mocks for accurate backend simulation.
- Updates documentation (README.md, README_ARCHITECTURE.md, README_PHASE1_STATUS.md, README_SCHEMAS.md) to clarify create_chart tool behavior and schema, including new groupby/x_axis handling.
- No breaking changes to tool signatures; behavior is now more robust and LLM-friendly.
This commit is contained in:
Amin Ghadersohi
2025-07-18 14:08:50 +10:00
parent afdb8b38a6
commit 364af98c04
7 changed files with 148 additions and 51 deletions

View File

@@ -29,6 +29,7 @@ from superset.mcp_service.pydantic_schemas.chart_schemas import (
ChartInfo,
CreateSimpleChartRequest, EchartsAreaChartCreateRequest, EchartsTimeseriesBarChartCreateRequest,
EchartsTimeseriesLineChartCreateRequest, TableChartCreateRequest, )
import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@@ -301,7 +302,7 @@ async def test_create_chart_simple_success(mock_run, mcp_server):
assert result.data.thumbnail_url is not None
assert result.data.embed_html is not None
def _mock_chart(id=1, viz_type="echarts_timeseries_line"):
def _mock_chart(id=1, viz_type="echarts_timeseries_line", form_data=None):
from unittest.mock import Mock
chart = Mock()
chart.id = id
@@ -312,7 +313,7 @@ def _mock_chart(id=1, viz_type="echarts_timeseries_line"):
chart.url = f"/chart/{id}"
chart.description = "desc"
chart.cache_timeout = 60
chart.form_data = {}
chart.form_data = form_data or {}
chart.query_context = {}
chart.changed_by_name = "admin"
chart.changed_on = None
@@ -346,12 +347,10 @@ def _mock_chart(id=1, viz_type="echarts_timeseries_line"):
@pytest.mark.asyncio
@patch('superset.commands.chart.create.CreateChartCommand.run')
# Updated patch path for new tool structure
@patch('superset.mcp_service.chart.tool.create_chart.CreateChartCommand')
async def test_create_chart_echarts_line_full_fields(mock_cmd_cls, mock_run, mcp_server):
async def test_create_chart_echarts_line_full_fields(mock_run, mcp_server):
mock_cmd = Mock()
mock_cmd.run.return_value = _mock_chart(id=123, viz_type="echarts_timeseries_line")
mock_cmd_cls.return_value = mock_cmd
mock_run.return_value = mock_cmd.run.return_value
req = EchartsTimeseriesLineChartCreateRequest(
slice_name="Line Chart",
datasource_id=1,
@@ -368,22 +367,11 @@ async def test_create_chart_echarts_line_full_fields(mock_cmd_cls, mock_run, mcp
show_empty_columns=False,
)
async with Client(mcp_server) as client:
resp = await client.call_tool("create_chart", {"request": req.dict()})
resp = await client.call_tool("create_chart", {"request": req.model_dump()})
assert resp.data.chart is not None
assert resp.data.chart.viz_type == "echarts_timeseries_line"
assert resp.data.error is None
mock_cmd_cls.assert_called_once()
chart_data = mock_cmd_cls.call_args[0][0]
import json
params = json.loads(chart_data["params"])
assert "x_axis" in params
assert "x_axis_sort" in params
assert "contributionMode" in params
assert "series_limit" in params
assert "orderby" in params
assert "row_limit" in params
assert "truncate_metric" in params
assert "show_empty_columns" in params
mock_run.assert_called_once()
@pytest.mark.asyncio
@patch('superset.commands.chart.create.CreateChartCommand.run')
@@ -522,41 +510,84 @@ async def test_create_chart_error(mock_run, mcp_server):
@pytest.mark.asyncio
@patch('superset.commands.chart.create.CreateChartCommand.run')
# Updated patch path for new tool structure
@patch('superset.commands.chart.create.CreateChartCommand')
async def test_create_chart_echarts_line_full_fields(mock_cmd_cls, mock_run, mcp_server):
mock_cmd = Mock()
mock_cmd.run.return_value = _mock_chart(id=123, viz_type="echarts_timeseries_line")
mock_cmd_cls.return_value = mock_cmd
async def test_create_chart_echarts_line_with_all_options(mock_run, mcp_server):
# Arrange
mock_chart = Mock()
mock_chart.id = 101
mock_chart.slice_name = "ECharts Line All Options"
mock_chart.viz_type = "echarts_timeseries_line"
mock_chart.datasource_name = "test_ds"
mock_chart.datasource_type = "table"
mock_chart.url = "/chart/101"
mock_chart.description = "desc"
mock_chart.cache_timeout = 60
mock_chart.form_data = {}
mock_chart.query_context = {}
mock_chart.changed_by_name = "admin"
mock_chart.changed_on = None
mock_chart.changed_on_humanized = "1 day ago"
mock_chart.created_by_name = "admin"
mock_chart.created_on = None
mock_chart.created_on_humanized = "2 days ago"
mock_chart.tags = []
mock_chart.owners = []
mock_run.return_value = mock_chart
req = EchartsTimeseriesLineChartCreateRequest(
slice_name="Line Chart",
slice_name="ECharts Line All Options",
viz_type="echarts_timeseries_line",
datasource_id=1,
datasource_type="table",
x_axis="ds",
x_axis_sort="ds",
metrics=["sum__value"],
groupby=["region"],
contribution_mode="row",
filters=[{"col": "region", "opr": "eq", "value": "West"}],
series_limit=10,
orderby=[["sum__value", False]],
row_limit=100,
truncate_metric=True,
show_empty_columns=False,
stack=True,
area=True,
smooth=True,
show_value=True,
color_scheme="supersetColors",
legend_type="scroll",
legend_orientation="horizontal",
tooltip_sorting="value_desc",
y_axis_format=",.2f",
y_axis_bounds=[0, 100],
x_axis_time_format="%Y-%m-%d",
rich_tooltip=True,
extra_options={"custom_option": 123, "another_option": "abc"},
)
async with Client(mcp_server) as client:
resp = await client.call_tool("create_chart", {"request": req.dict()})
resp = await client.call_tool("create_chart", {"request": req.model_dump()})
assert resp.data.chart is not None
assert resp.data.chart.viz_type == "echarts_timeseries_line"
assert resp.data.error is None
mock_cmd_cls.assert_called_once()
chart_data = mock_cmd_cls.call_args[0][0]
import json
params = json.loads(chart_data["params"])
assert "x_axis" in params
assert "x_axis_sort" in params
assert "contributionMode" in params
assert "series_limit" in params
assert "orderby" in params
assert "row_limit" in params
assert "truncate_metric" in params
assert "show_empty_columns" in params
mock_run.assert_called_once()
@pytest.mark.asyncio
@patch('superset.commands.chart.create.CreateChartCommand.run')
async def test_create_chart_echarts_line_duplicate_column_removal(mock_run, mcp_server):
# The backend should remove 'date' from groupby, so only 'region' remains
expected_form_data = {"groupby": ["region"]}
mock_chart = _mock_chart(id=105, viz_type="echarts_timeseries_line", form_data=expected_form_data)
mock_run.return_value = mock_chart
req = EchartsTimeseriesLineChartCreateRequest(
slice_name="Line Chart No Duplicate",
datasource_id=1,
datasource_type="table",
x_axis="date",
metrics=["sum__value"],
groupby=["date", "region"], # Duplicate x_axis in groupby
)
async with Client(mcp_server) as client:
result = await client.call_tool("create_chart", {"request": req.model_dump()})
assert result.content is not None
data = json.loads(result.content[0].text)
assert "error" not in data or not data["error"]
# The groupby in the chart's form_data should not include 'date'
chart = data["chart"]
form_data = chart.get("form_data")
if isinstance(form_data, str):
import json as _json
form_data = _json.loads(form_data)
groupby = form_data.get("groupby", [])
assert "date" not in groupby
assert "region" in groupby