diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 4a439d3bd74..c637d3d60b2 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -25,6 +25,7 @@ from typing import Any, Dict, List from urllib.parse import parse_qs, urlparse from fastmcp import Context +from sqlalchemy.exc import SQLAlchemyError from superset_core.mcp.decorators import tool, ToolAnnotations from superset.commands.exceptions import CommandException @@ -710,36 +711,55 @@ async def generate_chart( # noqa: C901 # Build chart info using serialize_chart_object for saved charts chart_info = None + chart_data = None if request.save_chart and chart: from sqlalchemy.orm import joinedload + from superset import db from superset.daos.chart import ChartDAO from superset.mcp_service.chart.schemas import serialize_chart_object from superset.models.slice import Slice # Re-fetch with eager-loaded relationships to avoid detached # instance errors when serialize_chart_object accesses .tags - # and .owners. Use joinedload (single JOIN query) since we - # are fetching a single chart. - chart = ( - ChartDAO.find_by_id( - chart.id, - query_options=[ - joinedload(Slice.owners), - joinedload(Slice.tags), - ], + # and .owners. The preceding commit may invalidate the session + # in multi-tenant environments; on failure, build a minimal + # chart_data dict from scalar attributes that are already loaded + # — relationship fields (owners, tags) would trigger + # lazy-loading on the same dead session. + try: + chart = ( + ChartDAO.find_by_id( + chart.id, + query_options=[ + joinedload(Slice.owners), + joinedload(Slice.tags), + ], + ) + or chart ) - or chart - ) + except SQLAlchemyError: + logger.warning( + "Re-fetch of chart %s failed; returning minimal response", + chart.id, + exc_info=True, + ) + db.session.rollback() + chart_data = { + "id": chart.id, + "slice_name": chart.slice_name, + "viz_type": chart.viz_type, + "url": explore_url, + "uuid": str(chart.uuid) if chart.uuid else None, + } - chart_info = serialize_chart_object(chart) - if chart_info: - # Override the URL with explore_url - chart_info.url = explore_url + if chart_data is None: + chart_info = serialize_chart_object(chart) + if chart_info: + chart_info.url = explore_url # Safely serialize chart_info - handle both Pydantic models and dicts - chart_data = None - if chart_info: + if chart_data is None and chart_info is not None: if hasattr(chart_info, "model_dump"): chart_data = chart_info.model_dump() elif isinstance(chart_info, dict): @@ -786,7 +806,7 @@ async def generate_chart( # noqa: C901 ) return GenerateChartResponse.model_validate(result) - except Exception as e: + except (CommandException, SQLAlchemyError, KeyError, ValueError) as e: await ctx.error( "Chart generation failed: error=%s, execution_time_ms=%s" % ( diff --git a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py index 0c371d7be30..2262bdb722f 100644 --- a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py +++ b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py @@ -322,6 +322,10 @@ def add_chart_to_existing_dashboard( Add chart to existing dashboard. Auto-positions in 2-column grid. Returns updated dashboard info. """ + from sqlalchemy.exc import SQLAlchemyError + + from superset.commands.exceptions import CommandException + try: from superset.commands.dashboard.update import UpdateDashboardCommand from superset.daos.dashboard import DashboardDAO @@ -426,25 +430,54 @@ def add_chart_to_existing_dashboard( # Re-fetch the dashboard with eager-loaded relationships to avoid # "Instance is not bound to a Session" errors when serializing - # chart .tags and .owners. + # chart .tags and .owners. The preceding command.run() commit may + # invalidate the session in multi-tenant environments; on failure, + # return a minimal response using only scalar attributes that are + # already loaded — relationship fields (owners, tags, slices) would + # trigger lazy-loading on the same dead session. from sqlalchemy.orm import subqueryload - from superset.daos.dashboard import DashboardDAO from superset.models.dashboard import Dashboard from superset.models.slice import Slice - updated_dashboard = ( - DashboardDAO.find_by_id( - updated_dashboard.id, - query_options=[ - subqueryload(Dashboard.slices).subqueryload(Slice.owners), - subqueryload(Dashboard.slices).subqueryload(Slice.tags), - subqueryload(Dashboard.owners), - subqueryload(Dashboard.tags), - ], + try: + updated_dashboard = ( + DashboardDAO.find_by_id( + updated_dashboard.id, + query_options=[ + subqueryload(Dashboard.slices).subqueryload(Slice.owners), + subqueryload(Dashboard.slices).subqueryload(Slice.tags), + subqueryload(Dashboard.owners), + subqueryload(Dashboard.tags), + ], + ) + or updated_dashboard + ) + except SQLAlchemyError: + logger.warning( + "Re-fetch of dashboard %s failed; returning minimal response", + updated_dashboard.id, + exc_info=True, + ) + db.session.rollback() + dashboard_url = ( + f"{get_superset_base_url()}/superset/dashboard/{updated_dashboard.id}/" + ) + position_info = { + "row": row_key, + "chart_key": chart_key, + "row_key": row_key, + } + return AddChartToDashboardResponse( + dashboard=DashboardInfo( + id=updated_dashboard.id, + dashboard_title=updated_dashboard.dashboard_title, + url=dashboard_url, + ), + dashboard_url=dashboard_url, + position=position_info, + error=None, ) - or updated_dashboard - ) # Convert to response format from superset.mcp_service.dashboard.schemas import ( @@ -477,8 +510,9 @@ def add_chart_to_existing_dashboard( ], roles=[], charts=[ - serialize_chart_object(chart) + obj for chart in getattr(updated_dashboard, "slices", []) + if (obj := serialize_chart_object(chart)) is not None ], ) @@ -500,7 +534,7 @@ def add_chart_to_existing_dashboard( error=None, ) - except Exception as e: + except (CommandException, SQLAlchemyError, KeyError, ValueError) as e: logger.error("Error adding chart to dashboard: %s", e) return AddChartToDashboardResponse( dashboard=None, diff --git a/superset/mcp_service/dashboard/tool/generate_dashboard.py b/superset/mcp_service/dashboard/tool/generate_dashboard.py index 875106ac364..bd9ce924709 100644 --- a/superset/mcp_service/dashboard/tool/generate_dashboard.py +++ b/superset/mcp_service/dashboard/tool/generate_dashboard.py @@ -339,21 +339,49 @@ def generate_dashboard( error="Failed to create dashboard due to a database error.", ) - # Re-fetch with eager-loaded relationships for serialization + # Re-fetch with eager-loaded relationships for serialization. + # The preceding commit may invalidate the session in multi-tenant + # environments, causing "Can't reconnect until invalid transaction + # is rolled back". Wrap the DAO re-fetch in try/except; on failure, + # return a minimal response using only scalar attributes that are + # already loaded — relationship fields (owners, tags, slices) would + # trigger lazy-loading on the same dead session. from superset.daos.dashboard import DashboardDAO - dashboard = ( - DashboardDAO.find_by_id( - dashboard.id, - query_options=[ - subqueryload(Dashboard.slices).subqueryload(Slice.owners), - subqueryload(Dashboard.slices).subqueryload(Slice.tags), - subqueryload(Dashboard.owners), - subqueryload(Dashboard.tags), - ], + try: + dashboard = ( + DashboardDAO.find_by_id( + dashboard.id, + query_options=[ + subqueryload(Dashboard.slices).subqueryload(Slice.owners), + subqueryload(Dashboard.slices).subqueryload(Slice.tags), + subqueryload(Dashboard.owners), + subqueryload(Dashboard.tags), + ], + ) + or dashboard + ) + except SQLAlchemyError: + logger.warning( + "Re-fetch of dashboard %s failed; returning minimal response", + dashboard.id, + exc_info=True, + ) + db.session.rollback() + dashboard_url = ( + f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/" + ) + return GenerateDashboardResponse( + dashboard=DashboardInfo( + id=dashboard.id, + dashboard_title=dashboard.dashboard_title, + url=dashboard_url, + chart_count=len(request.chart_ids), + published=dashboard.published, + ), + dashboard_url=dashboard_url, + error=None, ) - or dashboard - ) # Convert to our response format from superset.mcp_service.dashboard.schemas import ( diff --git a/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py b/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py index ef6ce5383f3..8fba05e0ece 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py @@ -410,38 +410,68 @@ class TestChartSerializationEagerLoading: with pytest.raises(DetachedInstanceError): serialize_chart_object(chart) - @patch("superset.daos.chart.ChartDAO.find_by_id") - def test_generate_chart_refetches_via_dao(self, mock_find_by_id): - """The serialization path re-fetches the chart via ChartDAO.find_by_id - with joinedload query_options for owners and tags.""" + def test_generate_chart_refetches_via_dao(self): + """The serialization path re-fetches the chart via + ChartDAO.find_by_id() with query_options for owners and tags.""" refetched_chart = _make_mock_chart() refetched_chart.tags = [Mock(id=1, name="tag1", type="custom")] refetched_chart.tags[0].description = "" - mock_find_by_id.return_value = refetched_chart - - from superset.daos.chart import ChartDAO + mock_dao = MagicMock() + mock_dao.find_by_id.return_value = refetched_chart chart = ( - ChartDAO.find_by_id(42, query_options=["dummy_option"]) + mock_dao.find_by_id(42, query_options=[Mock(), Mock()]) or _make_mock_chart() ) assert chart is refetched_chart - mock_find_by_id.assert_called_once_with(42, query_options=["dummy_option"]) + mock_dao.find_by_id.assert_called() - @patch("superset.daos.chart.ChartDAO.find_by_id") - def test_generate_chart_falls_back_to_original_on_refetch_failure( - self, mock_find_by_id - ): - """Falls back to original chart if ChartDAO.find_by_id returns None.""" + def test_generate_chart_falls_back_to_original_on_dao_none(self): + """Falls back to original chart if ChartDAO.find_by_id() + returns None.""" original_chart = _make_mock_chart() - mock_find_by_id.return_value = None - from superset.daos.chart import ChartDAO + mock_dao = MagicMock() + mock_dao.find_by_id.return_value = None - chart = ( - ChartDAO.find_by_id(original_chart.id, query_options=[]) or original_chart - ) + chart = mock_dao.find_by_id(42, query_options=[Mock()]) or original_chart assert chart is original_chart + + def test_generate_chart_refetch_sqlalchemy_error_rollback(self): + """When the DAO re-fetch raises SQLAlchemyError, the session is + rolled back and a minimal chart_data dict is built from scalar + attributes instead of calling serialize_chart_object (which would + trigger lazy-loading on the same dead session).""" + from sqlalchemy.exc import SQLAlchemyError + + original_chart = _make_mock_chart() + mock_dao = MagicMock() + mock_dao.find_by_id.side_effect = SQLAlchemyError("session error") + mock_session = MagicMock() + explore_url = "http://example.com/explore/?slice_id=42" + + chart_data = None + try: + mock_dao.find_by_id(42, query_options=[Mock()]) + except SQLAlchemyError: + mock_session.rollback() + chart_data = { + "id": original_chart.id, + "slice_name": original_chart.slice_name, + "viz_type": original_chart.viz_type, + "url": explore_url, + "uuid": str(original_chart.uuid) if original_chart.uuid else None, + } + + mock_session.rollback.assert_called() + # Minimal chart_data should contain scalar fields only + assert chart_data is not None + assert chart_data["id"] == original_chart.id + assert chart_data["slice_name"] == original_chart.slice_name + assert chart_data["url"] == explore_url + # No tags/owners keys — those would require relationship access + assert "tags" not in chart_data + assert "owners" not in chart_data diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py index 24146a35252..02eebddeb00 100644 --- a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py @@ -20,7 +20,7 @@ Unit tests for dashboard generation MCP tools """ import logging -from unittest.mock import Mock, patch +from unittest.mock import MagicMock, Mock, patch import pytest from fastmcp import Client @@ -133,7 +133,8 @@ def _setup_generate_dashboard_mocks( The tool creates dashboards directly via db.session (bypassing CreateDashboardCommand) and re-queries user/charts in the tool's - own session. This helper wires up the mock chain for that path. + own session. The re-fetch uses DashboardDAO.find_by_id() with + query_options for eager loading of slice relationships. """ mock_user = Mock() mock_user.id = 1 @@ -143,8 +144,8 @@ def _setup_generate_dashboard_mocks( mock_user.email = "admin@example.com" mock_user.active = True - mock_query = Mock() - mock_filter = Mock() + mock_query = MagicMock() + mock_filter = MagicMock() mock_query.filter.return_value = mock_filter mock_query.filter_by.return_value = mock_filter mock_filter.order_by.return_value = mock_filter @@ -153,6 +154,7 @@ def _setup_generate_dashboard_mocks( mock_db_session.query.return_value = mock_query mock_dashboard_cls.return_value = dashboard + # DashboardDAO.find_by_id is used for the re-fetch with eager loading mock_find_by_id.return_value = dashboard @@ -556,15 +558,15 @@ class TestAddChartToExistingDashboard: _mock_chart(id=20), _mock_chart(id=30), ] - # First call: initial validation returns original dashboard - # Second call: re-fetch after update returns updated dashboard - mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard] - mock_chart = _mock_chart(id=30, slice_name="New Chart") mock_db_session.get.return_value = mock_chart mock_update_command.return_value.run.return_value = updated_dashboard + # First DAO call returns initial dashboard (validation), + # second DAO call returns updated dashboard (re-fetch with eager loading) + mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard] + request = {"dashboard_id": 1, "chart_id": 30} async with Client(mcp_server) as client: @@ -700,8 +702,6 @@ class TestAddChartToExistingDashboard: mock_dashboard = _mock_dashboard(id=2) mock_dashboard.slices = [] mock_dashboard.position_json = "{}" - mock_find_dashboard.return_value = mock_dashboard - mock_chart = _mock_chart(id=15) mock_db_session.get.return_value = mock_chart @@ -709,6 +709,10 @@ class TestAddChartToExistingDashboard: updated_dashboard.slices = [_mock_chart(id=15)] mock_update_command.return_value.run.return_value = updated_dashboard + # First DAO call returns initial dashboard (validation), + # second returns updated dashboard (re-fetch) + mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard] + request = {"dashboard_id": 2, "chart_id": 15} async with Client(mcp_server) as client: @@ -809,8 +813,6 @@ class TestAddChartToExistingDashboard: "DASHBOARD_VERSION_KEY": "v2", } ) - mock_find_dashboard.return_value = mock_dashboard - mock_chart = _mock_chart(id=25, slice_name="Tab Chart") mock_db_session.get.return_value = mock_chart @@ -818,6 +820,10 @@ class TestAddChartToExistingDashboard: updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=25)] mock_update_command.return_value.run.return_value = updated_dashboard + # First DAO call returns initial dashboard (validation), + # second returns updated dashboard (re-fetch) + mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard] + request = {"dashboard_id": 3, "chart_id": 25} async with Client(mcp_server) as client: @@ -909,8 +915,6 @@ class TestAddChartToExistingDashboard: "DASHBOARD_VERSION_KEY": "v2", } ) - mock_find_dashboard.return_value = mock_dashboard - mock_chart = _mock_chart(id=30, slice_name="Customer Chart") mock_db_session.get.return_value = mock_chart @@ -918,6 +922,10 @@ class TestAddChartToExistingDashboard: updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=30)] mock_update_command.return_value.run.return_value = updated_dashboard + # First DAO call returns initial dashboard (validation), + # second returns updated dashboard (re-fetch) + mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard] + request = {"dashboard_id": 3, "chart_id": 30, "target_tab": "Customers"} async with Client(mcp_server) as client: @@ -984,8 +992,6 @@ class TestAddChartToExistingDashboard: "DASHBOARD_VERSION_KEY": "v2", } ) - mock_find_dashboard.return_value = mock_dashboard - mock_chart = _mock_chart(id=50, slice_name="New Nanoid Chart") mock_db_session.get.return_value = mock_chart @@ -993,6 +999,10 @@ class TestAddChartToExistingDashboard: updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=50)] mock_update_command.return_value.run.return_value = updated_dashboard + # First DAO call returns initial dashboard (validation), + # second returns updated dashboard (re-fetch) + mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard] + request = {"dashboard_id": 4, "chart_id": 50} async with Client(mcp_server) as client: @@ -1218,58 +1228,122 @@ class TestGenerateTitleFromCharts: class TestDashboardSerializationEagerLoading: - """Tests for eager loading fix in dashboard serialization paths.""" + """Tests for eager loading fix in dashboard serialization paths. + The re-fetch uses DashboardDAO.find_by_id() with query_options for + eager loading. A try/except around the DAO call handles "Can't + reconnect until invalid transaction is rolled back" errors in + multi-tenant environments by rolling back and falling back to the + original dashboard object. + """ + + @patch("superset.models.dashboard.Dashboard") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") - def test_generate_dashboard_refetches_via_dao(self, mock_find_by_id): - """generate_dashboard re-fetches dashboard via DashboardDAO.find_by_id + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_generate_dashboard_refetches_via_dao( + self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server + ): + """generate_dashboard re-fetches dashboard via DashboardDAO.find_by_id() with eager-loaded slice relationships before serialization.""" - refetched_dashboard = _mock_dashboard() - refetched_chart = _mock_chart(id=1, slice_name="Refetched Chart") - refetched_dashboard.slices = [refetched_chart] - - mock_find_by_id.return_value = refetched_dashboard - - from superset.daos.dashboard import DashboardDAO - - result = ( - DashboardDAO.find_by_id(1, query_options=["dummy"]) or _mock_dashboard() + charts = [_mock_chart(id=1, slice_name="Chart 1")] + dashboard = _mock_dashboard(id=10, title="Refetch Test") + _setup_generate_dashboard_mocks( + mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, dashboard ) - assert result is refetched_dashboard - mock_find_by_id.assert_called_once_with(1, query_options=["dummy"]) + request = {"chart_ids": [1], "dashboard_title": "Refetch Test"} + async with Client(mcp_server) as client: + result = await client.call_tool("generate_dashboard", {"request": request}) + assert result.structured_content["error"] is None + # Verify DashboardDAO.find_by_id was called for re-fetch + mock_find_by_id.assert_called() + + @patch("superset.models.dashboard.Dashboard") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") - def test_add_chart_refetches_dashboard_via_dao(self, mock_find_by_id): - """add_chart_to_existing_dashboard re-fetches dashboard via - DashboardDAO.find_by_id with eager-loaded slice relationships.""" - original_dashboard = _mock_dashboard() - refetched_dashboard = _mock_dashboard() - refetched_dashboard.slices = [_mock_chart()] + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_generate_dashboard_refetch_sqlalchemy_error_rollback( + self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server + ): + """When the DAO re-fetch raises SQLAlchemyError, the session is + rolled back and a minimal response is returned with only scalar + attributes (no owners/tags/charts that would trigger lazy-loading).""" + from sqlalchemy.exc import SQLAlchemyError - mock_find_by_id.return_value = refetched_dashboard - - from superset.daos.dashboard import DashboardDAO - - result = ( - DashboardDAO.find_by_id(original_dashboard.id, query_options=["dummy"]) - or original_dashboard + charts = [_mock_chart(id=1, slice_name="Chart 1")] + dashboard = _mock_dashboard(id=10, title="Rollback Test") + _setup_generate_dashboard_mocks( + mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, dashboard ) + # Make the DAO re-fetch raise SQLAlchemyError + mock_find_by_id.side_effect = SQLAlchemyError("session error") - assert result is refetched_dashboard + request = {"chart_ids": [1], "dashboard_title": "Rollback Test"} + async with Client(mcp_server) as client: + result = await client.call_tool("generate_dashboard", {"request": request}) + data = result.structured_content + assert data["error"] is None + mock_db_session.rollback.assert_called() + # Minimal response should have scalar fields + dash = data["dashboard"] + assert dash["id"] == 10 + assert dash["dashboard_title"] == "Rollback Test" + assert "/superset/dashboard/10/" in data["dashboard_url"] + # Relationship fields should be empty (defaults) + assert dash["owners"] == [] + assert dash["tags"] == [] + assert dash["charts"] == [] + + @patch("superset.commands.dashboard.update.UpdateDashboardCommand") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") - def test_add_chart_falls_back_on_refetch_failure(self, mock_find_by_id): - """add_chart_to_existing_dashboard falls back to original dashboard - if DashboardDAO.find_by_id returns None.""" - original_dashboard = _mock_dashboard() - mock_find_by_id.return_value = None + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_add_chart_refetch_sqlalchemy_error_rollback( + self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server + ): + """When the DAO re-fetch raises SQLAlchemyError after adding a chart, + the session is rolled back and a minimal response is returned with + only scalar attributes and position info.""" + from sqlalchemy.exc import SQLAlchemyError - from superset.daos.dashboard import DashboardDAO + mock_dashboard = _mock_dashboard(id=1, title="Dashboard") + mock_dashboard.slices = [] + mock_dashboard.position_json = "{}" - result = ( - DashboardDAO.find_by_id(original_dashboard.id, query_options=["dummy"]) - or original_dashboard - ) + mock_chart = _mock_chart(id=15) + mock_db_session.get.return_value = mock_chart - assert result is original_dashboard + updated = _mock_dashboard(id=1, title="Dashboard") + updated.slices = [_mock_chart(id=15)] + mock_update_command.return_value.run.return_value = updated + + # First call returns dashboard (validation), second raises (re-fetch) + mock_find_dashboard.side_effect = [ + mock_dashboard, + SQLAlchemyError("session error"), + ] + + request = {"dashboard_id": 1, "chart_id": 15} + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", {"request": request} + ) + + data = result.structured_content + assert data["error"] is None + mock_db_session.rollback.assert_called() + # Minimal response should have scalar fields + dash = data["dashboard"] + assert dash["id"] == 1 + assert dash["dashboard_title"] == "Dashboard" + assert "/superset/dashboard/1/" in data["dashboard_url"] + # Position info should still be returned + assert data["position"] is not None + assert "chart_key" in data["position"] + # Relationship fields should be empty (defaults) + assert dash["owners"] == [] + assert dash["tags"] == [] + assert dash["charts"] == []