# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Unit tests for MCP chart schema validation. """ import pytest from pydantic import ValidationError from superset.mcp_service.chart.schemas import ( ColumnRef, TableChartConfig, XYChartConfig, ) class TestTableChartConfig: """Test TableChartConfig validation.""" def test_duplicate_labels_rejected(self) -> None: """Test that TableChartConfig rejects duplicate labels.""" with pytest.raises(ValidationError, match="Duplicate column/metric labels"): TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product_line", label="product_line"), ColumnRef(name="sales", aggregate="SUM", label="product_line"), ], ) def test_unique_labels_accepted(self) -> None: """Test that TableChartConfig accepts unique labels.""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product_line", label="Product Line"), ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ], ) assert len(config.columns) == 2 def test_default_viz_type_is_table(self) -> None: """Test that default viz_type is 'table'.""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], ) assert config.viz_type == "table" def test_ag_grid_table_viz_type_accepted(self) -> None: """Test that viz_type='ag-grid-table' is accepted for AG Grid table.""" config = TableChartConfig( chart_type="table", viz_type="ag-grid-table", columns=[ ColumnRef(name="product_line"), ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ], ) assert config.viz_type == "ag-grid-table" assert len(config.columns) == 2 def test_ag_grid_table_with_all_options(self) -> None: """Test AG Grid table with filters and sorting.""" from superset.mcp_service.chart.schemas import FilterConfig config = TableChartConfig( chart_type="table", viz_type="ag-grid-table", columns=[ ColumnRef(name="product_line"), ColumnRef(name="quantity", aggregate="SUM", label="Total Quantity"), ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ], filters=[FilterConfig(column="status", op="=", value="active")], sort_by=["product_line"], ) assert config.viz_type == "ag-grid-table" assert len(config.columns) == 3 assert config.filters is not None assert len(config.filters) == 1 assert config.sort_by == ["product_line"] def test_invalid_viz_type_rejected(self) -> None: """Test that invalid viz_type values are rejected.""" from pydantic import ValidationError with pytest.raises(ValidationError): TableChartConfig( chart_type="table", viz_type="invalid-type", columns=[ColumnRef(name="product")], ) class TestXYChartConfig: """Test XYChartConfig validation.""" def test_different_labels_accepted(self) -> None: """Test that different labels for x and y are accepted.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="product_line"), # Label: "product_line" y=[ ColumnRef( name="product_line", aggregate="COUNT" ), # Label: "COUNT(product_line)" ], ) assert config.x.name == "product_line" assert config.y[0].aggregate == "COUNT" def test_explicit_duplicate_label_rejected(self) -> None: """Test that explicit duplicate labels are rejected.""" with pytest.raises(ValidationError, match="Duplicate column/metric labels"): XYChartConfig( chart_type="xy", x=ColumnRef(name="product_line"), y=[ColumnRef(name="sales", label="product_line")], ) def test_duplicate_y_axis_labels_rejected(self) -> None: """Test that duplicate y-axis labels are rejected.""" with pytest.raises(ValidationError, match="Duplicate column/metric labels"): XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ ColumnRef(name="sales", aggregate="SUM"), ColumnRef(name="revenue", aggregate="SUM", label="SUM(sales)"), ], ) def test_unique_labels_accepted(self) -> None: """Test that unique labels are accepted.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date", label="Order Date"), y=[ ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ColumnRef(name="profit", aggregate="AVG", label="Average Profit"), ], ) assert len(config.y) == 2 def test_group_by_duplicate_with_x_rejected(self) -> None: """Test that group_by conflicts with x are rejected.""" with pytest.raises(ValidationError, match="Duplicate column/metric labels"): XYChartConfig( chart_type="xy", x=ColumnRef(name="region"), y=[ColumnRef(name="sales", aggregate="SUM")], group_by=ColumnRef(name="category", label="region"), ) def test_realistic_chart_configurations(self) -> None: """Test realistic chart configurations.""" # This should work - COUNT(product_line) != product_line config = XYChartConfig( chart_type="xy", x=ColumnRef(name="product_line"), y=[ ColumnRef(name="product_line", aggregate="COUNT"), ColumnRef(name="sales", aggregate="SUM"), ], ) assert config.x.name == "product_line" assert len(config.y) == 2 def test_time_series_chart_configuration(self) -> None: """Test time series chart configuration works.""" # This should PASS now - the chart creation logic fixes duplicates config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date"), y=[ ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"), ], kind="line", ) assert config.x.name == "order_date" assert config.kind == "line" def test_time_series_with_custom_x_axis_label(self) -> None: """Test time series chart with custom x-axis label.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date", label="Order Date"), y=[ ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"), ], kind="line", ) assert config.x.label == "Order Date" def test_area_chart_configuration(self) -> None: """Test area chart configuration.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="category"), y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")], kind="area", ) assert config.kind == "area" def test_unknown_fields_rejected(self) -> None: """Test that unknown fields like 'series' are rejected.""" with pytest.raises(ValidationError, match="Extra inputs are not permitted"): XYChartConfig( chart_type="xy", x=ColumnRef(name="territory"), y=[ColumnRef(name="sales", aggregate="SUM")], kind="bar", series=ColumnRef(name="year"), ) def test_group_by_accepted(self) -> None: """Test that group_by is the correct field for series grouping.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="territory"), y=[ColumnRef(name="sales", aggregate="SUM")], kind="bar", group_by=ColumnRef(name="year"), ) assert config.group_by is not None assert config.group_by.name == "year" class TestTableChartConfigExtraFields: """Test TableChartConfig rejects unknown fields.""" def test_unknown_fields_rejected(self) -> None: """Test that unknown fields are rejected.""" with pytest.raises(ValidationError, match="Extra inputs are not permitted"): TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], foo="bar", )