mirror of
https://github.com/apache/superset.git
synced 2026-04-28 04:25:07 +00:00
feat(mcp): add a preview flow to mcp chart updates (#39383)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
(cherry picked from commit 69f062b804)
This commit is contained in:
committed by
Michael S. Molina
parent
02e6d671b4
commit
4e9a6db9b3
@@ -1386,9 +1386,21 @@ class UpdateChartRequest(QueryCacheControl):
|
||||
chart_name: str | None = Field(
|
||||
None, description="Auto-generates if omitted", max_length=255
|
||||
)
|
||||
generate_preview: bool = True
|
||||
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
|
||||
generate_preview: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"When True (default), returns a preview explore URL so the user "
|
||||
"can review changes before saving. When False, persists the "
|
||||
"update immediately."
|
||||
),
|
||||
)
|
||||
preview_formats: list[Literal["url", "ascii", "vega_lite", "table"]] = Field(
|
||||
default_factory=lambda: ["url"],
|
||||
description=(
|
||||
"Extra preview formats to render after saving. Only used when "
|
||||
"generate_preview=False. When generate_preview=True, the preview "
|
||||
"is always an explore URL."
|
||||
),
|
||||
)
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
|
||||
@@ -22,6 +22,7 @@ MCP tool: update_chart
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -65,6 +66,32 @@ def _find_chart(identifier: int | str) -> Any | None:
|
||||
return ChartDAO.find_by_id(identifier, id_column="uuid")
|
||||
|
||||
|
||||
def _validation_error_response(message: str, details: str) -> GenerateChartResponse:
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"error": {
|
||||
"error_type": "ValidationError",
|
||||
"message": message,
|
||||
"details": details,
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _missing_config_or_name_error() -> GenerateChartResponse:
|
||||
return _validation_error_response(
|
||||
message="Either 'config' or 'chart_name' must be provided.",
|
||||
details=(
|
||||
"Either 'config' or 'chart_name' must be provided. "
|
||||
"Use config for visualization changes, chart_name for renaming."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _build_update_payload(
|
||||
request: UpdateChartRequest,
|
||||
chart: Any,
|
||||
@@ -94,26 +121,94 @@ def _build_update_payload(
|
||||
|
||||
# Name-only update: keep existing visualization, just rename
|
||||
if not request.chart_name:
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"error": {
|
||||
"error_type": "ValidationError",
|
||||
"message": ("Either 'config' or 'chart_name' must be provided."),
|
||||
"details": (
|
||||
"Either 'config' or 'chart_name' must be provided. "
|
||||
"Use config for visualization changes, chart_name "
|
||||
"for renaming."
|
||||
),
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
)
|
||||
return _missing_config_or_name_error()
|
||||
return {"slice_name": request.chart_name}
|
||||
|
||||
|
||||
def _build_preview_form_data(
|
||||
request: UpdateChartRequest,
|
||||
chart: Any,
|
||||
) -> dict[str, Any] | GenerateChartResponse:
|
||||
"""Merge the existing chart's form_data with the requested changes.
|
||||
|
||||
Used by the preview-first flow so the user can review edits in Explore
|
||||
before clicking Save. Returns the merged form_data dict on success, or a
|
||||
GenerateChartResponse error when neither config nor chart_name is given.
|
||||
"""
|
||||
existing_form_data: dict[str, Any] = {}
|
||||
if getattr(chart, "params", None):
|
||||
try:
|
||||
existing_form_data = json.loads(chart.params) or {}
|
||||
except (ValueError, TypeError):
|
||||
logger.warning(
|
||||
"Failed to parse existing chart.params for chart %s", chart.id
|
||||
)
|
||||
existing_form_data = {}
|
||||
|
||||
if request.config is not None:
|
||||
config = parse_chart_config(request.config)
|
||||
dataset_id = chart.datasource_id if chart.datasource_id else None
|
||||
new_form_data = map_config_to_form_data(config, dataset_id=dataset_id)
|
||||
new_form_data.pop("_mcp_warnings", None)
|
||||
merged = {**existing_form_data, **new_form_data}
|
||||
else:
|
||||
if not request.chart_name:
|
||||
return _missing_config_or_name_error()
|
||||
merged = dict(existing_form_data)
|
||||
|
||||
if request.chart_name:
|
||||
merged["slice_name"] = request.chart_name
|
||||
elif chart.slice_name:
|
||||
merged["slice_name"] = chart.slice_name
|
||||
|
||||
merged["slice_id"] = chart.id
|
||||
if chart.datasource_id:
|
||||
merged["datasource"] = f"{chart.datasource_id}__table"
|
||||
|
||||
return merged
|
||||
|
||||
|
||||
def _create_preview_url(
|
||||
chart: Any, form_data: dict[str, Any]
|
||||
) -> tuple[str, str | None, list[str]]:
|
||||
"""Cache form_data and return (explore_url, form_data_key, warnings).
|
||||
|
||||
The URL includes both ``slice_id`` and ``form_data_key`` so that when the
|
||||
user clicks Save in Explore, the edits overwrite the original chart.
|
||||
"""
|
||||
from superset.commands.explore.form_data.parameters import CommandParameters
|
||||
from superset.mcp_service.commands.create_form_data import MCPCreateFormDataCommand
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
base_url = get_superset_base_url()
|
||||
|
||||
if not chart.datasource_id:
|
||||
warning = (
|
||||
"Chart has no datasource; the preview URL shows the saved chart "
|
||||
"state, not the pending changes. Open the URL and apply the "
|
||||
"changes manually."
|
||||
)
|
||||
logger.warning(
|
||||
"Chart %s has no datasource_id; preview URL cannot embed "
|
||||
"form_data — user will see saved state.",
|
||||
chart.id,
|
||||
)
|
||||
return f"{base_url}/explore/?slice_id={chart.id}", None, [warning]
|
||||
|
||||
cmd_params = CommandParameters(
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=chart.datasource_id,
|
||||
chart_id=chart.id,
|
||||
tab_id=None,
|
||||
form_data=json.dumps(form_data),
|
||||
)
|
||||
form_data_key = MCPCreateFormDataCommand(cmd_params).run()
|
||||
explore_url = (
|
||||
f"{base_url}/explore/?form_data_key={form_data_key}&slice_id={chart.id}"
|
||||
)
|
||||
return explore_url, form_data_key, []
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["mutate"],
|
||||
class_permission_name="Chart",
|
||||
@@ -128,13 +223,16 @@ async def update_chart( # noqa: C901
|
||||
) -> GenerateChartResponse:
|
||||
"""Update existing chart with new configuration.
|
||||
|
||||
IMPORTANT:
|
||||
- Chart must already be saved (from generate_chart with save_chart=True)
|
||||
- LLM clients MUST display updated chart URL to users
|
||||
- Use numeric ID or UUID string to identify the chart (NOT chart name)
|
||||
- MUST include chart_type in config (either 'xy' or 'table')
|
||||
IMPORTANT BEHAVIOR:
|
||||
- By default (generate_preview=True), a preview explore URL is returned
|
||||
so the user can review changes and click Save to overwrite the original
|
||||
chart (if they have permission).
|
||||
- Set generate_preview=False to persist the update immediately.
|
||||
- LLM clients MUST display the returned explore URL to users.
|
||||
- Use numeric ID or UUID string to identify the chart (NOT chart name).
|
||||
- MUST include chart_type in config (either 'xy' or 'table').
|
||||
|
||||
Example usage:
|
||||
Example usage (preview, default):
|
||||
```json
|
||||
{
|
||||
"identifier": 123,
|
||||
@@ -147,17 +245,12 @@ async def update_chart( # noqa: C901
|
||||
}
|
||||
```
|
||||
|
||||
Or with UUID:
|
||||
Example usage (persist immediately):
|
||||
```json
|
||||
{
|
||||
"identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab",
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "product_name"},
|
||||
{"name": "revenue", "aggregate": "SUM"}
|
||||
]
|
||||
}
|
||||
"identifier": 123,
|
||||
"generate_preview": false,
|
||||
"config": {"chart_type": "table", "columns": [{"name": "region"}]}
|
||||
}
|
||||
```
|
||||
|
||||
@@ -218,25 +311,43 @@ async def update_chart( # noqa: C901
|
||||
}
|
||||
)
|
||||
|
||||
# Build update payload (config update or name-only rename)
|
||||
from superset.commands.chart.update import UpdateChartCommand
|
||||
updated_chart: Any = None
|
||||
explore_url: str
|
||||
form_data_key: str | None = None
|
||||
warnings: list[str] = []
|
||||
saved = False
|
||||
|
||||
payload_or_error = _build_update_payload(request, chart)
|
||||
if isinstance(payload_or_error, GenerateChartResponse):
|
||||
return payload_or_error
|
||||
if not request.generate_preview:
|
||||
from superset.commands.chart.update import UpdateChartCommand
|
||||
|
||||
with event_logger.log_context(action="mcp.update_chart.db_write"):
|
||||
command = UpdateChartCommand(chart.id, payload_or_error)
|
||||
updated_chart = command.run()
|
||||
payload_or_error = _build_update_payload(request, chart)
|
||||
if isinstance(payload_or_error, GenerateChartResponse):
|
||||
return payload_or_error
|
||||
|
||||
with event_logger.log_context(action="mcp.update_chart.db_write"):
|
||||
command = UpdateChartCommand(chart.id, payload_or_error)
|
||||
updated_chart = command.run()
|
||||
saved = True
|
||||
explore_url = (
|
||||
f"{get_superset_base_url()}/explore/?slice_id={updated_chart.id}"
|
||||
)
|
||||
else:
|
||||
preview_or_error = _build_preview_form_data(request, chart)
|
||||
if isinstance(preview_or_error, GenerateChartResponse):
|
||||
return preview_or_error
|
||||
|
||||
with event_logger.log_context(action="mcp.update_chart.preview_link"):
|
||||
explore_url, form_data_key, warnings = _create_preview_url(
|
||||
chart, preview_or_error
|
||||
)
|
||||
|
||||
# Parse config for analysis (may be None for name-only updates)
|
||||
config = parse_chart_config(request.config) if request.config else None
|
||||
|
||||
# Generate semantic analysis
|
||||
capabilities = analyze_chart_capabilities(updated_chart, config)
|
||||
semantics = analyze_chart_semantics(updated_chart, config)
|
||||
chart_for_analysis = updated_chart if saved else chart
|
||||
capabilities = analyze_chart_capabilities(chart_for_analysis, config)
|
||||
semantics = analyze_chart_semantics(chart_for_analysis, config)
|
||||
|
||||
# Create performance metadata
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
performance = PerformanceMetadata(
|
||||
query_duration_ms=execution_time,
|
||||
@@ -244,21 +355,29 @@ async def update_chart( # noqa: C901
|
||||
optimization_suggestions=[],
|
||||
)
|
||||
|
||||
# Create accessibility metadata
|
||||
chart_name = (
|
||||
updated_chart.slice_name
|
||||
if updated_chart and hasattr(updated_chart, "slice_name")
|
||||
else (generate_chart_name(config) if config else "Updated chart")
|
||||
if saved and updated_chart and hasattr(updated_chart, "slice_name")
|
||||
else (
|
||||
request.chart_name
|
||||
or (chart.slice_name if hasattr(chart, "slice_name") else None)
|
||||
or (generate_chart_name(config) if config else "Updated chart")
|
||||
)
|
||||
)
|
||||
accessibility = AccessibilityMetadata(
|
||||
color_blind_safe=True, # Would need actual analysis
|
||||
alt_text=f"Updated chart showing {chart_name}",
|
||||
color_blind_safe=True,
|
||||
alt_text=(
|
||||
f"Updated chart showing {chart_name}"
|
||||
if saved
|
||||
else f"Updated chart preview showing {chart_name}"
|
||||
),
|
||||
high_contrast_available=False,
|
||||
)
|
||||
|
||||
# Generate previews if requested
|
||||
previews = {}
|
||||
if request.generate_preview:
|
||||
# Generate previews for saved charts only. Unsaved previews rely on
|
||||
# the explore URL for interactive viewing.
|
||||
previews: dict[str, Any] = {}
|
||||
if saved and updated_chart and request.preview_formats:
|
||||
try:
|
||||
with event_logger.log_context(action="mcp.update_chart.preview"):
|
||||
from superset.mcp_service.chart.tool.get_chart_preview import (
|
||||
@@ -279,35 +398,44 @@ async def update_chart( # noqa: C901
|
||||
previews[format_type] = preview_result.content
|
||||
|
||||
except Exception as e:
|
||||
# Log warning but don't fail the entire request
|
||||
logger.warning("Preview generation failed: %s", e)
|
||||
|
||||
# Return enhanced data
|
||||
# Fallback: extract form_data_key from explore_url if not set
|
||||
if form_data_key is None and explore_url and "form_data_key=" in explore_url:
|
||||
parsed = urlparse(explore_url)
|
||||
values = parse_qs(parsed.query).get("form_data_key")
|
||||
if values:
|
||||
form_data_key = values[0]
|
||||
|
||||
chart_id = updated_chart.id if saved and updated_chart else chart.id
|
||||
chart_uuid = (
|
||||
str(updated_chart.uuid)
|
||||
if saved and updated_chart and updated_chart.uuid
|
||||
else (str(chart.uuid) if chart.uuid else None)
|
||||
)
|
||||
viz_type = updated_chart.viz_type if saved and updated_chart else chart.viz_type
|
||||
|
||||
result = {
|
||||
"chart": {
|
||||
"id": updated_chart.id,
|
||||
"slice_name": updated_chart.slice_name,
|
||||
"viz_type": updated_chart.viz_type,
|
||||
"url": (
|
||||
f"{get_superset_base_url()}/explore/?slice_id={updated_chart.id}"
|
||||
),
|
||||
"uuid": str(updated_chart.uuid) if updated_chart.uuid else None,
|
||||
"updated": True,
|
||||
"id": chart_id,
|
||||
"slice_name": chart_name,
|
||||
"viz_type": viz_type,
|
||||
"url": explore_url,
|
||||
"uuid": chart_uuid,
|
||||
"form_data_key": form_data_key,
|
||||
"is_unsaved_state": not saved,
|
||||
},
|
||||
"error": None,
|
||||
# Enhanced fields for better LLM integration
|
||||
"warnings": warnings,
|
||||
"previews": previews,
|
||||
"capabilities": capabilities.model_dump() if capabilities else None,
|
||||
"semantics": semantics.model_dump() if semantics else None,
|
||||
"explore_url": (
|
||||
f"{get_superset_base_url()}/explore/?slice_id={updated_chart.id}"
|
||||
),
|
||||
"explore_url": explore_url,
|
||||
"form_data_key": form_data_key,
|
||||
"api_endpoints": {
|
||||
"data": (
|
||||
f"{get_superset_base_url()}/api/v1/chart/{updated_chart.id}/data/"
|
||||
),
|
||||
"data": f"{get_superset_base_url()}/api/v1/chart/{chart_id}/data/",
|
||||
"export": (
|
||||
f"{get_superset_base_url()}/api/v1/chart/{updated_chart.id}/export/"
|
||||
f"{get_superset_base_url()}/api/v1/chart/{chart_id}/export/"
|
||||
),
|
||||
},
|
||||
"performance": performance.model_dump() if performance else None,
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
Unit tests for update_chart MCP tool
|
||||
"""
|
||||
|
||||
import importlib
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
@@ -37,10 +38,18 @@ from superset.mcp_service.chart.schemas import (
|
||||
XYChartConfig,
|
||||
)
|
||||
from superset.mcp_service.chart.tool.update_chart import (
|
||||
_build_preview_form_data,
|
||||
_build_update_payload,
|
||||
_find_chart,
|
||||
)
|
||||
|
||||
# The __init__.py re-exports the update_chart *function*, so a plain
|
||||
# `from ... import update_chart` gives the function, not the module.
|
||||
# Use importlib to get the module for patch.object().
|
||||
update_chart_module = importlib.import_module(
|
||||
"superset.mcp_service.chart.tool.update_chart"
|
||||
)
|
||||
|
||||
|
||||
class TestUpdateChart:
|
||||
"""Tests for update_chart MCP tool."""
|
||||
@@ -105,33 +114,28 @@ class TestUpdateChart:
|
||||
assert request2.chart_name == "Updated Sales Report"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_generation(self):
|
||||
"""Test preview generation options in update request."""
|
||||
async def test_update_chart_preview_formats(self):
|
||||
"""Test preview_formats options in update request."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
|
||||
# Default preview generation
|
||||
# Default preview formats
|
||||
request1 = UpdateChartRequest(identifier=123, config=config)
|
||||
assert request1.generate_preview is True
|
||||
assert request1.preview_formats == ["url"]
|
||||
|
||||
# Custom preview formats
|
||||
request2 = UpdateChartRequest(
|
||||
identifier=123,
|
||||
config=config,
|
||||
generate_preview=True,
|
||||
preview_formats=["url", "ascii", "table"],
|
||||
)
|
||||
assert request2.generate_preview is True
|
||||
assert set(request2.preview_formats) == {"url", "ascii", "table"}
|
||||
|
||||
# Disable preview generation
|
||||
request3 = UpdateChartRequest(
|
||||
identifier=123, config=config, generate_preview=False
|
||||
)
|
||||
assert request3.generate_preview is False
|
||||
# Empty preview formats (no extra previews after save)
|
||||
request3 = UpdateChartRequest(identifier=123, config=config, preview_formats=[])
|
||||
assert request3.preview_formats == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_identifier_types(self):
|
||||
@@ -685,6 +689,7 @@ class TestUpdateChartNameOnly:
|
||||
|
||||
assert result.structured_content["success"] is True
|
||||
assert result.structured_content["chart"]["slice_name"] == "Renamed Chart"
|
||||
assert result.structured_content["chart"]["is_unsaved_state"] is False
|
||||
|
||||
# Verify UpdateChartCommand was called with name-only payload
|
||||
mock_update_cmd_cls.assert_called_once_with(
|
||||
@@ -729,3 +734,391 @@ class TestUpdateChartNameOnly:
|
||||
assert error["error_type"] == "ValidationError"
|
||||
assert "config" in error["message"].lower()
|
||||
assert "chart_name" in error["message"].lower()
|
||||
|
||||
|
||||
class TestUpdateChartPreviewFirst:
|
||||
"""Integration-style tests for the preview-first default flow."""
|
||||
|
||||
@patch.object(update_chart_module, "_create_preview_url", new_callable=Mock)
|
||||
@patch(
|
||||
"superset.commands.chart.update.UpdateChartCommand",
|
||||
new_callable=Mock,
|
||||
)
|
||||
@patch(
|
||||
"superset.mcp_service.auth.check_chart_data_access",
|
||||
new_callable=Mock,
|
||||
)
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_generates_preview_without_saving(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_check_access,
|
||||
mock_update_cmd_cls,
|
||||
mock_create_preview,
|
||||
mcp_server,
|
||||
):
|
||||
"""Default update flow returns a preview URL and does NOT save."""
|
||||
mock_chart = Mock()
|
||||
mock_chart.id = 1
|
||||
mock_chart.datasource_id = 10
|
||||
mock_chart.slice_name = "Existing Chart"
|
||||
mock_chart.viz_type = "table"
|
||||
mock_chart.uuid = "abc-123"
|
||||
mock_chart.params = '{"viz_type": "table", "datasource": "10__table"}'
|
||||
mock_find_by_id.return_value = mock_chart
|
||||
|
||||
mock_check_access.return_value = DatasetValidationResult(
|
||||
is_valid=True,
|
||||
dataset_id=10,
|
||||
dataset_name="my_dataset",
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
preview_url = (
|
||||
"http://localhost:8088/explore/?form_data_key=preview_key&slice_id=1"
|
||||
)
|
||||
mock_create_preview.return_value = (preview_url, "preview_key", [])
|
||||
|
||||
request = {
|
||||
"identifier": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "col1"}],
|
||||
},
|
||||
}
|
||||
|
||||
async with Client(mcp) as client:
|
||||
result = await client.call_tool("update_chart", {"request": request})
|
||||
|
||||
assert result.structured_content["success"] is True
|
||||
assert result.structured_content["chart"]["is_unsaved_state"] is True
|
||||
assert result.structured_content["chart"]["id"] == 1
|
||||
assert result.structured_content["chart"]["form_data_key"] == "preview_key"
|
||||
assert result.structured_content["explore_url"] == preview_url
|
||||
assert result.structured_content["form_data_key"] == "preview_key"
|
||||
|
||||
# Ensure the chart was NOT persisted
|
||||
mock_update_cmd_cls.assert_not_called()
|
||||
mock_create_preview.assert_called_once()
|
||||
|
||||
@patch.object(update_chart_module, "_create_preview_url", new_callable=Mock)
|
||||
@patch(
|
||||
"superset.mcp_service.auth.check_chart_data_access",
|
||||
new_callable=Mock,
|
||||
)
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_missing_config_and_name_returns_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_check_access,
|
||||
mock_create_preview,
|
||||
mcp_server,
|
||||
):
|
||||
"""Preview flow also errors when neither config nor chart_name given."""
|
||||
mock_chart = Mock()
|
||||
mock_chart.id = 1
|
||||
mock_chart.datasource_id = 10
|
||||
mock_chart.params = "{}"
|
||||
mock_find_by_id.return_value = mock_chart
|
||||
|
||||
mock_check_access.return_value = DatasetValidationResult(
|
||||
is_valid=True,
|
||||
dataset_id=10,
|
||||
dataset_name="my_dataset",
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
async with Client(mcp) as client:
|
||||
result = await client.call_tool(
|
||||
"update_chart", {"request": {"identifier": 1}}
|
||||
)
|
||||
|
||||
assert result.structured_content["success"] is False
|
||||
error = result.structured_content["error"]
|
||||
assert error["error_type"] == "ValidationError"
|
||||
mock_create_preview.assert_not_called()
|
||||
|
||||
|
||||
class TestBuildPreviewFormData:
|
||||
"""Unit tests for _build_preview_form_data helper."""
|
||||
|
||||
def test_merges_existing_params_with_new_config(self):
|
||||
"""New config values override existing form_data keys."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="region")],
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
chart = Mock()
|
||||
chart.id = 42
|
||||
chart.datasource_id = 7
|
||||
chart.slice_name = "Existing"
|
||||
chart.params = '{"viz_type": "line", "custom_flag": true}'
|
||||
|
||||
result = _build_preview_form_data(request, chart)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
# Existing keys not touched by the new config are preserved
|
||||
assert result["custom_flag"] is True
|
||||
# New config overrides existing keys
|
||||
assert result["viz_type"] == "table"
|
||||
# slice_id and datasource are always stamped onto the preview
|
||||
assert result["slice_id"] == 42
|
||||
assert result["datasource"] == "7__table"
|
||||
assert result["slice_name"] == "Existing"
|
||||
|
||||
def test_name_only_preview_keeps_existing_form_data(self):
|
||||
"""Name-only preview preserves existing form_data and renames."""
|
||||
request = UpdateChartRequest(identifier=1, chart_name="Brand New Name")
|
||||
chart = Mock()
|
||||
chart.id = 5
|
||||
chart.datasource_id = 3
|
||||
chart.slice_name = "Old"
|
||||
chart.params = '{"viz_type": "big_number", "metric": "count"}'
|
||||
|
||||
result = _build_preview_form_data(request, chart)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["viz_type"] == "big_number"
|
||||
assert result["metric"] == "count"
|
||||
assert result["slice_name"] == "Brand New Name"
|
||||
assert result["slice_id"] == 5
|
||||
|
||||
def test_missing_config_and_name_returns_validation_error(self):
|
||||
"""Matches the _build_update_payload validation behavior."""
|
||||
request = UpdateChartRequest(identifier=1)
|
||||
chart = Mock()
|
||||
chart.id = 1
|
||||
chart.datasource_id = 10
|
||||
chart.params = "{}"
|
||||
|
||||
result = _build_preview_form_data(request, chart)
|
||||
|
||||
assert isinstance(result, GenerateChartResponse)
|
||||
assert result.success is False
|
||||
assert result.error is not None
|
||||
assert result.error.error_type == "ValidationError"
|
||||
|
||||
def test_handles_invalid_existing_params(self):
|
||||
"""Gracefully recovers when chart.params is not valid JSON."""
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[ColumnRef(name="col1")],
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
chart = Mock()
|
||||
chart.id = 9
|
||||
chart.datasource_id = 4
|
||||
chart.slice_name = "Broken"
|
||||
chart.params = "not-json"
|
||||
|
||||
result = _build_preview_form_data(request, chart)
|
||||
|
||||
assert isinstance(result, dict)
|
||||
assert result["slice_id"] == 9
|
||||
assert result["slice_name"] == "Broken"
|
||||
|
||||
|
||||
class TestUpdateChartSaveWithConfig:
|
||||
"""Save-path integration tests for update_chart with a full config payload."""
|
||||
|
||||
@patch(
|
||||
"superset.commands.chart.update.UpdateChartCommand",
|
||||
new_callable=Mock,
|
||||
)
|
||||
@patch(
|
||||
"superset.mcp_service.auth.check_chart_data_access",
|
||||
new_callable=Mock,
|
||||
)
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_save_chart_with_config_success(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_check_access,
|
||||
mock_update_cmd_cls,
|
||||
mcp_server,
|
||||
):
|
||||
"""generate_preview=False with config persists and returns saved chart."""
|
||||
mock_chart = Mock()
|
||||
mock_chart.id = 77
|
||||
mock_chart.datasource_id = 10
|
||||
mock_chart.slice_name = "Pre-save"
|
||||
mock_chart.viz_type = "table"
|
||||
mock_chart.uuid = "uuid-77"
|
||||
mock_chart.params = '{"viz_type": "table"}'
|
||||
mock_find_by_id.return_value = mock_chart
|
||||
|
||||
mock_check_access.return_value = DatasetValidationResult(
|
||||
is_valid=True,
|
||||
dataset_id=10,
|
||||
dataset_name="my_dataset",
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
updated_chart = Mock()
|
||||
updated_chart.id = 77
|
||||
updated_chart.slice_name = "After-save"
|
||||
updated_chart.viz_type = "table"
|
||||
updated_chart.uuid = "uuid-77"
|
||||
mock_update_cmd_cls.return_value.run.return_value = updated_chart
|
||||
|
||||
request = {
|
||||
"identifier": 77,
|
||||
"generate_preview": False,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "col1"}],
|
||||
},
|
||||
}
|
||||
|
||||
async with Client(mcp) as client:
|
||||
result = await client.call_tool("update_chart", {"request": request})
|
||||
|
||||
assert result.structured_content["success"] is True
|
||||
chart = result.structured_content["chart"]
|
||||
assert chart["is_unsaved_state"] is False
|
||||
assert chart["id"] == 77
|
||||
assert chart["slice_name"] == "After-save"
|
||||
assert "slice_id=77" in result.structured_content["explore_url"]
|
||||
mock_update_cmd_cls.assert_called_once()
|
||||
|
||||
|
||||
class TestUpdateChartErrorPaths:
|
||||
"""Integration tests for error-handling branches in update_chart."""
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_chart_not_found_returns_notfound_error(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mcp_server,
|
||||
):
|
||||
"""Missing chart returns a structured NotFound error without raising."""
|
||||
mock_find_by_id.return_value = None
|
||||
|
||||
async with Client(mcp) as client:
|
||||
result = await client.call_tool(
|
||||
"update_chart", {"request": {"identifier": 9999}}
|
||||
)
|
||||
|
||||
assert result.structured_content["success"] is False
|
||||
error = result.structured_content["error"]
|
||||
assert error["error_type"] == "NotFound"
|
||||
assert "9999" in error["message"]
|
||||
|
||||
@patch(
|
||||
"superset.commands.chart.update.UpdateChartCommand",
|
||||
new_callable=Mock,
|
||||
)
|
||||
@patch(
|
||||
"superset.mcp_service.auth.check_chart_data_access",
|
||||
new_callable=Mock,
|
||||
)
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_command_exception_is_caught(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_check_access,
|
||||
mock_update_cmd_cls,
|
||||
mcp_server,
|
||||
):
|
||||
"""CommandException from UpdateChartCommand.run() is captured and returned."""
|
||||
from superset.commands.exceptions import CommandException
|
||||
|
||||
mock_chart = Mock()
|
||||
mock_chart.id = 5
|
||||
mock_chart.datasource_id = 10
|
||||
mock_chart.slice_name = "Name"
|
||||
mock_chart.viz_type = "table"
|
||||
mock_chart.uuid = "uuid-5"
|
||||
mock_chart.params = "{}"
|
||||
mock_find_by_id.return_value = mock_chart
|
||||
|
||||
mock_check_access.return_value = DatasetValidationResult(
|
||||
is_valid=True,
|
||||
dataset_id=10,
|
||||
dataset_name="my_dataset",
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
mock_update_cmd_cls.return_value.run.side_effect = CommandException("boom")
|
||||
|
||||
request = {
|
||||
"identifier": 5,
|
||||
"generate_preview": False,
|
||||
"chart_name": "Retry",
|
||||
}
|
||||
|
||||
async with Client(mcp) as client:
|
||||
result = await client.call_tool("update_chart", {"request": request})
|
||||
|
||||
assert result.structured_content["success"] is False
|
||||
error = result.structured_content["error"]
|
||||
assert error["error_type"] == "CommandException"
|
||||
assert "boom" in error["details"]
|
||||
|
||||
@patch.object(update_chart_module, "_create_preview_url", new_callable=Mock)
|
||||
@patch(
|
||||
"superset.mcp_service.auth.check_chart_data_access",
|
||||
new_callable=Mock,
|
||||
)
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock)
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_preview_extracts_form_data_key_from_url_fallback(
|
||||
self,
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_check_access,
|
||||
mock_create_preview,
|
||||
mcp_server,
|
||||
):
|
||||
"""If _create_preview_url returns (url, None), form_data_key comes from url."""
|
||||
mock_chart = Mock()
|
||||
mock_chart.id = 8
|
||||
mock_chart.datasource_id = 10
|
||||
mock_chart.slice_name = "Chart"
|
||||
mock_chart.viz_type = "table"
|
||||
mock_chart.uuid = "uuid-8"
|
||||
mock_chart.params = '{"viz_type": "table"}'
|
||||
mock_find_by_id.return_value = mock_chart
|
||||
|
||||
mock_check_access.return_value = DatasetValidationResult(
|
||||
is_valid=True,
|
||||
dataset_id=10,
|
||||
dataset_name="my_dataset",
|
||||
warnings=[],
|
||||
)
|
||||
|
||||
preview_url = (
|
||||
"http://localhost:8088/explore/?form_data_key=url_embedded_key&slice_id=8"
|
||||
)
|
||||
mock_create_preview.return_value = (preview_url, None, [])
|
||||
|
||||
request = {
|
||||
"identifier": 8,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "col1"}],
|
||||
},
|
||||
}
|
||||
|
||||
async with Client(mcp) as client:
|
||||
result = await client.call_tool("update_chart", {"request": request})
|
||||
|
||||
assert result.structured_content["success"] is True
|
||||
assert result.structured_content["form_data_key"] == "url_embedded_key"
|
||||
|
||||
Reference in New Issue
Block a user