mirror of
https://github.com/apache/superset.git
synced 2026-06-07 08:39:25 +00:00
chore(mcp): extract shared chart helpers and ASCII rendering into separate modules (#39438)
This commit is contained in:
863
superset/mcp_service/chart/ascii_charts.py
Normal file
863
superset/mcp_service/chart/ascii_charts.py
Normal file
@@ -0,0 +1,863 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
ASCII chart rendering functions for MCP chart previews.
|
||||
|
||||
Pure rendering functions that take data in and return strings out.
|
||||
No dependencies on MCP tools, FastMCP Context, or request objects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_ascii_chart(
|
||||
data: list[Any], chart_type: str, width: int = 80, height: int = 20
|
||||
) -> str:
|
||||
"""Generate ASCII art chart from data."""
|
||||
if not data:
|
||||
return "No data available for ASCII chart"
|
||||
|
||||
try:
|
||||
# Clamp to safe minimums to prevent negative plot sizes
|
||||
width = max(width, 21)
|
||||
height = max(height, 9)
|
||||
|
||||
logger.debug(
|
||||
"generate_ascii_chart: chart_type=%s, data_rows=%s", chart_type, len(data)
|
||||
)
|
||||
|
||||
if chart_type in ["bar", "column", "echarts_timeseries_bar"]:
|
||||
return _generate_ascii_bar_chart(data, width, height)
|
||||
elif chart_type in ["line", "echarts_timeseries_line"]:
|
||||
return _generate_ascii_line_chart(data, width, height)
|
||||
elif chart_type in ["scatter", "echarts_timeseries_scatter"]:
|
||||
return _generate_ascii_scatter_chart(data, width, height)
|
||||
else:
|
||||
logger.debug(
|
||||
"Unsupported chart type '%s', falling back to table", chart_type
|
||||
)
|
||||
return generate_ascii_table(data, width)
|
||||
except (TypeError, ValueError, KeyError, IndexError) as e:
|
||||
logger.error("ASCII chart generation failed: %s", e, exc_info=True)
|
||||
return "ASCII chart generation failed"
|
||||
|
||||
|
||||
def _generate_ascii_bar_chart(data: list[Any], width: int, height: int) -> str:
|
||||
"""Generate enhanced ASCII bar chart with horizontal and vertical options."""
|
||||
if not data:
|
||||
return "No data for bar chart"
|
||||
|
||||
# Extract numeric values for bars
|
||||
values = []
|
||||
labels = []
|
||||
|
||||
for row in data[:12]: # Increased limit for better charts
|
||||
if isinstance(row, dict):
|
||||
# Find numeric and string values
|
||||
numeric_val = None
|
||||
label_val = None
|
||||
|
||||
for _key, val in row.items():
|
||||
if isinstance(val, (int, float)) and numeric_val is None:
|
||||
numeric_val = val
|
||||
elif isinstance(val, str) and label_val is None:
|
||||
label_val = val
|
||||
|
||||
if numeric_val is not None:
|
||||
values.append(numeric_val)
|
||||
labels.append(label_val or f"Item {len(values)}")
|
||||
|
||||
if not values:
|
||||
return "No numeric data found for bar chart"
|
||||
|
||||
# Decide between horizontal and vertical based on label lengths
|
||||
avg_label_length = sum(len(str(label)) for label in labels) / len(labels)
|
||||
use_horizontal = avg_label_length > 8 or len(values) > 8
|
||||
|
||||
if use_horizontal:
|
||||
return _generate_horizontal_bar_chart(values, labels, width)
|
||||
else:
|
||||
return _generate_vertical_bar_chart(values, labels, width, height)
|
||||
|
||||
|
||||
def _generate_horizontal_bar_chart(
|
||||
values: list[float], labels: list[str], width: int
|
||||
) -> str:
|
||||
"""Generate horizontal ASCII bar chart."""
|
||||
lines = []
|
||||
lines.append("📊 Horizontal Bar Chart")
|
||||
lines.append("═" * min(width, 60))
|
||||
|
||||
max_val = max(values) if values else 1
|
||||
min_val = min(values) if values else 0
|
||||
max_bar_width = min(40, width - 20) # Leave space for labels and values
|
||||
|
||||
# Add scale indicator
|
||||
lines.append(f"Scale: {min_val:.1f} ────────────── {max_val:.1f}")
|
||||
lines.append("")
|
||||
|
||||
for label, value in zip(labels, values, strict=False):
|
||||
# Calculate bar length
|
||||
if max_val > min_val:
|
||||
normalized = (value - min_val) / (max_val - min_val)
|
||||
bar_length = max(1, int(normalized * max_bar_width))
|
||||
else:
|
||||
bar_length = 1
|
||||
|
||||
# Create bar with gradient effect
|
||||
bar = _create_gradient_bar(bar_length, value, max_val)
|
||||
|
||||
# Format value
|
||||
if abs(value) >= 1000000:
|
||||
value_str = f"{value / 1000000:.1f}M"
|
||||
elif abs(value) >= 1000:
|
||||
value_str = f"{value / 1000:.1f}K"
|
||||
else:
|
||||
value_str = f"{value:.1f}"
|
||||
|
||||
# Truncate label if too long
|
||||
display_label = label[:15] if len(label) > 15 else label
|
||||
lines.append(f"{display_label:>15} ▐{bar:<{max_bar_width}} {value_str}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _generate_vertical_bar_chart( # noqa: C901
|
||||
values: list[float], labels: list[str], width: int, height: int
|
||||
) -> str:
|
||||
"""Generate vertical ASCII bar chart."""
|
||||
lines = []
|
||||
lines.append("📊 Vertical Bar Chart")
|
||||
lines.append("═" * min(width, 60))
|
||||
|
||||
max_val = max(values) if values else 1
|
||||
min_val = min(values) if values else 0
|
||||
chart_height = min(15, height - 8) # Leave space for title and labels
|
||||
|
||||
# Create the chart grid
|
||||
grid = []
|
||||
for _ in range(chart_height):
|
||||
grid.append([" "] * len(values))
|
||||
|
||||
# Fill the bars
|
||||
for col, value in enumerate(values):
|
||||
if max_val > min_val:
|
||||
normalized = (value - min_val) / (max_val - min_val)
|
||||
bar_height = max(1, int(normalized * chart_height))
|
||||
else:
|
||||
bar_height = 1
|
||||
|
||||
# Fill from bottom up
|
||||
for row_idx in range(chart_height - bar_height, chart_height):
|
||||
if row_idx < len(grid):
|
||||
# Use different characters for height effect
|
||||
if row_idx == chart_height - bar_height:
|
||||
grid[row_idx][col] = "▀" # Top of bar
|
||||
elif row_idx == chart_height - 1:
|
||||
grid[row_idx][col] = "█" # Bottom of bar
|
||||
else:
|
||||
grid[row_idx][col] = "█" # Middle of bar
|
||||
|
||||
# Add Y-axis scale
|
||||
for i, row_data in enumerate(grid):
|
||||
y_val = (
|
||||
max_val - (i / (chart_height - 1)) * (max_val - min_val)
|
||||
if chart_height > 1
|
||||
else max_val
|
||||
)
|
||||
if abs(y_val) >= 1000:
|
||||
y_label = f"{y_val:.0f}"
|
||||
else:
|
||||
y_label = f"{y_val:.1f}"
|
||||
lines.append(f"{y_label:>6} ┤ " + "".join(f"{cell:^3}" for cell in row_data))
|
||||
|
||||
# Add X-axis
|
||||
lines.append(" └" + "───" * len(values))
|
||||
|
||||
# Add labels
|
||||
label_line = " "
|
||||
for label in labels:
|
||||
short_label = label[:3] if len(label) > 3 else label
|
||||
label_line += f"{short_label:^3}"
|
||||
lines.append(label_line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _create_gradient_bar(length: int, value: float, max_val: float) -> str:
|
||||
"""Create a gradient bar with visual effects."""
|
||||
if length <= 0:
|
||||
return ""
|
||||
|
||||
# Create gradient effect based on value intensity
|
||||
intensity = value / max_val if max_val > 0 else 0
|
||||
|
||||
if intensity > 0.8:
|
||||
# High values - solid bars
|
||||
return "█" * length
|
||||
elif intensity > 0.6:
|
||||
# Medium-high values - mostly solid with some texture
|
||||
return "█" * (length - 1) + "▉" if length > 1 else "█"
|
||||
elif intensity > 0.4:
|
||||
# Medium values - mixed texture
|
||||
return "▊" * length
|
||||
elif intensity > 0.2:
|
||||
# Low-medium values - lighter texture
|
||||
return "▋" * length
|
||||
else:
|
||||
# Low values - lightest texture
|
||||
return "▌" * length
|
||||
|
||||
|
||||
def _generate_ascii_line_chart(data: list[Any], width: int, height: int) -> str:
|
||||
"""Generate enhanced ASCII line chart with trend analysis."""
|
||||
if not data:
|
||||
return "No data for line chart"
|
||||
|
||||
lines = []
|
||||
lines.append("📈 Line Chart with Trend Analysis")
|
||||
lines.append("═" * min(width, 60))
|
||||
|
||||
# Extract values and labels for plotting
|
||||
values, labels = _extract_time_series_data(data)
|
||||
|
||||
if not values:
|
||||
return "No numeric data found for line chart"
|
||||
|
||||
# Generate enhanced line chart
|
||||
if len(values) >= 3:
|
||||
lines.extend(_create_enhanced_line_chart(values, labels, width, height))
|
||||
else:
|
||||
# Fallback to sparkline for small datasets
|
||||
sparkline_data = _create_sparkline(values)
|
||||
lines.extend(sparkline_data)
|
||||
|
||||
# Add trend analysis
|
||||
trend_analysis = _analyze_trend(values)
|
||||
lines.append("")
|
||||
lines.append("📊 Trend Analysis:")
|
||||
lines.extend(trend_analysis)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_time_series_data(data: list[Any]) -> tuple[list[float], list[str]]:
|
||||
"""Extract time series data with labels."""
|
||||
values = []
|
||||
labels = []
|
||||
|
||||
for row in data[:20]: # Limit points for readability
|
||||
if isinstance(row, dict):
|
||||
# Find the first numeric value and first string/date value
|
||||
numeric_val = None
|
||||
label_val = None
|
||||
|
||||
for key, val in row.items():
|
||||
if isinstance(val, (int, float)) and numeric_val is None:
|
||||
numeric_val = val
|
||||
elif isinstance(val, str) and label_val is None:
|
||||
# Use the key name if it looks like a date/time field
|
||||
if any(
|
||||
date_word in key.lower()
|
||||
for date_word in ["date", "time", "month", "day", "year"]
|
||||
):
|
||||
label_val = str(val)[:10] # Truncate long dates
|
||||
else:
|
||||
label_val = str(val)[:8] # Truncate long strings
|
||||
|
||||
if numeric_val is not None:
|
||||
values.append(numeric_val)
|
||||
labels.append(label_val or f"P{len(values)}")
|
||||
|
||||
return values, labels
|
||||
|
||||
|
||||
def _create_enhanced_line_chart(
|
||||
values: list[float], labels: list[str], width: int, height: int
|
||||
) -> list[str]:
|
||||
"""Create an enhanced ASCII line chart with better visualization."""
|
||||
lines = []
|
||||
chart_width = min(50, width - 15)
|
||||
chart_height = min(12, height - 8)
|
||||
|
||||
if len(values) < 2:
|
||||
return ["Insufficient data for line chart"]
|
||||
|
||||
# Normalize values to chart height
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
val_range = max_val - min_val if max_val != min_val else 1
|
||||
|
||||
# Create chart grid
|
||||
grid = [[" " for _ in range(chart_width)] for _ in range(chart_height)]
|
||||
|
||||
# Plot the line with connecting segments
|
||||
prev_x, prev_y = None, None
|
||||
|
||||
for i, value in enumerate(values):
|
||||
# Map to grid coordinates
|
||||
x = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0
|
||||
y = chart_height - 1 - int(((value - min_val) / val_range) * (chart_height - 1))
|
||||
|
||||
# Ensure coordinates are within bounds
|
||||
x = max(0, min(x, chart_width - 1))
|
||||
y = max(0, min(y, chart_height - 1))
|
||||
|
||||
# Mark the point
|
||||
grid[y][x] = "●"
|
||||
|
||||
# Draw line segment to previous point
|
||||
if prev_x is not None and prev_y is not None:
|
||||
_draw_line_segment(grid, prev_x, prev_y, x, y, chart_width, chart_height)
|
||||
|
||||
prev_x, prev_y = x, y
|
||||
|
||||
# Render the chart with Y-axis labels
|
||||
for i, row in enumerate(grid):
|
||||
y_val = (
|
||||
max_val - (i / (chart_height - 1)) * val_range
|
||||
if chart_height > 1
|
||||
else max_val
|
||||
)
|
||||
if abs(y_val) >= 1000:
|
||||
y_label = f"{y_val:.0f}"
|
||||
else:
|
||||
y_label = f"{y_val:.1f}"
|
||||
lines.append(f"{y_label:>8} ┤ " + "".join(row))
|
||||
|
||||
# Add X-axis
|
||||
lines.append(" └" + "─" * chart_width)
|
||||
|
||||
# Add selected X-axis labels (show every few labels to avoid crowding)
|
||||
label_line = " "
|
||||
step = max(1, len(labels) // 6) # Show max 6 labels
|
||||
for i in range(0, len(labels), step):
|
||||
pos = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0
|
||||
# Add spacing to position the label
|
||||
while len(label_line) - 10 < pos:
|
||||
label_line += " "
|
||||
label_line += labels[i][:8]
|
||||
|
||||
lines.append(label_line)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _draw_line_segment(
|
||||
grid: list[list[str]], x1: int, y1: int, x2: int, y2: int, width: int, height: int
|
||||
) -> None:
|
||||
"""Draw a line segment between two points using Bresenham-like algorithm."""
|
||||
# Simple line drawing - connect points with line characters
|
||||
if x1 == x2: # Vertical line
|
||||
start_y, end_y = sorted([y1, y2])
|
||||
for y in range(start_y + 1, end_y):
|
||||
if 0 <= y < height and 0 <= x1 < width:
|
||||
grid[y][x1] = "│"
|
||||
elif y1 == y2: # Horizontal line
|
||||
start_x, end_x = sorted([x1, x2])
|
||||
for x in range(start_x + 1, end_x):
|
||||
if 0 <= y1 < height and 0 <= x < width:
|
||||
grid[y1][x] = "─"
|
||||
else: # Diagonal line - use simple interpolation
|
||||
steps = max(abs(x2 - x1), abs(y2 - y1))
|
||||
for step in range(1, steps):
|
||||
x = x1 + int((x2 - x1) * step / steps)
|
||||
y = y1 + int((y2 - y1) * step / steps)
|
||||
if 0 <= x < width and 0 <= y < height:
|
||||
if abs(x2 - x1) > abs(y2 - y1):
|
||||
grid[y][x] = "─"
|
||||
else:
|
||||
grid[y][x] = "│"
|
||||
|
||||
|
||||
def _analyze_trend(values: list[float]) -> list[str]:
|
||||
"""Analyze trend in the data."""
|
||||
if len(values) < 2:
|
||||
return ["• Insufficient data for trend analysis"]
|
||||
|
||||
analysis = []
|
||||
|
||||
# Calculate basic statistics
|
||||
first_val = values[0]
|
||||
last_val = values[-1]
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
avg_val = sum(values) / len(values)
|
||||
|
||||
# Overall trend
|
||||
if last_val > first_val * 1.1:
|
||||
trend_icon = "📈"
|
||||
trend_desc = "Strong upward trend"
|
||||
elif last_val > first_val * 1.05:
|
||||
trend_icon = "📊"
|
||||
trend_desc = "Moderate upward trend"
|
||||
elif last_val < first_val * 0.9:
|
||||
trend_icon = "📉"
|
||||
trend_desc = "Strong downward trend"
|
||||
elif last_val < first_val * 0.95:
|
||||
trend_icon = "📊"
|
||||
trend_desc = "Moderate downward trend"
|
||||
else:
|
||||
trend_icon = "➡️"
|
||||
trend_desc = "Relatively stable"
|
||||
|
||||
analysis.append(f"• {trend_icon} {trend_desc}")
|
||||
analysis.append(f"• Range: {min_val:.1f} to {max_val:.1f} (avg: {avg_val:.1f})")
|
||||
|
||||
# Volatility
|
||||
if len(values) >= 3:
|
||||
changes = [abs(values[i] - values[i - 1]) for i in range(1, len(values))]
|
||||
avg_change = sum(changes) / len(changes)
|
||||
volatility = "High" if avg_change > (max_val - min_val) * 0.1 else "Low"
|
||||
analysis.append(f"• Volatility: {volatility}")
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
def _create_sparkline(values: list[float]) -> list[str]:
|
||||
"""Create sparkline visualization from values."""
|
||||
if len(values) <= 1:
|
||||
return []
|
||||
|
||||
max_val = max(values)
|
||||
min_val = min(values)
|
||||
range_val = max_val - min_val if max_val != min_val else 1
|
||||
|
||||
sparkline = ""
|
||||
for val in values:
|
||||
normalized = (val - min_val) / range_val
|
||||
if normalized < 0.2:
|
||||
sparkline += "▁"
|
||||
elif normalized < 0.4:
|
||||
sparkline += "▂"
|
||||
elif normalized < 0.6:
|
||||
sparkline += "▄"
|
||||
elif normalized < 0.8:
|
||||
sparkline += "▆"
|
||||
else:
|
||||
sparkline += "█"
|
||||
|
||||
# Safe formatting to avoid NaN display
|
||||
if _is_nan_value(min_val) or _is_nan_value(max_val):
|
||||
return ["Range: Unable to calculate from data", sparkline]
|
||||
else:
|
||||
return [f"Range: {min_val:.2f} to {max_val:.2f}", sparkline]
|
||||
|
||||
|
||||
def _is_nan_value(value: Any) -> bool:
|
||||
"""Check if a value is NaN or invalid."""
|
||||
try:
|
||||
return math.isnan(float(value))
|
||||
except (ValueError, TypeError):
|
||||
return True
|
||||
|
||||
|
||||
def _generate_ascii_scatter_chart(data: list[Any], width: int, height: int) -> str:
|
||||
"""Generate ASCII scatter plot."""
|
||||
if not data:
|
||||
return "No data for scatter chart"
|
||||
|
||||
lines = []
|
||||
lines.append("ASCII Scatter Plot")
|
||||
lines.append("=" * min(width, 50))
|
||||
|
||||
# Extract data points
|
||||
x_values, y_values, x_column, y_column = _extract_scatter_data(data)
|
||||
|
||||
# Log debug info server-side only
|
||||
logger.debug(
|
||||
"Scatter chart: x_column=%s, y_column=%s, valid_pairs=%d",
|
||||
x_column,
|
||||
y_column,
|
||||
len(x_values),
|
||||
)
|
||||
|
||||
# Check if we have enough data
|
||||
if len(x_values) < 2:
|
||||
return generate_ascii_table(data, width)
|
||||
|
||||
# Add axis info
|
||||
lines.extend(_create_axis_info(x_values, y_values, x_column, y_column))
|
||||
|
||||
# Create and render grid
|
||||
grid = _create_scatter_grid(x_values, y_values, width, height)
|
||||
lines.extend(_render_scatter_grid(grid, x_values, y_values, width, height))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_scatter_data(
|
||||
data: list[Any],
|
||||
) -> tuple[list[float], list[float], str | None, str | None]:
|
||||
"""Extract X,Y data from scatter chart data."""
|
||||
x_values = []
|
||||
y_values = []
|
||||
x_column = None
|
||||
y_column = None
|
||||
numeric_columns = []
|
||||
|
||||
if data and isinstance(data[0], dict):
|
||||
# Find the first two numeric columns
|
||||
for key, val in data[0].items():
|
||||
if isinstance(val, (int, float)) and not (
|
||||
isinstance(val, float) and (val != val)
|
||||
): # Exclude NaN
|
||||
numeric_columns.append(key)
|
||||
|
||||
if len(numeric_columns) >= 2:
|
||||
x_column = numeric_columns[0]
|
||||
y_column = numeric_columns[1]
|
||||
|
||||
# Extract X,Y pairs
|
||||
for row in data[:50]: # Limit for ASCII display
|
||||
if isinstance(row, dict):
|
||||
x_val = row.get(x_column)
|
||||
y_val = row.get(y_column)
|
||||
# Check for valid numbers (not NaN)
|
||||
if (
|
||||
isinstance(x_val, (int, float))
|
||||
and isinstance(y_val, (int, float))
|
||||
and not (
|
||||
isinstance(x_val, float) and (x_val != x_val)
|
||||
) # Not NaN
|
||||
and not (isinstance(y_val, float) and (y_val != y_val))
|
||||
): # Not NaN
|
||||
x_values.append(x_val)
|
||||
y_values.append(y_val)
|
||||
|
||||
return x_values, y_values, x_column, y_column
|
||||
|
||||
|
||||
def _create_axis_info(
|
||||
x_values: list[float],
|
||||
y_values: list[float],
|
||||
x_column: str | None,
|
||||
y_column: str | None,
|
||||
) -> list[str]:
|
||||
"""Create axis information lines."""
|
||||
return [
|
||||
f"X-axis: {x_column} (range: {min(x_values):.2f} to {max(x_values):.2f})",
|
||||
f"Y-axis: {y_column} (range: {min(y_values):.2f} to {max(y_values):.2f})",
|
||||
f"Showing {len(x_values)} data points",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
def _create_scatter_grid(
|
||||
x_values: list[float], y_values: list[float], width: int, height: int
|
||||
) -> list[list[str]]:
|
||||
"""Create and populate the scatter plot grid."""
|
||||
plot_width = min(40, width - 10)
|
||||
plot_height = min(15, height - 8)
|
||||
|
||||
# Normalize values to fit in grid
|
||||
x_min, x_max = min(x_values), max(x_values)
|
||||
y_min, y_max = min(y_values), max(y_values)
|
||||
x_range = x_max - x_min if x_max != x_min else 1
|
||||
y_range = y_max - y_min if y_max != y_min else 1
|
||||
|
||||
# Create grid
|
||||
grid = [[" " for _ in range(plot_width)] for _ in range(plot_height)]
|
||||
|
||||
# Plot points
|
||||
for x, y in zip(x_values, y_values, strict=False):
|
||||
try:
|
||||
grid_x = int(((x - x_min) / x_range) * (plot_width - 1))
|
||||
grid_y = int(((y - y_min) / y_range) * (plot_height - 1))
|
||||
except (ValueError, OverflowError):
|
||||
# Skip points that can't be converted to integers (NaN, inf, etc.)
|
||||
continue
|
||||
grid_y = plot_height - 1 - grid_y # Flip Y axis for display
|
||||
|
||||
if 0 <= grid_x < plot_width and 0 <= grid_y < plot_height:
|
||||
if grid[grid_y][grid_x] == " ":
|
||||
grid[grid_y][grid_x] = "•"
|
||||
else:
|
||||
grid[grid_y][grid_x] = "█" # Multiple points
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def _render_scatter_grid(
|
||||
grid: list[list[str]],
|
||||
x_values: list[float],
|
||||
y_values: list[float],
|
||||
width: int,
|
||||
height: int,
|
||||
) -> list[str]:
|
||||
"""Render the scatter plot grid with axes and labels."""
|
||||
lines = []
|
||||
plot_width = min(40, width - 10)
|
||||
plot_height = min(15, height - 8)
|
||||
|
||||
x_min, x_max = min(x_values), max(x_values)
|
||||
y_min, y_max = min(y_values), max(y_values)
|
||||
y_range = y_max - y_min if y_max != y_min else 1
|
||||
|
||||
# Add Y-axis labels and plot
|
||||
for i, row in enumerate(grid):
|
||||
y_val = y_max - (i / (plot_height - 1)) * y_range if plot_height > 1 else y_max
|
||||
y_label = f"{y_val:.1f}" if abs(y_val) < 1000 else f"{y_val:.0f}"
|
||||
lines.append(f"{y_label:>6} |{''.join(row)}")
|
||||
|
||||
# Add X-axis
|
||||
x_axis_line = " " * 7 + "+" + "-" * plot_width
|
||||
lines.append(x_axis_line)
|
||||
|
||||
# Add X-axis labels
|
||||
x_left_label = f"{x_min:.1f}" if abs(x_min) < 1000 else f"{x_min:.0f}"
|
||||
x_right_label = f"{x_max:.1f}" if abs(x_max) < 1000 else f"{x_max:.0f}"
|
||||
x_labels = (
|
||||
" " * 8
|
||||
+ x_left_label
|
||||
+ " " * (plot_width - len(x_left_label) - len(x_right_label))
|
||||
+ x_right_label
|
||||
)
|
||||
lines.append(x_labels)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def generate_ascii_table(data: list[Any], width: int) -> str:
|
||||
"""Generate enhanced ASCII table with better formatting."""
|
||||
if not data:
|
||||
return "No data for table"
|
||||
|
||||
lines = []
|
||||
lines.append("📋 Data Table")
|
||||
lines.append("═" * min(width, 70))
|
||||
|
||||
# Get column headers from first row
|
||||
if isinstance(data[0], dict):
|
||||
# Select best columns to display
|
||||
all_headers = list(data[0].keys())
|
||||
headers = _select_display_columns(all_headers, data, max_cols=6)
|
||||
|
||||
# Calculate optimal column widths
|
||||
col_widths = _calculate_column_widths(headers, data, width)
|
||||
|
||||
# Create enhanced header
|
||||
lines.append(_create_table_header(headers, col_widths))
|
||||
lines.append(_create_table_separator(col_widths))
|
||||
|
||||
# Add data rows with better formatting
|
||||
row_count = min(15, len(data)) # Show more rows
|
||||
for i, row in enumerate(data[:row_count]):
|
||||
formatted_row = _format_table_row(row, headers, col_widths)
|
||||
lines.append(formatted_row)
|
||||
|
||||
# Add separator every 5 rows for readability
|
||||
if i > 0 and (i + 1) % 5 == 0 and i < row_count - 1:
|
||||
lines.append(_create_light_separator(col_widths))
|
||||
|
||||
# Add footer with stats
|
||||
lines.append(_create_table_separator(col_widths))
|
||||
lines.append(f"📊 Showing {row_count} of {len(data)} rows")
|
||||
|
||||
# Add column summaries for numeric columns
|
||||
numeric_summaries = _create_numeric_summaries(data, headers)
|
||||
if numeric_summaries:
|
||||
lines.append("")
|
||||
lines.append("📈 Numeric Summaries:")
|
||||
lines.extend(numeric_summaries)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _select_display_columns(
|
||||
all_headers: list[str], data: list[Any], max_cols: int = 6
|
||||
) -> list[str]:
|
||||
"""Select the most interesting columns to display."""
|
||||
if len(all_headers) <= max_cols:
|
||||
return all_headers
|
||||
|
||||
# Prioritize columns by interest level
|
||||
priority_scores = {}
|
||||
|
||||
for header in all_headers:
|
||||
score = 0
|
||||
header_lower = header.lower()
|
||||
|
||||
# Higher priority for common business fields
|
||||
if any(word in header_lower for word in ["name", "title", "id"]):
|
||||
score += 10
|
||||
if any(
|
||||
word in header_lower
|
||||
for word in ["amount", "price", "cost", "revenue", "sales"]
|
||||
):
|
||||
score += 8
|
||||
if any(word in header_lower for word in ["date", "time", "created", "updated"]):
|
||||
score += 6
|
||||
if any(word in header_lower for word in ["count", "total", "sum", "avg"]):
|
||||
score += 5
|
||||
|
||||
# Check data variety (more variety = more interesting)
|
||||
sample_values = [
|
||||
str(row.get(header, "")) for row in data[:10] if isinstance(row, dict)
|
||||
]
|
||||
unique_values = len(set(sample_values))
|
||||
if unique_values > 1:
|
||||
score += min(unique_values, 5)
|
||||
|
||||
priority_scores[header] = score
|
||||
|
||||
# Return top scoring columns
|
||||
sorted_headers = sorted(
|
||||
all_headers, key=lambda h: priority_scores.get(h, 0), reverse=True
|
||||
)
|
||||
return sorted_headers[:max_cols]
|
||||
|
||||
|
||||
def _calculate_column_widths(
|
||||
headers: list[str], data: list[Any], total_width: int
|
||||
) -> list[int]:
|
||||
"""Calculate optimal column widths based on content."""
|
||||
if not headers:
|
||||
return []
|
||||
|
||||
# Start with minimum widths based on header lengths
|
||||
min_widths = [max(8, min(len(h) + 2, 15)) for h in headers]
|
||||
|
||||
# Check actual data to adjust widths
|
||||
for row in data[:10]: # Sample first 10 rows
|
||||
if isinstance(row, dict):
|
||||
for i, header in enumerate(headers):
|
||||
val = row.get(header, "")
|
||||
if isinstance(val, float):
|
||||
content_len = len(f"{val:.2f}")
|
||||
elif isinstance(val, int):
|
||||
content_len = len(str(val))
|
||||
else:
|
||||
content_len = len(str(val))
|
||||
|
||||
min_widths[i] = max(min_widths[i], min(content_len + 1, 20))
|
||||
|
||||
# Distribute remaining space proportionally
|
||||
used_width = sum(min_widths) + len(headers) * 3 # Account for separators
|
||||
available_width = min(total_width - 10, 80) # Leave margins
|
||||
|
||||
if used_width < available_width:
|
||||
# Distribute extra space
|
||||
extra_space = available_width - used_width
|
||||
for i in range(len(min_widths)):
|
||||
min_widths[i] += extra_space // len(min_widths)
|
||||
|
||||
return min_widths
|
||||
|
||||
|
||||
def _create_table_header(headers: list[str], widths: list[int]) -> str:
|
||||
"""Create formatted table header."""
|
||||
formatted_headers = []
|
||||
for header, width in zip(headers, widths, strict=False):
|
||||
# Truncate and center header
|
||||
display_header = header[: width - 2] if len(header) > width - 2 else header
|
||||
formatted_headers.append(f"{display_header:^{width}}")
|
||||
|
||||
return (
|
||||
"┌"
|
||||
+ "┬".join("─" * w for w in widths)
|
||||
+ "┐\n│"
|
||||
+ "│".join(formatted_headers)
|
||||
+ "│"
|
||||
)
|
||||
|
||||
|
||||
def _create_table_separator(widths: list[int]) -> str:
|
||||
"""Create table separator line."""
|
||||
return "├" + "┼".join("─" * w for w in widths) + "┤"
|
||||
|
||||
|
||||
def _create_light_separator(widths: list[int]) -> str:
|
||||
"""Create light separator line."""
|
||||
return "├" + "┼".join("┈" * w for w in widths) + "┤"
|
||||
|
||||
|
||||
def _format_table_row(
|
||||
row: dict[str, Any], headers: list[str], widths: list[int]
|
||||
) -> str:
|
||||
"""Format a single table row."""
|
||||
formatted_values = []
|
||||
|
||||
for header, width in zip(headers, widths, strict=False):
|
||||
val = row.get(header, "")
|
||||
|
||||
# Format value based on type
|
||||
if isinstance(val, float):
|
||||
if abs(val) >= 1000000:
|
||||
formatted_val = f"{val / 1000000:.1f}M"
|
||||
elif abs(val) >= 1000:
|
||||
formatted_val = f"{val / 1000:.1f}K"
|
||||
else:
|
||||
formatted_val = f"{val:.2f}"
|
||||
elif isinstance(val, int):
|
||||
if abs(val) >= 1000000:
|
||||
formatted_val = f"{val // 1000000}M"
|
||||
elif abs(val) >= 1000:
|
||||
formatted_val = f"{val // 1000}K"
|
||||
else:
|
||||
formatted_val = str(val)
|
||||
else:
|
||||
formatted_val = str(val)
|
||||
|
||||
# Truncate if too long
|
||||
if len(formatted_val) > width - 2:
|
||||
formatted_val = formatted_val[: width - 5] + "..."
|
||||
|
||||
# Align numbers right, text left
|
||||
if isinstance(val, (int, float)):
|
||||
formatted_values.append(f"{formatted_val:>{width - 2}}")
|
||||
else:
|
||||
formatted_values.append(f"{formatted_val:<{width - 2}}")
|
||||
|
||||
return "│ " + " │ ".join(formatted_values) + " │"
|
||||
|
||||
|
||||
def _create_numeric_summaries(data: list[Any], headers: list[str]) -> list[str]:
|
||||
"""Create summaries for numeric columns."""
|
||||
summaries = []
|
||||
|
||||
for header in headers:
|
||||
# Collect numeric values
|
||||
values = []
|
||||
for row in data:
|
||||
if isinstance(row, dict):
|
||||
val = row.get(header)
|
||||
if isinstance(val, (int, float)):
|
||||
values.append(val)
|
||||
|
||||
if len(values) >= 2:
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
avg_val = sum(values) / len(values)
|
||||
|
||||
if abs(avg_val) >= 1000:
|
||||
avg_str = f"{avg_val / 1000:.1f}K"
|
||||
else:
|
||||
avg_str = f"{avg_val:.1f}"
|
||||
|
||||
summaries.append(
|
||||
f" {header}: avg={avg_str}, range={min_val:.1f}-{max_val:.1f}"
|
||||
)
|
||||
|
||||
return summaries
|
||||
81
superset/mcp_service/chart/chart_helpers.py
Normal file
81
superset/mcp_service/chart/chart_helpers.py
Normal file
@@ -0,0 +1,81 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Shared helper functions for MCP chart tools.
|
||||
|
||||
This module contains reusable utility functions for common operations
|
||||
across chart tools: chart lookup, cached form data retrieval, and
|
||||
URL parameter extraction. Config mapping logic lives in chart_utils.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.slice import Slice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def find_chart_by_identifier(identifier: int | str) -> Slice | None:
|
||||
"""Find a chart by numeric ID or UUID string.
|
||||
|
||||
Accepts an integer ID, a string that looks like a digit (e.g. "123"),
|
||||
or a UUID string. Returns the Slice model instance or None.
|
||||
"""
|
||||
from superset.daos.chart import ChartDAO # avoid circular import
|
||||
|
||||
if isinstance(identifier, int) or (
|
||||
isinstance(identifier, str) and identifier.isdigit()
|
||||
):
|
||||
chart_id = int(identifier) if isinstance(identifier, str) else identifier
|
||||
return ChartDAO.find_by_id(chart_id)
|
||||
return ChartDAO.find_by_id(identifier, id_column="uuid")
|
||||
|
||||
|
||||
def get_cached_form_data(form_data_key: str) -> str | None:
|
||||
"""Retrieve form_data from cache using form_data_key.
|
||||
|
||||
Returns the JSON string of form_data if found, None otherwise.
|
||||
"""
|
||||
# avoid circular import — commands depend on app initialization
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.commands.explore.form_data.get import GetFormDataCommand
|
||||
from superset.commands.explore.form_data.parameters import CommandParameters
|
||||
|
||||
try:
|
||||
cmd_params = CommandParameters(key=form_data_key)
|
||||
return GetFormDataCommand(cmd_params).run()
|
||||
except (KeyError, ValueError, CommandException) as e:
|
||||
logger.warning("Failed to retrieve form_data from cache: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def extract_form_data_key_from_url(url: str | None) -> str | None:
|
||||
"""Extract the form_data_key query parameter from an explore URL.
|
||||
|
||||
Returns the form_data_key value or None if not found or URL is empty.
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
parsed = urlparse(url)
|
||||
values = parse_qs(parsed.query).get("form_data_key", [])
|
||||
return values[0] if values else None
|
||||
@@ -22,7 +22,6 @@ import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -32,6 +31,7 @@ from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.auth import has_dataset_access
|
||||
from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
analyze_chart_capabilities,
|
||||
analyze_chart_semantics,
|
||||
@@ -540,13 +540,8 @@ async def generate_chart( # noqa: C901
|
||||
explore_url = generate_explore_link(request.dataset_id, form_data)
|
||||
await ctx.debug("Generated explore link: explore_url=%s" % (explore_url,))
|
||||
|
||||
# Extract form_data_key from the explore URL using proper URL parsing
|
||||
if explore_url:
|
||||
parsed = urlparse(explore_url)
|
||||
query_params = parse_qs(parsed.query)
|
||||
form_data_key_list = query_params.get("form_data_key", [])
|
||||
if form_data_key_list:
|
||||
form_data_key = form_data_key_list[0]
|
||||
# Extract form_data_key from the explore URL
|
||||
form_data_key = extract_form_data_key_from_url(explore_url)
|
||||
|
||||
# Compile check for preview-only mode
|
||||
# Validate dataset existence and user access before running queries
|
||||
|
||||
@@ -25,15 +25,19 @@ from typing import Any, Dict, List, TYPE_CHECKING
|
||||
|
||||
from fastmcp import Context
|
||||
from flask import current_app
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.slice import Slice
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.commands.explore.form_data.parameters import CommandParameters
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_helpers import (
|
||||
find_chart_by_identifier,
|
||||
get_cached_form_data,
|
||||
)
|
||||
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ChartData,
|
||||
@@ -62,21 +66,6 @@ def _apply_extra_form_data(
|
||||
merge_extra_filters(form_data)
|
||||
|
||||
|
||||
def _get_cached_form_data(form_data_key: str) -> str | None:
|
||||
"""Retrieve form_data from cache using form_data_key.
|
||||
|
||||
Returns the JSON string of form_data if found, None otherwise.
|
||||
"""
|
||||
from superset.commands.explore.form_data.get import GetFormDataCommand
|
||||
|
||||
try:
|
||||
cmd_params = CommandParameters(key=form_data_key)
|
||||
return GetFormDataCommand(cmd_params).run()
|
||||
except (KeyError, ValueError, CommandException) as e:
|
||||
logger.warning("Failed to retrieve form_data from cache: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["data"],
|
||||
class_permission_name="Chart",
|
||||
@@ -127,7 +116,6 @@ async def get_chart_data( # noqa: C901
|
||||
|
||||
try:
|
||||
await ctx.report_progress(1, 4, "Looking up chart")
|
||||
from superset.daos.chart import ChartDAO
|
||||
from superset.utils import json as utils_json
|
||||
|
||||
chart = None
|
||||
@@ -141,7 +129,7 @@ async def get_chart_data( # noqa: C901
|
||||
"No chart identifier - querying data from unsaved chart cache: "
|
||||
"form_data_key=%s" % (request.form_data_key,)
|
||||
)
|
||||
cached_form_data = _get_cached_form_data(request.form_data_key)
|
||||
cached_form_data = get_cached_form_data(request.form_data_key)
|
||||
if not cached_form_data:
|
||||
return ChartError(
|
||||
error="No cached chart data found for form_data_key. "
|
||||
@@ -166,25 +154,13 @@ async def get_chart_data( # noqa: C901
|
||||
|
||||
# Find the chart by identifier
|
||||
with event_logger.log_context(action="mcp.get_chart_data.chart_lookup"):
|
||||
if isinstance(request.identifier, int) or (
|
||||
isinstance(request.identifier, str) and request.identifier.isdigit()
|
||||
):
|
||||
chart_id = (
|
||||
int(request.identifier)
|
||||
if isinstance(request.identifier, str)
|
||||
else request.identifier
|
||||
await ctx.debug("Looking up chart: identifier=%s" % (request.identifier,))
|
||||
if request.identifier is None:
|
||||
return ChartError(
|
||||
error="Chart identifier is required",
|
||||
error_type="ValidationError",
|
||||
)
|
||||
await ctx.debug(
|
||||
"Performing ID-based chart lookup: chart_id=%s" % (chart_id,)
|
||||
)
|
||||
chart = ChartDAO.find_by_id(chart_id)
|
||||
elif isinstance(request.identifier, str):
|
||||
await ctx.debug(
|
||||
"Performing UUID-based chart lookup: uuid=%s"
|
||||
% (request.identifier,)
|
||||
)
|
||||
# Try UUID lookup using DAO flexible method
|
||||
chart = ChartDAO.find_by_id(request.identifier, id_column="uuid")
|
||||
chart = find_chart_by_identifier(request.identifier)
|
||||
|
||||
if not chart:
|
||||
await ctx.error("Chart not found: identifier=%s" % (request.identifier,))
|
||||
@@ -239,7 +215,7 @@ async def get_chart_data( # noqa: C901
|
||||
"Retrieving unsaved chart state from cache: form_data_key=%s"
|
||||
% (request.form_data_key,)
|
||||
)
|
||||
if cached_form_data := _get_cached_form_data(request.form_data_key):
|
||||
if cached_form_data := get_cached_form_data(request.form_data_key):
|
||||
try:
|
||||
parsed_form_data = utils_json.loads(cached_form_data)
|
||||
# Only use if it's actually a dict (not null, list, etc.)
|
||||
@@ -818,7 +794,15 @@ async def get_chart_data( # noqa: C901
|
||||
error=OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
error_type="OAUTH2_REDIRECT_ERROR",
|
||||
)
|
||||
except Exception as e:
|
||||
except (
|
||||
SupersetException,
|
||||
CommandException,
|
||||
SQLAlchemyError,
|
||||
KeyError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
await ctx.error(
|
||||
"Chart data retrieval failed: identifier=%s, error=%s, error_type=%s"
|
||||
% (
|
||||
|
||||
@@ -25,9 +25,8 @@ from fastmcp import Context
|
||||
from sqlalchemy.orm import subqueryload
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.commands.explore.form_data.parameters import CommandParameters
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_helpers import get_cached_form_data
|
||||
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ChartError,
|
||||
@@ -41,26 +40,11 @@ from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_cached_form_data(form_data_key: str) -> str | None:
|
||||
"""Retrieve form_data from cache using form_data_key.
|
||||
|
||||
Returns the JSON string of form_data if found, None otherwise.
|
||||
"""
|
||||
from superset.commands.explore.form_data.get import GetFormDataCommand
|
||||
|
||||
try:
|
||||
cmd_params = CommandParameters(key=form_data_key)
|
||||
return GetFormDataCommand(cmd_params).run()
|
||||
except (KeyError, ValueError, CommandException) as e:
|
||||
logger.warning("Failed to retrieve form_data from cache: %s", e)
|
||||
return None
|
||||
|
||||
|
||||
def _build_unsaved_chart_info(form_data_key: str) -> ChartInfo | ChartError:
|
||||
"""Build a ChartInfo from cached form_data when no chart identifier exists."""
|
||||
from superset.utils import json as utils_json
|
||||
|
||||
cached_form_data = _get_cached_form_data(form_data_key)
|
||||
cached_form_data = get_cached_form_data(form_data_key)
|
||||
if not cached_form_data:
|
||||
return ChartError(
|
||||
error="No cached chart data found for form_data_key. "
|
||||
@@ -94,7 +78,7 @@ def _apply_unsaved_state_override(result: ChartInfo, form_data_key: str) -> None
|
||||
"""Override a ChartInfo's form_data with cached unsaved state."""
|
||||
from superset.utils import json as utils_json
|
||||
|
||||
if cached_form_data := _get_cached_form_data(form_data_key):
|
||||
if cached_form_data := get_cached_form_data(form_data_key):
|
||||
try:
|
||||
result.form_data = utils_json.loads(cached_form_data)
|
||||
result.form_data_key = form_data_key
|
||||
|
||||
@@ -23,11 +23,17 @@ import logging
|
||||
from typing import Any, Dict, List, Protocol
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.ascii_charts import (
|
||||
generate_ascii_chart,
|
||||
generate_ascii_table,
|
||||
)
|
||||
from superset.mcp_service.chart.chart_helpers import find_chart_by_identifier
|
||||
from superset.mcp_service.chart.chart_utils import validate_chart_dataset
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AccessibilityMetadata,
|
||||
@@ -262,7 +268,7 @@ class TablePreviewStrategy(PreviewFormatStrategy):
|
||||
if result and "queries" in result and len(result["queries"]) > 0:
|
||||
data = result["queries"][0].get("data", [])
|
||||
|
||||
table_data = _generate_ascii_table(data, 120)
|
||||
table_data = generate_ascii_table(data, 120)
|
||||
|
||||
return TablePreview(
|
||||
table_data=table_data,
|
||||
@@ -964,874 +970,6 @@ class PreviewFormatGenerator:
|
||||
return strategy.generate()
|
||||
|
||||
|
||||
def generate_ascii_chart(
|
||||
data: List[Any], chart_type: str, width: int = 80, height: int = 20
|
||||
) -> str:
|
||||
"""Generate ASCII art chart from data."""
|
||||
if not data or len(data) == 0:
|
||||
return "No data available for ASCII chart"
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"generate_ascii_chart: chart_type=%s, data_rows=%s", chart_type, len(data)
|
||||
)
|
||||
|
||||
# Generate appropriate ASCII chart based on type
|
||||
if chart_type in ["bar", "column", "echarts_timeseries_bar"]:
|
||||
logger.info("Generating bar chart")
|
||||
return _generate_ascii_bar_chart(data, width, height)
|
||||
elif chart_type in ["line", "echarts_timeseries_line"]:
|
||||
logger.info("Generating line chart")
|
||||
return _generate_ascii_line_chart(data, width, height)
|
||||
elif chart_type in ["scatter", "echarts_timeseries_scatter"]:
|
||||
logger.info("Generating scatter chart")
|
||||
return _generate_ascii_scatter_chart(data, width, height)
|
||||
else:
|
||||
# Default to table format for unsupported chart types
|
||||
logger.info(
|
||||
"Unsupported chart type '%s', falling back to table", chart_type
|
||||
)
|
||||
return _generate_ascii_table(data, width)
|
||||
except (TypeError, ValueError, KeyError, IndexError) as e:
|
||||
logger.error("ASCII chart generation failed: %s", e)
|
||||
import traceback
|
||||
|
||||
logger.error("Traceback: %s", traceback.format_exc())
|
||||
return f"ASCII chart generation failed: {str(e)}"
|
||||
|
||||
|
||||
def _generate_ascii_bar_chart(data: List[Any], width: int, height: int) -> str:
|
||||
"""Generate enhanced ASCII bar chart with horizontal and vertical options."""
|
||||
if not data:
|
||||
return "No data for bar chart"
|
||||
|
||||
# Extract numeric values for bars
|
||||
values = []
|
||||
labels = []
|
||||
|
||||
for row in data[:12]: # Increased limit for better charts
|
||||
if isinstance(row, dict):
|
||||
# Find numeric and string values
|
||||
numeric_val = None
|
||||
label_val = None
|
||||
|
||||
for _key, val in row.items():
|
||||
if isinstance(val, (int, float)) and numeric_val is None:
|
||||
numeric_val = val
|
||||
elif isinstance(val, str) and label_val is None:
|
||||
label_val = val
|
||||
|
||||
if numeric_val is not None:
|
||||
values.append(numeric_val)
|
||||
labels.append(label_val or f"Item {len(values)}")
|
||||
|
||||
if not values:
|
||||
return "No numeric data found for bar chart"
|
||||
|
||||
# Decide between horizontal and vertical based on label lengths
|
||||
avg_label_length = sum(len(str(label)) for label in labels) / len(labels)
|
||||
use_horizontal = avg_label_length > 8 or len(values) > 8
|
||||
|
||||
if use_horizontal:
|
||||
return _generate_horizontal_bar_chart(values, labels, width)
|
||||
else:
|
||||
return _generate_vertical_bar_chart(values, labels, width, height)
|
||||
|
||||
|
||||
def _generate_horizontal_bar_chart(
|
||||
values: List[float], labels: List[str], width: int
|
||||
) -> str:
|
||||
"""Generate horizontal ASCII bar chart."""
|
||||
lines = []
|
||||
lines.append("📊 Horizontal Bar Chart")
|
||||
lines.append("═" * min(width, 60))
|
||||
|
||||
max_val = max(values) if values else 1
|
||||
min_val = min(values) if values else 0
|
||||
max_bar_width = min(40, width - 20) # Leave space for labels and values
|
||||
|
||||
# Add scale indicator
|
||||
lines.append(f"Scale: {min_val:.1f} ────────────── {max_val:.1f}")
|
||||
lines.append("")
|
||||
|
||||
for label, value in zip(labels, values, strict=False):
|
||||
# Calculate bar length
|
||||
if max_val > min_val:
|
||||
normalized = (value - min_val) / (max_val - min_val)
|
||||
bar_length = max(1, int(normalized * max_bar_width))
|
||||
else:
|
||||
bar_length = 1
|
||||
|
||||
# Create bar with gradient effect
|
||||
bar = _create_gradient_bar(bar_length, value, max_val)
|
||||
|
||||
# Format value
|
||||
if abs(value) >= 1000000:
|
||||
value_str = f"{value / 1000000:.1f}M"
|
||||
elif abs(value) >= 1000:
|
||||
value_str = f"{value / 1000:.1f}K"
|
||||
else:
|
||||
value_str = f"{value:.1f}"
|
||||
|
||||
# Truncate label if too long
|
||||
display_label = label[:15] if len(label) > 15 else label
|
||||
lines.append(f"{display_label:>15} ▐{bar:<{max_bar_width}} {value_str}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _generate_vertical_bar_chart( # noqa: C901
|
||||
values: List[float], labels: List[str], width: int, height: int
|
||||
) -> str:
|
||||
"""Generate vertical ASCII bar chart."""
|
||||
lines = []
|
||||
lines.append("📊 Vertical Bar Chart")
|
||||
lines.append("═" * min(width, 60))
|
||||
|
||||
max_val = max(values) if values else 1
|
||||
min_val = min(values) if values else 0
|
||||
chart_height = min(15, height - 8) # Leave space for title and labels
|
||||
|
||||
# Create the chart grid
|
||||
grid = []
|
||||
for _ in range(chart_height):
|
||||
grid.append([" "] * len(values))
|
||||
|
||||
# Fill the bars
|
||||
for col, value in enumerate(values):
|
||||
if max_val > min_val:
|
||||
normalized = (value - min_val) / (max_val - min_val)
|
||||
bar_height = max(1, int(normalized * chart_height))
|
||||
else:
|
||||
bar_height = 1
|
||||
|
||||
# Fill from bottom up
|
||||
for row_idx in range(chart_height - bar_height, chart_height):
|
||||
if row_idx < len(grid):
|
||||
# Use different characters for height effect
|
||||
if row_idx == chart_height - bar_height:
|
||||
grid[row_idx][col] = "▀" # Top of bar
|
||||
elif row_idx == chart_height - 1:
|
||||
grid[row_idx][col] = "█" # Bottom of bar
|
||||
else:
|
||||
grid[row_idx][col] = "█" # Middle of bar
|
||||
|
||||
# Add Y-axis scale
|
||||
for i, row_data in enumerate(grid):
|
||||
y_val = (
|
||||
max_val - (i / (chart_height - 1)) * (max_val - min_val)
|
||||
if chart_height > 1
|
||||
else max_val
|
||||
)
|
||||
if abs(y_val) >= 1000:
|
||||
y_label = f"{y_val:.0f}"
|
||||
else:
|
||||
y_label = f"{y_val:.1f}"
|
||||
lines.append(f"{y_label:>6} ┤ " + "".join(f"{cell:^3}" for cell in row_data))
|
||||
|
||||
# Add X-axis
|
||||
lines.append(" └" + "───" * len(values))
|
||||
|
||||
# Add labels
|
||||
label_line = " "
|
||||
for label in labels:
|
||||
short_label = label[:3] if len(label) > 3 else label
|
||||
label_line += f"{short_label:^3}"
|
||||
lines.append(label_line)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _create_gradient_bar(length: int, value: float, max_val: float) -> str:
|
||||
"""Create a gradient bar with visual effects."""
|
||||
if length <= 0:
|
||||
return ""
|
||||
|
||||
# Create gradient effect based on value intensity
|
||||
intensity = value / max_val if max_val > 0 else 0
|
||||
|
||||
if intensity > 0.8:
|
||||
# High values - solid bars
|
||||
return "█" * length
|
||||
elif intensity > 0.6:
|
||||
# Medium-high values - mostly solid with some texture
|
||||
return "█" * (length - 1) + "▉" if length > 1 else "█"
|
||||
elif intensity > 0.4:
|
||||
# Medium values - mixed texture
|
||||
return "▊" * length
|
||||
elif intensity > 0.2:
|
||||
# Low-medium values - lighter texture
|
||||
return "▋" * length
|
||||
else:
|
||||
# Low values - lightest texture
|
||||
return "▌" * length
|
||||
|
||||
|
||||
def _generate_ascii_line_chart(data: List[Any], width: int, height: int) -> str:
|
||||
"""Generate enhanced ASCII line chart with trend analysis."""
|
||||
if not data:
|
||||
return "No data for line chart"
|
||||
|
||||
lines = []
|
||||
lines.append("📈 Line Chart with Trend Analysis")
|
||||
lines.append("═" * min(width, 60))
|
||||
|
||||
# Extract values and labels for plotting
|
||||
values, labels = _extract_time_series_data(data)
|
||||
|
||||
if not values:
|
||||
return "No numeric data found for line chart"
|
||||
|
||||
# Generate enhanced line chart
|
||||
if len(values) >= 3:
|
||||
lines.extend(_create_enhanced_line_chart(values, labels, width, height))
|
||||
else:
|
||||
# Fallback to sparkline for small datasets
|
||||
sparkline_data = _create_sparkline(values)
|
||||
lines.extend(sparkline_data)
|
||||
|
||||
# Add trend analysis
|
||||
trend_analysis = _analyze_trend(values)
|
||||
lines.append("")
|
||||
lines.append("📊 Trend Analysis:")
|
||||
lines.extend(trend_analysis)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_time_series_data(data: List[Any]) -> tuple[List[float], List[str]]:
|
||||
"""Extract time series data with labels."""
|
||||
values = []
|
||||
labels = []
|
||||
|
||||
for row in data[:20]: # Limit points for readability
|
||||
if isinstance(row, dict):
|
||||
# Find the first numeric value and first string/date value
|
||||
numeric_val = None
|
||||
label_val = None
|
||||
|
||||
for key, val in row.items():
|
||||
if isinstance(val, (int, float)) and numeric_val is None:
|
||||
numeric_val = val
|
||||
elif isinstance(val, str) and label_val is None:
|
||||
# Use the key name if it looks like a date/time field
|
||||
if any(
|
||||
date_word in key.lower()
|
||||
for date_word in ["date", "time", "month", "day", "year"]
|
||||
):
|
||||
label_val = str(val)[:10] # Truncate long dates
|
||||
else:
|
||||
label_val = str(val)[:8] # Truncate long strings
|
||||
|
||||
if numeric_val is not None:
|
||||
values.append(numeric_val)
|
||||
labels.append(label_val or f"P{len(values)}")
|
||||
|
||||
return values, labels
|
||||
|
||||
|
||||
def _create_enhanced_line_chart(
|
||||
values: List[float], labels: List[str], width: int, height: int
|
||||
) -> List[str]:
|
||||
"""Create an enhanced ASCII line chart with better visualization."""
|
||||
lines = []
|
||||
chart_width = min(50, width - 15)
|
||||
chart_height = min(12, height - 8)
|
||||
|
||||
if len(values) < 2:
|
||||
return ["Insufficient data for line chart"]
|
||||
|
||||
# Normalize values to chart height
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
val_range = max_val - min_val if max_val != min_val else 1
|
||||
|
||||
# Create chart grid
|
||||
grid = [[" " for _ in range(chart_width)] for _ in range(chart_height)]
|
||||
|
||||
# Plot the line with connecting segments
|
||||
prev_x, prev_y = None, None
|
||||
|
||||
for i, value in enumerate(values):
|
||||
# Map to grid coordinates
|
||||
x = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0
|
||||
y = chart_height - 1 - int(((value - min_val) / val_range) * (chart_height - 1))
|
||||
|
||||
# Ensure coordinates are within bounds
|
||||
x = max(0, min(x, chart_width - 1))
|
||||
y = max(0, min(y, chart_height - 1))
|
||||
|
||||
# Mark the point
|
||||
grid[y][x] = "●"
|
||||
|
||||
# Draw line segment to previous point
|
||||
if prev_x is not None and prev_y is not None:
|
||||
_draw_line_segment(grid, prev_x, prev_y, x, y, chart_width, chart_height)
|
||||
|
||||
prev_x, prev_y = x, y
|
||||
|
||||
# Render the chart with Y-axis labels
|
||||
for i, row in enumerate(grid):
|
||||
y_val = (
|
||||
max_val - (i / (chart_height - 1)) * val_range
|
||||
if chart_height > 1
|
||||
else max_val
|
||||
)
|
||||
if abs(y_val) >= 1000:
|
||||
y_label = f"{y_val:.0f}"
|
||||
else:
|
||||
y_label = f"{y_val:.1f}"
|
||||
lines.append(f"{y_label:>8} ┤ " + "".join(row))
|
||||
|
||||
# Add X-axis
|
||||
lines.append(" └" + "─" * chart_width)
|
||||
|
||||
# Add selected X-axis labels (show every few labels to avoid crowding)
|
||||
label_line = " "
|
||||
step = max(1, len(labels) // 6) # Show max 6 labels
|
||||
for i in range(0, len(labels), step):
|
||||
pos = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0
|
||||
# Add spacing to position the label
|
||||
while len(label_line) - 10 < pos:
|
||||
label_line += " "
|
||||
label_line += labels[i][:8]
|
||||
|
||||
lines.append(label_line)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _draw_line_segment(
|
||||
grid: List[List[str]], x1: int, y1: int, x2: int, y2: int, width: int, height: int
|
||||
) -> None:
|
||||
"""Draw a line segment between two points using Bresenham-like algorithm."""
|
||||
# Simple line drawing - connect points with line characters
|
||||
if x1 == x2: # Vertical line
|
||||
start_y, end_y = sorted([y1, y2])
|
||||
for y in range(start_y + 1, end_y):
|
||||
if 0 <= y < height and 0 <= x1 < width:
|
||||
grid[y][x1] = "│"
|
||||
elif y1 == y2: # Horizontal line
|
||||
start_x, end_x = sorted([x1, x2])
|
||||
for x in range(start_x + 1, end_x):
|
||||
if 0 <= y1 < height and 0 <= x < width:
|
||||
grid[y1][x] = "─"
|
||||
else: # Diagonal line - use simple interpolation
|
||||
steps = max(abs(x2 - x1), abs(y2 - y1))
|
||||
for step in range(1, steps):
|
||||
x = x1 + int((x2 - x1) * step / steps)
|
||||
y = y1 + int((y2 - y1) * step / steps)
|
||||
if 0 <= x < width and 0 <= y < height:
|
||||
if abs(x2 - x1) > abs(y2 - y1):
|
||||
grid[y][x] = "─"
|
||||
else:
|
||||
grid[y][x] = "│"
|
||||
|
||||
|
||||
def _analyze_trend(values: List[float]) -> List[str]:
|
||||
"""Analyze trend in the data."""
|
||||
if len(values) < 2:
|
||||
return ["• Insufficient data for trend analysis"]
|
||||
|
||||
analysis = []
|
||||
|
||||
# Calculate basic statistics
|
||||
first_val = values[0]
|
||||
last_val = values[-1]
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
avg_val = sum(values) / len(values)
|
||||
|
||||
# Overall trend
|
||||
if last_val > first_val * 1.1:
|
||||
trend_icon = "📈"
|
||||
trend_desc = "Strong upward trend"
|
||||
elif last_val > first_val * 1.05:
|
||||
trend_icon = "📊"
|
||||
trend_desc = "Moderate upward trend"
|
||||
elif last_val < first_val * 0.9:
|
||||
trend_icon = "📉"
|
||||
trend_desc = "Strong downward trend"
|
||||
elif last_val < first_val * 0.95:
|
||||
trend_icon = "📊"
|
||||
trend_desc = "Moderate downward trend"
|
||||
else:
|
||||
trend_icon = "➡️"
|
||||
trend_desc = "Relatively stable"
|
||||
|
||||
analysis.append(f"• {trend_icon} {trend_desc}")
|
||||
analysis.append(f"• Range: {min_val:.1f} to {max_val:.1f} (avg: {avg_val:.1f})")
|
||||
|
||||
# Volatility
|
||||
if len(values) >= 3:
|
||||
changes = [abs(values[i] - values[i - 1]) for i in range(1, len(values))]
|
||||
avg_change = sum(changes) / len(changes)
|
||||
volatility = "High" if avg_change > (max_val - min_val) * 0.1 else "Low"
|
||||
analysis.append(f"• Volatility: {volatility}")
|
||||
|
||||
return analysis
|
||||
|
||||
|
||||
def _extract_numeric_values(data: List[Any]) -> List[float]:
|
||||
"""Extract numeric values from data for line chart."""
|
||||
values = []
|
||||
for row in data[:20]: # Limit points
|
||||
if isinstance(row, dict):
|
||||
for _key, val in row.items():
|
||||
if isinstance(val, (int, float)):
|
||||
values.append(val)
|
||||
break
|
||||
return values
|
||||
|
||||
|
||||
def _create_sparkline(values: List[float]) -> List[str]:
|
||||
"""Create sparkline visualization from values."""
|
||||
if len(values) <= 1:
|
||||
return []
|
||||
|
||||
max_val = max(values)
|
||||
min_val = min(values)
|
||||
range_val = max_val - min_val if max_val != min_val else 1
|
||||
|
||||
sparkline = ""
|
||||
for val in values:
|
||||
normalized = (val - min_val) / range_val
|
||||
if normalized < 0.2:
|
||||
sparkline += "▁"
|
||||
elif normalized < 0.4:
|
||||
sparkline += "▂"
|
||||
elif normalized < 0.6:
|
||||
sparkline += "▄"
|
||||
elif normalized < 0.8:
|
||||
sparkline += "▆"
|
||||
else:
|
||||
sparkline += "█"
|
||||
|
||||
# Safe formatting to avoid NaN display
|
||||
if _is_nan_value(min_val) or _is_nan_value(max_val):
|
||||
return ["Range: Unable to calculate from data", sparkline]
|
||||
else:
|
||||
return [f"Range: {min_val:.2f} to {max_val:.2f}", sparkline]
|
||||
|
||||
|
||||
def _is_nan_value(value: Any) -> bool:
|
||||
"""Check if a value is NaN or invalid."""
|
||||
try:
|
||||
import math
|
||||
|
||||
return math.isnan(float(value))
|
||||
except (ValueError, TypeError):
|
||||
return True
|
||||
|
||||
|
||||
def _generate_ascii_scatter_chart(data: List[Any], width: int, height: int) -> str:
|
||||
"""Generate ASCII scatter plot."""
|
||||
if not data:
|
||||
return "No data for scatter chart"
|
||||
|
||||
lines = []
|
||||
lines.append("ASCII Scatter Plot")
|
||||
lines.append("=" * min(width, 50))
|
||||
|
||||
# Extract data points
|
||||
x_values, y_values, x_column, y_column = _extract_scatter_data(data)
|
||||
|
||||
# Debug info
|
||||
lines.extend(_create_debug_info(x_values, y_values, x_column, y_column, data))
|
||||
|
||||
# Check if we have enough data
|
||||
if len(x_values) < 2:
|
||||
return _generate_ascii_table(data, width)
|
||||
|
||||
# Add axis info
|
||||
lines.extend(_create_axis_info(x_values, y_values, x_column, y_column))
|
||||
|
||||
# Create and render grid
|
||||
grid = _create_scatter_grid(x_values, y_values, width, height)
|
||||
lines.extend(_render_scatter_grid(grid, x_values, y_values, width, height))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_scatter_data(
|
||||
data: List[Any],
|
||||
) -> tuple[List[float], List[float], str | None, str | None]:
|
||||
"""Extract X,Y data from scatter chart data."""
|
||||
x_values = []
|
||||
y_values = []
|
||||
x_column = None
|
||||
y_column = None
|
||||
numeric_columns = []
|
||||
|
||||
if data and isinstance(data[0], dict):
|
||||
# Find the first two numeric columns
|
||||
for key, val in data[0].items():
|
||||
if isinstance(val, (int, float)) and not (
|
||||
isinstance(val, float) and (val != val)
|
||||
): # Exclude NaN
|
||||
numeric_columns.append(key)
|
||||
|
||||
if len(numeric_columns) >= 2:
|
||||
x_column = numeric_columns[0]
|
||||
y_column = numeric_columns[1]
|
||||
|
||||
# Extract X,Y pairs
|
||||
for row in data[:50]: # Limit for ASCII display
|
||||
if isinstance(row, dict):
|
||||
x_val = row.get(x_column)
|
||||
y_val = row.get(y_column)
|
||||
# Check for valid numbers (not NaN)
|
||||
if (
|
||||
isinstance(x_val, (int, float))
|
||||
and isinstance(y_val, (int, float))
|
||||
and not (
|
||||
isinstance(x_val, float) and (x_val != x_val)
|
||||
) # Not NaN
|
||||
and not (isinstance(y_val, float) and (y_val != y_val))
|
||||
): # Not NaN
|
||||
x_values.append(x_val)
|
||||
y_values.append(y_val)
|
||||
|
||||
return x_values, y_values, x_column, y_column
|
||||
|
||||
|
||||
def _create_debug_info(
|
||||
x_values: List[float],
|
||||
y_values: List[float],
|
||||
x_column: str | None,
|
||||
y_column: str | None,
|
||||
data: List[Any],
|
||||
) -> List[str]:
|
||||
"""Create debug information lines for scatter chart."""
|
||||
numeric_columns = []
|
||||
if data and isinstance(data[0], dict):
|
||||
for key, val in data[0].items():
|
||||
if isinstance(val, (int, float)) and not (
|
||||
isinstance(val, float) and (val != val)
|
||||
):
|
||||
numeric_columns.append(key)
|
||||
|
||||
return [
|
||||
f"DEBUG: Found {len(numeric_columns)} numeric columns: {numeric_columns}",
|
||||
f"DEBUG: X column: {x_column}, Y column: {y_column}",
|
||||
f"DEBUG: Valid X,Y pairs: {len(x_values)}",
|
||||
]
|
||||
|
||||
|
||||
def _create_axis_info(
|
||||
x_values: List[float],
|
||||
y_values: List[float],
|
||||
x_column: str | None,
|
||||
y_column: str | None,
|
||||
) -> List[str]:
|
||||
"""Create axis information lines."""
|
||||
return [
|
||||
f"X-axis: {x_column} (range: {min(x_values):.2f} to {max(x_values):.2f})",
|
||||
f"Y-axis: {y_column} (range: {min(y_values):.2f} to {max(y_values):.2f})",
|
||||
f"Showing {len(x_values)} data points",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
def _create_scatter_grid(
|
||||
x_values: List[float], y_values: List[float], width: int, height: int
|
||||
) -> List[List[str]]:
|
||||
"""Create and populate the scatter plot grid."""
|
||||
plot_width = min(40, width - 10)
|
||||
plot_height = min(15, height - 8)
|
||||
|
||||
# Normalize values to fit in grid
|
||||
x_min, x_max = min(x_values), max(x_values)
|
||||
y_min, y_max = min(y_values), max(y_values)
|
||||
x_range = x_max - x_min if x_max != x_min else 1
|
||||
y_range = y_max - y_min if y_max != y_min else 1
|
||||
|
||||
# Create grid
|
||||
grid = [[" " for _ in range(plot_width)] for _ in range(plot_height)]
|
||||
|
||||
# Plot points
|
||||
for x, y in zip(x_values, y_values, strict=False):
|
||||
try:
|
||||
grid_x = int(((x - x_min) / x_range) * (plot_width - 1))
|
||||
grid_y = int(((y - y_min) / y_range) * (plot_height - 1))
|
||||
grid_y = plot_height - 1 - grid_y # Flip Y axis for display
|
||||
|
||||
if 0 <= grid_x < plot_width and 0 <= grid_y < plot_height:
|
||||
if grid[grid_y][grid_x] == " ":
|
||||
grid[grid_y][grid_x] = "•"
|
||||
else:
|
||||
grid[grid_y][grid_x] = "█" # Multiple points
|
||||
except (ValueError, OverflowError):
|
||||
# Skip points that can't be converted to integers (NaN, inf, etc.)
|
||||
continue
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def _render_scatter_grid(
|
||||
grid: List[List[str]],
|
||||
x_values: List[float],
|
||||
y_values: List[float],
|
||||
width: int,
|
||||
height: int,
|
||||
) -> List[str]:
|
||||
"""Render the scatter plot grid with axes and labels."""
|
||||
lines = []
|
||||
plot_width = min(40, width - 10)
|
||||
plot_height = min(15, height - 8)
|
||||
|
||||
x_min, x_max = min(x_values), max(x_values)
|
||||
y_min, y_max = min(y_values), max(y_values)
|
||||
y_range = y_max - y_min if y_max != y_min else 1
|
||||
|
||||
# Add Y-axis labels and plot
|
||||
for i, row in enumerate(grid):
|
||||
y_val = y_max - (i / (plot_height - 1)) * y_range if plot_height > 1 else y_max
|
||||
y_label = f"{y_val:.1f}" if abs(y_val) < 1000 else f"{y_val:.0f}"
|
||||
lines.append(f"{y_label:>6} |{''.join(row)}")
|
||||
|
||||
# Add X-axis
|
||||
x_axis_line = " " * 7 + "+" + "-" * plot_width
|
||||
lines.append(x_axis_line)
|
||||
|
||||
# Add X-axis labels
|
||||
x_left_label = f"{x_min:.1f}" if abs(x_min) < 1000 else f"{x_min:.0f}"
|
||||
x_right_label = f"{x_max:.1f}" if abs(x_max) < 1000 else f"{x_max:.0f}"
|
||||
x_labels = (
|
||||
" " * 8
|
||||
+ x_left_label
|
||||
+ " " * (plot_width - len(x_left_label) - len(x_right_label))
|
||||
+ x_right_label
|
||||
)
|
||||
lines.append(x_labels)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _generate_ascii_table(data: List[Any], width: int) -> str:
|
||||
"""Generate enhanced ASCII table with better formatting."""
|
||||
if not data:
|
||||
return "No data for table"
|
||||
|
||||
lines = []
|
||||
lines.append("📋 Data Table")
|
||||
lines.append("═" * min(width, 70))
|
||||
|
||||
# Get column headers from first row
|
||||
if isinstance(data[0], dict):
|
||||
# Select best columns to display
|
||||
all_headers = list(data[0].keys())
|
||||
headers = _select_display_columns(all_headers, data, max_cols=6)
|
||||
|
||||
# Calculate optimal column widths
|
||||
col_widths = _calculate_column_widths(headers, data, width)
|
||||
|
||||
# Create enhanced header
|
||||
lines.append(_create_table_header(headers, col_widths))
|
||||
lines.append(_create_table_separator(col_widths))
|
||||
|
||||
# Add data rows with better formatting
|
||||
row_count = min(15, len(data)) # Show more rows
|
||||
for i, row in enumerate(data[:row_count]):
|
||||
formatted_row = _format_table_row(row, headers, col_widths)
|
||||
lines.append(formatted_row)
|
||||
|
||||
# Add separator every 5 rows for readability
|
||||
if i > 0 and (i + 1) % 5 == 0 and i < row_count - 1:
|
||||
lines.append(_create_light_separator(col_widths))
|
||||
|
||||
# Add footer with stats
|
||||
lines.append(_create_table_separator(col_widths))
|
||||
lines.append(f"📊 Showing {row_count} of {len(data)} rows")
|
||||
|
||||
# Add column summaries for numeric columns
|
||||
numeric_summaries = _create_numeric_summaries(data, headers)
|
||||
if numeric_summaries:
|
||||
lines.append("")
|
||||
lines.append("📈 Numeric Summaries:")
|
||||
lines.extend(numeric_summaries)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _select_display_columns(
|
||||
all_headers: List[str], data: List[Any], max_cols: int = 6
|
||||
) -> List[str]:
|
||||
"""Select the most interesting columns to display."""
|
||||
if len(all_headers) <= max_cols:
|
||||
return all_headers
|
||||
|
||||
# Prioritize columns by interest level
|
||||
priority_scores = {}
|
||||
|
||||
for header in all_headers:
|
||||
score = 0
|
||||
header_lower = header.lower()
|
||||
|
||||
# Higher priority for common business fields
|
||||
if any(word in header_lower for word in ["name", "title", "id"]):
|
||||
score += 10
|
||||
if any(
|
||||
word in header_lower
|
||||
for word in ["amount", "price", "cost", "revenue", "sales"]
|
||||
):
|
||||
score += 8
|
||||
if any(word in header_lower for word in ["date", "time", "created", "updated"]):
|
||||
score += 6
|
||||
if any(word in header_lower for word in ["count", "total", "sum", "avg"]):
|
||||
score += 5
|
||||
|
||||
# Check data variety (more variety = more interesting)
|
||||
sample_values = [
|
||||
str(row.get(header, "")) for row in data[:10] if isinstance(row, dict)
|
||||
]
|
||||
unique_values = len(set(sample_values))
|
||||
if unique_values > 1:
|
||||
score += min(unique_values, 5)
|
||||
|
||||
priority_scores[header] = score
|
||||
|
||||
# Return top scoring columns
|
||||
sorted_headers = sorted(
|
||||
all_headers, key=lambda h: priority_scores.get(h, 0), reverse=True
|
||||
)
|
||||
return sorted_headers[:max_cols]
|
||||
|
||||
|
||||
def _calculate_column_widths(
|
||||
headers: List[str], data: List[Any], total_width: int
|
||||
) -> List[int]:
|
||||
"""Calculate optimal column widths based on content."""
|
||||
if not headers:
|
||||
return []
|
||||
|
||||
# Start with minimum widths based on header lengths
|
||||
min_widths = [max(8, min(len(h) + 2, 15)) for h in headers]
|
||||
|
||||
# Check actual data to adjust widths
|
||||
for row in data[:10]: # Sample first 10 rows
|
||||
if isinstance(row, dict):
|
||||
for i, header in enumerate(headers):
|
||||
val = row.get(header, "")
|
||||
if isinstance(val, float):
|
||||
content_len = len(f"{val:.2f}")
|
||||
elif isinstance(val, int):
|
||||
content_len = len(str(val))
|
||||
else:
|
||||
content_len = len(str(val))
|
||||
|
||||
min_widths[i] = max(min_widths[i], min(content_len + 1, 20))
|
||||
|
||||
# Distribute remaining space proportionally
|
||||
used_width = sum(min_widths) + len(headers) * 3 # Account for separators
|
||||
available_width = min(total_width - 10, 80) # Leave margins
|
||||
|
||||
if used_width < available_width:
|
||||
# Distribute extra space
|
||||
extra_space = available_width - used_width
|
||||
for i in range(len(min_widths)):
|
||||
min_widths[i] += extra_space // len(min_widths)
|
||||
|
||||
return min_widths
|
||||
|
||||
|
||||
def _create_table_header(headers: List[str], widths: List[int]) -> str:
|
||||
"""Create formatted table header."""
|
||||
formatted_headers = []
|
||||
for header, width in zip(headers, widths, strict=False):
|
||||
# Truncate and center header
|
||||
display_header = header[: width - 2] if len(header) > width - 2 else header
|
||||
formatted_headers.append(f"{display_header:^{width}}")
|
||||
|
||||
return (
|
||||
"┌"
|
||||
+ "┬".join("─" * w for w in widths)
|
||||
+ "┐\n│"
|
||||
+ "│".join(formatted_headers)
|
||||
+ "│"
|
||||
)
|
||||
|
||||
|
||||
def _create_table_separator(widths: List[int]) -> str:
|
||||
"""Create table separator line."""
|
||||
return "├" + "┼".join("─" * w for w in widths) + "┤"
|
||||
|
||||
|
||||
def _create_light_separator(widths: List[int]) -> str:
|
||||
"""Create light separator line."""
|
||||
return "├" + "┼".join("┈" * w for w in widths) + "┤"
|
||||
|
||||
|
||||
def _format_table_row(
|
||||
row: Dict[str, Any], headers: List[str], widths: List[int]
|
||||
) -> str:
|
||||
"""Format a single table row."""
|
||||
formatted_values = []
|
||||
|
||||
for header, width in zip(headers, widths, strict=False):
|
||||
val = row.get(header, "")
|
||||
|
||||
# Format value based on type
|
||||
if isinstance(val, float):
|
||||
if abs(val) >= 1000000:
|
||||
formatted_val = f"{val / 1000000:.1f}M"
|
||||
elif abs(val) >= 1000:
|
||||
formatted_val = f"{val / 1000:.1f}K"
|
||||
else:
|
||||
formatted_val = f"{val:.2f}"
|
||||
elif isinstance(val, int):
|
||||
if abs(val) >= 1000000:
|
||||
formatted_val = f"{val // 1000000}M"
|
||||
elif abs(val) >= 1000:
|
||||
formatted_val = f"{val // 1000}K"
|
||||
else:
|
||||
formatted_val = str(val)
|
||||
else:
|
||||
formatted_val = str(val)
|
||||
|
||||
# Truncate if too long
|
||||
if len(formatted_val) > width - 2:
|
||||
formatted_val = formatted_val[: width - 5] + "..."
|
||||
|
||||
# Align numbers right, text left
|
||||
if isinstance(val, (int, float)):
|
||||
formatted_values.append(f"{formatted_val:>{width - 2}}")
|
||||
else:
|
||||
formatted_values.append(f"{formatted_val:<{width - 2}}")
|
||||
|
||||
return "│ " + " │ ".join(formatted_values) + " │"
|
||||
|
||||
|
||||
def _create_numeric_summaries(data: List[Any], headers: List[str]) -> List[str]:
|
||||
"""Create summaries for numeric columns."""
|
||||
summaries = []
|
||||
|
||||
for header in headers:
|
||||
# Collect numeric values
|
||||
values = []
|
||||
for row in data:
|
||||
if isinstance(row, dict):
|
||||
val = row.get(header)
|
||||
if isinstance(val, (int, float)):
|
||||
values.append(val)
|
||||
|
||||
if len(values) >= 2:
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
avg_val = sum(values) / len(values)
|
||||
|
||||
if abs(avg_val) >= 1000:
|
||||
avg_str = f"{avg_val / 1000:.1f}K"
|
||||
else:
|
||||
avg_str = f"{avg_val:.1f}"
|
||||
|
||||
summaries.append(
|
||||
f" {header}: avg={avg_str}, range={min_val:.1f}-{max_val:.1f}"
|
||||
)
|
||||
|
||||
return summaries
|
||||
|
||||
|
||||
async def _get_chart_preview_internal( # noqa: C901
|
||||
request: GetChartPreviewRequest,
|
||||
ctx: Context,
|
||||
@@ -1852,7 +990,6 @@ async def _get_chart_preview_internal( # noqa: C901
|
||||
"""
|
||||
try:
|
||||
await ctx.report_progress(1, 3, "Looking up chart")
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
# Find the chart
|
||||
with event_logger.log_context(action="mcp.get_chart_preview.chart_lookup"):
|
||||
@@ -1921,25 +1058,16 @@ async def _get_chart_preview_internal( # noqa: C901
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
elif isinstance(request.identifier, int) or (
|
||||
isinstance(request.identifier, str) and request.identifier.isdigit()
|
||||
):
|
||||
chart_id = (
|
||||
int(request.identifier)
|
||||
if isinstance(request.identifier, str)
|
||||
else request.identifier
|
||||
)
|
||||
else:
|
||||
await ctx.debug(
|
||||
"Performing ID-based chart lookup: chart_id=%s" % (chart_id,)
|
||||
"Looking up chart: identifier=%s" % (request.identifier,)
|
||||
)
|
||||
chart = ChartDAO.find_by_id(chart_id)
|
||||
elif isinstance(request.identifier, str):
|
||||
await ctx.debug(
|
||||
"Performing UUID-based chart lookup: uuid=%s"
|
||||
% (request.identifier,)
|
||||
)
|
||||
# Try UUID lookup using DAO flexible method
|
||||
chart = ChartDAO.find_by_id(request.identifier, id_column="uuid")
|
||||
if request.identifier is None:
|
||||
return ChartError(
|
||||
error="Chart identifier is required",
|
||||
error_type="ValidationError",
|
||||
)
|
||||
chart = find_chart_by_identifier(request.identifier)
|
||||
|
||||
# If not found and looks like a form_data_key, try transient
|
||||
if (
|
||||
@@ -2268,7 +1396,15 @@ async def get_chart_preview(
|
||||
error=OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
error_type="OAUTH2_REDIRECT_ERROR",
|
||||
)
|
||||
except Exception as e:
|
||||
except (
|
||||
SupersetException,
|
||||
CommandException,
|
||||
SQLAlchemyError,
|
||||
KeyError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
await ctx.error(
|
||||
"Chart preview generation failed: identifier=%s, error=%s, error_type=%s"
|
||||
% (
|
||||
|
||||
@@ -22,7 +22,6 @@ MCP tool: update_chart
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -31,6 +30,10 @@ from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_helpers import (
|
||||
extract_form_data_key_from_url,
|
||||
find_chart_by_identifier,
|
||||
)
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
analyze_chart_capabilities,
|
||||
analyze_chart_semantics,
|
||||
@@ -54,18 +57,6 @@ from superset.utils import json
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _find_chart(identifier: int | str) -> Any | None:
|
||||
"""Find a chart by numeric ID or UUID string."""
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
if isinstance(identifier, int) or (
|
||||
isinstance(identifier, str) and identifier.isdigit()
|
||||
):
|
||||
chart_id = int(identifier) if isinstance(identifier, str) else identifier
|
||||
return ChartDAO.find_by_id(chart_id)
|
||||
return ChartDAO.find_by_id(identifier, id_column="uuid")
|
||||
|
||||
|
||||
def _validation_error_response(message: str, details: str) -> GenerateChartResponse:
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
@@ -277,7 +268,7 @@ async def update_chart( # noqa: C901
|
||||
|
||||
try:
|
||||
with event_logger.log_context(action="mcp.update_chart.chart_lookup"):
|
||||
chart = _find_chart(request.identifier)
|
||||
chart = find_chart_by_identifier(request.identifier)
|
||||
|
||||
if not chart:
|
||||
return GenerateChartResponse.model_validate(
|
||||
@@ -412,15 +403,21 @@ async def update_chart( # noqa: C901
|
||||
if hasattr(preview_result, "content"):
|
||||
previews[format_type] = preview_result.content
|
||||
|
||||
except Exception as e:
|
||||
except (
|
||||
OAuth2RedirectError,
|
||||
OAuth2Error,
|
||||
CommandException,
|
||||
SQLAlchemyError,
|
||||
KeyError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
logger.warning("Preview generation failed: %s", e)
|
||||
|
||||
# Fallback: extract form_data_key from explore_url if not set
|
||||
if form_data_key is None and explore_url and "form_data_key=" in explore_url:
|
||||
parsed = urlparse(explore_url)
|
||||
values = parse_qs(parsed.query).get("form_data_key")
|
||||
if values:
|
||||
form_data_key = values[0]
|
||||
if form_data_key is None:
|
||||
form_data_key = extract_form_data_key_from_url(explore_url)
|
||||
|
||||
chart_id = updated_chart.id if saved and updated_chart else chart.id
|
||||
chart_uuid = (
|
||||
|
||||
@@ -24,10 +24,13 @@ import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
analyze_chart_capabilities,
|
||||
analyze_chart_semantics,
|
||||
@@ -123,9 +126,19 @@ def update_chart_preview(
|
||||
explore_url = generate_explore_link(request.dataset_id, new_form_data)
|
||||
|
||||
# Extract new form_data_key from the explore URL
|
||||
new_form_data_key = None
|
||||
if "form_data_key=" in explore_url:
|
||||
new_form_data_key = explore_url.split("form_data_key=")[1].split("&")[0]
|
||||
new_form_data_key = extract_form_data_key_from_url(explore_url)
|
||||
if not new_form_data_key:
|
||||
return {
|
||||
"chart": None,
|
||||
"error": {
|
||||
"error_type": "PreviewError",
|
||||
"message": "Failed to generate preview: missing form_data_key",
|
||||
"details": "The explore URL did not contain a form_data_key",
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
|
||||
with event_logger.log_context(action="mcp.update_chart_preview.metadata"):
|
||||
# Generate semantic analysis
|
||||
@@ -199,7 +212,15 @@ def update_chart_preview(
|
||||
"error": OAUTH2_CONFIG_ERROR_MESSAGE,
|
||||
"success": False,
|
||||
}
|
||||
except Exception as e:
|
||||
except (
|
||||
SupersetException,
|
||||
CommandException,
|
||||
SQLAlchemyError,
|
||||
KeyError,
|
||||
ValueError,
|
||||
TypeError,
|
||||
AttributeError,
|
||||
) as e:
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
return {
|
||||
"chart": None,
|
||||
|
||||
@@ -23,12 +23,12 @@ chart configuration.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
generate_explore_link as generate_url,
|
||||
map_config_to_form_data,
|
||||
@@ -175,14 +175,8 @@ async def generate_explore_link(
|
||||
dataset_id=request.dataset_id, form_data=form_data
|
||||
)
|
||||
|
||||
# Extract form_data_key from the explore URL using proper URL parsing
|
||||
form_data_key = None
|
||||
if explore_url:
|
||||
parsed = urlparse(explore_url)
|
||||
query_params = parse_qs(parsed.query)
|
||||
form_data_key_list = query_params.get("form_data_key", [])
|
||||
if form_data_key_list:
|
||||
form_data_key = form_data_key_list[0]
|
||||
# Extract form_data_key from the explore URL
|
||||
form_data_key = extract_form_data_key_from_url(explore_url)
|
||||
|
||||
await ctx.report_progress(4, 4, "URL generation complete")
|
||||
await ctx.info(
|
||||
|
||||
108
tests/unit_tests/mcp_service/chart/test_chart_helpers.py
Normal file
108
tests/unit_tests/mcp_service/chart/test_chart_helpers.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.mcp_service.chart.chart_helpers import (
|
||||
extract_form_data_key_from_url,
|
||||
find_chart_by_identifier,
|
||||
get_cached_form_data,
|
||||
)
|
||||
|
||||
|
||||
def test_extract_form_data_key_from_url_with_key():
|
||||
url = "http://localhost:8088/explore/?form_data_key=abc123&slice_id=1"
|
||||
assert extract_form_data_key_from_url(url) == "abc123"
|
||||
|
||||
|
||||
def test_extract_form_data_key_from_url_no_key():
|
||||
url = "http://localhost:8088/explore/?slice_id=1"
|
||||
assert extract_form_data_key_from_url(url) is None
|
||||
|
||||
|
||||
def test_extract_form_data_key_from_url_none():
|
||||
assert extract_form_data_key_from_url(None) is None
|
||||
|
||||
|
||||
def test_extract_form_data_key_from_url_empty():
|
||||
assert extract_form_data_key_from_url("") is None
|
||||
|
||||
|
||||
def test_extract_form_data_key_from_url_multiple_params():
|
||||
url = "http://localhost:8088/explore/?slice_id=5&form_data_key=xyz789&other=val"
|
||||
assert extract_form_data_key_from_url(url) == "xyz789"
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
def test_find_chart_by_identifier_int(mock_find):
|
||||
mock_chart = MagicMock()
|
||||
mock_chart.id = 42
|
||||
mock_find.return_value = mock_chart
|
||||
|
||||
result = find_chart_by_identifier(42)
|
||||
mock_find.assert_called_once_with(42)
|
||||
assert result == mock_chart
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
def test_find_chart_by_identifier_str_digit(mock_find):
|
||||
mock_chart = MagicMock()
|
||||
mock_find.return_value = mock_chart
|
||||
|
||||
result = find_chart_by_identifier("123")
|
||||
mock_find.assert_called_once_with(123)
|
||||
assert result == mock_chart
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
def test_find_chart_by_identifier_uuid(mock_find):
|
||||
mock_chart = MagicMock()
|
||||
mock_find.return_value = mock_chart
|
||||
|
||||
uuid_str = "a1b2c3d4-5678-90ab-cdef-1234567890ab"
|
||||
result = find_chart_by_identifier(uuid_str)
|
||||
mock_find.assert_called_once_with(uuid_str, id_column="uuid")
|
||||
assert result == mock_chart
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
def test_find_chart_by_identifier_not_found(mock_find):
|
||||
mock_find.return_value = None
|
||||
result = find_chart_by_identifier(999)
|
||||
assert result is None
|
||||
|
||||
|
||||
@patch(
|
||||
"superset.commands.explore.form_data.get.GetFormDataCommand.run",
|
||||
return_value='{"viz_type": "table"}',
|
||||
)
|
||||
@patch("superset.commands.explore.form_data.get.GetFormDataCommand.__init__")
|
||||
def test_get_cached_form_data_success(mock_init, mock_run):
|
||||
mock_init.return_value = None
|
||||
result = get_cached_form_data("test_key")
|
||||
assert result == '{"viz_type": "table"}'
|
||||
|
||||
|
||||
@patch(
|
||||
"superset.commands.explore.form_data.get.GetFormDataCommand.run",
|
||||
side_effect=KeyError("not found"),
|
||||
)
|
||||
@patch("superset.commands.explore.form_data.get.GetFormDataCommand.__init__")
|
||||
def test_get_cached_form_data_key_error(mock_init, mock_run):
|
||||
mock_init.return_value = None
|
||||
result = get_cached_form_data("bad_key")
|
||||
assert result is None
|
||||
@@ -26,6 +26,7 @@ import pytest
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.chart.chart_helpers import find_chart_by_identifier
|
||||
from superset.mcp_service.chart.chart_utils import DatasetValidationResult
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
@@ -40,7 +41,6 @@ from superset.mcp_service.chart.schemas import (
|
||||
from superset.mcp_service.chart.tool.update_chart import (
|
||||
_build_preview_form_data,
|
||||
_build_update_payload,
|
||||
_find_chart,
|
||||
)
|
||||
|
||||
# The __init__.py re-exports the update_chart *function*, so a plain
|
||||
@@ -519,7 +519,7 @@ class TestUpdateChartDatasetAccess:
|
||||
|
||||
|
||||
class TestFindChart:
|
||||
"""Tests for _find_chart helper."""
|
||||
"""Tests for find_chart_by_identifier helper (moved to chart_helpers)."""
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
def test_find_chart_by_numeric_id(self, mock_find):
|
||||
@@ -527,7 +527,7 @@ class TestFindChart:
|
||||
mock_chart = Mock()
|
||||
mock_find.return_value = mock_chart
|
||||
|
||||
result = _find_chart(42)
|
||||
result = find_chart_by_identifier(42)
|
||||
|
||||
mock_find.assert_called_once_with(42)
|
||||
assert result is mock_chart
|
||||
@@ -538,7 +538,7 @@ class TestFindChart:
|
||||
mock_chart = Mock()
|
||||
mock_find.return_value = mock_chart
|
||||
|
||||
result = _find_chart("123")
|
||||
result = find_chart_by_identifier("123")
|
||||
|
||||
mock_find.assert_called_once_with(123)
|
||||
assert result is mock_chart
|
||||
@@ -550,7 +550,7 @@ class TestFindChart:
|
||||
mock_find.return_value = mock_chart
|
||||
|
||||
uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
result = _find_chart(uuid)
|
||||
result = find_chart_by_identifier(uuid)
|
||||
|
||||
mock_find.assert_called_once_with(uuid, id_column="uuid")
|
||||
assert result is mock_chart
|
||||
@@ -560,7 +560,7 @@ class TestFindChart:
|
||||
"""Returns None when chart is not found."""
|
||||
mock_find.return_value = None
|
||||
|
||||
result = _find_chart(999)
|
||||
result = find_chart_by_identifier(999)
|
||||
|
||||
assert result is None
|
||||
|
||||
|
||||
Reference in New Issue
Block a user