feat(mcp): MCP service implementation (PRs 3-9 consolidated) (#35877)

This commit is contained in:
Amin Ghadersohi
2025-11-01 02:33:21 +11:00
committed by GitHub
parent 30d584afd1
commit fee4e7d8e2
106 changed files with 21826 additions and 223 deletions

View File

@@ -0,0 +1,160 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP chart schema validation.
"""
import pytest
from pydantic import ValidationError
from superset.mcp_service.chart.schemas import (
ColumnRef,
TableChartConfig,
XYChartConfig,
)
class TestTableChartConfig:
"""Test TableChartConfig validation."""
def test_duplicate_labels_rejected(self) -> None:
"""Test that TableChartConfig rejects duplicate labels."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
TableChartConfig(
columns=[
ColumnRef(name="product_line", label="product_line"),
ColumnRef(name="sales", aggregate="SUM", label="product_line"),
]
)
def test_unique_labels_accepted(self) -> None:
"""Test that TableChartConfig accepts unique labels."""
config = TableChartConfig(
columns=[
ColumnRef(name="product_line", label="Product Line"),
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
]
)
assert len(config.columns) == 2
class TestXYChartConfig:
"""Test XYChartConfig validation."""
def test_different_labels_accepted(self) -> None:
"""Test that different labels for x and y are accepted."""
config = XYChartConfig(
x=ColumnRef(name="product_line"), # Label: "product_line"
y=[
ColumnRef(
name="product_line", aggregate="COUNT"
), # Label: "COUNT(product_line)"
],
)
assert config.x.name == "product_line"
assert config.y[0].aggregate == "COUNT"
def test_explicit_duplicate_label_rejected(self) -> None:
"""Test that explicit duplicate labels are rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
x=ColumnRef(name="product_line"),
y=[ColumnRef(name="sales", label="product_line")],
)
def test_duplicate_y_axis_labels_rejected(self) -> None:
"""Test that duplicate y-axis labels are rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="revenue", aggregate="SUM", label="SUM(sales)"),
],
)
def test_unique_labels_accepted(self) -> None:
"""Test that unique labels are accepted."""
config = XYChartConfig(
x=ColumnRef(name="date", label="Order Date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="profit", aggregate="AVG", label="Average Profit"),
],
)
assert len(config.y) == 2
def test_group_by_duplicate_with_x_rejected(self) -> None:
"""Test that group_by conflicts with x are rejected."""
with pytest.raises(ValidationError, match="Duplicate column/metric labels"):
XYChartConfig(
x=ColumnRef(name="region"),
y=[ColumnRef(name="sales", aggregate="SUM")],
group_by=ColumnRef(name="category", label="region"),
)
def test_realistic_chart_configurations(self) -> None:
"""Test realistic chart configurations."""
# This should work - COUNT(product_line) != product_line
config = XYChartConfig(
x=ColumnRef(name="product_line"),
y=[
ColumnRef(name="product_line", aggregate="COUNT"),
ColumnRef(name="sales", aggregate="SUM"),
],
)
assert config.x.name == "product_line"
assert len(config.y) == 2
def test_time_series_chart_configuration(self) -> None:
"""Test time series chart configuration works."""
# This should PASS now - the chart creation logic fixes duplicates
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"),
],
kind="line",
)
assert config.x.name == "order_date"
assert config.kind == "line"
def test_time_series_with_custom_x_axis_label(self) -> None:
"""Test time series chart with custom x-axis label."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="order_date", label="Order Date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"),
],
kind="line",
)
assert config.x.label == "Order Date"
def test_area_chart_configuration(self) -> None:
"""Test area chart configuration."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="category"),
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
kind="area",
)
assert config.kind == "area"

View File

@@ -0,0 +1,465 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for chart utilities module"""
from unittest.mock import patch
import pytest
from superset.mcp_service.chart.chart_utils import (
create_metric_object,
generate_chart_name,
generate_explore_link,
map_config_to_form_data,
map_filter_operator,
map_table_config,
map_xy_config,
)
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
LegendConfig,
TableChartConfig,
XYChartConfig,
)
class TestCreateMetricObject:
"""Test create_metric_object function"""
def test_create_metric_object_with_aggregate(self) -> None:
"""Test creating metric object with specified aggregate"""
col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue")
result = create_metric_object(col)
assert result["aggregate"] == "SUM"
assert result["column"]["column_name"] == "revenue"
assert result["label"] == "Total Revenue"
assert result["optionName"] == "metric_revenue"
assert result["expressionType"] == "SIMPLE"
def test_create_metric_object_default_aggregate(self) -> None:
"""Test creating metric object with default aggregate"""
col = ColumnRef(name="orders")
result = create_metric_object(col)
assert result["aggregate"] == "SUM"
assert result["column"]["column_name"] == "orders"
assert result["label"] == "SUM(orders)"
assert result["optionName"] == "metric_orders"
class TestMapFilterOperator:
"""Test map_filter_operator function"""
def test_map_filter_operators(self) -> None:
"""Test mapping of various filter operators"""
assert map_filter_operator("=") == "=="
assert map_filter_operator(">") == ">"
assert map_filter_operator("<") == "<"
assert map_filter_operator(">=") == ">="
assert map_filter_operator("<=") == "<="
assert map_filter_operator("!=") == "!="
def test_map_filter_operator_unknown(self) -> None:
"""Test mapping of unknown operator returns original"""
assert map_filter_operator("UNKNOWN") == "UNKNOWN"
class TestMapTableConfig:
"""Test map_table_config function"""
def test_map_table_config_basic(self) -> None:
"""Test basic table config mapping with aggregated columns"""
config = TableChartConfig(
columns=[
ColumnRef(name="product", aggregate="COUNT"),
ColumnRef(name="revenue", aggregate="SUM"),
]
)
result = map_table_config(config)
assert result["viz_type"] == "table"
assert result["query_mode"] == "aggregate"
# Aggregated columns should be in metrics, not all_columns
assert "all_columns" not in result
assert len(result["metrics"]) == 2
assert result["metrics"][0]["aggregate"] == "COUNT"
assert result["metrics"][1]["aggregate"] == "SUM"
def test_map_table_config_raw_columns(self) -> None:
"""Test table config mapping with raw columns (no aggregates)"""
config = TableChartConfig(
columns=[
ColumnRef(name="product"),
ColumnRef(name="category"),
]
)
result = map_table_config(config)
assert result["viz_type"] == "table"
assert result["query_mode"] == "raw"
# Raw columns should be in all_columns
assert result["all_columns"] == ["product", "category"]
assert "metrics" not in result
def test_map_table_config_with_filters(self) -> None:
"""Test table config mapping with filters"""
config = TableChartConfig(
columns=[ColumnRef(name="product")],
filters=[FilterConfig(column="status", op="=", value="active")],
)
result = map_table_config(config)
assert "adhoc_filters" in result
assert len(result["adhoc_filters"]) == 1
filter_obj = result["adhoc_filters"][0]
assert filter_obj["subject"] == "status"
assert filter_obj["operator"] == "=="
assert filter_obj["comparator"] == "active"
assert filter_obj["expressionType"] == "SIMPLE"
def test_map_table_config_with_sort(self) -> None:
"""Test table config mapping with sort"""
config = TableChartConfig(
columns=[ColumnRef(name="product")], sort_by=["product", "revenue"]
)
result = map_table_config(config)
assert result["order_by_cols"] == ["product", "revenue"]
class TestMapXYConfig:
"""Test map_xy_config function"""
def test_map_xy_config_line_chart(self) -> None:
"""Test XY config mapping for line chart"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue", aggregate="SUM")],
kind="line",
)
result = map_xy_config(config)
assert result["viz_type"] == "echarts_timeseries_line"
assert result["x_axis"] == "date"
assert len(result["metrics"]) == 1
assert result["metrics"][0]["aggregate"] == "SUM"
def test_map_xy_config_with_groupby(self) -> None:
"""Test XY config mapping with group by"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="bar",
group_by=ColumnRef(name="region"),
)
result = map_xy_config(config)
assert result["viz_type"] == "echarts_timeseries_bar"
assert result["groupby"] == ["region"]
def test_map_xy_config_with_axes(self) -> None:
"""Test XY config mapping with axis configurations"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="area",
x_axis=AxisConfig(title="Date", format="%Y-%m-%d"),
y_axis=AxisConfig(title="Revenue", scale="log", format="$,.2f"),
)
result = map_xy_config(config)
assert result["viz_type"] == "echarts_area"
assert result["x_axis_title"] == "Date"
assert result["x_axis_format"] == "%Y-%m-%d"
assert result["y_axis_title"] == "Revenue"
assert result["y_axis_format"] == "$,.2f"
assert result["y_axis_scale"] == "log"
def test_map_xy_config_with_legend(self) -> None:
"""Test XY config mapping with legend configuration"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue")],
kind="scatter",
legend=LegendConfig(show=False, position="top"),
)
result = map_xy_config(config)
assert result["viz_type"] == "echarts_timeseries_scatter"
assert result["show_legend"] is False
assert result["legend_orientation"] == "top"
class TestMapConfigToFormData:
"""Test map_config_to_form_data function"""
def test_map_table_config_type(self) -> None:
"""Test mapping table config type"""
config = TableChartConfig(columns=[ColumnRef(name="test")])
result = map_config_to_form_data(config)
assert result["viz_type"] == "table"
def test_map_xy_config_type(self) -> None:
"""Test mapping XY config type"""
config = XYChartConfig(
x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line"
)
result = map_config_to_form_data(config)
assert result["viz_type"] == "echarts_timeseries_line"
def test_map_unsupported_config_type(self) -> None:
"""Test mapping unsupported config type raises error"""
with pytest.raises(ValueError, match="Unsupported config type"):
map_config_to_form_data("invalid_config") # type: ignore
class TestGenerateChartName:
"""Test generate_chart_name function"""
def test_generate_table_chart_name(self) -> None:
"""Test generating name for table chart"""
config = TableChartConfig(
columns=[
ColumnRef(name="product"),
ColumnRef(name="revenue"),
]
)
result = generate_chart_name(config)
assert result == "Table Chart - product, revenue"
def test_generate_xy_chart_name(self) -> None:
"""Test generating name for XY chart"""
config = XYChartConfig(
x=ColumnRef(name="date"),
y=[ColumnRef(name="revenue"), ColumnRef(name="orders")],
kind="line",
)
result = generate_chart_name(config)
assert result == "Line Chart - date vs revenue, orders"
def test_generate_chart_name_unsupported(self) -> None:
"""Test generating name for unsupported config type"""
result = generate_chart_name("invalid_config") # type: ignore
assert result == "Chart"
class TestGenerateExploreLink:
"""Test generate_explore_link function"""
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
def test_generate_explore_link_uses_base_url(self, mock_get_base_url) -> None:
"""Test that generate_explore_link uses the configured base URL"""
from urllib.parse import urlparse
mock_get_base_url.return_value = "https://superset.example.com"
form_data = {"viz_type": "table", "metrics": ["count"]}
result = generate_explore_link("123", form_data)
# Should use the configured base URL - use urlparse to avoid CodeQL warning
parsed_url = urlparse(result)
expected_netloc = "superset.example.com"
assert parsed_url.scheme == "https"
assert parsed_url.netloc == expected_netloc
assert "/explore/" in parsed_url.path
assert "datasource_id=123" in result
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
def test_generate_explore_link_fallback_url(self, mock_get_base_url) -> None:
"""Test generate_explore_link returns fallback URL when dataset not found"""
mock_get_base_url.return_value = "http://localhost:9001"
form_data = {"viz_type": "table"}
# Mock dataset not found scenario
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None):
result = generate_explore_link("999", form_data)
assert (
result
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=999"
)
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
@patch("superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand")
def test_generate_explore_link_with_form_data_key(
self, mock_command, mock_get_base_url
) -> None:
"""Test generate_explore_link creates form_data_key when dataset exists"""
mock_get_base_url.return_value = "http://localhost:9001"
mock_command.return_value.run.return_value = "test_form_data_key"
# Mock dataset exists
mock_dataset = type("Dataset", (), {"id": 123})()
with patch(
"superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_dataset
):
result = generate_explore_link(123, {"viz_type": "table"})
assert (
result == "http://localhost:9001/explore/?form_data_key=test_form_data_key"
)
mock_command.assert_called_once()
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
def test_generate_explore_link_exception_handling(self, mock_get_base_url) -> None:
"""Test generate_explore_link handles exceptions gracefully"""
mock_get_base_url.return_value = "http://localhost:9001"
# Mock exception during form_data creation
with patch(
"superset.daos.dataset.DatasetDAO.find_by_id",
side_effect=Exception("DB error"),
):
result = generate_explore_link("123", {"viz_type": "table"})
# Should fallback to basic URL
assert (
result
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=123"
)
class TestCriticalBugFixes:
"""Test critical bug fixes for chart utilities."""
def test_time_series_aggregation_fix(self) -> None:
"""Test that time series charts preserve temporal dimension."""
# Create a time series chart configuration
config = XYChartConfig(
chart_type="xy",
kind="line",
x=ColumnRef(name="order_date"),
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
)
form_data = map_xy_config(config)
# Verify the fix: x_axis should be set correctly
assert form_data["x_axis"] == "order_date"
# Verify the fix: groupby should not duplicate x_axis
# This prevents the "Duplicate column/metric labels" error
assert "groupby" not in form_data or "order_date" not in form_data.get(
"groupby", []
)
# Verify chart type mapping
assert form_data["viz_type"] == "echarts_timeseries_line"
def test_time_series_with_explicit_group_by(self) -> None:
"""Test time series with explicit group_by different from x_axis."""
config = XYChartConfig(
chart_type="xy",
kind="line",
x=ColumnRef(name="order_date"),
y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")],
group_by=ColumnRef(name="category"),
)
form_data = map_xy_config(config)
# Verify x_axis is set
assert form_data["x_axis"] == "order_date"
# Verify groupby only contains the explicit group_by, not x_axis
assert form_data.get("groupby") == ["category"]
assert "order_date" not in form_data.get("groupby", [])
def test_duplicate_label_prevention(self) -> None:
"""Test that duplicate column/metric labels are prevented."""
# This configuration would previously cause:
# "Duplicate column/metric labels: 'price_each'"
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="price_each", label="Price Each"), # Custom label
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="quantity", aggregate="COUNT", label="Order Count"),
],
group_by=ColumnRef(name="price_each"), # Same column as x_axis
kind="line",
)
form_data = map_xy_config(config)
# Verify the fix: x_axis is set
assert form_data["x_axis"] == "price_each"
# Verify the fix: groupby is empty because group_by == x_axis
# This prevents the duplicate label error
assert "groupby" not in form_data or not form_data["groupby"]
def test_metric_object_creation_with_labels(self) -> None:
"""Test that metric objects are created correctly with proper labels."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
ColumnRef(name="profit", aggregate="AVG"), # No custom label
],
kind="bar",
)
form_data = map_xy_config(config)
# Verify metrics are created correctly
metrics = form_data["metrics"]
assert len(metrics) == 2
# First metric with custom label
assert metrics[0]["label"] == "Total Sales"
assert metrics[0]["aggregate"] == "SUM"
assert metrics[0]["column"]["column_name"] == "sales"
# Second metric with auto-generated label
assert metrics[1]["label"] == "AVG(profit)"
assert metrics[1]["aggregate"] == "AVG"
assert metrics[1]["column"]["column_name"] == "profit"
def test_chart_type_mapping_comprehensive(self) -> None:
"""Test that chart types are mapped correctly to Superset viz types."""
test_cases = [
("line", "echarts_timeseries_line"),
("bar", "echarts_timeseries_bar"),
("area", "echarts_area"),
("scatter", "echarts_timeseries_scatter"),
]
for kind, expected_viz_type in test_cases:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind=kind,
)
form_data = map_xy_config(config)
assert form_data["viz_type"] == expected_viz_type

