fix(mcp): warn on invalid chart preview form data key (#39891)

Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Richard Fogaca Nienkotter
2026-05-05 16:40:00 -03:00
committed by GitHub
parent cb53745d43
commit 9459bc7bf4
2 changed files with 191 additions and 8 deletions

View File

@@ -51,9 +51,16 @@ from superset.utils import json as utils_json
logger = logging.getLogger(__name__)
INVALID_FORM_DATA_KEY_WARNING = (
"Previous cached chart state could not be loaded from the previous "
"form_data_key. The preview was generated from the supplied "
"configuration only; the previous form_data_key may be invalid or "
"expired."
)
def _get_old_adhoc_filters(form_data_key: str) -> list[Dict[str, Any]] | None:
"""Retrieve adhoc_filters from the previously cached form_data."""
def _get_previous_form_data(form_data_key: str) -> dict[str, Any] | None:
"""Retrieve the previously cached form_data."""
from superset.commands.exceptions import CommandException
from superset.commands.explore.form_data.get import GetFormDataCommand
from superset.commands.explore.form_data.parameters import CommandParameters
@@ -65,11 +72,9 @@ def _get_old_adhoc_filters(form_data_key: str) -> list[Dict[str, Any]] | None:
if isinstance(cached_data, str):
cached_data = utils_json.loads(cached_data)
if isinstance(cached_data, dict):
adhoc_filters = cached_data.get("adhoc_filters")
if adhoc_filters:
return adhoc_filters
return cached_data
except (KeyError, ValueError, TypeError, CommandException):
logger.debug("Could not retrieve old form_data for filter preservation")
logger.debug("Could not retrieve previous form_data from cache")
return None
@@ -113,11 +118,18 @@ def update_chart_preview(
config, dataset_id=request.dataset_id
)
new_form_data.pop("_mcp_warnings", None)
warnings: list[str] = []
previous_form_data: dict[str, Any] | None = None
if request.form_data_key:
previous_form_data = _get_previous_form_data(request.form_data_key)
if previous_form_data is None:
warnings.append(INVALID_FORM_DATA_KEY_WARNING)
# Preserve adhoc filters from the previous cached form_data
# when the new config doesn't explicitly specify filters
if getattr(config, "filters", None) is None and request.form_data_key:
old_adhoc_filters = _get_old_adhoc_filters(request.form_data_key)
if getattr(config, "filters", None) is None and previous_form_data:
old_adhoc_filters = previous_form_data.get("adhoc_filters")
if old_adhoc_filters:
new_form_data["adhoc_filters"] = old_adhoc_filters
@@ -183,6 +195,7 @@ def update_chart_preview(
"explore_url": explore_url,
"form_data_key": new_form_data_key,
"previous_form_data_key": request.form_data_key, # For reference
"warnings": warnings,
"api_endpoints": {}, # No API endpoints for unsaved charts
"performance": performance.model_dump() if performance else None,
"accessibility": accessibility.model_dump() if accessibility else None,

View File

@@ -19,6 +19,9 @@
Unit tests for update_chart_preview MCP tool
"""
import importlib
from unittest.mock import Mock, patch
import pytest
from superset.mcp_service.chart.schemas import (
@@ -31,6 +34,10 @@ from superset.mcp_service.chart.schemas import (
XYChartConfig,
)
update_chart_preview_module = importlib.import_module(
"superset.mcp_service.chart.tool.update_chart_preview"
)
class TestUpdateChartPreview:
"""Tests for update_chart_preview MCP tool."""
@@ -218,6 +225,7 @@ class TestUpdateChartPreview:
"explore_url",
"form_data_key",
"previous_form_data_key",
"warnings",
"api_endpoints",
"performance",
"accessibility",
@@ -472,3 +480,165 @@ class TestUpdateChartPreview:
)
assert request.config.sort_by == ["sales", "profit"]
assert len(request.config.columns) == 3
@patch("superset.commands.explore.form_data.get.GetFormDataCommand")
def test_get_previous_form_data_parses_json_cache_hit(
self,
mock_get_form_data_command,
) -> None:
"""Previous form_data lookup parses JSON strings from the cache."""
cached_adhoc_filters = [
{
"clause": "WHERE",
"comparator": "North",
"expressionType": "SIMPLE",
"operator": "==",
"subject": "region",
}
]
mock_get_form_data_command.return_value.run.return_value = (
'{"adhoc_filters": ['
'{"clause": "WHERE", "comparator": "North", '
'"expressionType": "SIMPLE", "operator": "==", "subject": "region"}'
'], "viz_type": "table"}'
)
result = update_chart_preview_module._get_previous_form_data("valid_key_12345")
assert result == {
"adhoc_filters": cached_adhoc_filters,
"viz_type": "table",
}
command_params = mock_get_form_data_command.call_args.args[0]
assert command_params.key == "valid_key_12345"
@patch("superset.commands.explore.form_data.get.GetFormDataCommand")
def test_get_previous_form_data_returns_none_for_cache_failure(
self,
mock_get_form_data_command,
) -> None:
"""Previous form_data lookup treats command failures as cache misses."""
mock_get_form_data_command.return_value.run.side_effect = (
update_chart_preview_module.CommandException("cache read failed")
)
result = update_chart_preview_module._get_previous_form_data(
"missing_key_12345"
)
assert result is None
@patch.object(update_chart_preview_module, "analyze_chart_semantics")
@patch.object(update_chart_preview_module, "analyze_chart_capabilities")
@patch.object(update_chart_preview_module, "generate_explore_link")
@patch.object(update_chart_preview_module, "_get_previous_form_data")
@patch("superset.mcp_service.auth.get_user_from_request")
@pytest.mark.asyncio
async def test_warns_when_previous_form_data_key_is_missing(
self,
mock_get_user_from_request,
mock_get_previous_form_data,
mock_generate_explore_link,
mock_analyze_chart_capabilities,
mock_analyze_chart_semantics,
) -> None:
"""Invalid previous form_data_key is warning-only for preview updates."""
mock_user = Mock()
mock_user.id = 1
mock_get_user_from_request.return_value = mock_user
mock_get_previous_form_data.return_value = None
mock_generate_explore_link.return_value = (
"http://localhost:8088/explore/?form_data_key=new_preview_key"
)
mock_analyze_chart_capabilities.return_value = None
mock_analyze_chart_semantics.return_value = None
request = UpdateChartPreviewRequest(
form_data_key="nonexistent_key_12345",
dataset_id=3,
config=TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="country", label="Country"),
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
],
sort_by=["sales"],
),
generate_preview=True,
preview_formats=["table"],
)
result = update_chart_preview_module.update_chart_preview(
request=request, ctx=Mock()
)
assert result["success"] is True
assert result["error"] is None
assert result["previous_form_data_key"] == "nonexistent_key_12345"
assert result["form_data_key"] == "new_preview_key"
assert result["warnings"] == [
update_chart_preview_module.INVALID_FORM_DATA_KEY_WARNING
]
mock_get_previous_form_data.assert_called_once_with("nonexistent_key_12345")
@patch.object(update_chart_preview_module, "analyze_chart_semantics")
@patch.object(update_chart_preview_module, "analyze_chart_capabilities")
@patch.object(update_chart_preview_module, "generate_explore_link")
@patch.object(update_chart_preview_module, "_get_previous_form_data")
@patch("superset.mcp_service.auth.get_user_from_request")
@pytest.mark.asyncio
async def test_preserves_previous_adhoc_filters_without_warning(
self,
mock_get_user_from_request,
mock_get_previous_form_data,
mock_generate_explore_link,
mock_analyze_chart_capabilities,
mock_analyze_chart_semantics,
) -> None:
"""Valid previous form_data preserves filters without a cache warning."""
mock_user = Mock()
mock_user.id = 1
mock_get_user_from_request.return_value = mock_user
cached_adhoc_filters = [
{
"clause": "WHERE",
"comparator": "North",
"expressionType": "SIMPLE",
"operator": "==",
"subject": "region",
}
]
mock_get_previous_form_data.return_value = {
"adhoc_filters": cached_adhoc_filters
}
mock_generate_explore_link.return_value = (
"http://localhost:8088/explore/?form_data_key=new_preview_key"
)
mock_analyze_chart_capabilities.return_value = None
mock_analyze_chart_semantics.return_value = None
request = UpdateChartPreviewRequest(
form_data_key="valid_key_12345",
dataset_id=3,
config=TableChartConfig(
chart_type="table",
columns=[
ColumnRef(name="country", label="Country"),
ColumnRef(name="sales", label="Sales", aggregate="SUM"),
],
sort_by=["sales"],
),
generate_preview=True,
preview_formats=["table"],
)
result = update_chart_preview_module.update_chart_preview(
request=request, ctx=Mock()
)
generated_form_data = mock_generate_explore_link.call_args.args[1]
assert generated_form_data["adhoc_filters"] == cached_adhoc_filters
assert result["success"] is True
assert result["error"] is None
assert result["warnings"] == []
mock_get_previous_form_data.assert_called_once_with("valid_key_12345")