feat(mcp): MCP service implementation (PRs 3-9 consolidated) (#35877)

This commit is contained in:
Amin Ghadersohi
2025-11-01 02:33:21 +11:00
committed by GitHub
parent 30d584afd1
commit fee4e7d8e2
106 changed files with 21826 additions and 223 deletions

View File

@@ -0,0 +1,160 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP chart schema validation.
"""
import pytest
from pydantic import ValidationError
from superset.mcp_service.chart.schemas import (
ColumnRef,
TableChartConfig,
XYChartConfig,
)
class TestTableChartConfig:
"""Test TableChartConfig validation."""
def test_duplicate_labels_rejected(self) -> None:
"""Test that TableChartConfig rejects duplicate labels."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
TableChartConfig(
columns=[
ColumnRef(name="product_line", label="product_line"),
ColumnRef(name="sales", aggregate="SUM", label="product_line"),
]
)
def test_unique_labels_accepted(self) -> None:
"""Test that TableChartConfig accepts unique labels."""
config = TableChartConfig(
columns=[
ColumnRef(name="product_line", label="Product Line"),
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
]
)
assert len(config.columns) == 2
class TestXYChartConfig:
"""Test XYChartConfig validation."""
def test_different_labels_accepted(self) -> None:
"""Test that different labels for x and y are accepted."""
config = XYChartConfig(
x=ColumnRef(name="product_line"), # Label: "product_line"
y=[
ColumnRef(
name="product_line", aggregate="COUNT"
), # Label: "COUNT(product_line)"
],
)
assert config.x.name == "product_line"
assert config.y[0].aggregate == "COUNT"
def test_explicit_duplicate_label_rejected(self) -> None:
"""Test that explicit duplicate labels are rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
x=ColumnRef(name="product_line"),
y=[ColumnRef(name="sales", label="product_line")],
)
def test_duplicate_y_axis_labels_rejected(self) -> None:
"""Test that duplicate y-axis labels are rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="revenue", aggregate="SUM", label="SUM(sales)"),
],
)
def test_unique_labels_accepted(self) -> None:
"""Test that unique labels are accepted."""
config = XYChartConfig(
x=ColumnRef(name="date", label="Order Date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="profit", aggregate="AVG", label="Average Profit"),
],
)
assert len(config.y) == 2
def test_group_by_duplicate_with_x_rejected(self) -> None:
"""Test that group_by conflicts with x are rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
x=ColumnRef(name="region"),
y=[ColumnRef(name="sales", aggregate="SUM")],
group_by=ColumnRef(name="category", label="region"),
)
def test_realistic_chart_configurations(self) -> None:
"""Test realistic chart configurations."""
# This should work - COUNT(product_line) != product_line
config = XYChartConfig(
x=ColumnRef(name="product_line"),
y=[
ColumnRef(name="product_line", aggregate="COUNT"),
ColumnRef(name="sales", aggregate="SUM"),
],
)
assert config.x.name == "product_line"
assert len(config.y) == 2
def test_time_series_chart_configuration(self) -> None:
"""Test time series chart configuration works."""
# This should PASS now - the chart creation logic fixes duplicates
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"),
],
kind="line",
)
assert config.x.name == "order_date"
assert config.kind == "line"
def test_time_series_with_custom_x_axis_label(self) -> None:
"""Test time series chart with custom x-axis label."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date", label="Order Date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"),
],
kind="line",
)
assert config.x.label == "Order Date"
def test_area_chart_configuration(self) -> None:
"""Test area chart configuration."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="category"),
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
kind="area",
)
assert config.kind == "area"

View File

@@ -0,0 +1,465 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for chart utilities module"""
from unittest.mock import patch
import pytest
from superset.mcp_service.chart.chart_utils import (
create_metric_object,
generate_chart_name,
generate_explore_link,
map_config_to_form_data,
map_filter_operator,
map_table_config,
map_xy_config,
)
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
LegendConfig,
TableChartConfig,
XYChartConfig,
)
class TestCreateMetricObject:
"""Test create_metric_object function"""
def test_create_metric_object_with_aggregate(self) -> None:
"""Test creating metric object with specified aggregate"""
col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue")
result = create_metric_object(col)
assert result["aggregate"] == "SUM"
assert result["column"]["column_name"] == "revenue"
assert result["label"] == "Total Revenue"
assert result["optionName"] == "metric_revenue"
assert result["expressionType"] == "SIMPLE"
def test_create_metric_object_default_aggregate(self) -> None:
"""Test creating metric object with default aggregate"""
col = ColumnRef(name="orders")
result = create_metric_object(col)
assert result["aggregate"] == "SUM"
assert result["column"]["column_name"] == "orders"
assert result["label"] == "SUM(orders)"
assert result["optionName"] == "metric_orders"
class TestMapFilterOperator:
"""Test map_filter_operator function"""
def test_map_filter_operators(self) -> None:
"""Test mapping of various filter operators"""
assert map_filter_operator("=") == "=="
assert map_filter_operator(">") == ">"
assert map_filter_operator("<") == "<"
assert map_filter_operator(">=") == ">="
assert map_filter_operator("<=") == "<="
assert map_filter_operator("!=") == "!="
def test_map_filter_operator_unknown(self) -> None:
"""Test mapping of unknown operator returns original"""
assert map_filter_operator("UNKNOWN") == "UNKNOWN"
class TestMapTableConfig:
"""Test map_table_config function"""
def test_map_table_config_basic(self) -> None:
"""Test basic table config mapping with aggregated columns"""
config = TableChartConfig(
columns=[
ColumnRef(name="product", aggregate="COUNT"),
ColumnRef(name="revenue", aggregate="SUM"),
]
)
result = map_table_config(config)
assert result["viz_type"] == "table"
assert result["query_mode"] == "aggregate"
# Aggregated columns should be in metrics, not all_columns
assert "all_columns" not in result
assert len(result["metrics"]) == 2
assert result["metrics"][0]["aggregate"] == "COUNT"
assert result["metrics"][1]["aggregate"] == "SUM"
def test_map_table_config_raw_columns(self) -> None:
"""Test table config mapping with raw columns (no aggregates)"""
config = TableChartConfig(
columns=[
ColumnRef(name="product"),
ColumnRef(name="category"),
]
)
result = map_table_config(config)
assert result["viz_type"] == "table"
assert result["query_mode"] == "raw"
# Raw columns should be in all_columns
assert result["all_columns"] == ["product", "category"]
assert "metrics" not in result
def test_map_table_config_with_filters(self) -> None:
"""Test table config mapping with filters"""
config = TableChartConfig(
columns=[ColumnRef(name="product")],
filters=[FilterConfig(column="status", op="=", value="active")],
)
result = map_table_config(config)
assert "adhoc_filters" in result
assert len(result["adhoc_filters"]) == 1
filter_obj = result["adhoc_filters"][0]
assert filter_obj["subject"] == "status"
assert filter_obj["operator"] == "=="
assert filter_obj["comparator"] == "active"
assert filter_obj["expressionType"] == "SIMPLE"
def test_map_table_config_with_sort(self) -> None:
"""Test table config mapping with sort"""
config = TableChartConfig(
columns=[ColumnRef(name="product")], sort_by=["product", "revenue"]
)
result = map_table_config(config)
assert result["order_by_cols"] == ["product", "revenue"]
class TestMapXYConfig:
"""Test map_xy_config function"""
def test_map_xy_config_line_chart(self) -> None:
"""Test XY config mapping for line chart"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
kind="line",
)
result = map_xy_config(config)
assert result["viz_type"] == "echarts_timeseries_line"
assert result["x_axis"] == "date"
assert len(result["metrics"]) == 1
assert result["metrics"][0]["aggregate"] == "SUM"
def test_map_xy_config_with_groupby(self) -> None:
"""Test XY config mapping with group by"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="bar",
group_by=ColumnRef(name="region"),
)
result = map_xy_config(config)
assert result["viz_type"] == "echarts_timeseries_bar"
assert result["groupby"] == ["region"]
def test_map_xy_config_with_axes(self) -> None:
"""Test XY config mapping with axis configurations"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="area",
x_axis=AxisConfig(title="Date", format="%Y-%m-%d"),
y_axis=AxisConfig(title="Revenue", scale="log", format="$,.2f"),
)
result = map_xy_config(config)
assert result["viz_type"] == "echarts_area"
assert result["x_axis_title"] == "Date"
assert result["x_axis_format"] == "%Y-%m-%d"
assert result["y_axis_title"] == "Revenue"
assert result["y_axis_format"] == "$,.2f"
assert result["y_axis_scale"] == "log"
def test_map_xy_config_with_legend(self) -> None:
"""Test XY config mapping with legend configuration"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="scatter",
legend=LegendConfig(show=False, position="top"),
)
result = map_xy_config(config)
assert result["viz_type"] == "echarts_timeseries_scatter"
assert result["show_legend"] is False
assert result["legend_orientation"] == "top"
class TestMapConfigToFormData:
"""Test map_config_to_form_data function"""
def test_map_table_config_type(self) -> None:
"""Test mapping table config type"""
config = TableChartConfig(columns=[ColumnRef(name="test")])
result = map_config_to_form_data(config)
assert result["viz_type"] == "table"
def test_map_xy_config_type(self) -> None:
"""Test mapping XY config type"""
config = XYChartConfig(
x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line"
)
result = map_config_to_form_data(config)
assert result["viz_type"] == "echarts_timeseries_line"
def test_map_unsupported_config_type(self) -> None:
"""Test mapping unsupported config type raises error"""
with pytest.raises(ValueError, match="Unsupported config type"):
map_config_to_form_data("invalid_config") # type: ignore
class TestGenerateChartName:
"""Test generate_chart_name function"""
def test_generate_table_chart_name(self) -> None:
"""Test generating name for table chart"""
config = TableChartConfig(
columns=[
ColumnRef(name="product"),
ColumnRef(name="revenue"),
]
)
result = generate_chart_name(config)
assert result == "Table Chart - product, revenue"
def test_generate_xy_chart_name(self) -> None:
"""Test generating name for XY chart"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue"), ColumnRef(name="orders")],
kind="line",
)
result = generate_chart_name(config)
assert result == "Line Chart - date vs revenue, orders"
def test_generate_chart_name_unsupported(self) -> None:
"""Test generating name for unsupported config type"""
result = generate_chart_name("invalid_config") # type: ignore
assert result == "Chart"
class TestGenerateExploreLink:
"""Test generate_explore_link function"""
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
def test_generate_explore_link_uses_base_url(self, mock_get_base_url) -> None:
"""Test that generate_explore_link uses the configured base URL"""
from urllib.parse import urlparse
mock_get_base_url.return_value = "https://superset.example.com"
form_data = {"viz_type": "table", "metrics": ["count"]}
result = generate_explore_link("123", form_data)
# Should use the configured base URL - use urlparse to avoid CodeQL warning
parsed_url = urlparse(result)
expected_netloc = "superset.example.com"
assert parsed_url.scheme == "https"
assert parsed_url.netloc == expected_netloc
assert "/explore/" in parsed_url.path
assert "datasource_id=123" in result
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
def test_generate_explore_link_fallback_url(self, mock_get_base_url) -> None:
"""Test generate_explore_link returns fallback URL when dataset not found"""
mock_get_base_url.return_value = "http://localhost:9001"
form_data = {"viz_type": "table"}
# Mock dataset not found scenario
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None):
result = generate_explore_link("999", form_data)
assert (
result
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=999"
)
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
@patch("superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand")
def test_generate_explore_link_with_form_data_key(
self, mock_command, mock_get_base_url
) -> None:
"""Test generate_explore_link creates form_data_key when dataset exists"""
mock_get_base_url.return_value = "http://localhost:9001"
mock_command.return_value.run.return_value = "test_form_data_key"
# Mock dataset exists
mock_dataset = type("Dataset", (), {"id": 123})()
with patch(
"superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_dataset
):
result = generate_explore_link(123, {"viz_type": "table"})
assert (
result == "http://localhost:9001/explore/?form_data_key=test_form_data_key"
)
mock_command.assert_called_once()
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
def test_generate_explore_link_exception_handling(self, mock_get_base_url) -> None:
"""Test generate_explore_link handles exceptions gracefully"""
mock_get_base_url.return_value = "http://localhost:9001"
# Mock exception during form_data creation
with patch(
"superset.daos.dataset.DatasetDAO.find_by_id",
side_effect=Exception("DB error"),
):
result = generate_explore_link("123", {"viz_type": "table"})
# Should fallback to basic URL
assert (
result
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=123"
)
class TestCriticalBugFixes:
"""Test critical bug fixes for chart utilities."""
def test_time_series_aggregation_fix(self) -> None:
"""Test that time series charts preserve temporal dimension."""
# Create a time series chart configuration
config = XYChartConfig(
chart_type="xy",
kind="line",
x=ColumnRef(name="order_date"),
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
)
form_data = map_xy_config(config)
# Verify the fix: x_axis should be set correctly
assert form_data["x_axis"] == "order_date"
# Verify the fix: groupby should not duplicate x_axis
# This prevents the "Duplicate column/metric labels" error
assert "groupby" not in form_data or "order_date" not in form_data.get(
"groupby", []
)
# Verify chart type mapping
assert form_data["viz_type"] == "echarts_timeseries_line"
def test_time_series_with_explicit_group_by(self) -> None:
"""Test time series with explicit group_by different from x_axis."""
config = XYChartConfig(
chart_type="xy",
kind="line",
x=ColumnRef(name="order_date"),
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
group_by=ColumnRef(name="category"),
)
form_data = map_xy_config(config)
# Verify x_axis is set
assert form_data["x_axis"] == "order_date"
# Verify groupby only contains the explicit group_by, not x_axis
assert form_data.get("groupby") == ["category"]
assert "order_date" not in form_data.get("groupby", [])
def test_duplicate_label_prevention(self) -> None:
"""Test that duplicate column/metric labels are prevented."""
# This configuration would previously cause:
# "Duplicate column/metric labels: 'price_each'"
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="price_each", label="Price Each"), # Custom label
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="quantity", aggregate="COUNT", label="Order Count"),
],
group_by=ColumnRef(name="price_each"), # Same column as x_axis
kind="line",
)
form_data = map_xy_config(config)
# Verify the fix: x_axis is set
assert form_data["x_axis"] == "price_each"
# Verify the fix: groupby is empty because group_by == x_axis
# This prevents the duplicate label error
assert "groupby" not in form_data or not form_data["groupby"]
def test_metric_object_creation_with_labels(self) -> None:
"""Test that metric objects are created correctly with proper labels."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="profit", aggregate="AVG"), # No custom label
],
kind="bar",
)
form_data = map_xy_config(config)
# Verify metrics are created correctly
metrics = form_data["metrics"]
assert len(metrics) == 2
# First metric with custom label
assert metrics[0]["label"] == "Total Sales"
assert metrics[0]["aggregate"] == "SUM"
assert metrics[0]["column"]["column_name"] == "sales"
# Second metric with auto-generated label
assert metrics[1]["label"] == "AVG(profit)"
assert metrics[1]["aggregate"] == "AVG"
assert metrics[1]["column"]["column_name"] == "profit"
def test_chart_type_mapping_comprehensive(self) -> None:
"""Test that chart types are mapped correctly to Superset viz types."""
test_cases = [
("line", "echarts_timeseries_line"),
("bar", "echarts_timeseries_bar"),
("area", "echarts_area"),
("scatter", "echarts_timeseries_scatter"),
]
for kind, expected_viz_type in test_cases:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind=kind,
)
form_data = map_xy_config(config)
assert form_data["viz_type"] == expected_viz_type