View File

@@ -0,0 +1,268 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP generate_chart tool
"""
import pytest
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
GenerateChartRequest,
LegendConfig,
TableChartConfig,
XYChartConfig,
)
class TestGenerateChart:
"""Tests for generate_chart MCP tool."""
@pytest.mark.asyncio
async def test_generate_chart_request_structure(self):
"""Test that chart generation request structures are properly formed."""
# Table chart request
table_config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region", label="Region"),
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
],
filters=[FilterConfig(column="year", op="=", value="2024")],
sort_by=["sales"],
)
table_request = GenerateChartRequest(dataset_id="1", config=table_config)
assert table_request.dataset_id == "1"
assert table_request.config.chart_type == "table"
assert len(table_request.config.columns) == 2
assert table_request.config.columns[0].name == "region"
assert table_request.config.columns[1].aggregate == "SUM"
# XY chart request
xy_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind="line",
group_by=ColumnRef(name="category"),
x_axis=AxisConfig(title="Date", format="smart_date"),
y_axis=AxisConfig(title="Value", format="$,.2f"),
legend=LegendConfig(show=True, position="top"),
)
xy_request = GenerateChartRequest(dataset_id="2", config=xy_config)
assert xy_request.config.chart_type == "xy"
assert xy_request.config.x.name == "date"
assert xy_request.config.y[0].aggregate == "SUM"
assert xy_request.config.kind == "line"
assert xy_request.config.x_axis.title == "Date"
assert xy_request.config.legend.show is True
@pytest.mark.asyncio
async def test_generate_chart_validation_error_handling(self):
"""Test that validation errors are properly structured."""
# Create a validation error with the correct structure
validation_error_entry = {
"field": "x_axis",
"provided_value": "invalid_col",
"error_type": "column_not_found",
"message": "Column 'invalid_col' not found",
}
# Test that validation error structure is correct
assert validation_error_entry["field"] == "x_axis"
assert validation_error_entry["error_type"] == "column_not_found"
@pytest.mark.asyncio
async def test_chart_config_variations(self):
"""Test various chart configuration options."""
# Test all chart types
chart_types = ["line", "bar", "area", "scatter"]
for chart_type in chart_types:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x_col"),
y=[ColumnRef(name="y_col")],
kind=chart_type,
)
assert config.kind == chart_type
# Test multiple Y columns
multi_y_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="profit", aggregate="AVG"),
ColumnRef(name="quantity", aggregate="COUNT"),
],
kind="line",
)
assert len(multi_y_config.y) == 3
assert multi_y_config.y[1].aggregate == "AVG"
# Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="]
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
for i, f in enumerate(filters):
assert f.op == operators[i]
@pytest.mark.asyncio
async def test_generate_chart_response_structure(self):
"""Test the expected response structure for chart generation."""
# The response should contain these fields
_ = {
"chart": {
"id": int,
"slice_name": str,
"viz_type": str,
"url": str,
"uuid": str,
"saved": bool,
},
"error": None,
"success": bool,
"schema_version": str,
"api_version": str,
}
# When chart creation succeeds, these additional fields may be present
_ = [
"previews",
"capabilities",
"semantics",
"explore_url",
"form_data_key",
"api_endpoints",
"performance",
"accessibility",
]
# This is just a structural test - actual integration tests would verify
# the tool returns data matching this structure
@pytest.mark.asyncio
async def test_dataset_id_flexibility(self):
"""Test that dataset_id can be string or int."""
configs = [
GenerateChartRequest(
dataset_id="123",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
),
GenerateChartRequest(
dataset_id="uuid-string-here",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
),
]
for config in configs:
assert isinstance(config.dataset_id, str)
@pytest.mark.asyncio
async def test_save_chart_flag(self):
"""Test save_chart flag behavior."""
# Default should be True (save chart)
request1 = GenerateChartRequest(
dataset_id="1",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
)
assert request1.save_chart is True
# Explicit False (preview only)
request2 = GenerateChartRequest(
dataset_id="1",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
save_chart=False,
)
assert request2.save_chart is False
@pytest.mark.asyncio
async def test_preview_formats(self):
"""Test preview format options."""
formats = ["url", "ascii", "table"]
request = GenerateChartRequest(
dataset_id="1",
config=TableChartConfig(
chart_type="table", columns=[ColumnRef(name="col1")]
),
generate_preview=True,
preview_formats=formats,
)
assert request.generate_preview is True
assert set(request.preview_formats) == set(formats)
@pytest.mark.asyncio
async def test_column_ref_features(self):
"""Test ColumnRef features like aggregation and labels."""
# Simple column
col1 = ColumnRef(name="region")
assert col1.name == "region"
assert col1.label is None
assert col1.aggregate is None
# Column with aggregation
col2 = ColumnRef(name="sales", aggregate="SUM", label="Total Sales")
assert col2.name == "sales"
assert col2.aggregate == "SUM"
assert col2.label == "Total Sales"
# All supported aggregations
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
for agg in aggs:
col = ColumnRef(name="value", aggregate=agg)
assert col.aggregate == agg
@pytest.mark.asyncio
async def test_axis_config_options(self):
"""Test axis configuration options."""
axis = AxisConfig(
title="Sales Amount",
format="$,.2f",
scale="linear",
)
assert axis.title == "Sales Amount"
assert axis.format == "$,.2f"
assert axis.scale == "linear"
# Test different formats
formats = ["SMART_NUMBER", "$,.2f", ",.0f", "smart_date", ".3%"]
for fmt in formats:
axis = AxisConfig(format=fmt)
assert axis.format == fmt
@pytest.mark.asyncio
async def test_legend_config_options(self):
"""Test legend configuration options."""
positions = ["top", "bottom", "left", "right"]
for pos in positions:
legend = LegendConfig(show=True, position=pos)
assert legend.position == pos
# Hidden legend
legend = LegendConfig(show=False)
assert legend.show is False

View File

@@ -0,0 +1,290 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for get_chart_preview MCP tool
"""
import pytest
from superset.mcp_service.chart.schemas import (
ASCIIPreview,
GetChartPreviewRequest,
TablePreview,
URLPreview,
)
class TestGetChartPreview:
"""Tests for get_chart_preview MCP tool."""
@pytest.mark.asyncio
async def test_get_chart_preview_request_structure(self):
"""Test that preview request structures are properly formed."""
# Numeric ID request
request1 = GetChartPreviewRequest(identifier=123, format="url")
assert request1.identifier == 123
assert request1.format == "url"
# Default dimensions are set
assert request1.width == 800
assert request1.height == 600
# String ID request
request2 = GetChartPreviewRequest(identifier="456", format="ascii")
assert request2.identifier == "456"
assert request2.format == "ascii"
# UUID request
request3 = GetChartPreviewRequest(
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", format="table"
)
assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert request3.format == "table"
# Default format
request4 = GetChartPreviewRequest(identifier=789)
assert request4.format == "url" # default
@pytest.mark.asyncio
async def test_preview_format_types(self):
"""Test different preview format types."""
formats = ["url", "ascii", "table"]
for fmt in formats:
request = GetChartPreviewRequest(identifier=1, format=fmt)
assert request.format == fmt
@pytest.mark.asyncio
async def test_url_preview_structure(self):
"""Test URLPreview response structure."""
preview = URLPreview(
preview_url="http://localhost:5008/screenshot/chart/123.png",
width=800,
height=600,
supports_interaction=False,
)
assert preview.type == "url"
assert preview.preview_url == "http://localhost:5008/screenshot/chart/123.png"
assert preview.width == 800
assert preview.height == 600
assert preview.supports_interaction is False
@pytest.mark.asyncio
async def test_ascii_preview_structure(self):
"""Test ASCIIPreview response structure."""
ascii_art = """
┌─────────────────────────┐
│ Sales by Region │
├─────────────────────────┤
│ North ████████ 45%
│ South ██████ 30%
│ East ████ 20%
│ West ██ 5%
└─────────────────────────┘
"""
preview = ASCIIPreview(
ascii_content=ascii_art.strip(),
width=25,
height=8,
)
assert preview.type == "ascii"
assert "Sales by Region" in preview.ascii_content
assert preview.width == 25
assert preview.height == 8
@pytest.mark.asyncio
async def test_table_preview_structure(self):
"""Test TablePreview response structure."""
table_content = """
| Region | Sales | Profit |
|--------|--------|--------|
| North | 45000 | 12000 |
| South | 30000 | 8000 |
| East | 20000 | 5000 |
| West | 5000 | 1000 |
"""
preview = TablePreview(
table_data=table_content.strip(),
row_count=4,
supports_sorting=True,
)
assert preview.type == "table"
assert "Region" in preview.table_data
assert "North" in preview.table_data
assert preview.row_count == 4
assert preview.supports_sorting is True
@pytest.mark.asyncio
async def test_chart_preview_response_structure(self):
"""Test the expected response structure for chart preview."""
# Core fields that should always be present
_ = [
"chart_id",
"chart_name",
"chart_type",
"explore_url",
"content", # Union of URLPreview | ASCIIPreview | TablePreview
"chart_description",
"accessibility",
"performance",
]
# Additional fields that may be present for backward compatibility
_ = [
"format",
"preview_url",
"ascii_chart",
"table_data",
"width",
"height",
"schema_version",
"api_version",
]
# This is a structural test - actual integration tests would verify
# the tool returns data matching this structure
@pytest.mark.asyncio
async def test_preview_dimensions(self):
"""Test preview dimensions in response."""
# Standard dimensions that may appear in preview responses
standard_sizes = [
(800, 600), # Default
(1200, 800), # Large
(400, 300), # Small
(1920, 1080), # Full HD
]
for width, height in standard_sizes:
# URL preview with dimensions
url_preview = URLPreview(
preview_url="http://example.com/chart.png",
width=width,
height=height,
supports_interaction=False,
)
assert url_preview.width == width
assert url_preview.height == height
@pytest.mark.asyncio
async def test_error_response_structures(self):
"""Test error response structures."""
# Error responses typically follow this structure
error_response = {
"error_type": "not_found",
"message": "Chart not found",
"details": "No chart found with ID 999",
"chart_id": 999,
}
assert error_response["error_type"] == "not_found"
assert error_response["chart_id"] == 999
# Preview generation error structure
preview_error = {
"error_type": "preview_error",
"message": "Failed to generate preview",
"details": "Screenshot service unavailable",
}
assert preview_error["error_type"] == "preview_error"
@pytest.mark.asyncio
async def test_accessibility_metadata(self):
"""Test accessibility metadata structure."""
from superset.mcp_service.chart.schemas import AccessibilityMetadata
metadata = AccessibilityMetadata(
color_blind_safe=True,
alt_text="Bar chart showing sales by region",
high_contrast_available=False,
)
assert metadata.color_blind_safe is True
assert "sales by region" in metadata.alt_text
assert metadata.high_contrast_available is False
@pytest.mark.asyncio
async def test_performance_metadata(self):
"""Test performance metadata structure."""
from superset.mcp_service.chart.schemas import PerformanceMetadata
metadata = PerformanceMetadata(
query_duration_ms=150,
cache_status="hit",
optimization_suggestions=["Consider adding an index on date column"],
)
assert metadata.query_duration_ms == 150
assert metadata.cache_status == "hit"
assert len(metadata.optimization_suggestions) == 1
@pytest.mark.asyncio
async def test_chart_types_support(self):
"""Test that various chart types are supported."""
chart_types = [
"echarts_timeseries_line",
"echarts_timeseries_bar",
"echarts_area",
"echarts_timeseries_scatter",
"table",
"pie",
"big_number",
"big_number_total",
"pivot_table_v2",
"dist_bar",
"box_plot",
]
# All chart types should be previewable
for _chart_type in chart_types:
# This would be tested in integration tests
pass
@pytest.mark.asyncio
async def test_ascii_art_variations(self):
"""Test ASCII art generation for different chart types."""
# Line chart ASCII
_ = """
Sales Trend
│ ╱╲
└────────────
Jan Feb Mar
"""
# Bar chart ASCII
_ = """
Sales by Region
│ ████ North
│ ███ South
│ ██ East
│ █ West
└────────────
"""
# Pie chart ASCII
_ = """
Market Share
╭─────╮
│ 45%
│ North │
╰─────────╯
"""
# These demonstrate the expected ASCII formats for different chart types

