chore(mcp): extract shared chart helpers and ASCII rendering into separate modules (#39438)

This commit is contained in:
Amin Ghadersohi
2026-04-21 20:10:49 -04:00
committed by GitHub
parent 05fc5bb424
commit e6853894ab
11 changed files with 1156 additions and 993 deletions

View File

@@ -0,0 +1,863 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
ASCII chart rendering functions for MCP chart previews.
Pure rendering functions that take data in and return strings out.
No dependencies on MCP tools, FastMCP Context, or request objects.
"""
from __future__ import annotations
import logging
import math
from typing import Any
logger = logging.getLogger(__name__)
def generate_ascii_chart(
data: list[Any], chart_type: str, width: int = 80, height: int = 20
) -> str:
"""Generate ASCII art chart from data."""
if not data:
return "No data available for ASCII chart"
try:
# Clamp to safe minimums to prevent negative plot sizes
width = max(width, 21)
height = max(height, 9)
logger.debug(
"generate_ascii_chart: chart_type=%s, data_rows=%s", chart_type, len(data)
)
if chart_type in ["bar", "column", "echarts_timeseries_bar"]:
return _generate_ascii_bar_chart(data, width, height)
elif chart_type in ["line", "echarts_timeseries_line"]:
return _generate_ascii_line_chart(data, width, height)
elif chart_type in ["scatter", "echarts_timeseries_scatter"]:
return _generate_ascii_scatter_chart(data, width, height)
else:
logger.debug(
"Unsupported chart type '%s', falling back to table", chart_type
)
return generate_ascii_table(data, width)
except (TypeError, ValueError, KeyError, IndexError) as e:
logger.error("ASCII chart generation failed: %s", e, exc_info=True)
return "ASCII chart generation failed"
def _generate_ascii_bar_chart(data: list[Any], width: int, height: int) -> str:
"""Generate enhanced ASCII bar chart with horizontal and vertical options."""
if not data:
return "No data for bar chart"
# Extract numeric values for bars
values = []
labels = []
for row in data[:12]: # Increased limit for better charts
if isinstance(row, dict):
# Find numeric and string values
numeric_val = None
label_val = None
for _key, val in row.items():
if isinstance(val, (int, float)) and numeric_val is None:
numeric_val = val
elif isinstance(val, str) and label_val is None:
label_val = val
if numeric_val is not None:
values.append(numeric_val)
labels.append(label_val or f"Item {len(values)}")
if not values:
return "No numeric data found for bar chart"
# Decide between horizontal and vertical based on label lengths
avg_label_length = sum(len(str(label)) for label in labels) / len(labels)
use_horizontal = avg_label_length > 8 or len(values) > 8
if use_horizontal:
return _generate_horizontal_bar_chart(values, labels, width)
else:
return _generate_vertical_bar_chart(values, labels, width, height)
def _generate_horizontal_bar_chart(
values: list[float], labels: list[str], width: int
) -> str:
"""Generate horizontal ASCII bar chart."""
lines = []
lines.append("📊 Horizontal Bar Chart")
lines.append("" * min(width, 60))
max_val = max(values) if values else 1
min_val = min(values) if values else 0
max_bar_width = min(40, width - 20) # Leave space for labels and values
# Add scale indicator
lines.append(f"Scale: {min_val:.1f} ────────────── {max_val:.1f}")
lines.append("")
for label, value in zip(labels, values, strict=False):
# Calculate bar length
if max_val > min_val:
normalized = (value - min_val) / (max_val - min_val)
bar_length = max(1, int(normalized * max_bar_width))
else:
bar_length = 1
# Create bar with gradient effect
bar = _create_gradient_bar(bar_length, value, max_val)
# Format value
if abs(value) >= 1000000:
value_str = f"{value / 1000000:.1f}M"
elif abs(value) >= 1000:
value_str = f"{value / 1000:.1f}K"
else:
value_str = f"{value:.1f}"
# Truncate label if too long
display_label = label[:15] if len(label) > 15 else label
lines.append(f"{display_label:>15}{bar:<{max_bar_width}} {value_str}")
return "\n".join(lines)
def _generate_vertical_bar_chart( # noqa: C901
values: list[float], labels: list[str], width: int, height: int
) -> str:
"""Generate vertical ASCII bar chart."""
lines = []
lines.append("📊 Vertical Bar Chart")
lines.append("" * min(width, 60))
max_val = max(values) if values else 1
min_val = min(values) if values else 0
chart_height = min(15, height - 8) # Leave space for title and labels
# Create the chart grid
grid = []
for _ in range(chart_height):
grid.append([" "] * len(values))
# Fill the bars
for col, value in enumerate(values):
if max_val > min_val:
normalized = (value - min_val) / (max_val - min_val)
bar_height = max(1, int(normalized * chart_height))
else:
bar_height = 1
# Fill from bottom up
for row_idx in range(chart_height - bar_height, chart_height):
if row_idx < len(grid):
# Use different characters for height effect
if row_idx == chart_height - bar_height:
grid[row_idx][col] = "" # Top of bar
elif row_idx == chart_height - 1:
grid[row_idx][col] = "" # Bottom of bar
else:
grid[row_idx][col] = "" # Middle of bar
# Add Y-axis scale
for i, row_data in enumerate(grid):
y_val = (
max_val - (i / (chart_height - 1)) * (max_val - min_val)
if chart_height > 1
else max_val
)
if abs(y_val) >= 1000:
y_label = f"{y_val:.0f}"
else:
y_label = f"{y_val:.1f}"
lines.append(f"{y_label:>6}" + "".join(f"{cell:^3}" for cell in row_data))
# Add X-axis
lines.append("" + "───" * len(values))
# Add labels
label_line = " "
for label in labels:
short_label = label[:3] if len(label) > 3 else label
label_line += f"{short_label:^3}"
lines.append(label_line)
return "\n".join(lines)
def _create_gradient_bar(length: int, value: float, max_val: float) -> str:
"""Create a gradient bar with visual effects."""
if length <= 0:
return ""
# Create gradient effect based on value intensity
intensity = value / max_val if max_val > 0 else 0
if intensity > 0.8:
# High values - solid bars
return "" * length
elif intensity > 0.6:
# Medium-high values - mostly solid with some texture
return "" * (length - 1) + "" if length > 1 else ""
elif intensity > 0.4:
# Medium values - mixed texture
return "" * length
elif intensity > 0.2:
# Low-medium values - lighter texture
return "" * length
else:
# Low values - lightest texture
return "" * length
def _generate_ascii_line_chart(data: list[Any], width: int, height: int) -> str:
"""Generate enhanced ASCII line chart with trend analysis."""
if not data:
return "No data for line chart"
lines = []
lines.append("📈 Line Chart with Trend Analysis")
lines.append("" * min(width, 60))
# Extract values and labels for plotting
values, labels = _extract_time_series_data(data)
if not values:
return "No numeric data found for line chart"
# Generate enhanced line chart
if len(values) >= 3:
lines.extend(_create_enhanced_line_chart(values, labels, width, height))
else:
# Fallback to sparkline for small datasets
sparkline_data = _create_sparkline(values)
lines.extend(sparkline_data)
# Add trend analysis
trend_analysis = _analyze_trend(values)
lines.append("")
lines.append("📊 Trend Analysis:")
lines.extend(trend_analysis)
return "\n".join(lines)
def _extract_time_series_data(data: list[Any]) -> tuple[list[float], list[str]]:
"""Extract time series data with labels."""
values = []
labels = []
for row in data[:20]: # Limit points for readability
if isinstance(row, dict):
# Find the first numeric value and first string/date value
numeric_val = None
label_val = None
for key, val in row.items():
if isinstance(val, (int, float)) and numeric_val is None:
numeric_val = val
elif isinstance(val, str) and label_val is None:
# Use the key name if it looks like a date/time field
if any(
date_word in key.lower()
for date_word in ["date", "time", "month", "day", "year"]
):
label_val = str(val)[:10] # Truncate long dates
else:
label_val = str(val)[:8] # Truncate long strings
if numeric_val is not None:
values.append(numeric_val)
labels.append(label_val or f"P{len(values)}")
return values, labels
def _create_enhanced_line_chart(
values: list[float], labels: list[str], width: int, height: int
) -> list[str]:
"""Create an enhanced ASCII line chart with better visualization."""
lines = []
chart_width = min(50, width - 15)
chart_height = min(12, height - 8)
if len(values) < 2:
return ["Insufficient data for line chart"]
# Normalize values to chart height
min_val = min(values)
max_val = max(values)
val_range = max_val - min_val if max_val != min_val else 1
# Create chart grid
grid = [[" " for _ in range(chart_width)] for _ in range(chart_height)]
# Plot the line with connecting segments
prev_x, prev_y = None, None
for i, value in enumerate(values):
# Map to grid coordinates
x = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0
y = chart_height - 1 - int(((value - min_val) / val_range) * (chart_height - 1))
# Ensure coordinates are within bounds
x = max(0, min(x, chart_width - 1))
y = max(0, min(y, chart_height - 1))
# Mark the point
grid[y][x] = ""
# Draw line segment to previous point
if prev_x is not None and prev_y is not None:
_draw_line_segment(grid, prev_x, prev_y, x, y, chart_width, chart_height)
prev_x, prev_y = x, y
# Render the chart with Y-axis labels
for i, row in enumerate(grid):
y_val = (
max_val - (i / (chart_height - 1)) * val_range
if chart_height > 1
else max_val
)
if abs(y_val) >= 1000:
y_label = f"{y_val:.0f}"
else:
y_label = f"{y_val:.1f}"
lines.append(f"{y_label:>8}" + "".join(row))
# Add X-axis
lines.append("" + "" * chart_width)
# Add selected X-axis labels (show every few labels to avoid crowding)
label_line = " "
step = max(1, len(labels) // 6) # Show max 6 labels
for i in range(0, len(labels), step):
pos = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0
# Add spacing to position the label
while len(label_line) - 10 < pos:
label_line += " "
label_line += labels[i][:8]
lines.append(label_line)
return lines
def _draw_line_segment(
grid: list[list[str]], x1: int, y1: int, x2: int, y2: int, width: int, height: int
) -> None:
"""Draw a line segment between two points using Bresenham-like algorithm."""
# Simple line drawing - connect points with line characters
if x1 == x2: # Vertical line
start_y, end_y = sorted([y1, y2])
for y in range(start_y + 1, end_y):
if 0 <= y < height and 0 <= x1 < width:
grid[y][x1] = ""
elif y1 == y2: # Horizontal line
start_x, end_x = sorted([x1, x2])
for x in range(start_x + 1, end_x):
if 0 <= y1 < height and 0 <= x < width:
grid[y1][x] = ""
else: # Diagonal line - use simple interpolation
steps = max(abs(x2 - x1), abs(y2 - y1))
for step in range(1, steps):
x = x1 + int((x2 - x1) * step / steps)
y = y1 + int((y2 - y1) * step / steps)
if 0 <= x < width and 0 <= y < height:
if abs(x2 - x1) > abs(y2 - y1):
grid[y][x] = ""
else:
grid[y][x] = ""
def _analyze_trend(values: list[float]) -> list[str]:
"""Analyze trend in the data."""
if len(values) < 2:
return ["• Insufficient data for trend analysis"]
analysis = []
# Calculate basic statistics
first_val = values[0]
last_val = values[-1]
min_val = min(values)
max_val = max(values)
avg_val = sum(values) / len(values)
# Overall trend
if last_val > first_val * 1.1:
trend_icon = "📈"
trend_desc = "Strong upward trend"
elif last_val > first_val * 1.05:
trend_icon = "📊"
trend_desc = "Moderate upward trend"
elif last_val < first_val * 0.9:
trend_icon = "📉"
trend_desc = "Strong downward trend"
elif last_val < first_val * 0.95:
trend_icon = "📊"
trend_desc = "Moderate downward trend"
else:
trend_icon = "➡️"
trend_desc = "Relatively stable"
analysis.append(f"{trend_icon} {trend_desc}")
analysis.append(f"• Range: {min_val:.1f} to {max_val:.1f} (avg: {avg_val:.1f})")
# Volatility
if len(values) >= 3:
changes = [abs(values[i] - values[i - 1]) for i in range(1, len(values))]
avg_change = sum(changes) / len(changes)
volatility = "High" if avg_change > (max_val - min_val) * 0.1 else "Low"
analysis.append(f"• Volatility: {volatility}")
return analysis
def _create_sparkline(values: list[float]) -> list[str]:
"""Create sparkline visualization from values."""
if len(values) <= 1:
return []
max_val = max(values)
min_val = min(values)
range_val = max_val - min_val if max_val != min_val else 1
sparkline = ""
for val in values:
normalized = (val - min_val) / range_val
if normalized < 0.2:
sparkline += ""
elif normalized < 0.4:
sparkline += ""
elif normalized < 0.6:
sparkline += ""
elif normalized < 0.8:
sparkline += ""
else:
sparkline += ""
# Safe formatting to avoid NaN display
if _is_nan_value(min_val) or _is_nan_value(max_val):
return ["Range: Unable to calculate from data", sparkline]
else:
return [f"Range: {min_val:.2f} to {max_val:.2f}", sparkline]
def _is_nan_value(value: Any) -> bool:
"""Check if a value is NaN or invalid."""
try:
return math.isnan(float(value))
except (ValueError, TypeError):
return True
def _generate_ascii_scatter_chart(data: list[Any], width: int, height: int) -> str:
"""Generate ASCII scatter plot."""
if not data:
return "No data for scatter chart"
lines = []
lines.append("ASCII Scatter Plot")
lines.append("=" * min(width, 50))
# Extract data points
x_values, y_values, x_column, y_column = _extract_scatter_data(data)
# Log debug info server-side only
logger.debug(
"Scatter chart: x_column=%s, y_column=%s, valid_pairs=%d",
x_column,
y_column,
len(x_values),
)
# Check if we have enough data
if len(x_values) < 2:
return generate_ascii_table(data, width)
# Add axis info
lines.extend(_create_axis_info(x_values, y_values, x_column, y_column))
# Create and render grid
grid = _create_scatter_grid(x_values, y_values, width, height)
lines.extend(_render_scatter_grid(grid, x_values, y_values, width, height))
return "\n".join(lines)
def _extract_scatter_data(
data: list[Any],
) -> tuple[list[float], list[float], str | None, str | None]:
"""Extract X,Y data from scatter chart data."""
x_values = []
y_values = []
x_column = None
y_column = None
numeric_columns = []
if data and isinstance(data[0], dict):
# Find the first two numeric columns
for key, val in data[0].items():
if isinstance(val, (int, float)) and not (
isinstance(val, float) and (val != val)
): # Exclude NaN
numeric_columns.append(key)
if len(numeric_columns) >= 2:
x_column = numeric_columns[0]
y_column = numeric_columns[1]
# Extract X,Y pairs
for row in data[:50]: # Limit for ASCII display
if isinstance(row, dict):
x_val = row.get(x_column)
y_val = row.get(y_column)
# Check for valid numbers (not NaN)
if (
isinstance(x_val, (int, float))
and isinstance(y_val, (int, float))
and not (
isinstance(x_val, float) and (x_val != x_val)
) # Not NaN
and not (isinstance(y_val, float) and (y_val != y_val))
): # Not NaN
x_values.append(x_val)
y_values.append(y_val)
return x_values, y_values, x_column, y_column
def _create_axis_info(
x_values: list[float],
y_values: list[float],
x_column: str | None,
y_column: str | None,
) -> list[str]:
"""Create axis information lines."""
return [
f"X-axis: {x_column} (range: {min(x_values):.2f} to {max(x_values):.2f})",
f"Y-axis: {y_column} (range: {min(y_values):.2f} to {max(y_values):.2f})",
f"Showing {len(x_values)} data points",
"",
]
def _create_scatter_grid(
x_values: list[float], y_values: list[float], width: int, height: int
) -> list[list[str]]:
"""Create and populate the scatter plot grid."""
plot_width = min(40, width - 10)
plot_height = min(15, height - 8)
# Normalize values to fit in grid
x_min, x_max = min(x_values), max(x_values)
y_min, y_max = min(y_values), max(y_values)
x_range = x_max - x_min if x_max != x_min else 1
y_range = y_max - y_min if y_max != y_min else 1
# Create grid
grid = [[" " for _ in range(plot_width)] for _ in range(plot_height)]
# Plot points
for x, y in zip(x_values, y_values, strict=False):
try:
grid_x = int(((x - x_min) / x_range) * (plot_width - 1))
grid_y = int(((y - y_min) / y_range) * (plot_height - 1))
except (ValueError, OverflowError):
# Skip points that can't be converted to integers (NaN, inf, etc.)
continue
grid_y = plot_height - 1 - grid_y # Flip Y axis for display
if 0 <= grid_x < plot_width and 0 <= grid_y < plot_height:
if grid[grid_y][grid_x] == " ":
grid[grid_y][grid_x] = ""
else:
grid[grid_y][grid_x] = "" # Multiple points
return grid
def _render_scatter_grid(
grid: list[list[str]],
x_values: list[float],
y_values: list[float],
width: int,
height: int,
) -> list[str]:
"""Render the scatter plot grid with axes and labels."""
lines = []
plot_width = min(40, width - 10)
plot_height = min(15, height - 8)
x_min, x_max = min(x_values), max(x_values)
y_min, y_max = min(y_values), max(y_values)
y_range = y_max - y_min if y_max != y_min else 1
# Add Y-axis labels and plot
for i, row in enumerate(grid):
y_val = y_max - (i / (plot_height - 1)) * y_range if plot_height > 1 else y_max
y_label = f"{y_val:.1f}" if abs(y_val) < 1000 else f"{y_val:.0f}"
lines.append(f"{y_label:>6} |{''.join(row)}")
# Add X-axis
x_axis_line = " " * 7 + "+" + "-" * plot_width
lines.append(x_axis_line)
# Add X-axis labels
x_left_label = f"{x_min:.1f}" if abs(x_min) < 1000 else f"{x_min:.0f}"
x_right_label = f"{x_max:.1f}" if abs(x_max) < 1000 else f"{x_max:.0f}"
x_labels = (
" " * 8
+ x_left_label
+ " " * (plot_width - len(x_left_label) - len(x_right_label))
+ x_right_label
)
lines.append(x_labels)
return lines
def generate_ascii_table(data: list[Any], width: int) -> str:
"""Generate enhanced ASCII table with better formatting."""
if not data:
return "No data for table"
lines = []
lines.append("📋 Data Table")
lines.append("" * min(width, 70))
# Get column headers from first row
if isinstance(data[0], dict):
# Select best columns to display
all_headers = list(data[0].keys())
headers = _select_display_columns(all_headers, data, max_cols=6)
# Calculate optimal column widths
col_widths = _calculate_column_widths(headers, data, width)
# Create enhanced header
lines.append(_create_table_header(headers, col_widths))
lines.append(_create_table_separator(col_widths))
# Add data rows with better formatting
row_count = min(15, len(data)) # Show more rows
for i, row in enumerate(data[:row_count]):
formatted_row = _format_table_row(row, headers, col_widths)
lines.append(formatted_row)
# Add separator every 5 rows for readability
if i > 0 and (i + 1) % 5 == 0 and i < row_count - 1:
lines.append(_create_light_separator(col_widths))
# Add footer with stats
lines.append(_create_table_separator(col_widths))
lines.append(f"📊 Showing {row_count} of {len(data)} rows")
# Add column summaries for numeric columns
numeric_summaries = _create_numeric_summaries(data, headers)
if numeric_summaries:
lines.append("")
lines.append("📈 Numeric Summaries:")
lines.extend(numeric_summaries)
return "\n".join(lines)
def _select_display_columns(
all_headers: list[str], data: list[Any], max_cols: int = 6
) -> list[str]:
"""Select the most interesting columns to display."""
if len(all_headers) <= max_cols:
return all_headers
# Prioritize columns by interest level
priority_scores = {}
for header in all_headers:
score = 0
header_lower = header.lower()
# Higher priority for common business fields
if any(word in header_lower for word in ["name", "title", "id"]):
score += 10
if any(
word in header_lower
for word in ["amount", "price", "cost", "revenue", "sales"]
):
score += 8
if any(word in header_lower for word in ["date", "time", "created", "updated"]):
score += 6
if any(word in header_lower for word in ["count", "total", "sum", "avg"]):
score += 5
# Check data variety (more variety = more interesting)
sample_values = [
str(row.get(header, "")) for row in data[:10] if isinstance(row, dict)
]
unique_values = len(set(sample_values))
if unique_values > 1:
score += min(unique_values, 5)
priority_scores[header] = score
# Return top scoring columns
sorted_headers = sorted(
all_headers, key=lambda h: priority_scores.get(h, 0), reverse=True
)
return sorted_headers[:max_cols]
def _calculate_column_widths(
headers: list[str], data: list[Any], total_width: int
) -> list[int]:
"""Calculate optimal column widths based on content."""
if not headers:
return []
# Start with minimum widths based on header lengths
min_widths = [max(8, min(len(h) + 2, 15)) for h in headers]
# Check actual data to adjust widths
for row in data[:10]: # Sample first 10 rows
if isinstance(row, dict):
for i, header in enumerate(headers):
val = row.get(header, "")
if isinstance(val, float):
content_len = len(f"{val:.2f}")
elif isinstance(val, int):
content_len = len(str(val))
else:
content_len = len(str(val))
min_widths[i] = max(min_widths[i], min(content_len + 1, 20))
# Distribute remaining space proportionally
used_width = sum(min_widths) + len(headers) * 3 # Account for separators
available_width = min(total_width - 10, 80) # Leave margins
if used_width < available_width:
# Distribute extra space
extra_space = available_width - used_width
for i in range(len(min_widths)):
min_widths[i] += extra_space // len(min_widths)
return min_widths
def _create_table_header(headers: list[str], widths: list[int]) -> str:
"""Create formatted table header."""
formatted_headers = []
for header, width in zip(headers, widths, strict=False):
# Truncate and center header
display_header = header[: width - 2] if len(header) > width - 2 else header
formatted_headers.append(f"{display_header:^{width}}")
return (
""
+ "".join("" * w for w in widths)
+ "\n"
+ "".join(formatted_headers)
+ ""
)
def _create_table_separator(widths: list[int]) -> str:
"""Create table separator line."""
return "" + "".join("" * w for w in widths) + ""
def _create_light_separator(widths: list[int]) -> str:
"""Create light separator line."""
return "" + "".join("" * w for w in widths) + ""
def _format_table_row(
row: dict[str, Any], headers: list[str], widths: list[int]
) -> str:
"""Format a single table row."""
formatted_values = []
for header, width in zip(headers, widths, strict=False):
val = row.get(header, "")
# Format value based on type
if isinstance(val, float):
if abs(val) >= 1000000:
formatted_val = f"{val / 1000000:.1f}M"
elif abs(val) >= 1000:
formatted_val = f"{val / 1000:.1f}K"
else:
formatted_val = f"{val:.2f}"
elif isinstance(val, int):
if abs(val) >= 1000000:
formatted_val = f"{val // 1000000}M"
elif abs(val) >= 1000:
formatted_val = f"{val // 1000}K"
else:
formatted_val = str(val)
else:
formatted_val = str(val)
# Truncate if too long
if len(formatted_val) > width - 2:
formatted_val = formatted_val[: width - 5] + "..."
# Align numbers right, text left
if isinstance(val, (int, float)):
formatted_values.append(f"{formatted_val:>{width - 2}}")
else:
formatted_values.append(f"{formatted_val:<{width - 2}}")
return "" + "".join(formatted_values) + ""
def _create_numeric_summaries(data: list[Any], headers: list[str]) -> list[str]:
"""Create summaries for numeric columns."""
summaries = []
for header in headers:
# Collect numeric values
values = []
for row in data:
if isinstance(row, dict):
val = row.get(header)
if isinstance(val, (int, float)):
values.append(val)
if len(values) >= 2:
min_val = min(values)
max_val = max(values)
avg_val = sum(values) / len(values)
if abs(avg_val) >= 1000:
avg_str = f"{avg_val / 1000:.1f}K"
else:
avg_str = f"{avg_val:.1f}"
summaries.append(
f" {header}: avg={avg_str}, range={min_val:.1f}-{max_val:.1f}"
)
return summaries

View File

@@ -0,0 +1,81 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Shared helper functions for MCP chart tools.
This module contains reusable utility functions for common operations
across chart tools: chart lookup, cached form data retrieval, and
URL parameter extraction. Config mapping logic lives in chart_utils.py.
"""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from urllib.parse import parse_qs, urlparse
if TYPE_CHECKING:
from superset.models.slice import Slice
logger = logging.getLogger(__name__)
def find_chart_by_identifier(identifier: int | str) -> Slice | None:
"""Find a chart by numeric ID or UUID string.
Accepts an integer ID, a string that looks like a digit (e.g. "123"),
or a UUID string. Returns the Slice model instance or None.
"""
from superset.daos.chart import ChartDAO # avoid circular import
if isinstance(identifier, int) or (
isinstance(identifier, str) and identifier.isdigit()
):
chart_id = int(identifier) if isinstance(identifier, str) else identifier
return ChartDAO.find_by_id(chart_id)
return ChartDAO.find_by_id(identifier, id_column="uuid")
def get_cached_form_data(form_data_key: str) -> str | None:
"""Retrieve form_data from cache using form_data_key.
Returns the JSON string of form_data if found, None otherwise.
"""
# avoid circular import — commands depend on app initialization
from superset.commands.exceptions import CommandException
from superset.commands.explore.form_data.get import GetFormDataCommand
from superset.commands.explore.form_data.parameters import CommandParameters
try:
cmd_params = CommandParameters(key=form_data_key)
return GetFormDataCommand(cmd_params).run()
except (KeyError, ValueError, CommandException) as e:
logger.warning("Failed to retrieve form_data from cache: %s", e)
return None
def extract_form_data_key_from_url(url: str | None) -> str | None:
"""Extract the form_data_key query parameter from an explore URL.
Returns the form_data_key value or None if not found or URL is empty.
"""
if not url:
return None
parsed = urlparse(url)
values = parse_qs(parsed.query).get("form_data_key", [])
return values[0] if values else None

View File

@@ -22,7 +22,6 @@ import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List
from urllib.parse import parse_qs, urlparse
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
@@ -32,6 +31,7 @@ from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.extensions import event_logger
from superset.mcp_service.auth import has_dataset_access
from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url
from superset.mcp_service.chart.chart_utils import (
analyze_chart_capabilities,
analyze_chart_semantics,
@@ -540,13 +540,8 @@ async def generate_chart( # noqa: C901
explore_url = generate_explore_link(request.dataset_id, form_data)
await ctx.debug("Generated explore link: explore_url=%s" % (explore_url,))
# Extract form_data_key from the explore URL using proper URL parsing
if explore_url:
parsed = urlparse(explore_url)
query_params = parse_qs(parsed.query)
form_data_key_list = query_params.get("form_data_key", [])
if form_data_key_list:
form_data_key = form_data_key_list[0]
# Extract form_data_key from the explore URL
form_data_key = extract_form_data_key_from_url(explore_url)
# Compile check for preview-only mode
# Validate dataset existence and user access before running queries

View File

@@ -25,15 +25,19 @@ from typing import Any, Dict, List, TYPE_CHECKING
from fastmcp import Context
from flask import current_app
from sqlalchemy.exc import SQLAlchemyError
from superset_core.mcp.decorators import tool, ToolAnnotations
if TYPE_CHECKING:
from superset.models.slice import Slice
from superset.commands.exceptions import CommandException
from superset.commands.explore.form_data.parameters import CommandParameters
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_helpers import (
find_chart_by_identifier,
get_cached_form_data,
)
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
ChartData,
@@ -62,21 +66,6 @@ def _apply_extra_form_data(
merge_extra_filters(form_data)
def _get_cached_form_data(form_data_key: str) -> str | None:
"""Retrieve form_data from cache using form_data_key.
Returns the JSON string of form_data if found, None otherwise.
"""
from superset.commands.explore.form_data.get import GetFormDataCommand
try:
cmd_params = CommandParameters(key=form_data_key)
return GetFormDataCommand(cmd_params).run()
except (KeyError, ValueError, CommandException) as e:
logger.warning("Failed to retrieve form_data from cache: %s", e)
return None
@tool(
tags=["data"],
class_permission_name="Chart",
@@ -127,7 +116,6 @@ async def get_chart_data( # noqa: C901
try:
await ctx.report_progress(1, 4, "Looking up chart")
from superset.daos.chart import ChartDAO
from superset.utils import json as utils_json
chart = None
@@ -141,7 +129,7 @@ async def get_chart_data( # noqa: C901
"No chart identifier - querying data from unsaved chart cache: "
"form_data_key=%s" % (request.form_data_key,)
)
cached_form_data = _get_cached_form_data(request.form_data_key)
cached_form_data = get_cached_form_data(request.form_data_key)
if not cached_form_data:
return ChartError(
error="No cached chart data found for form_data_key. "
@@ -166,25 +154,13 @@ async def get_chart_data( # noqa: C901
# Find the chart by identifier
with event_logger.log_context(action="mcp.get_chart_data.chart_lookup"):
if isinstance(request.identifier, int) or (
isinstance(request.identifier, str) and request.identifier.isdigit()
):
chart_id = (
int(request.identifier)
if isinstance(request.identifier, str)
else request.identifier
await ctx.debug("Looking up chart: identifier=%s" % (request.identifier,))
if request.identifier is None:
return ChartError(
error="Chart identifier is required",
error_type="ValidationError",
)
await ctx.debug(
"Performing ID-based chart lookup: chart_id=%s" % (chart_id,)
)
chart = ChartDAO.find_by_id(chart_id)
elif isinstance(request.identifier, str):
await ctx.debug(
"Performing UUID-based chart lookup: uuid=%s"
% (request.identifier,)
)
# Try UUID lookup using DAO flexible method
chart = ChartDAO.find_by_id(request.identifier, id_column="uuid")
chart = find_chart_by_identifier(request.identifier)
if not chart:
await ctx.error("Chart not found: identifier=%s" % (request.identifier,))
@@ -239,7 +215,7 @@ async def get_chart_data( # noqa: C901
"Retrieving unsaved chart state from cache: form_data_key=%s"
% (request.form_data_key,)
)
if cached_form_data := _get_cached_form_data(request.form_data_key):
if cached_form_data := get_cached_form_data(request.form_data_key):
try:
parsed_form_data = utils_json.loads(cached_form_data)
# Only use if it's actually a dict (not null, list, etc.)
@@ -818,7 +794,15 @@ async def get_chart_data( # noqa: C901
error=OAUTH2_CONFIG_ERROR_MESSAGE,
error_type="OAUTH2_REDIRECT_ERROR",
)
except Exception as e:
except (
SupersetException,
CommandException,
SQLAlchemyError,
KeyError,
ValueError,
TypeError,
AttributeError,
) as e:
await ctx.error(
"Chart data retrieval failed: identifier=%s, error=%s, error_type=%s"
% (

View File

@@ -25,9 +25,8 @@ from fastmcp import Context
from sqlalchemy.orm import subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.commands.explore.form_data.parameters import CommandParameters
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_helpers import get_cached_form_data
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
ChartError,
@@ -41,26 +40,11 @@ from superset.mcp_service.mcp_core import ModelGetInfoCore
logger = logging.getLogger(__name__)
def _get_cached_form_data(form_data_key: str) -> str | None:
"""Retrieve form_data from cache using form_data_key.
Returns the JSON string of form_data if found, None otherwise.
"""
from superset.commands.explore.form_data.get import GetFormDataCommand
try:
cmd_params = CommandParameters(key=form_data_key)
return GetFormDataCommand(cmd_params).run()
except (KeyError, ValueError, CommandException) as e:
logger.warning("Failed to retrieve form_data from cache: %s", e)
return None
def _build_unsaved_chart_info(form_data_key: str) -> ChartInfo | ChartError:
"""Build a ChartInfo from cached form_data when no chart identifier exists."""
from superset.utils import json as utils_json
cached_form_data = _get_cached_form_data(form_data_key)
cached_form_data = get_cached_form_data(form_data_key)
if not cached_form_data:
return ChartError(
error="No cached chart data found for form_data_key. "
@@ -94,7 +78,7 @@ def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None
"""Override a ChartInfo's form_data with cached unsaved state."""
from superset.utils import json as utils_json
if cached_form_data := _get_cached_form_data(form_data_key):
if cached_form_data := get_cached_form_data(form_data_key):
try:
result.form_data = utils_json.loads(cached_form_data)
result.form_data_key = form_data_key

View File

@@ -23,11 +23,17 @@ import logging
from typing import Any, Dict, List, Protocol
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.ascii_charts import (
generate_ascii_chart,
generate_ascii_table,
)
from superset.mcp_service.chart.chart_helpers import find_chart_by_identifier
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
from superset.mcp_service.chart.schemas import (
AccessibilityMetadata,
@@ -262,7 +268,7 @@ class TablePreviewStrategy(PreviewFormatStrategy):
if result and "queries" in result and len(result["queries"]) > 0:
data = result["queries"][0].get("data", [])
table_data = _generate_ascii_table(data, 120)
table_data = generate_ascii_table(data, 120)
return TablePreview(
table_data=table_data,
@@ -964,874 +970,6 @@ class PreviewFormatGenerator:
return strategy.generate()
def generate_ascii_chart(
data: List[Any], chart_type: str, width: int = 80, height: int = 20
) -> str:
"""Generate ASCII art chart from data."""
if not data or len(data) == 0:
return "No data available for ASCII chart"
try:
logger.info(
"generate_ascii_chart: chart_type=%s, data_rows=%s", chart_type, len(data)
)
# Generate appropriate ASCII chart based on type
if chart_type in ["bar", "column", "echarts_timeseries_bar"]:
logger.info("Generating bar chart")
return _generate_ascii_bar_chart(data, width, height)
elif chart_type in ["line", "echarts_timeseries_line"]:
logger.info("Generating line chart")
return _generate_ascii_line_chart(data, width, height)
elif chart_type in ["scatter", "echarts_timeseries_scatter"]:
logger.info("Generating scatter chart")
return _generate_ascii_scatter_chart(data, width, height)
else:
# Default to table format for unsupported chart types
logger.info(
"Unsupported chart type '%s', falling back to table", chart_type
)
return _generate_ascii_table(data, width)
except (TypeError, ValueError, KeyError, IndexError) as e:
logger.error("ASCII chart generation failed: %s", e)
import traceback
logger.error("Traceback: %s", traceback.format_exc())
return f"ASCII chart generation failed: {str(e)}"
def _generate_ascii_bar_chart(data: List[Any], width: int, height: int) -> str:
"""Generate enhanced ASCII bar chart with horizontal and vertical options."""
if not data:
return "No data for bar chart"
# Extract numeric values for bars
values = []
labels = []
for row in data[:12]: # Increased limit for better charts
if isinstance(row, dict):
# Find numeric and string values
numeric_val = None
label_val = None
for _key, val in row.items():
if isinstance(val, (int, float)) and numeric_val is None:
numeric_val = val
elif isinstance(val, str) and label_val is None:
label_val = val
if numeric_val is not None:
values.append(numeric_val)
labels.append(label_val or f"Item {len(values)}")
if not values:
return "No numeric data found for bar chart"
# Decide between horizontal and vertical based on label lengths
avg_label_length = sum(len(str(label)) for label in labels) / len(labels)
use_horizontal = avg_label_length > 8 or len(values) > 8
if use_horizontal:
return _generate_horizontal_bar_chart(values, labels, width)
else:
return _generate_vertical_bar_chart(values, labels, width, height)
def _generate_horizontal_bar_chart(
values: List[float], labels: List[str], width: int
) -> str:
"""Generate horizontal ASCII bar chart."""
lines = []
lines.append("📊 Horizontal Bar Chart")
lines.append("" * min(width, 60))
max_val = max(values) if values else 1
min_val = min(values) if values else 0
max_bar_width = min(40, width - 20) # Leave space for labels and values
# Add scale indicator
lines.append(f"Scale: {min_val:.1f} ────────────── {max_val:.1f}")
lines.append("")
for label, value in zip(labels, values, strict=False):
# Calculate bar length
if max_val > min_val:
normalized = (value - min_val) / (max_val - min_val)
bar_length = max(1, int(normalized * max_bar_width))
else:
bar_length = 1
# Create bar with gradient effect
bar = _create_gradient_bar(bar_length, value, max_val)
# Format value
if abs(value) >= 1000000:
value_str = f"{value / 1000000:.1f}M"
elif abs(value) >= 1000:
value_str = f"{value / 1000:.1f}K"
else:
value_str = f"{value:.1f}"
# Truncate label if too long
display_label = label[:15] if len(label) > 15 else label
lines.append(f"{display_label:>15}{bar:<{max_bar_width}} {value_str}")
return "\n".join(lines)
def _generate_vertical_bar_chart( # noqa: C901
values: List[float], labels: List[str], width: int, height: int
) -> str:
"""Generate vertical ASCII bar chart."""
lines = []
lines.append("📊 Vertical Bar Chart")
lines.append("" * min(width, 60))
max_val = max(values) if values else 1
min_val = min(values) if values else 0
chart_height = min(15, height - 8) # Leave space for title and labels
# Create the chart grid
grid = []
for _ in range(chart_height):
grid.append([" "] * len(values))
# Fill the bars
for col, value in enumerate(values):
if max_val > min_val:
normalized = (value - min_val) / (max_val - min_val)
bar_height = max(1, int(normalized * chart_height))
else:
bar_height = 1
# Fill from bottom up
for row_idx in range(chart_height - bar_height, chart_height):
if row_idx < len(grid):
# Use different characters for height effect
if row_idx == chart_height - bar_height:
grid[row_idx][col] = "" # Top of bar
elif row_idx == chart_height - 1:
grid[row_idx][col] = "" # Bottom of bar
else:
grid[row_idx][col] = "" # Middle of bar
# Add Y-axis scale
for i, row_data in enumerate(grid):
y_val = (
max_val - (i / (chart_height - 1)) * (max_val - min_val)
if chart_height > 1
else max_val
)
if abs(y_val) >= 1000:
y_label = f"{y_val:.0f}"
else:
y_label = f"{y_val:.1f}"
lines.append(f"{y_label:>6}" + "".join(f"{cell:^3}" for cell in row_data))
# Add X-axis
lines.append("" + "───" * len(values))
# Add labels
label_line = " "
for label in labels:
short_label = label[:3] if len(label) > 3 else label
label_line += f"{short_label:^3}"
lines.append(label_line)
return "\n".join(lines)
def _create_gradient_bar(length: int, value: float, max_val: float) -> str:
"""Create a gradient bar with visual effects."""
if length <= 0:
return ""
# Create gradient effect based on value intensity
intensity = value / max_val if max_val > 0 else 0
if intensity > 0.8:
# High values - solid bars
return "" * length
elif intensity > 0.6:
# Medium-high values - mostly solid with some texture
return "" * (length - 1) + "" if length > 1 else ""
elif intensity > 0.4:
# Medium values - mixed texture
return "" * length
elif intensity > 0.2:
# Low-medium values - lighter texture
return "" * length
else:
# Low values - lightest texture
return "" * length
def _generate_ascii_line_chart(data: List[Any], width: int, height: int) -> str:
"""Generate enhanced ASCII line chart with trend analysis."""
if not data:
return "No data for line chart"
lines = []
lines.append("📈 Line Chart with Trend Analysis")
lines.append("" * min(width, 60))
# Extract values and labels for plotting
values, labels = _extract_time_series_data(data)
if not values:
return "No numeric data found for line chart"
# Generate enhanced line chart
if len(values) >= 3:
lines.extend(_create_enhanced_line_chart(values, labels, width, height))
else:
# Fallback to sparkline for small datasets
sparkline_data = _create_sparkline(values)
lines.extend(sparkline_data)
# Add trend analysis
trend_analysis = _analyze_trend(values)
lines.append("")
lines.append("📊 Trend Analysis:")
lines.extend(trend_analysis)
return "\n".join(lines)
def _extract_time_series_data(data: List[Any]) -> tuple[List[float], List[str]]:
"""Extract time series data with labels."""
values = []
labels = []
for row in data[:20]: # Limit points for readability
if isinstance(row, dict):
# Find the first numeric value and first string/date value
numeric_val = None
label_val = None
for key, val in row.items():
if isinstance(val, (int, float)) and numeric_val is None:
numeric_val = val
elif isinstance(val, str) and label_val is None:
# Use the key name if it looks like a date/time field
if any(
date_word in key.lower()
for date_word in ["date", "time", "month", "day", "year"]
):
label_val = str(val)[:10] # Truncate long dates
else:
label_val = str(val)[:8] # Truncate long strings
if numeric_val is not None:
values.append(numeric_val)
labels.append(label_val or f"P{len(values)}")
return values, labels
def _create_enhanced_line_chart(
values: List[float], labels: List[str], width: int, height: int
) -> List[str]:
"""Create an enhanced ASCII line chart with better visualization."""
lines = []
chart_width = min(50, width - 15)
chart_height = min(12, height - 8)
if len(values) < 2:
return ["Insufficient data for line chart"]
# Normalize values to chart height
min_val = min(values)
max_val = max(values)
val_range = max_val - min_val if max_val != min_val else 1
# Create chart grid
grid = [[" " for _ in range(chart_width)] for _ in range(chart_height)]
# Plot the line with connecting segments
prev_x, prev_y = None, None
for i, value in enumerate(values):
# Map to grid coordinates
x = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0
y = chart_height - 1 - int(((value - min_val) / val_range) * (chart_height - 1))
# Ensure coordinates are within bounds
x = max(0, min(x, chart_width - 1))
y = max(0, min(y, chart_height - 1))
# Mark the point
grid[y][x] = ""
# Draw line segment to previous point
if prev_x is not None and prev_y is not None:
_draw_line_segment(grid, prev_x, prev_y, x, y, chart_width, chart_height)
prev_x, prev_y = x, y
# Render the chart with Y-axis labels
for i, row in enumerate(grid):
y_val = (
max_val - (i / (chart_height - 1)) * val_range
if chart_height > 1
else max_val
)
if abs(y_val) >= 1000:
y_label = f"{y_val:.0f}"
else:
y_label = f"{y_val:.1f}"
lines.append(f"{y_label:>8}" + "".join(row))
# Add X-axis
lines.append("" + "" * chart_width)
# Add selected X-axis labels (show every few labels to avoid crowding)
label_line = " "
step = max(1, len(labels) // 6) # Show max 6 labels
for i in range(0, len(labels), step):
pos = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0
# Add spacing to position the label
while len(label_line) - 10 < pos:
label_line += " "
label_line += labels[i][:8]
lines.append(label_line)
return lines
def _draw_line_segment(
grid: List[List[str]], x1: int, y1: int, x2: int, y2: int, width: int, height: int
) -> None:
"""Draw a line segment between two points using Bresenham-like algorithm."""
# Simple line drawing - connect points with line characters
if x1 == x2: # Vertical line
start_y, end_y = sorted([y1, y2])
for y in range(start_y + 1, end_y):
if 0 <= y < height and 0 <= x1 < width:
grid[y][x1] = ""
elif y1 == y2: # Horizontal line
start_x, end_x = sorted([x1, x2])
for x in range(start_x + 1, end_x):
if 0 <= y1 < height and 0 <= x < width:
grid[y1][x] = ""
else: # Diagonal line - use simple interpolation
steps = max(abs(x2 - x1), abs(y2 - y1))
for step in range(1, steps):
x = x1 + int((x2 - x1) * step / steps)
y = y1 + int((y2 - y1) * step / steps)
if 0 <= x < width and 0 <= y < height:
if abs(x2 - x1) > abs(y2 - y1):
grid[y][x] = ""
else:
grid[y][x] = ""
def _analyze_trend(values: List[float]) -> List[str]:
"""Analyze trend in the data."""
if len(values) < 2:
return ["• Insufficient data for trend analysis"]
analysis = []
# Calculate basic statistics
first_val = values[0]
last_val = values[-1]
min_val = min(values)
max_val = max(values)
avg_val = sum(values) / len(values)
# Overall trend
if last_val > first_val * 1.1:
trend_icon = "📈"
trend_desc = "Strong upward trend"
elif last_val > first_val * 1.05:
trend_icon = "📊"
trend_desc = "Moderate upward trend"
elif last_val < first_val * 0.9:
trend_icon = "📉"
trend_desc = "Strong downward trend"
elif last_val < first_val * 0.95:
trend_icon = "📊"
trend_desc = "Moderate downward trend"
else:
trend_icon = "➡️"
trend_desc = "Relatively stable"
analysis.append(f"{trend_icon} {trend_desc}")
analysis.append(f"• Range: {min_val:.1f} to {max_val:.1f} (avg: {avg_val:.1f})")
# Volatility
if len(values) >= 3:
changes = [abs(values[i] - values[i - 1]) for i in range(1, len(values))]
avg_change = sum(changes) / len(changes)
volatility = "High" if avg_change > (max_val - min_val) * 0.1 else "Low"
analysis.append(f"• Volatility: {volatility}")
return analysis
def _extract_numeric_values(data: List[Any]) -> List[float]:
"""Extract numeric values from data for line chart."""
values = []
for row in data[:20]: # Limit points
if isinstance(row, dict):
for _key, val in row.items():
if isinstance(val, (int, float)):
values.append(val)
break
return values
def _create_sparkline(values: List[float]) -> List[str]:
"""Create sparkline visualization from values."""
if len(values) <= 1:
return []
max_val = max(values)
min_val = min(values)
range_val = max_val - min_val if max_val != min_val else 1
sparkline = ""
for val in values:
normalized = (val - min_val) / range_val
if normalized < 0.2:
sparkline += ""
elif normalized < 0.4:
sparkline += ""
elif normalized < 0.6:
sparkline += ""
elif normalized < 0.8:
sparkline += ""
else:
sparkline += ""
# Safe formatting to avoid NaN display
if _is_nan_value(min_val) or _is_nan_value(max_val):
return ["Range: Unable to calculate from data", sparkline]
else:
return [f"Range: {min_val:.2f} to {max_val:.2f}", sparkline]
def _is_nan_value(value: Any) -> bool:
"""Check if a value is NaN or invalid."""
try:
import math
return math.isnan(float(value))
except (ValueError, TypeError):
return True
def _generate_ascii_scatter_chart(data: List[Any], width: int, height: int) -> str:
"""Generate ASCII scatter plot."""
if not data:
return "No data for scatter chart"
lines = []
lines.append("ASCII Scatter Plot")
lines.append("=" * min(width, 50))
# Extract data points
x_values, y_values, x_column, y_column = _extract_scatter_data(data)
# Debug info
lines.extend(_create_debug_info(x_values, y_values, x_column, y_column, data))
# Check if we have enough data
if len(x_values) < 2:
return _generate_ascii_table(data, width)
# Add axis info
lines.extend(_create_axis_info(x_values, y_values, x_column, y_column))
# Create and render grid
grid = _create_scatter_grid(x_values, y_values, width, height)
lines.extend(_render_scatter_grid(grid, x_values, y_values, width, height))
return "\n".join(lines)
def _extract_scatter_data(
data: List[Any],
) -> tuple[List[float], List[float], str | None, str | None]:
"""Extract X,Y data from scatter chart data."""
x_values = []
y_values = []
x_column = None
y_column = None
numeric_columns = []
if data and isinstance(data[0], dict):
# Find the first two numeric columns
for key, val in data[0].items():
if isinstance(val, (int, float)) and not (
isinstance(val, float) and (val != val)
): # Exclude NaN
numeric_columns.append(key)
if len(numeric_columns) >= 2:
x_column = numeric_columns[0]
y_column = numeric_columns[1]
# Extract X,Y pairs
for row in data[:50]: # Limit for ASCII display
if isinstance(row, dict):
x_val = row.get(x_column)
y_val = row.get(y_column)
# Check for valid numbers (not NaN)
if (
isinstance(x_val, (int, float))
and isinstance(y_val, (int, float))
and not (
isinstance(x_val, float) and (x_val != x_val)
) # Not NaN
and not (isinstance(y_val, float) and (y_val != y_val))
): # Not NaN
x_values.append(x_val)
y_values.append(y_val)
return x_values, y_values, x_column, y_column
def _create_debug_info(
x_values: List[float],
y_values: List[float],
x_column: str | None,
y_column: str | None,
data: List[Any],
) -> List[str]:
"""Create debug information lines for scatter chart."""
numeric_columns = []
if data and isinstance(data[0], dict):
for key, val in data[0].items():
if isinstance(val, (int, float)) and not (
isinstance(val, float) and (val != val)
):
numeric_columns.append(key)
return [
f"DEBUG: Found {len(numeric_columns)} numeric columns: {numeric_columns}",
f"DEBUG: X column: {x_column}, Y column: {y_column}",
f"DEBUG: Valid X,Y pairs: {len(x_values)}",
]
def _create_axis_info(
x_values: List[float],
y_values: List[float],
x_column: str | None,
y_column: str | None,
) -> List[str]:
"""Create axis information lines."""
return [
f"X-axis: {x_column} (range: {min(x_values):.2f} to {max(x_values):.2f})",
f"Y-axis: {y_column} (range: {min(y_values):.2f} to {max(y_values):.2f})",
f"Showing {len(x_values)} data points",
"",
]
def _create_scatter_grid(
x_values: List[float], y_values: List[float], width: int, height: int
) -> List[List[str]]:
"""Create and populate the scatter plot grid."""
plot_width = min(40, width - 10)
plot_height = min(15, height - 8)
# Normalize values to fit in grid
x_min, x_max = min(x_values), max(x_values)
y_min, y_max = min(y_values), max(y_values)
x_range = x_max - x_min if x_max != x_min else 1
y_range = y_max - y_min if y_max != y_min else 1
# Create grid
grid = [[" " for _ in range(plot_width)] for _ in range(plot_height)]
# Plot points
for x, y in zip(x_values, y_values, strict=False):
try:
grid_x = int(((x - x_min) / x_range) * (plot_width - 1))
grid_y = int(((y - y_min) / y_range) * (plot_height - 1))
grid_y = plot_height - 1 - grid_y # Flip Y axis for display
if 0 <= grid_x < plot_width and 0 <= grid_y < plot_height:
if grid[grid_y][grid_x] == " ":
grid[grid_y][grid_x] = ""
else:
grid[grid_y][grid_x] = "" # Multiple points
except (ValueError, OverflowError):
# Skip points that can't be converted to integers (NaN, inf, etc.)
continue
return grid
def _render_scatter_grid(
grid: List[List[str]],
x_values: List[float],
y_values: List[float],
width: int,
height: int,
) -> List[str]:
"""Render the scatter plot grid with axes and labels."""
lines = []
plot_width = min(40, width - 10)
plot_height = min(15, height - 8)
x_min, x_max = min(x_values), max(x_values)
y_min, y_max = min(y_values), max(y_values)
y_range = y_max - y_min if y_max != y_min else 1
# Add Y-axis labels and plot
for i, row in enumerate(grid):
y_val = y_max - (i / (plot_height - 1)) * y_range if plot_height > 1 else y_max
y_label = f"{y_val:.1f}" if abs(y_val) < 1000 else f"{y_val:.0f}"
lines.append(f"{y_label:>6} |{''.join(row)}")
# Add X-axis
x_axis_line = " " * 7 + "+" + "-" * plot_width
lines.append(x_axis_line)
# Add X-axis labels
x_left_label = f"{x_min:.1f}" if abs(x_min) < 1000 else f"{x_min:.0f}"
x_right_label = f"{x_max:.1f}" if abs(x_max) < 1000 else f"{x_max:.0f}"
x_labels = (
" " * 8
+ x_left_label
+ " " * (plot_width - len(x_left_label) - len(x_right_label))
+ x_right_label
)
lines.append(x_labels)
return lines
def _generate_ascii_table(data: List[Any], width: int) -> str:
"""Generate enhanced ASCII table with better formatting."""
if not data:
return "No data for table"
lines = []
lines.append("📋 Data Table")
lines.append("" * min(width, 70))
# Get column headers from first row
if isinstance(data[0], dict):
# Select best columns to display
all_headers = list(data[0].keys())
headers = _select_display_columns(all_headers, data, max_cols=6)
# Calculate optimal column widths
col_widths = _calculate_column_widths(headers, data, width)
# Create enhanced header
lines.append(_create_table_header(headers, col_widths))
lines.append(_create_table_separator(col_widths))
# Add data rows with better formatting
row_count = min(15, len(data)) # Show more rows
for i, row in enumerate(data[:row_count]):
formatted_row = _format_table_row(row, headers, col_widths)
lines.append(formatted_row)
# Add separator every 5 rows for readability
if i > 0 and (i + 1) % 5 == 0 and i < row_count - 1:
lines.append(_create_light_separator(col_widths))
# Add footer with stats
lines.append(_create_table_separator(col_widths))
lines.append(f"📊 Showing {row_count} of {len(data)} rows")
# Add column summaries for numeric columns
numeric_summaries = _create_numeric_summaries(data, headers)
if numeric_summaries:
lines.append("")
lines.append("📈 Numeric Summaries:")
lines.extend(numeric_summaries)
return "\n".join(lines)
def _select_display_columns(
all_headers: List[str], data: List[Any], max_cols: int = 6
) -> List[str]:
"""Select the most interesting columns to display."""
if len(all_headers) <= max_cols:
return all_headers
# Prioritize columns by interest level
priority_scores = {}
for header in all_headers:
score = 0
header_lower = header.lower()
# Higher priority for common business fields
if any(word in header_lower for word in ["name", "title", "id"]):
score += 10
if any(
word in header_lower
for word in ["amount", "price", "cost", "revenue", "sales"]
):
score += 8
if any(word in header_lower for word in ["date", "time", "created", "updated"]):
score += 6
if any(word in header_lower for word in ["count", "total", "sum", "avg"]):
score += 5
# Check data variety (more variety = more interesting)
sample_values = [
str(row.get(header, "")) for row in data[:10] if isinstance(row, dict)
]
unique_values = len(set(sample_values))
if unique_values > 1:
score += min(unique_values, 5)
priority_scores[header] = score
# Return top scoring columns
sorted_headers = sorted(
all_headers, key=lambda h: priority_scores.get(h, 0), reverse=True
)
return sorted_headers[:max_cols]
def _calculate_column_widths(
headers: List[str], data: List[Any], total_width: int
) -> List[int]:
"""Calculate optimal column widths based on content."""
if not headers:
return []
# Start with minimum widths based on header lengths
min_widths = [max(8, min(len(h) + 2, 15)) for h in headers]
# Check actual data to adjust widths
for row in data[:10]: # Sample first 10 rows
if isinstance(row, dict):
for i, header in enumerate(headers):
val = row.get(header, "")
if isinstance(val, float):
content_len = len(f"{val:.2f}")
elif isinstance(val, int):
content_len = len(str(val))
else:
content_len = len(str(val))
min_widths[i] = max(min_widths[i], min(content_len + 1, 20))
# Distribute remaining space proportionally
used_width = sum(min_widths) + len(headers) * 3 # Account for separators
available_width = min(total_width - 10, 80) # Leave margins
if used_width < available_width:
# Distribute extra space
extra_space = available_width - used_width
for i in range(len(min_widths)):
min_widths[i] += extra_space // len(min_widths)
return min_widths
def _create_table_header(headers: List[str], widths: List[int]) -> str:
"""Create formatted table header."""
formatted_headers = []
for header, width in zip(headers, widths, strict=False):
# Truncate and center header
display_header = header[: width - 2] if len(header) > width - 2 else header
formatted_headers.append(f"{display_header:^{width}}")
return (
""
+ "".join("" * w for w in widths)
+ "\n"
+ "".join(formatted_headers)
+ ""
)
def _create_table_separator(widths: List[int]) -> str:
"""Create table separator line."""
return "" + "".join("" * w for w in widths) + ""
def _create_light_separator(widths: List[int]) -> str:
"""Create light separator line."""
return "" + "".join("" * w for w in widths) + ""
def _format_table_row(
row: Dict[str, Any], headers: List[str], widths: List[int]
) -> str:
"""Format a single table row."""
formatted_values = []
for header, width in zip(headers, widths, strict=False):
val = row.get(header, "")
# Format value based on type
if isinstance(val, float):
if abs(val) >= 1000000:
formatted_val = f"{val / 1000000:.1f}M"
elif abs(val) >= 1000:
formatted_val = f"{val / 1000:.1f}K"
else:
formatted_val = f"{val:.2f}"
elif isinstance(val, int):
if abs(val) >= 1000000:
formatted_val = f"{val // 1000000}M"
elif abs(val) >= 1000:
formatted_val = f"{val // 1000}K"
else:
formatted_val = str(val)
else:
formatted_val = str(val)
# Truncate if too long
if len(formatted_val) > width - 2:
formatted_val = formatted_val[: width - 5] + "..."
# Align numbers right, text left
if isinstance(val, (int, float)):
formatted_values.append(f"{formatted_val:>{width - 2}}")
else:
formatted_values.append(f"{formatted_val:<{width - 2}}")
return "" + "".join(formatted_values) + ""
def _create_numeric_summaries(data: List[Any], headers: List[str]) -> List[str]:
"""Create summaries for numeric columns."""
summaries = []
for header in headers:
# Collect numeric values
values = []
for row in data:
if isinstance(row, dict):
val = row.get(header)
if isinstance(val, (int, float)):
values.append(val)
if len(values) >= 2:
min_val = min(values)
max_val = max(values)
avg_val = sum(values) / len(values)
if abs(avg_val) >= 1000:
avg_str = f"{avg_val / 1000:.1f}K"
else:
avg_str = f"{avg_val:.1f}"
summaries.append(
f" {header}: avg={avg_str}, range={min_val:.1f}-{max_val:.1f}"
)
return summaries
async def _get_chart_preview_internal( # noqa: C901
request: GetChartPreviewRequest,
ctx: Context,
@@ -1852,7 +990,6 @@ async def _get_chart_preview_internal( # noqa: C901
"""
try:
await ctx.report_progress(1, 3, "Looking up chart")
from superset.daos.chart import ChartDAO
# Find the chart
with event_logger.log_context(action="mcp.get_chart_preview.chart_lookup"):
@@ -1921,25 +1058,16 @@ async def _get_chart_preview_internal( # noqa: C901
error_type="NotFound",
)
elif isinstance(request.identifier, int) or (
isinstance(request.identifier, str) and request.identifier.isdigit()
):
chart_id = (
int(request.identifier)
if isinstance(request.identifier, str)
else request.identifier
)
else:
await ctx.debug(
"Performing ID-based chart lookup: chart_id=%s" % (chart_id,)
"Looking up chart: identifier=%s" % (request.identifier,)
)
chart = ChartDAO.find_by_id(chart_id)
elif isinstance(request.identifier, str):
await ctx.debug(
"Performing UUID-based chart lookup: uuid=%s"
% (request.identifier,)
)
# Try UUID lookup using DAO flexible method
chart = ChartDAO.find_by_id(request.identifier, id_column="uuid")
if request.identifier is None:
return ChartError(
error="Chart identifier is required",
error_type="ValidationError",
)
chart = find_chart_by_identifier(request.identifier)
# If not found and looks like a form_data_key, try transient
if (
@@ -2268,7 +1396,15 @@ async def get_chart_preview(
error=OAUTH2_CONFIG_ERROR_MESSAGE,
error_type="OAUTH2_REDIRECT_ERROR",
)
except Exception as e:
except (
SupersetException,
CommandException,
SQLAlchemyError,
KeyError,
ValueError,
TypeError,
AttributeError,
) as e:
await ctx.error(
"Chart preview generation failed: identifier=%s, error=%s, error_type=%s"
% (

View File

@@ -22,7 +22,6 @@ MCP tool: update_chart
import logging
import time
from typing import Any
from urllib.parse import parse_qs, urlparse
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
@@ -31,6 +30,10 @@ from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_helpers import (
extract_form_data_key_from_url,
find_chart_by_identifier,
)
from superset.mcp_service.chart.chart_utils import (
analyze_chart_capabilities,
analyze_chart_semantics,
@@ -54,18 +57,6 @@ from superset.utils import json
logger = logging.getLogger(__name__)
def _find_chart(identifier: int | str) -> Any | None:
"""Find a chart by numeric ID or UUID string."""
from superset.daos.chart import ChartDAO
if isinstance(identifier, int) or (
isinstance(identifier, str) and identifier.isdigit()
):
chart_id = int(identifier) if isinstance(identifier, str) else identifier
return ChartDAO.find_by_id(chart_id)
return ChartDAO.find_by_id(identifier, id_column="uuid")
def _validation_error_response(message: str, details: str) -> GenerateChartResponse:
return GenerateChartResponse.model_validate(
{
@@ -277,7 +268,7 @@ async def update_chart( # noqa: C901
try:
with event_logger.log_context(action="mcp.update_chart.chart_lookup"):
chart = _find_chart(request.identifier)
chart = find_chart_by_identifier(request.identifier)
if not chart:
return GenerateChartResponse.model_validate(
@@ -412,15 +403,21 @@ async def update_chart( # noqa: C901
if hasattr(preview_result, "content"):
previews[format_type] = preview_result.content
except Exception as e:
except (
OAuth2RedirectError,
OAuth2Error,
CommandException,
SQLAlchemyError,
KeyError,
ValueError,
TypeError,
AttributeError,
) as e:
logger.warning("Preview generation failed: %s", e)
# Fallback: extract form_data_key from explore_url if not set
if form_data_key is None and explore_url and "form_data_key=" in explore_url:
parsed = urlparse(explore_url)
values = parse_qs(parsed.query).get("form_data_key")
if values:
form_data_key = values[0]
if form_data_key is None:
form_data_key = extract_form_data_key_from_url(explore_url)
chart_id = updated_chart.id if saved and updated_chart else chart.id
chart_uuid = (

View File

@@ -24,10 +24,13 @@ import time
from typing import Any, Dict
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.exceptions import OAuth2Error, OAuth2RedirectError
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url
from superset.mcp_service.chart.chart_utils import (
analyze_chart_capabilities,
analyze_chart_semantics,
@@ -123,9 +126,19 @@ def update_chart_preview(
explore_url = generate_explore_link(request.dataset_id, new_form_data)
# Extract new form_data_key from the explore URL
new_form_data_key = None
if "form_data_key=" in explore_url:
new_form_data_key = explore_url.split("form_data_key=")[1].split("&")[0]
new_form_data_key = extract_form_data_key_from_url(explore_url)
if not new_form_data_key:
return {
"chart": None,
"error": {
"error_type": "PreviewError",
"message": "Failed to generate preview: missing form_data_key",
"details": "The explore URL did not contain a form_data_key",
},
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
with event_logger.log_context(action="mcp.update_chart_preview.metadata"):
# Generate semantic analysis
@@ -199,7 +212,15 @@ def update_chart_preview(
"error": OAUTH2_CONFIG_ERROR_MESSAGE,
"success": False,
}
except Exception as e:
except (
SupersetException,
CommandException,
SQLAlchemyError,
KeyError,
ValueError,
TypeError,
AttributeError,
) as e:
execution_time = int((time.time() - start_time) * 1000)
return {
"chart": None,

View File

@@ -23,12 +23,12 @@ chart configuration.
"""
from typing import Any, Dict
from urllib.parse import parse_qs, urlparse
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url
from superset.mcp_service.chart.chart_utils import (
generate_explore_link as generate_url,
map_config_to_form_data,
@@ -175,14 +175,8 @@ async def generate_explore_link(
dataset_id=request.dataset_id, form_data=form_data
)
# Extract form_data_key from the explore URL using proper URL parsing
form_data_key = None
if explore_url:
parsed = urlparse(explore_url)
query_params = parse_qs(parsed.query)
form_data_key_list = query_params.get("form_data_key", [])
if form_data_key_list:
form_data_key = form_data_key_list[0]
# Extract form_data_key from the explore URL
form_data_key = extract_form_data_key_from_url(explore_url)
await ctx.report_progress(4, 4, "URL generation complete")
await ctx.info(

View File

@@ -0,0 +1,108 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import MagicMock, patch
from superset.mcp_service.chart.chart_helpers import (
extract_form_data_key_from_url,
find_chart_by_identifier,
get_cached_form_data,
)
def test_extract_form_data_key_from_url_with_key():
url = "http://localhost:8088/explore/?form_data_key=abc123&slice_id=1"
assert extract_form_data_key_from_url(url) == "abc123"
def test_extract_form_data_key_from_url_no_key():
url = "http://localhost:8088/explore/?slice_id=1"
assert extract_form_data_key_from_url(url) is None
def test_extract_form_data_key_from_url_none():
assert extract_form_data_key_from_url(None) is None
def test_extract_form_data_key_from_url_empty():
assert extract_form_data_key_from_url("") is None
def test_extract_form_data_key_from_url_multiple_params():
url = "http://localhost:8088/explore/?slice_id=5&form_data_key=xyz789&other=val"
assert extract_form_data_key_from_url(url) == "xyz789"
@patch("superset.daos.chart.ChartDAO.find_by_id")
def test_find_chart_by_identifier_int(mock_find):
mock_chart = MagicMock()
mock_chart.id = 42
mock_find.return_value = mock_chart
result = find_chart_by_identifier(42)
mock_find.assert_called_once_with(42)
assert result == mock_chart
@patch("superset.daos.chart.ChartDAO.find_by_id")
def test_find_chart_by_identifier_str_digit(mock_find):
mock_chart = MagicMock()
mock_find.return_value = mock_chart
result = find_chart_by_identifier("123")
mock_find.assert_called_once_with(123)
assert result == mock_chart
@patch("superset.daos.chart.ChartDAO.find_by_id")
def test_find_chart_by_identifier_uuid(mock_find):
mock_chart = MagicMock()
mock_find.return_value = mock_chart
uuid_str = "a1b2c3d4-5678-90ab-cdef-1234567890ab"
result = find_chart_by_identifier(uuid_str)
mock_find.assert_called_once_with(uuid_str, id_column="uuid")
assert result == mock_chart
@patch("superset.daos.chart.ChartDAO.find_by_id")
def test_find_chart_by_identifier_not_found(mock_find):
mock_find.return_value = None
result = find_chart_by_identifier(999)
assert result is None
@patch(
"superset.commands.explore.form_data.get.GetFormDataCommand.run",
return_value='{"viz_type": "table"}',
)
@patch("superset.commands.explore.form_data.get.GetFormDataCommand.__init__")
def test_get_cached_form_data_success(mock_init, mock_run):
mock_init.return_value = None
result = get_cached_form_data("test_key")
assert result == '{"viz_type": "table"}'
@patch(
"superset.commands.explore.form_data.get.GetFormDataCommand.run",
side_effect=KeyError("not found"),
)
@patch("superset.commands.explore.form_data.get.GetFormDataCommand.__init__")
def test_get_cached_form_data_key_error(mock_init, mock_run):
mock_init.return_value = None
result = get_cached_form_data("bad_key")
assert result is None

View File

@@ -26,6 +26,7 @@ import pytest
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.chart_helpers import find_chart_by_identifier
from superset.mcp_service.chart.chart_utils import DatasetValidationResult
from superset.mcp_service.chart.schemas import (
AxisConfig,
@@ -40,7 +41,6 @@ from superset.mcp_service.chart.schemas import (
from superset.mcp_service.chart.tool.update_chart import (
_build_preview_form_data,
_build_update_payload,
_find_chart,
)
# The __init__.py re-exports the update_chart *function*, so a plain
@@ -519,7 +519,7 @@ class TestUpdateChartDatasetAccess:
class TestFindChart:
"""Tests for _find_chart helper."""
"""Tests for find_chart_by_identifier helper (moved to chart_helpers)."""
@patch("superset.daos.chart.ChartDAO.find_by_id")
def test_find_chart_by_numeric_id(self, mock_find):
@@ -527,7 +527,7 @@ class TestFindChart:
mock_chart = Mock()
mock_find.return_value = mock_chart
result = _find_chart(42)
result = find_chart_by_identifier(42)
mock_find.assert_called_once_with(42)
assert result is mock_chart
@@ -538,7 +538,7 @@ class TestFindChart:
mock_chart = Mock()
mock_find.return_value = mock_chart
result = _find_chart("123")
result = find_chart_by_identifier("123")
mock_find.assert_called_once_with(123)
assert result is mock_chart
@@ -550,7 +550,7 @@ class TestFindChart:
mock_find.return_value = mock_chart
uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
result = _find_chart(uuid)
result = find_chart_by_identifier(uuid)
mock_find.assert_called_once_with(uuid, id_column="uuid")
assert result is mock_chart
@@ -560,7 +560,7 @@ class TestFindChart:
"""Returns None when chart is not found."""
mock_find.return_value = None
result = _find_chart(999)
result = find_chart_by_identifier(999)
assert result is None