fix(mcp): improve default chart names with descriptive format (#38406)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-03-10 10:53:05 +01:00
committed by GitHub
parent 13fe88000a
commit 0cfd760a36
4 changed files with 280 additions and 31 deletions

View File

@@ -801,34 +801,165 @@ def map_filter_operator(op: str) -> str:
return operator_map.get(op, op)
def _humanize_column(col: ColumnRef) -> str:
"""Return a human-readable label for a column reference."""
if col.label:
return col.label
name = col.name.replace("_", " ").title()
if col.aggregate:
return f"{col.aggregate.capitalize()}({name})"
return name
def _summarize_filters(
filters: list[Any] | None,
) -> str | None:
"""Extract a short context string from filter configs."""
if not filters:
return None
parts: list[str] = []
for f in filters[:2]:
col = getattr(f, "column", "")
val = getattr(f, "value", "")
if isinstance(val, list):
val = ", ".join(str(v) for v in val[:3])
parts.append(f"{str(col).replace('_', ' ').title()} {val}")
return ", ".join(parts) if parts else None
def _truncate(name: str, max_length: int = 60) -> str:
"""Truncate to *max_length*, preserving the en-dash context portion."""
if len(name) <= max_length:
return name
if " \u2013 " in name:
what, _context = name.split(" \u2013 ", 1)
if len(what) <= max_length:
return what
return name[: max_length - 1] + "\u2026"
def _table_chart_what(config: TableChartConfig, dataset_name: str | None) -> str:
"""Build the descriptive fragment for a table chart."""
has_agg = any(col.aggregate for col in config.columns)
if has_agg:
metrics = [col for col in config.columns if col.aggregate]
what = ", ".join(_humanize_column(m) for m in metrics[:2])
return f"{what} Summary"
if dataset_name:
return f"{dataset_name} Records"
cols = ", ".join(_humanize_column(c) for c in config.columns[:3])
return f"{cols} Table"
def _xy_chart_what(config: XYChartConfig) -> str:
"""Build the descriptive fragment for an XY chart."""
primary_metric = _humanize_column(config.y[0]) if config.y else "Value"
dimension = _humanize_column(config.x)
if config.kind in ("line", "area") and config.group_by is None:
return f"{primary_metric} Over Time"
if config.group_by is not None:
group_label = _humanize_column(config.group_by)
return f"{primary_metric} by {group_label}"
if config.kind == "scatter":
return f"{primary_metric} vs {dimension}"
return f"{primary_metric} by {dimension}"
_GRAIN_MAP: dict[str, str] = {
"PT1H": "Hourly",
"P1D": "Daily",
"P1W": "Weekly",
"P1M": "Monthly",
"P3M": "Quarterly",
"P1Y": "Yearly",
}
def _xy_chart_context(config: XYChartConfig) -> str | None:
"""Build context (time grain / filters) for an XY chart name."""
parts: list[str] = []
if config.time_grain:
grain_val = (
config.time_grain.value
if hasattr(config.time_grain, "value")
else str(config.time_grain)
)
grain_str = _GRAIN_MAP.get(grain_val, grain_val)
parts.append(grain_str)
if filter_ctx := _summarize_filters(config.filters):
parts.append(filter_ctx)
return ", ".join(parts) if parts else None
def _pie_chart_what(config: PieChartConfig) -> str:
"""Build the 'what' portion for a pie chart name."""
dim = config.dimension.name
metric_label = config.metric.label or config.metric.name
return f"{dim} by {metric_label}"
def _pivot_table_what(config: PivotTableChartConfig) -> str:
"""Build the 'what' portion for a pivot table chart name."""
row_names = ", ".join(r.name for r in config.rows)
return f"Pivot Table \u2013 {row_names}"
def _mixed_timeseries_what(config: MixedTimeseriesChartConfig) -> str:
"""Build the 'what' portion for a mixed timeseries chart name."""
primary = config.y[0].label or config.y[0].name if config.y else "primary"
secondary = (
config.y_secondary[0].label or config.y_secondary[0].name
if config.y_secondary
else "secondary"
)
return f"{primary} + {secondary}"
def generate_chart_name(
config: TableChartConfig
| XYChartConfig
| PieChartConfig
| PivotTableChartConfig
| MixedTimeseriesChartConfig,
dataset_name: str | None = None,
) -> str:
"""Generate a chart name based on the configuration."""
"""Generate a descriptive chart name following a standard format.
Format conventions (by chart type):
Aggregated (bar/scatter with group_by): [Metric] by [Dimension]
Time-series (line/area, no group_by): [Metric] Over Time
Table (no aggregates): [Dataset] Records
Table (with aggregates): [Metric] Summary
Pie: [Dimension] by [Metric]
Pivot Table: Pivot Table [Row1, Row2]
Mixed Timeseries: [Primary] + [Secondary]
An en-dash followed by context (filters / time grain) is appended
when such information is available.
"""
if isinstance(config, TableChartConfig):
return f"Table Chart - {', '.join(col.name for col in config.columns)}"
what = _table_chart_what(config, dataset_name)
context = _summarize_filters(config.filters)
elif isinstance(config, XYChartConfig):
chart_type = config.kind.capitalize()
x_col = config.x.name
y_cols = ", ".join(col.name for col in config.y)
return f"{chart_type} Chart - {x_col} vs {y_cols}"
what = _xy_chart_what(config)
context = _xy_chart_context(config)
elif isinstance(config, PieChartConfig):
metric_label = config.metric.label or config.metric.name
return f"Pie Chart - {config.dimension.name} by {metric_label}"
what = _pie_chart_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, PivotTableChartConfig):
rows = ", ".join(col.name for col in config.rows)
return f"Pivot Table - {rows}"
what = _pivot_table_what(config)
context = _summarize_filters(config.filters)
elif isinstance(config, MixedTimeseriesChartConfig):
primary = ", ".join(col.name for col in config.y)
secondary = ", ".join(col.name for col in config.y_secondary)
return f"Mixed Chart - {primary} + {secondary}"
what = _mixed_timeseries_what(config)
context = _summarize_filters(config.filters)
else:
return "Chart"
name = what
if context:
name = f"{what} \u2013 {context}"
return _truncate(name)
def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities:
"""Analyze chart capabilities based on type and configuration."""

