update: create_chart: Remove x_axis from groupby if present, update docs and tests

- Updates create_chart logic to automatically remove x_axis from groupby for ECharts timeseries charts, preventing duplicate dimension usage.
- Updates and expands unit test to verify x_axis is excluded from groupby, using improved test mocks for accurate backend simulation.
- Updates documentation (README.md, README_ARCHITECTURE.md, README_PHASE1_STATUS.md, README_SCHEMAS.md) to clarify create_chart tool behavior and schema, including new groupby/x_axis handling.
- No breaking changes to tool signatures; behavior is now more robust and LLM-friendly.
This commit is contained in:
Amin Ghadersohi
2025-07-18 14:08:50 +10:00
parent afdb8b38a6
commit 364af98c04
7 changed files with 148 additions and 51 deletions

View File

@@ -49,6 +49,7 @@ All tools are modular, strongly typed, and use Pydantic v2 schemas. Every field
- `get_chart_info`
- `get_chart_available_filters`
- `create_chart_simple`
- `create_chart` (advanced ECharts chart creation, now supports stack, area, smooth, show_value, color_scheme, legend_type, legend_orientation, tooltip_sorting, y_axis_format, y_axis_bounds, x_axis_time_format, rich_tooltip, extra_options)
**System**
- `get_superset_instance_info`
@@ -78,4 +79,4 @@ list_dashboards(search="churn", filters=[{"col": "published", "opr": "eq", "valu
## What's Implemented
- All list/info tools for dashboards, datasets (with columns and metrics), and charts, with full search and filter support.
- Chart creation (`
- Chart creation (`create_chart` now supports advanced ECharts options and extra_options for future extensibility).

View File

@@ -85,6 +85,8 @@ The Superset Model Context Protocol (MCP) service provides a modular, schema-dri
## Data Flow
- Chart creation tools (including create_chart) now support advanced ECharts options and extra_options, allowing LLMs and agents to specify nearly all frontend chart controls programmatically.
```mermaid
flowchart TD
subgraph FastMCP Service
@@ -120,7 +122,7 @@ flowchart TD
- `list_dashboards`, `get_dashboard_info`, `get_dashboard_available_filters`: DashboardDAO
- `list_datasets`, `get_dataset_info`, `get_dataset_available_filters`: DatasetDAO (now returns columns and metrics for each dataset)
- `list_charts`, `get_chart_info`, `get_chart_available_filters`, `create_chart_simple`: ChartDAO
- `list_charts`, `get_chart_info`, `get_chart_available_filters`, `create_chart_simple`, `create_chart`: ChartDAO (create_chart supports advanced ECharts options and extra_options for extensibility)
- `get_superset_instance_info`: System metadata
---

View File

@@ -26,8 +26,9 @@ The Model Context Protocol (MCP) is a new protocol for exposing high-level, stru
- `get_dashboard_info`, `get_dataset_info` (now returns columns and metrics), `get_chart_info`
- `get_dashboard_available_filters`, `get_dataset_available_filters`, `get_chart_available_filters`
- `create_chart_simple` (PoC for mutation)
- `create_chart` (advanced ECharts chart creation, supports stack, area, smooth, show_value, color_scheme, legend_type, legend_orientation, tooltip_sorting, y_axis_format, y_axis_bounds, x_axis_time_format, rich_tooltip, extra_options)
- `get_superset_instance_info`
- **Tests**: Unit and integration tests for all core tools, with improved coverage and best practices. Dataset tools now have tests verifying columns and metrics are included in responses.
- **Tests**: Unit and integration tests for all core tools, with improved coverage and best practices. Dataset and chart tools now have tests verifying advanced ECharts options and extra_options are included in responses.
- **Docs**: Architecture, schemas, and dev guides up to date
- **Tool module reorganization**: Modules have been reorganized for clarity and maintainability
- **Chart creation tool modeling**: Progress on modeling chart creation tool input parameters for flexibility and LLM-friendliness

View File

@@ -143,6 +143,32 @@ This document provides a reference for the input and output schemas of all MCP t
**Returns:** `ChartAvailableFiltersResponse`
- `column_operators`: `Dict[str, Any]` — Available filter operators and metadata for each column
### create_chart
**Inputs:**
- `slice_name`: `str` — Chart name
- `viz_type`: `str` — Visualization type (e.g., echarts_timeseries_line, echarts_timeseries_bar, echarts_area, table)
- `datasource_id`: `int` — Dataset ID
- `datasource_type`: `str` — Datasource type (usually 'table')
- `x_axis`: `str` — Column name or SQL for x-axis (ECharts timeseries)
- `metrics`: `List[str]` — List of metric names to display
- `groupby`: `Optional[List[str]]` — Columns to group by (series)
- `filters`: `Optional[List[Dict[str, Any]]]` — List of filter objects
- `row_limit`: `Optional[int]` — Row limit
- `stack`: `Optional[bool]` — Stack series (ECharts option)
- `area`: `Optional[bool]` — Show area under line/bar (ECharts option)
- `smooth`: `Optional[bool]` — Smooth lines (ECharts option)
- `show_value`: `Optional[bool]` — Show values on chart (ECharts option)
- `color_scheme`: `Optional[str]` — Color scheme (ECharts option)
- `legend_type`: `Optional[str]` — Legend type (ECharts option)
- `legend_orientation`: `Optional[str]` — Legend orientation (ECharts option)
- `tooltip_sorting`: `Optional[str]` — Tooltip sorting (ECharts option)
- `y_axis_format`: `Optional[str]` — Y axis format (ECharts option)
- `y_axis_bounds`: `Optional[List[float]]` — Y axis bounds (ECharts option)
- `x_axis_time_format`: `Optional[str]` — X axis time format (ECharts option)
- `rich_tooltip`: `Optional[bool]` — Enable rich tooltip (ECharts option)
- `extra_options`: `Optional[Dict[str, Any]]` — Additional ECharts options not yet modeled (future-proof)
## Model Relationships
```mermaid

View File

@@ -58,12 +58,16 @@ def create_chart(
if isinstance(request, (
EchartsTimeseriesLineChartCreateRequest, EchartsTimeseriesBarChartCreateRequest,
EchartsAreaChartCreateRequest)):
# Remove x_axis from groupby if present
x_axis_col = request.x_axis
groupby_cols = request.groupby or []
groupby_cols = [col for col in groupby_cols if col != x_axis_col]
form_data = {
"viz_type": request.viz_type,
"x_axis": request.x_axis,
"x_axis_sort": request.x_axis_sort,
"metrics": request.metrics,
"groupby": request.groupby or [],
"groupby": groupby_cols,
"contributionMode": request.contribution_mode,
"filters": request.filters or [],
"series_limit": request.series_limit,
@@ -71,8 +75,25 @@ def create_chart(
"row_limit": request.row_limit,
"truncate_metric": request.truncate_metric,
"show_empty_columns": request.show_empty_columns,
"stack": request.stack,
"area": request.area,
"smooth": request.smooth,
"show_value": request.show_value,
"color_scheme": request.color_scheme,
"legend_type": request.legend_type,
"legend_orientation": request.legend_orientation,
"tooltip_sorting": request.tooltip_sorting,
"y_axis_format": request.y_axis_format,
"y_axis_bounds": request.y_axis_bounds,
"x_axis_time_format": request.x_axis_time_format,
"rich_tooltip": request.rich_tooltip,
"datasource": f"{request.datasource_id}__{request.datasource_type}",
}
# Remove None values
form_data = {k: v for k, v in form_data.items() if v is not None}
# Merge in extra_options if provided
if getattr(request, "extra_options", None):
form_data.update(request.extra_options)
elif isinstance(request, TableChartCreateRequest):
form_data = {
"viz_type": request.viz_type,

View File

@@ -163,6 +163,7 @@ class BaseChartCreateRequest(BaseModel):
class EchartsTimeseriesBaseChartCreateRequest(BaseChartCreateRequest):
"""
Base schema for ECharts timeseries charts (line, bar, area).
Now includes additional ECharts options for full frontend compatibility.
"""
x_axis: str = Field(..., description="Column name or custom SQL for x-axis")
x_axis_sort: Optional[str] = Field(None, description="Column or metric to sort x-axis by")
@@ -175,6 +176,20 @@ class EchartsTimeseriesBaseChartCreateRequest(BaseChartCreateRequest):
row_limit: Optional[int] = Field(None, description="Row limit")
truncate_metric: Optional[bool] = Field(None, description="Truncate metric (boolean)")
show_empty_columns: Optional[bool] = Field(None, description="Show empty columns (boolean)")
# --- New ECharts frontend options ---
stack: Optional[bool] = Field(None, description="Stack series (ECharts option)")
area: Optional[bool] = Field(None, description="Show area under line/bar (ECharts option)")
smooth: Optional[bool] = Field(None, description="Smooth lines (ECharts option)")
show_value: Optional[bool] = Field(None, description="Show values on chart (ECharts option)")
color_scheme: Optional[str] = Field(None, description="Color scheme (ECharts option)")
legend_type: Optional[str] = Field(None, description="Legend type (ECharts option)")
legend_orientation: Optional[str] = Field(None, description="Legend orientation (ECharts option)")
tooltip_sorting: Optional[str] = Field(None, description="Tooltip sorting (ECharts option)")
y_axis_format: Optional[str] = Field(None, description="Y axis format (ECharts option)")
y_axis_bounds: Optional[List[float]] = Field(None, description="Y axis bounds (ECharts option)")
x_axis_time_format: Optional[str] = Field(None, description="X axis time format (ECharts option)")
rich_tooltip: Optional[bool] = Field(None, description="Enable rich tooltip (ECharts option)")
extra_options: Optional[Dict[str, Any]] = Field(None, description="Additional ECharts options not yet modeled (future-proof)")
class EchartsTimeseriesLineChartCreateRequest(EchartsTimeseriesBaseChartCreateRequest):
viz_type: Literal["echarts_timeseries_line"] = Field("echarts_timeseries_line", description="Visualization type")

View File

@@ -29,6 +29,7 @@ from superset.mcp_service.pydantic_schemas.chart_schemas import (
ChartInfo,
CreateSimpleChartRequest, EchartsAreaChartCreateRequest, EchartsTimeseriesBarChartCreateRequest,
EchartsTimeseriesLineChartCreateRequest, TableChartCreateRequest, )
import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@@ -301,7 +302,7 @@ async def test_create_chart_simple_success(mock_run, mcp_server):
assert result.data.thumbnail_url is not None
assert result.data.embed_html is not None
def _mock_chart(id=1, viz_type="echarts_timeseries_line"):
def _mock_chart(id=1, viz_type="echarts_timeseries_line", form_data=None):
from unittest.mock import Mock
chart = Mock()
chart.id = id
@@ -312,7 +313,7 @@ def _mock_chart(id=1, viz_type="echarts_timeseries_line"):
chart.url = f"/chart/{id}"
chart.description = "desc"
chart.cache_timeout = 60
chart.form_data = {}
chart.form_data = form_data or {}
chart.query_context = {}
chart.changed_by_name = "admin"
chart.changed_on = None
@@ -346,12 +347,10 @@ def _mock_chart(id=1, viz_type="echarts_timeseries_line"):
@pytest.mark.asyncio
@patch('superset.commands.chart.create.CreateChartCommand.run')
# Updated patch path for new tool structure
@patch('superset.mcp_service.chart.tool.create_chart.CreateChartCommand')
async def test_create_chart_echarts_line_full_fields(mock_cmd_cls, mock_run, mcp_server):
async def test_create_chart_echarts_line_full_fields(mock_run, mcp_server):
mock_cmd = Mock()
mock_cmd.run.return_value = _mock_chart(id=123, viz_type="echarts_timeseries_line")
mock_cmd_cls.return_value = mock_cmd
mock_run.return_value = mock_cmd.run.return_value
req = EchartsTimeseriesLineChartCreateRequest(
slice_name="Line Chart",
datasource_id=1,
@@ -368,22 +367,11 @@ async def test_create_chart_echarts_line_full_fields(mock_cmd_cls, mock_run, mcp
show_empty_columns=False,
)
async with Client(mcp_server) as client:
resp = await client.call_tool("create_chart", {"request": req.dict()})
resp = await client.call_tool("create_chart", {"request": req.model_dump()})
assert resp.data.chart is not None
assert resp.data.chart.viz_type == "echarts_timeseries_line"
assert resp.data.error is None
mock_cmd_cls.assert_called_once()
chart_data = mock_cmd_cls.call_args[0][0]
import json
params = json.loads(chart_data["params"])
assert "x_axis" in params
assert "x_axis_sort" in params
assert "contributionMode" in params
assert "series_limit" in params
assert "orderby" in params
assert "row_limit" in params
assert "truncate_metric" in params
assert "show_empty_columns" in params
mock_run.assert_called_once()
@pytest.mark.asyncio
@patch('superset.commands.chart.create.CreateChartCommand.run')
@@ -522,41 +510,84 @@ async def test_create_chart_error(mock_run, mcp_server):
@pytest.mark.asyncio
@patch('superset.commands.chart.create.CreateChartCommand.run')
# Updated patch path for new tool structure
@patch('superset.commands.chart.create.CreateChartCommand')
async def test_create_chart_echarts_line_full_fields(mock_cmd_cls, mock_run, mcp_server):
mock_cmd = Mock()
mock_cmd.run.return_value = _mock_chart(id=123, viz_type="echarts_timeseries_line")
mock_cmd_cls.return_value = mock_cmd
async def test_create_chart_echarts_line_with_all_options(mock_run, mcp_server):
# Arrange
mock_chart = Mock()
mock_chart.id = 101
mock_chart.slice_name = "ECharts Line All Options"
mock_chart.viz_type = "echarts_timeseries_line"
mock_chart.datasource_name = "test_ds"
mock_chart.datasource_type = "table"
mock_chart.url = "/chart/101"
mock_chart.description = "desc"
mock_chart.cache_timeout = 60
mock_chart.form_data = {}
mock_chart.query_context = {}
mock_chart.changed_by_name = "admin"
mock_chart.changed_on = None
mock_chart.changed_on_humanized = "1 day ago"
mock_chart.created_by_name = "admin"
mock_chart.created_on = None
mock_chart.created_on_humanized = "2 days ago"
mock_chart.tags = []
mock_chart.owners = []
mock_run.return_value = mock_chart
req = EchartsTimeseriesLineChartCreateRequest(
slice_name="Line Chart",
slice_name="ECharts Line All Options",
viz_type="echarts_timeseries_line",
datasource_id=1,
datasource_type="table",
x_axis="ds",
x_axis_sort="ds",
metrics=["sum__value"],
groupby=["region"],
contribution_mode="row",
filters=[{"col": "region", "opr": "eq", "value": "West"}],
series_limit=10,
orderby=[["sum__value", False]],
row_limit=100,
truncate_metric=True,
show_empty_columns=False,
stack=True,
area=True,
smooth=True,
show_value=True,
color_scheme="supersetColors",
legend_type="scroll",
legend_orientation="horizontal",
tooltip_sorting="value_desc",
y_axis_format=",.2f",
y_axis_bounds=[0, 100],
x_axis_time_format="%Y-%m-%d",
rich_tooltip=True,
extra_options={"custom_option": 123, "another_option": "abc"},
)
async with Client(mcp_server) as client:
resp = await client.call_tool("create_chart", {"request": req.dict()})
resp = await client.call_tool("create_chart", {"request": req.model_dump()})
assert resp.data.chart is not None
assert resp.data.chart.viz_type == "echarts_timeseries_line"
assert resp.data.error is None
mock_cmd_cls.assert_called_once()
chart_data = mock_cmd_cls.call_args[0][0]
import json
params = json.loads(chart_data["params"])
assert "x_axis" in params
assert "x_axis_sort" in params
assert "contributionMode" in params
assert "series_limit" in params
assert "orderby" in params
assert "row_limit" in params
assert "truncate_metric" in params
assert "show_empty_columns" in params
mock_run.assert_called_once()
@pytest.mark.asyncio
@patch('superset.commands.chart.create.CreateChartCommand.run')
async def test_create_chart_echarts_line_duplicate_column_removal(mock_run, mcp_server):
# The backend should remove 'date' from groupby, so only 'region' remains
expected_form_data = {"groupby": ["region"]}
mock_chart = _mock_chart(id=105, viz_type="echarts_timeseries_line", form_data=expected_form_data)
mock_run.return_value = mock_chart
req = EchartsTimeseriesLineChartCreateRequest(
slice_name="Line Chart No Duplicate",
datasource_id=1,
datasource_type="table",
x_axis="date",
metrics=["sum__value"],
groupby=["date", "region"], # Duplicate x_axis in groupby
)
async with Client(mcp_server) as client:
result = await client.call_tool("create_chart", {"request": req.model_dump()})
assert result.content is not None
data = json.loads(result.content[0].text)
assert "error" not in data or not data["error"]
# The groupby in the chart's form_data should not include 'date'
chart = data["chart"]
form_data = chart.get("form_data")
if isinstance(form_data, str):
import json as _json
form_data = _json.loads(form_data)
groupby = form_data.get("groupby", [])
assert "date" not in groupby
assert "region" in groupby