mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
feat(mcp): MCP service implementation (PRs 3-9 consolidated) (#35877)
This commit is contained in:
268
tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py
Normal file
268
tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP generate_chart tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
GenerateChartRequest,
|
||||
LegendConfig,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateChart:
|
||||
"""Tests for generate_chart MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_request_structure(self):
|
||||
"""Test that chart generation request structures are properly formed."""
|
||||
# Table chart request
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region", label="Region"),
|
||||
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
|
||||
],
|
||||
filters=[FilterConfig(column="year", op="=", value="2024")],
|
||||
sort_by=["sales"],
|
||||
)
|
||||
table_request = GenerateChartRequest(dataset_id="1", config=table_config)
|
||||
assert table_request.dataset_id == "1"
|
||||
assert table_request.config.chart_type == "table"
|
||||
assert len(table_request.config.columns) == 2
|
||||
assert table_request.config.columns[0].name == "region"
|
||||
assert table_request.config.columns[1].aggregate == "SUM"
|
||||
|
||||
# XY chart request
|
||||
xy_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="value", aggregate="SUM")],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="category"),
|
||||
x_axis=AxisConfig(title="Date", format="smart_date"),
|
||||
y_axis=AxisConfig(title="Value", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="top"),
|
||||
)
|
||||
xy_request = GenerateChartRequest(dataset_id="2", config=xy_config)
|
||||
assert xy_request.config.chart_type == "xy"
|
||||
assert xy_request.config.x.name == "date"
|
||||
assert xy_request.config.y[0].aggregate == "SUM"
|
||||
assert xy_request.config.kind == "line"
|
||||
assert xy_request.config.x_axis.title == "Date"
|
||||
assert xy_request.config.legend.show is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_validation_error_handling(self):
|
||||
"""Test that validation errors are properly structured."""
|
||||
|
||||
# Create a validation error with the correct structure
|
||||
validation_error_entry = {
|
||||
"field": "x_axis",
|
||||
"provided_value": "invalid_col",
|
||||
"error_type": "column_not_found",
|
||||
"message": "Column 'invalid_col' not found",
|
||||
}
|
||||
|
||||
# Test that validation error structure is correct
|
||||
assert validation_error_entry["field"] == "x_axis"
|
||||
assert validation_error_entry["error_type"] == "column_not_found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chart_config_variations(self):
|
||||
"""Test various chart configuration options."""
|
||||
# Test all chart types
|
||||
chart_types = ["line", "bar", "area", "scatter"]
|
||||
for chart_type in chart_types:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x_col"),
|
||||
y=[ColumnRef(name="y_col")],
|
||||
kind=chart_type,
|
||||
)
|
||||
assert config.kind == chart_type
|
||||
|
||||
# Test multiple Y columns
|
||||
multi_y_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
ColumnRef(name="profit", aggregate="AVG"),
|
||||
ColumnRef(name="quantity", aggregate="COUNT"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
assert len(multi_y_config.y) == 3
|
||||
assert multi_y_config.y[1].aggregate == "AVG"
|
||||
|
||||
# Test filter operators
|
||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
|
||||
for i, f in enumerate(filters):
|
||||
assert f.op == operators[i]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_response_structure(self):
|
||||
"""Test the expected response structure for chart generation."""
|
||||
# The response should contain these fields
|
||||
_ = {
|
||||
"chart": {
|
||||
"id": int,
|
||||
"slice_name": str,
|
||||
"viz_type": str,
|
||||
"url": str,
|
||||
"uuid": str,
|
||||
"saved": bool,
|
||||
},
|
||||
"error": None,
|
||||
"success": bool,
|
||||
"schema_version": str,
|
||||
"api_version": str,
|
||||
}
|
||||
|
||||
# When chart creation succeeds, these additional fields may be present
|
||||
_ = [
|
||||
"previews",
|
||||
"capabilities",
|
||||
"semantics",
|
||||
"explore_url",
|
||||
"form_data_key",
|
||||
"api_endpoints",
|
||||
"performance",
|
||||
"accessibility",
|
||||
]
|
||||
|
||||
# This is just a structural test - actual integration tests would verify
|
||||
# the tool returns data matching this structure
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dataset_id_flexibility(self):
|
||||
"""Test that dataset_id can be string or int."""
|
||||
configs = [
|
||||
GenerateChartRequest(
|
||||
dataset_id="123",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
),
|
||||
GenerateChartRequest(
|
||||
dataset_id="uuid-string-here",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
assert isinstance(config.dataset_id, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_chart_flag(self):
|
||||
"""Test save_chart flag behavior."""
|
||||
# Default should be True (save chart)
|
||||
request1 = GenerateChartRequest(
|
||||
dataset_id="1",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
)
|
||||
assert request1.save_chart is True
|
||||
|
||||
# Explicit False (preview only)
|
||||
request2 = GenerateChartRequest(
|
||||
dataset_id="1",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
save_chart=False,
|
||||
)
|
||||
assert request2.save_chart is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_formats(self):
|
||||
"""Test preview format options."""
|
||||
formats = ["url", "ascii", "table"]
|
||||
request = GenerateChartRequest(
|
||||
dataset_id="1",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
generate_preview=True,
|
||||
preview_formats=formats,
|
||||
)
|
||||
assert request.generate_preview is True
|
||||
assert set(request.preview_formats) == set(formats)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_column_ref_features(self):
|
||||
"""Test ColumnRef features like aggregation and labels."""
|
||||
# Simple column
|
||||
col1 = ColumnRef(name="region")
|
||||
assert col1.name == "region"
|
||||
assert col1.label is None
|
||||
assert col1.aggregate is None
|
||||
|
||||
# Column with aggregation
|
||||
col2 = ColumnRef(name="sales", aggregate="SUM", label="Total Sales")
|
||||
assert col2.name == "sales"
|
||||
assert col2.aggregate == "SUM"
|
||||
assert col2.label == "Total Sales"
|
||||
|
||||
# All supported aggregations
|
||||
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
|
||||
for agg in aggs:
|
||||
col = ColumnRef(name="value", aggregate=agg)
|
||||
assert col.aggregate == agg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_axis_config_options(self):
|
||||
"""Test axis configuration options."""
|
||||
axis = AxisConfig(
|
||||
title="Sales Amount",
|
||||
format="$,.2f",
|
||||
scale="linear",
|
||||
)
|
||||
assert axis.title == "Sales Amount"
|
||||
assert axis.format == "$,.2f"
|
||||
assert axis.scale == "linear"
|
||||
|
||||
# Test different formats
|
||||
formats = ["SMART_NUMBER", "$,.2f", ",.0f", "smart_date", ".3%"]
|
||||
for fmt in formats:
|
||||
axis = AxisConfig(format=fmt)
|
||||
assert axis.format == fmt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legend_config_options(self):
|
||||
"""Test legend configuration options."""
|
||||
positions = ["top", "bottom", "left", "right"]
|
||||
for pos in positions:
|
||||
legend = LegendConfig(show=True, position=pos)
|
||||
assert legend.position == pos
|
||||
|
||||
# Hidden legend
|
||||
legend = LegendConfig(show=False)
|
||||
assert legend.show is False
|
||||
@@ -0,0 +1,290 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for get_chart_preview MCP tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ASCIIPreview,
|
||||
GetChartPreviewRequest,
|
||||
TablePreview,
|
||||
URLPreview,
|
||||
)
|
||||
|
||||
|
||||
class TestGetChartPreview:
|
||||
"""Tests for get_chart_preview MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chart_preview_request_structure(self):
|
||||
"""Test that preview request structures are properly formed."""
|
||||
# Numeric ID request
|
||||
request1 = GetChartPreviewRequest(identifier=123, format="url")
|
||||
assert request1.identifier == 123
|
||||
assert request1.format == "url"
|
||||
# Default dimensions are set
|
||||
assert request1.width == 800
|
||||
assert request1.height == 600
|
||||
|
||||
# String ID request
|
||||
request2 = GetChartPreviewRequest(identifier="456", format="ascii")
|
||||
assert request2.identifier == "456"
|
||||
assert request2.format == "ascii"
|
||||
|
||||
# UUID request
|
||||
request3 = GetChartPreviewRequest(
|
||||
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", format="table"
|
||||
)
|
||||
assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert request3.format == "table"
|
||||
|
||||
# Default format
|
||||
request4 = GetChartPreviewRequest(identifier=789)
|
||||
assert request4.format == "url" # default
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_format_types(self):
|
||||
"""Test different preview format types."""
|
||||
formats = ["url", "ascii", "table"]
|
||||
for fmt in formats:
|
||||
request = GetChartPreviewRequest(identifier=1, format=fmt)
|
||||
assert request.format == fmt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_url_preview_structure(self):
|
||||
"""Test URLPreview response structure."""
|
||||
preview = URLPreview(
|
||||
preview_url="http://localhost:5008/screenshot/chart/123.png",
|
||||
width=800,
|
||||
height=600,
|
||||
supports_interaction=False,
|
||||
)
|
||||
assert preview.type == "url"
|
||||
assert preview.preview_url == "http://localhost:5008/screenshot/chart/123.png"
|
||||
assert preview.width == 800
|
||||
assert preview.height == 600
|
||||
assert preview.supports_interaction is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ascii_preview_structure(self):
|
||||
"""Test ASCIIPreview response structure."""
|
||||
ascii_art = """
|
||||
┌─────────────────────────┐
|
||||
│ Sales by Region │
|
||||
├─────────────────────────┤
|
||||
│ North ████████ 45% │
|
||||
│ South ██████ 30% │
|
||||
│ East ████ 20% │
|
||||
│ West ██ 5% │
|
||||
└─────────────────────────┘
|
||||
"""
|
||||
preview = ASCIIPreview(
|
||||
ascii_content=ascii_art.strip(),
|
||||
width=25,
|
||||
height=8,
|
||||
)
|
||||
assert preview.type == "ascii"
|
||||
assert "Sales by Region" in preview.ascii_content
|
||||
assert preview.width == 25
|
||||
assert preview.height == 8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_preview_structure(self):
|
||||
"""Test TablePreview response structure."""
|
||||
table_content = """
|
||||
| Region | Sales | Profit |
|
||||
|--------|--------|--------|
|
||||
| North | 45000 | 12000 |
|
||||
| South | 30000 | 8000 |
|
||||
| East | 20000 | 5000 |
|
||||
| West | 5000 | 1000 |
|
||||
"""
|
||||
preview = TablePreview(
|
||||
table_data=table_content.strip(),
|
||||
row_count=4,
|
||||
supports_sorting=True,
|
||||
)
|
||||
assert preview.type == "table"
|
||||
assert "Region" in preview.table_data
|
||||
assert "North" in preview.table_data
|
||||
assert preview.row_count == 4
|
||||
assert preview.supports_sorting is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chart_preview_response_structure(self):
|
||||
"""Test the expected response structure for chart preview."""
|
||||
# Core fields that should always be present
|
||||
_ = [
|
||||
"chart_id",
|
||||
"chart_name",
|
||||
"chart_type",
|
||||
"explore_url",
|
||||
"content", # Union of URLPreview | ASCIIPreview | TablePreview
|
||||
"chart_description",
|
||||
"accessibility",
|
||||
"performance",
|
||||
]
|
||||
|
||||
# Additional fields that may be present for backward compatibility
|
||||
_ = [
|
||||
"format",
|
||||
"preview_url",
|
||||
"ascii_chart",
|
||||
"table_data",
|
||||
"width",
|
||||
"height",
|
||||
"schema_version",
|
||||
"api_version",
|
||||
]
|
||||
|
||||
# This is a structural test - actual integration tests would verify
|
||||
# the tool returns data matching this structure
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_dimensions(self):
|
||||
"""Test preview dimensions in response."""
|
||||
# Standard dimensions that may appear in preview responses
|
||||
standard_sizes = [
|
||||
(800, 600), # Default
|
||||
(1200, 800), # Large
|
||||
(400, 300), # Small
|
||||
(1920, 1080), # Full HD
|
||||
]
|
||||
|
||||
for width, height in standard_sizes:
|
||||
# URL preview with dimensions
|
||||
url_preview = URLPreview(
|
||||
preview_url="http://example.com/chart.png",
|
||||
width=width,
|
||||
height=height,
|
||||
supports_interaction=False,
|
||||
)
|
||||
assert url_preview.width == width
|
||||
assert url_preview.height == height
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_structures(self):
|
||||
"""Test error response structures."""
|
||||
# Error responses typically follow this structure
|
||||
error_response = {
|
||||
"error_type": "not_found",
|
||||
"message": "Chart not found",
|
||||
"details": "No chart found with ID 999",
|
||||
"chart_id": 999,
|
||||
}
|
||||
assert error_response["error_type"] == "not_found"
|
||||
assert error_response["chart_id"] == 999
|
||||
|
||||
# Preview generation error structure
|
||||
preview_error = {
|
||||
"error_type": "preview_error",
|
||||
"message": "Failed to generate preview",
|
||||
"details": "Screenshot service unavailable",
|
||||
}
|
||||
assert preview_error["error_type"] == "preview_error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accessibility_metadata(self):
|
||||
"""Test accessibility metadata structure."""
|
||||
from superset.mcp_service.chart.schemas import AccessibilityMetadata
|
||||
|
||||
metadata = AccessibilityMetadata(
|
||||
color_blind_safe=True,
|
||||
alt_text="Bar chart showing sales by region",
|
||||
high_contrast_available=False,
|
||||
)
|
||||
assert metadata.color_blind_safe is True
|
||||
assert "sales by region" in metadata.alt_text
|
||||
assert metadata.high_contrast_available is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_metadata(self):
|
||||
"""Test performance metadata structure."""
|
||||
from superset.mcp_service.chart.schemas import PerformanceMetadata
|
||||
|
||||
metadata = PerformanceMetadata(
|
||||
query_duration_ms=150,
|
||||
cache_status="hit",
|
||||
optimization_suggestions=["Consider adding an index on date column"],
|
||||
)
|
||||
assert metadata.query_duration_ms == 150
|
||||
assert metadata.cache_status == "hit"
|
||||
assert len(metadata.optimization_suggestions) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chart_types_support(self):
|
||||
"""Test that various chart types are supported."""
|
||||
chart_types = [
|
||||
"echarts_timeseries_line",
|
||||
"echarts_timeseries_bar",
|
||||
"echarts_area",
|
||||
"echarts_timeseries_scatter",
|
||||
"table",
|
||||
"pie",
|
||||
"big_number",
|
||||
"big_number_total",
|
||||
"pivot_table_v2",
|
||||
"dist_bar",
|
||||
"box_plot",
|
||||
]
|
||||
|
||||
# All chart types should be previewable
|
||||
for _chart_type in chart_types:
|
||||
# This would be tested in integration tests
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ascii_art_variations(self):
|
||||
"""Test ASCII art generation for different chart types."""
|
||||
# Line chart ASCII
|
||||
_ = """
|
||||
Sales Trend
|
||||
│
|
||||
│ ╱╲
|
||||
│ ╱ ╲
|
||||
│ ╱ ╲
|
||||
│ ╱ ╲
|
||||
│ ╱ ╲
|
||||
└────────────
|
||||
Jan Feb Mar
|
||||
"""
|
||||
|
||||
# Bar chart ASCII
|
||||
_ = """
|
||||
Sales by Region
|
||||
│
|
||||
│ ████ North
|
||||
│ ███ South
|
||||
│ ██ East
|
||||
│ █ West
|
||||
└────────────
|
||||
"""
|
||||
|
||||
# Pie chart ASCII
|
||||
_ = """
|
||||
Market Share
|
||||
╭─────╮
|
||||
╱ ╲
|
||||
│ 45% │
|
||||
│ North │
|
||||
╰─────────╯
|
||||
"""
|
||||
|
||||
# These demonstrate the expected ASCII formats for different chart types
|
||||
385
tests/unit_tests/mcp_service/chart/tool/test_update_chart.py
Normal file
385
tests/unit_tests/mcp_service/chart/tool/test_update_chart.py
Normal file
@@ -0,0 +1,385 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for update_chart MCP tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
LegendConfig,
|
||||
TableChartConfig,
|
||||
UpdateChartRequest,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateChart:
|
||||
"""Tests for update_chart MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_request_structure(self):
|
||||
"""Test that chart update request structures are properly formed."""
|
||||
# Table chart update with numeric ID
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region", label="Region"),
|
||||
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
|
||||
],
|
||||
filters=[FilterConfig(column="year", op="=", value="2024")],
|
||||
sort_by=["sales"],
|
||||
)
|
||||
table_request = UpdateChartRequest(identifier=123, config=table_config)
|
||||
assert table_request.identifier == 123
|
||||
assert table_request.config.chart_type == "table"
|
||||
assert len(table_request.config.columns) == 2
|
||||
assert table_request.config.columns[0].name == "region"
|
||||
assert table_request.config.columns[1].aggregate == "SUM"
|
||||
|
||||
# XY chart update with UUID
|
||||
xy_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="value", aggregate="SUM")],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="category"),
|
||||
x_axis=AxisConfig(title="Date", format="smart_date"),
|
||||
y_axis=AxisConfig(title="Value", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="top"),
|
||||
)
|
||||
xy_request = UpdateChartRequest(
|
||||
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config
|
||||
)
|
||||
assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert xy_request.config.chart_type == "xy"
|
||||
assert xy_request.config.x.name == "date"
|
||||
assert xy_request.config.y[0].aggregate == "SUM"
|
||||
assert xy_request.config.kind == "line"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_with_chart_name(self):
|
||||
"""Test updating chart with custom chart name."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Without custom name
|
||||
request1 = UpdateChartRequest(identifier=123, config=config)
|
||||
assert request1.chart_name is None
|
||||
|
||||
# With custom name
|
||||
request2 = UpdateChartRequest(
|
||||
identifier=123, config=config, chart_name="Updated Sales Report"
|
||||
)
|
||||
assert request2.chart_name == "Updated Sales Report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_generation(self):
|
||||
"""Test preview generation options in update request."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Default preview generation
|
||||
request1 = UpdateChartRequest(identifier=123, config=config)
|
||||
assert request1.generate_preview is True
|
||||
assert request1.preview_formats == ["url"]
|
||||
|
||||
# Custom preview formats
|
||||
request2 = UpdateChartRequest(
|
||||
identifier=123,
|
||||
config=config,
|
||||
generate_preview=True,
|
||||
preview_formats=["url", "ascii", "table"],
|
||||
)
|
||||
assert request2.generate_preview is True
|
||||
assert set(request2.preview_formats) == {"url", "ascii", "table"}
|
||||
|
||||
# Disable preview generation
|
||||
request3 = UpdateChartRequest(
|
||||
identifier=123, config=config, generate_preview=False
|
||||
)
|
||||
assert request3.generate_preview is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_identifier_types(self):
|
||||
"""Test that identifier can be int or string (UUID)."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Integer ID
|
||||
request1 = UpdateChartRequest(identifier=123, config=config)
|
||||
assert request1.identifier == 123
|
||||
assert isinstance(request1.identifier, int)
|
||||
|
||||
# String numeric ID
|
||||
request2 = UpdateChartRequest(identifier="456", config=config)
|
||||
assert request2.identifier == "456"
|
||||
assert isinstance(request2.identifier, str)
|
||||
|
||||
# UUID string
|
||||
request3 = UpdateChartRequest(
|
||||
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=config
|
||||
)
|
||||
assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert isinstance(request3.identifier, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_config_variations(self):
|
||||
"""Test various chart configuration options in updates."""
|
||||
# Test all XY chart types
|
||||
chart_types = ["line", "bar", "area", "scatter"]
|
||||
for chart_type in chart_types:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x_col"),
|
||||
y=[ColumnRef(name="y_col")],
|
||||
kind=chart_type,
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.kind == chart_type
|
||||
|
||||
# Test multiple Y columns
|
||||
multi_y_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
ColumnRef(name="profit", aggregate="AVG"),
|
||||
ColumnRef(name="quantity", aggregate="COUNT"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=multi_y_config)
|
||||
assert len(request.config.y) == 3
|
||||
assert request.config.y[1].aggregate == "AVG"
|
||||
|
||||
# Test filter operators
|
||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
filters=filters,
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=table_config)
|
||||
assert len(request.config.filters) == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_response_structure(self):
|
||||
"""Test the expected response structure for chart updates."""
|
||||
# The response should contain these fields
|
||||
expected_response = {
|
||||
"chart": {
|
||||
"id": int,
|
||||
"slice_name": str,
|
||||
"viz_type": str,
|
||||
"url": str,
|
||||
"uuid": str,
|
||||
"updated": bool,
|
||||
},
|
||||
"error": None,
|
||||
"success": bool,
|
||||
"schema_version": str,
|
||||
"api_version": str,
|
||||
}
|
||||
|
||||
# When chart update succeeds, these additional fields may be present
|
||||
optional_fields = [
|
||||
"previews",
|
||||
"capabilities",
|
||||
"semantics",
|
||||
"explore_url",
|
||||
"api_endpoints",
|
||||
"performance",
|
||||
"accessibility",
|
||||
]
|
||||
|
||||
# Validate structure expectations
|
||||
assert "chart" in expected_response
|
||||
assert "success" in expected_response
|
||||
assert len(optional_fields) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_axis_configurations(self):
|
||||
"""Test axis configuration updates."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="sales")],
|
||||
x_axis=AxisConfig(
|
||||
title="Date",
|
||||
format="smart_date",
|
||||
scale="linear",
|
||||
),
|
||||
y_axis=AxisConfig(
|
||||
title="Sales Amount",
|
||||
format="$,.2f",
|
||||
scale="log",
|
||||
),
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.x_axis.title == "Date"
|
||||
assert request.config.x_axis.format == "smart_date"
|
||||
assert request.config.x_axis.scale == "linear"
|
||||
assert request.config.y_axis.title == "Sales Amount"
|
||||
assert request.config.y_axis.format == "$,.2f"
|
||||
assert request.config.y_axis.scale == "log"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_legend_configurations(self):
|
||||
"""Test legend configuration updates."""
|
||||
positions = ["top", "bottom", "left", "right"]
|
||||
for pos in positions:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
legend=LegendConfig(show=True, position=pos),
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.legend.position == pos
|
||||
assert request.config.legend.show is True
|
||||
|
||||
# Hidden legend
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
legend=LegendConfig(show=False),
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.legend.show is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_aggregation_functions(self):
|
||||
"""Test all supported aggregation functions in updates."""
|
||||
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
|
||||
for agg in aggs:
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="value", aggregate=agg)],
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.columns[0].aggregate == agg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_error_responses(self):
|
||||
"""Test expected error response structures."""
|
||||
# Chart not found error
|
||||
error_response = {
|
||||
"chart": None,
|
||||
"error": "No chart found with identifier: 999",
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
assert error_response["success"] is False
|
||||
assert error_response["chart"] is None
|
||||
assert "chart found" in error_response["error"].lower()
|
||||
|
||||
# General update error
|
||||
update_error = {
|
||||
"chart": None,
|
||||
"error": "Chart update failed: Permission denied",
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
assert update_error["success"] is False
|
||||
assert "failed" in update_error["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chart_name_sanitization(self):
|
||||
"""Test that chart names are properly sanitized."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Test with potentially problematic characters
|
||||
test_names = [
|
||||
"Normal Chart Name",
|
||||
"Chart with 'quotes'",
|
||||
'Chart with "double quotes"',
|
||||
"Chart with <brackets>",
|
||||
]
|
||||
|
||||
for name in test_names:
|
||||
request = UpdateChartRequest(identifier=1, config=config, chart_name=name)
|
||||
# Chart name should be set (sanitization happens in the validator)
|
||||
assert request.chart_name is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_with_filters(self):
|
||||
"""Test updating chart with various filter configurations."""
|
||||
filters = [
|
||||
FilterConfig(column="region", op="=", value="North"),
|
||||
FilterConfig(column="sales", op=">=", value=1000),
|
||||
FilterConfig(column="date", op=">", value="2024-01-01"),
|
||||
]
|
||||
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region"),
|
||||
ColumnRef(name="sales"),
|
||||
ColumnRef(name="date"),
|
||||
],
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert len(request.config.filters) == 3
|
||||
assert request.config.filters[0].column == "region"
|
||||
assert request.config.filters[1].op == ">="
|
||||
assert request.config.filters[2].value == "2024-01-01"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_cache_control(self):
|
||||
"""Test cache control parameters in update request."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Default cache settings
|
||||
request1 = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request1.use_cache is True
|
||||
assert request1.force_refresh is False
|
||||
assert request1.cache_timeout is None
|
||||
|
||||
# Custom cache settings
|
||||
request2 = UpdateChartRequest(
|
||||
identifier=1,
|
||||
config=config,
|
||||
use_cache=False,
|
||||
force_refresh=True,
|
||||
cache_timeout=300,
|
||||
)
|
||||
assert request2.use_cache is False
|
||||
assert request2.force_refresh is True
|
||||
assert request2.cache_timeout == 300
|
||||
@@ -0,0 +1,474 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for update_chart_preview MCP tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
LegendConfig,
|
||||
TableChartConfig,
|
||||
UpdateChartPreviewRequest,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateChartPreview:
|
||||
"""Tests for update_chart_preview MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_request_structure(self):
|
||||
"""Test that chart preview update request structures are properly formed."""
|
||||
# Table chart preview update
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region", label="Region"),
|
||||
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
|
||||
],
|
||||
filters=[FilterConfig(column="year", op="=", value="2024")],
|
||||
sort_by=["sales"],
|
||||
)
|
||||
table_request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123def456", dataset_id=1, config=table_config
|
||||
)
|
||||
assert table_request.form_data_key == "abc123def456"
|
||||
assert table_request.dataset_id == 1
|
||||
assert table_request.config.chart_type == "table"
|
||||
assert len(table_request.config.columns) == 2
|
||||
assert table_request.config.columns[0].name == "region"
|
||||
|
||||
# XY chart preview update
|
||||
xy_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="value", aggregate="SUM")],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="category"),
|
||||
x_axis=AxisConfig(title="Date", format="smart_date"),
|
||||
y_axis=AxisConfig(title="Value", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="top"),
|
||||
)
|
||||
xy_request = UpdateChartPreviewRequest(
|
||||
form_data_key="xyz789ghi012", dataset_id="2", config=xy_config
|
||||
)
|
||||
assert xy_request.form_data_key == "xyz789ghi012"
|
||||
assert xy_request.dataset_id == "2"
|
||||
assert xy_request.config.chart_type == "xy"
|
||||
assert xy_request.config.x.name == "date"
|
||||
assert xy_request.config.kind == "line"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_dataset_id_types(self):
|
||||
"""Test that dataset_id can be int or string (UUID)."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Integer dataset_id
|
||||
request1 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=123, config=config
|
||||
)
|
||||
assert request1.dataset_id == 123
|
||||
assert isinstance(request1.dataset_id, int)
|
||||
|
||||
# String numeric dataset_id
|
||||
request2 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id="456", config=config
|
||||
)
|
||||
assert request2.dataset_id == "456"
|
||||
assert isinstance(request2.dataset_id, str)
|
||||
|
||||
# UUID string dataset_id
|
||||
request3 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123",
|
||||
dataset_id="a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||
config=config,
|
||||
)
|
||||
assert request3.dataset_id == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert isinstance(request3.dataset_id, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_generation_options(self):
|
||||
"""Test preview generation options in update preview request."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Default preview generation
|
||||
request1 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request1.generate_preview is True
|
||||
assert request1.preview_formats == ["url"]
|
||||
|
||||
# Custom preview formats
|
||||
request2 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123",
|
||||
dataset_id=1,
|
||||
config=config,
|
||||
generate_preview=True,
|
||||
preview_formats=["url", "ascii", "table"],
|
||||
)
|
||||
assert request2.generate_preview is True
|
||||
assert set(request2.preview_formats) == {"url", "ascii", "table"}
|
||||
|
||||
# Disable preview generation
|
||||
request3 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123",
|
||||
dataset_id=1,
|
||||
config=config,
|
||||
generate_preview=False,
|
||||
)
|
||||
assert request3.generate_preview is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_config_variations(self):
|
||||
"""Test various chart configuration options in preview updates."""
|
||||
# Test all XY chart types
|
||||
chart_types = ["line", "bar", "area", "scatter"]
|
||||
for chart_type in chart_types:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x_col"),
|
||||
y=[ColumnRef(name="y_col")],
|
||||
kind=chart_type,
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.kind == chart_type
|
||||
|
||||
# Test multiple Y columns
|
||||
multi_y_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
ColumnRef(name="profit", aggregate="AVG"),
|
||||
ColumnRef(name="quantity", aggregate="COUNT"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=multi_y_config
|
||||
)
|
||||
assert len(request.config.y) == 3
|
||||
assert request.config.y[1].aggregate == "AVG"
|
||||
|
||||
# Test filter operators
|
||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
filters=filters,
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=table_config
|
||||
)
|
||||
assert len(request.config.filters) == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_response_structure(self):
|
||||
"""Test the expected response structure for chart preview updates."""
|
||||
# The response should contain these fields
|
||||
expected_response = {
|
||||
"chart": {
|
||||
"id": None, # No ID for unsaved previews
|
||||
"slice_name": str,
|
||||
"viz_type": str,
|
||||
"url": str,
|
||||
"uuid": None, # No UUID for unsaved previews
|
||||
"saved": bool,
|
||||
"updated": bool,
|
||||
},
|
||||
"error": None,
|
||||
"success": bool,
|
||||
"schema_version": str,
|
||||
"api_version": str,
|
||||
}
|
||||
|
||||
# When preview update succeeds, these additional fields may be present
|
||||
optional_fields = [
|
||||
"previews",
|
||||
"capabilities",
|
||||
"semantics",
|
||||
"explore_url",
|
||||
"form_data_key",
|
||||
"previous_form_data_key",
|
||||
"api_endpoints",
|
||||
"performance",
|
||||
"accessibility",
|
||||
]
|
||||
|
||||
# Validate structure expectations
|
||||
assert "chart" in expected_response
|
||||
assert "success" in expected_response
|
||||
assert len(optional_fields) > 0
|
||||
assert expected_response["chart"]["id"] is None
|
||||
assert expected_response["chart"]["uuid"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_axis_configurations(self):
|
||||
"""Test axis configuration updates in preview."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="sales")],
|
||||
x_axis=AxisConfig(
|
||||
title="Date",
|
||||
format="smart_date",
|
||||
scale="linear",
|
||||
),
|
||||
y_axis=AxisConfig(
|
||||
title="Sales Amount",
|
||||
format="$,.2f",
|
||||
scale="log",
|
||||
),
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.x_axis.title == "Date"
|
||||
assert request.config.x_axis.format == "smart_date"
|
||||
assert request.config.y_axis.title == "Sales Amount"
|
||||
assert request.config.y_axis.format == "$,.2f"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_legend_configurations(self):
|
||||
"""Test legend configuration updates in preview."""
|
||||
positions = ["top", "bottom", "left", "right"]
|
||||
for pos in positions:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
legend=LegendConfig(show=True, position=pos),
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.legend.position == pos
|
||||
assert request.config.legend.show is True
|
||||
|
||||
# Hidden legend
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
legend=LegendConfig(show=False),
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.legend.show is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_aggregation_functions(self):
|
||||
"""Test all supported aggregation functions in preview updates."""
|
||||
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
|
||||
for agg in aggs:
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="value", aggregate=agg)],
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.columns[0].aggregate == agg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_error_responses(self):
|
||||
"""Test expected error response structures for preview updates."""
|
||||
# General update error
|
||||
error_response = {
|
||||
"chart": None,
|
||||
"error": "Chart preview update failed: Invalid form_data_key",
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
assert error_response["success"] is False
|
||||
assert error_response["chart"] is None
|
||||
assert "failed" in error_response["error"].lower()
|
||||
|
||||
# Missing dataset error
|
||||
dataset_error = {
|
||||
"chart": None,
|
||||
"error": "Chart preview update failed: Dataset not found",
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
assert dataset_error["success"] is False
|
||||
assert "dataset" in dataset_error["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_with_filters(self):
|
||||
"""Test updating preview with various filter configurations."""
|
||||
filters = [
|
||||
FilterConfig(column="region", op="=", value="North"),
|
||||
FilterConfig(column="sales", op=">=", value=1000),
|
||||
FilterConfig(column="date", op=">", value="2024-01-01"),
|
||||
]
|
||||
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region"),
|
||||
ColumnRef(name="sales"),
|
||||
ColumnRef(name="date"),
|
||||
],
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert len(request.config.filters) == 3
|
||||
assert request.config.filters[0].column == "region"
|
||||
assert request.config.filters[1].op == ">="
|
||||
assert request.config.filters[2].value == "2024-01-01"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_form_data_key_handling(self):
|
||||
"""Test form_data_key handling in preview updates."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Various form_data_key formats
|
||||
form_data_keys = [
|
||||
"abc123def456",
|
||||
"xyz-789-ghi-012",
|
||||
"key_with_underscore",
|
||||
"UPPERCASE_KEY",
|
||||
]
|
||||
|
||||
for key in form_data_keys:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key=key, dataset_id=1, config=config
|
||||
)
|
||||
assert request.form_data_key == key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_cache_control(self):
|
||||
"""Test cache control parameters in update preview request."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Default cache settings
|
||||
request1 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request1.use_cache is True
|
||||
assert request1.force_refresh is False
|
||||
assert request1.cache_form_data is True
|
||||
|
||||
# Custom cache settings
|
||||
request2 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123",
|
||||
dataset_id=1,
|
||||
config=config,
|
||||
use_cache=False,
|
||||
force_refresh=True,
|
||||
cache_form_data=False,
|
||||
)
|
||||
assert request2.use_cache is False
|
||||
assert request2.force_refresh is True
|
||||
assert request2.cache_form_data is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_no_save_behavior(self):
|
||||
"""Test that preview updates don't create permanent charts."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
|
||||
# Preview updates should never create permanent charts
|
||||
# This is validated by checking the response structure
|
||||
expected_unsaved_fields = {
|
||||
"id": None, # No chart ID
|
||||
"uuid": None, # No UUID
|
||||
"saved": False, # Not saved
|
||||
}
|
||||
|
||||
# These expectations are validated in the response, not the request
|
||||
assert request.form_data_key == "abc123"
|
||||
assert expected_unsaved_fields["id"] is None
|
||||
assert expected_unsaved_fields["uuid"] is None
|
||||
assert expected_unsaved_fields["saved"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_multiple_y_columns(self):
|
||||
"""Test preview updates with multiple Y-axis columns."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"),
|
||||
ColumnRef(name="cost", aggregate="SUM", label="Total Cost"),
|
||||
ColumnRef(name="profit", aggregate="SUM", label="Total Profit"),
|
||||
ColumnRef(name="orders", aggregate="COUNT", label="Order Count"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert len(request.config.y) == 4
|
||||
assert request.config.y[0].name == "revenue"
|
||||
assert request.config.y[1].name == "cost"
|
||||
assert request.config.y[2].name == "profit"
|
||||
assert request.config.y[3].name == "orders"
|
||||
assert request.config.y[3].aggregate == "COUNT"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_table_sorting(self):
|
||||
"""Test table chart sorting in preview updates."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region"),
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
ColumnRef(name="profit", aggregate="AVG"),
|
||||
],
|
||||
sort_by=["sales", "profit"],
|
||||
)
|
||||
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.sort_by == ["sales", "profit"]
|
||||
assert len(request.config.columns) == 3
|
||||
Reference in New Issue
Block a user