diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 361f3d105c6..5f6f867c709 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -118,6 +118,9 @@ Chart Types You Can CREATE with generate_chart/generate_explore_link: - chart_type="xy", kind="scatter": Scatter plot for correlation analysis - chart_type="table": Data table for detailed views - chart_type="table", viz_type="ag-grid-table": Interactive AG Grid table +- chart_type="pie": Pie chart for proportional data (set donut=True for donut) +- chart_type="pivot_table": Interactive pivot table for cross-tabulation +- chart_type="mixed_timeseries": Dual-series chart combining two chart types Time grain for temporal x-axis (time_grain parameter): - PT1H (hourly), P1D (daily), P1W (weekly), P1M (monthly), P1Y (yearly) diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index d798fe34009..600252f2cfa 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -30,6 +30,9 @@ from superset.mcp_service.chart.schemas import ( ChartCapabilities, ChartSemantics, ColumnRef, + MixedTimeseriesChartConfig, + PieChartConfig, + PivotTableChartConfig, TableChartConfig, XYChartConfig, ) @@ -301,7 +304,11 @@ def is_column_truly_temporal(column_name: str, dataset_id: int | str | None) -> def map_config_to_form_data( - config: TableChartConfig | XYChartConfig, + config: TableChartConfig + | XYChartConfig + | PieChartConfig + | PivotTableChartConfig + | MixedTimeseriesChartConfig, dataset_id: int | str | None = None, ) -> Dict[str, Any]: """Map chart config to Superset form_data.""" @@ -309,6 +316,12 @@ def map_config_to_form_data( return map_table_config(config) elif isinstance(config, XYChartConfig): return map_xy_config(config, dataset_id=dataset_id) + elif isinstance(config, PieChartConfig): + return map_pie_config(config) + elif isinstance(config, PivotTableChartConfig): + return map_pivot_table_config(config) + elif isinstance(config, MixedTimeseriesChartConfig): + return map_mixed_timeseries_config(config, dataset_id=dataset_id) else: raise ValueError(f"Unsupported config type: {type(config)}") @@ -567,6 +580,197 @@ def map_xy_config( return form_data +def map_pie_config(config: PieChartConfig) -> Dict[str, Any]: + """Map pie chart config to Superset form_data.""" + metric = create_metric_object(config.metric) + + form_data: Dict[str, Any] = { + "viz_type": "pie", + "groupby": [config.dimension.name], + "metric": metric, + "color_scheme": "supersetColors", + "show_labels": config.show_labels, + "show_legend": config.show_legend, + "label_type": config.label_type, + "number_format": config.number_format, + "sort_by_metric": config.sort_by_metric, + "row_limit": config.row_limit, + "donut": config.donut, + "show_total": config.show_total, + "labels_outside": config.labels_outside, + "outerRadius": config.outer_radius, + "innerRadius": config.inner_radius, + "date_format": "smart_date", + } + + if config.filters: + form_data["adhoc_filters"] = [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "subject": filter_config.column, + "operator": map_filter_operator(filter_config.op), + "comparator": filter_config.value, + } + for filter_config in config.filters + if filter_config is not None + ] + + return form_data + + +def map_pivot_table_config(config: PivotTableChartConfig) -> Dict[str, Any]: + """Map pivot table config to Superset form_data.""" + if not config.rows: + raise ValueError("Pivot table must have at least one row grouping column") + if not config.metrics: + raise ValueError("Pivot table must have at least one metric") + + metrics = [create_metric_object(col) for col in config.metrics] + + form_data: Dict[str, Any] = { + "viz_type": "pivot_table_v2", + "groupbyRows": [col.name for col in config.rows], + "groupbyColumns": [col.name for col in config.columns] + if config.columns + else [], + "metrics": metrics, + "aggregateFunction": config.aggregate_function, + "rowTotals": config.show_row_totals, + "colTotals": config.show_column_totals, + "transposePivot": config.transpose, + "combineMetric": config.combine_metric, + "valueFormat": config.value_format, + "metricsLayout": "COLUMNS", + "rowOrder": "key_a_to_z", + "colOrder": "key_a_to_z", + "row_limit": config.row_limit, + } + + if config.filters: + form_data["adhoc_filters"] = [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "subject": filter_config.column, + "operator": map_filter_operator(filter_config.op), + "comparator": filter_config.value, + } + for filter_config in config.filters + if filter_config is not None + ] + + return form_data + + +_MIXED_SERIES_TYPE_MAP = { + "line": "line", + "bar": "bar", + "area": "line", # area uses line type with area=True + "scatter": "scatter", +} + + +def _apply_axis_to_form_data( + form_data: Dict[str, Any], + axis_config: Any, + title_key: str, + format_key: str, + log_key: str | None = None, +) -> None: + """Apply a single axis configuration to form_data.""" + if not axis_config: + return + if axis_config.title: + form_data[title_key] = axis_config.title + if axis_config.format: + form_data[format_key] = axis_config.format + if log_key and axis_config.scale == "log": + form_data[log_key] = True + + +def _add_mixed_axis_config( + form_data: Dict[str, Any], + config: MixedTimeseriesChartConfig, +) -> None: + """Add axis configurations to mixed timeseries form_data.""" + _apply_axis_to_form_data( + form_data, config.x_axis, "xAxisTitle", "x_axis_time_format" + ) + _apply_axis_to_form_data( + form_data, config.y_axis, "yAxisTitle", "y_axis_format", "logAxis" + ) + _apply_axis_to_form_data( + form_data, + config.y_axis_secondary, + "yAxisTitleSecondary", + "y_axis_format_secondary", + "logAxisSecondary", + ) + + +def map_mixed_timeseries_config( + config: MixedTimeseriesChartConfig, + dataset_id: int | str | None = None, +) -> Dict[str, Any]: + """Map mixed timeseries chart config to Superset form_data.""" + if not config.y: + raise ValueError("Mixed timeseries must have at least one primary metric") + if not config.y_secondary: + raise ValueError("Mixed timeseries must have at least one secondary metric") + + # Check if x-axis column is truly temporal + x_is_temporal = is_column_truly_temporal(config.x.name, dataset_id) + + form_data: Dict[str, Any] = { + "viz_type": "mixed_timeseries", + "x_axis": config.x.name, + # Query A + "metrics": [create_metric_object(col) for col in config.y], + "seriesType": _MIXED_SERIES_TYPE_MAP.get(config.primary_kind, "line"), + "area": config.primary_kind == "area", + "yAxisIndex": 0, + # Query B + "metrics_b": [create_metric_object(col) for col in config.y_secondary], + "seriesTypeB": _MIXED_SERIES_TYPE_MAP.get(config.secondary_kind, "bar"), + "areaB": config.secondary_kind == "area", + "yAxisIndexB": 1, + # Display + "show_legend": config.show_legend, + "zoomable": True, + "rich_tooltip": True, + } + + # Configure temporal handling + configure_temporal_handling(form_data, x_is_temporal, config.time_grain) + + # Primary groupby (Query A) + if config.group_by and config.group_by.name != config.x.name: + form_data["groupby"] = [config.group_by.name] + + # Secondary groupby (Query B) + if config.group_by_secondary and config.group_by_secondary.name != config.x.name: + form_data["groupby_b"] = [config.group_by_secondary.name] + + _add_mixed_axis_config(form_data, config) + + # Filters + if config.filters: + form_data["adhoc_filters"] = [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "subject": filter_config.column, + "operator": map_filter_operator(filter_config.op), + "comparator": filter_config.value, + } + for filter_config in config.filters + if filter_config is not None + ] + + return form_data + + def map_filter_operator(op: str) -> str: """Map filter operator to Superset format.""" operator_map = { @@ -585,7 +789,13 @@ def map_filter_operator(op: str) -> str: return operator_map.get(op, op) -def generate_chart_name(config: TableChartConfig | XYChartConfig) -> str: +def generate_chart_name( + config: TableChartConfig + | XYChartConfig + | PieChartConfig + | PivotTableChartConfig + | MixedTimeseriesChartConfig, +) -> str: """Generate a chart name based on the configuration.""" if isinstance(config, TableChartConfig): return f"Table Chart - {', '.join(col.name for col in config.columns)}" @@ -594,6 +804,16 @@ def generate_chart_name(config: TableChartConfig | XYChartConfig) -> str: x_col = config.x.name y_cols = ", ".join(col.name for col in config.y) return f"{chart_type} Chart - {x_col} vs {y_cols}" + elif isinstance(config, PieChartConfig): + metric_label = config.metric.label or config.metric.name + return f"Pie Chart - {config.dimension.name} by {metric_label}" + elif isinstance(config, PivotTableChartConfig): + rows = ", ".join(col.name for col in config.rows) + return f"Pivot Table - {rows}" + elif isinstance(config, MixedTimeseriesChartConfig): + primary = ", ".join(col.name for col in config.y) + secondary = ", ".join(col.name for col in config.y_secondary) + return f"Mixed Chart - {primary} + {secondary}" else: return "Chart" @@ -603,22 +823,7 @@ def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilit if chart: viz_type = getattr(chart, "viz_type", "unknown") else: - # Map config chart_type to viz_type - chart_type = getattr(config, "chart_type", "unknown") - if chart_type == "xy": - kind = getattr(config, "kind", "line") - viz_type_map = { - "line": "echarts_timeseries_line", - "bar": "echarts_timeseries_bar", - "area": "echarts_area", - "scatter": "echarts_timeseries_scatter", - } - viz_type = viz_type_map.get(kind, "echarts_timeseries_line") - elif chart_type == "table": - # Use the viz_type from config if available (table or ag-grid-table) - viz_type = getattr(config, "viz_type", "table") - else: - viz_type = "unknown" + viz_type = _resolve_viz_type(config) # Determine interaction capabilities based on chart type interactive_types = [ @@ -663,27 +868,35 @@ def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilit ) +def _resolve_viz_type(config: Any) -> str: + """Resolve viz_type from a chart config object.""" + chart_type = getattr(config, "chart_type", "unknown") + if chart_type == "xy": + kind = getattr(config, "kind", "line") + viz_type_map = { + "line": "echarts_timeseries_line", + "bar": "echarts_timeseries_bar", + "area": "echarts_area", + "scatter": "echarts_timeseries_scatter", + } + return viz_type_map.get(kind, "echarts_timeseries_line") + elif chart_type == "table": + return getattr(config, "viz_type", "table") + elif chart_type == "pie": + return "pie" + elif chart_type == "pivot_table": + return "pivot_table_v2" + elif chart_type == "mixed_timeseries": + return "mixed_timeseries" + return "unknown" + + def analyze_chart_semantics(chart: Any | None, config: Any) -> ChartSemantics: """Generate semantic understanding of the chart.""" if chart: viz_type = getattr(chart, "viz_type", "unknown") else: - # Map config chart_type to viz_type - chart_type = getattr(config, "chart_type", "unknown") - if chart_type == "xy": - kind = getattr(config, "kind", "line") - viz_type_map = { - "line": "echarts_timeseries_line", - "bar": "echarts_timeseries_bar", - "area": "echarts_area", - "scatter": "echarts_timeseries_scatter", - } - viz_type = viz_type_map.get(kind, "echarts_timeseries_line") - elif chart_type == "table": - # Use the viz_type from config if available (table or ag-grid-table) - viz_type = getattr(config, "viz_type", "table") - else: - viz_type = "unknown" + viz_type = _resolve_viz_type(config) # Generate primary insight based on chart type insights_map = { @@ -696,6 +909,14 @@ def analyze_chart_semantics(chart: Any | None, config: Any) -> ChartSemantics: ), "pie": "Shows proportional relationships within a dataset", "echarts_area": "Emphasizes cumulative totals and part-to-whole relationships", + "pivot_table_v2": ( + "Cross-tabulates data with rows, columns, and aggregated metrics " + "for multi-dimensional analysis" + ), + "mixed_timeseries": ( + "Combines two different chart types on the same time axis " + "for comparing related metrics with different scales" + ), } primary_insight = insights_map.get( diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 2d0bb0cafe7..10b694d22c2 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -485,6 +485,183 @@ class FilterConfig(BaseModel): # Actual chart types +class PieChartConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + chart_type: Literal["pie"] = Field( + ..., + description=( + "Chart type discriminator - MUST be 'pie' for pie/donut charts. " + "This field is REQUIRED and tells Superset which chart " + "configuration schema to use." + ), + ) + dimension: ColumnRef = Field( + ..., description="Category column that defines the pie slices" + ) + metric: ColumnRef = Field( + ..., + description=( + "Value metric that determines slice sizes. " + "Must include an aggregate function (e.g., SUM, COUNT)." + ), + ) + donut: bool = Field(False, description="Render as a donut chart with a center hole") + show_labels: bool = Field(True, description="Display labels on slices") + label_type: Literal[ + "key", + "value", + "percent", + "key_value", + "key_percent", + "key_value_percent", + "value_percent", + ] = Field("key_value_percent", description="Type of labels to show on slices") + sort_by_metric: bool = Field(True, description="Sort slices by metric value") + show_legend: bool = Field(True, description="Whether to show legend") + filters: List[FilterConfig] | None = Field(None, description="Filters to apply") + row_limit: int = Field( + 100, + description="Maximum number of slices to display", + ge=1, + le=10000, + ) + number_format: str = Field( + "SMART_NUMBER", + description="Number format string", + max_length=50, + ) + show_total: bool = Field(False, description="Display aggregate count in center") + labels_outside: bool = Field(True, description="Place labels outside the pie") + outer_radius: int = Field( + 70, + description="Outer edge radius as a percentage (1-100)", + ge=1, + le=100, + ) + inner_radius: int = Field( + 30, + description="Inner radius as a percentage for donut (1-100)", + ge=1, + le=100, + ) + + +class PivotTableChartConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + chart_type: Literal["pivot_table"] = Field( + ..., + description=( + "Chart type discriminator - MUST be 'pivot_table' for interactive " + "pivot tables. This field is REQUIRED." + ), + ) + rows: List[ColumnRef] = Field( + ..., + min_length=1, + description="Row grouping columns (at least one required)", + ) + columns: List[ColumnRef] | None = Field( + None, + description="Column grouping columns (optional, for cross-tabulation)", + ) + metrics: List[ColumnRef] = Field( + ..., + min_length=1, + description=( + "Metrics to aggregate. Each must have an aggregate function " + "(e.g., SUM, COUNT, AVG)." + ), + ) + aggregate_function: Literal[ + "Sum", + "Average", + "Median", + "Sample Variance", + "Sample Standard Deviation", + "Minimum", + "Maximum", + "Count", + "Count Unique Values", + "First", + "Last", + ] = Field("Sum", description="Default aggregation function for the pivot table") + show_row_totals: bool = Field(True, description="Show row totals") + show_column_totals: bool = Field(True, description="Show column totals") + transpose: bool = Field(False, description="Swap rows and columns") + combine_metric: bool = Field( + False, + description="Display metrics side by side within columns", + ) + filters: List[FilterConfig] | None = Field(None, description="Filters to apply") + row_limit: int = Field( + 10000, + description="Maximum number of cells", + ge=1, + le=50000, + ) + value_format: str = Field( + "SMART_NUMBER", + description="Value format string", + max_length=50, + ) + + +class MixedTimeseriesChartConfig(BaseModel): + model_config = ConfigDict(extra="forbid") + + chart_type: Literal["mixed_timeseries"] = Field( + ..., + description=( + "Chart type discriminator - MUST be 'mixed_timeseries' for charts " + "that combine two different series types (e.g., line + bar). " + "This field is REQUIRED." + ), + ) + x: ColumnRef = Field(..., description="X-axis temporal column (shared)") + time_grain: TimeGrain | None = Field( + None, + description=( + "Time granularity for the x-axis. " + "Common values: PT1H (hourly), P1D (daily), P1W (weekly), " + "P1M (monthly), P1Y (yearly)." + ), + ) + # Primary series (Query A) + y: List[ColumnRef] = Field( + ..., + min_length=1, + description="Primary Y-axis metrics (Query A)", + ) + primary_kind: Literal["line", "bar", "area", "scatter"] = Field( + "line", description="Primary series chart type" + ) + group_by: ColumnRef | None = Field( + None, description="Group by column for primary series" + ) + # Secondary series (Query B) + y_secondary: List[ColumnRef] = Field( + ..., + min_length=1, + description="Secondary Y-axis metrics (Query B)", + ) + secondary_kind: Literal["line", "bar", "area", "scatter"] = Field( + "bar", description="Secondary series chart type" + ) + group_by_secondary: ColumnRef | None = Field( + None, description="Group by column for secondary series" + ) + # Display options + show_legend: bool = Field(True, description="Whether to show legend") + x_axis: AxisConfig | None = Field(None, description="X-axis configuration") + y_axis: AxisConfig | None = Field(None, description="Primary Y-axis configuration") + y_axis_secondary: AxisConfig | None = Field( + None, description="Secondary Y-axis configuration" + ) + filters: List[FilterConfig] | None = Field(None, description="Filters to apply") + + class TableChartConfig(BaseModel): model_config = ConfigDict(extra="forbid") @@ -631,10 +808,17 @@ class XYChartConfig(BaseModel): # Discriminated union entry point with custom error handling ChartConfig = Annotated[ - XYChartConfig | TableChartConfig, + XYChartConfig + | TableChartConfig + | PieChartConfig + | PivotTableChartConfig + | MixedTimeseriesChartConfig, Field( discriminator="chart_type", - description="Chart configuration - specify chart_type as 'xy' or 'table'", + description=( + "Chart configuration - specify chart_type as 'xy', 'table', " + "'pie', 'pivot_table', or 'mixed_timeseries'" + ), ), ] diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py index 9ba30c0bd95..aa50d4c843e 100644 --- a/superset/mcp_service/chart/validation/schema_validator.py +++ b/superset/mcp_service/chart/validation/schema_validator.py @@ -126,37 +126,52 @@ class SchemaValidator: return False, ChartGenerationError( error_type="missing_chart_type", message="Missing required field: chart_type", - details="Chart configuration must specify 'chart_type' as either 'xy' " - "or 'table'", + details="Chart configuration must specify 'chart_type'", suggestions=[ "Add 'chart_type': 'xy' for line/bar/area/scatter charts", "Add 'chart_type': 'table' for table visualizations", - "Example: 'config': {'chart_type': 'xy', ...}", + "Add 'chart_type': 'pie' for pie or donut charts", + "Add 'chart_type': 'pivot_table' for interactive pivot tables", + "Add 'chart_type': 'mixed_timeseries' for dual-series time charts", ], error_code="MISSING_CHART_TYPE", ) - if chart_type not in ["xy", "table"]: + return SchemaValidator._pre_validate_chart_type(chart_type, config) + + @staticmethod + def _pre_validate_chart_type( + chart_type: str, + config: Dict[str, Any], + ) -> Tuple[bool, ChartGenerationError | None]: + """Validate chart type and dispatch to type-specific pre-validation.""" + chart_type_validators = { + "xy": SchemaValidator._pre_validate_xy_config, + "table": SchemaValidator._pre_validate_table_config, + "pie": SchemaValidator._pre_validate_pie_config, + "pivot_table": SchemaValidator._pre_validate_pivot_table_config, + "mixed_timeseries": SchemaValidator._pre_validate_mixed_timeseries_config, + } + + if not isinstance(chart_type, str) or chart_type not in chart_type_validators: + valid_types = ", ".join(chart_type_validators.keys()) return False, ChartGenerationError( error_type="invalid_chart_type", message=f"Invalid chart_type: '{chart_type}'", - details=f"Chart type '{chart_type}' is not supported. Must be 'xy' or " - f"'table'", + details=f"Chart type '{chart_type}' is not supported. " + f"Must be one of: {valid_types}", suggestions=[ "Use 'chart_type': 'xy' for line, bar, area, or scatter charts", "Use 'chart_type': 'table' for tabular data display", + "Use 'chart_type': 'pie' for pie or donut charts", + "Use 'chart_type': 'pivot_table' for interactive pivot tables", + "Use 'chart_type': 'mixed_timeseries' for dual-series time charts", "Check spelling and ensure lowercase", ], error_code="INVALID_CHART_TYPE", ) - # Pre-validate structure based on chart type - if chart_type == "xy": - return SchemaValidator._pre_validate_xy_config(config) - elif chart_type == "table": - return SchemaValidator._pre_validate_table_config(config) - - return True, None + return chart_type_validators[chart_type](config) @staticmethod def _pre_validate_xy_config( @@ -237,6 +252,134 @@ class SchemaValidator: return True, None + @staticmethod + def _pre_validate_pie_config( + config: Dict[str, Any], + ) -> Tuple[bool, ChartGenerationError | None]: + """Pre-validate pie chart configuration.""" + missing_fields = [] + + if "dimension" not in config: + missing_fields.append("'dimension' (category column for slices)") + if "metric" not in config: + missing_fields.append("'metric' (value metric for slice sizes)") + + if missing_fields: + return False, ChartGenerationError( + error_type="missing_pie_fields", + message=f"Pie chart missing required " + f"fields: {', '.join(missing_fields)}", + details="Pie charts require a dimension (categories) and a metric " + "(values)", + suggestions=[ + "Add 'dimension' field: {'name': 'category_column'}", + "Add 'metric' field: {'name': 'value_column', 'aggregate': 'SUM'}", + "Example: {'chart_type': 'pie', 'dimension': {'name': " + "'product'}, 'metric': {'name': 'revenue', 'aggregate': 'SUM'}}", + ], + error_code="MISSING_PIE_FIELDS", + ) + + return True, None + + @staticmethod + def _pre_validate_pivot_table_config( + config: Dict[str, Any], + ) -> Tuple[bool, ChartGenerationError | None]: + """Pre-validate pivot table configuration.""" + missing_fields = [] + + if "rows" not in config: + missing_fields.append("'rows' (row grouping columns)") + if "metrics" not in config: + missing_fields.append("'metrics' (aggregation metrics)") + + if missing_fields: + return False, ChartGenerationError( + error_type="missing_pivot_fields", + message=f"Pivot table missing required " + f"fields: {', '.join(missing_fields)}", + details="Pivot tables require row groupings and metrics", + suggestions=[ + "Add 'rows' field: [{'name': 'category'}]", + "Add 'metrics' field: [{'name': 'sales', 'aggregate': 'SUM'}]", + "Optional 'columns' for cross-tabulation: [{'name': 'region'}]", + ], + error_code="MISSING_PIVOT_FIELDS", + ) + + if not isinstance(config.get("rows", []), list): + return False, ChartGenerationError( + error_type="invalid_rows_format", + message="Rows must be a list of columns", + details="The 'rows' field must be an array of column specifications", + suggestions=[ + "Wrap row columns in array: 'rows': [{'name': 'category'}]", + ], + error_code="INVALID_ROWS_FORMAT", + ) + + if not isinstance(config.get("metrics", []), list): + return False, ChartGenerationError( + error_type="invalid_metrics_format", + message="Metrics must be a list", + details="The 'metrics' field must be an array of metric specifications", + suggestions=[ + "Wrap metrics in array: 'metrics': [{'name': 'sales', " + "'aggregate': 'SUM'}]", + ], + error_code="INVALID_METRICS_FORMAT", + ) + + return True, None + + @staticmethod + def _pre_validate_mixed_timeseries_config( + config: Dict[str, Any], + ) -> Tuple[bool, ChartGenerationError | None]: + """Pre-validate mixed timeseries configuration.""" + missing_fields = [] + + if "x" not in config: + missing_fields.append("'x' (X-axis temporal column)") + if "y" not in config: + missing_fields.append("'y' (primary Y-axis metrics)") + if "y_secondary" not in config: + missing_fields.append("'y_secondary' (secondary Y-axis metrics)") + + if missing_fields: + return False, ChartGenerationError( + error_type="missing_mixed_timeseries_fields", + message=f"Mixed timeseries chart missing required " + f"fields: {', '.join(missing_fields)}", + details="Mixed timeseries charts require an x-axis, primary metrics, " + "and secondary metrics", + suggestions=[ + "Add 'x' field: {'name': 'date_column'}", + "Add 'y' field: [{'name': 'revenue', 'aggregate': 'SUM'}]", + "Add 'y_secondary' field: [{'name': 'orders', " + "'aggregate': 'COUNT'}]", + "Optional: 'primary_kind' and 'secondary_kind' for chart types", + ], + error_code="MISSING_MIXED_TIMESERIES_FIELDS", + ) + + for field_name in ["y", "y_secondary"]: + if not isinstance(config.get(field_name, []), list): + return False, ChartGenerationError( + error_type=f"invalid_{field_name}_format", + message=f"'{field_name}' must be a list of metrics", + details=f"The '{field_name}' field must be an array of metric " + "specifications", + suggestions=[ + f"Wrap in array: '{field_name}': " + "[{'name': 'col', 'aggregate': 'SUM'}]", + ], + error_code=f"INVALID_{field_name.upper()}_FORMAT", + ) + + return True, None + @staticmethod def _enhance_validation_error( error: PydanticValidationError, request_data: Dict[str, Any] diff --git a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py new file mode 100644 index 00000000000..e930ab83091 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py @@ -0,0 +1,929 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for new MCP chart types: pie, pivot_table, mixed_timeseries. + +Tests cover schema validation, form_data mapping, chart name generation, +and schema validator pre-validation for all three new chart types. +""" + +from unittest.mock import patch + +import pytest +from pydantic import ValidationError + +from superset.mcp_service.chart.chart_utils import ( + generate_chart_name, + map_config_to_form_data, + map_mixed_timeseries_config, + map_pie_config, + map_pivot_table_config, +) +from superset.mcp_service.chart.schemas import ( + AxisConfig, + ColumnRef, + FilterConfig, + MixedTimeseriesChartConfig, + PieChartConfig, + PivotTableChartConfig, +) +from superset.mcp_service.chart.validation.schema_validator import SchemaValidator + +# ============================================================ +# Pie Chart Schema Tests +# ============================================================ + + +class TestPieChartConfigSchema: + """Test PieChartConfig Pydantic schema validation.""" + + def test_basic_pie_config(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + ) + assert config.chart_type == "pie" + assert config.dimension.name == "product" + assert config.metric.aggregate == "SUM" + assert config.donut is False + assert config.show_labels is True + assert config.label_type == "key_value_percent" + + def test_donut_chart_config(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="category"), + metric=ColumnRef(name="count", aggregate="COUNT"), + donut=True, + inner_radius=40, + outer_radius=80, + ) + assert config.donut is True + assert config.inner_radius == 40 + assert config.outer_radius == 80 + + def test_pie_config_with_all_options(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="region"), + metric=ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + donut=True, + show_labels=False, + label_type="percent", + sort_by_metric=False, + show_legend=False, + row_limit=50, + number_format="$,.2f", + show_total=True, + labels_outside=False, + outer_radius=90, + inner_radius=50, + filters=[FilterConfig(column="status", op="=", value="active")], + ) + assert config.show_labels is False + assert config.label_type == "percent" + assert config.row_limit == 50 + assert config.show_total is True + assert config.filters is not None + assert len(config.filters) == 1 + + def test_pie_config_rejects_extra_fields(self) -> None: + with pytest.raises(ValidationError): + PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + unknown_field="bad", + ) + + def test_pie_config_missing_dimension(self) -> None: + with pytest.raises(ValidationError): + PieChartConfig( + chart_type="pie", + metric=ColumnRef(name="revenue", aggregate="SUM"), + ) + + def test_pie_config_missing_metric(self) -> None: + with pytest.raises(ValidationError): + PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + ) + + def test_pie_config_row_limit_bounds(self) -> None: + with pytest.raises(ValidationError): + PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + row_limit=0, + ) + + def test_pie_config_valid_label_types(self) -> None: + for label_type in [ + "key", + "value", + "percent", + "key_value", + "key_percent", + "key_value_percent", + "value_percent", + ]: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + label_type=label_type, + ) + assert config.label_type == label_type + + +# ============================================================ +# Pie Chart Form Data Mapping Tests +# ============================================================ + + +class TestMapPieConfig: + """Test map_pie_config form_data generation.""" + + def test_basic_pie_form_data(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + ) + result = map_pie_config(config) + + assert result["viz_type"] == "pie" + assert result["groupby"] == ["product"] + assert result["metric"]["aggregate"] == "SUM" + assert result["metric"]["column"]["column_name"] == "revenue" + assert result["show_labels"] is True + assert result["show_legend"] is True + assert result["label_type"] == "key_value_percent" + assert result["sort_by_metric"] is True + assert result["row_limit"] == 100 + assert result["donut"] is False + assert result["color_scheme"] == "supersetColors" + + def test_donut_form_data(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="category"), + metric=ColumnRef(name="count", aggregate="COUNT"), + donut=True, + inner_radius=40, + outer_radius=80, + ) + result = map_pie_config(config) + + assert result["donut"] is True + assert result["innerRadius"] == 40 + assert result["outerRadius"] == 80 + + def test_pie_form_data_with_filters(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + filters=[FilterConfig(column="region", op="=", value="US")], + ) + result = map_pie_config(config) + + assert "adhoc_filters" in result + assert len(result["adhoc_filters"]) == 1 + assert result["adhoc_filters"][0]["subject"] == "region" + assert result["adhoc_filters"][0]["operator"] == "==" + assert result["adhoc_filters"][0]["comparator"] == "US" + + def test_pie_form_data_custom_options(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="status"), + metric=ColumnRef(name="count", aggregate="COUNT"), + show_labels=False, + label_type="percent", + show_legend=False, + number_format="$,.2f", + show_total=True, + labels_outside=False, + ) + result = map_pie_config(config) + + assert result["show_labels"] is False + assert result["label_type"] == "percent" + assert result["show_legend"] is False + assert result["number_format"] == "$,.2f" + assert result["show_total"] is True + assert result["labels_outside"] is False + + def test_pie_form_data_custom_metric_label(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"), + ) + result = map_pie_config(config) + + assert result["metric"]["label"] == "Total Revenue" + assert result["metric"]["hasCustomLabel"] is True + + +# ============================================================ +# Pivot Table Schema Tests +# ============================================================ + + +class TestPivotTableChartConfigSchema: + """Test PivotTableChartConfig Pydantic schema validation.""" + + def test_basic_pivot_table_config(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + ) + assert config.chart_type == "pivot_table" + assert len(config.rows) == 1 + assert len(config.metrics) == 1 + assert config.aggregate_function == "Sum" + assert config.show_row_totals is True + assert config.show_column_totals is True + + def test_pivot_table_with_columns(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + columns=[ColumnRef(name="region")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + ) + assert config.columns is not None + assert len(config.columns) == 1 + assert config.columns[0].name == "region" + + def test_pivot_table_with_all_options(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product"), ColumnRef(name="category")], + columns=[ColumnRef(name="region")], + metrics=[ + ColumnRef(name="revenue", aggregate="SUM"), + ColumnRef(name="orders", aggregate="COUNT"), + ], + aggregate_function="Average", + show_row_totals=False, + show_column_totals=False, + transpose=True, + combine_metric=True, + row_limit=5000, + value_format="$,.2f", + filters=[FilterConfig(column="year", op="=", value=2024)], + ) + assert config.aggregate_function == "Average" + assert config.show_row_totals is False + assert config.transpose is True + assert config.combine_metric is True + assert config.row_limit == 5000 + + def test_pivot_table_missing_rows(self) -> None: + with pytest.raises(ValidationError): + PivotTableChartConfig( + chart_type="pivot_table", + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + ) + + def test_pivot_table_missing_metrics(self) -> None: + with pytest.raises(ValidationError): + PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + ) + + def test_pivot_table_empty_rows(self) -> None: + with pytest.raises(ValidationError): + PivotTableChartConfig( + chart_type="pivot_table", + rows=[], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + ) + + def test_pivot_table_rejects_extra_fields(self) -> None: + with pytest.raises(ValidationError): + PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + unknown_field="bad", + ) + + def test_pivot_table_valid_aggregate_functions(self) -> None: + for agg in ["Sum", "Average", "Median", "Count", "Minimum", "Maximum"]: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + aggregate_function=agg, + ) + assert config.aggregate_function == agg + + +# ============================================================ +# Pivot Table Form Data Mapping Tests +# ============================================================ + + +class TestMapPivotTableConfig: + """Test map_pivot_table_config form_data generation.""" + + def test_basic_pivot_form_data(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + ) + result = map_pivot_table_config(config) + + assert result["viz_type"] == "pivot_table_v2" + assert result["groupbyRows"] == ["product"] + assert result["groupbyColumns"] == [] + assert len(result["metrics"]) == 1 + assert result["metrics"][0]["aggregate"] == "SUM" + assert result["aggregateFunction"] == "Sum" + assert result["rowTotals"] is True + assert result["colTotals"] is True + assert result["metricsLayout"] == "COLUMNS" + + def test_pivot_form_data_with_columns(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + columns=[ColumnRef(name="region"), ColumnRef(name="quarter")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + ) + result = map_pivot_table_config(config) + + assert result["groupbyRows"] == ["product"] + assert result["groupbyColumns"] == ["region", "quarter"] + + def test_pivot_form_data_with_filters(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + filters=[FilterConfig(column="year", op="=", value=2024)], + ) + result = map_pivot_table_config(config) + + assert "adhoc_filters" in result + assert len(result["adhoc_filters"]) == 1 + assert result["adhoc_filters"][0]["subject"] == "year" + + def test_pivot_form_data_custom_options(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + aggregate_function="Average", + show_row_totals=False, + show_column_totals=False, + transpose=True, + combine_metric=True, + value_format="$,.2f", + ) + result = map_pivot_table_config(config) + + assert result["aggregateFunction"] == "Average" + assert result["rowTotals"] is False + assert result["colTotals"] is False + assert result["transposePivot"] is True + assert result["combineMetric"] is True + assert result["valueFormat"] == "$,.2f" + + def test_pivot_form_data_multiple_metrics(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ + ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"), + ColumnRef(name="orders", aggregate="COUNT", label="Order Count"), + ], + ) + result = map_pivot_table_config(config) + + assert len(result["metrics"]) == 2 + assert result["metrics"][0]["label"] == "Total Revenue" + assert result["metrics"][1]["label"] == "Order Count" + + +# ============================================================ +# Mixed Timeseries Schema Tests +# ============================================================ + + +class TestMixedTimeseriesChartConfigSchema: + """Test MixedTimeseriesChartConfig Pydantic schema validation.""" + + def test_basic_mixed_timeseries_config(self) -> None: + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="order_date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + assert config.chart_type == "mixed_timeseries" + assert config.x.name == "order_date" + assert config.primary_kind == "line" + assert config.secondary_kind == "bar" + assert config.show_legend is True + + def test_mixed_timeseries_with_all_options(self) -> None: + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + time_grain="P1M", + y=[ColumnRef(name="revenue", aggregate="SUM")], + primary_kind="area", + group_by=ColumnRef(name="region"), + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + secondary_kind="scatter", + group_by_secondary=ColumnRef(name="channel"), + show_legend=False, + x_axis=AxisConfig(title="Date"), + y_axis=AxisConfig(title="Revenue", format="$,.2f"), + y_axis_secondary=AxisConfig(title="Orders", scale="log"), + filters=[FilterConfig(column="status", op="=", value="complete")], + ) + assert config.primary_kind == "area" + assert config.secondary_kind == "scatter" + assert config.time_grain == "P1M" + assert config.group_by is not None + assert config.group_by.name == "region" + assert config.group_by_secondary is not None + assert config.group_by_secondary.name == "channel" + + def test_mixed_timeseries_missing_y(self) -> None: + with pytest.raises(ValidationError): + MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + + def test_mixed_timeseries_missing_y_secondary(self) -> None: + with pytest.raises(ValidationError): + MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + ) + + def test_mixed_timeseries_empty_y(self) -> None: + with pytest.raises(ValidationError): + MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + + def test_mixed_timeseries_rejects_extra_fields(self) -> None: + with pytest.raises(ValidationError): + MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + unknown_field="bad", + ) + + +# ============================================================ +# Mixed Timeseries Form Data Mapping Tests +# ============================================================ + + +class TestMapMixedTimeseriesConfig: + """Test map_mixed_timeseries_config form_data generation.""" + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_basic_mixed_form_data(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="order_date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert result["viz_type"] == "mixed_timeseries" + assert result["x_axis"] == "order_date" + assert len(result["metrics"]) == 1 + assert result["metrics"][0]["aggregate"] == "SUM" + assert len(result["metrics_b"]) == 1 + assert result["metrics_b"][0]["aggregate"] == "COUNT" + assert result["seriesType"] == "line" + assert result["seriesTypeB"] == "bar" + assert result["yAxisIndex"] == 0 + assert result["yAxisIndexB"] == 1 + assert result["show_legend"] is True + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_mixed_form_data_with_time_grain(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + time_grain="P1W", + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert result["time_grain_sqla"] == "P1W" + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_mixed_form_data_area_series(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + primary_kind="area", + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + secondary_kind="area", + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert result["seriesType"] == "line" + assert result["area"] is True + assert result["seriesTypeB"] == "line" + assert result["areaB"] is True + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_mixed_form_data_with_groupby(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + group_by=ColumnRef(name="region"), + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + group_by_secondary=ColumnRef(name="channel"), + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert result["groupby"] == ["region"] + assert result["groupby_b"] == ["channel"] + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_mixed_form_data_groupby_same_as_x_ignored(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + group_by=ColumnRef(name="date"), # same as x + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + group_by_secondary=ColumnRef(name="date"), # same as x + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert "groupby" not in result + assert "groupby_b" not in result + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_mixed_form_data_with_axis_config(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + x_axis=AxisConfig(title="Date"), + y_axis=AxisConfig(title="Revenue", format="$,.2f", scale="log"), + y_axis_secondary=AxisConfig(title="Orders", format=",d", scale="log"), + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert result["xAxisTitle"] == "Date" + assert result["yAxisTitle"] == "Revenue" + assert result["y_axis_format"] == "$,.2f" + assert result["logAxis"] is True + assert result["yAxisTitleSecondary"] == "Orders" + assert result["y_axis_format_secondary"] == ",d" + assert result["logAxisSecondary"] is True + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_mixed_form_data_with_filters(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + filters=[FilterConfig(column="status", op="=", value="complete")], + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert "adhoc_filters" in result + assert len(result["adhoc_filters"]) == 1 + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_mixed_form_data_non_temporal_x(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = False + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="year"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert result["time_grain_sqla"] is None + assert result["granularity_sqla"] is None + assert result["x_axis_sort_series_type"] == "name" + + +# ============================================================ +# map_config_to_form_data Dispatch Tests +# ============================================================ + + +class TestMapConfigToFormDataDispatch: + """Test map_config_to_form_data dispatches to correct mapping function.""" + + def test_dispatches_pie_config(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + ) + result = map_config_to_form_data(config) + assert result["viz_type"] == "pie" + + def test_dispatches_pivot_table_config(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + ) + result = map_config_to_form_data(config) + assert result["viz_type"] == "pivot_table_v2" + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_dispatches_mixed_timeseries_config(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + result = map_config_to_form_data(config, dataset_id=1) + assert result["viz_type"] == "mixed_timeseries" + + +# ============================================================ +# Chart Name Generation Tests +# ============================================================ + + +class TestGenerateChartNameNewTypes: + """Test generate_chart_name for new chart types.""" + + def test_pie_chart_name(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM"), + ) + result = generate_chart_name(config) + assert result == "Pie Chart - product by revenue" + + def test_pie_chart_name_with_custom_label(self) -> None: + config = PieChartConfig( + chart_type="pie", + dimension=ColumnRef(name="product"), + metric=ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"), + ) + result = generate_chart_name(config) + assert result == "Pie Chart - product by Total Revenue" + + def test_pivot_table_chart_name(self) -> None: + config = PivotTableChartConfig( + chart_type="pivot_table", + rows=[ColumnRef(name="product"), ColumnRef(name="region")], + metrics=[ColumnRef(name="revenue", aggregate="SUM")], + ) + result = generate_chart_name(config) + assert result == "Pivot Table - product, region" + + def test_mixed_timeseries_chart_name(self) -> None: + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + result = generate_chart_name(config) + assert result == "Mixed Chart - revenue + orders" + + +# ============================================================ +# Schema Validator Pre-Validation Tests +# ============================================================ + + +class TestSchemaValidatorNewTypes: + """Test SchemaValidator pre-validation for new chart types.""" + + def test_pie_chart_type_accepted(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "pie", + "dimension": {"name": "product"}, + "metric": {"name": "revenue", "aggregate": "SUM"}, + }, + } + is_valid, request, error = SchemaValidator.validate_request(data) + assert is_valid is True + assert error is None + + def test_pivot_table_chart_type_accepted(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "pivot_table", + "rows": [{"name": "product"}], + "metrics": [{"name": "revenue", "aggregate": "SUM"}], + }, + } + is_valid, request, error = SchemaValidator.validate_request(data) + assert is_valid is True + assert error is None + + def test_mixed_timeseries_chart_type_accepted(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "mixed_timeseries", + "x": {"name": "date"}, + "y": [{"name": "revenue", "aggregate": "SUM"}], + "y_secondary": [{"name": "orders", "aggregate": "COUNT"}], + }, + } + is_valid, request, error = SchemaValidator.validate_request(data) + assert is_valid is True + assert error is None + + def test_pie_missing_dimension_rejected(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "pie", + "metric": {"name": "revenue", "aggregate": "SUM"}, + }, + } + is_valid, _, error = SchemaValidator.validate_request(data) + assert is_valid is False + assert error is not None + assert ( + "dimension" in error.message.lower() + or "dimension" in (error.details or "").lower() + ) + + def test_pie_missing_metric_rejected(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "pie", + "dimension": {"name": "product"}, + }, + } + is_valid, _, error = SchemaValidator.validate_request(data) + assert is_valid is False + assert error is not None + + def test_pivot_table_missing_rows_rejected(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "pivot_table", + "metrics": [{"name": "revenue", "aggregate": "SUM"}], + }, + } + is_valid, _, error = SchemaValidator.validate_request(data) + assert is_valid is False + assert error is not None + assert ( + "rows" in error.message.lower() or "rows" in (error.details or "").lower() + ) + + def test_pivot_table_missing_metrics_rejected(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "pivot_table", + "rows": [{"name": "product"}], + }, + } + is_valid, _, error = SchemaValidator.validate_request(data) + assert is_valid is False + assert error is not None + + def test_mixed_timeseries_missing_y_rejected(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "mixed_timeseries", + "x": {"name": "date"}, + "y_secondary": [{"name": "orders", "aggregate": "COUNT"}], + }, + } + is_valid, _, error = SchemaValidator.validate_request(data) + assert is_valid is False + assert error is not None + + def test_mixed_timeseries_missing_y_secondary_rejected(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "mixed_timeseries", + "x": {"name": "date"}, + "y": [{"name": "revenue", "aggregate": "SUM"}], + }, + } + is_valid, _, error = SchemaValidator.validate_request(data) + assert is_valid is False + assert error is not None + + def test_mixed_timeseries_missing_x_rejected(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "mixed_timeseries", + "y": [{"name": "revenue", "aggregate": "SUM"}], + "y_secondary": [{"name": "orders", "aggregate": "COUNT"}], + }, + } + is_valid, _, error = SchemaValidator.validate_request(data) + assert is_valid is False + assert error is not None + + def test_invalid_chart_type_lists_all_options(self) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": "invalid_type", + }, + } + is_valid, _, error = SchemaValidator.validate_request(data) + assert is_valid is False + assert error is not None + assert "pie" in (error.details or "").lower() + assert "pivot_table" in (error.details or "").lower() + assert "mixed_timeseries" in (error.details or "").lower() + + @pytest.mark.parametrize( + "bad_chart_type", + [["xy"], {"type": "xy"}, 123, True], + ) + def test_non_string_chart_type_rejected_gracefully( + self, bad_chart_type: object + ) -> None: + data = { + "dataset_id": 1, + "config": { + "chart_type": bad_chart_type, + }, + } + is_valid, _, error = SchemaValidator.validate_request(data) + assert is_valid is False + assert error is not None + assert error.error_code == "INVALID_CHART_TYPE"