diff --git a/superset/mcp_service/dashboard/tool/generate_dashboard.py b/superset/mcp_service/dashboard/tool/generate_dashboard.py index 1a25e1120b2..6daa750659c 100644 --- a/superset/mcp_service/dashboard/tool/generate_dashboard.py +++ b/superset/mcp_service/dashboard/tool/generate_dashboard.py @@ -25,6 +25,7 @@ import logging from typing import Any, Dict, List from fastmcp import Context +from flask import g from superset_core.mcp.decorators import tool, ToolAnnotations from superset.extensions import event_logger @@ -200,10 +201,12 @@ def generate_dashboard( Returns: - Dashboard ID and URL """ + from pydantic import ValidationError + from sqlalchemy.exc import SQLAlchemyError + try: # Get chart objects from IDs (required for SQLAlchemy relationships) from superset import db - from superset.commands.dashboard.create import CreateDashboardCommand from superset.models.slice import Slice with event_logger.log_context(action="mcp.generate_dashboard.chart_validation"): @@ -235,65 +238,92 @@ def generate_dashboard( else _generate_title_from_charts(chart_objects) ) - # Prepare dashboard data and create dashboard + # Create the dashboard directly with db.session instead of using + # CreateDashboardCommand. The command's @transaction decorator + # may operate in a different SQLAlchemy scoped-session than the + # one g.user and chart ORM objects are bound to in the MCP + # context, causing "Object is already attached to session X + # (this is Y)" errors. By re-querying all ORM objects in the + # tool's own db.session we keep everything in a single session. + from sqlalchemy.orm import subqueryload + + from superset.models.dashboard import Dashboard + with event_logger.log_context(action="mcp.generate_dashboard.db_write"): - dashboard_data: Dict[str, Any] = { - "dashboard_title": dashboard_title, - "json_metadata": json.dumps( - { - "filter_scopes": {}, - "expanded_slices": {}, - "refresh_frequency": 0, - "timed_refresh_immune_slices": [], - "color_scheme": None, - "label_colors": {}, - "shared_label_colors": {}, - "color_scheme_domain": [], - "cross_filters_enabled": False, - "native_filter_configuration": [], - "global_chart_configuration": { - "scope": { - "rootPath": ["ROOT_ID"], - "excluded": [], - } - }, - "chart_configuration": {}, - } - ), - "position_json": json.dumps(layout), - "published": request.published, - "slices": chart_objects, # Pass ORM objects, not IDs - } + json_metadata = json.dumps( + { + "filter_scopes": {}, + "expanded_slices": {}, + "refresh_frequency": 0, + "timed_refresh_immune_slices": [], + "color_scheme": None, + "label_colors": {}, + "shared_label_colors": {}, + "color_scheme_domain": [], + "cross_filters_enabled": False, + "native_filter_configuration": [], + "global_chart_configuration": { + "scope": { + "rootPath": ["ROOT_ID"], + "excluded": [], + } + }, + "chart_configuration": {}, + } + ) - if request.description: - dashboard_data["description"] = request.description - - # Create the dashboard using Superset's command pattern try: - command = CreateDashboardCommand(dashboard_data) - dashboard = command.run() - except Exception as cmd_err: - # Surface the root cause from @transaction's error wrapping - root_cause = cmd_err.__cause__ or cmd_err + dashboard = Dashboard() + dashboard.dashboard_title = dashboard_title + dashboard.json_metadata = json_metadata + dashboard.position_json = json.dumps(layout) + dashboard.published = request.published + + if request.description: + dashboard.description = request.description + + # Re-query the current user and charts directly in the + # current db.session. g.user was loaded in a Flask + # app_context that has since been torn down (the + # middleware's ``with flask_app.app_context()`` exits + # before the tool function runs), so the User object + # is bound to a dead/different scoped session. + # Querying fresh avoids all cross-session errors. + from superset.extensions import security_manager + + current_user = ( + db.session.query(security_manager.user_model) + .filter_by(id=g.user.id) + .first() + ) + if current_user: + dashboard.owners = [current_user] + + fresh_charts = ( + db.session.query(Slice) + .filter(Slice.id.in_(request.chart_ids)) + .order_by(Slice.id) + .all() + ) + dashboard.slices = fresh_charts + + db.session.add(dashboard) + db.session.commit() + except SQLAlchemyError as db_err: + db.session.rollback() logger.error( - "CreateDashboardCommand failed: %s (cause: %s)", - cmd_err, - root_cause, + "Dashboard creation failed: %s", + db_err, exc_info=True, ) return GenerateDashboardResponse( dashboard=None, dashboard_url=None, - error=f"Failed to create dashboard: {root_cause}", + error="Failed to create dashboard due to a database error.", ) - # Re-fetch the dashboard with eager-loaded relationships to avoid - # "Instance is not bound to a Session" errors when serializing - # chart .tags and .owners. - from sqlalchemy.orm import subqueryload - + # Re-fetch with eager-loaded relationships for serialization from superset.daos.dashboard import DashboardDAO - from superset.models.dashboard import Dashboard dashboard = ( DashboardDAO.find_by_id( @@ -355,8 +385,8 @@ def generate_dashboard( dashboard=dashboard_info, dashboard_url=dashboard_url, error=None ) - except Exception as e: - logger.error("Error creating dashboard: %s", e) + except (SQLAlchemyError, ValueError, AttributeError, ValidationError) as e: + logger.error("Error creating dashboard: %s", e, exc_info=True) return GenerateDashboardResponse( dashboard=None, dashboard_url=None, diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py index 0f07432eea1..b3928e4248c 100644 --- a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py @@ -105,30 +105,59 @@ def _mock_dashboard(id: int = 1, title: str = "Test Dashboard") -> Mock: return dashboard +def _setup_generate_dashboard_mocks( + mock_db_session, + mock_find_by_id, + mock_dashboard_cls, + charts, + dashboard, +): + """Set up common mocks for generate_dashboard tests. + + The tool creates dashboards directly via db.session (bypassing + CreateDashboardCommand) and re-queries user/charts in the tool's + own session. This helper wires up the mock chain for that path. + """ + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_user.first_name = "Admin" + mock_user.last_name = "User" + mock_user.email = "admin@example.com" + mock_user.active = True + + mock_query = Mock() + mock_filter = Mock() + mock_query.filter.return_value = mock_filter + mock_query.filter_by.return_value = mock_filter + mock_filter.order_by.return_value = mock_filter + mock_filter.all.return_value = charts + mock_filter.first.return_value = mock_user + mock_db_session.query.return_value = mock_query + + mock_dashboard_cls.return_value = dashboard + mock_find_by_id.return_value = dashboard + + class TestGenerateDashboard: """Tests for generate_dashboard MCP tool.""" - @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.models.dashboard.Dashboard") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") @patch("superset.db.session") @pytest.mark.asyncio async def test_generate_dashboard_basic( - self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server + self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server ): """Test basic dashboard generation with valid charts.""" - mock_query = Mock() - mock_filter = Mock() - mock_query.filter.return_value = mock_filter - mock_filter.order_by.return_value = mock_filter - mock_filter.all.return_value = [ + charts = [ _mock_chart(id=1, slice_name="Sales Chart"), _mock_chart(id=2, slice_name="Revenue Chart"), ] - mock_db_session.query.return_value = mock_query - mock_dashboard = _mock_dashboard(id=10, title="Analytics Dashboard") - mock_create_command.return_value.run.return_value = mock_dashboard - mock_find_by_id.return_value = mock_dashboard + _setup_generate_dashboard_mocks( + mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard + ) request = { "chart_ids": [1, 2], @@ -173,24 +202,19 @@ class TestGenerateDashboard: assert result.structured_content["dashboard"] is None assert result.structured_content["dashboard_url"] is None - @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.models.dashboard.Dashboard") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") @patch("superset.db.session") @pytest.mark.asyncio async def test_generate_dashboard_single_chart( - self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server + self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server ): """Test dashboard generation with a single chart.""" - mock_query = Mock() - mock_filter = Mock() - mock_query.filter.return_value = mock_filter - mock_filter.order_by.return_value = mock_filter - mock_filter.all.return_value = [_mock_chart(id=5, slice_name="Single Chart")] - mock_db_session.query.return_value = mock_query - + charts = [_mock_chart(id=5, slice_name="Single Chart")] mock_dashboard = _mock_dashboard(id=20, title="Single Chart Dashboard") - mock_create_command.return_value.run.return_value = mock_dashboard - mock_find_by_id.return_value = mock_dashboard + _setup_generate_dashboard_mocks( + mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard + ) request = { "chart_ids": [5], @@ -203,29 +227,22 @@ class TestGenerateDashboard: assert result.structured_content["error"] is None assert result.structured_content["dashboard"]["chart_count"] == 1 - assert result.structured_content["dashboard"]["published"] is True + assert result.structured_content["dashboard"]["published"] is False - @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.models.dashboard.Dashboard") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") @patch("superset.db.session") @pytest.mark.asyncio async def test_generate_dashboard_many_charts( - self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server + self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server ): """Test dashboard generation with many charts (grid layout).""" chart_ids = list(range(1, 7)) - mock_query = Mock() - mock_filter = Mock() - mock_query.filter.return_value = mock_filter - mock_filter.order_by.return_value = mock_filter - mock_filter.all.return_value = [ - _mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids - ] - mock_db_session.query.return_value = mock_query - + charts = [_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids] mock_dashboard = _mock_dashboard(id=30, title="Multi Chart Dashboard") - mock_create_command.return_value.run.return_value = mock_dashboard - mock_find_by_id.return_value = mock_dashboard + _setup_generate_dashboard_mocks( + mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard + ) request = {"chart_ids": chart_ids, "dashboard_title": "Multi Chart Dashboard"} @@ -235,10 +252,14 @@ class TestGenerateDashboard: assert result.structured_content["error"] is None assert result.structured_content["dashboard"]["chart_count"] == 6 - mock_create_command.assert_called_once() - call_args = mock_create_command.call_args[0][0] + # Verify db.session.add and commit were called + # (commit is called multiple times: once by tool + event_logger contexts) + mock_db_session.add.assert_called_once() + assert mock_db_session.commit.call_count >= 1 - position_json = json.loads(call_args["position_json"]) + # Verify layout was set on the dashboard object + created_dashboard = mock_dashboard_cls.return_value + position_json = json.loads(created_dashboard.position_json) assert "ROOT_ID" in position_json assert "GRID_ID" in position_json assert "DASHBOARD_VERSION_KEY" in position_json @@ -278,20 +299,33 @@ class TestGenerateDashboard: assert row_data["type"] == "ROW" assert column_key in row_data["children"] - @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.models.dashboard.Dashboard") @patch("superset.db.session") @pytest.mark.asyncio async def test_generate_dashboard_creation_failure( - self, mock_db_session, mock_create_command, mcp_server + self, mock_db_session, mock_dashboard_cls, mcp_server ): """Test error handling when dashboard creation fails.""" + from sqlalchemy.exc import SQLAlchemyError + mock_query = Mock() mock_filter = Mock() mock_query.filter.return_value = mock_filter + mock_query.filter_by.return_value = mock_filter mock_filter.order_by.return_value = mock_filter mock_filter.all.return_value = [_mock_chart(id=1)] + mock_filter.first.return_value = Mock( + id=1, + username="admin", + first_name="Admin", + last_name="User", + email="admin@example.com", + active=True, + ) mock_db_session.query.return_value = mock_query - mock_create_command.return_value.run.side_effect = Exception("Creation failed") + mock_db_session.commit.side_effect = SQLAlchemyError("Creation failed") + + mock_dashboard_cls.return_value = _mock_dashboard(id=99) request = {"chart_ids": [1], "dashboard_title": "Failed Dashboard"} @@ -301,25 +335,22 @@ class TestGenerateDashboard: assert result.structured_content["error"] is not None assert "Failed to create dashboard" in result.structured_content["error"] assert result.structured_content["dashboard"] is None + # rollback called by tool + event_logger error handling + assert mock_db_session.rollback.call_count >= 1 - @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.models.dashboard.Dashboard") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") @patch("superset.db.session") @pytest.mark.asyncio async def test_generate_dashboard_minimal_request( - self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server + self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server ): """Test dashboard generation with minimal required parameters.""" - mock_query = Mock() - mock_filter = Mock() - mock_query.filter.return_value = mock_filter - mock_filter.order_by.return_value = mock_filter - mock_filter.all.return_value = [_mock_chart(id=3)] - mock_db_session.query.return_value = mock_query - + charts = [_mock_chart(id=3)] mock_dashboard = _mock_dashboard(id=40, title="Minimal Dashboard") - mock_create_command.return_value.run.return_value = mock_dashboard - mock_find_by_id.return_value = mock_dashboard + _setup_generate_dashboard_mocks( + mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard + ) request = { "chart_ids": [3], @@ -335,33 +366,26 @@ class TestGenerateDashboard: == "Minimal Dashboard" ) - call_args = mock_create_command.call_args[0][0] - assert call_args["published"] is True - assert ( - "description" not in call_args or call_args.get("description") is None - ) + # Verify dashboard was created with default published=True + created = mock_dashboard_cls.return_value + assert created.published is True - @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.models.dashboard.Dashboard") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") @patch("superset.db.session") @pytest.mark.asyncio async def test_generate_dashboard_auto_title_from_charts( - self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server + self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server ): """Test that omitting dashboard_title generates a title from chart names.""" - mock_query = Mock() - mock_filter = Mock() - mock_query.filter.return_value = mock_filter - mock_filter.order_by.return_value = mock_filter - mock_filter.all.return_value = [ + charts = [ _mock_chart(id=1, slice_name="Sales Revenue"), _mock_chart(id=2, slice_name="Customer Count"), ] - mock_db_session.query.return_value = mock_query - mock_dashboard = _mock_dashboard(id=50, title="Sales Revenue & Customer Count") - mock_create_command.return_value.run.return_value = mock_dashboard - mock_find_by_id.return_value = mock_dashboard + _setup_generate_dashboard_mocks( + mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard + ) # No dashboard_title provided request = {"chart_ids": [1, 2]} @@ -371,29 +395,23 @@ class TestGenerateDashboard: assert result.structured_content["error"] is None - call_args = mock_create_command.call_args[0][0] - assert call_args["dashboard_title"] == "Sales Revenue & Customer Count" + # Verify auto-generated title was set on dashboard + created = mock_dashboard_cls.return_value + assert created.dashboard_title == "Sales Revenue & Customer Count" - @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.models.dashboard.Dashboard") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") @patch("superset.db.session") @pytest.mark.asyncio async def test_generate_dashboard_empty_string_title_preserved( - self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server + self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server ): """Test that an explicit empty-string title is NOT replaced by auto-gen.""" - mock_query = Mock() - mock_filter = Mock() - mock_query.filter.return_value = mock_filter - mock_filter.order_by.return_value = mock_filter - mock_filter.all.return_value = [ - _mock_chart(id=1, slice_name="Sales Revenue"), - ] - mock_db_session.query.return_value = mock_query - + charts = [_mock_chart(id=1, slice_name="Sales Revenue")] mock_dashboard = _mock_dashboard(id=60, title="") - mock_create_command.return_value.run.return_value = mock_dashboard - mock_find_by_id.return_value = mock_dashboard + _setup_generate_dashboard_mocks( + mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard + ) # Explicit empty string title request = {"chart_ids": [1], "dashboard_title": ""} @@ -403,8 +421,9 @@ class TestGenerateDashboard: assert result.structured_content["error"] is None - call_args = mock_create_command.call_args[0][0] - assert call_args["dashboard_title"] == "" + # Verify empty string title was preserved (not replaced by auto-gen) + created = mock_dashboard_cls.return_value + assert created.dashboard_title == "" class TestAddChartToExistingDashboard: