diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index ca5e1dfeb47..b1e5be880c1 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -801,34 +801,165 @@ def map_filter_operator(op: str) -> str: return operator_map.get(op, op) +def _humanize_column(col: ColumnRef) -> str: + """Return a human-readable label for a column reference.""" + if col.label: + return col.label + name = col.name.replace("_", " ").title() + if col.aggregate: + return f"{col.aggregate.capitalize()}({name})" + return name + + +def _summarize_filters( + filters: list[Any] | None, +) -> str | None: + """Extract a short context string from filter configs.""" + if not filters: + return None + parts: list[str] = [] + for f in filters[:2]: + col = getattr(f, "column", "") + val = getattr(f, "value", "") + if isinstance(val, list): + val = ", ".join(str(v) for v in val[:3]) + parts.append(f"{str(col).replace('_', ' ').title()} {val}") + return ", ".join(parts) if parts else None + + +def _truncate(name: str, max_length: int = 60) -> str: + """Truncate to *max_length*, preserving the en-dash context portion.""" + if len(name) <= max_length: + return name + if " \u2013 " in name: + what, _context = name.split(" \u2013 ", 1) + if len(what) <= max_length: + return what + return name[: max_length - 1] + "\u2026" + + +def _table_chart_what(config: TableChartConfig, dataset_name: str | None) -> str: + """Build the descriptive fragment for a table chart.""" + has_agg = any(col.aggregate for col in config.columns) + if has_agg: + metrics = [col for col in config.columns if col.aggregate] + what = ", ".join(_humanize_column(m) for m in metrics[:2]) + return f"{what} Summary" + if dataset_name: + return f"{dataset_name} Records" + cols = ", ".join(_humanize_column(c) for c in config.columns[:3]) + return f"{cols} Table" + + +def _xy_chart_what(config: XYChartConfig) -> str: + """Build the descriptive fragment for an XY chart.""" + primary_metric = _humanize_column(config.y[0]) if config.y else "Value" + dimension = _humanize_column(config.x) + + if config.kind in ("line", "area") and config.group_by is None: + return f"{primary_metric} Over Time" + if config.group_by is not None: + group_label = _humanize_column(config.group_by) + return f"{primary_metric} by {group_label}" + if config.kind == "scatter": + return f"{primary_metric} vs {dimension}" + return f"{primary_metric} by {dimension}" + + +_GRAIN_MAP: dict[str, str] = { + "PT1H": "Hourly", + "P1D": "Daily", + "P1W": "Weekly", + "P1M": "Monthly", + "P3M": "Quarterly", + "P1Y": "Yearly", +} + + +def _xy_chart_context(config: XYChartConfig) -> str | None: + """Build context (time grain / filters) for an XY chart name.""" + parts: list[str] = [] + if config.time_grain: + grain_val = ( + config.time_grain.value + if hasattr(config.time_grain, "value") + else str(config.time_grain) + ) + grain_str = _GRAIN_MAP.get(grain_val, grain_val) + parts.append(grain_str) + if filter_ctx := _summarize_filters(config.filters): + parts.append(filter_ctx) + return ", ".join(parts) if parts else None + + +def _pie_chart_what(config: PieChartConfig) -> str: + """Build the 'what' portion for a pie chart name.""" + dim = config.dimension.name + metric_label = config.metric.label or config.metric.name + return f"{dim} by {metric_label}" + + +def _pivot_table_what(config: PivotTableChartConfig) -> str: + """Build the 'what' portion for a pivot table chart name.""" + row_names = ", ".join(r.name for r in config.rows) + return f"Pivot Table \u2013 {row_names}" + + +def _mixed_timeseries_what(config: MixedTimeseriesChartConfig) -> str: + """Build the 'what' portion for a mixed timeseries chart name.""" + primary = config.y[0].label or config.y[0].name if config.y else "primary" + secondary = ( + config.y_secondary[0].label or config.y_secondary[0].name + if config.y_secondary + else "secondary" + ) + return f"{primary} + {secondary}" + + def generate_chart_name( config: TableChartConfig | XYChartConfig | PieChartConfig | PivotTableChartConfig | MixedTimeseriesChartConfig, + dataset_name: str | None = None, ) -> str: - """Generate a chart name based on the configuration.""" + """Generate a descriptive chart name following a standard format. + + Format conventions (by chart type): + Aggregated (bar/scatter with group_by): [Metric] by [Dimension] + Time-series (line/area, no group_by): [Metric] Over Time + Table (no aggregates): [Dataset] Records + Table (with aggregates): [Metric] Summary + Pie: [Dimension] by [Metric] + Pivot Table: Pivot Table – [Row1, Row2] + Mixed Timeseries: [Primary] + [Secondary] + An en-dash followed by context (filters / time grain) is appended + when such information is available. + """ if isinstance(config, TableChartConfig): - return f"Table Chart - {', '.join(col.name for col in config.columns)}" + what = _table_chart_what(config, dataset_name) + context = _summarize_filters(config.filters) elif isinstance(config, XYChartConfig): - chart_type = config.kind.capitalize() - x_col = config.x.name - y_cols = ", ".join(col.name for col in config.y) - return f"{chart_type} Chart - {x_col} vs {y_cols}" + what = _xy_chart_what(config) + context = _xy_chart_context(config) elif isinstance(config, PieChartConfig): - metric_label = config.metric.label or config.metric.name - return f"Pie Chart - {config.dimension.name} by {metric_label}" + what = _pie_chart_what(config) + context = _summarize_filters(config.filters) elif isinstance(config, PivotTableChartConfig): - rows = ", ".join(col.name for col in config.rows) - return f"Pivot Table - {rows}" + what = _pivot_table_what(config) + context = _summarize_filters(config.filters) elif isinstance(config, MixedTimeseriesChartConfig): - primary = ", ".join(col.name for col in config.y) - secondary = ", ".join(col.name for col in config.y_secondary) - return f"Mixed Chart - {primary} + {secondary}" + what = _mixed_timeseries_what(config) + context = _summarize_filters(config.filters) else: return "Chart" + name = what + if context: + name = f"{what} \u2013 {context}" + return _truncate(name) + def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities: """Analyze chart capabilities based on type and configuration.""" diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index b5711d5e5f7..f09d8f89fdf 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -264,10 +264,6 @@ async def generate_chart( # noqa: C901 await ctx.report_progress(2, 5, "Creating chart in database") from superset.commands.chart.create import CreateChartCommand - # Use custom chart name if provided, otherwise auto-generate - chart_name = request.chart_name or generate_chart_name(request.config) - await ctx.debug("Chart name: chart_name=%s" % (chart_name,)) - # Find the dataset to get its numeric ID from superset.daos.dataset import DatasetDAO @@ -344,6 +340,15 @@ async def generate_chart( # noqa: C901 } ) + # Generate chart name after dataset lookup so we can include dataset name + dataset_name = getattr(dataset, "datasource_name", None) or getattr( + dataset, "table_name", None + ) + chart_name = request.chart_name or generate_chart_name( + request.config, dataset_name=dataset_name + ) + await ctx.debug("Chart name: chart_name=%s" % (chart_name,)) + try: with event_logger.log_context(action="mcp.generate_chart.db_write"): command = CreateChartCommand( diff --git a/tests/unit_tests/mcp_service/chart/test_chart_utils.py b/tests/unit_tests/mcp_service/chart/test_chart_utils.py index 9b53024b175..ce5c618e03b 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_utils.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_utils.py @@ -568,8 +568,8 @@ class TestMapConfigToFormData: class TestGenerateChartName: """Test generate_chart_name function""" - def test_generate_table_chart_name(self) -> None: - """Test generating name for table chart""" + def test_table_no_aggregates(self) -> None: + """Table without aggregates uses column names.""" config = TableChartConfig( chart_type="table", columns=[ @@ -579,25 +579,138 @@ class TestGenerateChartName: ) result = generate_chart_name(config) - assert result == "Table Chart - product, revenue" + assert result == "Product, Revenue Table" - def test_generate_xy_chart_name(self) -> None: - """Test generating name for XY chart""" + def test_table_no_aggregates_with_dataset_name(self) -> None: + """Table without aggregates includes dataset name when available.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + ) + + result = generate_chart_name(config, dataset_name="Orders") + assert result == "Orders Records" + + def test_table_with_aggregates(self) -> None: + """Table with aggregates produces a summary name.""" + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="product"), + ColumnRef(name="revenue", aggregate="SUM"), + ], + ) + + result = generate_chart_name(config) + assert result == "Sum(Revenue) Summary" + + def test_line_chart_over_time(self) -> None: + """Line chart without group_by uses 'Over Time' format.""" config = XYChartConfig( chart_type="xy", - x=ColumnRef(name="date"), - y=[ColumnRef(name="revenue"), ColumnRef(name="orders")], + x=ColumnRef(name="order_date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", ) result = generate_chart_name(config) - assert result == "Line Chart - date vs revenue, orders" + assert result == "Sum(Revenue) Over Time" - def test_generate_chart_name_unsupported(self) -> None: - """Test generating name for unsupported config type""" + def test_bar_chart_by_dimension(self) -> None: + """Bar chart uses 'by [X]' format.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="product_category"), + y=[ColumnRef(name="order_count", aggregate="COUNT")], + kind="bar", + ) + + result = generate_chart_name(config) + assert result == "Count(Order Count) by Product Category" + + def test_line_chart_with_group_by(self) -> None: + """Line chart with group_by uses 'by [group]' format.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="line", + group_by=ColumnRef(name="sales_rep"), + ) + + result = generate_chart_name(config) + assert result == "Sum(Revenue) by Sales Rep" + + def test_scatter_plot(self) -> None: + """Scatter plot uses 'Y vs X' format.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="age"), + y=[ColumnRef(name="income")], + kind="scatter", + ) + + result = generate_chart_name(config) + assert result == "Income vs Age" + + def test_time_grain_in_context(self) -> None: + """Time grain is appended as context.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="line", + time_grain="P1M", + ) + + result = generate_chart_name(config) + assert result == "Sum(Revenue) Over Time \u2013 Monthly" + + def test_filter_context(self) -> None: + """Filters are appended as context.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + filters=[FilterConfig(column="region", op="=", value="West")], + ) + + result = generate_chart_name(config, dataset_name="Orders") + assert result == "Orders Records \u2013 Region West" + + def test_name_truncation(self) -> None: + """Names exceeding 60 chars are truncated.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ + ColumnRef( + name="very_long_metric_name_that_goes_on_and_on", aggregate="SUM" + ) + ], + kind="line", + group_by=ColumnRef(name="another_very_long_dimension_name_here"), + ) + + result = generate_chart_name(config) + assert len(result) <= 60 + + def test_unsupported_config_type(self) -> None: + """Unsupported config type returns generic name.""" result = generate_chart_name("invalid_config") # type: ignore assert result == "Chart" + def test_custom_labels_used(self) -> None: + """Column labels are preferred over names.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds", label="Date"), + y=[ColumnRef(name="cnt", aggregate="COUNT", label="Order Count")], + kind="bar", + ) + + result = generate_chart_name(config) + assert result == "Order Count by Date" + class TestGenerateExploreLink: """Test generate_explore_link function""" diff --git a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py index e930ab83091..e2e37ba99e6 100644 --- a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py +++ b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py @@ -723,7 +723,7 @@ class TestGenerateChartNameNewTypes: metric=ColumnRef(name="revenue", aggregate="SUM"), ) result = generate_chart_name(config) - assert result == "Pie Chart - product by revenue" + assert result == "product by revenue" def test_pie_chart_name_with_custom_label(self) -> None: config = PieChartConfig( @@ -732,7 +732,7 @@ class TestGenerateChartNameNewTypes: metric=ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"), ) result = generate_chart_name(config) - assert result == "Pie Chart - product by Total Revenue" + assert result == "product by Total Revenue" def test_pivot_table_chart_name(self) -> None: config = PivotTableChartConfig( @@ -741,7 +741,7 @@ class TestGenerateChartNameNewTypes: metrics=[ColumnRef(name="revenue", aggregate="SUM")], ) result = generate_chart_name(config) - assert result == "Pivot Table - product, region" + assert result == "Pivot Table \u2013 product, region" def test_mixed_timeseries_chart_name(self) -> None: config = MixedTimeseriesChartConfig( @@ -751,7 +751,7 @@ class TestGenerateChartNameNewTypes: y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], ) result = generate_chart_name(config) - assert result == "Mixed Chart - revenue + orders" + assert result == "revenue + orders" # ============================================================