mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat(mcp): MCP service implementation (PRs 3-9 consolidated) (#35877)
This commit is contained in:
160
tests/unit_tests/mcp_service/chart/test_chart_schemas.py
Normal file
160
tests/unit_tests/mcp_service/chart/test_chart_schemas.py
Normal file
@@ -0,0 +1,160 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP chart schema validation.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ColumnRef,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestTableChartConfig:
|
||||
"""Test TableChartConfig validation."""
|
||||
|
||||
def test_duplicate_labels_rejected(self) -> None:
|
||||
"""Test that TableChartConfig rejects duplicate labels."""
|
||||
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
|
||||
TableChartConfig(
|
||||
columns=[
|
||||
ColumnRef(name="product_line", label="product_line"),
|
||||
ColumnRef(name="sales", aggregate="SUM", label="product_line"),
|
||||
]
|
||||
)
|
||||
|
||||
def test_unique_labels_accepted(self) -> None:
|
||||
"""Test that TableChartConfig accepts unique labels."""
|
||||
config = TableChartConfig(
|
||||
columns=[
|
||||
ColumnRef(name="product_line", label="Product Line"),
|
||||
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
|
||||
]
|
||||
)
|
||||
assert len(config.columns) == 2
|
||||
|
||||
|
||||
class TestXYChartConfig:
|
||||
"""Test XYChartConfig validation."""
|
||||
|
||||
def test_different_labels_accepted(self) -> None:
|
||||
"""Test that different labels for x and y are accepted."""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="product_line"), # Label: "product_line"
|
||||
y=[
|
||||
ColumnRef(
|
||||
name="product_line", aggregate="COUNT"
|
||||
), # Label: "COUNT(product_line)"
|
||||
],
|
||||
)
|
||||
assert config.x.name == "product_line"
|
||||
assert config.y[0].aggregate == "COUNT"
|
||||
|
||||
def test_explicit_duplicate_label_rejected(self) -> None:
|
||||
"""Test that explicit duplicate labels are rejected."""
|
||||
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
|
||||
XYChartConfig(
|
||||
x=ColumnRef(name="product_line"),
|
||||
y=[ColumnRef(name="sales", label="product_line")],
|
||||
)
|
||||
|
||||
def test_duplicate_y_axis_labels_rejected(self) -> None:
|
||||
"""Test that duplicate y-axis labels are rejected."""
|
||||
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
|
||||
XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
ColumnRef(name="revenue", aggregate="SUM", label="SUM(sales)"),
|
||||
],
|
||||
)
|
||||
|
||||
def test_unique_labels_accepted(self) -> None:
|
||||
"""Test that unique labels are accepted."""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date", label="Order Date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
|
||||
ColumnRef(name="profit", aggregate="AVG", label="Average Profit"),
|
||||
],
|
||||
)
|
||||
assert len(config.y) == 2
|
||||
|
||||
def test_group_by_duplicate_with_x_rejected(self) -> None:
|
||||
"""Test that group_by conflicts with x are rejected."""
|
||||
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
|
||||
XYChartConfig(
|
||||
x=ColumnRef(name="region"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
group_by=ColumnRef(name="category", label="region"),
|
||||
)
|
||||
|
||||
def test_realistic_chart_configurations(self) -> None:
|
||||
"""Test realistic chart configurations."""
|
||||
# This should work - COUNT(product_line) != product_line
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="product_line"),
|
||||
y=[
|
||||
ColumnRef(name="product_line", aggregate="COUNT"),
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
],
|
||||
)
|
||||
assert config.x.name == "product_line"
|
||||
assert len(config.y) == 2
|
||||
|
||||
def test_time_series_chart_configuration(self) -> None:
|
||||
"""Test time series chart configuration works."""
|
||||
# This should PASS now - the chart creation logic fixes duplicates
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="order_date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
|
||||
ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
assert config.x.name == "order_date"
|
||||
assert config.kind == "line"
|
||||
|
||||
def test_time_series_with_custom_x_axis_label(self) -> None:
|
||||
"""Test time series chart with custom x-axis label."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="order_date", label="Order Date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
|
||||
ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
assert config.x.label == "Order Date"
|
||||
|
||||
def test_area_chart_configuration(self) -> None:
|
||||
"""Test area chart configuration."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="category"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
|
||||
kind="area",
|
||||
)
|
||||
assert config.kind == "area"
|
||||
465
tests/unit_tests/mcp_service/chart/test_chart_utils.py
Normal file
465
tests/unit_tests/mcp_service/chart/test_chart_utils.py
Normal file
@@ -0,0 +1,465 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for chart utilities module"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
create_metric_object,
|
||||
generate_chart_name,
|
||||
generate_explore_link,
|
||||
map_config_to_form_data,
|
||||
map_filter_operator,
|
||||
map_table_config,
|
||||
map_xy_config,
|
||||
)
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
LegendConfig,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestCreateMetricObject:
|
||||
"""Test create_metric_object function"""
|
||||
|
||||
def test_create_metric_object_with_aggregate(self) -> None:
|
||||
"""Test creating metric object with specified aggregate"""
|
||||
col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue")
|
||||
result = create_metric_object(col)
|
||||
|
||||
assert result["aggregate"] == "SUM"
|
||||
assert result["column"]["column_name"] == "revenue"
|
||||
assert result["label"] == "Total Revenue"
|
||||
assert result["optionName"] == "metric_revenue"
|
||||
assert result["expressionType"] == "SIMPLE"
|
||||
|
||||
def test_create_metric_object_default_aggregate(self) -> None:
|
||||
"""Test creating metric object with default aggregate"""
|
||||
col = ColumnRef(name="orders")
|
||||
result = create_metric_object(col)
|
||||
|
||||
assert result["aggregate"] == "SUM"
|
||||
assert result["column"]["column_name"] == "orders"
|
||||
assert result["label"] == "SUM(orders)"
|
||||
assert result["optionName"] == "metric_orders"
|
||||
|
||||
|
||||
class TestMapFilterOperator:
|
||||
"""Test map_filter_operator function"""
|
||||
|
||||
def test_map_filter_operators(self) -> None:
|
||||
"""Test mapping of various filter operators"""
|
||||
assert map_filter_operator("=") == "=="
|
||||
assert map_filter_operator(">") == ">"
|
||||
assert map_filter_operator("<") == "<"
|
||||
assert map_filter_operator(">=") == ">="
|
||||
assert map_filter_operator("<=") == "<="
|
||||
assert map_filter_operator("!=") == "!="
|
||||
|
||||
def test_map_filter_operator_unknown(self) -> None:
|
||||
"""Test mapping of unknown operator returns original"""
|
||||
assert map_filter_operator("UNKNOWN") == "UNKNOWN"
|
||||
|
||||
|
||||
class TestMapTableConfig:
|
||||
"""Test map_table_config function"""
|
||||
|
||||
def test_map_table_config_basic(self) -> None:
|
||||
"""Test basic table config mapping with aggregated columns"""
|
||||
config = TableChartConfig(
|
||||
columns=[
|
||||
ColumnRef(name="product", aggregate="COUNT"),
|
||||
ColumnRef(name="revenue", aggregate="SUM"),
|
||||
]
|
||||
)
|
||||
|
||||
result = map_table_config(config)
|
||||
|
||||
assert result["viz_type"] == "table"
|
||||
assert result["query_mode"] == "aggregate"
|
||||
# Aggregated columns should be in metrics, not all_columns
|
||||
assert "all_columns" not in result
|
||||
assert len(result["metrics"]) == 2
|
||||
assert result["metrics"][0]["aggregate"] == "COUNT"
|
||||
assert result["metrics"][1]["aggregate"] == "SUM"
|
||||
|
||||
def test_map_table_config_raw_columns(self) -> None:
|
||||
"""Test table config mapping with raw columns (no aggregates)"""
|
||||
config = TableChartConfig(
|
||||
columns=[
|
||||
ColumnRef(name="product"),
|
||||
ColumnRef(name="category"),
|
||||
]
|
||||
)
|
||||
|
||||
result = map_table_config(config)
|
||||
|
||||
assert result["viz_type"] == "table"
|
||||
assert result["query_mode"] == "raw"
|
||||
# Raw columns should be in all_columns
|
||||
assert result["all_columns"] == ["product", "category"]
|
||||
assert "metrics" not in result
|
||||
|
||||
def test_map_table_config_with_filters(self) -> None:
|
||||
"""Test table config mapping with filters"""
|
||||
config = TableChartConfig(
|
||||
columns=[ColumnRef(name="product")],
|
||||
filters=[FilterConfig(column="status", op="=", value="active")],
|
||||
)
|
||||
|
||||
result = map_table_config(config)
|
||||
|
||||
assert "adhoc_filters" in result
|
||||
assert len(result["adhoc_filters"]) == 1
|
||||
filter_obj = result["adhoc_filters"][0]
|
||||
assert filter_obj["subject"] == "status"
|
||||
assert filter_obj["operator"] == "=="
|
||||
assert filter_obj["comparator"] == "active"
|
||||
assert filter_obj["expressionType"] == "SIMPLE"
|
||||
|
||||
def test_map_table_config_with_sort(self) -> None:
|
||||
"""Test table config mapping with sort"""
|
||||
config = TableChartConfig(
|
||||
columns=[ColumnRef(name="product")], sort_by=["product", "revenue"]
|
||||
)
|
||||
|
||||
result = map_table_config(config)
|
||||
assert result["order_by_cols"] == ["product", "revenue"]
|
||||
|
||||
|
||||
class TestMapXYConfig:
|
||||
"""Test map_xy_config function"""
|
||||
|
||||
def test_map_xy_config_line_chart(self) -> None:
|
||||
"""Test XY config mapping for line chart"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
kind="line",
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert result["viz_type"] == "echarts_timeseries_line"
|
||||
assert result["x_axis"] == "date"
|
||||
assert len(result["metrics"]) == 1
|
||||
assert result["metrics"][0]["aggregate"] == "SUM"
|
||||
|
||||
def test_map_xy_config_with_groupby(self) -> None:
|
||||
"""Test XY config mapping with group by"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue")],
|
||||
kind="bar",
|
||||
group_by=ColumnRef(name="region"),
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert result["viz_type"] == "echarts_timeseries_bar"
|
||||
assert result["groupby"] == ["region"]
|
||||
|
||||
def test_map_xy_config_with_axes(self) -> None:
|
||||
"""Test XY config mapping with axis configurations"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue")],
|
||||
kind="area",
|
||||
x_axis=AxisConfig(title="Date", format="%Y-%m-%d"),
|
||||
y_axis=AxisConfig(title="Revenue", scale="log", format="$,.2f"),
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert result["viz_type"] == "echarts_area"
|
||||
assert result["x_axis_title"] == "Date"
|
||||
assert result["x_axis_format"] == "%Y-%m-%d"
|
||||
assert result["y_axis_title"] == "Revenue"
|
||||
assert result["y_axis_format"] == "$,.2f"
|
||||
assert result["y_axis_scale"] == "log"
|
||||
|
||||
def test_map_xy_config_with_legend(self) -> None:
|
||||
"""Test XY config mapping with legend configuration"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue")],
|
||||
kind="scatter",
|
||||
legend=LegendConfig(show=False, position="top"),
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert result["viz_type"] == "echarts_timeseries_scatter"
|
||||
assert result["show_legend"] is False
|
||||
assert result["legend_orientation"] == "top"
|
||||
|
||||
|
||||
class TestMapConfigToFormData:
|
||||
"""Test map_config_to_form_data function"""
|
||||
|
||||
def test_map_table_config_type(self) -> None:
|
||||
"""Test mapping table config type"""
|
||||
config = TableChartConfig(columns=[ColumnRef(name="test")])
|
||||
result = map_config_to_form_data(config)
|
||||
assert result["viz_type"] == "table"
|
||||
|
||||
def test_map_xy_config_type(self) -> None:
|
||||
"""Test mapping XY config type"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line"
|
||||
)
|
||||
result = map_config_to_form_data(config)
|
||||
assert result["viz_type"] == "echarts_timeseries_line"
|
||||
|
||||
def test_map_unsupported_config_type(self) -> None:
|
||||
"""Test mapping unsupported config type raises error"""
|
||||
with pytest.raises(ValueError, match="Unsupported config type"):
|
||||
map_config_to_form_data("invalid_config") # type: ignore
|
||||
|
||||
|
||||
class TestGenerateChartName:
|
||||
"""Test generate_chart_name function"""
|
||||
|
||||
def test_generate_table_chart_name(self) -> None:
|
||||
"""Test generating name for table chart"""
|
||||
config = TableChartConfig(
|
||||
columns=[
|
||||
ColumnRef(name="product"),
|
||||
ColumnRef(name="revenue"),
|
||||
]
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Table Chart - product, revenue"
|
||||
|
||||
def test_generate_xy_chart_name(self) -> None:
|
||||
"""Test generating name for XY chart"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue"), ColumnRef(name="orders")],
|
||||
kind="line",
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Line Chart - date vs revenue, orders"
|
||||
|
||||
def test_generate_chart_name_unsupported(self) -> None:
|
||||
"""Test generating name for unsupported config type"""
|
||||
result = generate_chart_name("invalid_config") # type: ignore
|
||||
assert result == "Chart"
|
||||
|
||||
|
||||
class TestGenerateExploreLink:
|
||||
"""Test generate_explore_link function"""
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
|
||||
def test_generate_explore_link_uses_base_url(self, mock_get_base_url) -> None:
|
||||
"""Test that generate_explore_link uses the configured base URL"""
|
||||
from urllib.parse import urlparse
|
||||
|
||||
mock_get_base_url.return_value = "https://superset.example.com"
|
||||
form_data = {"viz_type": "table", "metrics": ["count"]}
|
||||
|
||||
result = generate_explore_link("123", form_data)
|
||||
|
||||
# Should use the configured base URL - use urlparse to avoid CodeQL warning
|
||||
parsed_url = urlparse(result)
|
||||
expected_netloc = "superset.example.com"
|
||||
assert parsed_url.scheme == "https"
|
||||
assert parsed_url.netloc == expected_netloc
|
||||
assert "/explore/" in parsed_url.path
|
||||
assert "datasource_id=123" in result
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
|
||||
def test_generate_explore_link_fallback_url(self, mock_get_base_url) -> None:
|
||||
"""Test generate_explore_link returns fallback URL when dataset not found"""
|
||||
mock_get_base_url.return_value = "http://localhost:9001"
|
||||
form_data = {"viz_type": "table"}
|
||||
|
||||
# Mock dataset not found scenario
|
||||
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None):
|
||||
result = generate_explore_link("999", form_data)
|
||||
|
||||
assert (
|
||||
result
|
||||
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=999"
|
||||
)
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
|
||||
@patch("superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand")
|
||||
def test_generate_explore_link_with_form_data_key(
|
||||
self, mock_command, mock_get_base_url
|
||||
) -> None:
|
||||
"""Test generate_explore_link creates form_data_key when dataset exists"""
|
||||
mock_get_base_url.return_value = "http://localhost:9001"
|
||||
mock_command.return_value.run.return_value = "test_form_data_key"
|
||||
|
||||
# Mock dataset exists
|
||||
mock_dataset = type("Dataset", (), {"id": 123})()
|
||||
with patch(
|
||||
"superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_dataset
|
||||
):
|
||||
result = generate_explore_link(123, {"viz_type": "table"})
|
||||
|
||||
assert (
|
||||
result == "http://localhost:9001/explore/?form_data_key=test_form_data_key"
|
||||
)
|
||||
mock_command.assert_called_once()
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
|
||||
def test_generate_explore_link_exception_handling(self, mock_get_base_url) -> None:
|
||||
"""Test generate_explore_link handles exceptions gracefully"""
|
||||
mock_get_base_url.return_value = "http://localhost:9001"
|
||||
|
||||
# Mock exception during form_data creation
|
||||
with patch(
|
||||
"superset.daos.dataset.DatasetDAO.find_by_id",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
result = generate_explore_link("123", {"viz_type": "table"})
|
||||
|
||||
# Should fallback to basic URL
|
||||
assert (
|
||||
result
|
||||
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=123"
|
||||
)
|
||||
|
||||
|
||||
class TestCriticalBugFixes:
|
||||
"""Test critical bug fixes for chart utilities."""
|
||||
|
||||
def test_time_series_aggregation_fix(self) -> None:
|
||||
"""Test that time series charts preserve temporal dimension."""
|
||||
# Create a time series chart configuration
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
kind="line",
|
||||
x=ColumnRef(name="order_date"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
|
||||
)
|
||||
|
||||
form_data = map_xy_config(config)
|
||||
|
||||
# Verify the fix: x_axis should be set correctly
|
||||
assert form_data["x_axis"] == "order_date"
|
||||
|
||||
# Verify the fix: groupby should not duplicate x_axis
|
||||
# This prevents the "Duplicate column/metric labels" error
|
||||
assert "groupby" not in form_data or "order_date" not in form_data.get(
|
||||
"groupby", []
|
||||
)
|
||||
|
||||
# Verify chart type mapping
|
||||
assert form_data["viz_type"] == "echarts_timeseries_line"
|
||||
|
||||
def test_time_series_with_explicit_group_by(self) -> None:
|
||||
"""Test time series with explicit group_by different from x_axis."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
kind="line",
|
||||
x=ColumnRef(name="order_date"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
|
||||
group_by=ColumnRef(name="category"),
|
||||
)
|
||||
|
||||
form_data = map_xy_config(config)
|
||||
|
||||
# Verify x_axis is set
|
||||
assert form_data["x_axis"] == "order_date"
|
||||
|
||||
# Verify groupby only contains the explicit group_by, not x_axis
|
||||
assert form_data.get("groupby") == ["category"]
|
||||
assert "order_date" not in form_data.get("groupby", [])
|
||||
|
||||
def test_duplicate_label_prevention(self) -> None:
|
||||
"""Test that duplicate column/metric labels are prevented."""
|
||||
# This configuration would previously cause:
|
||||
# "Duplicate column/metric labels: 'price_each'"
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="price_each", label="Price Each"), # Custom label
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
|
||||
ColumnRef(name="quantity", aggregate="COUNT", label="Order Count"),
|
||||
],
|
||||
group_by=ColumnRef(name="price_each"), # Same column as x_axis
|
||||
kind="line",
|
||||
)
|
||||
|
||||
form_data = map_xy_config(config)
|
||||
|
||||
# Verify the fix: x_axis is set
|
||||
assert form_data["x_axis"] == "price_each"
|
||||
|
||||
# Verify the fix: groupby is empty because group_by == x_axis
|
||||
# This prevents the duplicate label error
|
||||
assert "groupby" not in form_data or not form_data["groupby"]
|
||||
|
||||
def test_metric_object_creation_with_labels(self) -> None:
|
||||
"""Test that metric objects are created correctly with proper labels."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
|
||||
ColumnRef(name="profit", aggregate="AVG"), # No custom label
|
||||
],
|
||||
kind="bar",
|
||||
)
|
||||
|
||||
form_data = map_xy_config(config)
|
||||
|
||||
# Verify metrics are created correctly
|
||||
metrics = form_data["metrics"]
|
||||
assert len(metrics) == 2
|
||||
|
||||
# First metric with custom label
|
||||
assert metrics[0]["label"] == "Total Sales"
|
||||
assert metrics[0]["aggregate"] == "SUM"
|
||||
assert metrics[0]["column"]["column_name"] == "sales"
|
||||
|
||||
# Second metric with auto-generated label
|
||||
assert metrics[1]["label"] == "AVG(profit)"
|
||||
assert metrics[1]["aggregate"] == "AVG"
|
||||
assert metrics[1]["column"]["column_name"] == "profit"
|
||||
|
||||
def test_chart_type_mapping_comprehensive(self) -> None:
|
||||
"""Test that chart types are mapped correctly to Superset viz types."""
|
||||
test_cases = [
|
||||
("line", "echarts_timeseries_line"),
|
||||
("bar", "echarts_timeseries_bar"),
|
||||
("area", "echarts_area"),
|
||||
("scatter", "echarts_timeseries_scatter"),
|
||||
]
|
||||
|
||||
for kind, expected_viz_type in test_cases:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="value", aggregate="SUM")],
|
||||
kind=kind,
|
||||
)
|
||||
|
||||
form_data = map_xy_config(config)
|
||||
assert form_data["viz_type"] == expected_viz_type
|
||||
268
tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py
Normal file
268
tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py
Normal file
@@ -0,0 +1,268 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP generate_chart tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
GenerateChartRequest,
|
||||
LegendConfig,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateChart:
|
||||
"""Tests for generate_chart MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_request_structure(self):
|
||||
"""Test that chart generation request structures are properly formed."""
|
||||
# Table chart request
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region", label="Region"),
|
||||
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
|
||||
],
|
||||
filters=[FilterConfig(column="year", op="=", value="2024")],
|
||||
sort_by=["sales"],
|
||||
)
|
||||
table_request = GenerateChartRequest(dataset_id="1", config=table_config)
|
||||
assert table_request.dataset_id == "1"
|
||||
assert table_request.config.chart_type == "table"
|
||||
assert len(table_request.config.columns) == 2
|
||||
assert table_request.config.columns[0].name == "region"
|
||||
assert table_request.config.columns[1].aggregate == "SUM"
|
||||
|
||||
# XY chart request
|
||||
xy_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="value", aggregate="SUM")],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="category"),
|
||||
x_axis=AxisConfig(title="Date", format="smart_date"),
|
||||
y_axis=AxisConfig(title="Value", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="top"),
|
||||
)
|
||||
xy_request = GenerateChartRequest(dataset_id="2", config=xy_config)
|
||||
assert xy_request.config.chart_type == "xy"
|
||||
assert xy_request.config.x.name == "date"
|
||||
assert xy_request.config.y[0].aggregate == "SUM"
|
||||
assert xy_request.config.kind == "line"
|
||||
assert xy_request.config.x_axis.title == "Date"
|
||||
assert xy_request.config.legend.show is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_validation_error_handling(self):
|
||||
"""Test that validation errors are properly structured."""
|
||||
|
||||
# Create a validation error with the correct structure
|
||||
validation_error_entry = {
|
||||
"field": "x_axis",
|
||||
"provided_value": "invalid_col",
|
||||
"error_type": "column_not_found",
|
||||
"message": "Column 'invalid_col' not found",
|
||||
}
|
||||
|
||||
# Test that validation error structure is correct
|
||||
assert validation_error_entry["field"] == "x_axis"
|
||||
assert validation_error_entry["error_type"] == "column_not_found"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chart_config_variations(self):
|
||||
"""Test various chart configuration options."""
|
||||
# Test all chart types
|
||||
chart_types = ["line", "bar", "area", "scatter"]
|
||||
for chart_type in chart_types:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x_col"),
|
||||
y=[ColumnRef(name="y_col")],
|
||||
kind=chart_type,
|
||||
)
|
||||
assert config.kind == chart_type
|
||||
|
||||
# Test multiple Y columns
|
||||
multi_y_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
ColumnRef(name="profit", aggregate="AVG"),
|
||||
ColumnRef(name="quantity", aggregate="COUNT"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
assert len(multi_y_config.y) == 3
|
||||
assert multi_y_config.y[1].aggregate == "AVG"
|
||||
|
||||
# Test filter operators
|
||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
|
||||
for i, f in enumerate(filters):
|
||||
assert f.op == operators[i]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_response_structure(self):
|
||||
"""Test the expected response structure for chart generation."""
|
||||
# The response should contain these fields
|
||||
_ = {
|
||||
"chart": {
|
||||
"id": int,
|
||||
"slice_name": str,
|
||||
"viz_type": str,
|
||||
"url": str,
|
||||
"uuid": str,
|
||||
"saved": bool,
|
||||
},
|
||||
"error": None,
|
||||
"success": bool,
|
||||
"schema_version": str,
|
||||
"api_version": str,
|
||||
}
|
||||
|
||||
# When chart creation succeeds, these additional fields may be present
|
||||
_ = [
|
||||
"previews",
|
||||
"capabilities",
|
||||
"semantics",
|
||||
"explore_url",
|
||||
"form_data_key",
|
||||
"api_endpoints",
|
||||
"performance",
|
||||
"accessibility",
|
||||
]
|
||||
|
||||
# This is just a structural test - actual integration tests would verify
|
||||
# the tool returns data matching this structure
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_dataset_id_flexibility(self):
|
||||
"""Test that dataset_id can be string or int."""
|
||||
configs = [
|
||||
GenerateChartRequest(
|
||||
dataset_id="123",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
),
|
||||
GenerateChartRequest(
|
||||
dataset_id="uuid-string-here",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
for config in configs:
|
||||
assert isinstance(config.dataset_id, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_chart_flag(self):
|
||||
"""Test save_chart flag behavior."""
|
||||
# Default should be True (save chart)
|
||||
request1 = GenerateChartRequest(
|
||||
dataset_id="1",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
)
|
||||
assert request1.save_chart is True
|
||||
|
||||
# Explicit False (preview only)
|
||||
request2 = GenerateChartRequest(
|
||||
dataset_id="1",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
save_chart=False,
|
||||
)
|
||||
assert request2.save_chart is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_formats(self):
|
||||
"""Test preview format options."""
|
||||
formats = ["url", "ascii", "table"]
|
||||
request = GenerateChartRequest(
|
||||
dataset_id="1",
|
||||
config=TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="col1")]
|
||||
),
|
||||
generate_preview=True,
|
||||
preview_formats=formats,
|
||||
)
|
||||
assert request.generate_preview is True
|
||||
assert set(request.preview_formats) == set(formats)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_column_ref_features(self):
|
||||
"""Test ColumnRef features like aggregation and labels."""
|
||||
# Simple column
|
||||
col1 = ColumnRef(name="region")
|
||||
assert col1.name == "region"
|
||||
assert col1.label is None
|
||||
assert col1.aggregate is None
|
||||
|
||||
# Column with aggregation
|
||||
col2 = ColumnRef(name="sales", aggregate="SUM", label="Total Sales")
|
||||
assert col2.name == "sales"
|
||||
assert col2.aggregate == "SUM"
|
||||
assert col2.label == "Total Sales"
|
||||
|
||||
# All supported aggregations
|
||||
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
|
||||
for agg in aggs:
|
||||
col = ColumnRef(name="value", aggregate=agg)
|
||||
assert col.aggregate == agg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_axis_config_options(self):
|
||||
"""Test axis configuration options."""
|
||||
axis = AxisConfig(
|
||||
title="Sales Amount",
|
||||
format="$,.2f",
|
||||
scale="linear",
|
||||
)
|
||||
assert axis.title == "Sales Amount"
|
||||
assert axis.format == "$,.2f"
|
||||
assert axis.scale == "linear"
|
||||
|
||||
# Test different formats
|
||||
formats = ["SMART_NUMBER", "$,.2f", ",.0f", "smart_date", ".3%"]
|
||||
for fmt in formats:
|
||||
axis = AxisConfig(format=fmt)
|
||||
assert axis.format == fmt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_legend_config_options(self):
|
||||
"""Test legend configuration options."""
|
||||
positions = ["top", "bottom", "left", "right"]
|
||||
for pos in positions:
|
||||
legend = LegendConfig(show=True, position=pos)
|
||||
assert legend.position == pos
|
||||
|
||||
# Hidden legend
|
||||
legend = LegendConfig(show=False)
|
||||
assert legend.show is False
|
||||
@@ -0,0 +1,290 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for get_chart_preview MCP tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ASCIIPreview,
|
||||
GetChartPreviewRequest,
|
||||
TablePreview,
|
||||
URLPreview,
|
||||
)
|
||||
|
||||
|
||||
class TestGetChartPreview:
|
||||
"""Tests for get_chart_preview MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chart_preview_request_structure(self):
|
||||
"""Test that preview request structures are properly formed."""
|
||||
# Numeric ID request
|
||||
request1 = GetChartPreviewRequest(identifier=123, format="url")
|
||||
assert request1.identifier == 123
|
||||
assert request1.format == "url"
|
||||
# Default dimensions are set
|
||||
assert request1.width == 800
|
||||
assert request1.height == 600
|
||||
|
||||
# String ID request
|
||||
request2 = GetChartPreviewRequest(identifier="456", format="ascii")
|
||||
assert request2.identifier == "456"
|
||||
assert request2.format == "ascii"
|
||||
|
||||
# UUID request
|
||||
request3 = GetChartPreviewRequest(
|
||||
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", format="table"
|
||||
)
|
||||
assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert request3.format == "table"
|
||||
|
||||
# Default format
|
||||
request4 = GetChartPreviewRequest(identifier=789)
|
||||
assert request4.format == "url" # default
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_format_types(self):
|
||||
"""Test different preview format types."""
|
||||
formats = ["url", "ascii", "table"]
|
||||
for fmt in formats:
|
||||
request = GetChartPreviewRequest(identifier=1, format=fmt)
|
||||
assert request.format == fmt
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_url_preview_structure(self):
|
||||
"""Test URLPreview response structure."""
|
||||
preview = URLPreview(
|
||||
preview_url="http://localhost:5008/screenshot/chart/123.png",
|
||||
width=800,
|
||||
height=600,
|
||||
supports_interaction=False,
|
||||
)
|
||||
assert preview.type == "url"
|
||||
assert preview.preview_url == "http://localhost:5008/screenshot/chart/123.png"
|
||||
assert preview.width == 800
|
||||
assert preview.height == 600
|
||||
assert preview.supports_interaction is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ascii_preview_structure(self):
|
||||
"""Test ASCIIPreview response structure."""
|
||||
ascii_art = """
|
||||
┌─────────────────────────┐
|
||||
│ Sales by Region │
|
||||
├─────────────────────────┤
|
||||
│ North ████████ 45% │
|
||||
│ South ██████ 30% │
|
||||
│ East ████ 20% │
|
||||
│ West ██ 5% │
|
||||
└─────────────────────────┘
|
||||
"""
|
||||
preview = ASCIIPreview(
|
||||
ascii_content=ascii_art.strip(),
|
||||
width=25,
|
||||
height=8,
|
||||
)
|
||||
assert preview.type == "ascii"
|
||||
assert "Sales by Region" in preview.ascii_content
|
||||
assert preview.width == 25
|
||||
assert preview.height == 8
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_table_preview_structure(self):
|
||||
"""Test TablePreview response structure."""
|
||||
table_content = """
|
||||
| Region | Sales | Profit |
|
||||
|--------|--------|--------|
|
||||
| North | 45000 | 12000 |
|
||||
| South | 30000 | 8000 |
|
||||
| East | 20000 | 5000 |
|
||||
| West | 5000 | 1000 |
|
||||
"""
|
||||
preview = TablePreview(
|
||||
table_data=table_content.strip(),
|
||||
row_count=4,
|
||||
supports_sorting=True,
|
||||
)
|
||||
assert preview.type == "table"
|
||||
assert "Region" in preview.table_data
|
||||
assert "North" in preview.table_data
|
||||
assert preview.row_count == 4
|
||||
assert preview.supports_sorting is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chart_preview_response_structure(self):
|
||||
"""Test the expected response structure for chart preview."""
|
||||
# Core fields that should always be present
|
||||
_ = [
|
||||
"chart_id",
|
||||
"chart_name",
|
||||
"chart_type",
|
||||
"explore_url",
|
||||
"content", # Union of URLPreview | ASCIIPreview | TablePreview
|
||||
"chart_description",
|
||||
"accessibility",
|
||||
"performance",
|
||||
]
|
||||
|
||||
# Additional fields that may be present for backward compatibility
|
||||
_ = [
|
||||
"format",
|
||||
"preview_url",
|
||||
"ascii_chart",
|
||||
"table_data",
|
||||
"width",
|
||||
"height",
|
||||
"schema_version",
|
||||
"api_version",
|
||||
]
|
||||
|
||||
# This is a structural test - actual integration tests would verify
|
||||
# the tool returns data matching this structure
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_dimensions(self):
|
||||
"""Test preview dimensions in response."""
|
||||
# Standard dimensions that may appear in preview responses
|
||||
standard_sizes = [
|
||||
(800, 600), # Default
|
||||
(1200, 800), # Large
|
||||
(400, 300), # Small
|
||||
(1920, 1080), # Full HD
|
||||
]
|
||||
|
||||
for width, height in standard_sizes:
|
||||
# URL preview with dimensions
|
||||
url_preview = URLPreview(
|
||||
preview_url="http://example.com/chart.png",
|
||||
width=width,
|
||||
height=height,
|
||||
supports_interaction=False,
|
||||
)
|
||||
assert url_preview.width == width
|
||||
assert url_preview.height == height
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_error_response_structures(self):
|
||||
"""Test error response structures."""
|
||||
# Error responses typically follow this structure
|
||||
error_response = {
|
||||
"error_type": "not_found",
|
||||
"message": "Chart not found",
|
||||
"details": "No chart found with ID 999",
|
||||
"chart_id": 999,
|
||||
}
|
||||
assert error_response["error_type"] == "not_found"
|
||||
assert error_response["chart_id"] == 999
|
||||
|
||||
# Preview generation error structure
|
||||
preview_error = {
|
||||
"error_type": "preview_error",
|
||||
"message": "Failed to generate preview",
|
||||
"details": "Screenshot service unavailable",
|
||||
}
|
||||
assert preview_error["error_type"] == "preview_error"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_accessibility_metadata(self):
|
||||
"""Test accessibility metadata structure."""
|
||||
from superset.mcp_service.chart.schemas import AccessibilityMetadata
|
||||
|
||||
metadata = AccessibilityMetadata(
|
||||
color_blind_safe=True,
|
||||
alt_text="Bar chart showing sales by region",
|
||||
high_contrast_available=False,
|
||||
)
|
||||
assert metadata.color_blind_safe is True
|
||||
assert "sales by region" in metadata.alt_text
|
||||
assert metadata.high_contrast_available is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_performance_metadata(self):
|
||||
"""Test performance metadata structure."""
|
||||
from superset.mcp_service.chart.schemas import PerformanceMetadata
|
||||
|
||||
metadata = PerformanceMetadata(
|
||||
query_duration_ms=150,
|
||||
cache_status="hit",
|
||||
optimization_suggestions=["Consider adding an index on date column"],
|
||||
)
|
||||
assert metadata.query_duration_ms == 150
|
||||
assert metadata.cache_status == "hit"
|
||||
assert len(metadata.optimization_suggestions) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chart_types_support(self):
|
||||
"""Test that various chart types are supported."""
|
||||
chart_types = [
|
||||
"echarts_timeseries_line",
|
||||
"echarts_timeseries_bar",
|
||||
"echarts_area",
|
||||
"echarts_timeseries_scatter",
|
||||
"table",
|
||||
"pie",
|
||||
"big_number",
|
||||
"big_number_total",
|
||||
"pivot_table_v2",
|
||||
"dist_bar",
|
||||
"box_plot",
|
||||
]
|
||||
|
||||
# All chart types should be previewable
|
||||
for _chart_type in chart_types:
|
||||
# This would be tested in integration tests
|
||||
pass
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ascii_art_variations(self):
|
||||
"""Test ASCII art generation for different chart types."""
|
||||
# Line chart ASCII
|
||||
_ = """
|
||||
Sales Trend
|
||||
│
|
||||
│ ╱╲
|
||||
│ ╱ ╲
|
||||
│ ╱ ╲
|
||||
│ ╱ ╲
|
||||
│ ╱ ╲
|
||||
└────────────
|
||||
Jan Feb Mar
|
||||
"""
|
||||
|
||||
# Bar chart ASCII
|
||||
_ = """
|
||||
Sales by Region
|
||||
│
|
||||
│ ████ North
|
||||
│ ███ South
|
||||
│ ██ East
|
||||
│ █ West
|
||||
└────────────
|
||||
"""
|
||||
|
||||
# Pie chart ASCII
|
||||
_ = """
|
||||
Market Share
|
||||
╭─────╮
|
||||
╱ ╲
|
||||
│ 45% │
|
||||
│ North │
|
||||
╰─────────╯
|
||||
"""
|
||||
|
||||
# These demonstrate the expected ASCII formats for different chart types
|
||||
385
tests/unit_tests/mcp_service/chart/tool/test_update_chart.py
Normal file
385
tests/unit_tests/mcp_service/chart/tool/test_update_chart.py
Normal file
@@ -0,0 +1,385 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for update_chart MCP tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
LegendConfig,
|
||||
TableChartConfig,
|
||||
UpdateChartRequest,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateChart:
|
||||
"""Tests for update_chart MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_request_structure(self):
|
||||
"""Test that chart update request structures are properly formed."""
|
||||
# Table chart update with numeric ID
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region", label="Region"),
|
||||
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
|
||||
],
|
||||
filters=[FilterConfig(column="year", op="=", value="2024")],
|
||||
sort_by=["sales"],
|
||||
)
|
||||
table_request = UpdateChartRequest(identifier=123, config=table_config)
|
||||
assert table_request.identifier == 123
|
||||
assert table_request.config.chart_type == "table"
|
||||
assert len(table_request.config.columns) == 2
|
||||
assert table_request.config.columns[0].name == "region"
|
||||
assert table_request.config.columns[1].aggregate == "SUM"
|
||||
|
||||
# XY chart update with UUID
|
||||
xy_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="value", aggregate="SUM")],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="category"),
|
||||
x_axis=AxisConfig(title="Date", format="smart_date"),
|
||||
y_axis=AxisConfig(title="Value", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="top"),
|
||||
)
|
||||
xy_request = UpdateChartRequest(
|
||||
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config
|
||||
)
|
||||
assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert xy_request.config.chart_type == "xy"
|
||||
assert xy_request.config.x.name == "date"
|
||||
assert xy_request.config.y[0].aggregate == "SUM"
|
||||
assert xy_request.config.kind == "line"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_with_chart_name(self):
|
||||
"""Test updating chart with custom chart name."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Without custom name
|
||||
request1 = UpdateChartRequest(identifier=123, config=config)
|
||||
assert request1.chart_name is None
|
||||
|
||||
# With custom name
|
||||
request2 = UpdateChartRequest(
|
||||
identifier=123, config=config, chart_name="Updated Sales Report"
|
||||
)
|
||||
assert request2.chart_name == "Updated Sales Report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_generation(self):
|
||||
"""Test preview generation options in update request."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Default preview generation
|
||||
request1 = UpdateChartRequest(identifier=123, config=config)
|
||||
assert request1.generate_preview is True
|
||||
assert request1.preview_formats == ["url"]
|
||||
|
||||
# Custom preview formats
|
||||
request2 = UpdateChartRequest(
|
||||
identifier=123,
|
||||
config=config,
|
||||
generate_preview=True,
|
||||
preview_formats=["url", "ascii", "table"],
|
||||
)
|
||||
assert request2.generate_preview is True
|
||||
assert set(request2.preview_formats) == {"url", "ascii", "table"}
|
||||
|
||||
# Disable preview generation
|
||||
request3 = UpdateChartRequest(
|
||||
identifier=123, config=config, generate_preview=False
|
||||
)
|
||||
assert request3.generate_preview is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_identifier_types(self):
|
||||
"""Test that identifier can be int or string (UUID)."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Integer ID
|
||||
request1 = UpdateChartRequest(identifier=123, config=config)
|
||||
assert request1.identifier == 123
|
||||
assert isinstance(request1.identifier, int)
|
||||
|
||||
# String numeric ID
|
||||
request2 = UpdateChartRequest(identifier="456", config=config)
|
||||
assert request2.identifier == "456"
|
||||
assert isinstance(request2.identifier, str)
|
||||
|
||||
# UUID string
|
||||
request3 = UpdateChartRequest(
|
||||
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=config
|
||||
)
|
||||
assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert isinstance(request3.identifier, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_config_variations(self):
|
||||
"""Test various chart configuration options in updates."""
|
||||
# Test all XY chart types
|
||||
chart_types = ["line", "bar", "area", "scatter"]
|
||||
for chart_type in chart_types:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x_col"),
|
||||
y=[ColumnRef(name="y_col")],
|
||||
kind=chart_type,
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.kind == chart_type
|
||||
|
||||
# Test multiple Y columns
|
||||
multi_y_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
ColumnRef(name="profit", aggregate="AVG"),
|
||||
ColumnRef(name="quantity", aggregate="COUNT"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=multi_y_config)
|
||||
assert len(request.config.y) == 3
|
||||
assert request.config.y[1].aggregate == "AVG"
|
||||
|
||||
# Test filter operators
|
||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
filters=filters,
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=table_config)
|
||||
assert len(request.config.filters) == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_response_structure(self):
|
||||
"""Test the expected response structure for chart updates."""
|
||||
# The response should contain these fields
|
||||
expected_response = {
|
||||
"chart": {
|
||||
"id": int,
|
||||
"slice_name": str,
|
||||
"viz_type": str,
|
||||
"url": str,
|
||||
"uuid": str,
|
||||
"updated": bool,
|
||||
},
|
||||
"error": None,
|
||||
"success": bool,
|
||||
"schema_version": str,
|
||||
"api_version": str,
|
||||
}
|
||||
|
||||
# When chart update succeeds, these additional fields may be present
|
||||
optional_fields = [
|
||||
"previews",
|
||||
"capabilities",
|
||||
"semantics",
|
||||
"explore_url",
|
||||
"api_endpoints",
|
||||
"performance",
|
||||
"accessibility",
|
||||
]
|
||||
|
||||
# Validate structure expectations
|
||||
assert "chart" in expected_response
|
||||
assert "success" in expected_response
|
||||
assert len(optional_fields) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_axis_configurations(self):
|
||||
"""Test axis configuration updates."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="sales")],
|
||||
x_axis=AxisConfig(
|
||||
title="Date",
|
||||
format="smart_date",
|
||||
scale="linear",
|
||||
),
|
||||
y_axis=AxisConfig(
|
||||
title="Sales Amount",
|
||||
format="$,.2f",
|
||||
scale="log",
|
||||
),
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.x_axis.title == "Date"
|
||||
assert request.config.x_axis.format == "smart_date"
|
||||
assert request.config.x_axis.scale == "linear"
|
||||
assert request.config.y_axis.title == "Sales Amount"
|
||||
assert request.config.y_axis.format == "$,.2f"
|
||||
assert request.config.y_axis.scale == "log"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_legend_configurations(self):
|
||||
"""Test legend configuration updates."""
|
||||
positions = ["top", "bottom", "left", "right"]
|
||||
for pos in positions:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
legend=LegendConfig(show=True, position=pos),
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.legend.position == pos
|
||||
assert request.config.legend.show is True
|
||||
|
||||
# Hidden legend
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
legend=LegendConfig(show=False),
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.legend.show is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_aggregation_functions(self):
|
||||
"""Test all supported aggregation functions in updates."""
|
||||
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
|
||||
for agg in aggs:
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="value", aggregate=agg)],
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.columns[0].aggregate == agg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_error_responses(self):
|
||||
"""Test expected error response structures."""
|
||||
# Chart not found error
|
||||
error_response = {
|
||||
"chart": None,
|
||||
"error": "No chart found with identifier: 999",
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
assert error_response["success"] is False
|
||||
assert error_response["chart"] is None
|
||||
assert "chart found" in error_response["error"].lower()
|
||||
|
||||
# General update error
|
||||
update_error = {
|
||||
"chart": None,
|
||||
"error": "Chart update failed: Permission denied",
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
assert update_error["success"] is False
|
||||
assert "failed" in update_error["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chart_name_sanitization(self):
|
||||
"""Test that chart names are properly sanitized."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Test with potentially problematic characters
|
||||
test_names = [
|
||||
"Normal Chart Name",
|
||||
"Chart with 'quotes'",
|
||||
'Chart with "double quotes"',
|
||||
"Chart with <brackets>",
|
||||
]
|
||||
|
||||
for name in test_names:
|
||||
request = UpdateChartRequest(identifier=1, config=config, chart_name=name)
|
||||
# Chart name should be set (sanitization happens in the validator)
|
||||
assert request.chart_name is not None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_with_filters(self):
|
||||
"""Test updating chart with various filter configurations."""
|
||||
filters = [
|
||||
FilterConfig(column="region", op="=", value="North"),
|
||||
FilterConfig(column="sales", op=">=", value=1000),
|
||||
FilterConfig(column="date", op=">", value="2024-01-01"),
|
||||
]
|
||||
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region"),
|
||||
ColumnRef(name="sales"),
|
||||
ColumnRef(name="date"),
|
||||
],
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert len(request.config.filters) == 3
|
||||
assert request.config.filters[0].column == "region"
|
||||
assert request.config.filters[1].op == ">="
|
||||
assert request.config.filters[2].value == "2024-01-01"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_cache_control(self):
|
||||
"""Test cache control parameters in update request."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Default cache settings
|
||||
request1 = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request1.use_cache is True
|
||||
assert request1.force_refresh is False
|
||||
assert request1.cache_timeout is None
|
||||
|
||||
# Custom cache settings
|
||||
request2 = UpdateChartRequest(
|
||||
identifier=1,
|
||||
config=config,
|
||||
use_cache=False,
|
||||
force_refresh=True,
|
||||
cache_timeout=300,
|
||||
)
|
||||
assert request2.use_cache is False
|
||||
assert request2.force_refresh is True
|
||||
assert request2.cache_timeout == 300
|
||||
@@ -0,0 +1,474 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for update_chart_preview MCP tool
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
LegendConfig,
|
||||
TableChartConfig,
|
||||
UpdateChartPreviewRequest,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateChartPreview:
|
||||
"""Tests for update_chart_preview MCP tool."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_request_structure(self):
|
||||
"""Test that chart preview update request structures are properly formed."""
|
||||
# Table chart preview update
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region", label="Region"),
|
||||
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
|
||||
],
|
||||
filters=[FilterConfig(column="year", op="=", value="2024")],
|
||||
sort_by=["sales"],
|
||||
)
|
||||
table_request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123def456", dataset_id=1, config=table_config
|
||||
)
|
||||
assert table_request.form_data_key == "abc123def456"
|
||||
assert table_request.dataset_id == 1
|
||||
assert table_request.config.chart_type == "table"
|
||||
assert len(table_request.config.columns) == 2
|
||||
assert table_request.config.columns[0].name == "region"
|
||||
|
||||
# XY chart preview update
|
||||
xy_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="value", aggregate="SUM")],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="category"),
|
||||
x_axis=AxisConfig(title="Date", format="smart_date"),
|
||||
y_axis=AxisConfig(title="Value", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="top"),
|
||||
)
|
||||
xy_request = UpdateChartPreviewRequest(
|
||||
form_data_key="xyz789ghi012", dataset_id="2", config=xy_config
|
||||
)
|
||||
assert xy_request.form_data_key == "xyz789ghi012"
|
||||
assert xy_request.dataset_id == "2"
|
||||
assert xy_request.config.chart_type == "xy"
|
||||
assert xy_request.config.x.name == "date"
|
||||
assert xy_request.config.kind == "line"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_dataset_id_types(self):
|
||||
"""Test that dataset_id can be int or string (UUID)."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Integer dataset_id
|
||||
request1 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=123, config=config
|
||||
)
|
||||
assert request1.dataset_id == 123
|
||||
assert isinstance(request1.dataset_id, int)
|
||||
|
||||
# String numeric dataset_id
|
||||
request2 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id="456", config=config
|
||||
)
|
||||
assert request2.dataset_id == "456"
|
||||
assert isinstance(request2.dataset_id, str)
|
||||
|
||||
# UUID string dataset_id
|
||||
request3 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123",
|
||||
dataset_id="a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||
config=config,
|
||||
)
|
||||
assert request3.dataset_id == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert isinstance(request3.dataset_id, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_generation_options(self):
|
||||
"""Test preview generation options in update preview request."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Default preview generation
|
||||
request1 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request1.generate_preview is True
|
||||
assert request1.preview_formats == ["url"]
|
||||
|
||||
# Custom preview formats
|
||||
request2 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123",
|
||||
dataset_id=1,
|
||||
config=config,
|
||||
generate_preview=True,
|
||||
preview_formats=["url", "ascii", "table"],
|
||||
)
|
||||
assert request2.generate_preview is True
|
||||
assert set(request2.preview_formats) == {"url", "ascii", "table"}
|
||||
|
||||
# Disable preview generation
|
||||
request3 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123",
|
||||
dataset_id=1,
|
||||
config=config,
|
||||
generate_preview=False,
|
||||
)
|
||||
assert request3.generate_preview is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_config_variations(self):
|
||||
"""Test various chart configuration options in preview updates."""
|
||||
# Test all XY chart types
|
||||
chart_types = ["line", "bar", "area", "scatter"]
|
||||
for chart_type in chart_types:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x_col"),
|
||||
y=[ColumnRef(name="y_col")],
|
||||
kind=chart_type,
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.kind == chart_type
|
||||
|
||||
# Test multiple Y columns
|
||||
multi_y_config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
ColumnRef(name="profit", aggregate="AVG"),
|
||||
ColumnRef(name="quantity", aggregate="COUNT"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=multi_y_config
|
||||
)
|
||||
assert len(request.config.y) == 3
|
||||
assert request.config.y[1].aggregate == "AVG"
|
||||
|
||||
# Test filter operators
|
||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
filters=filters,
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=table_config
|
||||
)
|
||||
assert len(request.config.filters) == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_response_structure(self):
|
||||
"""Test the expected response structure for chart preview updates."""
|
||||
# The response should contain these fields
|
||||
expected_response = {
|
||||
"chart": {
|
||||
"id": None, # No ID for unsaved previews
|
||||
"slice_name": str,
|
||||
"viz_type": str,
|
||||
"url": str,
|
||||
"uuid": None, # No UUID for unsaved previews
|
||||
"saved": bool,
|
||||
"updated": bool,
|
||||
},
|
||||
"error": None,
|
||||
"success": bool,
|
||||
"schema_version": str,
|
||||
"api_version": str,
|
||||
}
|
||||
|
||||
# When preview update succeeds, these additional fields may be present
|
||||
optional_fields = [
|
||||
"previews",
|
||||
"capabilities",
|
||||
"semantics",
|
||||
"explore_url",
|
||||
"form_data_key",
|
||||
"previous_form_data_key",
|
||||
"api_endpoints",
|
||||
"performance",
|
||||
"accessibility",
|
||||
]
|
||||
|
||||
# Validate structure expectations
|
||||
assert "chart" in expected_response
|
||||
assert "success" in expected_response
|
||||
assert len(optional_fields) > 0
|
||||
assert expected_response["chart"]["id"] is None
|
||||
assert expected_response["chart"]["uuid"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_axis_configurations(self):
|
||||
"""Test axis configuration updates in preview."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="sales")],
|
||||
x_axis=AxisConfig(
|
||||
title="Date",
|
||||
format="smart_date",
|
||||
scale="linear",
|
||||
),
|
||||
y_axis=AxisConfig(
|
||||
title="Sales Amount",
|
||||
format="$,.2f",
|
||||
scale="log",
|
||||
),
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.x_axis.title == "Date"
|
||||
assert request.config.x_axis.format == "smart_date"
|
||||
assert request.config.y_axis.title == "Sales Amount"
|
||||
assert request.config.y_axis.format == "$,.2f"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_legend_configurations(self):
|
||||
"""Test legend configuration updates in preview."""
|
||||
positions = ["top", "bottom", "left", "right"]
|
||||
for pos in positions:
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
legend=LegendConfig(show=True, position=pos),
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.legend.position == pos
|
||||
assert request.config.legend.show is True
|
||||
|
||||
# Hidden legend
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
legend=LegendConfig(show=False),
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.legend.show is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_aggregation_functions(self):
|
||||
"""Test all supported aggregation functions in preview updates."""
|
||||
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
|
||||
for agg in aggs:
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="value", aggregate=agg)],
|
||||
)
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.columns[0].aggregate == agg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_error_responses(self):
|
||||
"""Test expected error response structures for preview updates."""
|
||||
# General update error
|
||||
error_response = {
|
||||
"chart": None,
|
||||
"error": "Chart preview update failed: Invalid form_data_key",
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
assert error_response["success"] is False
|
||||
assert error_response["chart"] is None
|
||||
assert "failed" in error_response["error"].lower()
|
||||
|
||||
# Missing dataset error
|
||||
dataset_error = {
|
||||
"chart": None,
|
||||
"error": "Chart preview update failed: Dataset not found",
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
assert dataset_error["success"] is False
|
||||
assert "dataset" in dataset_error["error"].lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_with_filters(self):
|
||||
"""Test updating preview with various filter configurations."""
|
||||
filters = [
|
||||
FilterConfig(column="region", op="=", value="North"),
|
||||
FilterConfig(column="sales", op=">=", value=1000),
|
||||
FilterConfig(column="date", op=">", value="2024-01-01"),
|
||||
]
|
||||
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region"),
|
||||
ColumnRef(name="sales"),
|
||||
ColumnRef(name="date"),
|
||||
],
|
||||
filters=filters,
|
||||
)
|
||||
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert len(request.config.filters) == 3
|
||||
assert request.config.filters[0].column == "region"
|
||||
assert request.config.filters[1].op == ">="
|
||||
assert request.config.filters[2].value == "2024-01-01"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_form_data_key_handling(self):
|
||||
"""Test form_data_key handling in preview updates."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Various form_data_key formats
|
||||
form_data_keys = [
|
||||
"abc123def456",
|
||||
"xyz-789-ghi-012",
|
||||
"key_with_underscore",
|
||||
"UPPERCASE_KEY",
|
||||
]
|
||||
|
||||
for key in form_data_keys:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key=key, dataset_id=1, config=config
|
||||
)
|
||||
assert request.form_data_key == key
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_cache_control(self):
|
||||
"""Test cache control parameters in update preview request."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Default cache settings
|
||||
request1 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request1.use_cache is True
|
||||
assert request1.force_refresh is False
|
||||
assert request1.cache_form_data is True
|
||||
|
||||
# Custom cache settings
|
||||
request2 = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123",
|
||||
dataset_id=1,
|
||||
config=config,
|
||||
use_cache=False,
|
||||
force_refresh=True,
|
||||
cache_form_data=False,
|
||||
)
|
||||
assert request2.use_cache is False
|
||||
assert request2.force_refresh is True
|
||||
assert request2.cache_form_data is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_no_save_behavior(self):
|
||||
"""Test that preview updates don't create permanent charts."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
|
||||
# Preview updates should never create permanent charts
|
||||
# This is validated by checking the response structure
|
||||
expected_unsaved_fields = {
|
||||
"id": None, # No chart ID
|
||||
"uuid": None, # No UUID
|
||||
"saved": False, # Not saved
|
||||
}
|
||||
|
||||
# These expectations are validated in the response, not the request
|
||||
assert request.form_data_key == "abc123"
|
||||
assert expected_unsaved_fields["id"] is None
|
||||
assert expected_unsaved_fields["uuid"] is None
|
||||
assert expected_unsaved_fields["saved"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_multiple_y_columns(self):
|
||||
"""Test preview updates with multiple Y-axis columns."""
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"),
|
||||
ColumnRef(name="cost", aggregate="SUM", label="Total Cost"),
|
||||
ColumnRef(name="profit", aggregate="SUM", label="Total Profit"),
|
||||
ColumnRef(name="orders", aggregate="COUNT", label="Order Count"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert len(request.config.y) == 4
|
||||
assert request.config.y[0].name == "revenue"
|
||||
assert request.config.y[1].name == "cost"
|
||||
assert request.config.y[2].name == "profit"
|
||||
assert request.config.y[3].name == "orders"
|
||||
assert request.config.y[3].aggregate == "COUNT"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_table_sorting(self):
|
||||
"""Test table chart sorting in preview updates."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region"),
|
||||
ColumnRef(name="sales", aggregate="SUM"),
|
||||
ColumnRef(name="profit", aggregate="AVG"),
|
||||
],
|
||||
sort_by=["sales", "profit"],
|
||||
)
|
||||
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.sort_by == ["sales", "profit"]
|
||||
assert len(request.config.columns) == 3
|
||||
23
tests/unit_tests/mcp_service/conftest.py
Normal file
23
tests/unit_tests/mcp_service/conftest.py
Normal file
@@ -0,0 +1,23 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP service test configuration.
|
||||
|
||||
Tool imports are handled by app.py, not here.
|
||||
This conftest is empty to prevent test pollution.
|
||||
"""
|
||||
16
tests/unit_tests/mcp_service/dashboard/__init__.py
Normal file
16
tests/unit_tests/mcp_service/dashboard/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/dashboard/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/dashboard/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,450 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for dashboard generation MCP tools
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _mock_chart(id: int = 1, slice_name: str = "Test Chart") -> Mock:
|
||||
"""Create a mock chart object."""
|
||||
chart = Mock()
|
||||
chart.id = id
|
||||
chart.slice_name = slice_name
|
||||
chart.uuid = f"chart-uuid-{id}"
|
||||
return chart
|
||||
|
||||
|
||||
def _mock_dashboard(id: int = 1, title: str = "Test Dashboard") -> Mock:
|
||||
"""Create a mock dashboard object."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = id
|
||||
dashboard.dashboard_title = title
|
||||
dashboard.slug = f"test-dashboard-{id}"
|
||||
dashboard.description = "Test dashboard description"
|
||||
dashboard.published = True
|
||||
dashboard.created_on = "2024-01-01"
|
||||
dashboard.changed_on = "2024-01-01"
|
||||
dashboard.created_by = Mock()
|
||||
dashboard.created_by.username = "test_user"
|
||||
dashboard.changed_by = Mock()
|
||||
dashboard.changed_by.username = "test_user"
|
||||
dashboard.uuid = f"dashboard-uuid-{id}"
|
||||
dashboard.slices = []
|
||||
dashboard.owners = [] # Add missing owners attribute
|
||||
dashboard.tags = [] # Add missing tags attribute
|
||||
return dashboard
|
||||
|
||||
|
||||
class TestGenerateDashboard:
|
||||
"""Tests for generate_dashboard MCP tool."""
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_basic(
|
||||
self, mock_db_session, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test basic dashboard generation with valid charts."""
|
||||
# Mock database query for charts
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
_mock_chart(id=1, slice_name="Sales Chart"),
|
||||
_mock_chart(id=2, slice_name="Revenue Chart"),
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
# Mock dashboard creation
|
||||
mock_dashboard = _mock_dashboard(id=10, title="Analytics Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
|
||||
request = {
|
||||
"chart_ids": [1, 2],
|
||||
"dashboard_title": "Analytics Dashboard",
|
||||
"description": "Dashboard for analytics",
|
||||
"published": True,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard is not None
|
||||
assert result.data.dashboard.id == 10
|
||||
assert result.data.dashboard.dashboard_title == "Analytics Dashboard"
|
||||
assert result.data.dashboard.chart_count == 2
|
||||
assert "/superset/dashboard/10/" in result.data.dashboard_url
|
||||
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_missing_charts(self, mock_db_session, mcp_server):
|
||||
"""Test error handling when some charts don't exist."""
|
||||
# Mock database query returning only chart 1 (chart 2 missing)
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
_mock_chart(id=1),
|
||||
# Chart 2 is missing from the result
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
request = {"chart_ids": [1, 2], "dashboard_title": "Test Dashboard"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Charts not found: [2]" in result.data.error
|
||||
assert result.data.dashboard is None
|
||||
assert result.data.dashboard_url is None
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_single_chart(
|
||||
self, mock_db_session, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with a single chart."""
|
||||
# Mock database query for single chart
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=5, slice_name="Single Chart")]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=20, title="Single Chart Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
|
||||
request = {
|
||||
"chart_ids": [5],
|
||||
"dashboard_title": "Single Chart Dashboard",
|
||||
"published": False,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard.chart_count == 1
|
||||
assert result.data.dashboard.published is True # From mock
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_many_charts(
|
||||
self, mock_db_session, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with many charts (grid layout)."""
|
||||
# Mock 6 charts
|
||||
chart_ids = list(range(1, 7))
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=30, title="Multi Chart Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
|
||||
request = {"chart_ids": chart_ids, "dashboard_title": "Multi Chart Dashboard"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard.chart_count == 6
|
||||
|
||||
# Verify CreateDashboardCommand was called with proper layout
|
||||
mock_create_command.assert_called_once()
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
|
||||
# Check position_json contains proper layout
|
||||
position_json = json.loads(call_args["position_json"])
|
||||
assert "ROOT_ID" in position_json
|
||||
assert "GRID_ID" in position_json
|
||||
assert "DASHBOARD_VERSION_KEY" in position_json
|
||||
assert position_json["DASHBOARD_VERSION_KEY"] == "v2"
|
||||
|
||||
# ROOT should only contain GRID
|
||||
assert position_json["ROOT_ID"]["children"] == ["GRID_ID"]
|
||||
|
||||
# GRID should contain rows (6 charts = 3 rows in 2-chart layout)
|
||||
grid_children = position_json["GRID_ID"]["children"]
|
||||
assert len(grid_children) == 3
|
||||
|
||||
# Check each chart has proper structure
|
||||
for i, chart_id in enumerate(chart_ids):
|
||||
chart_key = f"CHART-{chart_id}"
|
||||
row_index = i // 2 # 2 charts per row
|
||||
row_key = f"ROW-{row_index}"
|
||||
|
||||
# Chart should exist
|
||||
assert chart_key in position_json
|
||||
chart_data = position_json[chart_key]
|
||||
assert chart_data["type"] == "CHART"
|
||||
assert "meta" in chart_data
|
||||
assert chart_data["meta"]["chartId"] == chart_id
|
||||
|
||||
# Row should exist and contain charts (up to 2 per row)
|
||||
assert row_key in position_json
|
||||
row_data = position_json[row_key]
|
||||
assert row_data["type"] == "ROW"
|
||||
assert chart_key in row_data["children"]
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_creation_failure(
|
||||
self, mock_db_session, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test error handling when dashboard creation fails."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=1)]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_create_command.return_value.run.side_effect = Exception("Creation failed")
|
||||
|
||||
request = {"chart_ids": [1], "dashboard_title": "Failed Dashboard"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Failed to create dashboard" in result.data.error
|
||||
assert result.data.dashboard is None
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_minimal_request(
|
||||
self, mock_db_session, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with minimal required parameters."""
|
||||
# Mock database query for single chart
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=3)]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=40, title="Minimal Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
|
||||
request = {
|
||||
"chart_ids": [3],
|
||||
"dashboard_title": "Minimal Dashboard",
|
||||
# No description, published defaults to True
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard.dashboard_title == "Minimal Dashboard"
|
||||
|
||||
# Check that description was not included in call
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
assert call_args["published"] is True # Default value
|
||||
assert (
|
||||
"description" not in call_args or call_args.get("description") is None
|
||||
)
|
||||
|
||||
|
||||
class TestAddChartToExistingDashboard:
|
||||
"""Tests for add_chart_to_existing_dashboard MCP tool."""
|
||||
|
||||
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_to_dashboard_basic(
|
||||
self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server
|
||||
):
|
||||
"""Test adding a chart to an existing dashboard."""
|
||||
# Mock existing dashboard with some charts
|
||||
mock_dashboard = _mock_dashboard(id=1, title="Existing Dashboard")
|
||||
mock_dashboard.slices = [Mock(id=10), Mock(id=20)] # Existing charts
|
||||
mock_dashboard.position_json = json.dumps(
|
||||
{
|
||||
"ROOT_ID": {
|
||||
"children": ["CHART-10", "CHART-20"],
|
||||
"id": "ROOT_ID",
|
||||
"type": "ROOT",
|
||||
},
|
||||
"CHART-10": {"id": "CHART-10", "type": "CHART", "parents": ["ROOT_ID"]},
|
||||
"CHART-10_POSITION": {"h": 16, "w": 24, "x": 0, "y": 0},
|
||||
"CHART-20": {"id": "CHART-20", "type": "CHART", "parents": ["ROOT_ID"]},
|
||||
"CHART-20_POSITION": {"h": 16, "w": 24, "x": 24, "y": 0},
|
||||
}
|
||||
)
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
# Mock chart to add
|
||||
mock_chart = _mock_chart(id=30, slice_name="New Chart")
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
|
||||
# Mock updated dashboard
|
||||
updated_dashboard = _mock_dashboard(id=1, title="Existing Dashboard")
|
||||
updated_dashboard.slices = [Mock(id=10), Mock(id=20), Mock(id=30)]
|
||||
mock_update_command.return_value.run.return_value = updated_dashboard
|
||||
|
||||
request = {"dashboard_id": 1, "chart_id": 30}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard is not None
|
||||
assert result.data.dashboard.chart_count == 3
|
||||
assert result.data.position is not None
|
||||
assert "row" in result.data.position # Should have row info
|
||||
assert "chart_key" in result.data.position
|
||||
assert "/superset/dashboard/1/" in result.data.dashboard_url
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_dashboard_not_found(self, mock_find_dashboard, mcp_server):
|
||||
"""Test error when dashboard doesn't exist."""
|
||||
mock_find_dashboard.return_value = None
|
||||
|
||||
request = {"dashboard_id": 999, "chart_id": 1}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Dashboard with ID 999 not found" in result.data.error
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_chart_not_found(
|
||||
self, mock_db_session, mock_find_dashboard, mcp_server
|
||||
):
|
||||
"""Test error when chart doesn't exist."""
|
||||
mock_find_dashboard.return_value = _mock_dashboard()
|
||||
mock_db_session.get.return_value = None
|
||||
|
||||
request = {"dashboard_id": 1, "chart_id": 999}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Chart with ID 999 not found" in result.data.error
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_already_in_dashboard(
|
||||
self, mock_db_session, mock_find_dashboard, mcp_server
|
||||
):
|
||||
"""Test error when chart is already in dashboard."""
|
||||
mock_dashboard = _mock_dashboard()
|
||||
mock_dashboard.slices = [Mock(id=5)] # Chart 5 already exists
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
mock_db_session.get.return_value = _mock_chart(id=5)
|
||||
|
||||
request = {"dashboard_id": 1, "chart_id": 5}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Chart 5 is already in dashboard 1" in result.data.error
|
||||
|
||||
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_empty_dashboard(
|
||||
self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server
|
||||
):
|
||||
"""Test adding chart to dashboard with no existing layout."""
|
||||
mock_dashboard = _mock_dashboard(id=2)
|
||||
mock_dashboard.slices = []
|
||||
mock_dashboard.position_json = "{}" # Empty layout
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
mock_chart = _mock_chart(id=15)
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
|
||||
updated_dashboard = _mock_dashboard(id=2)
|
||||
updated_dashboard.slices = [Mock(id=15)]
|
||||
mock_update_command.return_value.run.return_value = updated_dashboard
|
||||
|
||||
request = {"dashboard_id": 2, "chart_id": 15}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is None
|
||||
assert "row" in result.data.position # Should have row info
|
||||
assert result.data.position.get("row") == 0 # First row
|
||||
|
||||
# Verify update was called with proper layout structure
|
||||
call_args = mock_update_command.call_args[0][1]
|
||||
layout = json.loads(call_args["position_json"])
|
||||
assert "ROOT_ID" in layout
|
||||
assert "GRID_ID" in layout
|
||||
assert "ROW-0" in layout
|
||||
assert "CHART-15" in layout
|
||||
@@ -0,0 +1,573 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP dashboard tools (list_dashboards, get_dashboard_info)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.dashboard.schemas import (
|
||||
ListDashboardsRequest,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_basic(mock_list, mcp_server):
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Test Dashboard"
|
||||
dashboard.slug = "test-dashboard"
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.published = True
|
||||
dashboard.changed_by_name = "admin"
|
||||
dashboard.changed_on = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.created_by_name = "admin"
|
||||
dashboard.created_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.tags = []
|
||||
dashboard.owners = []
|
||||
dashboard.slices = []
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.uuid = "test-dashboard-uuid-1"
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_list.return_value = ([dashboard], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListDashboardsRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
dashboards = result.data.dashboards
|
||||
assert len(dashboards) == 1
|
||||
assert dashboards[0].dashboard_title == "Test Dashboard"
|
||||
assert dashboards[0].uuid == "test-dashboard-uuid-1"
|
||||
assert dashboards[0].slug == "test-dashboard"
|
||||
assert dashboards[0].published is True
|
||||
|
||||
# Verify UUID and slug are in default columns
|
||||
assert "uuid" in result.data.columns_requested
|
||||
assert "slug" in result.data.columns_requested
|
||||
assert "uuid" in result.data.columns_loaded
|
||||
assert "slug" in result.data.columns_loaded
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_with_filters(mock_list, mcp_server):
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Filtered Dashboard"
|
||||
dashboard.slug = "filtered-dashboard"
|
||||
dashboard.url = "/dashboard/2"
|
||||
dashboard.published = True
|
||||
dashboard.changed_by_name = "admin"
|
||||
dashboard.changed_on = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.created_by_name = "admin"
|
||||
dashboard.created_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.tags = []
|
||||
dashboard.owners = []
|
||||
dashboard.slices = []
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.uuid = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_list.return_value = ([dashboard], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
filters = [
|
||||
{"col": "dashboard_title", "opr": "sw", "value": "Sales"},
|
||||
{"col": "published", "opr": "eq", "value": True},
|
||||
]
|
||||
request = ListDashboardsRequest(
|
||||
filters=filters,
|
||||
select_columns=["id", "dashboard_title"],
|
||||
order_column="changed_on",
|
||||
order_direction="desc",
|
||||
page=1,
|
||||
page_size=50,
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.data.count == 1
|
||||
assert result.data.dashboards[0].dashboard_title == "Filtered Dashboard"
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_with_string_filters(mock_list, mcp_server):
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client: # noqa: F841
|
||||
filters = '[{"col": "dashboard_title", "opr": "sw", "value": "Sales"}]'
|
||||
|
||||
# Test that string filters cause validation error at schema level
|
||||
with pytest.raises(ValueError, match="validation error"):
|
||||
ListDashboardsRequest(filters=filters) # noqa: F841
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_api_error(mock_list, mcp_server):
|
||||
mock_list.side_effect = ToolError("API request failed")
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError) as excinfo: # noqa: PT012
|
||||
request = ListDashboardsRequest()
|
||||
await client.call_tool("list_dashboards", {"request": request.model_dump()})
|
||||
assert "API request failed" in str(excinfo.value)
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_with_search(mock_list, mcp_server):
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "search_dashboard"
|
||||
dashboard.slug = "search-dashboard"
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.published = True
|
||||
dashboard.changed_by_name = "admin"
|
||||
dashboard.changed_on = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.created_by_name = "admin"
|
||||
dashboard.created_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.tags = []
|
||||
dashboard.owners = []
|
||||
dashboard.slices = []
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.uuid = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_list.return_value = ([dashboard], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListDashboardsRequest(search="search_dashboard")
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.data.count == 1
|
||||
assert result.data.dashboards[0].dashboard_title == "search_dashboard"
|
||||
args, kwargs = mock_list.call_args
|
||||
assert kwargs["search"] == "search_dashboard"
|
||||
assert "dashboard_title" in kwargs["search_columns"]
|
||||
assert "slug" in kwargs["search_columns"]
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_with_simple_filters(mock_list, mcp_server):
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
filters = [
|
||||
{"col": "dashboard_title", "opr": "eq", "value": "Sales"},
|
||||
{"col": "published", "opr": "eq", "value": True},
|
||||
]
|
||||
request = ListDashboardsRequest(filters=filters)
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
assert hasattr(result.data, "count")
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_success(mock_info, mcp_server):
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Test Dashboard"
|
||||
dashboard.slug = "test-dashboard"
|
||||
dashboard.description = "Test description"
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.published = True
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.created_on = None
|
||||
dashboard.changed_on = None
|
||||
dashboard.created_by = None
|
||||
dashboard.changed_by = None
|
||||
dashboard.uuid = None
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_info.return_value = dashboard # Only the dashboard object
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
assert result.data["dashboard_title"] == "Test Dashboard"
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_not_found(mock_info, mcp_server):
|
||||
mock_info.return_value = None # Not found returns None
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
assert result.data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_access_denied(mock_info, mcp_server):
|
||||
mock_info.return_value = None # Access denied returns None
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
assert result.data["error_type"] == "not_found"
|
||||
|
||||
|
||||
# TODO (Phase 3+): Add tests for get_dashboard_available_filters tool
|
||||
|
||||
|
||||
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_by_uuid(mock_find_object, mcp_server):
|
||||
"""Test getting dashboard info using UUID identifier."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Test Dashboard UUID"
|
||||
dashboard.slug = "test-dashboard-uuid"
|
||||
dashboard.description = "Test description"
|
||||
dashboard.css = ""
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = "{}"
|
||||
dashboard.position_json = "{}"
|
||||
dashboard.published = True
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.created_on = None
|
||||
dashboard.changed_on = None
|
||||
dashboard.created_by = None
|
||||
dashboard.changed_by = None
|
||||
dashboard.uuid = "c3d4e5f6-g7h8-9012-cdef-gh3456789012"
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.created_on_humanized = "2 days ago"
|
||||
dashboard.changed_on_humanized = "1 day ago"
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
|
||||
mock_find_object.return_value = dashboard
|
||||
async with Client(mcp_server) as client:
|
||||
uuid_str = "c3d4e5f6-g7h8-9012-cdef-gh3456789012"
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": uuid_str}}
|
||||
)
|
||||
assert result.data["dashboard_title"] == "Test Dashboard UUID"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_by_slug(mock_find_object, mcp_server):
|
||||
"""Test getting dashboard info using slug identifier."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = 2
|
||||
dashboard.dashboard_title = "Test Dashboard Slug"
|
||||
dashboard.slug = "test-dashboard-slug"
|
||||
dashboard.description = "Test description"
|
||||
dashboard.css = ""
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = "{}"
|
||||
dashboard.position_json = "{}"
|
||||
dashboard.published = True
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.created_on = None
|
||||
dashboard.changed_on = None
|
||||
dashboard.created_by = None
|
||||
dashboard.changed_by = None
|
||||
dashboard.uuid = "d4e5f6g7-h8i9-0123-defg-hi4567890123"
|
||||
dashboard.url = "/dashboard/2"
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.created_on_humanized = "2 days ago"
|
||||
dashboard.changed_on_humanized = "1 day ago"
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
|
||||
mock_find_object.return_value = dashboard
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": "test-dashboard-slug"}}
|
||||
)
|
||||
assert result.data["dashboard_title"] == "Test Dashboard Slug"
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_custom_uuid_slug_columns(mock_list, mcp_server):
|
||||
"""Test that custom column selection includes UUID and slug when explicitly
|
||||
requested."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Custom Columns Dashboard"
|
||||
dashboard.slug = "custom-dashboard"
|
||||
dashboard.uuid = "test-custom-uuid-123"
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.published = True
|
||||
dashboard.changed_by_name = "admin"
|
||||
dashboard.changed_on = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.created_by_name = "admin"
|
||||
dashboard.created_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.tags = []
|
||||
dashboard.owners = []
|
||||
dashboard.slices = []
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"uuid": dashboard.uuid,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_list.return_value = ([dashboard], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListDashboardsRequest(
|
||||
select_columns=["id", "dashboard_title", "uuid", "slug"],
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
dashboards = result.data.dashboards
|
||||
assert len(dashboards) == 1
|
||||
assert dashboards[0].uuid == "test-custom-uuid-123"
|
||||
assert dashboards[0].slug == "custom-dashboard"
|
||||
|
||||
# Verify custom columns include UUID and slug
|
||||
assert "uuid" in result.data.columns_requested
|
||||
assert "slug" in result.data.columns_requested
|
||||
assert "uuid" in result.data.columns_loaded
|
||||
assert "slug" in result.data.columns_loaded
|
||||
|
||||
|
||||
class TestDashboardSortableColumns:
|
||||
"""Test sortable columns configuration for dashboard tools."""
|
||||
|
||||
def test_dashboard_sortable_columns_definition(self):
|
||||
"""Test that dashboard sortable columns are properly defined."""
|
||||
from superset.mcp_service.dashboard.tool.list_dashboards import (
|
||||
SORTABLE_DASHBOARD_COLUMNS,
|
||||
)
|
||||
|
||||
assert SORTABLE_DASHBOARD_COLUMNS == [
|
||||
"id",
|
||||
"dashboard_title",
|
||||
"slug",
|
||||
"published",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
# Ensure no computed properties are included
|
||||
assert "changed_on_delta_humanized" not in SORTABLE_DASHBOARD_COLUMNS
|
||||
assert "changed_by_name" not in SORTABLE_DASHBOARD_COLUMNS
|
||||
assert "uuid" not in SORTABLE_DASHBOARD_COLUMNS
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_with_valid_order_column(self, mock_list, mcp_server):
|
||||
"""Test list_dashboards with valid order column."""
|
||||
mock_list.return_value = ([], 0)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
# Test with valid sortable column
|
||||
request = ListDashboardsRequest(
|
||||
order_column="dashboard_title", order_direction="desc"
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
# Verify the DAO was called with the correct order column
|
||||
mock_list.assert_called_once()
|
||||
call_args = mock_list.call_args[1]
|
||||
assert call_args["order_column"] == "dashboard_title"
|
||||
assert call_args["order_direction"] == "desc"
|
||||
|
||||
# Verify the result
|
||||
assert result.data.count == 0
|
||||
assert result.data.dashboards == []
|
||||
|
||||
def test_sortable_columns_in_docstring(self):
|
||||
"""Test that sortable columns are documented in tool docstring."""
|
||||
from superset.mcp_service.dashboard.tool.list_dashboards import (
|
||||
list_dashboards,
|
||||
SORTABLE_DASHBOARD_COLUMNS,
|
||||
)
|
||||
|
||||
# Check list_dashboards docstring (stored in description after @mcp.tool)
|
||||
assert hasattr(list_dashboards, "description")
|
||||
assert "Sortable columns for order_column:" in list_dashboards.description
|
||||
for col in SORTABLE_DASHBOARD_COLUMNS:
|
||||
assert col in list_dashboards.description
|
||||
16
tests/unit_tests/mcp_service/dataset/__init__.py
Normal file
16
tests/unit_tests/mcp_service/dataset/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/dataset/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/dataset/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
1231
tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py
Normal file
1231
tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
16
tests/unit_tests/mcp_service/explore/__init__.py
Normal file
16
tests/unit_tests/mcp_service/explore/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/explore/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/explore/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,580 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Comprehensive unit tests for MCP generate_explore_link tool
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AxisConfig,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
GenerateExploreLinkRequest,
|
||||
LegendConfig,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _mock_dataset(id: int = 1) -> Mock:
|
||||
"""Create a mock dataset object."""
|
||||
dataset = Mock()
|
||||
dataset.id = id
|
||||
return dataset
|
||||
|
||||
|
||||
class TestGenerateExploreLink:
|
||||
"""Comprehensive tests for generate_explore_link MCP tool."""
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_table_explore_link_minimal(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test generating explore link for minimal table chart."""
|
||||
mock_create_form_data.return_value = "test_form_data_key_123"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
config = TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="region")]
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=test_form_data_key_123"
|
||||
)
|
||||
mock_create_form_data.assert_called_once()
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_table_explore_link_with_features(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test generating explore link for table chart with features."""
|
||||
mock_create_form_data.return_value = "comprehensive_key_456"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=5)
|
||||
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region", label="Sales Region"),
|
||||
ColumnRef(name="revenue", label="Total Revenue", aggregate="SUM"),
|
||||
ColumnRef(name="orders", label="Order Count", aggregate="COUNT"),
|
||||
],
|
||||
filters=[
|
||||
FilterConfig(column="year", op="=", value="2024"),
|
||||
FilterConfig(column="status", op="!=", value="cancelled"),
|
||||
],
|
||||
sort_by=["revenue", "orders"],
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="5", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=comprehensive_key_456"
|
||||
)
|
||||
mock_create_form_data.assert_called_once()
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_line_chart_explore_link(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test generating explore link for line chart."""
|
||||
mock_create_form_data.return_value = "line_chart_key_789"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=3)
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date", label="Date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", label="Daily Sales", aggregate="SUM"),
|
||||
ColumnRef(name="orders", label="Order Count", aggregate="COUNT"),
|
||||
],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="region", label="Sales Region"),
|
||||
x_axis=AxisConfig(title="Time Period", format="smart_date"),
|
||||
y_axis=AxisConfig(title="Sales Metrics", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="bottom"),
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="3", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=line_chart_key_789"
|
||||
)
|
||||
mock_create_form_data.assert_called_once()
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_bar_chart_explore_link(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test generating explore link for bar chart."""
|
||||
mock_create_form_data.return_value = "bar_chart_key_abc"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=7)
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="product_category", label="Category"),
|
||||
y=[ColumnRef(name="revenue", label="Revenue", aggregate="SUM")],
|
||||
kind="bar",
|
||||
group_by=ColumnRef(name="quarter", label="Quarter"),
|
||||
y_axis=AxisConfig(title="Revenue ($)", format="$,.0f"),
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="7", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=bar_chart_key_abc"
|
||||
)
|
||||
mock_create_form_data.assert_called_once()
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_area_chart_explore_link(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test generating explore link for area chart."""
|
||||
mock_create_form_data.return_value = "area_chart_key_def"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=2)
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="month", label="Month"),
|
||||
y=[
|
||||
ColumnRef(
|
||||
name="cumulative_sales", label="Cumulative Sales", aggregate="SUM"
|
||||
)
|
||||
],
|
||||
kind="area",
|
||||
legend=LegendConfig(show=False),
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="2", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=area_chart_key_def"
|
||||
)
|
||||
mock_create_form_data.assert_called_once()
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_scatter_chart_explore_link(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test generating explore link for scatter chart."""
|
||||
mock_create_form_data.return_value = "scatter_chart_key_ghi"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=4)
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="price", label="Unit Price"),
|
||||
y=[ColumnRef(name="quantity", label="Quantity Sold", aggregate="SUM")],
|
||||
kind="scatter",
|
||||
group_by=ColumnRef(name="product_type", label="Product Type"),
|
||||
x_axis=AxisConfig(title="Price ($)", format="$,.2f"),
|
||||
y_axis=AxisConfig(title="Quantity", format=",.0f"),
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="4", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=scatter_chart_key_ghi"
|
||||
)
|
||||
mock_create_form_data.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_explore_link_cache_failure_fallback(
|
||||
self, mock_create_form_data, mcp_server
|
||||
):
|
||||
"""Test fallback when form_data cache creation fails."""
|
||||
mock_create_form_data.side_effect = Exception("Cache storage failed")
|
||||
|
||||
config = TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="test_col")]
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
# Should fallback to basic URL format
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=1"
|
||||
)
|
||||
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_explore_link_database_lock_fallback(
|
||||
self, mock_create_form_data, mcp_server
|
||||
):
|
||||
"""Test fallback when database is locked."""
|
||||
from sqlalchemy.exc import OperationalError
|
||||
|
||||
mock_create_form_data.side_effect = OperationalError(
|
||||
"database is locked", None, None
|
||||
)
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="sales")],
|
||||
kind="line",
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="5", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
# Should fallback to basic dataset URL
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=5"
|
||||
)
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_explore_link_with_many_columns(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test generating explore link with many columns."""
|
||||
mock_create_form_data.return_value = "many_columns_key"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Create 15 columns
|
||||
columns = [
|
||||
ColumnRef(
|
||||
name=f"metric_{i}",
|
||||
label=f"Metric {i}",
|
||||
aggregate="SUM" if i % 2 == 0 else "COUNT",
|
||||
)
|
||||
for i in range(15)
|
||||
]
|
||||
|
||||
config = TableChartConfig(chart_type="table", columns=columns)
|
||||
request = GenerateExploreLinkRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=many_columns_key"
|
||||
)
|
||||
mock_create_form_data.assert_called_once()
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_explore_link_with_many_filters(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test generating explore link with many filters."""
|
||||
mock_create_form_data.return_value = "many_filters_key"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Create 12 filters
|
||||
filters = [
|
||||
FilterConfig(
|
||||
column=f"filter_col_{i}",
|
||||
op="=" if i % 3 == 0 else "!=",
|
||||
value=f"value_{i}",
|
||||
)
|
||||
for i in range(12)
|
||||
]
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x_col"),
|
||||
y=[ColumnRef(name="y_col")],
|
||||
kind="bar",
|
||||
filters=filters,
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=many_filters_key"
|
||||
)
|
||||
mock_create_form_data.assert_called_once()
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_explore_link_url_format_consistency(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test that all generated URLs follow consistent format."""
|
||||
mock_create_form_data.return_value = "consistency_test_key"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
configs = [
|
||||
TableChartConfig(chart_type="table", columns=[ColumnRef(name="col1")]),
|
||||
XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
kind="line",
|
||||
),
|
||||
XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
kind="bar",
|
||||
),
|
||||
XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
kind="area",
|
||||
),
|
||||
XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="x"),
|
||||
y=[ColumnRef(name="y")],
|
||||
kind="scatter",
|
||||
),
|
||||
]
|
||||
|
||||
for i, config in enumerate(configs):
|
||||
request = GenerateExploreLinkRequest(dataset_id=str(i + 1), config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
# All URLs should follow the same format
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=consistency_test_key"
|
||||
)
|
||||
assert result.data["error"] is None
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_explore_link_dataset_id_types(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test explore link generation with different dataset_id formats."""
|
||||
mock_create_form_data.return_value = "dataset_test_key"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
config = TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="test_col")]
|
||||
)
|
||||
|
||||
# Test various dataset_id formats
|
||||
dataset_ids = ["1", "42", "999", "123456789"]
|
||||
|
||||
for dataset_id in dataset_ids:
|
||||
request = GenerateExploreLinkRequest(dataset_id=dataset_id, config=config)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=dataset_test_key"
|
||||
)
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_explore_link_complex_configuration(
|
||||
self, mock_create_form_data, mock_find_dataset, mcp_server
|
||||
):
|
||||
"""Test explore link generation with complex chart configuration."""
|
||||
mock_create_form_data.return_value = "complex_config_key"
|
||||
mock_find_dataset.return_value = _mock_dataset(id=10)
|
||||
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="timestamp", label="Time"),
|
||||
y=[
|
||||
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
|
||||
ColumnRef(name="orders", label="Orders", aggregate="COUNT"),
|
||||
ColumnRef(name="profit", label="Profit", aggregate="AVG"),
|
||||
],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="region", label="Region"),
|
||||
x_axis=AxisConfig(title="Time Period", format="smart_date"),
|
||||
y_axis=AxisConfig(title="Metrics", format="$,.2f", scale="linear"),
|
||||
legend=LegendConfig(show=True, position="bottom"),
|
||||
filters=[
|
||||
FilterConfig(column="status", op="=", value="active"),
|
||||
FilterConfig(column="date", op=">=", value="2024-01-01"),
|
||||
FilterConfig(column="revenue", op=">", value="1000"),
|
||||
],
|
||||
)
|
||||
request = GenerateExploreLinkRequest(dataset_id="10", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
assert result.data["error"] is None
|
||||
assert (
|
||||
result.data["url"]
|
||||
== "http://localhost:9001/explore/?form_data_key=complex_config_key"
|
||||
)
|
||||
mock_create_form_data.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_url_different_datasets(
|
||||
self, mock_create_form_data, mcp_server
|
||||
):
|
||||
"""Test fallback URLs are correct for different dataset IDs."""
|
||||
mock_create_form_data.side_effect = Exception(
|
||||
"Always fail for fallback testing"
|
||||
)
|
||||
|
||||
config = TableChartConfig(chart_type="table", columns=[ColumnRef(name="col")])
|
||||
|
||||
dataset_ids = ["1", "5", "100", "999"]
|
||||
|
||||
for dataset_id in dataset_ids:
|
||||
request = GenerateExploreLinkRequest(dataset_id=dataset_id, config=config)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_explore_link", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
# Should fallback to basic URL with correct dataset_id
|
||||
expected_url = f"http://localhost:9001/explore/?datasource_type=table&datasource_id={dataset_id}"
|
||||
assert result.data["error"] is None
|
||||
assert result.data["url"] == expected_url
|
||||
16
tests/unit_tests/mcp_service/sql_lab/__init__.py
Normal file
16
tests/unit_tests/mcp_service/sql_lab/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,64 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Helper function to extract row data from MCP responses.
|
||||
|
||||
The MCP client seems to wrap dict rows in Root objects.
|
||||
This helper handles the extraction properly.
|
||||
"""
|
||||
|
||||
|
||||
def extract_row_data(row):
|
||||
"""Extract dictionary data from a row object."""
|
||||
# Handle different possible formats
|
||||
if isinstance(row, dict):
|
||||
return row
|
||||
|
||||
# Check for Pydantic Root object
|
||||
if hasattr(row, "__root__"):
|
||||
return row.__root__
|
||||
|
||||
# Check if it's a Pydantic model with model_dump
|
||||
if hasattr(row, "model_dump"):
|
||||
return row.model_dump()
|
||||
|
||||
# Try to access __dict__ directly
|
||||
if hasattr(row, "__dict__"):
|
||||
# Filter out private attributes
|
||||
return {k: v for k, v in row.__dict__.items() if not k.startswith("_")}
|
||||
|
||||
# Last resort - convert to string and parse
|
||||
# This is for the Root object issue
|
||||
row_str = str(row)
|
||||
if row_str == "Root()":
|
||||
# Empty Root object - the actual data might be elsewhere
|
||||
# Let's check all attributes
|
||||
attrs = dir(row)
|
||||
for attr in attrs:
|
||||
if not attr.startswith("_") and attr not in [
|
||||
"model_dump",
|
||||
"model_validate",
|
||||
]:
|
||||
try:
|
||||
val = getattr(row, attr)
|
||||
if isinstance(val, dict):
|
||||
return val
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
raise ValueError(f"Cannot extract data from row of type {type(row)}: {row}")
|
||||
16
tests/unit_tests/mcp_service/sql_lab/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/sql_lab/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
497
tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
Normal file
497
tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
Normal file
@@ -0,0 +1,497 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for execute_sql MCP tool
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _mock_database(
|
||||
id: int = 1,
|
||||
database_name: str = "test_db",
|
||||
allow_dml: bool = False,
|
||||
) -> Mock:
|
||||
"""Create a mock database object."""
|
||||
database = Mock()
|
||||
database.id = id
|
||||
database.database_name = database_name
|
||||
database.allow_dml = allow_dml
|
||||
|
||||
# Mock raw connection context manager
|
||||
mock_cursor = Mock()
|
||||
mock_cursor.description = [
|
||||
("id", "INTEGER", None, None, None, None, False),
|
||||
("name", "VARCHAR", None, None, None, None, True),
|
||||
]
|
||||
mock_cursor.fetchmany.return_value = [(1, "test_name")]
|
||||
mock_cursor.rowcount = 1
|
||||
|
||||
mock_conn = Mock()
|
||||
mock_conn.cursor.return_value = mock_cursor
|
||||
mock_conn.commit = Mock()
|
||||
|
||||
mock_context = MagicMock()
|
||||
mock_context.__enter__.return_value = mock_conn
|
||||
mock_context.__exit__.return_value = None
|
||||
|
||||
database.get_raw_connection.return_value = mock_context
|
||||
|
||||
return database
|
||||
|
||||
|
||||
class TestExecuteSql:
|
||||
"""Tests for execute_sql MCP tool."""
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_basic_select(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test basic SELECT query execution."""
|
||||
# Setup mocks
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT id, name FROM users LIMIT 10",
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
assert result.data.row_count == 1
|
||||
assert len(result.data.rows) == 1
|
||||
assert result.data.rows[0]["id"] == 1
|
||||
assert result.data.rows[0]["name"] == "test_name"
|
||||
assert len(result.data.columns) == 2
|
||||
assert result.data.columns[0].name == "id"
|
||||
assert result.data.columns[0].type == "INTEGER"
|
||||
assert result.data.execution_time > 0
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_with_parameters(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test SQL execution with parameter substitution."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM {table} WHERE status = '{status}' LIMIT {limit}",
|
||||
"parameters": {"table": "orders", "status": "active", "limit": "5"},
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
# Verify parameter substitution happened
|
||||
mock_database.get_raw_connection.assert_called_once()
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
# Check that the SQL was formatted with parameters
|
||||
executed_sql = cursor.execute.call_args[0][0]
|
||||
assert "orders" in executed_sql
|
||||
assert "active" in executed_sql
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_database_not_found(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test error when database is not found."""
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
None
|
||||
)
|
||||
|
||||
request = {
|
||||
"database_id": 999,
|
||||
"sql": "SELECT 1",
|
||||
"limit": 1,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "Database with ID 999 not found" in result.data.error
|
||||
assert result.data.error_type == "DATABASE_NOT_FOUND_ERROR"
|
||||
assert result.data.rows is None
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_access_denied(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test error when user lacks database access."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
# Use Mock instead of AsyncMock for synchronous call
|
||||
from unittest.mock import Mock
|
||||
|
||||
mock_security_manager.can_access_database = Mock(return_value=False)
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT 1",
|
||||
"limit": 1,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "Access denied to database" in result.data.error
|
||||
assert result.data.error_type == "SECURITY_ERROR"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_dml_not_allowed(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test error when DML operations are not allowed."""
|
||||
mock_database = _mock_database(allow_dml=False)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "UPDATE users SET name = 'test' WHERE id = 1",
|
||||
"limit": 1,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert result.data.error_type == "DML_NOT_ALLOWED"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_dml_allowed(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test successful DML execution when allowed."""
|
||||
mock_database = _mock_database(allow_dml=True)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
# Mock cursor for DML operation
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
cursor.rowcount = 3 # 3 rows affected
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "UPDATE users SET active = true WHERE last_login > '2024-01-01'",
|
||||
"limit": 1,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
assert result.data.affected_rows == 3
|
||||
assert result.data.rows == [] # Empty rows for DML
|
||||
assert result.data.row_count == 0
|
||||
# Verify commit was called
|
||||
(
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.commit.assert_called_once()
|
||||
)
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_empty_results(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test query that returns no results."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
# Mock empty results
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
cursor.fetchmany.return_value = []
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM users WHERE id = 999999",
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
assert result.data.row_count == 0
|
||||
assert len(result.data.rows) == 0
|
||||
assert len(result.data.columns) == 2 # Column metadata still returned
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_missing_parameter(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test error when required parameter is missing."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM {table_name} WHERE id = {user_id}",
|
||||
"parameters": {"table_name": "users"}, # Missing user_id
|
||||
"limit": 1,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "user_id" in result.data.error # Error contains parameter name
|
||||
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_empty_parameters_with_placeholders(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test error when empty parameters dict is provided but SQL has
|
||||
placeholders."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM {table_name} LIMIT 5",
|
||||
"parameters": {}, # Empty dict but SQL has {table_name}
|
||||
"limit": 5,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "Missing parameter: table_name" in result.data.error
|
||||
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_with_schema(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test SQL execution with schema specification."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT COUNT(*) as total FROM orders",
|
||||
"schema": "sales",
|
||||
"limit": 1,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
assert result.data.error is None
|
||||
# Verify schema was passed to get_raw_connection
|
||||
# Verify schema was passed
|
||||
call_args = mock_database.get_raw_connection.call_args
|
||||
assert call_args[1]["schema"] == "sales"
|
||||
assert call_args[1]["catalog"] is None
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_limit_enforcement(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test that LIMIT is added to SELECT queries without one."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM users", # No LIMIT
|
||||
"limit": 50,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is True
|
||||
# Verify LIMIT was added
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
executed_sql = cursor.execute.call_args[0][0]
|
||||
assert "LIMIT 50" in executed_sql
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("superset.db")
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_sql_injection_prevention(
|
||||
self, mock_db, mock_security_manager, mcp_server
|
||||
):
|
||||
"""Test that SQL injection attempts are handled safely."""
|
||||
mock_database = _mock_database()
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
|
||||
mock_database
|
||||
)
|
||||
mock_security_manager.can_access_database.return_value = True
|
||||
|
||||
# Mock execute to raise an exception
|
||||
cursor = ( # fmt: skip
|
||||
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
|
||||
)
|
||||
cursor.execute.side_effect = Exception("Syntax error")
|
||||
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
assert result.data.success is False
|
||||
assert result.data.error is not None
|
||||
assert "Syntax error" in result.data.error # Contains actual error
|
||||
assert result.data.error_type == "EXECUTION_ERROR"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_empty_query_validation(self, mcp_server):
|
||||
"""Test validation of empty SQL query."""
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": " ", # Empty/whitespace only
|
||||
"limit": 10,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError, match="SQL query cannot be empty"):
|
||||
await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_sql_invalid_limit(self, mcp_server):
|
||||
"""Test validation of invalid limit values."""
|
||||
# Test limit too low
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT 1",
|
||||
"limit": 0,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError, match="minimum of 1"):
|
||||
await client.call_tool("execute_sql", {"request": request})
|
||||
|
||||
# Test limit too high
|
||||
request = {
|
||||
"database_id": 1,
|
||||
"sql": "SELECT 1",
|
||||
"limit": 20000,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError, match="maximum of 10000"):
|
||||
await client.call_tool("execute_sql", {"request": request})
|
||||
Reference in New Issue
Block a user