View File

@@ -264,10 +264,6 @@ async def generate_chart( # noqa: C901
await ctx.report_progress(2, 5, "Creating chart in database")
from superset.commands.chart.create import CreateChartCommand
# Use custom chart name if provided, otherwise auto-generate
chart_name = request.chart_name or generate_chart_name(request.config)
await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
# Find the dataset to get its numeric ID
from superset.daos.dataset import DatasetDAO
@@ -344,6 +340,15 @@ async def generate_chart( # noqa: C901
}
)
# Generate chart name after dataset lookup so we can include dataset name
dataset_name = getattr(dataset, "datasource_name", None) or getattr(
dataset, "table_name", None
)
chart_name = request.chart_name or generate_chart_name(
request.config, dataset_name=dataset_name
)
await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
try:
with event_logger.log_context(action="mcp.generate_chart.db_write"):
command = CreateChartCommand(

View File

@@ -568,8 +568,8 @@ class TestMapConfigToFormData:
class TestGenerateChartName:
"""Test generate_chart_name function"""
def test_generate_table_chart_name(self) -> None:
"""Test generating name for table chart"""
def test_table_no_aggregates(self) -> None:
"""Table without aggregates uses column names."""
config = TableChartConfig(
chart_type="table",
columns=[
@@ -579,25 +579,138 @@ class TestGenerateChartName:
)
result = generate_chart_name(config)
assert result == "Table Chart - product, revenue"
assert result == "Product, Revenue Table"
def test_generate_xy_chart_name(self) -> None:
"""Test generating name for XY chart"""
def test_table_no_aggregates_with_dataset_name(self) -> None:
"""Table without aggregates includes dataset name when available."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product")],
)
result = generate_chart_name(config, dataset_name="Orders")
assert result == "Orders Records"
def test_table_with_aggregates(self) -> None:
"""Table with aggregates produces a summary name."""
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="product"),
ColumnRef(name="revenue", aggregate="SUM"),
],
)
result = generate_chart_name(config)
assert result == "Sum(Revenue) Summary"
def test_line_chart_over_time(self) -> None:
"""Line chart without group_by uses 'Over Time' format."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue"), ColumnRef(name="orders")],
x=ColumnRef(name="order_date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
kind="line",
)
result = generate_chart_name(config)
assert result == "Line Chart - date vs revenue, orders"
assert result == "Sum(Revenue) Over Time"
def test_generate_chart_name_unsupported(self) -> None:
"""Test generating name for unsupported config type"""
def test_bar_chart_by_dimension(self) -> None:
"""Bar chart uses 'by [X]' format."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="product_category"),
y=[ColumnRef(name="order_count", aggregate="COUNT")],
kind="bar",
)
result = generate_chart_name(config)
assert result == "Count(Order Count) by Product Category"
def test_line_chart_with_group_by(self) -> None:
"""Line chart with group_by uses 'by [group]' format."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
kind="line",
group_by=ColumnRef(name="sales_rep"),
)
result = generate_chart_name(config)
assert result == "Sum(Revenue) by Sales Rep"
def test_scatter_plot(self) -> None:
"""Scatter plot uses 'Y vs X' format."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="age"),
y=[ColumnRef(name="income")],
kind="scatter",
)
result = generate_chart_name(config)
assert result == "Income vs Age"
def test_time_grain_in_context(self) -> None:
"""Time grain is appended as context."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
kind="line",
time_grain="P1M",
)
result = generate_chart_name(config)
assert result == "Sum(Revenue) Over Time \u2013 Monthly"
def test_filter_context(self) -> None:
"""Filters are appended as context."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product")],
filters=[FilterConfig(column="region", op="=", value="West")],
)
result = generate_chart_name(config, dataset_name="Orders")
assert result == "Orders Records \u2013 Region West"
def test_name_truncation(self) -> None:
"""Names exceeding 60 chars are truncated."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(
name="very_long_metric_name_that_goes_on_and_on", aggregate="SUM"
)
],
kind="line",
group_by=ColumnRef(name="another_very_long_dimension_name_here"),
)
result = generate_chart_name(config)
assert len(result) <= 60
def test_unsupported_config_type(self) -> None:
"""Unsupported config type returns generic name."""
result = generate_chart_name("invalid_config") # type: ignore
assert result == "Chart"
def test_custom_labels_used(self) -> None:
"""Column labels are preferred over names."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="ds", label="Date"),
y=[ColumnRef(name="cnt", aggregate="COUNT", label="Order Count")],
kind="bar",
)
result = generate_chart_name(config)
assert result == "Order Count by Date"
class TestGenerateExploreLink:
"""Test generate_explore_link function"""

View File

@@ -723,7 +723,7 @@ class TestGenerateChartNameNewTypes:
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
result = generate_chart_name(config)
assert result == "Pie Chart - product by revenue"
assert result == "product by revenue"
def test_pie_chart_name_with_custom_label(self) -> None:
config = PieChartConfig(
@@ -732,7 +732,7 @@ class TestGenerateChartNameNewTypes:
metric=ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"),
)
result = generate_chart_name(config)
assert result == "Pie Chart - product by Total Revenue"
assert result == "product by Total Revenue"
def test_pivot_table_chart_name(self) -> None:
config = PivotTableChartConfig(
@@ -741,7 +741,7 @@ class TestGenerateChartNameNewTypes:
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
result = generate_chart_name(config)
assert result == "Pivot Table - product, region"
assert result == "Pivot Table \u2013 product, region"
def test_mixed_timeseries_chart_name(self) -> None:
config = MixedTimeseriesChartConfig(
@@ -751,7 +751,7 @@ class TestGenerateChartNameNewTypes:
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
result = generate_chart_name(config)
assert result == "Mixed Chart - revenue + orders"
assert result == "revenue + orders"
# ============================================================