mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
fix(mcp): improve default chart names with descriptive format (#38406)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -801,34 +801,165 @@ def map_filter_operator(op: str) -> str:
|
||||
return operator_map.get(op, op)
|
||||
|
||||
|
||||
def _humanize_column(col: ColumnRef) -> str:
|
||||
"""Return a human-readable label for a column reference."""
|
||||
if col.label:
|
||||
return col.label
|
||||
name = col.name.replace("_", " ").title()
|
||||
if col.aggregate:
|
||||
return f"{col.aggregate.capitalize()}({name})"
|
||||
return name
|
||||
|
||||
|
||||
def _summarize_filters(
|
||||
filters: list[Any] | None,
|
||||
) -> str | None:
|
||||
"""Extract a short context string from filter configs."""
|
||||
if not filters:
|
||||
return None
|
||||
parts: list[str] = []
|
||||
for f in filters[:2]:
|
||||
col = getattr(f, "column", "")
|
||||
val = getattr(f, "value", "")
|
||||
if isinstance(val, list):
|
||||
val = ", ".join(str(v) for v in val[:3])
|
||||
parts.append(f"{str(col).replace('_', ' ').title()} {val}")
|
||||
return ", ".join(parts) if parts else None
|
||||
|
||||
|
||||
def _truncate(name: str, max_length: int = 60) -> str:
|
||||
"""Truncate to *max_length*, preserving the en-dash context portion."""
|
||||
if len(name) <= max_length:
|
||||
return name
|
||||
if " \u2013 " in name:
|
||||
what, _context = name.split(" \u2013 ", 1)
|
||||
if len(what) <= max_length:
|
||||
return what
|
||||
return name[: max_length - 1] + "\u2026"
|
||||
|
||||
|
||||
def _table_chart_what(config: TableChartConfig, dataset_name: str | None) -> str:
|
||||
"""Build the descriptive fragment for a table chart."""
|
||||
has_agg = any(col.aggregate for col in config.columns)
|
||||
if has_agg:
|
||||
metrics = [col for col in config.columns if col.aggregate]
|
||||
what = ", ".join(_humanize_column(m) for m in metrics[:2])
|
||||
return f"{what} Summary"
|
||||
if dataset_name:
|
||||
return f"{dataset_name} Records"
|
||||
cols = ", ".join(_humanize_column(c) for c in config.columns[:3])
|
||||
return f"{cols} Table"
|
||||
|
||||
|
||||
def _xy_chart_what(config: XYChartConfig) -> str:
|
||||
"""Build the descriptive fragment for an XY chart."""
|
||||
primary_metric = _humanize_column(config.y[0]) if config.y else "Value"
|
||||
dimension = _humanize_column(config.x)
|
||||
|
||||
if config.kind in ("line", "area") and config.group_by is None:
|
||||
return f"{primary_metric} Over Time"
|
||||
if config.group_by is not None:
|
||||
group_label = _humanize_column(config.group_by)
|
||||
return f"{primary_metric} by {group_label}"
|
||||
if config.kind == "scatter":
|
||||
return f"{primary_metric} vs {dimension}"
|
||||
return f"{primary_metric} by {dimension}"
|
||||
|
||||
|
||||
_GRAIN_MAP: dict[str, str] = {
|
||||
"PT1H": "Hourly",
|
||||
"P1D": "Daily",
|
||||
"P1W": "Weekly",
|
||||
"P1M": "Monthly",
|
||||
"P3M": "Quarterly",
|
||||
"P1Y": "Yearly",
|
||||
}
|
||||
|
||||
|
||||
def _xy_chart_context(config: XYChartConfig) -> str | None:
|
||||
"""Build context (time grain / filters) for an XY chart name."""
|
||||
parts: list[str] = []
|
||||
if config.time_grain:
|
||||
grain_val = (
|
||||
config.time_grain.value
|
||||
if hasattr(config.time_grain, "value")
|
||||
else str(config.time_grain)
|
||||
)
|
||||
grain_str = _GRAIN_MAP.get(grain_val, grain_val)
|
||||
parts.append(grain_str)
|
||||
if filter_ctx := _summarize_filters(config.filters):
|
||||
parts.append(filter_ctx)
|
||||
return ", ".join(parts) if parts else None
|
||||
|
||||
|
||||
def _pie_chart_what(config: PieChartConfig) -> str:
|
||||
"""Build the 'what' portion for a pie chart name."""
|
||||
dim = config.dimension.name
|
||||
metric_label = config.metric.label or config.metric.name
|
||||
return f"{dim} by {metric_label}"
|
||||
|
||||
|
||||
def _pivot_table_what(config: PivotTableChartConfig) -> str:
|
||||
"""Build the 'what' portion for a pivot table chart name."""
|
||||
row_names = ", ".join(r.name for r in config.rows)
|
||||
return f"Pivot Table \u2013 {row_names}"
|
||||
|
||||
|
||||
def _mixed_timeseries_what(config: MixedTimeseriesChartConfig) -> str:
|
||||
"""Build the 'what' portion for a mixed timeseries chart name."""
|
||||
primary = config.y[0].label or config.y[0].name if config.y else "primary"
|
||||
secondary = (
|
||||
config.y_secondary[0].label or config.y_secondary[0].name
|
||||
if config.y_secondary
|
||||
else "secondary"
|
||||
)
|
||||
return f"{primary} + {secondary}"
|
||||
|
||||
|
||||
def generate_chart_name(
|
||||
config: TableChartConfig
|
||||
| XYChartConfig
|
||||
| PieChartConfig
|
||||
| PivotTableChartConfig
|
||||
| MixedTimeseriesChartConfig,
|
||||
dataset_name: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a chart name based on the configuration."""
|
||||
"""Generate a descriptive chart name following a standard format.
|
||||
|
||||
Format conventions (by chart type):
|
||||
Aggregated (bar/scatter with group_by): [Metric] by [Dimension]
|
||||
Time-series (line/area, no group_by): [Metric] Over Time
|
||||
Table (no aggregates): [Dataset] Records
|
||||
Table (with aggregates): [Metric] Summary
|
||||
Pie: [Dimension] by [Metric]
|
||||
Pivot Table: Pivot Table – [Row1, Row2]
|
||||
Mixed Timeseries: [Primary] + [Secondary]
|
||||
An en-dash followed by context (filters / time grain) is appended
|
||||
when such information is available.
|
||||
"""
|
||||
if isinstance(config, TableChartConfig):
|
||||
return f"Table Chart - {', '.join(col.name for col in config.columns)}"
|
||||
what = _table_chart_what(config, dataset_name)
|
||||
context = _summarize_filters(config.filters)
|
||||
elif isinstance(config, XYChartConfig):
|
||||
chart_type = config.kind.capitalize()
|
||||
x_col = config.x.name
|
||||
y_cols = ", ".join(col.name for col in config.y)
|
||||
return f"{chart_type} Chart - {x_col} vs {y_cols}"
|
||||
what = _xy_chart_what(config)
|
||||
context = _xy_chart_context(config)
|
||||
elif isinstance(config, PieChartConfig):
|
||||
metric_label = config.metric.label or config.metric.name
|
||||
return f"Pie Chart - {config.dimension.name} by {metric_label}"
|
||||
what = _pie_chart_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
elif isinstance(config, PivotTableChartConfig):
|
||||
rows = ", ".join(col.name for col in config.rows)
|
||||
return f"Pivot Table - {rows}"
|
||||
what = _pivot_table_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
elif isinstance(config, MixedTimeseriesChartConfig):
|
||||
primary = ", ".join(col.name for col in config.y)
|
||||
secondary = ", ".join(col.name for col in config.y_secondary)
|
||||
return f"Mixed Chart - {primary} + {secondary}"
|
||||
what = _mixed_timeseries_what(config)
|
||||
context = _summarize_filters(config.filters)
|
||||
else:
|
||||
return "Chart"
|
||||
|
||||
name = what
|
||||
if context:
|
||||
name = f"{what} \u2013 {context}"
|
||||
return _truncate(name)
|
||||
|
||||
|
||||
def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities:
|
||||
"""Analyze chart capabilities based on type and configuration."""
|
||||
|
||||
@@ -264,10 +264,6 @@ async def generate_chart( # noqa: C901
|
||||
await ctx.report_progress(2, 5, "Creating chart in database")
|
||||
from superset.commands.chart.create import CreateChartCommand
|
||||
|
||||
# Use custom chart name if provided, otherwise auto-generate
|
||||
chart_name = request.chart_name or generate_chart_name(request.config)
|
||||
await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
|
||||
|
||||
# Find the dataset to get its numeric ID
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
@@ -344,6 +340,15 @@ async def generate_chart( # noqa: C901
|
||||
}
|
||||
)
|
||||
|
||||
# Generate chart name after dataset lookup so we can include dataset name
|
||||
dataset_name = getattr(dataset, "datasource_name", None) or getattr(
|
||||
dataset, "table_name", None
|
||||
)
|
||||
chart_name = request.chart_name or generate_chart_name(
|
||||
request.config, dataset_name=dataset_name
|
||||
)
|
||||
await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
|
||||
|
||||
try:
|
||||
with event_logger.log_context(action="mcp.generate_chart.db_write"):
|
||||
command = CreateChartCommand(
|
||||
|
||||
@@ -568,8 +568,8 @@ class TestMapConfigToFormData:
|
||||
class TestGenerateChartName:
|
||||
"""Test generate_chart_name function"""
|
||||
|
||||
def test_generate_table_chart_name(self) -> None:
|
||||
"""Test generating name for table chart"""
|
||||
def test_table_no_aggregates(self) -> None:
|
||||
"""Table without aggregates uses column names."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
@@ -579,25 +579,138 @@ class TestGenerateChartName:
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Table Chart - product, revenue"
|
||||
assert result == "Product, Revenue Table"
|
||||
|
||||
def test_generate_xy_chart_name(self) -> None:
|
||||
"""Test generating name for XY chart"""
|
||||
def test_table_no_aggregates_with_dataset_name(self) -> None:
|
||||
"""Table without aggregates includes dataset name when available."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="product")],
|
||||
)
|
||||
|
||||
result = generate_chart_name(config, dataset_name="Orders")
|
||||
assert result == "Orders Records"
|
||||
|
||||
def test_table_with_aggregates(self) -> None:
|
||||
"""Table with aggregates produces a summary name."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="product"),
|
||||
ColumnRef(name="revenue", aggregate="SUM"),
|
||||
],
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Sum(Revenue) Summary"
|
||||
|
||||
def test_line_chart_over_time(self) -> None:
|
||||
"""Line chart without group_by uses 'Over Time' format."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue"), ColumnRef(name="orders")],
|
||||
x=ColumnRef(name="order_date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
kind="line",
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Line Chart - date vs revenue, orders"
|
||||
assert result == "Sum(Revenue) Over Time"
|
||||
|
||||
def test_generate_chart_name_unsupported(self) -> None:
|
||||
"""Test generating name for unsupported config type"""
|
||||
def test_bar_chart_by_dimension(self) -> None:
|
||||
"""Bar chart uses 'by [X]' format."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="product_category"),
|
||||
y=[ColumnRef(name="order_count", aggregate="COUNT")],
|
||||
kind="bar",
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Count(Order Count) by Product Category"
|
||||
|
||||
def test_line_chart_with_group_by(self) -> None:
|
||||
"""Line chart with group_by uses 'by [group]' format."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="sales_rep"),
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Sum(Revenue) by Sales Rep"
|
||||
|
||||
def test_scatter_plot(self) -> None:
|
||||
"""Scatter plot uses 'Y vs X' format."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="age"),
|
||||
y=[ColumnRef(name="income")],
|
||||
kind="scatter",
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Income vs Age"
|
||||
|
||||
def test_time_grain_in_context(self) -> None:
|
||||
"""Time grain is appended as context."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
kind="line",
|
||||
time_grain="P1M",
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Sum(Revenue) Over Time \u2013 Monthly"
|
||||
|
||||
def test_filter_context(self) -> None:
|
||||
"""Filters are appended as context."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="product")],
|
||||
filters=[FilterConfig(column="region", op="=", value="West")],
|
||||
)
|
||||
|
||||
result = generate_chart_name(config, dataset_name="Orders")
|
||||
assert result == "Orders Records \u2013 Region West"
|
||||
|
||||
def test_name_truncation(self) -> None:
|
||||
"""Names exceeding 60 chars are truncated."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(
|
||||
name="very_long_metric_name_that_goes_on_and_on", aggregate="SUM"
|
||||
)
|
||||
],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="another_very_long_dimension_name_here"),
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert len(result) <= 60
|
||||
|
||||
def test_unsupported_config_type(self) -> None:
|
||||
"""Unsupported config type returns generic name."""
|
||||
result = generate_chart_name("invalid_config") # type: ignore
|
||||
assert result == "Chart"
|
||||
|
||||
def test_custom_labels_used(self) -> None:
|
||||
"""Column labels are preferred over names."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="ds", label="Date"),
|
||||
y=[ColumnRef(name="cnt", aggregate="COUNT", label="Order Count")],
|
||||
kind="bar",
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Order Count by Date"
|
||||
|
||||
|
||||
class TestGenerateExploreLink:
|
||||
"""Test generate_explore_link function"""
|
||||
|
||||
@@ -723,7 +723,7 @@ class TestGenerateChartNameNewTypes:
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM"),
|
||||
)
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Pie Chart - product by revenue"
|
||||
assert result == "product by revenue"
|
||||
|
||||
def test_pie_chart_name_with_custom_label(self) -> None:
|
||||
config = PieChartConfig(
|
||||
@@ -732,7 +732,7 @@ class TestGenerateChartNameNewTypes:
|
||||
metric=ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"),
|
||||
)
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Pie Chart - product by Total Revenue"
|
||||
assert result == "product by Total Revenue"
|
||||
|
||||
def test_pivot_table_chart_name(self) -> None:
|
||||
config = PivotTableChartConfig(
|
||||
@@ -741,7 +741,7 @@ class TestGenerateChartNameNewTypes:
|
||||
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
)
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Pivot Table - product, region"
|
||||
assert result == "Pivot Table \u2013 product, region"
|
||||
|
||||
def test_mixed_timeseries_chart_name(self) -> None:
|
||||
config = MixedTimeseriesChartConfig(
|
||||
@@ -751,7 +751,7 @@ class TestGenerateChartNameNewTypes:
|
||||
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
|
||||
)
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Mixed Chart - revenue + orders"
|
||||
assert result == "revenue + orders"
|
||||
|
||||
|
||||
# ============================================================
|
||||
|
||||
Reference in New Issue
Block a user