Files
superset2/tests/unit_tests/mcp_service/chart/test_chart_schemas.py

258 lines
9.6 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP chart schema validation.
"""
import pytest
from pydantic import ValidationError
from superset.mcp_service.chart.schemas import (
ColumnRef,
TableChartConfig,
XYChartConfig,
)
class TestTableChartConfig:
"""Test TableChartConfig validation."""
def test_duplicate_labels_rejected(self) -> None:
"""Test that TableChartConfig rejects duplicate labels."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="product_line", label="product_line"),
ColumnRef(name="sales", aggregate="SUM", label="product_line"),
],
)
def test_unique_labels_accepted(self) -> None:
"""Test that TableChartConfig accepts unique labels."""
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="product_line", label="Product Line"),
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
],
)
assert len(config.columns) == 2
def test_default_viz_type_is_table(self) -> None:
"""Test that default viz_type is 'table'."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product")],
)
assert config.viz_type == "table"
def test_ag_grid_table_viz_type_accepted(self) -> None:
"""Test that viz_type='ag-grid-table' is accepted for AG Grid table."""
config = TableChartConfig(
chart_type="table",
viz_type="ag-grid-table",
columns=[
ColumnRef(name="product_line"),
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
],
)
assert config.viz_type == "ag-grid-table"
assert len(config.columns) == 2
def test_ag_grid_table_with_all_options(self) -> None:
"""Test AG Grid table with filters and sorting."""
from superset.mcp_service.chart.schemas import FilterConfig
config = TableChartConfig(
chart_type="table",
viz_type="ag-grid-table",
columns=[
ColumnRef(name="product_line"),
ColumnRef(name="quantity", aggregate="SUM", label="Total Quantity"),
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
],
filters=[FilterConfig(column="status", op="=", value="active")],
sort_by=["product_line"],
)
assert config.viz_type == "ag-grid-table"
assert len(config.columns) == 3
assert config.filters is not None
assert len(config.filters) == 1
assert config.sort_by == ["product_line"]
def test_invalid_viz_type_rejected(self) -> None:
"""Test that invalid viz_type values are rejected."""
from pydantic import ValidationError
with pytest.raises(ValidationError):
TableChartConfig(
chart_type="table",
viz_type="invalid-type",
columns=[ColumnRef(name="product")],
)
class TestXYChartConfig:
"""Test XYChartConfig validation."""
def test_different_labels_accepted(self) -> None:
"""Test that different labels for x and y are accepted."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="product_line"), # Label: "product_line"
y=[
ColumnRef(
name="product_line", aggregate="COUNT"
), # Label: "COUNT(product_line)"
],
)
assert config.x.name == "product_line"
assert config.y[0].aggregate == "COUNT"
def test_explicit_duplicate_label_rejected(self) -> None:
"""Test that explicit duplicate labels are rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="product_line"),
y=[ColumnRef(name="sales", label="product_line")],
)
def test_duplicate_y_axis_labels_rejected(self) -> None:
"""Test that duplicate y-axis labels are rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="revenue", aggregate="SUM", label="SUM(sales)"),
],
)
def test_unique_labels_accepted(self) -> None:
"""Test that unique labels are accepted."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date", label="Order Date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="profit", aggregate="AVG", label="Average Profit"),
],
)
assert len(config.y) == 2
def test_group_by_duplicate_with_x_rejected(self) -> None:
"""Test that group_by conflicts with x are rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="region"),
y=[ColumnRef(name="sales", aggregate="SUM")],
group_by=ColumnRef(name="category", label="region"),
)
def test_realistic_chart_configurations(self) -> None:
"""Test realistic chart configurations."""
# This should work - COUNT(product_line) != product_line
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="product_line"),
y=[
ColumnRef(name="product_line", aggregate="COUNT"),
ColumnRef(name="sales", aggregate="SUM"),
],
)
assert config.x.name == "product_line"
assert len(config.y) == 2
def test_time_series_chart_configuration(self) -> None:
"""Test time series chart configuration works."""
# This should PASS now - the chart creation logic fixes duplicates
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"),
],
kind="line",
)
assert config.x.name == "order_date"
assert config.kind == "line"
def test_time_series_with_custom_x_axis_label(self) -> None:
"""Test time series chart with custom x-axis label."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date", label="Order Date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"),
],
kind="line",
)
assert config.x.label == "Order Date"
def test_area_chart_configuration(self) -> None:
"""Test area chart configuration."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="category"),
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
kind="area",
)
assert config.kind == "area"
def test_unknown_fields_rejected(self) -> None:
"""Test that unknown fields like 'series' are rejected."""
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="territory"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="bar",
series=ColumnRef(name="year"),
)
def test_group_by_accepted(self) -> None:
"""Test that group_by is the correct field for series grouping."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="territory"),
y=[ColumnRef(name="sales", aggregate="SUM")],
kind="bar",
group_by=ColumnRef(name="year"),
)
assert config.group_by is not None
assert config.group_by.name == "year"
class TestTableChartConfigExtraFields:
"""Test TableChartConfig rejects unknown fields."""
def test_unknown_fields_rejected(self) -> None:
"""Test that unknown fields are rejected."""
with pytest.raises(ValidationError, match="Extra inputs are not permitted"):
TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="product")],
foo="bar",
)