View File

@@ -0,0 +1,385 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for update_chart MCP tool
"""
import pytest
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
LegendConfig,
TableChartConfig,
UpdateChartRequest,
XYChartConfig,
)
class TestUpdateChart:
"""Tests for update_chart MCP tool."""
@pytest.mark.asyncio
async def test_update_chart_request_structure(self):
"""Test that chart update request structures are properly formed."""
# Table chart update with numeric ID
table_config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region", label="Region"),
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
],
filters=[FilterConfig(column="year", op="=", value="2024")],
sort_by=["sales"],
)
table_request = UpdateChartRequest(identifier=123, config=table_config)
assert table_request.identifier == 123
assert table_request.config.chart_type == "table"
assert len(table_request.config.columns) == 2
assert table_request.config.columns[0].name == "region"
assert table_request.config.columns[1].aggregate == "SUM"
# XY chart update with UUID
xy_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind="line",
group_by=ColumnRef(name="category"),
x_axis=AxisConfig(title="Date", format="smart_date"),
y_axis=AxisConfig(title="Value", format="$,.2f"),
legend=LegendConfig(show=True, position="top"),
)
xy_request = UpdateChartRequest(
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config
)
assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert xy_request.config.chart_type == "xy"
assert xy_request.config.x.name == "date"
assert xy_request.config.y[0].aggregate == "SUM"
assert xy_request.config.kind == "line"
@pytest.mark.asyncio
async def test_update_chart_with_chart_name(self):
"""Test updating chart with custom chart name."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Without custom name
request1 = UpdateChartRequest(identifier=123, config=config)
assert request1.chart_name is None
# With custom name
request2 = UpdateChartRequest(
identifier=123, config=config, chart_name="Updated Sales Report"
)
assert request2.chart_name == "Updated Sales Report"
@pytest.mark.asyncio
async def test_update_chart_preview_generation(self):
"""Test preview generation options in update request."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Default preview generation
request1 = UpdateChartRequest(identifier=123, config=config)
assert request1.generate_preview is True
assert request1.preview_formats == ["url"]
# Custom preview formats
request2 = UpdateChartRequest(
identifier=123,
config=config,
generate_preview=True,
preview_formats=["url", "ascii", "table"],
)
assert request2.generate_preview is True
assert set(request2.preview_formats) == {"url", "ascii", "table"}
# Disable preview generation
request3 = UpdateChartRequest(
identifier=123, config=config, generate_preview=False
)
assert request3.generate_preview is False
@pytest.mark.asyncio
async def test_update_chart_identifier_types(self):
"""Test that identifier can be int or string (UUID)."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Integer ID
request1 = UpdateChartRequest(identifier=123, config=config)
assert request1.identifier == 123
assert isinstance(request1.identifier, int)
# String numeric ID
request2 = UpdateChartRequest(identifier="456", config=config)
assert request2.identifier == "456"
assert isinstance(request2.identifier, str)
# UUID string
request3 = UpdateChartRequest(
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=config
)
assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert isinstance(request3.identifier, str)
@pytest.mark.asyncio
async def test_update_chart_config_variations(self):
"""Test various chart configuration options in updates."""
# Test all XY chart types
chart_types = ["line", "bar", "area", "scatter"]
for chart_type in chart_types:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x_col"),
y=[ColumnRef(name="y_col")],
kind=chart_type,
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.kind == chart_type
# Test multiple Y columns
multi_y_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="profit", aggregate="AVG"),
ColumnRef(name="quantity", aggregate="COUNT"),
],
kind="line",
)
request = UpdateChartRequest(identifier=1, config=multi_y_config)
assert len(request.config.y) == 3
assert request.config.y[1].aggregate == "AVG"
# Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="]
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
table_config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
filters=filters,
)
request = UpdateChartRequest(identifier=1, config=table_config)
assert len(request.config.filters) == 6
@pytest.mark.asyncio
async def test_update_chart_response_structure(self):
"""Test the expected response structure for chart updates."""
# The response should contain these fields
expected_response = {
"chart": {
"id": int,
"slice_name": str,
"viz_type": str,
"url": str,
"uuid": str,
"updated": bool,
},
"error": None,
"success": bool,
"schema_version": str,
"api_version": str,
}
# When chart update succeeds, these additional fields may be present
optional_fields = [
"previews",
"capabilities",
"semantics",
"explore_url",
"api_endpoints",
"performance",
"accessibility",
]
# Validate structure expectations
assert "chart" in expected_response
assert "success" in expected_response
assert len(optional_fields) > 0
@pytest.mark.asyncio
async def test_update_chart_axis_configurations(self):
"""Test axis configuration updates."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="sales")],
x_axis=AxisConfig(
title="Date",
format="smart_date",
scale="linear",
),
y_axis=AxisConfig(
title="Sales Amount",
format="$,.2f",
scale="log",
),
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.x_axis.title == "Date"
assert request.config.x_axis.format == "smart_date"
assert request.config.x_axis.scale == "linear"
assert request.config.y_axis.title == "Sales Amount"
assert request.config.y_axis.format == "$,.2f"
assert request.config.y_axis.scale == "log"
@pytest.mark.asyncio
async def test_update_chart_legend_configurations(self):
"""Test legend configuration updates."""
positions = ["top", "bottom", "left", "right"]
for pos in positions:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
legend=LegendConfig(show=True, position=pos),
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.legend.position == pos
assert request.config.legend.show is True
# Hidden legend
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
legend=LegendConfig(show=False),
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.legend.show is False
@pytest.mark.asyncio
async def test_update_chart_aggregation_functions(self):
"""Test all supported aggregation functions in updates."""
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
for agg in aggs:
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="value", aggregate=agg)],
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.columns[0].aggregate == agg
@pytest.mark.asyncio
async def test_update_chart_error_responses(self):
"""Test expected error response structures."""
# Chart not found error
error_response = {
"chart": None,
"error": "No chart found with identifier: 999",
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
assert error_response["success"] is False
assert error_response["chart"] is None
assert "chart found" in error_response["error"].lower()
# General update error
update_error = {
"chart": None,
"error": "Chart update failed: Permission denied",
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
assert update_error["success"] is False
assert "failed" in update_error["error"].lower()
@pytest.mark.asyncio
async def test_chart_name_sanitization(self):
"""Test that chart names are properly sanitized."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Test with potentially problematic characters
test_names = [
"Normal Chart Name",
"Chart with 'quotes'",
'Chart with "double quotes"',
"Chart with <brackets>",
]
for name in test_names:
request = UpdateChartRequest(identifier=1, config=config, chart_name=name)
# Chart name should be set (sanitization happens in the validator)
assert request.chart_name is not None
@pytest.mark.asyncio
async def test_update_chart_with_filters(self):
"""Test updating chart with various filter configurations."""
filters = [
FilterConfig(column="region", op="=", value="North"),
FilterConfig(column="sales", op=">=", value=1000),
FilterConfig(column="date", op=">", value="2024-01-01"),
]
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region"),
ColumnRef(name="sales"),
ColumnRef(name="date"),
],
filters=filters,
)
request = UpdateChartRequest(identifier=1, config=config)
assert len(request.config.filters) == 3
assert request.config.filters[0].column == "region"
assert request.config.filters[1].op == ">="
assert request.config.filters[2].value == "2024-01-01"
@pytest.mark.asyncio
async def test_update_chart_cache_control(self):
"""Test cache control parameters in update request."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Default cache settings
request1 = UpdateChartRequest(identifier=1, config=config)
assert request1.use_cache is True
assert request1.force_refresh is False
assert request1.cache_timeout is None
# Custom cache settings
request2 = UpdateChartRequest(
identifier=1,
config=config,
use_cache=False,
force_refresh=True,
cache_timeout=300,
)
assert request2.use_cache is False
assert request2.force_refresh is True
assert request2.cache_timeout == 300

View File

@@ -0,0 +1,474 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for update_chart_preview MCP tool
"""
import pytest
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
LegendConfig,
TableChartConfig,
UpdateChartPreviewRequest,
XYChartConfig,
)
class TestUpdateChartPreview:
"""Tests for update_chart_preview MCP tool."""
@pytest.mark.asyncio
async def test_update_chart_preview_request_structure(self):
"""Test that chart preview update request structures are properly formed."""
# Table chart preview update
table_config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region", label="Region"),
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
],
filters=[FilterConfig(column="year", op="=", value="2024")],
sort_by=["sales"],
)
table_request = UpdateChartPreviewRequest(
form_data_key="abc123def456", dataset_id=1, config=table_config
)
assert table_request.form_data_key == "abc123def456"
assert table_request.dataset_id == 1
assert table_request.config.chart_type == "table"
assert len(table_request.config.columns) == 2
assert table_request.config.columns[0].name == "region"
# XY chart preview update
xy_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="value", aggregate="SUM")],
kind="line",
group_by=ColumnRef(name="category"),
x_axis=AxisConfig(title="Date", format="smart_date"),
y_axis=AxisConfig(title="Value", format="$,.2f"),
legend=LegendConfig(show=True, position="top"),
)
xy_request = UpdateChartPreviewRequest(
form_data_key="xyz789ghi012", dataset_id="2", config=xy_config
)
assert xy_request.form_data_key == "xyz789ghi012"
assert xy_request.dataset_id == "2"
assert xy_request.config.chart_type == "xy"
assert xy_request.config.x.name == "date"
assert xy_request.config.kind == "line"
@pytest.mark.asyncio
async def test_update_chart_preview_dataset_id_types(self):
"""Test that dataset_id can be int or string (UUID)."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Integer dataset_id
request1 = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=123, config=config
)
assert request1.dataset_id == 123
assert isinstance(request1.dataset_id, int)
# String numeric dataset_id
request2 = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id="456", config=config
)
assert request2.dataset_id == "456"
assert isinstance(request2.dataset_id, str)
# UUID string dataset_id
request3 = UpdateChartPreviewRequest(
form_data_key="abc123",
dataset_id="a1b2c3d4-e5f6-7890-abcd-ef1234567890",
config=config,
)
assert request3.dataset_id == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert isinstance(request3.dataset_id, str)
@pytest.mark.asyncio
async def test_update_chart_preview_generation_options(self):
"""Test preview generation options in update preview request."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Default preview generation
request1 = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request1.generate_preview is True
assert request1.preview_formats == ["url"]
# Custom preview formats
request2 = UpdateChartPreviewRequest(
form_data_key="abc123",
dataset_id=1,
config=config,
generate_preview=True,
preview_formats=["url", "ascii", "table"],
)
assert request2.generate_preview is True
assert set(request2.preview_formats) == {"url", "ascii", "table"}
# Disable preview generation
request3 = UpdateChartPreviewRequest(
form_data_key="abc123",
dataset_id=1,
config=config,
generate_preview=False,
)
assert request3.generate_preview is False
@pytest.mark.asyncio
async def test_update_chart_preview_config_variations(self):
"""Test various chart configuration options in preview updates."""
# Test all XY chart types
chart_types = ["line", "bar", "area", "scatter"]
for chart_type in chart_types:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x_col"),
y=[ColumnRef(name="y_col")],
kind=chart_type,
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.kind == chart_type
# Test multiple Y columns
multi_y_config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="profit", aggregate="AVG"),
ColumnRef(name="quantity", aggregate="COUNT"),
],
kind="line",
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=multi_y_config
)
assert len(request.config.y) == 3
assert request.config.y[1].aggregate == "AVG"
# Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="]
filters = [FilterConfig(column="col", op=op, value="val") for op in operators]
table_config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
filters=filters,
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=table_config
)
assert len(request.config.filters) == 6
@pytest.mark.asyncio
async def test_update_chart_preview_response_structure(self):
"""Test the expected response structure for chart preview updates."""
# The response should contain these fields
expected_response = {
"chart": {
"id": None, # No ID for unsaved previews
"slice_name": str,
"viz_type": str,
"url": str,
"uuid": None, # No UUID for unsaved previews
"saved": bool,
"updated": bool,
},
"error": None,
"success": bool,
"schema_version": str,
"api_version": str,
}
# When preview update succeeds, these additional fields may be present
optional_fields = [
"previews",
"capabilities",
"semantics",
"explore_url",
"form_data_key",
"previous_form_data_key",
"api_endpoints",
"performance",
"accessibility",
]
# Validate structure expectations
assert "chart" in expected_response
assert "success" in expected_response
assert len(optional_fields) > 0
assert expected_response["chart"]["id"] is None
assert expected_response["chart"]["uuid"] is None
@pytest.mark.asyncio
async def test_update_chart_preview_axis_configurations(self):
"""Test axis configuration updates in preview."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="sales")],
x_axis=AxisConfig(
title="Date",
format="smart_date",
scale="linear",
),
y_axis=AxisConfig(
title="Sales Amount",
format="$,.2f",
scale="log",
),
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.x_axis.title == "Date"
assert request.config.x_axis.format == "smart_date"
assert request.config.y_axis.title == "Sales Amount"
assert request.config.y_axis.format == "$,.2f"
@pytest.mark.asyncio
async def test_update_chart_preview_legend_configurations(self):
"""Test legend configuration updates in preview."""
positions = ["top", "bottom", "left", "right"]
for pos in positions:
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
legend=LegendConfig(show=True, position=pos),
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.legend.position == pos
assert request.config.legend.show is True
# Hidden legend
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
legend=LegendConfig(show=False),
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.legend.show is False
@pytest.mark.asyncio
async def test_update_chart_preview_aggregation_functions(self):
"""Test all supported aggregation functions in preview updates."""
aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"]
for agg in aggs:
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="value", aggregate=agg)],
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.columns[0].aggregate == agg
@pytest.mark.asyncio
async def test_update_chart_preview_error_responses(self):
"""Test expected error response structures for preview updates."""
# General update error
error_response = {
"chart": None,
"error": "Chart preview update failed: Invalid form_data_key",
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
assert error_response["success"] is False
assert error_response["chart"] is None
assert "failed" in error_response["error"].lower()
# Missing dataset error
dataset_error = {
"chart": None,
"error": "Chart preview update failed: Dataset not found",
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
assert dataset_error["success"] is False
assert "dataset" in dataset_error["error"].lower()
@pytest.mark.asyncio
async def test_update_chart_preview_with_filters(self):
"""Test updating preview with various filter configurations."""
filters = [
FilterConfig(column="region", op="=", value="North"),
FilterConfig(column="sales", op=">=", value=1000),
FilterConfig(column="date", op=">", value="2024-01-01"),
]
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region"),
ColumnRef(name="sales"),
ColumnRef(name="date"),
],
filters=filters,
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert len(request.config.filters) == 3
assert request.config.filters[0].column == "region"
assert request.config.filters[1].op == ">="
assert request.config.filters[2].value == "2024-01-01"
@pytest.mark.asyncio
async def test_update_chart_preview_form_data_key_handling(self):
"""Test form_data_key handling in preview updates."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Various form_data_key formats
form_data_keys = [
"abc123def456",
"xyz-789-ghi-012",
"key_with_underscore",
"UPPERCASE_KEY",
]
for key in form_data_keys:
request = UpdateChartPreviewRequest(
form_data_key=key, dataset_id=1, config=config
)
assert request.form_data_key == key
@pytest.mark.asyncio
async def test_update_chart_preview_cache_control(self):
"""Test cache control parameters in update preview request."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
# Default cache settings
request1 = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request1.use_cache is True
assert request1.force_refresh is False
assert request1.cache_form_data is True
# Custom cache settings
request2 = UpdateChartPreviewRequest(
form_data_key="abc123",
dataset_id=1,
config=config,
use_cache=False,
force_refresh=True,
cache_form_data=False,
)
assert request2.use_cache is False
assert request2.force_refresh is True
assert request2.cache_form_data is False
@pytest.mark.asyncio
async def test_update_chart_preview_no_save_behavior(self):
"""Test that preview updates don't create permanent charts."""
config = TableChartConfig(
chart_type="table",
columns=[ColumnRef(name="col1")],
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
# Preview updates should never create permanent charts
# This is validated by checking the response structure
expected_unsaved_fields = {
"id": None, # No chart ID
"uuid": None, # No UUID
"saved": False, # Not saved
}
# These expectations are validated in the response, not the request
assert request.form_data_key == "abc123"
assert expected_unsaved_fields["id"] is None
assert expected_unsaved_fields["uuid"] is None
assert expected_unsaved_fields["saved"] is False
@pytest.mark.asyncio
async def test_update_chart_preview_multiple_y_columns(self):
"""Test preview updates with multiple Y-axis columns."""
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[
ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"),
ColumnRef(name="cost", aggregate="SUM", label="Total Cost"),
ColumnRef(name="profit", aggregate="SUM", label="Total Profit"),
ColumnRef(name="orders", aggregate="COUNT", label="Order Count"),
],
kind="line",
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert len(request.config.y) == 4
assert request.config.y[0].name == "revenue"
assert request.config.y[1].name == "cost"
assert request.config.y[2].name == "profit"
assert request.config.y[3].name == "orders"
assert request.config.y[3].aggregate == "COUNT"
@pytest.mark.asyncio
async def test_update_chart_preview_table_sorting(self):
"""Test table chart sorting in preview updates."""
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region"),
ColumnRef(name="sales", aggregate="SUM"),
ColumnRef(name="profit", aggregate="AVG"),
],
sort_by=["sales", "profit"],
)
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.sort_by == ["sales", "profit"]
assert len(request.config.columns) == 3

