fix(mcp): add missing command.validate() to MCP chart data tools (#38521)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-03-10 09:50:41 +01:00
committed by GitHub
parent 0533ca9941
commit 2a876e8b86
5 changed files with 193 additions and 0 deletions

View File

@@ -103,6 +103,7 @@ def generate_preview_from_form_data(
# Execute query
command = ChartDataCommand(query_context_obj)
command.validate()
result = command.run()
if not result or not result.get("queries"):

View File

@@ -101,6 +101,7 @@ def _compile_chart(
)
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
warnings: List[str] = []

View File

@@ -462,6 +462,7 @@ async def get_chart_data( # noqa: C901
# Execute the query
with event_logger.log_context(action="mcp.get_chart_data.query_execution"):
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
# Handle empty query results for certain chart types

View File

@@ -160,6 +160,7 @@ class ASCIIPreviewStrategy(PreviewFormatStrategy):
)
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
data = []
@@ -234,6 +235,7 @@ class TablePreviewStrategy(PreviewFormatStrategy):
)
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
data = []
@@ -340,6 +342,7 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
# Execute the query
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
# Extract data from result

View File

@@ -671,3 +671,190 @@ class TestGetChartDataRequestSchema:
assert data["identifier"] == 123
assert data["limit"] == 50
assert data["format"] == "json"
class TestChartDataCommandValidation:
"""Tests that ChartDataCommand.validate() is called before run().
These tests verify the security fix (CWE-862) that adds command.validate()
before command.run() in MCP chart data tools. The validate() call enforces
row-level security, guest user guards, and schema-level permissions.
"""
def test_preview_utils_calls_validate_before_run(self):
"""Test that generate_preview_from_form_data calls validate() before run()."""
from unittest.mock import MagicMock, patch
call_order: list[str] = []
mock_query_result = {"queries": [{"data": [{"col1": "a", "col2": 1}]}]}
mock_command = MagicMock()
mock_command.validate.side_effect = lambda: call_order.append("validate")
mock_command.run.side_effect = lambda: (
call_order.append("run"),
mock_query_result,
)[1]
mock_dataset = MagicMock()
mock_dataset.id = 10
# ChartDataCommand is module-level import in preview_utils;
# db and QueryContextFactory are local imports inside the function.
with (
patch("superset.extensions.db") as mock_db,
patch(
"superset.mcp_service.chart.preview_utils.ChartDataCommand",
return_value=mock_command,
),
patch(
"superset.common.query_context_factory.QueryContextFactory"
) as mock_factory,
):
mock_db.session.query.return_value.get.return_value = mock_dataset
mock_factory.return_value.create.return_value = MagicMock()
from superset.mcp_service.chart.preview_utils import (
generate_preview_from_form_data,
)
generate_preview_from_form_data(
form_data={"metrics": [{"label": "count"}], "viz_type": "table"},
dataset_id=10,
preview_format="table",
)
mock_command.validate.assert_called_once()
mock_command.run.assert_called_once()
assert call_order == ["validate", "run"]
def test_preview_utils_security_exception_from_validate(self):
"""Test that SupersetSecurityException from validate() is propagated."""
from unittest.mock import MagicMock, patch
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from superset.mcp_service.chart.schemas import ChartError
security_error = SupersetSecurityException(
SupersetError(
message="Access denied",
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.ERROR,
)
)
mock_command = MagicMock()
mock_command.validate.side_effect = security_error
mock_dataset = MagicMock()
mock_dataset.id = 10
with (
patch("superset.extensions.db") as mock_db,
patch(
"superset.mcp_service.chart.preview_utils.ChartDataCommand",
return_value=mock_command,
),
patch(
"superset.common.query_context_factory.QueryContextFactory"
) as mock_factory,
):
mock_db.session.query.return_value.get.return_value = mock_dataset
mock_factory.return_value.create.return_value = MagicMock()
from superset.mcp_service.chart.preview_utils import (
generate_preview_from_form_data,
)
result = generate_preview_from_form_data(
form_data={"metrics": [{"label": "count"}], "viz_type": "table"},
dataset_id=10,
preview_format="table",
)
# SupersetSecurityException is caught by the broad except and
# returned as a ChartError
assert isinstance(result, ChartError)
assert "Access denied" in result.error
mock_command.run.assert_not_called()
def test_compile_chart_calls_validate_before_run(self):
"""Test that _compile_chart calls validate() before run()."""
from unittest.mock import MagicMock, patch
call_order: list[str] = []
mock_query_result = {"queries": [{"data": [{"col1": 1}]}]}
mock_command = MagicMock()
mock_command.validate.side_effect = lambda: call_order.append("validate")
mock_command.run.side_effect = lambda: (
call_order.append("run"),
mock_query_result,
)[1]
# Both ChartDataCommand and QueryContextFactory are local imports
# inside _compile_chart, so patch at source.
with (
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand",
return_value=mock_command,
),
patch(
"superset.common.query_context_factory.QueryContextFactory"
) as mock_factory,
):
mock_factory.return_value.create.return_value = MagicMock()
from superset.mcp_service.chart.tool.generate_chart import _compile_chart
result = _compile_chart(
form_data={"metrics": [{"label": "count"}], "viz_type": "table"},
dataset_id=10,
)
assert result.success is True
mock_command.validate.assert_called_once()
mock_command.run.assert_called_once()
assert call_order == ["validate", "run"]
def test_compile_chart_security_exception_from_validate(self):
"""Test that _compile_chart propagates security exception from validate()."""
from unittest.mock import MagicMock, patch
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
security_error = SupersetSecurityException(
SupersetError(
message="Row-level security violation",
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.ERROR,
)
)
mock_command = MagicMock()
mock_command.validate.side_effect = security_error
with (
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand",
return_value=mock_command,
),
patch(
"superset.common.query_context_factory.QueryContextFactory"
) as mock_factory,
):
mock_factory.return_value.create.return_value = MagicMock()
from superset.mcp_service.chart.tool.generate_chart import _compile_chart
# SupersetSecurityException is not caught by _compile_chart's
# specific except blocks, so it propagates to the caller
# (generate_chart's broad except handler).
with pytest.raises(SupersetSecurityException):
_compile_chart(
form_data={"metrics": [{"label": "count"}], "viz_type": "table"},
dataset_id=10,
)
mock_command.run.assert_not_called()