diff --git a/superset/mcp_service/chart/preview_utils.py b/superset/mcp_service/chart/preview_utils.py index 677d3034fd4..c28681a9da0 100644 --- a/superset/mcp_service/chart/preview_utils.py +++ b/superset/mcp_service/chart/preview_utils.py @@ -103,6 +103,7 @@ def generate_preview_from_form_data( # Execute query command = ChartDataCommand(query_context_obj) + command.validate() result = command.run() if not result or not result.get("queries"): diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 6ff3bee59cc..b5711d5e5f7 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -101,6 +101,7 @@ def _compile_chart( ) command = ChartDataCommand(query_context) + command.validate() result = command.run() warnings: List[str] = [] diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py index eaa50ca6b45..9239a278d9e 100644 --- a/superset/mcp_service/chart/tool/get_chart_data.py +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -462,6 +462,7 @@ async def get_chart_data( # noqa: C901 # Execute the query with event_logger.log_context(action="mcp.get_chart_data.query_execution"): command = ChartDataCommand(query_context) + command.validate() result = command.run() # Handle empty query results for certain chart types diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py index ee7c70a481b..6ee3ece4df4 100644 --- a/superset/mcp_service/chart/tool/get_chart_preview.py +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -160,6 +160,7 @@ class ASCIIPreviewStrategy(PreviewFormatStrategy): ) command = ChartDataCommand(query_context) + command.validate() result = command.run() data = [] @@ -234,6 +235,7 @@ class TablePreviewStrategy(PreviewFormatStrategy): ) command = ChartDataCommand(query_context) + command.validate() result = command.run() data = [] @@ -340,6 +342,7 @@ class VegaLitePreviewStrategy(PreviewFormatStrategy): # Execute the query command = ChartDataCommand(query_context) + command.validate() result = command.run() # Extract data from result diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py index 7850ef82a9a..0230d5edcf8 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_data.py @@ -671,3 +671,190 @@ class TestGetChartDataRequestSchema: assert data["identifier"] == 123 assert data["limit"] == 50 assert data["format"] == "json" + + +class TestChartDataCommandValidation: + """Tests that ChartDataCommand.validate() is called before run(). + + These tests verify the security fix (CWE-862) that adds command.validate() + before command.run() in MCP chart data tools. The validate() call enforces + row-level security, guest user guards, and schema-level permissions. + """ + + def test_preview_utils_calls_validate_before_run(self): + """Test that generate_preview_from_form_data calls validate() before run().""" + from unittest.mock import MagicMock, patch + + call_order: list[str] = [] + mock_query_result = {"queries": [{"data": [{"col1": "a", "col2": 1}]}]} + + mock_command = MagicMock() + mock_command.validate.side_effect = lambda: call_order.append("validate") + mock_command.run.side_effect = lambda: ( + call_order.append("run"), + mock_query_result, + )[1] + + mock_dataset = MagicMock() + mock_dataset.id = 10 + + # ChartDataCommand is module-level import in preview_utils; + # db and QueryContextFactory are local imports inside the function. + with ( + patch("superset.extensions.db") as mock_db, + patch( + "superset.mcp_service.chart.preview_utils.ChartDataCommand", + return_value=mock_command, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory" + ) as mock_factory, + ): + mock_db.session.query.return_value.get.return_value = mock_dataset + mock_factory.return_value.create.return_value = MagicMock() + + from superset.mcp_service.chart.preview_utils import ( + generate_preview_from_form_data, + ) + + generate_preview_from_form_data( + form_data={"metrics": [{"label": "count"}], "viz_type": "table"}, + dataset_id=10, + preview_format="table", + ) + + mock_command.validate.assert_called_once() + mock_command.run.assert_called_once() + assert call_order == ["validate", "run"] + + def test_preview_utils_security_exception_from_validate(self): + """Test that SupersetSecurityException from validate() is propagated.""" + from unittest.mock import MagicMock, patch + + from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + from superset.exceptions import SupersetSecurityException + from superset.mcp_service.chart.schemas import ChartError + + security_error = SupersetSecurityException( + SupersetError( + message="Access denied", + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + mock_command = MagicMock() + mock_command.validate.side_effect = security_error + + mock_dataset = MagicMock() + mock_dataset.id = 10 + + with ( + patch("superset.extensions.db") as mock_db, + patch( + "superset.mcp_service.chart.preview_utils.ChartDataCommand", + return_value=mock_command, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory" + ) as mock_factory, + ): + mock_db.session.query.return_value.get.return_value = mock_dataset + mock_factory.return_value.create.return_value = MagicMock() + + from superset.mcp_service.chart.preview_utils import ( + generate_preview_from_form_data, + ) + + result = generate_preview_from_form_data( + form_data={"metrics": [{"label": "count"}], "viz_type": "table"}, + dataset_id=10, + preview_format="table", + ) + + # SupersetSecurityException is caught by the broad except and + # returned as a ChartError + assert isinstance(result, ChartError) + assert "Access denied" in result.error + mock_command.run.assert_not_called() + + def test_compile_chart_calls_validate_before_run(self): + """Test that _compile_chart calls validate() before run().""" + from unittest.mock import MagicMock, patch + + call_order: list[str] = [] + mock_query_result = {"queries": [{"data": [{"col1": 1}]}]} + + mock_command = MagicMock() + mock_command.validate.side_effect = lambda: call_order.append("validate") + mock_command.run.side_effect = lambda: ( + call_order.append("run"), + mock_query_result, + )[1] + + # Both ChartDataCommand and QueryContextFactory are local imports + # inside _compile_chart, so patch at source. + with ( + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand", + return_value=mock_command, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory" + ) as mock_factory, + ): + mock_factory.return_value.create.return_value = MagicMock() + + from superset.mcp_service.chart.tool.generate_chart import _compile_chart + + result = _compile_chart( + form_data={"metrics": [{"label": "count"}], "viz_type": "table"}, + dataset_id=10, + ) + + assert result.success is True + mock_command.validate.assert_called_once() + mock_command.run.assert_called_once() + assert call_order == ["validate", "run"] + + def test_compile_chart_security_exception_from_validate(self): + """Test that _compile_chart propagates security exception from validate().""" + from unittest.mock import MagicMock, patch + + from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + from superset.exceptions import SupersetSecurityException + + security_error = SupersetSecurityException( + SupersetError( + message="Row-level security violation", + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + mock_command = MagicMock() + mock_command.validate.side_effect = security_error + + with ( + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand", + return_value=mock_command, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory" + ) as mock_factory, + ): + mock_factory.return_value.create.return_value = MagicMock() + + from superset.mcp_service.chart.tool.generate_chart import _compile_chart + + # SupersetSecurityException is not caught by _compile_chart's + # specific except blocks, so it propagates to the caller + # (generate_chart's broad except handler). + with pytest.raises(SupersetSecurityException): + _compile_chart( + form_data={"metrics": [{"label": "count"}], "viz_type": "table"}, + dataset_id=10, + ) + + mock_command.run.assert_not_called()