View File

@@ -0,0 +1,23 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
MCP service test configuration.
Tool imports are handled by app.py, not here.
This conftest is empty to prevent test pollution.
"""

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,450 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for dashboard generation MCP tools
"""
import logging
from unittest.mock import Mock, patch
import pytest
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.utils import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _mock_chart(id: int = 1, slice_name: str = "Test Chart") -> Mock:
"""Create a mock chart object."""
chart = Mock()
chart.id = id
chart.slice_name = slice_name
chart.uuid = f"chart-uuid-{id}"
return chart
def _mock_dashboard(id: int = 1, title: str = "Test Dashboard") -> Mock:
"""Create a mock dashboard object."""
dashboard = Mock()
dashboard.id = id
dashboard.dashboard_title = title
dashboard.slug = f"test-dashboard-{id}"
dashboard.description = "Test dashboard description"
dashboard.published = True
dashboard.created_on = "2024-01-01"
dashboard.changed_on = "2024-01-01"
dashboard.created_by = Mock()
dashboard.created_by.username = "test_user"
dashboard.changed_by = Mock()
dashboard.changed_by.username = "test_user"
dashboard.uuid = f"dashboard-uuid-{id}"
dashboard.slices = []
dashboard.owners = [] # Add missing owners attribute
dashboard.tags = [] # Add missing tags attribute
return dashboard
class TestGenerateDashboard:
"""Tests for generate_dashboard MCP tool."""
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_basic(
self, mock_db_session, mock_create_command, mcp_server
):
"""Test basic dashboard generation with valid charts."""
# Mock database query for charts
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.all.return_value = [
_mock_chart(id=1, slice_name="Sales Chart"),
_mock_chart(id=2, slice_name="Revenue Chart"),
]
mock_db_session.query.return_value = mock_query
# Mock dashboard creation
mock_dashboard = _mock_dashboard(id=10, title="Analytics Dashboard")
mock_create_command.return_value.run.return_value = mock_dashboard
request = {
"chart_ids": [1, 2],
"dashboard_title": "Analytics Dashboard",
"description": "Dashboard for analytics",
"published": True,
}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.data.error is None
assert result.data.dashboard is not None
assert result.data.dashboard.id == 10
assert result.data.dashboard.dashboard_title == "Analytics Dashboard"
assert result.data.dashboard.chart_count == 2
assert "/superset/dashboard/10/" in result.data.dashboard_url
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_missing_charts(self, mock_db_session, mcp_server):
"""Test error handling when some charts don't exist."""
# Mock database query returning only chart 1 (chart 2 missing)
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.all.return_value = [
_mock_chart(id=1),
# Chart 2 is missing from the result
]
mock_db_session.query.return_value = mock_query
request = {"chart_ids": [1, 2], "dashboard_title": "Test Dashboard"}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.data.error is not None
assert "Charts not found: [2]" in result.data.error
assert result.data.dashboard is None
assert result.data.dashboard_url is None
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_single_chart(
self, mock_db_session, mock_create_command, mcp_server
):
"""Test dashboard generation with a single chart."""
# Mock database query for single chart
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.all.return_value = [_mock_chart(id=5, slice_name="Single Chart")]
mock_db_session.query.return_value = mock_query
mock_dashboard = _mock_dashboard(id=20, title="Single Chart Dashboard")
mock_create_command.return_value.run.return_value = mock_dashboard
request = {
"chart_ids": [5],
"dashboard_title": "Single Chart Dashboard",
"published": False,
}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.data.error is None
assert result.data.dashboard.chart_count == 1
assert result.data.dashboard.published is True # From mock
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_many_charts(
self, mock_db_session, mock_create_command, mcp_server
):
"""Test dashboard generation with many charts (grid layout)."""
# Mock 6 charts
chart_ids = list(range(1, 7))
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.all.return_value = [
_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids
]
mock_db_session.query.return_value = mock_query
mock_dashboard = _mock_dashboard(id=30, title="Multi Chart Dashboard")
mock_create_command.return_value.run.return_value = mock_dashboard
request = {"chart_ids": chart_ids, "dashboard_title": "Multi Chart Dashboard"}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.data.error is None
assert result.data.dashboard.chart_count == 6
# Verify CreateDashboardCommand was called with proper layout
mock_create_command.assert_called_once()
call_args = mock_create_command.call_args[0][0]
# Check position_json contains proper layout
position_json = json.loads(call_args["position_json"])
assert "ROOT_ID" in position_json
assert "GRID_ID" in position_json
assert "DASHBOARD_VERSION_KEY" in position_json
assert position_json["DASHBOARD_VERSION_KEY"] == "v2"
# ROOT should only contain GRID
assert position_json["ROOT_ID"]["children"] == ["GRID_ID"]
# GRID should contain rows (6 charts = 3 rows in 2-chart layout)
grid_children = position_json["GRID_ID"]["children"]
assert len(grid_children) == 3
# Check each chart has proper structure
for i, chart_id in enumerate(chart_ids):
chart_key = f"CHART-{chart_id}"
row_index = i // 2 # 2 charts per row
row_key = f"ROW-{row_index}"
# Chart should exist
assert chart_key in position_json
chart_data = position_json[chart_key]
assert chart_data["type"] == "CHART"
assert "meta" in chart_data
assert chart_data["meta"]["chartId"] == chart_id
# Row should exist and contain charts (up to 2 per row)
assert row_key in position_json
row_data = position_json[row_key]
assert row_data["type"] == "ROW"
assert chart_key in row_data["children"]
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_creation_failure(
self, mock_db_session, mock_create_command, mcp_server
):
"""Test error handling when dashboard creation fails."""
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.all.return_value = [_mock_chart(id=1)]
mock_db_session.query.return_value = mock_query
mock_create_command.return_value.run.side_effect = Exception("Creation failed")
request = {"chart_ids": [1], "dashboard_title": "Failed Dashboard"}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.data.error is not None
assert "Failed to create dashboard" in result.data.error
assert result.data.dashboard is None
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_minimal_request(
self, mock_db_session, mock_create_command, mcp_server
):
"""Test dashboard generation with minimal required parameters."""
# Mock database query for single chart
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.all.return_value = [_mock_chart(id=3)]
mock_db_session.query.return_value = mock_query
mock_dashboard = _mock_dashboard(id=40, title="Minimal Dashboard")
mock_create_command.return_value.run.return_value = mock_dashboard
request = {
"chart_ids": [3],
"dashboard_title": "Minimal Dashboard",
# No description, published defaults to True
}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.data.error is None
assert result.data.dashboard.dashboard_title == "Minimal Dashboard"
# Check that description was not included in call
call_args = mock_create_command.call_args[0][0]
assert call_args["published"] is True # Default value
assert (
"description" not in call_args or call_args.get("description") is None
)
class TestAddChartToExistingDashboard:
"""Tests for add_chart_to_existing_dashboard MCP tool."""
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_add_chart_to_dashboard_basic(
self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server
):
"""Test adding a chart to an existing dashboard."""
# Mock existing dashboard with some charts
mock_dashboard = _mock_dashboard(id=1, title="Existing Dashboard")
mock_dashboard.slices = [Mock(id=10), Mock(id=20)] # Existing charts
mock_dashboard.position_json = json.dumps(
{
"ROOT_ID": {
"children": ["CHART-10", "CHART-20"],
"id": "ROOT_ID",
"type": "ROOT",
},
"CHART-10": {"id": "CHART-10", "type": "CHART", "parents": ["ROOT_ID"]},
"CHART-10_POSITION": {"h": 16, "w": 24, "x": 0, "y": 0},
"CHART-20": {"id": "CHART-20", "type": "CHART", "parents": ["ROOT_ID"]},
"CHART-20_POSITION": {"h": 16, "w": 24, "x": 24, "y": 0},
}
)
mock_find_dashboard.return_value = mock_dashboard
# Mock chart to add
mock_chart = _mock_chart(id=30, slice_name="New Chart")
mock_db_session.get.return_value = mock_chart
# Mock updated dashboard
updated_dashboard = _mock_dashboard(id=1, title="Existing Dashboard")
updated_dashboard.slices = [Mock(id=10), Mock(id=20), Mock(id=30)]
mock_update_command.return_value.run.return_value = updated_dashboard
request = {"dashboard_id": 1, "chart_id": 30}
async with Client(mcp_server) as client:
result = await client.call_tool(
"add_chart_to_existing_dashboard", {"request": request}
)
assert result.data.error is None
assert result.data.dashboard is not None
assert result.data.dashboard.chart_count == 3
assert result.data.position is not None
assert "row" in result.data.position # Should have row info
assert "chart_key" in result.data.position
assert "/superset/dashboard/1/" in result.data.dashboard_url
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_add_chart_dashboard_not_found(self, mock_find_dashboard, mcp_server):
"""Test error when dashboard doesn't exist."""
mock_find_dashboard.return_value = None
request = {"dashboard_id": 999, "chart_id": 1}
async with Client(mcp_server) as client:
result = await client.call_tool(
"add_chart_to_existing_dashboard", {"request": request}
)
assert result.data.error is not None
assert "Dashboard with ID 999 not found" in result.data.error
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_add_chart_chart_not_found(
self, mock_db_session, mock_find_dashboard, mcp_server
):
"""Test error when chart doesn't exist."""
mock_find_dashboard.return_value = _mock_dashboard()
mock_db_session.get.return_value = None
request = {"dashboard_id": 1, "chart_id": 999}
async with Client(mcp_server) as client:
result = await client.call_tool(
"add_chart_to_existing_dashboard", {"request": request}
)
assert result.data.error is not None
assert "Chart with ID 999 not found" in result.data.error
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_add_chart_already_in_dashboard(
self, mock_db_session, mock_find_dashboard, mcp_server
):
"""Test error when chart is already in dashboard."""
mock_dashboard = _mock_dashboard()
mock_dashboard.slices = [Mock(id=5)] # Chart 5 already exists
mock_find_dashboard.return_value = mock_dashboard
mock_db_session.get.return_value = _mock_chart(id=5)
request = {"dashboard_id": 1, "chart_id": 5}
async with Client(mcp_server) as client:
result = await client.call_tool(
"add_chart_to_existing_dashboard", {"request": request}
)
assert result.data.error is not None
assert "Chart 5 is already in dashboard 1" in result.data.error
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_add_chart_empty_dashboard(
self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server
):
"""Test adding chart to dashboard with no existing layout."""
mock_dashboard = _mock_dashboard(id=2)
mock_dashboard.slices = []
mock_dashboard.position_json = "{}" # Empty layout
mock_find_dashboard.return_value = mock_dashboard
mock_chart = _mock_chart(id=15)
mock_db_session.get.return_value = mock_chart
updated_dashboard = _mock_dashboard(id=2)
updated_dashboard.slices = [Mock(id=15)]
mock_update_command.return_value.run.return_value = updated_dashboard
request = {"dashboard_id": 2, "chart_id": 15}
async with Client(mcp_server) as client:
result = await client.call_tool(
"add_chart_to_existing_dashboard", {"request": request}
)
assert result.data.error is None
assert "row" in result.data.position # Should have row info
assert result.data.position.get("row") == 0 # First row
# Verify update was called with proper layout structure
call_args = mock_update_command.call_args[0][1]
layout = json.loads(call_args["position_json"])
assert "ROOT_ID" in layout
assert "GRID_ID" in layout
assert "ROW-0" in layout
assert "CHART-15" in layout

