mirror of
https://github.com/apache/superset.git
synced 2026-04-18 15:44:57 +00:00
fix(mcp): add missing command.validate() to MCP chart data tools (#38521)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -103,6 +103,7 @@ def generate_preview_from_form_data(
|
||||
|
||||
# Execute query
|
||||
command = ChartDataCommand(query_context_obj)
|
||||
command.validate()
|
||||
result = command.run()
|
||||
|
||||
if not result or not result.get("queries"):
|
||||
|
||||
@@ -101,6 +101,7 @@ def _compile_chart(
|
||||
)
|
||||
|
||||
command = ChartDataCommand(query_context)
|
||||
command.validate()
|
||||
result = command.run()
|
||||
|
||||
warnings: List[str] = []
|
||||
|
||||
@@ -462,6 +462,7 @@ async def get_chart_data( # noqa: C901
|
||||
# Execute the query
|
||||
with event_logger.log_context(action="mcp.get_chart_data.query_execution"):
|
||||
command = ChartDataCommand(query_context)
|
||||
command.validate()
|
||||
result = command.run()
|
||||
|
||||
# Handle empty query results for certain chart types
|
||||
|
||||
@@ -160,6 +160,7 @@ class ASCIIPreviewStrategy(PreviewFormatStrategy):
|
||||
)
|
||||
|
||||
command = ChartDataCommand(query_context)
|
||||
command.validate()
|
||||
result = command.run()
|
||||
|
||||
data = []
|
||||
@@ -234,6 +235,7 @@ class TablePreviewStrategy(PreviewFormatStrategy):
|
||||
)
|
||||
|
||||
command = ChartDataCommand(query_context)
|
||||
command.validate()
|
||||
result = command.run()
|
||||
|
||||
data = []
|
||||
@@ -340,6 +342,7 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy):
|
||||
|
||||
# Execute the query
|
||||
command = ChartDataCommand(query_context)
|
||||
command.validate()
|
||||
result = command.run()
|
||||
|
||||
# Extract data from result
|
||||
|
||||
@@ -671,3 +671,190 @@ class TestGetChartDataRequestSchema:
|
||||
assert data["identifier"] == 123
|
||||
assert data["limit"] == 50
|
||||
assert data["format"] == "json"
|
||||
|
||||
|
||||
class TestChartDataCommandValidation:
|
||||
"""Tests that ChartDataCommand.validate() is called before run().
|
||||
|
||||
These tests verify the security fix (CWE-862) that adds command.validate()
|
||||
before command.run() in MCP chart data tools. The validate() call enforces
|
||||
row-level security, guest user guards, and schema-level permissions.
|
||||
"""
|
||||
|
||||
def test_preview_utils_calls_validate_before_run(self):
|
||||
"""Test that generate_preview_from_form_data calls validate() before run()."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
call_order: list[str] = []
|
||||
mock_query_result = {"queries": [{"data": [{"col1": "a", "col2": 1}]}]}
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.validate.side_effect = lambda: call_order.append("validate")
|
||||
mock_command.run.side_effect = lambda: (
|
||||
call_order.append("run"),
|
||||
mock_query_result,
|
||||
)[1]
|
||||
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.id = 10
|
||||
|
||||
# ChartDataCommand is module-level import in preview_utils;
|
||||
# db and QueryContextFactory are local imports inside the function.
|
||||
with (
|
||||
patch("superset.extensions.db") as mock_db,
|
||||
patch(
|
||||
"superset.mcp_service.chart.preview_utils.ChartDataCommand",
|
||||
return_value=mock_command,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory"
|
||||
) as mock_factory,
|
||||
):
|
||||
mock_db.session.query.return_value.get.return_value = mock_dataset
|
||||
mock_factory.return_value.create.return_value = MagicMock()
|
||||
|
||||
from superset.mcp_service.chart.preview_utils import (
|
||||
generate_preview_from_form_data,
|
||||
)
|
||||
|
||||
generate_preview_from_form_data(
|
||||
form_data={"metrics": [{"label": "count"}], "viz_type": "table"},
|
||||
dataset_id=10,
|
||||
preview_format="table",
|
||||
)
|
||||
|
||||
mock_command.validate.assert_called_once()
|
||||
mock_command.run.assert_called_once()
|
||||
assert call_order == ["validate", "run"]
|
||||
|
||||
def test_preview_utils_security_exception_from_validate(self):
|
||||
"""Test that SupersetSecurityException from validate() is propagated."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.mcp_service.chart.schemas import ChartError
|
||||
|
||||
security_error = SupersetSecurityException(
|
||||
SupersetError(
|
||||
message="Access denied",
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.validate.side_effect = security_error
|
||||
|
||||
mock_dataset = MagicMock()
|
||||
mock_dataset.id = 10
|
||||
|
||||
with (
|
||||
patch("superset.extensions.db") as mock_db,
|
||||
patch(
|
||||
"superset.mcp_service.chart.preview_utils.ChartDataCommand",
|
||||
return_value=mock_command,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory"
|
||||
) as mock_factory,
|
||||
):
|
||||
mock_db.session.query.return_value.get.return_value = mock_dataset
|
||||
mock_factory.return_value.create.return_value = MagicMock()
|
||||
|
||||
from superset.mcp_service.chart.preview_utils import (
|
||||
generate_preview_from_form_data,
|
||||
)
|
||||
|
||||
result = generate_preview_from_form_data(
|
||||
form_data={"metrics": [{"label": "count"}], "viz_type": "table"},
|
||||
dataset_id=10,
|
||||
preview_format="table",
|
||||
)
|
||||
|
||||
# SupersetSecurityException is caught by the broad except and
|
||||
# returned as a ChartError
|
||||
assert isinstance(result, ChartError)
|
||||
assert "Access denied" in result.error
|
||||
mock_command.run.assert_not_called()
|
||||
|
||||
def test_compile_chart_calls_validate_before_run(self):
|
||||
"""Test that _compile_chart calls validate() before run()."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
call_order: list[str] = []
|
||||
mock_query_result = {"queries": [{"data": [{"col1": 1}]}]}
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.validate.side_effect = lambda: call_order.append("validate")
|
||||
mock_command.run.side_effect = lambda: (
|
||||
call_order.append("run"),
|
||||
mock_query_result,
|
||||
)[1]
|
||||
|
||||
# Both ChartDataCommand and QueryContextFactory are local imports
|
||||
# inside _compile_chart, so patch at source.
|
||||
with (
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand",
|
||||
return_value=mock_command,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory"
|
||||
) as mock_factory,
|
||||
):
|
||||
mock_factory.return_value.create.return_value = MagicMock()
|
||||
|
||||
from superset.mcp_service.chart.tool.generate_chart import _compile_chart
|
||||
|
||||
result = _compile_chart(
|
||||
form_data={"metrics": [{"label": "count"}], "viz_type": "table"},
|
||||
dataset_id=10,
|
||||
)
|
||||
|
||||
assert result.success is True
|
||||
mock_command.validate.assert_called_once()
|
||||
mock_command.run.assert_called_once()
|
||||
assert call_order == ["validate", "run"]
|
||||
|
||||
def test_compile_chart_security_exception_from_validate(self):
|
||||
"""Test that _compile_chart propagates security exception from validate()."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
security_error = SupersetSecurityException(
|
||||
SupersetError(
|
||||
message="Row-level security violation",
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.validate.side_effect = security_error
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand",
|
||||
return_value=mock_command,
|
||||
),
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory"
|
||||
) as mock_factory,
|
||||
):
|
||||
mock_factory.return_value.create.return_value = MagicMock()
|
||||
|
||||
from superset.mcp_service.chart.tool.generate_chart import _compile_chart
|
||||
|
||||
# SupersetSecurityException is not caught by _compile_chart's
|
||||
# specific except blocks, so it propagates to the caller
|
||||
# (generate_chart's broad except handler).
|
||||
with pytest.raises(SupersetSecurityException):
|
||||
_compile_chart(
|
||||
form_data={"metrics": [{"label": "count"}], "viz_type": "table"},
|
||||
dataset_id=10,
|
||||
)
|
||||
|
||||
mock_command.run.assert_not_called()
|
||||
|
||||
Reference in New Issue
Block a user