feat(mcp): add compile check to validate chart queries before returning (#38408)

This commit is contained in:
Amin Ghadersohi
2026-03-06 02:10:58 -05:00
committed by GitHub
parent 7d2efd8c1a
commit 3609cd9544
2 changed files with 268 additions and 0 deletions

View File

@@ -20,6 +20,8 @@ MCP tool: generate_chart (simplified schema)
import logging
import time
from dataclasses import dataclass, field
from typing import Any, Dict, List
from urllib.parse import parse_qs, urlparse
from fastmcp import Context
@@ -48,6 +50,73 @@ from superset.utils import json
logger = logging.getLogger(__name__)
@dataclass
class CompileResult:
"""Result of a chart compile check (test query execution)."""
success: bool
error: str | None = None
warnings: List[str] = field(default_factory=list)
row_count: int | None = None
def _compile_chart(
form_data: Dict[str, Any],
dataset_id: int,
) -> CompileResult:
"""Execute the chart's query to verify it renders without errors.
Builds a ``QueryContext`` from *form_data* and runs it through
``ChartDataCommand``. A small ``row_limit`` is used so the check is
fast — we only need to know the query compiles and returns data, not
fetch the full result set.
Returns a :class:`CompileResult` with ``success=True`` when the
query executes cleanly.
"""
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.commands.chart.exceptions import (
ChartDataCacheLoadError,
ChartDataQueryFailedError,
)
from superset.common.query_context_factory import QueryContextFactory
from superset.mcp_service.chart.preview_utils import _build_query_columns
try:
columns = _build_query_columns(form_data)
factory = QueryContextFactory()
query_context = factory.create(
datasource={"id": dataset_id, "type": "table"},
queries=[
{
"columns": columns,
"metrics": form_data.get("metrics", []),
"orderby": form_data.get("orderby", []),
"row_limit": 2,
"filters": form_data.get("adhoc_filters", []),
"time_range": form_data.get("time_range", "No filter"),
}
],
form_data=form_data,
)
command = ChartDataCommand(query_context)
result = command.run()
warnings: List[str] = []
row_count = 0
for query in result.get("queries", []):
if query.get("error"):
return CompileResult(success=False, error=str(query["error"]))
row_count += len(query.get("data", []))
return CompileResult(success=True, warnings=warnings, row_count=row_count)
except (ChartDataQueryFailedError, ChartDataCacheLoadError) as exc:
return CompileResult(success=False, error=str(exc))
except (CommandException, ValueError, KeyError) as exc:
return CompileResult(success=False, error=str(exc))
@tool(tags=["mutate"])
@parse_request(GenerateChartRequest)
async def generate_chart( # noqa: C901
@@ -321,6 +390,62 @@ async def generate_chart( # noqa: C901
# Add any validation warnings (e.g., virtual dataset warnings)
response_warnings.extend(dataset_check.warnings)
# Compile check: execute the chart query to catch runtime errors
await ctx.report_progress(3, 5, "Running compile check (test query)")
with event_logger.log_context(
action="mcp.generate_chart.compile_check"
):
compile_result = _compile_chart(form_data, dataset.id)
if not compile_result.success:
# Query failed — delete the broken chart and return an error
logger.warning(
"Compile check failed for chart %s: %s",
chart.id,
compile_result.error,
)
await ctx.error(
"Chart compile check failed: error=%s" % (compile_result.error,)
)
from superset.daos.chart import ChartDAO
ChartDAO.delete([chart])
from superset.mcp_service.common.error_schemas import (
ChartGenerationError,
)
execution_time = int((time.time() - start_time) * 1000)
error = ChartGenerationError(
error_type="compile_error",
message=(
"Chart query failed to execute. The chart was not saved."
),
details=str(compile_result.error) or "",
suggestions=[
"Check that all columns exist in the dataset",
"Verify aggregate functions are compatible "
"with column types",
"Ensure filters reference valid columns",
"Try simplifying the chart configuration",
],
error_code="CHART_COMPILE_FAILED",
)
return GenerateChartResponse.model_validate(
{
"chart": None,
"error": error.model_dump(),
"form_data": form_data,
"performance": {
"query_duration_ms": execution_time,
"cache_status": "error",
"optimization_suggestions": [],
},
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
)
response_warnings.extend(compile_result.warnings)
except CommandException as e:
logger.error("Chart creation failed: %s", e)
await ctx.error("Chart creation failed: error=%s" % (str(e),))
@@ -384,6 +509,75 @@ async def generate_chart( # noqa: C901
if form_data_key_list:
form_data_key = form_data_key_list[0]
# Compile check for preview-only mode
# Validate dataset existence and user access before running queries
await ctx.report_progress(3, 5, "Running compile check (test query)")
numeric_dataset_id: int | None = None
from superset.daos.dataset import DatasetDAO
if isinstance(request.dataset_id, int) or (
isinstance(request.dataset_id, str) and request.dataset_id.isdigit()
):
candidate_id = (
int(request.dataset_id)
if isinstance(request.dataset_id, str)
else request.dataset_id
)
ds = DatasetDAO.find_by_id(candidate_id)
if ds and has_dataset_access(ds):
numeric_dataset_id = ds.id
else:
ds = DatasetDAO.find_by_id(request.dataset_id, id_column="uuid")
if ds and has_dataset_access(ds):
numeric_dataset_id = ds.id
if numeric_dataset_id is not None:
with event_logger.log_context(
action="mcp.generate_chart.compile_check"
):
compile_result = _compile_chart(form_data, numeric_dataset_id)
if not compile_result.success:
await ctx.error(
"Chart compile check failed: error=%s" % (compile_result.error,)
)
from superset.mcp_service.common.error_schemas import (
ChartGenerationError,
)
execution_time = int((time.time() - start_time) * 1000)
error = ChartGenerationError(
error_type="compile_error",
message=(
"Chart query failed to execute. "
"The chart configuration is invalid."
),
details=str(compile_result.error) or "",
suggestions=[
"Check that all columns exist in the dataset",
"Verify aggregate functions are compatible "
"with column types",
"Ensure filters reference valid columns",
"Try simplifying the chart configuration",
],
error_code="CHART_COMPILE_FAILED",
)
return GenerateChartResponse.model_validate(
{
"chart": None,
"error": error.model_dump(),
"form_data": form_data,
"performance": {
"query_duration_ms": execution_time,
"cache_status": "error",
"optimization_suggestions": [],
},
"success": False,
"schema_version": "2.0",
"api_version": "v1",
}
)
response_warnings.extend(compile_result.warnings)
# Generate semantic analysis
capabilities = analyze_chart_capabilities(chart, request.config)
semantics = analyze_chart_semantics(chart, request.config)

View File

@@ -19,6 +19,8 @@
Unit tests for MCP generate_chart tool
"""
from unittest.mock import MagicMock, patch
import pytest
from superset.mcp_service.chart.schemas import (
@@ -30,6 +32,10 @@ from superset.mcp_service.chart.schemas import (
TableChartConfig,
XYChartConfig,
)
from superset.mcp_service.chart.tool.generate_chart import (
_compile_chart,
CompileResult,
)
class TestGenerateChart:
@@ -277,3 +283,71 @@ class TestGenerateChart:
# Hidden legend
legend = LegendConfig(show=False)
assert legend.show is False
class TestCompileChart:
"""Tests for _compile_chart helper."""
@patch("superset.commands.chart.data.get_data_command.ChartDataCommand")
@patch("superset.common.query_context_factory.QueryContextFactory")
def test_compile_chart_success(self, mock_factory_cls, mock_cmd_cls):
"""Test _compile_chart returns success when query executes cleanly."""
mock_factory_cls.return_value.create.return_value = MagicMock()
mock_cmd_cls.return_value.run.return_value = {
"queries": [{"data": [{"col": 1}, {"col": 2}]}]
}
form_data = {
"viz_type": "echarts_timeseries_bar",
"metrics": [{"label": "count", "expressionType": "SIMPLE"}],
"groupby": ["region"],
}
result = _compile_chart(form_data, dataset_id=1)
assert isinstance(result, CompileResult)
assert result.success is True
assert result.error is None
assert result.row_count == 2
@patch("superset.commands.chart.data.get_data_command.ChartDataCommand")
@patch("superset.common.query_context_factory.QueryContextFactory")
def test_compile_chart_query_error_in_payload(self, mock_factory_cls, mock_cmd_cls):
"""Test _compile_chart detects errors embedded in query results."""
mock_factory_cls.return_value.create.return_value = MagicMock()
mock_cmd_cls.return_value.run.return_value = {
"queries": [{"error": "column 'bad_col' does not exist"}]
}
result = _compile_chart({"metrics": []}, dataset_id=1)
assert result.success is False
assert "bad_col" in (result.error or "")
@patch("superset.commands.chart.data.get_data_command.ChartDataCommand")
@patch("superset.common.query_context_factory.QueryContextFactory")
def test_compile_chart_command_exception(self, mock_factory_cls, mock_cmd_cls):
"""Test _compile_chart handles ChartDataQueryFailedError."""
from superset.commands.chart.exceptions import (
ChartDataQueryFailedError,
)
mock_factory_cls.return_value.create.return_value = MagicMock()
mock_cmd_cls.return_value.run.side_effect = ChartDataQueryFailedError(
"syntax error near FROM"
)
result = _compile_chart({"metrics": []}, dataset_id=1)
assert result.success is False
assert "syntax error" in (result.error or "")
@patch("superset.commands.chart.data.get_data_command.ChartDataCommand")
@patch("superset.common.query_context_factory.QueryContextFactory")
def test_compile_chart_value_error(self, mock_factory_cls, mock_cmd_cls):
"""Test _compile_chart handles ValueError from bad config."""
mock_factory_cls.return_value.create.side_effect = ValueError("invalid metric")
result = _compile_chart({"metrics": []}, dataset_id=1)
assert result.success is False
assert "invalid metric" in (result.error or "")