View File

@@ -0,0 +1,573 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for MCP dashboard tools (list_dashboards, get_dashboard_info)
"""
import logging
from unittest.mock import Mock, patch
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from superset.mcp_service.app import mcp
from superset.mcp_service.dashboard.schemas import (
ListDashboardsRequest,
)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_basic(mock_list, mcp_server):
dashboard = Mock()
dashboard.id = 1
dashboard.dashboard_title = "Test Dashboard"
dashboard.slug = "test-dashboard"
dashboard.url = "/dashboard/1"
dashboard.published = True
dashboard.changed_by_name = "admin"
dashboard.changed_on = None
dashboard.changed_on_humanized = None
dashboard.created_by_name = "admin"
dashboard.created_on = None
dashboard.created_on_humanized = None
dashboard.tags = []
dashboard.owners = []
dashboard.slices = []
dashboard.description = None
dashboard.css = None
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = None
dashboard.position_json = None
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.uuid = "test-dashboard-uuid-1"
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
"dashboard_title": dashboard.dashboard_title,
"slug": dashboard.slug,
"url": dashboard.url,
"published": dashboard.published,
"changed_by_name": dashboard.changed_by_name,
"changed_on": dashboard.changed_on,
"changed_on_humanized": dashboard.changed_on_humanized,
"created_by_name": dashboard.created_by_name,
"created_on": dashboard.created_on,
"created_on_humanized": dashboard.created_on_humanized,
"tags": dashboard.tags,
"owners": dashboard.owners,
"charts": [],
}
mock_list.return_value = ([dashboard], 1)
async with Client(mcp_server) as client:
request = ListDashboardsRequest(page=1, page_size=10)
result = await client.call_tool(
"list_dashboards", {"request": request.model_dump()}
)
dashboards = result.data.dashboards
assert len(dashboards) == 1
assert dashboards[0].dashboard_title == "Test Dashboard"
assert dashboards[0].uuid == "test-dashboard-uuid-1"
assert dashboards[0].slug == "test-dashboard"
assert dashboards[0].published is True
# Verify UUID and slug are in default columns
assert "uuid" in result.data.columns_requested
assert "slug" in result.data.columns_requested
assert "uuid" in result.data.columns_loaded
assert "slug" in result.data.columns_loaded
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_with_filters(mock_list, mcp_server):
dashboard = Mock()
dashboard.id = 1
dashboard.dashboard_title = "Filtered Dashboard"
dashboard.slug = "filtered-dashboard"
dashboard.url = "/dashboard/2"
dashboard.published = True
dashboard.changed_by_name = "admin"
dashboard.changed_on = None
dashboard.changed_on_humanized = None
dashboard.created_by_name = "admin"
dashboard.created_on = None
dashboard.created_on_humanized = None
dashboard.tags = []
dashboard.owners = []
dashboard.slices = []
dashboard.description = None
dashboard.css = None
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = None
dashboard.position_json = None
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.uuid = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
"dashboard_title": dashboard.dashboard_title,
"slug": dashboard.slug,
"url": dashboard.url,
"published": dashboard.published,
"changed_by_name": dashboard.changed_by_name,
"changed_on": dashboard.changed_on,
"changed_on_humanized": dashboard.changed_on_humanized,
"created_by_name": dashboard.created_by_name,
"created_on": dashboard.created_on,
"created_on_humanized": dashboard.created_on_humanized,
"tags": dashboard.tags,
"owners": dashboard.owners,
"charts": [],
}
mock_list.return_value = ([dashboard], 1)
async with Client(mcp_server) as client:
filters = [
{"col": "dashboard_title", "opr": "sw", "value": "Sales"},
{"col": "published", "opr": "eq", "value": True},
]
request = ListDashboardsRequest(
filters=filters,
select_columns=["id", "dashboard_title"],
order_column="changed_on",
order_direction="desc",
page=1,
page_size=50,
)
result = await client.call_tool(
"list_dashboards", {"request": request.model_dump()}
)
assert result.data.count == 1
assert result.data.dashboards[0].dashboard_title == "Filtered Dashboard"
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_with_string_filters(mock_list, mcp_server):
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client: # noqa: F841
filters = '[{"col": "dashboard_title", "opr": "sw", "value": "Sales"}]'
# Test that string filters cause validation error at schema level
with pytest.raises(ValueError, match="validation error"):
ListDashboardsRequest(filters=filters) # noqa: F841
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_api_error(mock_list, mcp_server):
mock_list.side_effect = ToolError("API request failed")
async with Client(mcp_server) as client:
with pytest.raises(ToolError) as excinfo: # noqa: PT012
request = ListDashboardsRequest()
await client.call_tool("list_dashboards", {"request": request.model_dump()})
assert "API request failed" in str(excinfo.value)
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_with_search(mock_list, mcp_server):
dashboard = Mock()
dashboard.id = 1
dashboard.dashboard_title = "search_dashboard"
dashboard.slug = "search-dashboard"
dashboard.url = "/dashboard/1"
dashboard.published = True
dashboard.changed_by_name = "admin"
dashboard.changed_on = None
dashboard.changed_on_humanized = None
dashboard.created_by_name = "admin"
dashboard.created_on = None
dashboard.created_on_humanized = None
dashboard.tags = []
dashboard.owners = []
dashboard.slices = []
dashboard.description = None
dashboard.css = None
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = None
dashboard.position_json = None
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.uuid = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
"dashboard_title": dashboard.dashboard_title,
"slug": dashboard.slug,
"url": dashboard.url,
"published": dashboard.published,
"changed_by_name": dashboard.changed_by_name,
"changed_on": dashboard.changed_on,
"changed_on_humanized": dashboard.changed_on_humanized,
"created_by_name": dashboard.created_by_name,
"created_on": dashboard.created_on,
"created_on_humanized": dashboard.created_on_humanized,
"tags": dashboard.tags,
"owners": dashboard.owners,
"charts": [],
}
mock_list.return_value = ([dashboard], 1)
async with Client(mcp_server) as client:
request = ListDashboardsRequest(search="search_dashboard")
result = await client.call_tool(
"list_dashboards", {"request": request.model_dump()}
)
assert result.data.count == 1
assert result.data.dashboards[0].dashboard_title == "search_dashboard"
args, kwargs = mock_list.call_args
assert kwargs["search"] == "search_dashboard"
assert "dashboard_title" in kwargs["search_columns"]
assert "slug" in kwargs["search_columns"]
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_with_simple_filters(mock_list, mcp_server):
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
filters = [
{"col": "dashboard_title", "opr": "eq", "value": "Sales"},
{"col": "published", "opr": "eq", "value": True},
]
request = ListDashboardsRequest(filters=filters)
result = await client.call_tool(
"list_dashboards", {"request": request.model_dump()}
)
assert hasattr(result.data, "count")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_dashboard_info_success(mock_info, mcp_server):
dashboard = Mock()
dashboard.id = 1
dashboard.dashboard_title = "Test Dashboard"
dashboard.slug = "test-dashboard"
dashboard.description = "Test description"
dashboard.css = None
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = None
dashboard.position_json = None
dashboard.published = True
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.created_on = None
dashboard.changed_on = None
dashboard.created_by = None
dashboard.changed_by = None
dashboard.uuid = None
dashboard.url = "/dashboard/1"
dashboard.thumbnail_url = None
dashboard.created_on_humanized = None
dashboard.changed_on_humanized = None
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
"dashboard_title": dashboard.dashboard_title,
"slug": dashboard.slug,
"url": dashboard.url,
"published": dashboard.published,
"changed_by_name": dashboard.changed_by_name,
"changed_on": dashboard.changed_on,
"changed_on_humanized": dashboard.changed_on_humanized,
"created_by_name": dashboard.created_by_name,
"created_on": dashboard.created_on,
"created_on_humanized": dashboard.created_on_humanized,
"tags": dashboard.tags,
"owners": dashboard.owners,
"charts": [],
}
mock_info.return_value = dashboard # Only the dashboard object
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_info", {"request": {"identifier": 1}}
)
assert result.data["dashboard_title"] == "Test Dashboard"
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_dashboard_info_not_found(mock_info, mcp_server):
mock_info.return_value = None # Not found returns None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_info", {"request": {"identifier": 999}}
)
assert result.data["error_type"] == "not_found"
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_dashboard_info_access_denied(mock_info, mcp_server):
mock_info.return_value = None # Access denied returns None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_info", {"request": {"identifier": 1}}
)
assert result.data["error_type"] == "not_found"
# TODO (Phase 3+): Add tests for get_dashboard_available_filters tool
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
@pytest.mark.asyncio
async def test_get_dashboard_info_by_uuid(mock_find_object, mcp_server):
"""Test getting dashboard info using UUID identifier."""
dashboard = Mock()
dashboard.id = 1
dashboard.dashboard_title = "Test Dashboard UUID"
dashboard.slug = "test-dashboard-uuid"
dashboard.description = "Test description"
dashboard.css = ""
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = "{}"
dashboard.position_json = "{}"
dashboard.published = True
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.created_on = None
dashboard.changed_on = None
dashboard.created_by = None
dashboard.changed_by = None
dashboard.uuid = "c3d4e5f6-g7h8-9012-cdef-gh3456789012"
dashboard.url = "/dashboard/1"
dashboard.thumbnail_url = None
dashboard.created_on_humanized = "2 days ago"
dashboard.changed_on_humanized = "1 day ago"
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
uuid_str = "c3d4e5f6-g7h8-9012-cdef-gh3456789012"
result = await client.call_tool(
"get_dashboard_info", {"request": {"identifier": uuid_str}}
)
assert result.data["dashboard_title"] == "Test Dashboard UUID"
@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object")
@pytest.mark.asyncio
async def test_get_dashboard_info_by_slug(mock_find_object, mcp_server):
"""Test getting dashboard info using slug identifier."""
dashboard = Mock()
dashboard.id = 2
dashboard.dashboard_title = "Test Dashboard Slug"
dashboard.slug = "test-dashboard-slug"
dashboard.description = "Test description"
dashboard.css = ""
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = "{}"
dashboard.position_json = "{}"
dashboard.published = True
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.created_on = None
dashboard.changed_on = None
dashboard.created_by = None
dashboard.changed_by = None
dashboard.uuid = "d4e5f6g7-h8i9-0123-defg-hi4567890123"
dashboard.url = "/dashboard/2"
dashboard.thumbnail_url = None
dashboard.created_on_humanized = "2 days ago"
dashboard.changed_on_humanized = "1 day ago"
dashboard.slices = []
dashboard.owners = []
dashboard.tags = []
dashboard.roles = []
mock_find_object.return_value = dashboard
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_info", {"request": {"identifier": "test-dashboard-slug"}}
)
assert result.data["dashboard_title"] == "Test Dashboard Slug"
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_custom_uuid_slug_columns(mock_list, mcp_server):
"""Test that custom column selection includes UUID and slug when explicitly
requested."""
dashboard = Mock()
dashboard.id = 1
dashboard.dashboard_title = "Custom Columns Dashboard"
dashboard.slug = "custom-dashboard"
dashboard.uuid = "test-custom-uuid-123"
dashboard.url = "/dashboard/1"
dashboard.published = True
dashboard.changed_by_name = "admin"
dashboard.changed_on = None
dashboard.changed_on_humanized = None
dashboard.created_by_name = "admin"
dashboard.created_on = None
dashboard.created_on_humanized = None
dashboard.tags = []
dashboard.owners = []
dashboard.slices = []
dashboard.description = None
dashboard.css = None
dashboard.certified_by = None
dashboard.certification_details = None
dashboard.json_metadata = None
dashboard.position_json = None
dashboard.is_managed_externally = False
dashboard.external_url = None
dashboard.thumbnail_url = None
dashboard.roles = []
dashboard.charts = []
dashboard._mapping = {
"id": dashboard.id,
"dashboard_title": dashboard.dashboard_title,
"slug": dashboard.slug,
"uuid": dashboard.uuid,
"url": dashboard.url,
"published": dashboard.published,
"changed_by_name": dashboard.changed_by_name,
"changed_on": dashboard.changed_on,
"changed_on_humanized": dashboard.changed_on_humanized,
"created_by_name": dashboard.created_by_name,
"created_on": dashboard.created_on,
"created_on_humanized": dashboard.created_on_humanized,
"tags": dashboard.tags,
"owners": dashboard.owners,
"charts": [],
}
mock_list.return_value = ([dashboard], 1)
async with Client(mcp_server) as client:
request = ListDashboardsRequest(
select_columns=["id", "dashboard_title", "uuid", "slug"],
page=1,
page_size=10,
)
result = await client.call_tool(
"list_dashboards", {"request": request.model_dump()}
)
dashboards = result.data.dashboards
assert len(dashboards) == 1
assert dashboards[0].uuid == "test-custom-uuid-123"
assert dashboards[0].slug == "custom-dashboard"
# Verify custom columns include UUID and slug
assert "uuid" in result.data.columns_requested
assert "slug" in result.data.columns_requested
assert "uuid" in result.data.columns_loaded
assert "slug" in result.data.columns_loaded
class TestDashboardSortableColumns:
"""Test sortable columns configuration for dashboard tools."""
def test_dashboard_sortable_columns_definition(self):
"""Test that dashboard sortable columns are properly defined."""
from superset.mcp_service.dashboard.tool.list_dashboards import (
SORTABLE_DASHBOARD_COLUMNS,
)
assert SORTABLE_DASHBOARD_COLUMNS == [
"id",
"dashboard_title",
"slug",
"published",
"changed_on",
"created_on",
]
# Ensure no computed properties are included
assert "changed_on_delta_humanized" not in SORTABLE_DASHBOARD_COLUMNS
assert "changed_by_name" not in SORTABLE_DASHBOARD_COLUMNS
assert "uuid" not in SORTABLE_DASHBOARD_COLUMNS
@patch("superset.daos.dashboard.DashboardDAO.list")
@pytest.mark.asyncio
async def test_list_dashboards_with_valid_order_column(self, mock_list, mcp_server):
"""Test list_dashboards with valid order column."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
# Test with valid sortable column
request = ListDashboardsRequest(
order_column="dashboard_title", order_direction="desc"
)
result = await client.call_tool(
"list_dashboards", {"request": request.model_dump()}
)
# Verify the DAO was called with the correct order column
mock_list.assert_called_once()
call_args = mock_list.call_args[1]
assert call_args["order_column"] == "dashboard_title"
assert call_args["order_direction"] == "desc"
# Verify the result
assert result.data.count == 0
assert result.data.dashboards == []
def test_sortable_columns_in_docstring(self):
"""Test that sortable columns are documented in tool docstring."""
from superset.mcp_service.dashboard.tool.list_dashboards import (
list_dashboards,
SORTABLE_DASHBOARD_COLUMNS,
)
# Check list_dashboards docstring (stored in description after @mcp.tool)
assert hasattr(list_dashboards, "description")
assert "Sortable columns for order_column:" in list_dashboards.description
for col in SORTABLE_DASHBOARD_COLUMNS:
assert col in list_dashboards.description

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,580 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Comprehensive unit tests for MCP generate_explore_link tool
"""
import logging
from unittest.mock import Mock, patch
import pytest
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.mcp_service.chart.schemas import (
AxisConfig,
ColumnRef,
FilterConfig,
GenerateExploreLinkRequest,
LegendConfig,
TableChartConfig,
XYChartConfig,
)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _mock_dataset(id: int = 1) -> Mock:
"""Create a mock dataset object."""
dataset = Mock()
dataset.id = id
return dataset
class TestGenerateExploreLink:
"""Comprehensive tests for generate_explore_link MCP tool."""
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_table_explore_link_minimal(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test generating explore link for minimal table chart."""
mock_create_form_data.return_value = "test_form_data_key_123"
mock_find_dataset.return_value = _mock_dataset(id=1)
config = TableChartConfig(
chart_type="table", columns=[ColumnRef(name="region")]
)
request = GenerateExploreLinkRequest(dataset_id="1", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=test_form_data_key_123"
)
mock_create_form_data.assert_called_once()
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_table_explore_link_with_features(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test generating explore link for table chart with features."""
mock_create_form_data.return_value = "comprehensive_key_456"
mock_find_dataset.return_value = _mock_dataset(id=5)
config = TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="region", label="Sales Region"),
ColumnRef(name="revenue", label="Total Revenue", aggregate="SUM"),
ColumnRef(name="orders", label="Order Count", aggregate="COUNT"),
],
filters=[
FilterConfig(column="year", op="=", value="2024"),
FilterConfig(column="status", op="!=", value="cancelled"),
],
sort_by=["revenue", "orders"],
)
request = GenerateExploreLinkRequest(dataset_id="5", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=comprehensive_key_456"
)
mock_create_form_data.assert_called_once()
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_line_chart_explore_link(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test generating explore link for line chart."""
mock_create_form_data.return_value = "line_chart_key_789"
mock_find_dataset.return_value = _mock_dataset(id=3)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date", label="Date"),
y=[
ColumnRef(name="sales", label="Daily Sales", aggregate="SUM"),
ColumnRef(name="orders", label="Order Count", aggregate="COUNT"),
],
kind="line",
group_by=ColumnRef(name="region", label="Sales Region"),
x_axis=AxisConfig(title="Time Period", format="smart_date"),
y_axis=AxisConfig(title="Sales Metrics", format="$,.2f"),
legend=LegendConfig(show=True, position="bottom"),
)
request = GenerateExploreLinkRequest(dataset_id="3", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=line_chart_key_789"
)
mock_create_form_data.assert_called_once()
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_bar_chart_explore_link(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test generating explore link for bar chart."""
mock_create_form_data.return_value = "bar_chart_key_abc"
mock_find_dataset.return_value = _mock_dataset(id=7)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="product_category", label="Category"),
y=[ColumnRef(name="revenue", label="Revenue", aggregate="SUM")],
kind="bar",
group_by=ColumnRef(name="quarter", label="Quarter"),
y_axis=AxisConfig(title="Revenue ($)", format="$,.0f"),
)
request = GenerateExploreLinkRequest(dataset_id="7", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=bar_chart_key_abc"
)
mock_create_form_data.assert_called_once()
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_area_chart_explore_link(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test generating explore link for area chart."""
mock_create_form_data.return_value = "area_chart_key_def"
mock_find_dataset.return_value = _mock_dataset(id=2)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="month", label="Month"),
y=[
ColumnRef(
name="cumulative_sales", label="Cumulative Sales", aggregate="SUM"
)
],
kind="area",
legend=LegendConfig(show=False),
)
request = GenerateExploreLinkRequest(dataset_id="2", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=area_chart_key_def"
)
mock_create_form_data.assert_called_once()
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_scatter_chart_explore_link(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test generating explore link for scatter chart."""
mock_create_form_data.return_value = "scatter_chart_key_ghi"
mock_find_dataset.return_value = _mock_dataset(id=4)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="price", label="Unit Price"),
y=[ColumnRef(name="quantity", label="Quantity Sold", aggregate="SUM")],
kind="scatter",
group_by=ColumnRef(name="product_type", label="Product Type"),
x_axis=AxisConfig(title="Price ($)", format="$,.2f"),
y_axis=AxisConfig(title="Quantity", format=",.0f"),
)
request = GenerateExploreLinkRequest(dataset_id="4", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=scatter_chart_key_ghi"
)
mock_create_form_data.assert_called_once()
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_explore_link_cache_failure_fallback(
self, mock_create_form_data, mcp_server
):
"""Test fallback when form_data cache creation fails."""
mock_create_form_data.side_effect = Exception("Cache storage failed")
config = TableChartConfig(
chart_type="table", columns=[ColumnRef(name="test_col")]
)
request = GenerateExploreLinkRequest(dataset_id="1", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
# Should fallback to basic URL format
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=1"
)
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_explore_link_database_lock_fallback(
self, mock_create_form_data, mcp_server
):
"""Test fallback when database is locked."""
from sqlalchemy.exc import OperationalError
mock_create_form_data.side_effect = OperationalError(
"database is locked", None, None
)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="date"),
y=[ColumnRef(name="sales")],
kind="line",
)
request = GenerateExploreLinkRequest(dataset_id="5", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
# Should fallback to basic dataset URL
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?datasource_type=table&datasource_id=5"
)
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_explore_link_with_many_columns(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test generating explore link with many columns."""
mock_create_form_data.return_value = "many_columns_key"
mock_find_dataset.return_value = _mock_dataset(id=1)
# Create 15 columns
columns = [
ColumnRef(
name=f"metric_{i}",
label=f"Metric {i}",
aggregate="SUM" if i % 2 == 0 else "COUNT",
)
for i in range(15)
]
config = TableChartConfig(chart_type="table", columns=columns)
request = GenerateExploreLinkRequest(dataset_id="1", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=many_columns_key"
)
mock_create_form_data.assert_called_once()
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_explore_link_with_many_filters(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test generating explore link with many filters."""
mock_create_form_data.return_value = "many_filters_key"
mock_find_dataset.return_value = _mock_dataset(id=1)
# Create 12 filters
filters = [
FilterConfig(
column=f"filter_col_{i}",
op="=" if i % 3 == 0 else "!=",
value=f"value_{i}",
)
for i in range(12)
]
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x_col"),
y=[ColumnRef(name="y_col")],
kind="bar",
filters=filters,
)
request = GenerateExploreLinkRequest(dataset_id="1", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=many_filters_key"
)
mock_create_form_data.assert_called_once()
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_explore_link_url_format_consistency(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test that all generated URLs follow consistent format."""
mock_create_form_data.return_value = "consistency_test_key"
mock_find_dataset.return_value = _mock_dataset(id=1)
configs = [
TableChartConfig(chart_type="table", columns=[ColumnRef(name="col1")]),
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
kind="line",
),
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
kind="bar",
),
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
kind="area",
),
XYChartConfig(
chart_type="xy",
x=ColumnRef(name="x"),
y=[ColumnRef(name="y")],
kind="scatter",
),
]
for i, config in enumerate(configs):
request = GenerateExploreLinkRequest(dataset_id=str(i + 1), config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
# All URLs should follow the same format
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=consistency_test_key"
)
assert result.data["error"] is None
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_explore_link_dataset_id_types(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test explore link generation with different dataset_id formats."""
mock_create_form_data.return_value = "dataset_test_key"
mock_find_dataset.return_value = _mock_dataset(id=1)
config = TableChartConfig(
chart_type="table", columns=[ColumnRef(name="test_col")]
)
# Test various dataset_id formats
dataset_ids = ["1", "42", "999", "123456789"]
for dataset_id in dataset_ids:
request = GenerateExploreLinkRequest(dataset_id=dataset_id, config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=dataset_test_key"
)
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_generate_explore_link_complex_configuration(
self, mock_create_form_data, mock_find_dataset, mcp_server
):
"""Test explore link generation with complex chart configuration."""
mock_create_form_data.return_value = "complex_config_key"
mock_find_dataset.return_value = _mock_dataset(id=10)
config = XYChartConfig(
chart_type="xy",
x=ColumnRef(name="timestamp", label="Time"),
y=[
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
ColumnRef(name="orders", label="Orders", aggregate="COUNT"),
ColumnRef(name="profit", label="Profit", aggregate="AVG"),
],
kind="line",
group_by=ColumnRef(name="region", label="Region"),
x_axis=AxisConfig(title="Time Period", format="smart_date"),
y_axis=AxisConfig(title="Metrics", format="$,.2f", scale="linear"),
legend=LegendConfig(show=True, position="bottom"),
filters=[
FilterConfig(column="status", op="=", value="active"),
FilterConfig(column="date", op=">=", value="2024-01-01"),
FilterConfig(column="revenue", op=">", value="1000"),
],
)
request = GenerateExploreLinkRequest(dataset_id="10", config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
assert result.data["error"] is None
assert (
result.data["url"]
== "http://localhost:9001/explore/?form_data_key=complex_config_key"
)
mock_create_form_data.assert_called_once()
@patch(
"superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run"
)
@pytest.mark.asyncio
async def test_fallback_url_different_datasets(
self, mock_create_form_data, mcp_server
):
"""Test fallback URLs are correct for different dataset IDs."""
mock_create_form_data.side_effect = Exception(
"Always fail for fallback testing"
)
config = TableChartConfig(chart_type="table", columns=[ColumnRef(name="col")])
dataset_ids = ["1", "5", "100", "999"]
for dataset_id in dataset_ids:
request = GenerateExploreLinkRequest(dataset_id=dataset_id, config=config)
async with Client(mcp_server) as client:
result = await client.call_tool(
"generate_explore_link", {"request": request.model_dump()}
)
# Should fallback to basic URL with correct dataset_id
expected_url = f"http://localhost:9001/explore/?datasource_type=table&datasource_id={dataset_id}"
assert result.data["error"] is None
assert result.data["url"] == expected_url

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Helper function to extract row data from MCP responses.
The MCP client seems to wrap dict rows in Root objects.
This helper handles the extraction properly.
"""
def extract_row_data(row):
"""Extract dictionary data from a row object."""
# Handle different possible formats
if isinstance(row, dict):
return row
# Check for Pydantic Root object
if hasattr(row, "__root__"):
return row.__root__
# Check if it's a Pydantic model with model_dump
if hasattr(row, "model_dump"):
return row.model_dump()
# Try to access __dict__ directly
if hasattr(row, "__dict__"):
# Filter out private attributes
return {k: v for k, v in row.__dict__.items() if not k.startswith("_")}
# Last resort - convert to string and parse
# This is for the Root object issue
row_str = str(row)
if row_str == "Root()":
# Empty Root object - the actual data might be elsewhere
# Let's check all attributes
attrs = dir(row)
for attr in attrs:
if not attr.startswith("_") and attr not in [
"model_dump",
"model_validate",
]:
try:
val = getattr(row, attr)
if isinstance(val, dict):
return val
except AttributeError:
pass
raise ValueError(f"Cannot extract data from row of type {type(row)}: {row}")

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,497 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for execute_sql MCP tool
"""
import logging
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from superset.mcp_service.app import mcp
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _mock_database(
id: int = 1,
database_name: str = "test_db",
allow_dml: bool = False,
) -> Mock:
"""Create a mock database object."""
database = Mock()
database.id = id
database.database_name = database_name
database.allow_dml = allow_dml
# Mock raw connection context manager
mock_cursor = Mock()
mock_cursor.description = [
("id", "INTEGER", None, None, None, None, False),
("name", "VARCHAR", None, None, None, None, True),
]
mock_cursor.fetchmany.return_value = [(1, "test_name")]
mock_cursor.rowcount = 1
mock_conn = Mock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.commit = Mock()
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_conn
mock_context.__exit__.return_value = None
database.get_raw_connection.return_value = mock_context
return database
class TestExecuteSql:
"""Tests for execute_sql MCP tool."""
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_basic_select(
self, mock_db, mock_security_manager, mcp_server
):
"""Test basic SELECT query execution."""
# Setup mocks
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT id, name FROM users LIMIT 10",
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.row_count == 1
assert len(result.data.rows) == 1
assert result.data.rows[0]["id"] == 1
assert result.data.rows[0]["name"] == "test_name"
assert len(result.data.columns) == 2
assert result.data.columns[0].name == "id"
assert result.data.columns[0].type == "INTEGER"
assert result.data.execution_time > 0
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_with_parameters(
self, mock_db, mock_security_manager, mcp_server
):
"""Test SQL execution with parameter substitution."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table} WHERE status = '{status}' LIMIT {limit}",
"parameters": {"table": "orders", "status": "active", "limit": "5"},
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
# Verify parameter substitution happened
mock_database.get_raw_connection.assert_called_once()
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
# Check that the SQL was formatted with parameters
executed_sql = cursor.execute.call_args[0][0]
assert "orders" in executed_sql
assert "active" in executed_sql
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_database_not_found(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when database is not found."""
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
None
)
request = {
"database_id": 999,
"sql": "SELECT 1",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Database with ID 999 not found" in result.data.error
assert result.data.error_type == "DATABASE_NOT_FOUND_ERROR"
assert result.data.rows is None
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_access_denied(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when user lacks database access."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
# Use Mock instead of AsyncMock for synchronous call
from unittest.mock import Mock
mock_security_manager.can_access_database = Mock(return_value=False)
request = {
"database_id": 1,
"sql": "SELECT 1",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Access denied to database" in result.data.error
assert result.data.error_type == "SECURITY_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_dml_not_allowed(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when DML operations are not allowed."""
mock_database = _mock_database(allow_dml=False)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "UPDATE users SET name = 'test' WHERE id = 1",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert result.data.error_type == "DML_NOT_ALLOWED"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_dml_allowed(
self, mock_db, mock_security_manager, mcp_server
):
"""Test successful DML execution when allowed."""
mock_database = _mock_database(allow_dml=True)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock cursor for DML operation
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.rowcount = 3 # 3 rows affected
request = {
"database_id": 1,
"sql": "UPDATE users SET active = true WHERE last_login > '2024-01-01'",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.affected_rows == 3
assert result.data.rows == [] # Empty rows for DML
assert result.data.row_count == 0
# Verify commit was called
(
mock_database.get_raw_connection.return_value.__enter__.return_value.commit.assert_called_once()
)
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_empty_results(
self, mock_db, mock_security_manager, mcp_server
):
"""Test query that returns no results."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock empty results
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.fetchmany.return_value = []
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 999999",
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.row_count == 0
assert len(result.data.rows) == 0
assert len(result.data.columns) == 2 # Column metadata still returned
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_missing_parameter(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when required parameter is missing."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table_name} WHERE id = {user_id}",
"parameters": {"table_name": "users"}, # Missing user_id
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "user_id" in result.data.error # Error contains parameter name
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_empty_parameters_with_placeholders(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when empty parameters dict is provided but SQL has
placeholders."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table_name} LIMIT 5",
"parameters": {}, # Empty dict but SQL has {table_name}
"limit": 5,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Missing parameter: table_name" in result.data.error
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_with_schema(
self, mock_db, mock_security_manager, mcp_server
):
"""Test SQL execution with schema specification."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT COUNT(*) as total FROM orders",
"schema": "sales",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
# Verify schema was passed to get_raw_connection
# Verify schema was passed
call_args = mock_database.get_raw_connection.call_args
assert call_args[1]["schema"] == "sales"
assert call_args[1]["catalog"] is None
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_limit_enforcement(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that LIMIT is added to SELECT queries without one."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM users", # No LIMIT
"limit": 50,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
# Verify LIMIT was added
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
executed_sql = cursor.execute.call_args[0][0]
assert "LIMIT 50" in executed_sql
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_sql_injection_prevention(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that SQL injection attempts are handled safely."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock execute to raise an exception
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.execute.side_effect = Exception("Syntax error")
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Syntax error" in result.data.error # Contains actual error
assert result.data.error_type == "EXECUTION_ERROR"
@pytest.mark.asyncio
async def test_execute_sql_empty_query_validation(self, mcp_server):
"""Test validation of empty SQL query."""
request = {
"database_id": 1,
"sql": " ", # Empty/whitespace only
"limit": 10,
}
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="SQL query cannot be empty"):
await client.call_tool("execute_sql", {"request": request})
@pytest.mark.asyncio
async def test_execute_sql_invalid_limit(self, mcp_server):
"""Test validation of invalid limit values."""
# Test limit too low
request = {
"database_id": 1,
"sql": "SELECT 1",
"limit": 0,
}
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="minimum of 1"):
await client.call_tool("execute_sql", {"request": request})
# Test limit too high
request = {
"database_id": 1,
"sql": "SELECT 1",
"limit": 20000,
}
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="maximum of 10000"):
await client.call_tool("execute_sql", {"request": request})