feat(mcp): add pie, pivot table, and mixed timeseries chart creation support (#38375)

This commit is contained in:
Amin Ghadersohi
2026-03-06 02:13:47 -05:00
committed by GitHub
parent 3609cd9544
commit 84a53eab31
5 changed files with 1529 additions and 49 deletions

View File

@@ -118,6 +118,9 @@ Chart Types You Can CREATE with generate_chart/generate_explore_link:
- chart_type="xy", kind="scatter": Scatter plot for correlation analysis
- chart_type="table": Data table for detailed views
- chart_type="table", viz_type="ag-grid-table": Interactive AG Grid table
- chart_type="pie": Pie chart for proportional data (set donut=True for donut)
- chart_type="pivot_table": Interactive pivot table for cross-tabulation
- chart_type="mixed_timeseries": Dual-series chart combining two chart types
Time grain for temporal x-axis (time_grain parameter):
- PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P1Y (yearly)

View File

@@ -30,6 +30,9 @@ from superset.mcp_service.chart.schemas import (
ChartCapabilities,
ChartSemantics,
ColumnRef,
MixedTimeseriesChartConfig,
PieChartConfig,
PivotTableChartConfig,
TableChartConfig,
XYChartConfig,
)
@@ -301,7 +304,11 @@ def is_column_truly_temporal(column_name: str, dataset_id: int | str | None) ->
def map_config_to_form_data(
config: TableChartConfig | XYChartConfig,
config: TableChartConfig
| XYChartConfig
| PieChartConfig
| PivotTableChartConfig
| MixedTimeseriesChartConfig,
dataset_id: int | str | None = None,
) -> Dict[str, Any]:
"""Map chart config to Superset form_data."""
@@ -309,6 +316,12 @@ def map_config_to_form_data(
return map_table_config(config)
elif isinstance(config, XYChartConfig):
return map_xy_config(config, dataset_id=dataset_id)
elif isinstance(config, PieChartConfig):
return map_pie_config(config)
elif isinstance(config, PivotTableChartConfig):
return map_pivot_table_config(config)
elif isinstance(config, MixedTimeseriesChartConfig):
return map_mixed_timeseries_config(config, dataset_id=dataset_id)
else:
raise ValueError(f"Unsupported config type: {type(config)}")
@@ -567,6 +580,197 @@ def map_xy_config(
return form_data
def map_pie_config(config: PieChartConfig) -> Dict[str, Any]:
"""Map pie chart config to Superset form_data."""
metric = create_metric_object(config.metric)
form_data: Dict[str, Any] = {
"viz_type": "pie",
"groupby": [config.dimension.name],
"metric": metric,
"color_scheme": "supersetColors",
"show_labels": config.show_labels,
"show_legend": config.show_legend,
"label_type": config.label_type,
"number_format": config.number_format,
"sort_by_metric": config.sort_by_metric,
"row_limit": config.row_limit,
"donut": config.donut,
"show_total": config.show_total,
"labels_outside": config.labels_outside,
"outerRadius": config.outer_radius,
"innerRadius": config.inner_radius,
"date_format": "smart_date",
}
if config.filters:
form_data["adhoc_filters"] = [
{
"clause": "WHERE",
"expressionType": "SIMPLE",
"subject": filter_config.column,
"operator": map_filter_operator(filter_config.op),
"comparator": filter_config.value,
}
for filter_config in config.filters
if filter_config is not None
]
return form_data
def map_pivot_table_config(config: PivotTableChartConfig) -> Dict[str, Any]:
"""Map pivot table config to Superset form_data."""
if not config.rows:
raise ValueError("Pivot table must have at least one row grouping column")
if not config.metrics:
raise ValueError("Pivot table must have at least one metric")
metrics = [create_metric_object(col) for col in config.metrics]
form_data: Dict[str, Any] = {
"viz_type": "pivot_table_v2",
"groupbyRows": [col.name for col in config.rows],
"groupbyColumns": [col.name for col in config.columns]
if config.columns
else [],
"metrics": metrics,
"aggregateFunction": config.aggregate_function,
"rowTotals": config.show_row_totals,
"colTotals": config.show_column_totals,
"transposePivot": config.transpose,
"combineMetric": config.combine_metric,
"valueFormat": config.value_format,
"metricsLayout": "COLUMNS",
"rowOrder": "key_a_to_z",
"colOrder": "key_a_to_z",
"row_limit": config.row_limit,
}
if config.filters:
form_data["adhoc_filters"] = [
{
"clause": "WHERE",
"expressionType": "SIMPLE",
"subject": filter_config.column,
"operator": map_filter_operator(filter_config.op),
"comparator": filter_config.value,
}
for filter_config in config.filters
if filter_config is not None
]
return form_data
_MIXED_SERIES_TYPE_MAP = {
"line": "line",
"bar": "bar",
"area": "line", # area uses line type with area=True
"scatter": "scatter",
}
def _apply_axis_to_form_data(
form_data: Dict[str, Any],
axis_config: Any,
title_key: str,
format_key: str,
log_key: str | None = None,
) -> None:
"""Apply a single axis configuration to form_data."""
if not axis_config:
return
if axis_config.title:
form_data[title_key] = axis_config.title
if axis_config.format:
form_data[format_key] = axis_config.format
if log_key and axis_config.scale == "log":
form_data[log_key] = True
def _add_mixed_axis_config(
form_data: Dict[str, Any],
config: MixedTimeseriesChartConfig,
) -> None:
"""Add axis configurations to mixed timeseries form_data."""
_apply_axis_to_form_data(
form_data, config.x_axis, "xAxisTitle", "x_axis_time_format"
)
_apply_axis_to_form_data(
form_data, config.y_axis, "yAxisTitle", "y_axis_format", "logAxis"
)
_apply_axis_to_form_data(
form_data,
config.y_axis_secondary,
"yAxisTitleSecondary",
"y_axis_format_secondary",
"logAxisSecondary",
)
def map_mixed_timeseries_config(
config: MixedTimeseriesChartConfig,
dataset_id: int | str | None = None,
) -> Dict[str, Any]:
"""Map mixed timeseries chart config to Superset form_data."""
if not config.y:
raise ValueError("Mixed timeseries must have at least one primary metric")
if not config.y_secondary:
raise ValueError("Mixed timeseries must have at least one secondary metric")
# Check if x-axis column is truly temporal
x_is_temporal = is_column_truly_temporal(config.x.name, dataset_id)
form_data: Dict[str, Any] = {
"viz_type": "mixed_timeseries",
"x_axis": config.x.name,
# Query A
"metrics": [create_metric_object(col) for col in config.y],
"seriesType": _MIXED_SERIES_TYPE_MAP.get(config.primary_kind, "line"),
"area": config.primary_kind == "area",
"yAxisIndex": 0,
# Query B
"metrics_b": [create_metric_object(col) for col in config.y_secondary],
"seriesTypeB": _MIXED_SERIES_TYPE_MAP.get(config.secondary_kind, "bar"),
"areaB": config.secondary_kind == "area",
"yAxisIndexB": 1,
# Display
"show_legend": config.show_legend,
"zoomable": True,
"rich_tooltip": True,
}
# Configure temporal handling
configure_temporal_handling(form_data, x_is_temporal, config.time_grain)
# Primary groupby (Query A)
if config.group_by and config.group_by.name != config.x.name:
form_data["groupby"] = [config.group_by.name]
# Secondary groupby (Query B)
if config.group_by_secondary and config.group_by_secondary.name != config.x.name:
form_data["groupby_b"] = [config.group_by_secondary.name]
_add_mixed_axis_config(form_data, config)
# Filters
if config.filters:
form_data["adhoc_filters"] = [
{
"clause": "WHERE",
"expressionType": "SIMPLE",
"subject": filter_config.column,
"operator": map_filter_operator(filter_config.op),
"comparator": filter_config.value,
}
for filter_config in config.filters
if filter_config is not None
]
return form_data
def map_filter_operator(op: str) -> str:
"""Map filter operator to Superset format."""
operator_map = {
@@ -585,7 +789,13 @@ def map_filter_operator(op: str) -> str:
return operator_map.get(op, op)
def generate_chart_name(config: TableChartConfig | XYChartConfig) -> str:
def generate_chart_name(
config: TableChartConfig
| XYChartConfig
| PieChartConfig
| PivotTableChartConfig
| MixedTimeseriesChartConfig,
) -> str:
"""Generate a chart name based on the configuration."""
if isinstance(config, TableChartConfig):
return f"Table Chart - {', '.join(col.name for col in config.columns)}"
@@ -594,6 +804,16 @@ def generate_chart_name(config: TableChartConfig | XYChartConfig) -> str:
x_col = config.x.name
y_cols = ", ".join(col.name for col in config.y)
return f"{chart_type} Chart - {x_col} vs {y_cols}"
elif isinstance(config, PieChartConfig):
metric_label = config.metric.label or config.metric.name
return f"Pie Chart - {config.dimension.name} by {metric_label}"
elif isinstance(config, PivotTableChartConfig):
rows = ", ".join(col.name for col in config.rows)
return f"Pivot Table - {rows}"
elif isinstance(config, MixedTimeseriesChartConfig):
primary = ", ".join(col.name for col in config.y)
secondary = ", ".join(col.name for col in config.y_secondary)
return f"Mixed Chart - {primary} + {secondary}"
else:
return "Chart"
@@ -603,22 +823,7 @@ def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilit
if chart:
viz_type = getattr(chart, "viz_type", "unknown")
else:
# Map config chart_type to viz_type
chart_type = getattr(config, "chart_type", "unknown")
if chart_type == "xy":
kind = getattr(config, "kind", "line")
viz_type_map = {
"line": "echarts_timeseries_line",
"bar": "echarts_timeseries_bar",
"area": "echarts_area",
"scatter": "echarts_timeseries_scatter",
}
viz_type = viz_type_map.get(kind, "echarts_timeseries_line")
elif chart_type == "table":
# Use the viz_type from config if available (table or ag-grid-table)
viz_type = getattr(config, "viz_type", "table")
else:
viz_type = "unknown"
viz_type = _resolve_viz_type(config)
# Determine interaction capabilities based on chart type
interactive_types = [
@@ -663,27 +868,35 @@ def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilit
)
def _resolve_viz_type(config: Any) -> str:
"""Resolve viz_type from a chart config object."""
chart_type = getattr(config, "chart_type", "unknown")
if chart_type == "xy":
kind = getattr(config, "kind", "line")
viz_type_map = {
"line": "echarts_timeseries_line",
"bar": "echarts_timeseries_bar",
"area": "echarts_area",
"scatter": "echarts_timeseries_scatter",
}
return viz_type_map.get(kind, "echarts_timeseries_line")
elif chart_type == "table":
return getattr(config, "viz_type", "table")
elif chart_type == "pie":
return "pie"
elif chart_type == "pivot_table":
return "pivot_table_v2"
elif chart_type == "mixed_timeseries":
return "mixed_timeseries"
return "unknown"
def analyze_chart_semantics(chart: Any | None, config: Any) -> ChartSemantics:
"""Generate semantic understanding of the chart."""
if chart:
viz_type = getattr(chart, "viz_type", "unknown")
else:
# Map config chart_type to viz_type
chart_type = getattr(config, "chart_type", "unknown")
if chart_type == "xy":
kind = getattr(config, "kind", "line")
viz_type_map = {
"line": "echarts_timeseries_line",
"bar": "echarts_timeseries_bar",
"area": "echarts_area",
"scatter": "echarts_timeseries_scatter",
}
viz_type = viz_type_map.get(kind, "echarts_timeseries_line")
elif chart_type == "table":
# Use the viz_type from config if available (table or ag-grid-table)
viz_type = getattr(config, "viz_type", "table")
else:
viz_type = "unknown"
viz_type = _resolve_viz_type(config)
# Generate primary insight based on chart type
insights_map = {
@@ -696,6 +909,14 @@ def analyze_chart_semantics(chart: Any | None, config: Any) -> ChartSemantics:
),
"pie": "Shows proportional relationships within a dataset",
"echarts_area": "Emphasizes cumulative totals and part-to-whole relationships",
"pivot_table_v2": (
"Cross-tabulates data with rows, columns, and aggregated metrics "
"for multi-dimensional analysis"
),
"mixed_timeseries": (
"Combines two different chart types on the same time axis "
"for comparing related metrics with different scales"
),
}
primary_insight = insights_map.get(

View File

@@ -485,6 +485,183 @@ class FilterConfig(BaseModel):
# Actual chart types
class PieChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
chart_type: Literal["pie"] = Field(
...,
description=(
"Chart type discriminator - MUST be 'pie' for pie/donut charts. "
"This field is REQUIRED and tells Superset which chart "
"configuration schema to use."
),
)
dimension: ColumnRef = Field(
..., description="Category column that defines the pie slices"
)
metric: ColumnRef = Field(
...,
description=(
"Value metric that determines slice sizes. "
"Must include an aggregate function (e.g., SUM, COUNT)."
),
)
donut: bool = Field(False, description="Render as a donut chart with a center hole")
show_labels: bool = Field(True, description="Display labels on slices")
label_type: Literal[
"key",
"value",
"percent",
"key_value",
"key_percent",
"key_value_percent",
"value_percent",
] = Field("key_value_percent", description="Type of labels to show on slices")
sort_by_metric: bool = Field(True, description="Sort slices by metric value")
show_legend: bool = Field(True, description="Whether to show legend")
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
row_limit: int = Field(
100,
description="Maximum number of slices to display",
ge=1,
le=10000,
)
number_format: str = Field(
"SMART_NUMBER",
description="Number format string",
max_length=50,
)
show_total: bool = Field(False, description="Display aggregate count in center")
labels_outside: bool = Field(True, description="Place labels outside the pie")
outer_radius: int = Field(
70,
description="Outer edge radius as a percentage (1-100)",
ge=1,
le=100,
)
inner_radius: int = Field(
30,
description="Inner radius as a percentage for donut (1-100)",
ge=1,
le=100,
)
class PivotTableChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
chart_type: Literal["pivot_table"] = Field(
...,
description=(
"Chart type discriminator - MUST be 'pivot_table' for interactive "
"pivot tables. This field is REQUIRED."
),
)
rows: List[ColumnRef] = Field(
...,
min_length=1,
description="Row grouping columns (at least one required)",
)
columns: List[ColumnRef] | None = Field(
None,
description="Column grouping columns (optional, for cross-tabulation)",
)
metrics: List[ColumnRef] = Field(
...,
min_length=1,
description=(
"Metrics to aggregate. Each must have an aggregate function "
"(e.g., SUM, COUNT, AVG)."
),
)
aggregate_function: Literal[
"Sum",
"Average",
"Median",
"Sample Variance",
"Sample Standard Deviation",
"Minimum",
"Maximum",
"Count",
"Count Unique Values",
"First",
"Last",
] = Field("Sum", description="Default aggregation function for the pivot table")
show_row_totals: bool = Field(True, description="Show row totals")
show_column_totals: bool = Field(True, description="Show column totals")
transpose: bool = Field(False, description="Swap rows and columns")
combine_metric: bool = Field(
False,
description="Display metrics side by side within columns",
)
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
row_limit: int = Field(
10000,
description="Maximum number of cells",
ge=1,
le=50000,
)
value_format: str = Field(
"SMART_NUMBER",
description="Value format string",
max_length=50,
)
class MixedTimeseriesChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
chart_type: Literal["mixed_timeseries"] = Field(
...,
description=(
"Chart type discriminator - MUST be 'mixed_timeseries' for charts "
"that combine two different series types (e.g., line + bar). "
"This field is REQUIRED."
),
)
x: ColumnRef = Field(..., description="X-axis temporal column (shared)")
time_grain: TimeGrain | None = Field(
None,
description=(
"Time granularity for the x-axis. "
"Common values: PT1H (hourly), P1D (daily), P1W (weekly), "
"P1M (monthly), P1Y (yearly)."
),
)
# Primary series (Query A)
y: List[ColumnRef] = Field(
...,
min_length=1,
description="Primary Y-axis metrics (Query A)",
)
primary_kind: Literal["line", "bar", "area", "scatter"] = Field(
"line", description="Primary series chart type"
)
group_by: ColumnRef | None = Field(
None, description="Group by column for primary series"
)
# Secondary series (Query B)
y_secondary: List[ColumnRef] = Field(
...,
min_length=1,
description="Secondary Y-axis metrics (Query B)",
)
secondary_kind: Literal["line", "bar", "area", "scatter"] = Field(
"bar", description="Secondary series chart type"
)
group_by_secondary: ColumnRef | None = Field(
None, description="Group by column for secondary series"
)
# Display options
show_legend: bool = Field(True, description="Whether to show legend")
x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
y_axis: AxisConfig | None = Field(None, description="Primary Y-axis configuration")
y_axis_secondary: AxisConfig | None = Field(
None, description="Secondary Y-axis configuration"
)
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
class TableChartConfig(BaseModel):
model_config = ConfigDict(extra="forbid")
@@ -631,10 +808,17 @@ class XYChartConfig(BaseModel):
# Discriminated union entry point with custom error handling
ChartConfig = Annotated[
XYChartConfig | TableChartConfig,
XYChartConfig
| TableChartConfig
| PieChartConfig
| PivotTableChartConfig
| MixedTimeseriesChartConfig,
Field(
discriminator="chart_type",
description="Chart configuration - specify chart_type as 'xy' or 'table'",
description=(
"Chart configuration - specify chart_type as 'xy', 'table', "
"'pie', 'pivot_table', or 'mixed_timeseries'"
),
),
]

View File

@@ -126,37 +126,52 @@ class SchemaValidator:
return False, ChartGenerationError(
error_type="missing_chart_type",
message="Missing required field: chart_type",
details="Chart configuration must specify 'chart_type' as either 'xy' "
"or 'table'",
details="Chart configuration must specify 'chart_type'",
suggestions=[
"Add 'chart_type': 'xy' for line/bar/area/scatter charts",
"Add 'chart_type': 'table' for table visualizations",
"Example: 'config': {'chart_type': 'xy', ...}",
"Add 'chart_type': 'pie' for pie or donut charts",
"Add 'chart_type': 'pivot_table' for interactive pivot tables",
"Add 'chart_type': 'mixed_timeseries' for dual-series time charts",
],
error_code="MISSING_CHART_TYPE",
)
if chart_type not in ["xy", "table"]:
return SchemaValidator._pre_validate_chart_type(chart_type, config)
@staticmethod
def _pre_validate_chart_type(
chart_type: str,
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Validate chart type and dispatch to type-specific pre-validation."""
chart_type_validators = {
"xy": SchemaValidator._pre_validate_xy_config,
"table": SchemaValidator._pre_validate_table_config,
"pie": SchemaValidator._pre_validate_pie_config,
"pivot_table": SchemaValidator._pre_validate_pivot_table_config,
"mixed_timeseries": SchemaValidator._pre_validate_mixed_timeseries_config,
}
if not isinstance(chart_type, str) or chart_type not in chart_type_validators:
valid_types = ", ".join(chart_type_validators.keys())
return False, ChartGenerationError(
error_type="invalid_chart_type",
message=f"Invalid chart_type: '{chart_type}'",
details=f"Chart type '{chart_type}' is not supported. Must be 'xy' or "
f"'table'",
details=f"Chart type '{chart_type}' is not supported. "
f"Must be one of: {valid_types}",
suggestions=[
"Use 'chart_type': 'xy' for line, bar, area, or scatter charts",
"Use 'chart_type': 'table' for tabular data display",
"Use 'chart_type': 'pie' for pie or donut charts",
"Use 'chart_type': 'pivot_table' for interactive pivot tables",
"Use 'chart_type': 'mixed_timeseries' for dual-series time charts",
"Check spelling and ensure lowercase",
],
error_code="INVALID_CHART_TYPE",
)
# Pre-validate structure based on chart type
if chart_type == "xy":
return SchemaValidator._pre_validate_xy_config(config)
elif chart_type == "table":
return SchemaValidator._pre_validate_table_config(config)
return True, None
return chart_type_validators[chart_type](config)
@staticmethod
def _pre_validate_xy_config(
@@ -237,6 +252,134 @@ class SchemaValidator:
return True, None
@staticmethod
def _pre_validate_pie_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate pie chart configuration."""
missing_fields = []
if "dimension" not in config:
missing_fields.append("'dimension' (category column for slices)")
if "metric" not in config:
missing_fields.append("'metric' (value metric for slice sizes)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_pie_fields",
message=f"Pie chart missing required "
f"fields: {', '.join(missing_fields)}",
details="Pie charts require a dimension (categories) and a metric "
"(values)",
suggestions=[
"Add 'dimension' field: {'name': 'category_column'}",
"Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}",
"Example: {'chart_type': 'pie', 'dimension': {'name': "
"'product'}, 'metric': {'name': 'revenue', 'aggregate': 'SUM'}}",
],
error_code="MISSING_PIE_FIELDS",
)
return True, None
@staticmethod
def _pre_validate_pivot_table_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate pivot table configuration."""
missing_fields = []
if "rows" not in config:
missing_fields.append("'rows' (row grouping columns)")
if "metrics" not in config:
missing_fields.append("'metrics' (aggregation metrics)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_pivot_fields",
message=f"Pivot table missing required "
f"fields: {', '.join(missing_fields)}",
details="Pivot tables require row groupings and metrics",
suggestions=[
"Add 'rows' field: [{'name': 'category'}]",
"Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]",
"Optional 'columns' for cross-tabulation: [{'name': 'region'}]",
],
error_code="MISSING_PIVOT_FIELDS",
)
if not isinstance(config.get("rows", []), list):
return False, ChartGenerationError(
error_type="invalid_rows_format",
message="Rows must be a list of columns",
details="The 'rows' field must be an array of column specifications",
suggestions=[
"Wrap row columns in array: 'rows': [{'name': 'category'}]",
],
error_code="INVALID_ROWS_FORMAT",
)
if not isinstance(config.get("metrics", []), list):
return False, ChartGenerationError(
error_type="invalid_metrics_format",
message="Metrics must be a list",
details="The 'metrics' field must be an array of metric specifications",
suggestions=[
"Wrap metrics in array: 'metrics': [{'name': 'sales', "
"'aggregate': 'SUM'}]",
],
error_code="INVALID_METRICS_FORMAT",
)
return True, None
@staticmethod
def _pre_validate_mixed_timeseries_config(
config: Dict[str, Any],
) -> Tuple[bool, ChartGenerationError | None]:
"""Pre-validate mixed timeseries configuration."""
missing_fields = []
if "x" not in config:
missing_fields.append("'x' (X-axis temporal column)")
if "y" not in config:
missing_fields.append("'y' (primary Y-axis metrics)")
if "y_secondary" not in config:
missing_fields.append("'y_secondary' (secondary Y-axis metrics)")
if missing_fields:
return False, ChartGenerationError(
error_type="missing_mixed_timeseries_fields",
message=f"Mixed timeseries chart missing required "
f"fields: {', '.join(missing_fields)}",
details="Mixed timeseries charts require an x-axis, primary metrics, "
"and secondary metrics",
suggestions=[
"Add 'x' field: {'name': 'date_column'}",
"Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]",
"Add 'y_secondary' field: [{'name': 'orders', "
"'aggregate': 'COUNT'}]",
"Optional: 'primary_kind' and 'secondary_kind' for chart types",
],
error_code="MISSING_MIXED_TIMESERIES_FIELDS",
)
for field_name in ["y", "y_secondary"]:
if not isinstance(config.get(field_name, []), list):
return False, ChartGenerationError(
error_type=f"invalid_{field_name}_format",
message=f"'{field_name}' must be a list of metrics",
details=f"The '{field_name}' field must be an array of metric "
"specifications",
suggestions=[
f"Wrap in array: '{field_name}': "
"[{'name': 'col', 'aggregate': 'SUM'}]",
],
error_code=f"INVALID_{field_name.upper()}_FORMAT",
)
return True, None
@staticmethod
def _enhance_validation_error(
error: PydanticValidationError, request_data: Dict[str, Any]

View File

@@ -0,0 +1,929 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for new MCP chart types: pie, pivot_table, mixed_timeseries.
Tests cover schema validation, form_data mapping, chart name generation,
and schema validator pre-validation for all three new chart types.
"""
from unittest.mock import patch
import pytest
from pydantic import ValidationError
from superset.mcp_service.chart.chart_utils import (
generate_chart_name,
map_config_to_form_data,
map_mixed_timeseries_config,
map_pie_config,
map_pivot_table_config,
)
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
MixedTimeseriesChartConfig,
PieChartConfig,
PivotTableChartConfig,
)
from superset.mcp_service.chart.validation.schema_validator import SchemaValidator
# ============================================================
# Pie Chart Schema Tests
# ============================================================
class TestPieChartConfigSchema:
"""Test PieChartConfig Pydantic schema validation."""
def test_basic_pie_config(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
assert config.chart_type == "pie"
assert config.dimension.name == "product"
assert config.metric.aggregate == "SUM"
assert config.donut is False
assert config.show_labels is True
assert config.label_type == "key_value_percent"
def test_donut_chart_config(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="category"),
metric=ColumnRef(name="count", aggregate="COUNT"),
donut=True,
inner_radius=40,
outer_radius=80,
)
assert config.donut is True
assert config.inner_radius == 40
assert config.outer_radius == 80
def test_pie_config_with_all_options(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="region"),
metric=ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
donut=True,
show_labels=False,
label_type="percent",
sort_by_metric=False,
show_legend=False,
row_limit=50,
number_format="$,.2f",
show_total=True,
labels_outside=False,
outer_radius=90,
inner_radius=50,
filters=[FilterConfig(column="status", op="=", value="active")],
)
assert config.show_labels is False
assert config.label_type == "percent"
assert config.row_limit == 50
assert config.show_total is True
assert config.filters is not None
assert len(config.filters) == 1
def test_pie_config_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError):
PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
unknown_field="bad",
)
def test_pie_config_missing_dimension(self) -> None:
with pytest.raises(ValidationError):
PieChartConfig(
chart_type="pie",
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
def test_pie_config_missing_metric(self) -> None:
with pytest.raises(ValidationError):
PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
)
def test_pie_config_row_limit_bounds(self) -> None:
with pytest.raises(ValidationError):
PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
row_limit=0,
)
def test_pie_config_valid_label_types(self) -> None:
for label_type in [
"key",
"value",
"percent",
"key_value",
"key_percent",
"key_value_percent",
"value_percent",
]:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
label_type=label_type,
)
assert config.label_type == label_type
# ============================================================
# Pie Chart Form Data Mapping Tests
# ============================================================
class TestMapPieConfig:
"""Test map_pie_config form_data generation."""
def test_basic_pie_form_data(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
result = map_pie_config(config)
assert result["viz_type"] == "pie"
assert result["groupby"] == ["product"]
assert result["metric"]["aggregate"] == "SUM"
assert result["metric"]["column"]["column_name"] == "revenue"
assert result["show_labels"] is True
assert result["show_legend"] is True
assert result["label_type"] == "key_value_percent"
assert result["sort_by_metric"] is True
assert result["row_limit"] == 100
assert result["donut"] is False
assert result["color_scheme"] == "supersetColors"
def test_donut_form_data(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="category"),
metric=ColumnRef(name="count", aggregate="COUNT"),
donut=True,
inner_radius=40,
outer_radius=80,
)
result = map_pie_config(config)
assert result["donut"] is True
assert result["innerRadius"] == 40
assert result["outerRadius"] == 80
def test_pie_form_data_with_filters(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
filters=[FilterConfig(column="region", op="=", value="US")],
)
result = map_pie_config(config)
assert "adhoc_filters" in result
assert len(result["adhoc_filters"]) == 1
assert result["adhoc_filters"][0]["subject"] == "region"
assert result["adhoc_filters"][0]["operator"] == "=="
assert result["adhoc_filters"][0]["comparator"] == "US"
def test_pie_form_data_custom_options(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="status"),
metric=ColumnRef(name="count", aggregate="COUNT"),
show_labels=False,
label_type="percent",
show_legend=False,
number_format="$,.2f",
show_total=True,
labels_outside=False,
)
result = map_pie_config(config)
assert result["show_labels"] is False
assert result["label_type"] == "percent"
assert result["show_legend"] is False
assert result["number_format"] == "$,.2f"
assert result["show_total"] is True
assert result["labels_outside"] is False
def test_pie_form_data_custom_metric_label(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"),
)
result = map_pie_config(config)
assert result["metric"]["label"] == "Total Revenue"
assert result["metric"]["hasCustomLabel"] is True
# ============================================================
# Pivot Table Schema Tests
# ============================================================
class TestPivotTableChartConfigSchema:
"""Test PivotTableChartConfig Pydantic schema validation."""
def test_basic_pivot_table_config(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
assert config.chart_type == "pivot_table"
assert len(config.rows) == 1
assert len(config.metrics) == 1
assert config.aggregate_function == "Sum"
assert config.show_row_totals is True
assert config.show_column_totals is True
def test_pivot_table_with_columns(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
columns=[ColumnRef(name="region")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
assert config.columns is not None
assert len(config.columns) == 1
assert config.columns[0].name == "region"
def test_pivot_table_with_all_options(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product"), ColumnRef(name="category")],
columns=[ColumnRef(name="region")],
metrics=[
ColumnRef(name="revenue", aggregate="SUM"),
ColumnRef(name="orders", aggregate="COUNT"),
],
aggregate_function="Average",
show_row_totals=False,
show_column_totals=False,
transpose=True,
combine_metric=True,
row_limit=5000,
value_format="$,.2f",
filters=[FilterConfig(column="year", op="=", value=2024)],
)
assert config.aggregate_function == "Average"
assert config.show_row_totals is False
assert config.transpose is True
assert config.combine_metric is True
assert config.row_limit == 5000
def test_pivot_table_missing_rows(self) -> None:
with pytest.raises(ValidationError):
PivotTableChartConfig(
chart_type="pivot_table",
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
def test_pivot_table_missing_metrics(self) -> None:
with pytest.raises(ValidationError):
PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
)
def test_pivot_table_empty_rows(self) -> None:
with pytest.raises(ValidationError):
PivotTableChartConfig(
chart_type="pivot_table",
rows=[],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
def test_pivot_table_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError):
PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
unknown_field="bad",
)
def test_pivot_table_valid_aggregate_functions(self) -> None:
for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
aggregate_function=agg,
)
assert config.aggregate_function == agg
# ============================================================
# Pivot Table Form Data Mapping Tests
# ============================================================
class TestMapPivotTableConfig:
"""Test map_pivot_table_config form_data generation."""
def test_basic_pivot_form_data(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
result = map_pivot_table_config(config)
assert result["viz_type"] == "pivot_table_v2"
assert result["groupbyRows"] == ["product"]
assert result["groupbyColumns"] == []
assert len(result["metrics"]) == 1
assert result["metrics"][0]["aggregate"] == "SUM"
assert result["aggregateFunction"] == "Sum"
assert result["rowTotals"] is True
assert result["colTotals"] is True
assert result["metricsLayout"] == "COLUMNS"
def test_pivot_form_data_with_columns(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
columns=[ColumnRef(name="region"), ColumnRef(name="quarter")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
result = map_pivot_table_config(config)
assert result["groupbyRows"] == ["product"]
assert result["groupbyColumns"] == ["region", "quarter"]
def test_pivot_form_data_with_filters(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
filters=[FilterConfig(column="year", op="=", value=2024)],
)
result = map_pivot_table_config(config)
assert "adhoc_filters" in result
assert len(result["adhoc_filters"]) == 1
assert result["adhoc_filters"][0]["subject"] == "year"
def test_pivot_form_data_custom_options(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
aggregate_function="Average",
show_row_totals=False,
show_column_totals=False,
transpose=True,
combine_metric=True,
value_format="$,.2f",
)
result = map_pivot_table_config(config)
assert result["aggregateFunction"] == "Average"
assert result["rowTotals"] is False
assert result["colTotals"] is False
assert result["transposePivot"] is True
assert result["combineMetric"] is True
assert result["valueFormat"] == "$,.2f"
def test_pivot_form_data_multiple_metrics(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[
ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"),
ColumnRef(name="orders", aggregate="COUNT", label="Order Count"),
],
)
result = map_pivot_table_config(config)
assert len(result["metrics"]) == 2
assert result["metrics"][0]["label"] == "Total Revenue"
assert result["metrics"][1]["label"] == "Order Count"
# ============================================================
# Mixed Timeseries Schema Tests
# ============================================================
class TestMixedTimeseriesChartConfigSchema:
"""Test MixedTimeseriesChartConfig Pydantic schema validation."""
def test_basic_mixed_timeseries_config(self) -> None:
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="order_date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
assert config.chart_type == "mixed_timeseries"
assert config.x.name == "order_date"
assert config.primary_kind == "line"
assert config.secondary_kind == "bar"
assert config.show_legend is True
def test_mixed_timeseries_with_all_options(self) -> None:
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
time_grain="P1M",
y=[ColumnRef(name="revenue", aggregate="SUM")],
primary_kind="area",
group_by=ColumnRef(name="region"),
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
secondary_kind="scatter",
group_by_secondary=ColumnRef(name="channel"),
show_legend=False,
x_axis=AxisConfig(title="Date"),
y_axis=AxisConfig(title="Revenue", format="$,.2f"),
y_axis_secondary=AxisConfig(title="Orders", scale="log"),
filters=[FilterConfig(column="status", op="=", value="complete")],
)
assert config.primary_kind == "area"
assert config.secondary_kind == "scatter"
assert config.time_grain == "P1M"
assert config.group_by is not None
assert config.group_by.name == "region"
assert config.group_by_secondary is not None
assert config.group_by_secondary.name == "channel"
def test_mixed_timeseries_missing_y(self) -> None:
with pytest.raises(ValidationError):
MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
def test_mixed_timeseries_missing_y_secondary(self) -> None:
with pytest.raises(ValidationError):
MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
)
def test_mixed_timeseries_empty_y(self) -> None:
with pytest.raises(ValidationError):
MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
def test_mixed_timeseries_rejects_extra_fields(self) -> None:
with pytest.raises(ValidationError):
MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
unknown_field="bad",
)
# ============================================================
# Mixed Timeseries Form Data Mapping Tests
# ============================================================
class TestMapMixedTimeseriesConfig:
"""Test map_mixed_timeseries_config form_data generation."""
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_basic_mixed_form_data(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="order_date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
result = map_mixed_timeseries_config(config, dataset_id=1)
assert result["viz_type"] == "mixed_timeseries"
assert result["x_axis"] == "order_date"
assert len(result["metrics"]) == 1
assert result["metrics"][0]["aggregate"] == "SUM"
assert len(result["metrics_b"]) == 1
assert result["metrics_b"][0]["aggregate"] == "COUNT"
assert result["seriesType"] == "line"
assert result["seriesTypeB"] == "bar"
assert result["yAxisIndex"] == 0
assert result["yAxisIndexB"] == 1
assert result["show_legend"] is True
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_mixed_form_data_with_time_grain(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
time_grain="P1W",
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
result = map_mixed_timeseries_config(config, dataset_id=1)
assert result["time_grain_sqla"] == "P1W"
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_mixed_form_data_area_series(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
primary_kind="area",
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
secondary_kind="area",
)
result = map_mixed_timeseries_config(config, dataset_id=1)
assert result["seriesType"] == "line"
assert result["area"] is True
assert result["seriesTypeB"] == "line"
assert result["areaB"] is True
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_mixed_form_data_with_groupby(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
group_by=ColumnRef(name="region"),
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
group_by_secondary=ColumnRef(name="channel"),
)
result = map_mixed_timeseries_config(config, dataset_id=1)
assert result["groupby"] == ["region"]
assert result["groupby_b"] == ["channel"]
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_mixed_form_data_groupby_same_as_x_ignored(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
group_by=ColumnRef(name="date"), # same as x
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
group_by_secondary=ColumnRef(name="date"), # same as x
)
result = map_mixed_timeseries_config(config, dataset_id=1)
assert "groupby" not in result
assert "groupby_b" not in result
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_mixed_form_data_with_axis_config(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
x_axis=AxisConfig(title="Date"),
y_axis=AxisConfig(title="Revenue", format="$,.2f", scale="log"),
y_axis_secondary=AxisConfig(title="Orders", format=",d", scale="log"),
)
result = map_mixed_timeseries_config(config, dataset_id=1)
assert result["xAxisTitle"] == "Date"
assert result["yAxisTitle"] == "Revenue"
assert result["y_axis_format"] == "$,.2f"
assert result["logAxis"] is True
assert result["yAxisTitleSecondary"] == "Orders"
assert result["y_axis_format_secondary"] == ",d"
assert result["logAxisSecondary"] is True
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_mixed_form_data_with_filters(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
filters=[FilterConfig(column="status", op="=", value="complete")],
)
result = map_mixed_timeseries_config(config, dataset_id=1)
assert "adhoc_filters" in result
assert len(result["adhoc_filters"]) == 1
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_mixed_form_data_non_temporal_x(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = False
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="year"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
result = map_mixed_timeseries_config(config, dataset_id=1)
assert result["time_grain_sqla"] is None
assert result["granularity_sqla"] is None
assert result["x_axis_sort_series_type"] == "name"
# ============================================================
# map_config_to_form_data Dispatch Tests
# ============================================================
class TestMapConfigToFormDataDispatch:
"""Test map_config_to_form_data dispatches to correct mapping function."""
def test_dispatches_pie_config(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
result = map_config_to_form_data(config)
assert result["viz_type"] == "pie"
def test_dispatches_pivot_table_config(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
result = map_config_to_form_data(config)
assert result["viz_type"] == "pivot_table_v2"
@patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal")
def test_dispatches_mixed_timeseries_config(self, mock_is_temporal) -> None:
mock_is_temporal.return_value = True
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
result = map_config_to_form_data(config, dataset_id=1)
assert result["viz_type"] == "mixed_timeseries"
# ============================================================
# Chart Name Generation Tests
# ============================================================
class TestGenerateChartNameNewTypes:
"""Test generate_chart_name for new chart types."""
def test_pie_chart_name(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM"),
)
result = generate_chart_name(config)
assert result == "Pie Chart - product by revenue"
def test_pie_chart_name_with_custom_label(self) -> None:
config = PieChartConfig(
chart_type="pie",
dimension=ColumnRef(name="product"),
metric=ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"),
)
result = generate_chart_name(config)
assert result == "Pie Chart - product by Total Revenue"
def test_pivot_table_chart_name(self) -> None:
config = PivotTableChartConfig(
chart_type="pivot_table",
rows=[ColumnRef(name="product"), ColumnRef(name="region")],
metrics=[ColumnRef(name="revenue", aggregate="SUM")],
)
result = generate_chart_name(config)
assert result == "Pivot Table - product, region"
def test_mixed_timeseries_chart_name(self) -> None:
config = MixedTimeseriesChartConfig(
chart_type="mixed_timeseries",
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
y_secondary=[ColumnRef(name="orders", aggregate="COUNT")],
)
result = generate_chart_name(config)
assert result == "Mixed Chart - revenue + orders"
# ============================================================
# Schema Validator Pre-Validation Tests
# ============================================================
class TestSchemaValidatorNewTypes:
"""Test SchemaValidator pre-validation for new chart types."""
def test_pie_chart_type_accepted(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "pie",
"dimension": {"name": "product"},
"metric": {"name": "revenue", "aggregate": "SUM"},
},
}
is_valid, request, error = SchemaValidator.validate_request(data)
assert is_valid is True
assert error is None
def test_pivot_table_chart_type_accepted(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "pivot_table",
"rows": [{"name": "product"}],
"metrics": [{"name": "revenue", "aggregate": "SUM"}],
},
}
is_valid, request, error = SchemaValidator.validate_request(data)
assert is_valid is True
assert error is None
def test_mixed_timeseries_chart_type_accepted(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "mixed_timeseries",
"x": {"name": "date"},
"y": [{"name": "revenue", "aggregate": "SUM"}],
"y_secondary": [{"name": "orders", "aggregate": "COUNT"}],
},
}
is_valid, request, error = SchemaValidator.validate_request(data)
assert is_valid is True
assert error is None
def test_pie_missing_dimension_rejected(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "pie",
"metric": {"name": "revenue", "aggregate": "SUM"},
},
}
is_valid, _, error = SchemaValidator.validate_request(data)
assert is_valid is False
assert error is not None
assert (
"dimension" in error.message.lower()
or "dimension" in (error.details or "").lower()
)
def test_pie_missing_metric_rejected(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "pie",
"dimension": {"name": "product"},
},
}
is_valid, _, error = SchemaValidator.validate_request(data)
assert is_valid is False
assert error is not None
def test_pivot_table_missing_rows_rejected(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "pivot_table",
"metrics": [{"name": "revenue", "aggregate": "SUM"}],
},
}
is_valid, _, error = SchemaValidator.validate_request(data)
assert is_valid is False
assert error is not None
assert (
"rows" in error.message.lower() or "rows" in (error.details or "").lower()
)
def test_pivot_table_missing_metrics_rejected(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "pivot_table",
"rows": [{"name": "product"}],
},
}
is_valid, _, error = SchemaValidator.validate_request(data)
assert is_valid is False
assert error is not None
def test_mixed_timeseries_missing_y_rejected(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "mixed_timeseries",
"x": {"name": "date"},
"y_secondary": [{"name": "orders", "aggregate": "COUNT"}],
},
}
is_valid, _, error = SchemaValidator.validate_request(data)
assert is_valid is False
assert error is not None
def test_mixed_timeseries_missing_y_secondary_rejected(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "mixed_timeseries",
"x": {"name": "date"},
"y": [{"name": "revenue", "aggregate": "SUM"}],
},
}
is_valid, _, error = SchemaValidator.validate_request(data)
assert is_valid is False
assert error is not None
def test_mixed_timeseries_missing_x_rejected(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "mixed_timeseries",
"y": [{"name": "revenue", "aggregate": "SUM"}],
"y_secondary": [{"name": "orders", "aggregate": "COUNT"}],
},
}
is_valid, _, error = SchemaValidator.validate_request(data)
assert is_valid is False
assert error is not None
def test_invalid_chart_type_lists_all_options(self) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": "invalid_type",
},
}
is_valid, _, error = SchemaValidator.validate_request(data)
assert is_valid is False
assert error is not None
assert "pie" in (error.details or "").lower()
assert "pivot_table" in (error.details or "").lower()
assert "mixed_timeseries" in (error.details or "").lower()
@pytest.mark.parametrize(
"bad_chart_type",
[["xy"], {"type": "xy"}, 123, True],
)
def test_non_string_chart_type_rejected_gracefully(
self, bad_chart_type: object
) -> None:
data = {
"dataset_id": 1,
"config": {
"chart_type": bad_chart_type,
},
}
is_valid, _, error = SchemaValidator.validate_request(data)
assert is_valid is False
assert error is not None
assert error.error_code == "INVALID_CHART_TYPE"