mirror of
https://github.com/apache/superset.git
synced 2026-04-11 12:26:05 +00:00
562 lines
17 KiB
Python
562 lines
17 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
||
# or more contributor license agreements. See the NOTICE file
|
||
# distributed with this work for additional information
|
||
# regarding copyright ownership. The ASF licenses this file
|
||
# to you under the Apache License, Version 2.0 (the
|
||
# "License"); you may not use this file except in compliance
|
||
# with the License. You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing,
|
||
# software distributed under the License is distributed on an
|
||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||
# KIND, either express or implied. See the License for the
|
||
# specific language governing permissions and limitations
|
||
# under the License.
|
||
|
||
"""
|
||
Preview utilities for chart generation without saving.
|
||
|
||
This module provides utilities for generating chart previews
|
||
from form data without requiring a saved chart object.
|
||
"""
|
||
|
||
import logging
|
||
from typing import Any, Dict, List
|
||
|
||
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
||
from superset.mcp_service.chart.schemas import (
|
||
ASCIIPreview,
|
||
ChartError,
|
||
TablePreview,
|
||
VegaLitePreview,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
def generate_preview_from_form_data(
|
||
form_data: Dict[str, Any], dataset_id: int, preview_format: str
|
||
) -> Any:
|
||
"""
|
||
Generate preview from form data without a saved chart.
|
||
|
||
Args:
|
||
form_data: Chart configuration form data
|
||
dataset_id: Dataset ID
|
||
preview_format: Preview format (ascii, table, etc.)
|
||
|
||
Returns:
|
||
Preview object or ChartError
|
||
"""
|
||
try:
|
||
# Execute query to get data
|
||
from superset.connectors.sqla.models import SqlaTable
|
||
from superset.extensions import db
|
||
|
||
dataset = db.session.query(SqlaTable).get(dataset_id)
|
||
if not dataset:
|
||
return ChartError(
|
||
error=f"Dataset {dataset_id} not found", error_type="DatasetNotFound"
|
||
)
|
||
|
||
# Create query context from form data using factory
|
||
from superset.common.query_context_factory import QueryContextFactory
|
||
|
||
factory = QueryContextFactory()
|
||
query_context_obj = factory.create(
|
||
datasource={"id": dataset_id, "type": "table"},
|
||
queries=[
|
||
{
|
||
"columns": form_data.get("columns", []),
|
||
"metrics": form_data.get("metrics", []),
|
||
"orderby": form_data.get("orderby", []),
|
||
"row_limit": form_data.get("row_limit", 100),
|
||
"filters": form_data.get("adhoc_filters", []),
|
||
"time_range": form_data.get("time_range", "No filter"),
|
||
}
|
||
],
|
||
form_data=form_data,
|
||
)
|
||
|
||
# Execute query
|
||
command = ChartDataCommand(query_context_obj)
|
||
result = command.run()
|
||
|
||
if not result or not result.get("queries"):
|
||
return ChartError(
|
||
error="No data returned from query", error_type="EmptyResult"
|
||
)
|
||
|
||
query_result = result["queries"][0]
|
||
data = query_result.get("data", [])
|
||
|
||
# Generate preview based on format
|
||
if preview_format == "ascii":
|
||
return _generate_ascii_preview_from_data(data, form_data)
|
||
elif preview_format == "table":
|
||
return _generate_table_preview_from_data(data, form_data)
|
||
elif preview_format == "vega_lite":
|
||
return _generate_vega_lite_preview_from_data(data, form_data)
|
||
else:
|
||
return ChartError(
|
||
error=f"Unsupported preview format: {preview_format}",
|
||
error_type="UnsupportedFormat",
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error("Preview generation from form data failed: %s", e)
|
||
return ChartError(
|
||
error=f"Failed to generate preview: {str(e)}", error_type="PreviewError"
|
||
)
|
||
|
||
|
||
def _generate_ascii_preview_from_data(
|
||
data: List[Dict[str, Any]], form_data: Dict[str, Any]
|
||
) -> ASCIIPreview:
|
||
"""Generate ASCII preview from raw data."""
|
||
viz_type = form_data.get("viz_type", "table")
|
||
|
||
# Handle different chart types
|
||
if viz_type in ["bar", "dist_bar", "column"]:
|
||
content = _generate_safe_ascii_bar_chart(data)
|
||
elif viz_type in ["line", "area"]:
|
||
content = _generate_safe_ascii_line_chart(data)
|
||
elif viz_type == "pie":
|
||
content = _generate_safe_ascii_pie_chart(data)
|
||
else:
|
||
content = _generate_safe_ascii_table(data)
|
||
|
||
return ASCIIPreview(
|
||
ascii_content=content, width=80, height=20, supports_color=False
|
||
)
|
||
|
||
|
||
def _calculate_column_widths(
|
||
display_columns: List[str], data: List[Dict[str, Any]]
|
||
) -> Dict[str, int]:
|
||
"""Calculate optimal width for each column."""
|
||
column_widths = {}
|
||
for col in display_columns:
|
||
# Start with column name length
|
||
max_width = len(str(col))
|
||
|
||
# Check data values to determine width
|
||
for row in data[:20]: # Sample first 20 rows
|
||
val = row.get(col, "")
|
||
if isinstance(val, float):
|
||
val_str = f"{val:.2f}"
|
||
elif isinstance(val, int):
|
||
val_str = str(val)
|
||
else:
|
||
val_str = str(val)
|
||
max_width = max(max_width, len(val_str))
|
||
|
||
# Set reasonable bounds
|
||
column_widths[col] = min(max(max_width, 8), 25)
|
||
return column_widths
|
||
|
||
|
||
def _format_value(val: Any, width: int) -> str:
|
||
"""Format a value based on its type."""
|
||
if isinstance(val, float):
|
||
if abs(val) >= 1000000:
|
||
val_str = f"{val:.2e}" # Scientific notation for large numbers
|
||
elif abs(val) >= 1000:
|
||
val_str = f"{val:,.2f}" # Thousands separator
|
||
else:
|
||
val_str = f"{val:.2f}"
|
||
elif isinstance(val, int):
|
||
if abs(val) >= 1000:
|
||
val_str = f"{val:,}" # Thousands separator
|
||
else:
|
||
val_str = str(val)
|
||
elif val is None:
|
||
val_str = "NULL"
|
||
else:
|
||
val_str = str(val)
|
||
|
||
# Truncate if too long
|
||
if len(val_str) > width:
|
||
val_str = val_str[: width - 2] + ".."
|
||
return val_str
|
||
|
||
|
||
def _generate_table_preview_from_data(
|
||
data: List[Dict[str, Any]], form_data: Dict[str, Any]
|
||
) -> TablePreview:
|
||
"""Generate table preview from raw data with improved formatting."""
|
||
if not data:
|
||
return TablePreview(
|
||
table_data="No data available", row_count=0, supports_sorting=False
|
||
)
|
||
|
||
# Get columns
|
||
columns = list(data[0].keys()) if data else []
|
||
|
||
# Determine optimal column widths and how many columns to show
|
||
max_columns = 8 # Show more columns than before
|
||
display_columns = columns[:max_columns]
|
||
|
||
# Calculate optimal width for each column
|
||
column_widths = _calculate_column_widths(display_columns, data)
|
||
|
||
# Format table with proper alignment
|
||
lines = ["Table Preview", "=" * 80]
|
||
|
||
# Header with dynamic width
|
||
header_parts = []
|
||
separator_parts = []
|
||
for col in display_columns:
|
||
width = column_widths[col]
|
||
col_name = str(col)
|
||
if len(col_name) > width:
|
||
col_name = col_name[: width - 2] + ".."
|
||
header_parts.append(f"{col_name:<{width}}")
|
||
separator_parts.append("-" * width)
|
||
|
||
lines.append(" | ".join(header_parts))
|
||
lines.append("-+-".join(separator_parts))
|
||
|
||
# Data rows with proper formatting
|
||
rows_shown = min(len(data), 15) # Show more rows
|
||
for row in data[:rows_shown]:
|
||
row_parts = []
|
||
for col in display_columns:
|
||
width = column_widths[col]
|
||
val = row.get(col, "")
|
||
val_str = _format_value(val, width)
|
||
row_parts.append(f"{val_str:<{width}}")
|
||
lines.append(" | ".join(row_parts))
|
||
|
||
# Summary information
|
||
if len(data) > rows_shown:
|
||
lines.append(f"... and {len(data) - rows_shown} more rows")
|
||
|
||
if len(columns) > max_columns:
|
||
lines.append(f"... and {len(columns) - max_columns} more columns")
|
||
|
||
lines.append("")
|
||
lines.append(f"Total: {len(data)} rows × {len(columns)} columns")
|
||
|
||
return TablePreview(
|
||
table_data="\n".join(lines), row_count=len(data), supports_sorting=True
|
||
)
|
||
|
||
|
||
def _generate_safe_ascii_bar_chart(data: List[Dict[str, Any]]) -> str:
|
||
"""Generate ASCII bar chart with proper error handling."""
|
||
if not data:
|
||
return "No data available for bar chart"
|
||
|
||
lines = ["ASCII Bar Chart", "=" * 50]
|
||
|
||
# Extract values safely
|
||
values = []
|
||
labels = []
|
||
|
||
for row in data[:10]:
|
||
label = None
|
||
value = None
|
||
|
||
for _, val in row.items():
|
||
if isinstance(val, (int, float)) and not _is_nan(val) and value is None:
|
||
value = val
|
||
elif isinstance(val, str) and label is None:
|
||
label = val
|
||
|
||
if value is not None:
|
||
values.append(value)
|
||
labels.append(label or f"Item {len(values)}")
|
||
|
||
if not values:
|
||
return "No numeric data found for bar chart"
|
||
|
||
# Generate bars
|
||
max_val = max(values)
|
||
if max_val == 0:
|
||
return "All values are zero"
|
||
|
||
for label, value in zip(labels, values, strict=False):
|
||
bar_length = int((value / max_val) * 30)
|
||
bar = "█" * bar_length
|
||
lines.append(f"{label[:10]:>10} |{bar:<30} {value:.2f}")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_safe_ascii_line_chart(data: List[Dict[str, Any]]) -> str:
|
||
"""Generate ASCII line chart with proper NaN handling."""
|
||
if not data:
|
||
return "No data available for line chart"
|
||
|
||
lines = ["ASCII Line Chart", "=" * 50]
|
||
values = _extract_numeric_values_safe(data)
|
||
|
||
if not values:
|
||
return "No valid numeric data found for line chart"
|
||
|
||
range_str = _format_range_display(values)
|
||
lines.append(range_str)
|
||
|
||
sparkline = _generate_sparkline_safe(values)
|
||
lines.append(sparkline)
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _extract_numeric_values_safe(data: List[Dict[str, Any]]) -> List[float]:
|
||
"""Extract numeric values safely from data."""
|
||
values = []
|
||
for row in data[:20]:
|
||
for _, val in row.items():
|
||
if isinstance(val, (int, float)) and not _is_nan(val):
|
||
values.append(val)
|
||
break
|
||
return values
|
||
|
||
|
||
def _format_range_display(values: List[float]) -> str:
|
||
"""Format range display safely."""
|
||
min_val = min(values)
|
||
max_val = max(values)
|
||
|
||
if _is_nan(min_val) or _is_nan(max_val):
|
||
return "Range: Unable to calculate"
|
||
else:
|
||
return f"Range: {min_val:.2f} to {max_val:.2f}"
|
||
|
||
|
||
def _generate_sparkline_safe(values: List[float]) -> str:
|
||
"""Generate sparkline from values."""
|
||
if not values:
|
||
return ""
|
||
|
||
min_val = min(values)
|
||
|
||
if (max_val := max(values)) != min_val:
|
||
sparkline = ""
|
||
for val in values:
|
||
normalized = (val - min_val) / (max_val - min_val)
|
||
if normalized < 0.2:
|
||
sparkline += "▁"
|
||
elif normalized < 0.4:
|
||
sparkline += "▂"
|
||
elif normalized < 0.6:
|
||
sparkline += "▄"
|
||
elif normalized < 0.8:
|
||
sparkline += "▆"
|
||
else:
|
||
sparkline += "█"
|
||
return sparkline
|
||
else:
|
||
return "─" * len(values) # Flat line if all values are same
|
||
|
||
|
||
def _generate_safe_ascii_pie_chart(data: List[Dict[str, Any]]) -> str:
|
||
"""Generate ASCII pie chart representation."""
|
||
if not data:
|
||
return "No data available for pie chart"
|
||
|
||
lines = ["ASCII Pie Chart", "=" * 50]
|
||
|
||
# Extract values and labels
|
||
values = []
|
||
labels = []
|
||
|
||
for row in data[:8]: # Limit slices
|
||
label = None
|
||
value = None
|
||
|
||
for _, val in row.items():
|
||
if isinstance(val, (int, float)) and not _is_nan(val) and value is None:
|
||
value = val
|
||
elif isinstance(val, str) and label is None:
|
||
label = val
|
||
|
||
if value is not None and value > 0:
|
||
values.append(value)
|
||
labels.append(label or f"Slice {len(values)}")
|
||
|
||
if not values:
|
||
return "No valid data for pie chart"
|
||
|
||
# Calculate percentages
|
||
total = sum(values)
|
||
if total == 0:
|
||
return "Total is zero"
|
||
|
||
for label, value in zip(labels, values, strict=False):
|
||
percentage = (value / total) * 100
|
||
bar_length = int(percentage / 3) # Scale to fit
|
||
bar = "●" * bar_length
|
||
lines.append(f"{label[:15]:>15}: {bar} {percentage:.1f}%")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_safe_ascii_table(data: List[Dict[str, Any]]) -> str:
|
||
"""Generate ASCII table with safe formatting."""
|
||
if not data:
|
||
return "No data available"
|
||
|
||
lines = ["Data Table", "=" * 50]
|
||
|
||
# Get columns
|
||
columns = list(data[0].keys()) if data else []
|
||
|
||
# Format header
|
||
header = " | ".join(str(col)[:10] for col in columns[:5])
|
||
lines.append(header)
|
||
lines.append("-" * len(header))
|
||
|
||
# Format rows
|
||
for row in data[:10]:
|
||
row_str = " | ".join(str(row.get(col, ""))[:10] for col in columns[:5])
|
||
lines.append(row_str)
|
||
|
||
if len(data) > 10:
|
||
lines.append(f"... {len(data) - 10} more rows")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _is_nan(value: Any) -> bool:
|
||
"""Check if a value is NaN."""
|
||
try:
|
||
import math
|
||
|
||
return math.isnan(float(value))
|
||
except (ValueError, TypeError):
|
||
return False
|
||
|
||
|
||
def _generate_vega_lite_preview_from_data( # noqa: C901
|
||
data: List[Dict[str, Any]], form_data: Dict[str, Any]
|
||
) -> VegaLitePreview:
|
||
"""Generate Vega-Lite preview from raw data and form_data."""
|
||
viz_type = form_data.get("viz_type", "table")
|
||
|
||
# Map Superset viz types to Vega-Lite marks
|
||
viz_to_mark = {
|
||
"echarts_timeseries_line": "line",
|
||
"echarts_timeseries_bar": "bar",
|
||
"echarts_area": "area",
|
||
"echarts_timeseries_scatter": "point",
|
||
"bar": "bar",
|
||
"line": "line",
|
||
"area": "area",
|
||
"scatter": "point",
|
||
"pie": "arc",
|
||
"table": "text",
|
||
}
|
||
|
||
mark = viz_to_mark.get(viz_type, "bar")
|
||
|
||
# Basic Vega-Lite spec
|
||
spec = {
|
||
"$schema": "https://vega.github.io/schema/vega-lite/v5.json",
|
||
"data": {"values": data},
|
||
"mark": mark,
|
||
}
|
||
|
||
# Get x_axis and metrics from form_data
|
||
x_axis = form_data.get("x_axis")
|
||
metrics = form_data.get("metrics", [])
|
||
groupby = form_data.get("groupby", [])
|
||
|
||
# Build encoding based on available fields
|
||
encoding = {}
|
||
|
||
# Handle X-axis
|
||
if x_axis and x_axis in (data[0] if data else {}):
|
||
# Detect field type from data
|
||
field_type = "nominal" # default
|
||
if data and len(data) > 0:
|
||
sample_val = data[0].get(x_axis)
|
||
if isinstance(sample_val, str):
|
||
# Check if it's a date/time
|
||
if any(char in str(sample_val) for char in ["-", "/", ":"]):
|
||
field_type = "temporal"
|
||
else:
|
||
field_type = "nominal"
|
||
elif isinstance(sample_val, (int, float)):
|
||
field_type = "quantitative"
|
||
|
||
encoding["x"] = {
|
||
"field": x_axis,
|
||
"type": field_type,
|
||
"title": x_axis,
|
||
}
|
||
|
||
# Handle Y-axis (metrics)
|
||
if metrics and data:
|
||
# Find the first metric column in the data
|
||
metric_col = None
|
||
for col in data[0].keys():
|
||
# Check if this is a metric column (usually has aggregation in name)
|
||
if any(
|
||
agg in str(col).upper()
|
||
for agg in ["SUM", "AVG", "COUNT", "MIN", "MAX", "TOTAL"]
|
||
):
|
||
metric_col = col
|
||
break
|
||
# Or check if it's numeric
|
||
elif isinstance(data[0].get(col), (int, float)):
|
||
metric_col = col
|
||
break
|
||
|
||
if metric_col:
|
||
encoding["y"] = {
|
||
"field": metric_col,
|
||
"type": "quantitative",
|
||
"title": metric_col,
|
||
}
|
||
|
||
# Handle color encoding for groupby
|
||
if groupby and len(groupby) > 0 and groupby[0] in (data[0] if data else {}):
|
||
encoding["color"] = {
|
||
"field": groupby[0],
|
||
"type": "nominal",
|
||
"title": groupby[0],
|
||
}
|
||
|
||
# Special handling for pie charts
|
||
if mark == "arc" and data:
|
||
# For pie charts, we need theta encoding
|
||
if "y" in encoding:
|
||
encoding["theta"] = encoding.pop("y")
|
||
encoding["theta"]["stack"] = True
|
||
if "x" in encoding:
|
||
# Use x as color for pie
|
||
encoding["color"] = {
|
||
"field": encoding["x"]["field"],
|
||
"type": "nominal",
|
||
}
|
||
del encoding["x"]
|
||
|
||
# Add encoding to spec
|
||
if encoding:
|
||
spec["encoding"] = encoding
|
||
|
||
# Add responsive sizing - Vega-Lite supports "container" as a special width value
|
||
spec["width"] = "container"
|
||
spec["height"] = 400 # type: ignore
|
||
|
||
# Add interactivity
|
||
if mark in ["line", "point", "bar", "area"]:
|
||
spec["selection"] = {
|
||
"highlight": {
|
||
"type": "single",
|
||
"on": "mouseover",
|
||
"empty": "none",
|
||
}
|
||
}
|
||
|
||
return VegaLitePreview(
|
||
specification=spec,
|
||
data_url=None,
|
||
supports_streaming=False,
|
||
)
|