mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
feat(mcp): add compile check to validate chart queries before returning (#38408)
This commit is contained in:
@@ -20,6 +20,8 @@ MCP tool: generate_chart (simplified schema)
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Dict, List
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from fastmcp import Context
|
||||
@@ -48,6 +50,73 @@ from superset.utils import json
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CompileResult:
|
||||
"""Result of a chart compile check (test query execution)."""
|
||||
|
||||
success: bool
|
||||
error: str | None = None
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
row_count: int | None = None
|
||||
|
||||
|
||||
def _compile_chart(
|
||||
form_data: Dict[str, Any],
|
||||
dataset_id: int,
|
||||
) -> CompileResult:
|
||||
"""Execute the chart's query to verify it renders without errors.
|
||||
|
||||
Builds a ``QueryContext`` from *form_data* and runs it through
|
||||
``ChartDataCommand``. A small ``row_limit`` is used so the check is
|
||||
fast — we only need to know the query compiles and returns data, not
|
||||
fetch the full result set.
|
||||
|
||||
Returns a :class:`CompileResult` with ``success=True`` when the
|
||||
query executes cleanly.
|
||||
"""
|
||||
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
||||
from superset.commands.chart.exceptions import (
|
||||
ChartDataCacheLoadError,
|
||||
ChartDataQueryFailedError,
|
||||
)
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.mcp_service.chart.preview_utils import _build_query_columns
|
||||
|
||||
try:
|
||||
columns = _build_query_columns(form_data)
|
||||
factory = QueryContextFactory()
|
||||
query_context = factory.create(
|
||||
datasource={"id": dataset_id, "type": "table"},
|
||||
queries=[
|
||||
{
|
||||
"columns": columns,
|
||||
"metrics": form_data.get("metrics", []),
|
||||
"orderby": form_data.get("orderby", []),
|
||||
"row_limit": 2,
|
||||
"filters": form_data.get("adhoc_filters", []),
|
||||
"time_range": form_data.get("time_range", "No filter"),
|
||||
}
|
||||
],
|
||||
form_data=form_data,
|
||||
)
|
||||
|
||||
command = ChartDataCommand(query_context)
|
||||
result = command.run()
|
||||
|
||||
warnings: List[str] = []
|
||||
row_count = 0
|
||||
for query in result.get("queries", []):
|
||||
if query.get("error"):
|
||||
return CompileResult(success=False, error=str(query["error"]))
|
||||
row_count += len(query.get("data", []))
|
||||
|
||||
return CompileResult(success=True, warnings=warnings, row_count=row_count)
|
||||
except (ChartDataQueryFailedError, ChartDataCacheLoadError) as exc:
|
||||
return CompileResult(success=False, error=str(exc))
|
||||
except (CommandException, ValueError, KeyError) as exc:
|
||||
return CompileResult(success=False, error=str(exc))
|
||||
|
||||
|
||||
@tool(tags=["mutate"])
|
||||
@parse_request(GenerateChartRequest)
|
||||
async def generate_chart( # noqa: C901
|
||||
@@ -321,6 +390,62 @@ async def generate_chart( # noqa: C901
|
||||
# Add any validation warnings (e.g., virtual dataset warnings)
|
||||
response_warnings.extend(dataset_check.warnings)
|
||||
|
||||
# Compile check: execute the chart query to catch runtime errors
|
||||
await ctx.report_progress(3, 5, "Running compile check (test query)")
|
||||
with event_logger.log_context(
|
||||
action="mcp.generate_chart.compile_check"
|
||||
):
|
||||
compile_result = _compile_chart(form_data, dataset.id)
|
||||
if not compile_result.success:
|
||||
# Query failed — delete the broken chart and return an error
|
||||
logger.warning(
|
||||
"Compile check failed for chart %s: %s",
|
||||
chart.id,
|
||||
compile_result.error,
|
||||
)
|
||||
await ctx.error(
|
||||
"Chart compile check failed: error=%s" % (compile_result.error,)
|
||||
)
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
ChartDAO.delete([chart])
|
||||
from superset.mcp_service.common.error_schemas import (
|
||||
ChartGenerationError,
|
||||
)
|
||||
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
error = ChartGenerationError(
|
||||
error_type="compile_error",
|
||||
message=(
|
||||
"Chart query failed to execute. The chart was not saved."
|
||||
),
|
||||
details=str(compile_result.error) or "",
|
||||
suggestions=[
|
||||
"Check that all columns exist in the dataset",
|
||||
"Verify aggregate functions are compatible "
|
||||
"with column types",
|
||||
"Ensure filters reference valid columns",
|
||||
"Try simplifying the chart configuration",
|
||||
],
|
||||
error_code="CHART_COMPILE_FAILED",
|
||||
)
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"error": error.model_dump(),
|
||||
"form_data": form_data,
|
||||
"performance": {
|
||||
"query_duration_ms": execution_time,
|
||||
"cache_status": "error",
|
||||
"optimization_suggestions": [],
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
)
|
||||
response_warnings.extend(compile_result.warnings)
|
||||
|
||||
except CommandException as e:
|
||||
logger.error("Chart creation failed: %s", e)
|
||||
await ctx.error("Chart creation failed: error=%s" % (str(e),))
|
||||
@@ -384,6 +509,75 @@ async def generate_chart( # noqa: C901
|
||||
if form_data_key_list:
|
||||
form_data_key = form_data_key_list[0]
|
||||
|
||||
# Compile check for preview-only mode
|
||||
# Validate dataset existence and user access before running queries
|
||||
await ctx.report_progress(3, 5, "Running compile check (test query)")
|
||||
numeric_dataset_id: int | None = None
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
if isinstance(request.dataset_id, int) or (
|
||||
isinstance(request.dataset_id, str) and request.dataset_id.isdigit()
|
||||
):
|
||||
candidate_id = (
|
||||
int(request.dataset_id)
|
||||
if isinstance(request.dataset_id, str)
|
||||
else request.dataset_id
|
||||
)
|
||||
ds = DatasetDAO.find_by_id(candidate_id)
|
||||
if ds and has_dataset_access(ds):
|
||||
numeric_dataset_id = ds.id
|
||||
else:
|
||||
ds = DatasetDAO.find_by_id(request.dataset_id, id_column="uuid")
|
||||
if ds and has_dataset_access(ds):
|
||||
numeric_dataset_id = ds.id
|
||||
|
||||
if numeric_dataset_id is not None:
|
||||
with event_logger.log_context(
|
||||
action="mcp.generate_chart.compile_check"
|
||||
):
|
||||
compile_result = _compile_chart(form_data, numeric_dataset_id)
|
||||
if not compile_result.success:
|
||||
await ctx.error(
|
||||
"Chart compile check failed: error=%s" % (compile_result.error,)
|
||||
)
|
||||
from superset.mcp_service.common.error_schemas import (
|
||||
ChartGenerationError,
|
||||
)
|
||||
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
error = ChartGenerationError(
|
||||
error_type="compile_error",
|
||||
message=(
|
||||
"Chart query failed to execute. "
|
||||
"The chart configuration is invalid."
|
||||
),
|
||||
details=str(compile_result.error) or "",
|
||||
suggestions=[
|
||||
"Check that all columns exist in the dataset",
|
||||
"Verify aggregate functions are compatible "
|
||||
"with column types",
|
||||
"Ensure filters reference valid columns",
|
||||
"Try simplifying the chart configuration",
|
||||
],
|
||||
error_code="CHART_COMPILE_FAILED",
|
||||
)
|
||||
return GenerateChartResponse.model_validate(
|
||||
{
|
||||
"chart": None,
|
||||
"error": error.model_dump(),
|
||||
"form_data": form_data,
|
||||
"performance": {
|
||||
"query_duration_ms": execution_time,
|
||||
"cache_status": "error",
|
||||
"optimization_suggestions": [],
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
)
|
||||
response_warnings.extend(compile_result.warnings)
|
||||
|
||||
# Generate semantic analysis
|
||||
capabilities = analyze_chart_capabilities(chart, request.config)
|
||||
semantics = analyze_chart_semantics(chart, request.config)
|
||||
|
||||
@@ -19,6 +19,8 @@
|
||||
Unit tests for MCP generate_chart tool
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
@@ -30,6 +32,10 @@ from superset.mcp_service.chart.schemas import (
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
from superset.mcp_service.chart.tool.generate_chart import (
|
||||
_compile_chart,
|
||||
CompileResult,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerateChart:
|
||||
@@ -277,3 +283,71 @@ class TestGenerateChart:
|
||||
# Hidden legend
|
||||
legend = LegendConfig(show=False)
|
||||
assert legend.show is False
|
||||
|
||||
|
||||
class TestCompileChart:
|
||||
"""Tests for _compile_chart helper."""
|
||||
|
||||
@patch("superset.commands.chart.data.get_data_command.ChartDataCommand")
|
||||
@patch("superset.common.query_context_factory.QueryContextFactory")
|
||||
def test_compile_chart_success(self, mock_factory_cls, mock_cmd_cls):
|
||||
"""Test _compile_chart returns success when query executes cleanly."""
|
||||
mock_factory_cls.return_value.create.return_value = MagicMock()
|
||||
mock_cmd_cls.return_value.run.return_value = {
|
||||
"queries": [{"data": [{"col": 1}, {"col": 2}]}]
|
||||
}
|
||||
|
||||
form_data = {
|
||||
"viz_type": "echarts_timeseries_bar",
|
||||
"metrics": [{"label": "count", "expressionType": "SIMPLE"}],
|
||||
"groupby": ["region"],
|
||||
}
|
||||
result = _compile_chart(form_data, dataset_id=1)
|
||||
|
||||
assert isinstance(result, CompileResult)
|
||||
assert result.success is True
|
||||
assert result.error is None
|
||||
assert result.row_count == 2
|
||||
|
||||
@patch("superset.commands.chart.data.get_data_command.ChartDataCommand")
|
||||
@patch("superset.common.query_context_factory.QueryContextFactory")
|
||||
def test_compile_chart_query_error_in_payload(self, mock_factory_cls, mock_cmd_cls):
|
||||
"""Test _compile_chart detects errors embedded in query results."""
|
||||
mock_factory_cls.return_value.create.return_value = MagicMock()
|
||||
mock_cmd_cls.return_value.run.return_value = {
|
||||
"queries": [{"error": "column 'bad_col' does not exist"}]
|
||||
}
|
||||
|
||||
result = _compile_chart({"metrics": []}, dataset_id=1)
|
||||
|
||||
assert result.success is False
|
||||
assert "bad_col" in (result.error or "")
|
||||
|
||||
@patch("superset.commands.chart.data.get_data_command.ChartDataCommand")
|
||||
@patch("superset.common.query_context_factory.QueryContextFactory")
|
||||
def test_compile_chart_command_exception(self, mock_factory_cls, mock_cmd_cls):
|
||||
"""Test _compile_chart handles ChartDataQueryFailedError."""
|
||||
from superset.commands.chart.exceptions import (
|
||||
ChartDataQueryFailedError,
|
||||
)
|
||||
|
||||
mock_factory_cls.return_value.create.return_value = MagicMock()
|
||||
mock_cmd_cls.return_value.run.side_effect = ChartDataQueryFailedError(
|
||||
"syntax error near FROM"
|
||||
)
|
||||
|
||||
result = _compile_chart({"metrics": []}, dataset_id=1)
|
||||
|
||||
assert result.success is False
|
||||
assert "syntax error" in (result.error or "")
|
||||
|
||||
@patch("superset.commands.chart.data.get_data_command.ChartDataCommand")
|
||||
@patch("superset.common.query_context_factory.QueryContextFactory")
|
||||
def test_compile_chart_value_error(self, mock_factory_cls, mock_cmd_cls):
|
||||
"""Test _compile_chart handles ValueError from bad config."""
|
||||
mock_factory_cls.return_value.create.side_effect = ValueError("invalid metric")
|
||||
|
||||
result = _compile_chart({"metrics": []}, dataset_id=1)
|
||||
|
||||
assert result.success is False
|
||||
assert "invalid metric" in (result.error or "")
|
||||
|
||||
Reference in New Issue
Block a user