From e6853894ab1c63d8537efcfcaa9c7c19e011ec3e Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Tue, 21 Apr 2026 20:10:49 -0400 Subject: [PATCH] chore(mcp): extract shared chart helpers and ASCII rendering into separate modules (#39438) --- superset/mcp_service/chart/ascii_charts.py | 863 +++++++++++++++++ superset/mcp_service/chart/chart_helpers.py | 81 ++ .../mcp_service/chart/tool/generate_chart.py | 11 +- .../mcp_service/chart/tool/get_chart_data.py | 60 +- .../mcp_service/chart/tool/get_chart_info.py | 22 +- .../chart/tool/get_chart_preview.py | 912 +----------------- .../mcp_service/chart/tool/update_chart.py | 37 +- .../chart/tool/update_chart_preview.py | 31 +- .../explore/tool/generate_explore_link.py | 12 +- .../mcp_service/chart/test_chart_helpers.py | 108 +++ .../chart/tool/test_update_chart.py | 12 +- 11 files changed, 1156 insertions(+), 993 deletions(-) create mode 100644 superset/mcp_service/chart/ascii_charts.py create mode 100644 superset/mcp_service/chart/chart_helpers.py create mode 100644 tests/unit_tests/mcp_service/chart/test_chart_helpers.py diff --git a/superset/mcp_service/chart/ascii_charts.py b/superset/mcp_service/chart/ascii_charts.py new file mode 100644 index 00000000000..b5aa4d71f49 --- /dev/null +++ b/superset/mcp_service/chart/ascii_charts.py @@ -0,0 +1,863 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +ASCII chart rendering functions for MCP chart previews. + +Pure rendering functions that take data in and return strings out. +No dependencies on MCP tools, FastMCP Context, or request objects. +""" + +from __future__ import annotations + +import logging +import math +from typing import Any + +logger = logging.getLogger(__name__) + + +def generate_ascii_chart( + data: list[Any], chart_type: str, width: int = 80, height: int = 20 +) -> str: + """Generate ASCII art chart from data.""" + if not data: + return "No data available for ASCII chart" + + try: + # Clamp to safe minimums to prevent negative plot sizes + width = max(width, 21) + height = max(height, 9) + + logger.debug( + "generate_ascii_chart: chart_type=%s, data_rows=%s", chart_type, len(data) + ) + + if chart_type in ["bar", "column", "echarts_timeseries_bar"]: + return _generate_ascii_bar_chart(data, width, height) + elif chart_type in ["line", "echarts_timeseries_line"]: + return _generate_ascii_line_chart(data, width, height) + elif chart_type in ["scatter", "echarts_timeseries_scatter"]: + return _generate_ascii_scatter_chart(data, width, height) + else: + logger.debug( + "Unsupported chart type '%s', falling back to table", chart_type + ) + return generate_ascii_table(data, width) + except (TypeError, ValueError, KeyError, IndexError) as e: + logger.error("ASCII chart generation failed: %s", e, exc_info=True) + return "ASCII chart generation failed" + + +def _generate_ascii_bar_chart(data: list[Any], width: int, height: int) -> str: + """Generate enhanced ASCII bar chart with horizontal and vertical options.""" + if not data: + return "No data for bar chart" + + # Extract numeric values for bars + values = [] + labels = [] + + for row in data[:12]: # Increased limit for better charts + if isinstance(row, dict): + # Find numeric and string values + numeric_val = None + label_val = None + + for _key, val in row.items(): + if isinstance(val, (int, float)) and numeric_val is None: + numeric_val = val + elif isinstance(val, str) and label_val is None: + label_val = val + + if numeric_val is not None: + values.append(numeric_val) + labels.append(label_val or f"Item {len(values)}") + + if not values: + return "No numeric data found for bar chart" + + # Decide between horizontal and vertical based on label lengths + avg_label_length = sum(len(str(label)) for label in labels) / len(labels) + use_horizontal = avg_label_length > 8 or len(values) > 8 + + if use_horizontal: + return _generate_horizontal_bar_chart(values, labels, width) + else: + return _generate_vertical_bar_chart(values, labels, width, height) + + +def _generate_horizontal_bar_chart( + values: list[float], labels: list[str], width: int +) -> str: + """Generate horizontal ASCII bar chart.""" + lines = [] + lines.append("šŸ“Š Horizontal Bar Chart") + lines.append("═" * min(width, 60)) + + max_val = max(values) if values else 1 + min_val = min(values) if values else 0 + max_bar_width = min(40, width - 20) # Leave space for labels and values + + # Add scale indicator + lines.append(f"Scale: {min_val:.1f} ────────────── {max_val:.1f}") + lines.append("") + + for label, value in zip(labels, values, strict=False): + # Calculate bar length + if max_val > min_val: + normalized = (value - min_val) / (max_val - min_val) + bar_length = max(1, int(normalized * max_bar_width)) + else: + bar_length = 1 + + # Create bar with gradient effect + bar = _create_gradient_bar(bar_length, value, max_val) + + # Format value + if abs(value) >= 1000000: + value_str = f"{value / 1000000:.1f}M" + elif abs(value) >= 1000: + value_str = f"{value / 1000:.1f}K" + else: + value_str = f"{value:.1f}" + + # Truncate label if too long + display_label = label[:15] if len(label) > 15 else label + lines.append(f"{display_label:>15} ▐{bar:<{max_bar_width}} {value_str}") + + return "\n".join(lines) + + +def _generate_vertical_bar_chart( # noqa: C901 + values: list[float], labels: list[str], width: int, height: int +) -> str: + """Generate vertical ASCII bar chart.""" + lines = [] + lines.append("šŸ“Š Vertical Bar Chart") + lines.append("═" * min(width, 60)) + + max_val = max(values) if values else 1 + min_val = min(values) if values else 0 + chart_height = min(15, height - 8) # Leave space for title and labels + + # Create the chart grid + grid = [] + for _ in range(chart_height): + grid.append([" "] * len(values)) + + # Fill the bars + for col, value in enumerate(values): + if max_val > min_val: + normalized = (value - min_val) / (max_val - min_val) + bar_height = max(1, int(normalized * chart_height)) + else: + bar_height = 1 + + # Fill from bottom up + for row_idx in range(chart_height - bar_height, chart_height): + if row_idx < len(grid): + # Use different characters for height effect + if row_idx == chart_height - bar_height: + grid[row_idx][col] = "ā–€" # Top of bar + elif row_idx == chart_height - 1: + grid[row_idx][col] = "ā–ˆ" # Bottom of bar + else: + grid[row_idx][col] = "ā–ˆ" # Middle of bar + + # Add Y-axis scale + for i, row_data in enumerate(grid): + y_val = ( + max_val - (i / (chart_height - 1)) * (max_val - min_val) + if chart_height > 1 + else max_val + ) + if abs(y_val) >= 1000: + y_label = f"{y_val:.0f}" + else: + y_label = f"{y_val:.1f}" + lines.append(f"{y_label:>6} ┤ " + "".join(f"{cell:^3}" for cell in row_data)) + + # Add X-axis + lines.append(" ā””" + "───" * len(values)) + + # Add labels + label_line = " " + for label in labels: + short_label = label[:3] if len(label) > 3 else label + label_line += f"{short_label:^3}" + lines.append(label_line) + + return "\n".join(lines) + + +def _create_gradient_bar(length: int, value: float, max_val: float) -> str: + """Create a gradient bar with visual effects.""" + if length <= 0: + return "" + + # Create gradient effect based on value intensity + intensity = value / max_val if max_val > 0 else 0 + + if intensity > 0.8: + # High values - solid bars + return "ā–ˆ" * length + elif intensity > 0.6: + # Medium-high values - mostly solid with some texture + return "ā–ˆ" * (length - 1) + "ā–‰" if length > 1 else "ā–ˆ" + elif intensity > 0.4: + # Medium values - mixed texture + return "ā–Š" * length + elif intensity > 0.2: + # Low-medium values - lighter texture + return "ā–‹" * length + else: + # Low values - lightest texture + return "ā–Œ" * length + + +def _generate_ascii_line_chart(data: list[Any], width: int, height: int) -> str: + """Generate enhanced ASCII line chart with trend analysis.""" + if not data: + return "No data for line chart" + + lines = [] + lines.append("šŸ“ˆ Line Chart with Trend Analysis") + lines.append("═" * min(width, 60)) + + # Extract values and labels for plotting + values, labels = _extract_time_series_data(data) + + if not values: + return "No numeric data found for line chart" + + # Generate enhanced line chart + if len(values) >= 3: + lines.extend(_create_enhanced_line_chart(values, labels, width, height)) + else: + # Fallback to sparkline for small datasets + sparkline_data = _create_sparkline(values) + lines.extend(sparkline_data) + + # Add trend analysis + trend_analysis = _analyze_trend(values) + lines.append("") + lines.append("šŸ“Š Trend Analysis:") + lines.extend(trend_analysis) + + return "\n".join(lines) + + +def _extract_time_series_data(data: list[Any]) -> tuple[list[float], list[str]]: + """Extract time series data with labels.""" + values = [] + labels = [] + + for row in data[:20]: # Limit points for readability + if isinstance(row, dict): + # Find the first numeric value and first string/date value + numeric_val = None + label_val = None + + for key, val in row.items(): + if isinstance(val, (int, float)) and numeric_val is None: + numeric_val = val + elif isinstance(val, str) and label_val is None: + # Use the key name if it looks like a date/time field + if any( + date_word in key.lower() + for date_word in ["date", "time", "month", "day", "year"] + ): + label_val = str(val)[:10] # Truncate long dates + else: + label_val = str(val)[:8] # Truncate long strings + + if numeric_val is not None: + values.append(numeric_val) + labels.append(label_val or f"P{len(values)}") + + return values, labels + + +def _create_enhanced_line_chart( + values: list[float], labels: list[str], width: int, height: int +) -> list[str]: + """Create an enhanced ASCII line chart with better visualization.""" + lines = [] + chart_width = min(50, width - 15) + chart_height = min(12, height - 8) + + if len(values) < 2: + return ["Insufficient data for line chart"] + + # Normalize values to chart height + min_val = min(values) + max_val = max(values) + val_range = max_val - min_val if max_val != min_val else 1 + + # Create chart grid + grid = [[" " for _ in range(chart_width)] for _ in range(chart_height)] + + # Plot the line with connecting segments + prev_x, prev_y = None, None + + for i, value in enumerate(values): + # Map to grid coordinates + x = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0 + y = chart_height - 1 - int(((value - min_val) / val_range) * (chart_height - 1)) + + # Ensure coordinates are within bounds + x = max(0, min(x, chart_width - 1)) + y = max(0, min(y, chart_height - 1)) + + # Mark the point + grid[y][x] = "ā—" + + # Draw line segment to previous point + if prev_x is not None and prev_y is not None: + _draw_line_segment(grid, prev_x, prev_y, x, y, chart_width, chart_height) + + prev_x, prev_y = x, y + + # Render the chart with Y-axis labels + for i, row in enumerate(grid): + y_val = ( + max_val - (i / (chart_height - 1)) * val_range + if chart_height > 1 + else max_val + ) + if abs(y_val) >= 1000: + y_label = f"{y_val:.0f}" + else: + y_label = f"{y_val:.1f}" + lines.append(f"{y_label:>8} ┤ " + "".join(row)) + + # Add X-axis + lines.append(" ā””" + "─" * chart_width) + + # Add selected X-axis labels (show every few labels to avoid crowding) + label_line = " " + step = max(1, len(labels) // 6) # Show max 6 labels + for i in range(0, len(labels), step): + pos = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0 + # Add spacing to position the label + while len(label_line) - 10 < pos: + label_line += " " + label_line += labels[i][:8] + + lines.append(label_line) + + return lines + + +def _draw_line_segment( + grid: list[list[str]], x1: int, y1: int, x2: int, y2: int, width: int, height: int +) -> None: + """Draw a line segment between two points using Bresenham-like algorithm.""" + # Simple line drawing - connect points with line characters + if x1 == x2: # Vertical line + start_y, end_y = sorted([y1, y2]) + for y in range(start_y + 1, end_y): + if 0 <= y < height and 0 <= x1 < width: + grid[y][x1] = "│" + elif y1 == y2: # Horizontal line + start_x, end_x = sorted([x1, x2]) + for x in range(start_x + 1, end_x): + if 0 <= y1 < height and 0 <= x < width: + grid[y1][x] = "─" + else: # Diagonal line - use simple interpolation + steps = max(abs(x2 - x1), abs(y2 - y1)) + for step in range(1, steps): + x = x1 + int((x2 - x1) * step / steps) + y = y1 + int((y2 - y1) * step / steps) + if 0 <= x < width and 0 <= y < height: + if abs(x2 - x1) > abs(y2 - y1): + grid[y][x] = "─" + else: + grid[y][x] = "│" + + +def _analyze_trend(values: list[float]) -> list[str]: + """Analyze trend in the data.""" + if len(values) < 2: + return ["• Insufficient data for trend analysis"] + + analysis = [] + + # Calculate basic statistics + first_val = values[0] + last_val = values[-1] + min_val = min(values) + max_val = max(values) + avg_val = sum(values) / len(values) + + # Overall trend + if last_val > first_val * 1.1: + trend_icon = "šŸ“ˆ" + trend_desc = "Strong upward trend" + elif last_val > first_val * 1.05: + trend_icon = "šŸ“Š" + trend_desc = "Moderate upward trend" + elif last_val < first_val * 0.9: + trend_icon = "šŸ“‰" + trend_desc = "Strong downward trend" + elif last_val < first_val * 0.95: + trend_icon = "šŸ“Š" + trend_desc = "Moderate downward trend" + else: + trend_icon = "āž”ļø" + trend_desc = "Relatively stable" + + analysis.append(f"• {trend_icon} {trend_desc}") + analysis.append(f"• Range: {min_val:.1f} to {max_val:.1f} (avg: {avg_val:.1f})") + + # Volatility + if len(values) >= 3: + changes = [abs(values[i] - values[i - 1]) for i in range(1, len(values))] + avg_change = sum(changes) / len(changes) + volatility = "High" if avg_change > (max_val - min_val) * 0.1 else "Low" + analysis.append(f"• Volatility: {volatility}") + + return analysis + + +def _create_sparkline(values: list[float]) -> list[str]: + """Create sparkline visualization from values.""" + if len(values) <= 1: + return [] + + max_val = max(values) + min_val = min(values) + range_val = max_val - min_val if max_val != min_val else 1 + + sparkline = "" + for val in values: + normalized = (val - min_val) / range_val + if normalized < 0.2: + sparkline += "▁" + elif normalized < 0.4: + sparkline += "ā–‚" + elif normalized < 0.6: + sparkline += "ā–„" + elif normalized < 0.8: + sparkline += "ā–†" + else: + sparkline += "ā–ˆ" + + # Safe formatting to avoid NaN display + if _is_nan_value(min_val) or _is_nan_value(max_val): + return ["Range: Unable to calculate from data", sparkline] + else: + return [f"Range: {min_val:.2f} to {max_val:.2f}", sparkline] + + +def _is_nan_value(value: Any) -> bool: + """Check if a value is NaN or invalid.""" + try: + return math.isnan(float(value)) + except (ValueError, TypeError): + return True + + +def _generate_ascii_scatter_chart(data: list[Any], width: int, height: int) -> str: + """Generate ASCII scatter plot.""" + if not data: + return "No data for scatter chart" + + lines = [] + lines.append("ASCII Scatter Plot") + lines.append("=" * min(width, 50)) + + # Extract data points + x_values, y_values, x_column, y_column = _extract_scatter_data(data) + + # Log debug info server-side only + logger.debug( + "Scatter chart: x_column=%s, y_column=%s, valid_pairs=%d", + x_column, + y_column, + len(x_values), + ) + + # Check if we have enough data + if len(x_values) < 2: + return generate_ascii_table(data, width) + + # Add axis info + lines.extend(_create_axis_info(x_values, y_values, x_column, y_column)) + + # Create and render grid + grid = _create_scatter_grid(x_values, y_values, width, height) + lines.extend(_render_scatter_grid(grid, x_values, y_values, width, height)) + + return "\n".join(lines) + + +def _extract_scatter_data( + data: list[Any], +) -> tuple[list[float], list[float], str | None, str | None]: + """Extract X,Y data from scatter chart data.""" + x_values = [] + y_values = [] + x_column = None + y_column = None + numeric_columns = [] + + if data and isinstance(data[0], dict): + # Find the first two numeric columns + for key, val in data[0].items(): + if isinstance(val, (int, float)) and not ( + isinstance(val, float) and (val != val) + ): # Exclude NaN + numeric_columns.append(key) + + if len(numeric_columns) >= 2: + x_column = numeric_columns[0] + y_column = numeric_columns[1] + + # Extract X,Y pairs + for row in data[:50]: # Limit for ASCII display + if isinstance(row, dict): + x_val = row.get(x_column) + y_val = row.get(y_column) + # Check for valid numbers (not NaN) + if ( + isinstance(x_val, (int, float)) + and isinstance(y_val, (int, float)) + and not ( + isinstance(x_val, float) and (x_val != x_val) + ) # Not NaN + and not (isinstance(y_val, float) and (y_val != y_val)) + ): # Not NaN + x_values.append(x_val) + y_values.append(y_val) + + return x_values, y_values, x_column, y_column + + +def _create_axis_info( + x_values: list[float], + y_values: list[float], + x_column: str | None, + y_column: str | None, +) -> list[str]: + """Create axis information lines.""" + return [ + f"X-axis: {x_column} (range: {min(x_values):.2f} to {max(x_values):.2f})", + f"Y-axis: {y_column} (range: {min(y_values):.2f} to {max(y_values):.2f})", + f"Showing {len(x_values)} data points", + "", + ] + + +def _create_scatter_grid( + x_values: list[float], y_values: list[float], width: int, height: int +) -> list[list[str]]: + """Create and populate the scatter plot grid.""" + plot_width = min(40, width - 10) + plot_height = min(15, height - 8) + + # Normalize values to fit in grid + x_min, x_max = min(x_values), max(x_values) + y_min, y_max = min(y_values), max(y_values) + x_range = x_max - x_min if x_max != x_min else 1 + y_range = y_max - y_min if y_max != y_min else 1 + + # Create grid + grid = [[" " for _ in range(plot_width)] for _ in range(plot_height)] + + # Plot points + for x, y in zip(x_values, y_values, strict=False): + try: + grid_x = int(((x - x_min) / x_range) * (plot_width - 1)) + grid_y = int(((y - y_min) / y_range) * (plot_height - 1)) + except (ValueError, OverflowError): + # Skip points that can't be converted to integers (NaN, inf, etc.) + continue + grid_y = plot_height - 1 - grid_y # Flip Y axis for display + + if 0 <= grid_x < plot_width and 0 <= grid_y < plot_height: + if grid[grid_y][grid_x] == " ": + grid[grid_y][grid_x] = "•" + else: + grid[grid_y][grid_x] = "ā–ˆ" # Multiple points + + return grid + + +def _render_scatter_grid( + grid: list[list[str]], + x_values: list[float], + y_values: list[float], + width: int, + height: int, +) -> list[str]: + """Render the scatter plot grid with axes and labels.""" + lines = [] + plot_width = min(40, width - 10) + plot_height = min(15, height - 8) + + x_min, x_max = min(x_values), max(x_values) + y_min, y_max = min(y_values), max(y_values) + y_range = y_max - y_min if y_max != y_min else 1 + + # Add Y-axis labels and plot + for i, row in enumerate(grid): + y_val = y_max - (i / (plot_height - 1)) * y_range if plot_height > 1 else y_max + y_label = f"{y_val:.1f}" if abs(y_val) < 1000 else f"{y_val:.0f}" + lines.append(f"{y_label:>6} |{''.join(row)}") + + # Add X-axis + x_axis_line = " " * 7 + "+" + "-" * plot_width + lines.append(x_axis_line) + + # Add X-axis labels + x_left_label = f"{x_min:.1f}" if abs(x_min) < 1000 else f"{x_min:.0f}" + x_right_label = f"{x_max:.1f}" if abs(x_max) < 1000 else f"{x_max:.0f}" + x_labels = ( + " " * 8 + + x_left_label + + " " * (plot_width - len(x_left_label) - len(x_right_label)) + + x_right_label + ) + lines.append(x_labels) + + return lines + + +def generate_ascii_table(data: list[Any], width: int) -> str: + """Generate enhanced ASCII table with better formatting.""" + if not data: + return "No data for table" + + lines = [] + lines.append("šŸ“‹ Data Table") + lines.append("═" * min(width, 70)) + + # Get column headers from first row + if isinstance(data[0], dict): + # Select best columns to display + all_headers = list(data[0].keys()) + headers = _select_display_columns(all_headers, data, max_cols=6) + + # Calculate optimal column widths + col_widths = _calculate_column_widths(headers, data, width) + + # Create enhanced header + lines.append(_create_table_header(headers, col_widths)) + lines.append(_create_table_separator(col_widths)) + + # Add data rows with better formatting + row_count = min(15, len(data)) # Show more rows + for i, row in enumerate(data[:row_count]): + formatted_row = _format_table_row(row, headers, col_widths) + lines.append(formatted_row) + + # Add separator every 5 rows for readability + if i > 0 and (i + 1) % 5 == 0 and i < row_count - 1: + lines.append(_create_light_separator(col_widths)) + + # Add footer with stats + lines.append(_create_table_separator(col_widths)) + lines.append(f"šŸ“Š Showing {row_count} of {len(data)} rows") + + # Add column summaries for numeric columns + numeric_summaries = _create_numeric_summaries(data, headers) + if numeric_summaries: + lines.append("") + lines.append("šŸ“ˆ Numeric Summaries:") + lines.extend(numeric_summaries) + + return "\n".join(lines) + + +def _select_display_columns( + all_headers: list[str], data: list[Any], max_cols: int = 6 +) -> list[str]: + """Select the most interesting columns to display.""" + if len(all_headers) <= max_cols: + return all_headers + + # Prioritize columns by interest level + priority_scores = {} + + for header in all_headers: + score = 0 + header_lower = header.lower() + + # Higher priority for common business fields + if any(word in header_lower for word in ["name", "title", "id"]): + score += 10 + if any( + word in header_lower + for word in ["amount", "price", "cost", "revenue", "sales"] + ): + score += 8 + if any(word in header_lower for word in ["date", "time", "created", "updated"]): + score += 6 + if any(word in header_lower for word in ["count", "total", "sum", "avg"]): + score += 5 + + # Check data variety (more variety = more interesting) + sample_values = [ + str(row.get(header, "")) for row in data[:10] if isinstance(row, dict) + ] + unique_values = len(set(sample_values)) + if unique_values > 1: + score += min(unique_values, 5) + + priority_scores[header] = score + + # Return top scoring columns + sorted_headers = sorted( + all_headers, key=lambda h: priority_scores.get(h, 0), reverse=True + ) + return sorted_headers[:max_cols] + + +def _calculate_column_widths( + headers: list[str], data: list[Any], total_width: int +) -> list[int]: + """Calculate optimal column widths based on content.""" + if not headers: + return [] + + # Start with minimum widths based on header lengths + min_widths = [max(8, min(len(h) + 2, 15)) for h in headers] + + # Check actual data to adjust widths + for row in data[:10]: # Sample first 10 rows + if isinstance(row, dict): + for i, header in enumerate(headers): + val = row.get(header, "") + if isinstance(val, float): + content_len = len(f"{val:.2f}") + elif isinstance(val, int): + content_len = len(str(val)) + else: + content_len = len(str(val)) + + min_widths[i] = max(min_widths[i], min(content_len + 1, 20)) + + # Distribute remaining space proportionally + used_width = sum(min_widths) + len(headers) * 3 # Account for separators + available_width = min(total_width - 10, 80) # Leave margins + + if used_width < available_width: + # Distribute extra space + extra_space = available_width - used_width + for i in range(len(min_widths)): + min_widths[i] += extra_space // len(min_widths) + + return min_widths + + +def _create_table_header(headers: list[str], widths: list[int]) -> str: + """Create formatted table header.""" + formatted_headers = [] + for header, width in zip(headers, widths, strict=False): + # Truncate and center header + display_header = header[: width - 2] if len(header) > width - 2 else header + formatted_headers.append(f"{display_header:^{width}}") + + return ( + "ā”Œ" + + "┬".join("─" * w for w in widths) + + "┐\n│" + + "│".join(formatted_headers) + + "│" + ) + + +def _create_table_separator(widths: list[int]) -> str: + """Create table separator line.""" + return "ā”œ" + "┼".join("─" * w for w in widths) + "┤" + + +def _create_light_separator(widths: list[int]) -> str: + """Create light separator line.""" + return "ā”œ" + "┼".join("ā”ˆ" * w for w in widths) + "┤" + + +def _format_table_row( + row: dict[str, Any], headers: list[str], widths: list[int] +) -> str: + """Format a single table row.""" + formatted_values = [] + + for header, width in zip(headers, widths, strict=False): + val = row.get(header, "") + + # Format value based on type + if isinstance(val, float): + if abs(val) >= 1000000: + formatted_val = f"{val / 1000000:.1f}M" + elif abs(val) >= 1000: + formatted_val = f"{val / 1000:.1f}K" + else: + formatted_val = f"{val:.2f}" + elif isinstance(val, int): + if abs(val) >= 1000000: + formatted_val = f"{val // 1000000}M" + elif abs(val) >= 1000: + formatted_val = f"{val // 1000}K" + else: + formatted_val = str(val) + else: + formatted_val = str(val) + + # Truncate if too long + if len(formatted_val) > width - 2: + formatted_val = formatted_val[: width - 5] + "..." + + # Align numbers right, text left + if isinstance(val, (int, float)): + formatted_values.append(f"{formatted_val:>{width - 2}}") + else: + formatted_values.append(f"{formatted_val:<{width - 2}}") + + return "│ " + " │ ".join(formatted_values) + " │" + + +def _create_numeric_summaries(data: list[Any], headers: list[str]) -> list[str]: + """Create summaries for numeric columns.""" + summaries = [] + + for header in headers: + # Collect numeric values + values = [] + for row in data: + if isinstance(row, dict): + val = row.get(header) + if isinstance(val, (int, float)): + values.append(val) + + if len(values) >= 2: + min_val = min(values) + max_val = max(values) + avg_val = sum(values) / len(values) + + if abs(avg_val) >= 1000: + avg_str = f"{avg_val / 1000:.1f}K" + else: + avg_str = f"{avg_val:.1f}" + + summaries.append( + f" {header}: avg={avg_str}, range={min_val:.1f}-{max_val:.1f}" + ) + + return summaries diff --git a/superset/mcp_service/chart/chart_helpers.py b/superset/mcp_service/chart/chart_helpers.py new file mode 100644 index 00000000000..05477e76eee --- /dev/null +++ b/superset/mcp_service/chart/chart_helpers.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Shared helper functions for MCP chart tools. + +This module contains reusable utility functions for common operations +across chart tools: chart lookup, cached form data retrieval, and +URL parameter extraction. Config mapping logic lives in chart_utils.py. +""" + +from __future__ import annotations + +import logging +from typing import TYPE_CHECKING +from urllib.parse import parse_qs, urlparse + +if TYPE_CHECKING: + from superset.models.slice import Slice + +logger = logging.getLogger(__name__) + + +def find_chart_by_identifier(identifier: int | str) -> Slice | None: + """Find a chart by numeric ID or UUID string. + + Accepts an integer ID, a string that looks like a digit (e.g. "123"), + or a UUID string. Returns the Slice model instance or None. + """ + from superset.daos.chart import ChartDAO # avoid circular import + + if isinstance(identifier, int) or ( + isinstance(identifier, str) and identifier.isdigit() + ): + chart_id = int(identifier) if isinstance(identifier, str) else identifier + return ChartDAO.find_by_id(chart_id) + return ChartDAO.find_by_id(identifier, id_column="uuid") + + +def get_cached_form_data(form_data_key: str) -> str | None: + """Retrieve form_data from cache using form_data_key. + + Returns the JSON string of form_data if found, None otherwise. + """ + # avoid circular import — commands depend on app initialization + from superset.commands.exceptions import CommandException + from superset.commands.explore.form_data.get import GetFormDataCommand + from superset.commands.explore.form_data.parameters import CommandParameters + + try: + cmd_params = CommandParameters(key=form_data_key) + return GetFormDataCommand(cmd_params).run() + except (KeyError, ValueError, CommandException) as e: + logger.warning("Failed to retrieve form_data from cache: %s", e) + return None + + +def extract_form_data_key_from_url(url: str | None) -> str | None: + """Extract the form_data_key query parameter from an explore URL. + + Returns the form_data_key value or None if not found or URL is empty. + """ + if not url: + return None + parsed = urlparse(url) + values = parse_qs(parsed.query).get("form_data_key", []) + return values[0] if values else None diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 494af9b5614..7ed0a5cc6b0 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -22,7 +22,6 @@ import logging import time from dataclasses import dataclass, field from typing import Any, Dict, List -from urllib.parse import parse_qs, urlparse from fastmcp import Context from sqlalchemy.exc import SQLAlchemyError @@ -32,6 +31,7 @@ from superset.commands.exceptions import CommandException from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.extensions import event_logger from superset.mcp_service.auth import has_dataset_access +from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url from superset.mcp_service.chart.chart_utils import ( analyze_chart_capabilities, analyze_chart_semantics, @@ -540,13 +540,8 @@ async def generate_chart( # noqa: C901 explore_url = generate_explore_link(request.dataset_id, form_data) await ctx.debug("Generated explore link: explore_url=%s" % (explore_url,)) - # Extract form_data_key from the explore URL using proper URL parsing - if explore_url: - parsed = urlparse(explore_url) - query_params = parse_qs(parsed.query) - form_data_key_list = query_params.get("form_data_key", []) - if form_data_key_list: - form_data_key = form_data_key_list[0] + # Extract form_data_key from the explore URL + form_data_key = extract_form_data_key_from_url(explore_url) # Compile check for preview-only mode # Validate dataset existence and user access before running queries diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index 7ee7eefee02..d901b23e638 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -25,15 +25,19 @@ from typing import Any, Dict, List, TYPE_CHECKING from fastmcp import Context from flask import current_app +from sqlalchemy.exc import SQLAlchemyError from superset_core.mcp.decorators import tool, ToolAnnotations if TYPE_CHECKING: from superset.models.slice import Slice from superset.commands.exceptions import CommandException -from superset.commands.explore.form_data.parameters import CommandParameters from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException from superset.extensions import event_logger +from superset.mcp_service.chart.chart_helpers import ( + find_chart_by_identifier, + get_cached_form_data, +) from superset.mcp_service.chart.chart_utils import validate_chart_dataset from superset.mcp_service.chart.schemas import ( ChartData, @@ -62,21 +66,6 @@ def _apply_extra_form_data( merge_extra_filters(form_data) -def _get_cached_form_data(form_data_key: str) -> str | None: - """Retrieve form_data from cache using form_data_key. - - Returns the JSON string of form_data if found, None otherwise. - """ - from superset.commands.explore.form_data.get import GetFormDataCommand - - try: - cmd_params = CommandParameters(key=form_data_key) - return GetFormDataCommand(cmd_params).run() - except (KeyError, ValueError, CommandException) as e: - logger.warning("Failed to retrieve form_data from cache: %s", e) - return None - - @tool( tags=["data"], class_permission_name="Chart", @@ -127,7 +116,6 @@ async def get_chart_data( # noqa: C901 try: await ctx.report_progress(1, 4, "Looking up chart") - from superset.daos.chart import ChartDAO from superset.utils import json as utils_json chart = None @@ -141,7 +129,7 @@ async def get_chart_data( # noqa: C901 "No chart identifier - querying data from unsaved chart cache: " "form_data_key=%s" % (request.form_data_key,) ) - cached_form_data = _get_cached_form_data(request.form_data_key) + cached_form_data = get_cached_form_data(request.form_data_key) if not cached_form_data: return ChartError( error="No cached chart data found for form_data_key. " @@ -166,25 +154,13 @@ async def get_chart_data( # noqa: C901 # Find the chart by identifier with event_logger.log_context(action="mcp.get_chart_data.chart_lookup"): - if isinstance(request.identifier, int) or ( - isinstance(request.identifier, str) and request.identifier.isdigit() - ): - chart_id = ( - int(request.identifier) - if isinstance(request.identifier, str) - else request.identifier + await ctx.debug("Looking up chart: identifier=%s" % (request.identifier,)) + if request.identifier is None: + return ChartError( + error="Chart identifier is required", + error_type="ValidationError", ) - await ctx.debug( - "Performing ID-based chart lookup: chart_id=%s" % (chart_id,) - ) - chart = ChartDAO.find_by_id(chart_id) - elif isinstance(request.identifier, str): - await ctx.debug( - "Performing UUID-based chart lookup: uuid=%s" - % (request.identifier,) - ) - # Try UUID lookup using DAO flexible method - chart = ChartDAO.find_by_id(request.identifier, id_column="uuid") + chart = find_chart_by_identifier(request.identifier) if not chart: await ctx.error("Chart not found: identifier=%s" % (request.identifier,)) @@ -239,7 +215,7 @@ async def get_chart_data( # noqa: C901 "Retrieving unsaved chart state from cache: form_data_key=%s" % (request.form_data_key,) ) - if cached_form_data := _get_cached_form_data(request.form_data_key): + if cached_form_data := get_cached_form_data(request.form_data_key): try: parsed_form_data = utils_json.loads(cached_form_data) # Only use if it's actually a dict (not null, list, etc.) @@ -818,7 +794,15 @@ async def get_chart_data( # noqa: C901 error=OAUTH2_CONFIG_ERROR_MESSAGE, error_type="OAUTH2_REDIRECT_ERROR", ) - except Exception as e: + except ( + SupersetException, + CommandException, + SQLAlchemyError, + KeyError, + ValueError, + TypeError, + AttributeError, + ) as e: await ctx.error( "Chart data retrieval failed: identifier=%s, error=%s, error_type=%s" % ( diff --git a/superset/mcp_service/chart/tool/get_chart_info.py b/superset/mcp_service/chart/tool/get_chart_info.py index 0477d28f039..fb8720a1c1b 100644 --- a/superset/mcp_service/chart/tool/get_chart_info.py +++ b/superset/mcp_service/chart/tool/get_chart_info.py @@ -25,9 +25,8 @@ from fastmcp import Context from sqlalchemy.orm import subqueryload from superset_core.mcp.decorators import tool, ToolAnnotations -from superset.commands.exceptions import CommandException -from superset.commands.explore.form_data.parameters import CommandParameters from superset.extensions import event_logger +from superset.mcp_service.chart.chart_helpers import get_cached_form_data from superset.mcp_service.chart.chart_utils import validate_chart_dataset from superset.mcp_service.chart.schemas import ( ChartError, @@ -41,26 +40,11 @@ from superset.mcp_service.mcp_core import ModelGetInfoCore logger = logging.getLogger(__name__) -def _get_cached_form_data(form_data_key: str) -> str | None: - """Retrieve form_data from cache using form_data_key. - - Returns the JSON string of form_data if found, None otherwise. - """ - from superset.commands.explore.form_data.get import GetFormDataCommand - - try: - cmd_params = CommandParameters(key=form_data_key) - return GetFormDataCommand(cmd_params).run() - except (KeyError, ValueError, CommandException) as e: - logger.warning("Failed to retrieve form_data from cache: %s", e) - return None - - def _build_unsaved_chart_info(form_data_key: str) -> ChartInfo | ChartError: """Build a ChartInfo from cached form_data when no chart identifier exists.""" from superset.utils import json as utils_json - cached_form_data = _get_cached_form_data(form_data_key) + cached_form_data = get_cached_form_data(form_data_key) if not cached_form_data: return ChartError( error="No cached chart data found for form_data_key. " @@ -94,7 +78,7 @@ def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None """Override a ChartInfo's form_data with cached unsaved state.""" from superset.utils import json as utils_json - if cached_form_data := _get_cached_form_data(form_data_key): + if cached_form_data := get_cached_form_data(form_data_key): try: result.form_data = utils_json.loads(cached_form_data) result.form_data_key = form_data_key diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py index 77072a5943c..dbffd152ea5 100644 --- a/superset/mcp_service/chart/tool/get_chart_preview.py +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -23,11 +23,17 @@ import logging from typing import Any, Dict, List, Protocol from fastmcp import Context +from sqlalchemy.exc import SQLAlchemyError from superset_core.mcp.decorators import tool, ToolAnnotations from superset.commands.exceptions import CommandException from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException from superset.extensions import event_logger +from superset.mcp_service.chart.ascii_charts import ( + generate_ascii_chart, + generate_ascii_table, +) +from superset.mcp_service.chart.chart_helpers import find_chart_by_identifier from superset.mcp_service.chart.chart_utils import validate_chart_dataset from superset.mcp_service.chart.schemas import ( AccessibilityMetadata, @@ -262,7 +268,7 @@ class TablePreviewStrategy(PreviewFormatStrategy): if result and "queries" in result and len(result["queries"]) > 0: data = result["queries"][0].get("data", []) - table_data = _generate_ascii_table(data, 120) + table_data = generate_ascii_table(data, 120) return TablePreview( table_data=table_data, @@ -964,874 +970,6 @@ class PreviewFormatGenerator: return strategy.generate() -def generate_ascii_chart( - data: List[Any], chart_type: str, width: int = 80, height: int = 20 -) -> str: - """Generate ASCII art chart from data.""" - if not data or len(data) == 0: - return "No data available for ASCII chart" - - try: - logger.info( - "generate_ascii_chart: chart_type=%s, data_rows=%s", chart_type, len(data) - ) - - # Generate appropriate ASCII chart based on type - if chart_type in ["bar", "column", "echarts_timeseries_bar"]: - logger.info("Generating bar chart") - return _generate_ascii_bar_chart(data, width, height) - elif chart_type in ["line", "echarts_timeseries_line"]: - logger.info("Generating line chart") - return _generate_ascii_line_chart(data, width, height) - elif chart_type in ["scatter", "echarts_timeseries_scatter"]: - logger.info("Generating scatter chart") - return _generate_ascii_scatter_chart(data, width, height) - else: - # Default to table format for unsupported chart types - logger.info( - "Unsupported chart type '%s', falling back to table", chart_type - ) - return _generate_ascii_table(data, width) - except (TypeError, ValueError, KeyError, IndexError) as e: - logger.error("ASCII chart generation failed: %s", e) - import traceback - - logger.error("Traceback: %s", traceback.format_exc()) - return f"ASCII chart generation failed: {str(e)}" - - -def _generate_ascii_bar_chart(data: List[Any], width: int, height: int) -> str: - """Generate enhanced ASCII bar chart with horizontal and vertical options.""" - if not data: - return "No data for bar chart" - - # Extract numeric values for bars - values = [] - labels = [] - - for row in data[:12]: # Increased limit for better charts - if isinstance(row, dict): - # Find numeric and string values - numeric_val = None - label_val = None - - for _key, val in row.items(): - if isinstance(val, (int, float)) and numeric_val is None: - numeric_val = val - elif isinstance(val, str) and label_val is None: - label_val = val - - if numeric_val is not None: - values.append(numeric_val) - labels.append(label_val or f"Item {len(values)}") - - if not values: - return "No numeric data found for bar chart" - - # Decide between horizontal and vertical based on label lengths - avg_label_length = sum(len(str(label)) for label in labels) / len(labels) - use_horizontal = avg_label_length > 8 or len(values) > 8 - - if use_horizontal: - return _generate_horizontal_bar_chart(values, labels, width) - else: - return _generate_vertical_bar_chart(values, labels, width, height) - - -def _generate_horizontal_bar_chart( - values: List[float], labels: List[str], width: int -) -> str: - """Generate horizontal ASCII bar chart.""" - lines = [] - lines.append("šŸ“Š Horizontal Bar Chart") - lines.append("═" * min(width, 60)) - - max_val = max(values) if values else 1 - min_val = min(values) if values else 0 - max_bar_width = min(40, width - 20) # Leave space for labels and values - - # Add scale indicator - lines.append(f"Scale: {min_val:.1f} ────────────── {max_val:.1f}") - lines.append("") - - for label, value in zip(labels, values, strict=False): - # Calculate bar length - if max_val > min_val: - normalized = (value - min_val) / (max_val - min_val) - bar_length = max(1, int(normalized * max_bar_width)) - else: - bar_length = 1 - - # Create bar with gradient effect - bar = _create_gradient_bar(bar_length, value, max_val) - - # Format value - if abs(value) >= 1000000: - value_str = f"{value / 1000000:.1f}M" - elif abs(value) >= 1000: - value_str = f"{value / 1000:.1f}K" - else: - value_str = f"{value:.1f}" - - # Truncate label if too long - display_label = label[:15] if len(label) > 15 else label - lines.append(f"{display_label:>15} ▐{bar:<{max_bar_width}} {value_str}") - - return "\n".join(lines) - - -def _generate_vertical_bar_chart( # noqa: C901 - values: List[float], labels: List[str], width: int, height: int -) -> str: - """Generate vertical ASCII bar chart.""" - lines = [] - lines.append("šŸ“Š Vertical Bar Chart") - lines.append("═" * min(width, 60)) - - max_val = max(values) if values else 1 - min_val = min(values) if values else 0 - chart_height = min(15, height - 8) # Leave space for title and labels - - # Create the chart grid - grid = [] - for _ in range(chart_height): - grid.append([" "] * len(values)) - - # Fill the bars - for col, value in enumerate(values): - if max_val > min_val: - normalized = (value - min_val) / (max_val - min_val) - bar_height = max(1, int(normalized * chart_height)) - else: - bar_height = 1 - - # Fill from bottom up - for row_idx in range(chart_height - bar_height, chart_height): - if row_idx < len(grid): - # Use different characters for height effect - if row_idx == chart_height - bar_height: - grid[row_idx][col] = "ā–€" # Top of bar - elif row_idx == chart_height - 1: - grid[row_idx][col] = "ā–ˆ" # Bottom of bar - else: - grid[row_idx][col] = "ā–ˆ" # Middle of bar - - # Add Y-axis scale - for i, row_data in enumerate(grid): - y_val = ( - max_val - (i / (chart_height - 1)) * (max_val - min_val) - if chart_height > 1 - else max_val - ) - if abs(y_val) >= 1000: - y_label = f"{y_val:.0f}" - else: - y_label = f"{y_val:.1f}" - lines.append(f"{y_label:>6} ┤ " + "".join(f"{cell:^3}" for cell in row_data)) - - # Add X-axis - lines.append(" ā””" + "───" * len(values)) - - # Add labels - label_line = " " - for label in labels: - short_label = label[:3] if len(label) > 3 else label - label_line += f"{short_label:^3}" - lines.append(label_line) - - return "\n".join(lines) - - -def _create_gradient_bar(length: int, value: float, max_val: float) -> str: - """Create a gradient bar with visual effects.""" - if length <= 0: - return "" - - # Create gradient effect based on value intensity - intensity = value / max_val if max_val > 0 else 0 - - if intensity > 0.8: - # High values - solid bars - return "ā–ˆ" * length - elif intensity > 0.6: - # Medium-high values - mostly solid with some texture - return "ā–ˆ" * (length - 1) + "ā–‰" if length > 1 else "ā–ˆ" - elif intensity > 0.4: - # Medium values - mixed texture - return "ā–Š" * length - elif intensity > 0.2: - # Low-medium values - lighter texture - return "ā–‹" * length - else: - # Low values - lightest texture - return "ā–Œ" * length - - -def _generate_ascii_line_chart(data: List[Any], width: int, height: int) -> str: - """Generate enhanced ASCII line chart with trend analysis.""" - if not data: - return "No data for line chart" - - lines = [] - lines.append("šŸ“ˆ Line Chart with Trend Analysis") - lines.append("═" * min(width, 60)) - - # Extract values and labels for plotting - values, labels = _extract_time_series_data(data) - - if not values: - return "No numeric data found for line chart" - - # Generate enhanced line chart - if len(values) >= 3: - lines.extend(_create_enhanced_line_chart(values, labels, width, height)) - else: - # Fallback to sparkline for small datasets - sparkline_data = _create_sparkline(values) - lines.extend(sparkline_data) - - # Add trend analysis - trend_analysis = _analyze_trend(values) - lines.append("") - lines.append("šŸ“Š Trend Analysis:") - lines.extend(trend_analysis) - - return "\n".join(lines) - - -def _extract_time_series_data(data: List[Any]) -> tuple[List[float], List[str]]: - """Extract time series data with labels.""" - values = [] - labels = [] - - for row in data[:20]: # Limit points for readability - if isinstance(row, dict): - # Find the first numeric value and first string/date value - numeric_val = None - label_val = None - - for key, val in row.items(): - if isinstance(val, (int, float)) and numeric_val is None: - numeric_val = val - elif isinstance(val, str) and label_val is None: - # Use the key name if it looks like a date/time field - if any( - date_word in key.lower() - for date_word in ["date", "time", "month", "day", "year"] - ): - label_val = str(val)[:10] # Truncate long dates - else: - label_val = str(val)[:8] # Truncate long strings - - if numeric_val is not None: - values.append(numeric_val) - labels.append(label_val or f"P{len(values)}") - - return values, labels - - -def _create_enhanced_line_chart( - values: List[float], labels: List[str], width: int, height: int -) -> List[str]: - """Create an enhanced ASCII line chart with better visualization.""" - lines = [] - chart_width = min(50, width - 15) - chart_height = min(12, height - 8) - - if len(values) < 2: - return ["Insufficient data for line chart"] - - # Normalize values to chart height - min_val = min(values) - max_val = max(values) - val_range = max_val - min_val if max_val != min_val else 1 - - # Create chart grid - grid = [[" " for _ in range(chart_width)] for _ in range(chart_height)] - - # Plot the line with connecting segments - prev_x, prev_y = None, None - - for i, value in enumerate(values): - # Map to grid coordinates - x = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0 - y = chart_height - 1 - int(((value - min_val) / val_range) * (chart_height - 1)) - - # Ensure coordinates are within bounds - x = max(0, min(x, chart_width - 1)) - y = max(0, min(y, chart_height - 1)) - - # Mark the point - grid[y][x] = "ā—" - - # Draw line segment to previous point - if prev_x is not None and prev_y is not None: - _draw_line_segment(grid, prev_x, prev_y, x, y, chart_width, chart_height) - - prev_x, prev_y = x, y - - # Render the chart with Y-axis labels - for i, row in enumerate(grid): - y_val = ( - max_val - (i / (chart_height - 1)) * val_range - if chart_height > 1 - else max_val - ) - if abs(y_val) >= 1000: - y_label = f"{y_val:.0f}" - else: - y_label = f"{y_val:.1f}" - lines.append(f"{y_label:>8} ┤ " + "".join(row)) - - # Add X-axis - lines.append(" ā””" + "─" * chart_width) - - # Add selected X-axis labels (show every few labels to avoid crowding) - label_line = " " - step = max(1, len(labels) // 6) # Show max 6 labels - for i in range(0, len(labels), step): - pos = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0 - # Add spacing to position the label - while len(label_line) - 10 < pos: - label_line += " " - label_line += labels[i][:8] - - lines.append(label_line) - - return lines - - -def _draw_line_segment( - grid: List[List[str]], x1: int, y1: int, x2: int, y2: int, width: int, height: int -) -> None: - """Draw a line segment between two points using Bresenham-like algorithm.""" - # Simple line drawing - connect points with line characters - if x1 == x2: # Vertical line - start_y, end_y = sorted([y1, y2]) - for y in range(start_y + 1, end_y): - if 0 <= y < height and 0 <= x1 < width: - grid[y][x1] = "│" - elif y1 == y2: # Horizontal line - start_x, end_x = sorted([x1, x2]) - for x in range(start_x + 1, end_x): - if 0 <= y1 < height and 0 <= x < width: - grid[y1][x] = "─" - else: # Diagonal line - use simple interpolation - steps = max(abs(x2 - x1), abs(y2 - y1)) - for step in range(1, steps): - x = x1 + int((x2 - x1) * step / steps) - y = y1 + int((y2 - y1) * step / steps) - if 0 <= x < width and 0 <= y < height: - if abs(x2 - x1) > abs(y2 - y1): - grid[y][x] = "─" - else: - grid[y][x] = "│" - - -def _analyze_trend(values: List[float]) -> List[str]: - """Analyze trend in the data.""" - if len(values) < 2: - return ["• Insufficient data for trend analysis"] - - analysis = [] - - # Calculate basic statistics - first_val = values[0] - last_val = values[-1] - min_val = min(values) - max_val = max(values) - avg_val = sum(values) / len(values) - - # Overall trend - if last_val > first_val * 1.1: - trend_icon = "šŸ“ˆ" - trend_desc = "Strong upward trend" - elif last_val > first_val * 1.05: - trend_icon = "šŸ“Š" - trend_desc = "Moderate upward trend" - elif last_val < first_val * 0.9: - trend_icon = "šŸ“‰" - trend_desc = "Strong downward trend" - elif last_val < first_val * 0.95: - trend_icon = "šŸ“Š" - trend_desc = "Moderate downward trend" - else: - trend_icon = "āž”ļø" - trend_desc = "Relatively stable" - - analysis.append(f"• {trend_icon} {trend_desc}") - analysis.append(f"• Range: {min_val:.1f} to {max_val:.1f} (avg: {avg_val:.1f})") - - # Volatility - if len(values) >= 3: - changes = [abs(values[i] - values[i - 1]) for i in range(1, len(values))] - avg_change = sum(changes) / len(changes) - volatility = "High" if avg_change > (max_val - min_val) * 0.1 else "Low" - analysis.append(f"• Volatility: {volatility}") - - return analysis - - -def _extract_numeric_values(data: List[Any]) -> List[float]: - """Extract numeric values from data for line chart.""" - values = [] - for row in data[:20]: # Limit points - if isinstance(row, dict): - for _key, val in row.items(): - if isinstance(val, (int, float)): - values.append(val) - break - return values - - -def _create_sparkline(values: List[float]) -> List[str]: - """Create sparkline visualization from values.""" - if len(values) <= 1: - return [] - - max_val = max(values) - min_val = min(values) - range_val = max_val - min_val if max_val != min_val else 1 - - sparkline = "" - for val in values: - normalized = (val - min_val) / range_val - if normalized < 0.2: - sparkline += "▁" - elif normalized < 0.4: - sparkline += "ā–‚" - elif normalized < 0.6: - sparkline += "ā–„" - elif normalized < 0.8: - sparkline += "ā–†" - else: - sparkline += "ā–ˆ" - - # Safe formatting to avoid NaN display - if _is_nan_value(min_val) or _is_nan_value(max_val): - return ["Range: Unable to calculate from data", sparkline] - else: - return [f"Range: {min_val:.2f} to {max_val:.2f}", sparkline] - - -def _is_nan_value(value: Any) -> bool: - """Check if a value is NaN or invalid.""" - try: - import math - - return math.isnan(float(value)) - except (ValueError, TypeError): - return True - - -def _generate_ascii_scatter_chart(data: List[Any], width: int, height: int) -> str: - """Generate ASCII scatter plot.""" - if not data: - return "No data for scatter chart" - - lines = [] - lines.append("ASCII Scatter Plot") - lines.append("=" * min(width, 50)) - - # Extract data points - x_values, y_values, x_column, y_column = _extract_scatter_data(data) - - # Debug info - lines.extend(_create_debug_info(x_values, y_values, x_column, y_column, data)) - - # Check if we have enough data - if len(x_values) < 2: - return _generate_ascii_table(data, width) - - # Add axis info - lines.extend(_create_axis_info(x_values, y_values, x_column, y_column)) - - # Create and render grid - grid = _create_scatter_grid(x_values, y_values, width, height) - lines.extend(_render_scatter_grid(grid, x_values, y_values, width, height)) - - return "\n".join(lines) - - -def _extract_scatter_data( - data: List[Any], -) -> tuple[List[float], List[float], str | None, str | None]: - """Extract X,Y data from scatter chart data.""" - x_values = [] - y_values = [] - x_column = None - y_column = None - numeric_columns = [] - - if data and isinstance(data[0], dict): - # Find the first two numeric columns - for key, val in data[0].items(): - if isinstance(val, (int, float)) and not ( - isinstance(val, float) and (val != val) - ): # Exclude NaN - numeric_columns.append(key) - - if len(numeric_columns) >= 2: - x_column = numeric_columns[0] - y_column = numeric_columns[1] - - # Extract X,Y pairs - for row in data[:50]: # Limit for ASCII display - if isinstance(row, dict): - x_val = row.get(x_column) - y_val = row.get(y_column) - # Check for valid numbers (not NaN) - if ( - isinstance(x_val, (int, float)) - and isinstance(y_val, (int, float)) - and not ( - isinstance(x_val, float) and (x_val != x_val) - ) # Not NaN - and not (isinstance(y_val, float) and (y_val != y_val)) - ): # Not NaN - x_values.append(x_val) - y_values.append(y_val) - - return x_values, y_values, x_column, y_column - - -def _create_debug_info( - x_values: List[float], - y_values: List[float], - x_column: str | None, - y_column: str | None, - data: List[Any], -) -> List[str]: - """Create debug information lines for scatter chart.""" - numeric_columns = [] - if data and isinstance(data[0], dict): - for key, val in data[0].items(): - if isinstance(val, (int, float)) and not ( - isinstance(val, float) and (val != val) - ): - numeric_columns.append(key) - - return [ - f"DEBUG: Found {len(numeric_columns)} numeric columns: {numeric_columns}", - f"DEBUG: X column: {x_column}, Y column: {y_column}", - f"DEBUG: Valid X,Y pairs: {len(x_values)}", - ] - - -def _create_axis_info( - x_values: List[float], - y_values: List[float], - x_column: str | None, - y_column: str | None, -) -> List[str]: - """Create axis information lines.""" - return [ - f"X-axis: {x_column} (range: {min(x_values):.2f} to {max(x_values):.2f})", - f"Y-axis: {y_column} (range: {min(y_values):.2f} to {max(y_values):.2f})", - f"Showing {len(x_values)} data points", - "", - ] - - -def _create_scatter_grid( - x_values: List[float], y_values: List[float], width: int, height: int -) -> List[List[str]]: - """Create and populate the scatter plot grid.""" - plot_width = min(40, width - 10) - plot_height = min(15, height - 8) - - # Normalize values to fit in grid - x_min, x_max = min(x_values), max(x_values) - y_min, y_max = min(y_values), max(y_values) - x_range = x_max - x_min if x_max != x_min else 1 - y_range = y_max - y_min if y_max != y_min else 1 - - # Create grid - grid = [[" " for _ in range(plot_width)] for _ in range(plot_height)] - - # Plot points - for x, y in zip(x_values, y_values, strict=False): - try: - grid_x = int(((x - x_min) / x_range) * (plot_width - 1)) - grid_y = int(((y - y_min) / y_range) * (plot_height - 1)) - grid_y = plot_height - 1 - grid_y # Flip Y axis for display - - if 0 <= grid_x < plot_width and 0 <= grid_y < plot_height: - if grid[grid_y][grid_x] == " ": - grid[grid_y][grid_x] = "•" - else: - grid[grid_y][grid_x] = "ā–ˆ" # Multiple points - except (ValueError, OverflowError): - # Skip points that can't be converted to integers (NaN, inf, etc.) - continue - - return grid - - -def _render_scatter_grid( - grid: List[List[str]], - x_values: List[float], - y_values: List[float], - width: int, - height: int, -) -> List[str]: - """Render the scatter plot grid with axes and labels.""" - lines = [] - plot_width = min(40, width - 10) - plot_height = min(15, height - 8) - - x_min, x_max = min(x_values), max(x_values) - y_min, y_max = min(y_values), max(y_values) - y_range = y_max - y_min if y_max != y_min else 1 - - # Add Y-axis labels and plot - for i, row in enumerate(grid): - y_val = y_max - (i / (plot_height - 1)) * y_range if plot_height > 1 else y_max - y_label = f"{y_val:.1f}" if abs(y_val) < 1000 else f"{y_val:.0f}" - lines.append(f"{y_label:>6} |{''.join(row)}") - - # Add X-axis - x_axis_line = " " * 7 + "+" + "-" * plot_width - lines.append(x_axis_line) - - # Add X-axis labels - x_left_label = f"{x_min:.1f}" if abs(x_min) < 1000 else f"{x_min:.0f}" - x_right_label = f"{x_max:.1f}" if abs(x_max) < 1000 else f"{x_max:.0f}" - x_labels = ( - " " * 8 - + x_left_label - + " " * (plot_width - len(x_left_label) - len(x_right_label)) - + x_right_label - ) - lines.append(x_labels) - - return lines - - -def _generate_ascii_table(data: List[Any], width: int) -> str: - """Generate enhanced ASCII table with better formatting.""" - if not data: - return "No data for table" - - lines = [] - lines.append("šŸ“‹ Data Table") - lines.append("═" * min(width, 70)) - - # Get column headers from first row - if isinstance(data[0], dict): - # Select best columns to display - all_headers = list(data[0].keys()) - headers = _select_display_columns(all_headers, data, max_cols=6) - - # Calculate optimal column widths - col_widths = _calculate_column_widths(headers, data, width) - - # Create enhanced header - lines.append(_create_table_header(headers, col_widths)) - lines.append(_create_table_separator(col_widths)) - - # Add data rows with better formatting - row_count = min(15, len(data)) # Show more rows - for i, row in enumerate(data[:row_count]): - formatted_row = _format_table_row(row, headers, col_widths) - lines.append(formatted_row) - - # Add separator every 5 rows for readability - if i > 0 and (i + 1) % 5 == 0 and i < row_count - 1: - lines.append(_create_light_separator(col_widths)) - - # Add footer with stats - lines.append(_create_table_separator(col_widths)) - lines.append(f"šŸ“Š Showing {row_count} of {len(data)} rows") - - # Add column summaries for numeric columns - numeric_summaries = _create_numeric_summaries(data, headers) - if numeric_summaries: - lines.append("") - lines.append("šŸ“ˆ Numeric Summaries:") - lines.extend(numeric_summaries) - - return "\n".join(lines) - - -def _select_display_columns( - all_headers: List[str], data: List[Any], max_cols: int = 6 -) -> List[str]: - """Select the most interesting columns to display.""" - if len(all_headers) <= max_cols: - return all_headers - - # Prioritize columns by interest level - priority_scores = {} - - for header in all_headers: - score = 0 - header_lower = header.lower() - - # Higher priority for common business fields - if any(word in header_lower for word in ["name", "title", "id"]): - score += 10 - if any( - word in header_lower - for word in ["amount", "price", "cost", "revenue", "sales"] - ): - score += 8 - if any(word in header_lower for word in ["date", "time", "created", "updated"]): - score += 6 - if any(word in header_lower for word in ["count", "total", "sum", "avg"]): - score += 5 - - # Check data variety (more variety = more interesting) - sample_values = [ - str(row.get(header, "")) for row in data[:10] if isinstance(row, dict) - ] - unique_values = len(set(sample_values)) - if unique_values > 1: - score += min(unique_values, 5) - - priority_scores[header] = score - - # Return top scoring columns - sorted_headers = sorted( - all_headers, key=lambda h: priority_scores.get(h, 0), reverse=True - ) - return sorted_headers[:max_cols] - - -def _calculate_column_widths( - headers: List[str], data: List[Any], total_width: int -) -> List[int]: - """Calculate optimal column widths based on content.""" - if not headers: - return [] - - # Start with minimum widths based on header lengths - min_widths = [max(8, min(len(h) + 2, 15)) for h in headers] - - # Check actual data to adjust widths - for row in data[:10]: # Sample first 10 rows - if isinstance(row, dict): - for i, header in enumerate(headers): - val = row.get(header, "") - if isinstance(val, float): - content_len = len(f"{val:.2f}") - elif isinstance(val, int): - content_len = len(str(val)) - else: - content_len = len(str(val)) - - min_widths[i] = max(min_widths[i], min(content_len + 1, 20)) - - # Distribute remaining space proportionally - used_width = sum(min_widths) + len(headers) * 3 # Account for separators - available_width = min(total_width - 10, 80) # Leave margins - - if used_width < available_width: - # Distribute extra space - extra_space = available_width - used_width - for i in range(len(min_widths)): - min_widths[i] += extra_space // len(min_widths) - - return min_widths - - -def _create_table_header(headers: List[str], widths: List[int]) -> str: - """Create formatted table header.""" - formatted_headers = [] - for header, width in zip(headers, widths, strict=False): - # Truncate and center header - display_header = header[: width - 2] if len(header) > width - 2 else header - formatted_headers.append(f"{display_header:^{width}}") - - return ( - "ā”Œ" - + "┬".join("─" * w for w in widths) - + "┐\n│" - + "│".join(formatted_headers) - + "│" - ) - - -def _create_table_separator(widths: List[int]) -> str: - """Create table separator line.""" - return "ā”œ" + "┼".join("─" * w for w in widths) + "┤" - - -def _create_light_separator(widths: List[int]) -> str: - """Create light separator line.""" - return "ā”œ" + "┼".join("ā”ˆ" * w for w in widths) + "┤" - - -def _format_table_row( - row: Dict[str, Any], headers: List[str], widths: List[int] -) -> str: - """Format a single table row.""" - formatted_values = [] - - for header, width in zip(headers, widths, strict=False): - val = row.get(header, "") - - # Format value based on type - if isinstance(val, float): - if abs(val) >= 1000000: - formatted_val = f"{val / 1000000:.1f}M" - elif abs(val) >= 1000: - formatted_val = f"{val / 1000:.1f}K" - else: - formatted_val = f"{val:.2f}" - elif isinstance(val, int): - if abs(val) >= 1000000: - formatted_val = f"{val // 1000000}M" - elif abs(val) >= 1000: - formatted_val = f"{val // 1000}K" - else: - formatted_val = str(val) - else: - formatted_val = str(val) - - # Truncate if too long - if len(formatted_val) > width - 2: - formatted_val = formatted_val[: width - 5] + "..." - - # Align numbers right, text left - if isinstance(val, (int, float)): - formatted_values.append(f"{formatted_val:>{width - 2}}") - else: - formatted_values.append(f"{formatted_val:<{width - 2}}") - - return "│ " + " │ ".join(formatted_values) + " │" - - -def _create_numeric_summaries(data: List[Any], headers: List[str]) -> List[str]: - """Create summaries for numeric columns.""" - summaries = [] - - for header in headers: - # Collect numeric values - values = [] - for row in data: - if isinstance(row, dict): - val = row.get(header) - if isinstance(val, (int, float)): - values.append(val) - - if len(values) >= 2: - min_val = min(values) - max_val = max(values) - avg_val = sum(values) / len(values) - - if abs(avg_val) >= 1000: - avg_str = f"{avg_val / 1000:.1f}K" - else: - avg_str = f"{avg_val:.1f}" - - summaries.append( - f" {header}: avg={avg_str}, range={min_val:.1f}-{max_val:.1f}" - ) - - return summaries - - async def _get_chart_preview_internal( # noqa: C901 request: GetChartPreviewRequest, ctx: Context, @@ -1852,7 +990,6 @@ async def _get_chart_preview_internal( # noqa: C901 """ try: await ctx.report_progress(1, 3, "Looking up chart") - from superset.daos.chart import ChartDAO # Find the chart with event_logger.log_context(action="mcp.get_chart_preview.chart_lookup"): @@ -1921,25 +1058,16 @@ async def _get_chart_preview_internal( # noqa: C901 error_type="NotFound", ) - elif isinstance(request.identifier, int) or ( - isinstance(request.identifier, str) and request.identifier.isdigit() - ): - chart_id = ( - int(request.identifier) - if isinstance(request.identifier, str) - else request.identifier - ) + else: await ctx.debug( - "Performing ID-based chart lookup: chart_id=%s" % (chart_id,) + "Looking up chart: identifier=%s" % (request.identifier,) ) - chart = ChartDAO.find_by_id(chart_id) - elif isinstance(request.identifier, str): - await ctx.debug( - "Performing UUID-based chart lookup: uuid=%s" - % (request.identifier,) - ) - # Try UUID lookup using DAO flexible method - chart = ChartDAO.find_by_id(request.identifier, id_column="uuid") + if request.identifier is None: + return ChartError( + error="Chart identifier is required", + error_type="ValidationError", + ) + chart = find_chart_by_identifier(request.identifier) # If not found and looks like a form_data_key, try transient if ( @@ -2268,7 +1396,15 @@ async def get_chart_preview( error=OAUTH2_CONFIG_ERROR_MESSAGE, error_type="OAUTH2_REDIRECT_ERROR", ) - except Exception as e: + except ( + SupersetException, + CommandException, + SQLAlchemyError, + KeyError, + ValueError, + TypeError, + AttributeError, + ) as e: await ctx.error( "Chart preview generation failed: identifier=%s, error=%s, error_type=%s" % ( diff --git a/superset/mcp_service/chart/tool/update_chart.py b/superset/mcp_service/chart/tool/update_chart.py index b9b5ab49af9..40b9a2f6e85 100644 --- a/superset/mcp_service/chart/tool/update_chart.py +++ b/superset/mcp_service/chart/tool/update_chart.py @@ -22,7 +22,6 @@ MCP tool: update_chart import logging import time from typing import Any -from urllib.parse import parse_qs, urlparse from fastmcp import Context from sqlalchemy.exc import SQLAlchemyError @@ -31,6 +30,10 @@ from superset_core.mcp.decorators import tool, ToolAnnotations from superset.commands.exceptions import CommandException from superset.exceptions import OAuth2Error, OAuth2RedirectError from superset.extensions import event_logger +from superset.mcp_service.chart.chart_helpers import ( + extract_form_data_key_from_url, + find_chart_by_identifier, +) from superset.mcp_service.chart.chart_utils import ( analyze_chart_capabilities, analyze_chart_semantics, @@ -54,18 +57,6 @@ from superset.utils import json logger = logging.getLogger(__name__) -def _find_chart(identifier: int | str) -> Any | None: - """Find a chart by numeric ID or UUID string.""" - from superset.daos.chart import ChartDAO - - if isinstance(identifier, int) or ( - isinstance(identifier, str) and identifier.isdigit() - ): - chart_id = int(identifier) if isinstance(identifier, str) else identifier - return ChartDAO.find_by_id(chart_id) - return ChartDAO.find_by_id(identifier, id_column="uuid") - - def _validation_error_response(message: str, details: str) -> GenerateChartResponse: return GenerateChartResponse.model_validate( { @@ -277,7 +268,7 @@ async def update_chart( # noqa: C901 try: with event_logger.log_context(action="mcp.update_chart.chart_lookup"): - chart = _find_chart(request.identifier) + chart = find_chart_by_identifier(request.identifier) if not chart: return GenerateChartResponse.model_validate( @@ -412,15 +403,21 @@ async def update_chart( # noqa: C901 if hasattr(preview_result, "content"): previews[format_type] = preview_result.content - except Exception as e: + except ( + OAuth2RedirectError, + OAuth2Error, + CommandException, + SQLAlchemyError, + KeyError, + ValueError, + TypeError, + AttributeError, + ) as e: logger.warning("Preview generation failed: %s", e) # Fallback: extract form_data_key from explore_url if not set - if form_data_key is None and explore_url and "form_data_key=" in explore_url: - parsed = urlparse(explore_url) - values = parse_qs(parsed.query).get("form_data_key") - if values: - form_data_key = values[0] + if form_data_key is None: + form_data_key = extract_form_data_key_from_url(explore_url) chart_id = updated_chart.id if saved and updated_chart else chart.id chart_uuid = ( diff --git a/superset/mcp_service/chart/tool/update_chart_preview.py b/superset/mcp_service/chart/tool/update_chart_preview.py index ae86d8df1c0..bc613b5a8da 100644 --- a/superset/mcp_service/chart/tool/update_chart_preview.py +++ b/superset/mcp_service/chart/tool/update_chart_preview.py @@ -24,10 +24,13 @@ import time from typing import Any, Dict from fastmcp import Context +from sqlalchemy.exc import SQLAlchemyError from superset_core.mcp.decorators import tool, ToolAnnotations -from superset.exceptions import OAuth2Error, OAuth2RedirectError +from superset.commands.exceptions import CommandException +from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException from superset.extensions import event_logger +from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url from superset.mcp_service.chart.chart_utils import ( analyze_chart_capabilities, analyze_chart_semantics, @@ -123,9 +126,19 @@ def update_chart_preview( explore_url = generate_explore_link(request.dataset_id, new_form_data) # Extract new form_data_key from the explore URL - new_form_data_key = None - if "form_data_key=" in explore_url: - new_form_data_key = explore_url.split("form_data_key=")[1].split("&")[0] + new_form_data_key = extract_form_data_key_from_url(explore_url) + if not new_form_data_key: + return { + "chart": None, + "error": { + "error_type": "PreviewError", + "message": "Failed to generate preview: missing form_data_key", + "details": "The explore URL did not contain a form_data_key", + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } with event_logger.log_context(action="mcp.update_chart_preview.metadata"): # Generate semantic analysis @@ -199,7 +212,15 @@ def update_chart_preview( "error": OAUTH2_CONFIG_ERROR_MESSAGE, "success": False, } - except Exception as e: + except ( + SupersetException, + CommandException, + SQLAlchemyError, + KeyError, + ValueError, + TypeError, + AttributeError, + ) as e: execution_time = int((time.time() - start_time) * 1000) return { "chart": None, diff --git a/superset/mcp_service/explore/tool/generate_explore_link.py b/superset/mcp_service/explore/tool/generate_explore_link.py index 9df6e030e97..40473a9cfe8 100644 --- a/superset/mcp_service/explore/tool/generate_explore_link.py +++ b/superset/mcp_service/explore/tool/generate_explore_link.py @@ -23,12 +23,12 @@ chart configuration. """ from typing import Any, Dict -from urllib.parse import parse_qs, urlparse from fastmcp import Context from superset_core.mcp.decorators import tool, ToolAnnotations from superset.extensions import event_logger +from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url from superset.mcp_service.chart.chart_utils import ( generate_explore_link as generate_url, map_config_to_form_data, @@ -175,14 +175,8 @@ async def generate_explore_link( dataset_id=request.dataset_id, form_data=form_data ) - # Extract form_data_key from the explore URL using proper URL parsing - form_data_key = None - if explore_url: - parsed = urlparse(explore_url) - query_params = parse_qs(parsed.query) - form_data_key_list = query_params.get("form_data_key", []) - if form_data_key_list: - form_data_key = form_data_key_list[0] + # Extract form_data_key from the explore URL + form_data_key = extract_form_data_key_from_url(explore_url) await ctx.report_progress(4, 4, "URL generation complete") await ctx.info( diff --git a/tests/unit_tests/mcp_service/chart/test_chart_helpers.py b/tests/unit_tests/mcp_service/chart/test_chart_helpers.py new file mode 100644 index 00000000000..5318f0fe8ac --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_chart_helpers.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from unittest.mock import MagicMock, patch + +from superset.mcp_service.chart.chart_helpers import ( + extract_form_data_key_from_url, + find_chart_by_identifier, + get_cached_form_data, +) + + +def test_extract_form_data_key_from_url_with_key(): + url = "http://localhost:8088/explore/?form_data_key=abc123&slice_id=1" + assert extract_form_data_key_from_url(url) == "abc123" + + +def test_extract_form_data_key_from_url_no_key(): + url = "http://localhost:8088/explore/?slice_id=1" + assert extract_form_data_key_from_url(url) is None + + +def test_extract_form_data_key_from_url_none(): + assert extract_form_data_key_from_url(None) is None + + +def test_extract_form_data_key_from_url_empty(): + assert extract_form_data_key_from_url("") is None + + +def test_extract_form_data_key_from_url_multiple_params(): + url = "http://localhost:8088/explore/?slice_id=5&form_data_key=xyz789&other=val" + assert extract_form_data_key_from_url(url) == "xyz789" + + +@patch("superset.daos.chart.ChartDAO.find_by_id") +def test_find_chart_by_identifier_int(mock_find): + mock_chart = MagicMock() + mock_chart.id = 42 + mock_find.return_value = mock_chart + + result = find_chart_by_identifier(42) + mock_find.assert_called_once_with(42) + assert result == mock_chart + + +@patch("superset.daos.chart.ChartDAO.find_by_id") +def test_find_chart_by_identifier_str_digit(mock_find): + mock_chart = MagicMock() + mock_find.return_value = mock_chart + + result = find_chart_by_identifier("123") + mock_find.assert_called_once_with(123) + assert result == mock_chart + + +@patch("superset.daos.chart.ChartDAO.find_by_id") +def test_find_chart_by_identifier_uuid(mock_find): + mock_chart = MagicMock() + mock_find.return_value = mock_chart + + uuid_str = "a1b2c3d4-5678-90ab-cdef-1234567890ab" + result = find_chart_by_identifier(uuid_str) + mock_find.assert_called_once_with(uuid_str, id_column="uuid") + assert result == mock_chart + + +@patch("superset.daos.chart.ChartDAO.find_by_id") +def test_find_chart_by_identifier_not_found(mock_find): + mock_find.return_value = None + result = find_chart_by_identifier(999) + assert result is None + + +@patch( + "superset.commands.explore.form_data.get.GetFormDataCommand.run", + return_value='{"viz_type": "table"}', +) +@patch("superset.commands.explore.form_data.get.GetFormDataCommand.__init__") +def test_get_cached_form_data_success(mock_init, mock_run): + mock_init.return_value = None + result = get_cached_form_data("test_key") + assert result == '{"viz_type": "table"}' + + +@patch( + "superset.commands.explore.form_data.get.GetFormDataCommand.run", + side_effect=KeyError("not found"), +) +@patch("superset.commands.explore.form_data.get.GetFormDataCommand.__init__") +def test_get_cached_form_data_key_error(mock_init, mock_run): + mock_init.return_value = None + result = get_cached_form_data("bad_key") + assert result is None diff --git a/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py index b027930be27..d6a49c42674 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py @@ -26,6 +26,7 @@ import pytest from fastmcp import Client from superset.mcp_service.app import mcp +from superset.mcp_service.chart.chart_helpers import find_chart_by_identifier from superset.mcp_service.chart.chart_utils import DatasetValidationResult from superset.mcp_service.chart.schemas import ( AxisConfig, @@ -40,7 +41,6 @@ from superset.mcp_service.chart.schemas import ( from superset.mcp_service.chart.tool.update_chart import ( _build_preview_form_data, _build_update_payload, - _find_chart, ) # The __init__.py re-exports the update_chart *function*, so a plain @@ -519,7 +519,7 @@ class TestUpdateChartDatasetAccess: class TestFindChart: - """Tests for _find_chart helper.""" + """Tests for find_chart_by_identifier helper (moved to chart_helpers).""" @patch("superset.daos.chart.ChartDAO.find_by_id") def test_find_chart_by_numeric_id(self, mock_find): @@ -527,7 +527,7 @@ class TestFindChart: mock_chart = Mock() mock_find.return_value = mock_chart - result = _find_chart(42) + result = find_chart_by_identifier(42) mock_find.assert_called_once_with(42) assert result is mock_chart @@ -538,7 +538,7 @@ class TestFindChart: mock_chart = Mock() mock_find.return_value = mock_chart - result = _find_chart("123") + result = find_chart_by_identifier("123") mock_find.assert_called_once_with(123) assert result is mock_chart @@ -550,7 +550,7 @@ class TestFindChart: mock_find.return_value = mock_chart uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890" - result = _find_chart(uuid) + result = find_chart_by_identifier(uuid) mock_find.assert_called_once_with(uuid, id_column="uuid") assert result is mock_chart @@ -560,7 +560,7 @@ class TestFindChart: """Returns None when chart is not found.""" mock_find.return_value = None - result = _find_chart(999) + result = find_chart_by_identifier(999) assert result is None