View File

@@ -0,0 +1,268 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP generate_chart tool
"""
import pytest
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
GenerateChartRequest,
LegendConfig,
TableChartConfig,
XYChartConfig,
)
class TestGenerateChart:
"""Tests for generate_chart MCP tool."""
@pytest.mark.asyncio
async def test_generate_chart_request_structure(self):
"""Test that chart generation request structures are properly formed."""
# Table chart request
table_config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region", label="Region"),
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
],
filters=[FilterConfig(column="year", op="=", value="2024")],
sort_by=["sales"],
)
table_request = GenerateChartRequest(dataset_id="1", config=table_config)
assert table_request.dataset_id == "1"
assert table_request.config.chart_type == "table"
assert len(table_request.config.columns) == 2
assert table_request.config.columns[0].name == "region"
assert table_request.config.columns[1].aggregate == "SUM"
# XY chart request
xy_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind="line",
group_by=ColumnRef(name="category"),
x_axis=AxisConfig(title="Date", format="smart_date"),
y_axis=AxisConfig(title="Value", format="$,.2f"),
legend=LegendConfig(show=True, position="top"),
)
xy_request = GenerateChartRequest(dataset_id="2", config=xy_config)
assert xy_request.config.chart_type == "xy"
assert xy_request.config.x.name == "date"
assert xy_request.config.y[0].aggregate == "SUM"
assert xy_request.config.kind == "line"
assert xy_request.config.x_axis.title == "Date"
assert xy_request.config.legend.show is True
@pytest.mark.asyncio
async def test_generate_chart_validation_error_handling(self):
"""Test that validation errors are properly structured."""
# Create a validation error with the correct structure
validation_error_entry = {
"field": "x_axis",
"provided_value": "invalid_col",
"error_type": "column_not_found",
"message": "Column 'invalid_col' not found",
}
# Test that validation error structure is correct
assert validation_error_entry["field"] == "x_axis"
assert validation_error_entry["error_type"] == "column_not_found"
@pytest.mark.asyncio
async def test_chart_config_variations(self):
"""Test various chart configuration options."""
# Test all chart types
chart_types = ["line", "bar", "area", "scatter"]
for chart_type in chart_types:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x_col"),
y=[ColumnRef(name="y_col")],
kind=chart_type,
)
assert config.kind == chart_type
# Test multiple Y columns
multi_y_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="profit", aggregate="AVG"),
ColumnRef(name="quantity", aggregate="COUNT"),
],
kind="line",
)
assert len(multi_y_config.y) == 3
assert multi_y_config.y[1].aggregate == "AVG"
# Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="]
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
for i, f in enumerate(filters):
assert f.op == operators[i]
@pytest.mark.asyncio
async def test_generate_chart_response_structure(self):
"""Test the expected response structure for chart generation."""
# The response should contain these fields
_ = {
"chart": {
"id": int,
"slice_name": str,
"viz_type": str,
"url": str,
"uuid": str,
"saved": bool,
},
"error": None,
"success": bool,
"schema_version": str,
"api_version": str,
}
# When chart creation succeeds, these additional fields may be present
_ = [
"previews",
"capabilities",
"semantics",
"explore_url",
"form_data_key",
"api_endpoints",
"performance",
"accessibility",
]
# This is just a structural test - actual integration tests would verify
# the tool returns data matching this structure
@pytest.mark.asyncio
async def test_dataset_id_flexibility(self):
"""Test that dataset_id can be string or int."""
configs = [
GenerateChartRequest(
dataset_id="123",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
),
GenerateChartRequest(
dataset_id="uuid-string-here",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
),
]
for config in configs:
assert isinstance(config.dataset_id, str)
@pytest.mark.asyncio
async def test_save_chart_flag(self):
"""Test save_chart flag behavior."""
# Default should be True (save chart)
request1 = GenerateChartRequest(
dataset_id="1",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
)
assert request1.save_chart is True
# Explicit False (preview only)
request2 = GenerateChartRequest(
dataset_id="1",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
save_chart=False,
)
assert request2.save_chart is False
@pytest.mark.asyncio
async def test_preview_formats(self):
"""Test preview format options."""
formats = ["url", "ascii", "table"]
request = GenerateChartRequest(
dataset_id="1",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
generate_preview=True,
preview_formats=formats,
)
assert request.generate_preview is True
assert set(request.preview_formats) == set(formats)
@pytest.mark.asyncio
async def test_column_ref_features(self):
"""Test ColumnRef features like aggregation and labels."""
# Simple column
col1 = ColumnRef(name="region")
assert col1.name == "region"
assert col1.label is None
assert col1.aggregate is None
# Column with aggregation
col2 = ColumnRef(name="sales", aggregate="SUM", label="Total Sales")
assert col2.name == "sales"
assert col2.aggregate == "SUM"
assert col2.label == "Total Sales"
# All supported aggregations
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
for agg in aggs:
col = ColumnRef(name="value", aggregate=agg)
assert col.aggregate == agg
@pytest.mark.asyncio
async def test_axis_config_options(self):
"""Test axis configuration options."""
axis = AxisConfig(
title="Sales Amount",
format="$,.2f",
scale="linear",
)
assert axis.title == "Sales Amount"
assert axis.format == "$,.2f"
assert axis.scale == "linear"
# Test different formats
formats = ["SMART_NUMBER", "$,.2f", ",.0f", "smart_date", ".3%"]
for fmt in formats:
axis = AxisConfig(format=fmt)
assert axis.format == fmt
@pytest.mark.asyncio
async def test_legend_config_options(self):
"""Test legend configuration options."""
positions = ["top", "bottom", "left", "right"]
for pos in positions:
legend = LegendConfig(show=True, position=pos)
assert legend.position == pos
# Hidden legend
legend = LegendConfig(show=False)
assert legend.show is False

View File

@@ -0,0 +1,290 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for get_chart_preview MCP tool
"""
import pytest
from superset.mcp_service.chart.schemas import (
ASCIIPreview,
GetChartPreviewRequest,
TablePreview,
URLPreview,
)
class TestGetChartPreview:
"""Tests for get_chart_preview MCP tool."""
@pytest.mark.asyncio
async def test_get_chart_preview_request_structure(self):
"""Test that preview request structures are properly formed."""
# Numeric ID request
request1 = GetChartPreviewRequest(identifier=123, format="url")
assert request1.identifier == 123
assert request1.format == "url"
# Default dimensions are set
assert request1.width == 800
assert request1.height == 600
# String ID request
request2 = GetChartPreviewRequest(identifier="456", format="ascii")
assert request2.identifier == "456"
assert request2.format == "ascii"
# UUID request
request3 = GetChartPreviewRequest(
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", format="table"
)
assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert request3.format == "table"
# Default format
request4 = GetChartPreviewRequest(identifier=789)
assert request4.format == "url" # default
@pytest.mark.asyncio
async def test_preview_format_types(self):
"""Test different preview format types."""
formats = ["url", "ascii", "table"]
for fmt in formats:
request = GetChartPreviewRequest(identifier=1, format=fmt)
assert request.format == fmt
@pytest.mark.asyncio
async def test_url_preview_structure(self):
"""Test URLPreview response structure."""
preview = URLPreview(
preview_url="http://localhost:5008/screenshot/chart/123.png",
width=800,
height=600,
supports_interaction=False,
)
assert preview.type == "url"
assert preview.preview_url == "http://localhost:5008/screenshot/chart/123.png"
assert preview.width == 800
assert preview.height == 600
assert preview.supports_interaction is False
@pytest.mark.asyncio
async def test_ascii_preview_structure(self):
"""Test ASCIIPreview response structure."""
ascii_art = """
┌─────────────────────────┐
│ Sales by Region │
├─────────────────────────┤
│ North ████████ 45%
│ South ██████ 30%
│ East ████ 20%
│ West ██ 5%
└─────────────────────────┘
"""
preview = ASCIIPreview(
ascii_content=ascii_art.strip(),
width=25,
height=8,
)
assert preview.type == "ascii"
assert "Sales by Region" in preview.ascii_content
assert preview.width == 25
assert preview.height == 8
@pytest.mark.asyncio
async def test_table_preview_structure(self):
"""Test TablePreview response structure."""
table_content = """
| Region | Sales | Profit |
|--------|--------|--------|
| North | 45000 | 12000 |
| South | 30000 | 8000 |
| East | 20000 | 5000 |
| West | 5000 | 1000 |
"""
preview = TablePreview(
table_data=table_content.strip(),
row_count=4,
supports_sorting=True,
)
assert preview.type == "table"
assert "Region" in preview.table_data
assert "North" in preview.table_data
assert preview.row_count == 4
assert preview.supports_sorting is True
@pytest.mark.asyncio
async def test_chart_preview_response_structure(self):
"""Test the expected response structure for chart preview."""
# Core fields that should always be present
_ = [
"chart_id",
"chart_name",
"chart_type",
"explore_url",
"content", # Union of URLPreview | ASCIIPreview | TablePreview
"chart_description",
"accessibility",
"performance",
]
# Additional fields that may be present for backward compatibility
_ = [
"format",
"preview_url",
"ascii_chart",
"table_data",
"width",
"height",
"schema_version",
"api_version",
]
# This is a structural test - actual integration tests would verify
# the tool returns data matching this structure
@pytest.mark.asyncio
async def test_preview_dimensions(self):
"""Test preview dimensions in response."""
# Standard dimensions that may appear in preview responses
standard_sizes = [
(800, 600), # Default
(1200, 800), # Large
(400, 300), # Small
(1920, 1080), # Full HD
]
for width, height in standard_sizes:
# URL preview with dimensions
url_preview = URLPreview(
preview_url="http://example.com/chart.png",
width=width,
height=height,
supports_interaction=False,
)
assert url_preview.width == width
assert url_preview.height == height
@pytest.mark.asyncio
async def test_error_response_structures(self):
"""Test error response structures."""
# Error responses typically follow this structure
error_response = {
"error_type": "not_found",
"message": "Chart not found",
"details": "No chart found with ID 999",
"chart_id": 999,
}
assert error_response["error_type"] == "not_found"
assert error_response["chart_id"] == 999
# Preview generation error structure
preview_error = {
"error_type": "preview_error",
"message": "Failed to generate preview",
"details": "Screenshot service unavailable",
}
assert preview_error["error_type"] == "preview_error"
@pytest.mark.asyncio
async def test_accessibility_metadata(self):
"""Test accessibility metadata structure."""
from superset.mcp_service.chart.schemas import AccessibilityMetadata
metadata = AccessibilityMetadata(
color_blind_safe=True,
alt_text="Bar chart showing sales by region",
high_contrast_available=False,
)
assert metadata.color_blind_safe is True
assert "sales by region" in metadata.alt_text
assert metadata.high_contrast_available is False
@pytest.mark.asyncio
async def test_performance_metadata(self):
"""Test performance metadata structure."""
from superset.mcp_service.chart.schemas import PerformanceMetadata
metadata = PerformanceMetadata(
query_duration_ms=150,
cache_status="hit",
optimization_suggestions=["Consider adding an index on date column"],
)
assert metadata.query_duration_ms == 150
assert metadata.cache_status == "hit"
assert len(metadata.optimization_suggestions) == 1
@pytest.mark.asyncio
async def test_chart_types_support(self):
"""Test that various chart types are supported."""
chart_types = [
"echarts_timeseries_line",
"echarts_timeseries_bar",
"echarts_area",
"echarts_timeseries_scatter",
"table",
"pie",
"big_number",
"big_number_total",
"pivot_table_v2",
"dist_bar",
"box_plot",
]
# All chart types should be previewable
for _chart_type in chart_types:
# This would be tested in integration tests
pass
@pytest.mark.asyncio
async def test_ascii_art_variations(self):
"""Test ASCII art generation for different chart types."""
# Line chart ASCII
_ = """
Sales Trend
│ ╱╲
└────────────
Jan Feb Mar
"""
# Bar chart ASCII
_ = """
Sales by Region
│ ████ North
│ ███ South
│ ██ East
│ █ West
└────────────
"""
# Pie chart ASCII
_ = """
Market Share
╭─────╮
│ 45%
│ North │
╰─────────╯
"""
# These demonstrate the expected ASCII formats for different chart types

View File

@@ -0,0 +1,385 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for update_chart MCP tool
"""
import pytest
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
LegendConfig,
TableChartConfig,
UpdateChartRequest,
XYChartConfig,
)
class TestUpdateChart:
"""Tests for update_chart MCP tool."""
@pytest.mark.asyncio
async def test_update_chart_request_structure(self):
"""Test that chart update request structures are properly formed."""
# Table chart update with numeric ID
table_config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region", label="Region"),
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
],
filters=[FilterConfig(column="year", op="=", value="2024")],
sort_by=["sales"],
)
table_request = UpdateChartRequest(identifier=123, config=table_config)
assert table_request.identifier == 123
assert table_request.config.chart_type == "table"
assert len(table_request.config.columns) == 2
assert table_request.config.columns[0].name == "region"
assert table_request.config.columns[1].aggregate == "SUM"
# XY chart update with UUID
xy_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind="line",
group_by=ColumnRef(name="category"),
x_axis=AxisConfig(title="Date", format="smart_date"),
y_axis=AxisConfig(title="Value", format="$,.2f"),
legend=LegendConfig(show=True, position="top"),
)
xy_request = UpdateChartRequest(
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config
)
assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert xy_request.config.chart_type == "xy"
assert xy_request.config.x.name == "date"
assert xy_request.config.y[0].aggregate == "SUM"
assert xy_request.config.kind == "line"
@pytest.mark.asyncio
async def test_update_chart_with_chart_name(self):
"""Test updating chart with custom chart name."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Without custom name
request1 = UpdateChartRequest(identifier=123, config=config)
assert request1.chart_name is None
# With custom name
request2 = UpdateChartRequest(
identifier=123, config=config, chart_name="Updated Sales Report"
)
assert request2.chart_name == "Updated Sales Report"
@pytest.mark.asyncio
async def test_update_chart_preview_generation(self):
"""Test preview generation options in update request."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Default preview generation
request1 = UpdateChartRequest(identifier=123, config=config)
assert request1.generate_preview is True
assert request1.preview_formats == ["url"]
# Custom preview formats
request2 = UpdateChartRequest(
identifier=123,
config=config,
generate_preview=True,
preview_formats=["url", "ascii", "table"],
)
assert request2.generate_preview is True
assert set(request2.preview_formats) == {"url", "ascii", "table"}
# Disable preview generation
request3 = UpdateChartRequest(
identifier=123, config=config, generate_preview=False
)
assert request3.generate_preview is False
@pytest.mark.asyncio
async def test_update_chart_identifier_types(self):
"""Test that identifier can be int or string (UUID)."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Integer ID
request1 = UpdateChartRequest(identifier=123, config=config)
assert request1.identifier == 123
assert isinstance(request1.identifier, int)
# String numeric ID
request2 = UpdateChartRequest(identifier="456", config=config)
assert request2.identifier == "456"
assert isinstance(request2.identifier, str)
# UUID string
request3 = UpdateChartRequest(
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=config
)
assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert isinstance(request3.identifier, str)
@pytest.mark.asyncio
async def test_update_chart_config_variations(self):
"""Test various chart configuration options in updates."""
# Test all XY chart types
chart_types = ["line", "bar", "area", "scatter"]
for chart_type in chart_types:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x_col"),
y=[ColumnRef(name="y_col")],
kind=chart_type,
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.kind == chart_type
# Test multiple Y columns
multi_y_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="profit", aggregate="AVG"),
ColumnRef(name="quantity", aggregate="COUNT"),
],
kind="line",
)
request = UpdateChartRequest(identifier=1, config=multi_y_config)
assert len(request.config.y) == 3
assert request.config.y[1].aggregate == "AVG"
# Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="]
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
table_config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
filters=filters,
)
request = UpdateChartRequest(identifier=1, config=table_config)
assert len(request.config.filters) == 6
@pytest.mark.asyncio
async def test_update_chart_response_structure(self):
"""Test the expected response structure for chart updates."""
# The response should contain these fields
expected_response = {
"chart": {
"id": int,
"slice_name": str,
"viz_type": str,
"url": str,
"uuid": str,
"updated": bool,
},
"error": None,
"success": bool,
"schema_version": str,
"api_version": str,
}
# When chart update succeeds, these additional fields may be present
optional_fields = [
"previews",
"capabilities",
"semantics",
"explore_url",
"api_endpoints",
"performance",
"accessibility",
]
# Validate structure expectations
assert "chart" in expected_response
assert "success" in expected_response
assert len(optional_fields) > 0
@pytest.mark.asyncio
async def test_update_chart_axis_configurations(self):
"""Test axis configuration updates."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="sales")],
x_axis=AxisConfig(
title="Date",
format="smart_date",
scale="linear",
),
y_axis=AxisConfig(
title="Sales Amount",
format="$,.2f",
scale="log",
),
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.x_axis.title == "Date"
assert request.config.x_axis.format == "smart_date"
assert request.config.x_axis.scale == "linear"
assert request.config.y_axis.title == "Sales Amount"
assert request.config.y_axis.format == "$,.2f"
assert request.config.y_axis.scale == "log"
@pytest.mark.asyncio
async def test_update_chart_legend_configurations(self):
"""Test legend configuration updates."""
positions = ["top", "bottom", "left", "right"]
for pos in positions:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
legend=LegendConfig(show=True, position=pos),
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.legend.position == pos
assert request.config.legend.show is True
# Hidden legend
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
legend=LegendConfig(show=False),
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.legend.show is False
@pytest.mark.asyncio
async def test_update_chart_aggregation_functions(self):
"""Test all supported aggregation functions in updates."""
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
for agg in aggs:
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="value", aggregate=agg)],
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.columns[0].aggregate == agg
@pytest.mark.asyncio
async def test_update_chart_error_responses(self):
"""Test expected error response structures."""
# Chart not found error
error_response = {
"chart": None,
"error": "No chart found with identifier: 999",
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
assert error_response["success"] is False
assert error_response["chart"] is None
assert "chart found" in error_response["error"].lower()
# General update error
update_error = {
"chart": None,
"error": "Chart update failed: Permission denied",
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
assert update_error["success"] is False
assert "failed" in update_error["error"].lower()
@pytest.mark.asyncio
async def test_chart_name_sanitization(self):
"""Test that chart names are properly sanitized."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Test with potentially problematic characters
test_names = [
"Normal Chart Name",
"Chart with 'quotes'",
'Chart with "double quotes"',
"Chart with <brackets>",
]
for name in test_names:
request = UpdateChartRequest(identifier=1, config=config, chart_name=name)
# Chart name should be set (sanitization happens in the validator)
assert request.chart_name is not None
@pytest.mark.asyncio
async def test_update_chart_with_filters(self):
"""Test updating chart with various filter configurations."""
filters = [
FilterConfig(column="region", op="=", value="North"),
FilterConfig(column="sales", op=">=", value=1000),
FilterConfig(column="date", op=">", value="2024-01-01"),
]
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region"),
ColumnRef(name="sales"),
ColumnRef(name="date"),
],
filters=filters,
)
request = UpdateChartRequest(identifier=1, config=config)
assert len(request.config.filters) == 3
assert request.config.filters[0].column == "region"
assert request.config.filters[1].op == ">="
assert request.config.filters[2].value == "2024-01-01"
@pytest.mark.asyncio
async def test_update_chart_cache_control(self):
"""Test cache control parameters in update request."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Default cache settings
request1 = UpdateChartRequest(identifier=1, config=config)
assert request1.use_cache is True
assert request1.force_refresh is False
assert request1.cache_timeout is None
# Custom cache settings
request2 = UpdateChartRequest(
identifier=1,
config=config,
use_cache=False,
force_refresh=True,
cache_timeout=300,
)
assert request2.use_cache is False
assert request2.force_refresh is True
assert request2.cache_timeout == 300

View File

@@ -0,0 +1,474 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for update_chart_preview MCP tool
"""
import pytest
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
LegendConfig,
TableChartConfig,
UpdateChartPreviewRequest,
XYChartConfig,
)
class TestUpdateChartPreview:
"""Tests for update_chart_preview MCP tool."""
@pytest.mark.asyncio
async def test_update_chart_preview_request_structure(self):
"""Test that chart preview update request structures are properly formed."""
# Table chart preview update
table_config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region", label="Region"),
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
],
filters=[FilterConfig(column="year", op="=", value="2024")],
sort_by=["sales"],
)
table_request = UpdateChartPreviewRequest(
form_data_key="abc123def456", dataset_id=1, config=table_config
)
assert table_request.form_data_key == "abc123def456"
assert table_request.dataset_id == 1
assert table_request.config.chart_type == "table"
assert len(table_request.config.columns) == 2
assert table_request.config.columns[0].name == "region"
# XY chart preview update
xy_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind="line",
group_by=ColumnRef(name="category"),
x_axis=AxisConfig(title="Date", format="smart_date"),
y_axis=AxisConfig(title="Value", format="$,.2f"),
legend=LegendConfig(show=True, position="top"),
)
xy_request = UpdateChartPreviewRequest(
form_data_key="xyz789ghi012", dataset_id="2", config=xy_config
)
assert xy_request.form_data_key == "xyz789ghi012"
assert xy_request.dataset_id == "2"
assert xy_request.config.chart_type == "xy"
assert xy_request.config.x.name == "date"
assert xy_request.config.kind == "line"
@pytest.mark.asyncio
async def test_update_chart_preview_dataset_id_types(self):
"""Test that dataset_id can be int or string (UUID)."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Integer dataset_id
request1 = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=123, config=config
)
assert request1.dataset_id == 123
assert isinstance(request1.dataset_id, int)
# String numeric dataset_id
request2 = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id="456", config=config
)
assert request2.dataset_id == "456"
assert isinstance(request2.dataset_id, str)
# UUID string dataset_id
request3 = UpdateChartPreviewRequest(
form_data_key="abc123",
dataset_id="a1b2c3d4-e5f6-7890-abcd-ef1234567890",
config=config,
)
assert request3.dataset_id == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert isinstance(request3.dataset_id, str)
@pytest.mark.asyncio
async def test_update_chart_preview_generation_options(self):
"""Test preview generation options in update preview request."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Default preview generation
request1 = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request1.generate_preview is True
assert request1.preview_formats == ["url"]
# Custom preview formats
request2 = UpdateChartPreviewRequest(
form_data_key="abc123",
dataset_id=1,
config=config,
generate_preview=True,
preview_formats=["url", "ascii", "table"],
)
assert request2.generate_preview is True
assert set(request2.preview_formats) == {"url", "ascii", "table"}
# Disable preview generation
request3 = UpdateChartPreviewRequest(
form_data_key="abc123",
dataset_id=1,
config=config,
generate_preview=False,
)
assert request3.generate_preview is False
@pytest.mark.asyncio
async def test_update_chart_preview_config_variations(self):
"""Test various chart configuration options in preview updates."""
# Test all XY chart types
chart_types = ["line", "bar", "area", "scatter"]
for chart_type in chart_types:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x_col"),
y=[ColumnRef(name="y_col")],
kind=chart_type,
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.kind == chart_type
# Test multiple Y columns
multi_y_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="profit", aggregate="AVG"),
ColumnRef(name="quantity", aggregate="COUNT"),
],
kind="line",
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=multi_y_config
)
assert len(request.config.y) == 3
assert request.config.y[1].aggregate == "AVG"
# Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="]
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
table_config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
filters=filters,
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=table_config
)
assert len(request.config.filters) == 6
@pytest.mark.asyncio
async def test_update_chart_preview_response_structure(self):
"""Test the expected response structure for chart preview updates."""
# The response should contain these fields
expected_response = {
"chart": {
"id": None, # No ID for unsaved previews
"slice_name": str,
"viz_type": str,
"url": str,
"uuid": None, # No UUID for unsaved previews
"saved": bool,
"updated": bool,
},
"error": None,
"success": bool,
"schema_version": str,
"api_version": str,
}
# When preview update succeeds, these additional fields may be present
optional_fields = [
"previews",
"capabilities",
"semantics",
"explore_url",
"form_data_key",
"previous_form_data_key",
"api_endpoints",
"performance",
"accessibility",
]
# Validate structure expectations
assert "chart" in expected_response
assert "success" in expected_response
assert len(optional_fields) > 0
assert expected_response["chart"]["id"] is None
assert expected_response["chart"]["uuid"] is None
@pytest.mark.asyncio
async def test_update_chart_preview_axis_configurations(self):
"""Test axis configuration updates in preview."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="sales")],
x_axis=AxisConfig(
title="Date",
format="smart_date",
scale="linear",
),
y_axis=AxisConfig(
title="Sales Amount",
format="$,.2f",
scale="log",
),
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.x_axis.title == "Date"
assert request.config.x_axis.format == "smart_date"
assert request.config.y_axis.title == "Sales Amount"
assert request.config.y_axis.format == "$,.2f"
@pytest.mark.asyncio
async def test_update_chart_preview_legend_configurations(self):
"""Test legend configuration updates in preview."""
positions = ["top", "bottom", "left", "right"]
for pos in positions:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
legend=LegendConfig(show=True, position=pos),
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.legend.position == pos
assert request.config.legend.show is True
# Hidden legend
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
legend=LegendConfig(show=False),
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.legend.show is False
@pytest.mark.asyncio
async def test_update_chart_preview_aggregation_functions(self):
"""Test all supported aggregation functions in preview updates."""
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
for agg in aggs:
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="value", aggregate=agg)],
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.columns[0].aggregate == agg
@pytest.mark.asyncio
async def test_update_chart_preview_error_responses(self):
"""Test expected error response structures for preview updates."""
# General update error
error_response = {
"chart": None,
"error": "Chart preview update failed: Invalid form_data_key",
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
assert error_response["success"] is False
assert error_response["chart"] is None
assert "failed" in error_response["error"].lower()
# Missing dataset error
dataset_error = {
"chart": None,
"error": "Chart preview update failed: Dataset not found",
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
assert dataset_error["success"] is False
assert "dataset" in dataset_error["error"].lower()
@pytest.mark.asyncio
async def test_update_chart_preview_with_filters(self):
"""Test updating preview with various filter configurations."""
filters = [
FilterConfig(column="region", op="=", value="North"),
FilterConfig(column="sales", op=">=", value=1000),
FilterConfig(column="date", op=">", value="2024-01-01"),
]
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region"),
ColumnRef(name="sales"),
ColumnRef(name="date"),
],
filters=filters,
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert len(request.config.filters) == 3
assert request.config.filters[0].column == "region"
assert request.config.filters[1].op == ">="
assert request.config.filters[2].value == "2024-01-01"
@pytest.mark.asyncio
async def test_update_chart_preview_form_data_key_handling(self):
"""Test form_data_key handling in preview updates."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Various form_data_key formats
form_data_keys = [
"abc123def456",
"xyz-789-ghi-012",
"key_with_underscore",
"UPPERCASE_KEY",
]
for key in form_data_keys:
request = UpdateChartPreviewRequest(
form_data_key=key, dataset_id=1, config=config
)
assert request.form_data_key == key
@pytest.mark.asyncio
async def test_update_chart_preview_cache_control(self):
"""Test cache control parameters in update preview request."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Default cache settings
request1 = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request1.use_cache is True
assert request1.force_refresh is False
assert request1.cache_form_data is True
# Custom cache settings
request2 = UpdateChartPreviewRequest(
form_data_key="abc123",
dataset_id=1,
config=config,
use_cache=False,
force_refresh=True,
cache_form_data=False,
)
assert request2.use_cache is False
assert request2.force_refresh is True
assert request2.cache_form_data is False
@pytest.mark.asyncio
async def test_update_chart_preview_no_save_behavior(self):
"""Test that preview updates don't create permanent charts."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
# Preview updates should never create permanent charts
# This is validated by checking the response structure
expected_unsaved_fields = {
"id": None, # No chart ID
"uuid": None, # No UUID
"saved": False, # Not saved
}
# These expectations are validated in the response, not the request
assert request.form_data_key == "abc123"
assert expected_unsaved_fields["id"] is None
assert expected_unsaved_fields["uuid"] is None
assert expected_unsaved_fields["saved"] is False
@pytest.mark.asyncio
async def test_update_chart_preview_multiple_y_columns(self):
"""Test preview updates with multiple Y-axis columns."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"),
ColumnRef(name="cost", aggregate="SUM", label="Total Cost"),
ColumnRef(name="profit", aggregate="SUM", label="Total Profit"),
ColumnRef(name="orders", aggregate="COUNT", label="Order Count"),
],
kind="line",
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert len(request.config.y) == 4
assert request.config.y[0].name == "revenue"
assert request.config.y[1].name == "cost"
assert request.config.y[2].name == "profit"
assert request.config.y[3].name == "orders"
assert request.config.y[3].aggregate == "COUNT"
@pytest.mark.asyncio
async def test_update_chart_preview_table_sorting(self):
"""Test table chart sorting in preview updates."""
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region"),
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="profit", aggregate="AVG"),
],
sort_by=["sales", "profit"],
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.sort_by == ["sales", "profit"]
assert len(request.config.columns) == 3