diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 2ba89cd71a8..c3cc82cf7bb 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -742,7 +742,13 @@ async def generate_chart( # noqa: C901 chart.id, exc_info=True, ) - db.session.rollback() + try: + db.session.rollback() # pylint: disable=consider-using-transaction + except SQLAlchemyError: + logger.warning( + "Database rollback failed during chart re-fetch error handling", + exc_info=True, + ) chart_data = { "id": chart.id, "slice_name": chart.slice_name, @@ -805,6 +811,14 @@ async def generate_chart( # noqa: C901 return GenerateChartResponse.model_validate(result) except (CommandException, SQLAlchemyError, KeyError, ValueError) as e: + from superset import db + + try: + db.session.rollback() # pylint: disable=consider-using-transaction + except SQLAlchemyError: + logger.warning( + "Database rollback failed during error handling", exc_info=True + ) await ctx.error( "Chart generation failed: error=%s, execution_time_ms=%s" % ( diff --git a/superset/mcp_service/chart/tool/update_chart.py b/superset/mcp_service/chart/tool/update_chart.py index a8f522c8758..767ef615f8b 100644 --- a/superset/mcp_service/chart/tool/update_chart.py +++ b/superset/mcp_service/chart/tool/update_chart.py @@ -23,6 +23,7 @@ import logging import time from fastmcp import Context +from sqlalchemy.exc import SQLAlchemyError from superset_core.mcp.decorators import tool, ToolAnnotations from superset.commands.exceptions import CommandException @@ -268,7 +269,21 @@ async def update_chart( } return GenerateChartResponse.model_validate(result) - except (CommandException, ValueError, KeyError, AttributeError) as e: + except ( + CommandException, + SQLAlchemyError, + ValueError, + KeyError, + AttributeError, + ) as e: + from superset import db + + try: + db.session.rollback() # pylint: disable=consider-using-transaction + except SQLAlchemyError: + logger.warning( + "Database rollback failed during error handling", exc_info=True + ) execution_time = int((time.time() - start_time) * 1000) return GenerateChartResponse.model_validate( { diff --git a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py index 395e970ef80..888f2423fcc 100644 --- a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py +++ b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py @@ -435,25 +435,62 @@ def add_chart_to_existing_dashboard( # Re-fetch the dashboard with eager-loaded relationships to avoid # "Instance is not bound to a Session" errors when serializing - # chart .tags and .owners. + # chart .tags and .owners. The preceding command.run() commit may + # invalidate the session in multi-tenant environments; on failure, + # return a minimal response using only scalar attributes that are + # already loaded — relationship fields (owners, tags, slices) would + # trigger lazy-loading on the same dead session. from sqlalchemy.orm import subqueryload - from superset.daos.dashboard import DashboardDAO from superset.models.dashboard import Dashboard from superset.models.slice import Slice - updated_dashboard = ( - DashboardDAO.find_by_id( - updated_dashboard.id, - query_options=[ - subqueryload(Dashboard.slices).subqueryload(Slice.owners), - subqueryload(Dashboard.slices).subqueryload(Slice.tags), - subqueryload(Dashboard.owners), - subqueryload(Dashboard.tags), - ], + try: + updated_dashboard = ( + DashboardDAO.find_by_id( + updated_dashboard.id, + query_options=[ + subqueryload(Dashboard.slices).subqueryload(Slice.owners), + subqueryload(Dashboard.slices).subqueryload(Slice.tags), + subqueryload(Dashboard.owners), + subqueryload(Dashboard.tags), + ], + ) + or updated_dashboard + ) + except SQLAlchemyError: + logger.warning( + "Re-fetch of dashboard %s failed; returning minimal response", + updated_dashboard.id, + exc_info=True, + ) + try: + db.session.rollback() # pylint: disable=consider-using-transaction + except SQLAlchemyError: + logger.warning( + "Database rollback failed during dashboard re-fetch error handling", + exc_info=True, + ) + dashboard_url = ( + f"{get_superset_base_url()}/superset/dashboard/{updated_dashboard.id}/" + ) + position_info = { + "row": row_key, + "chart_key": chart_key, + "row_key": row_key, + } + return AddChartToDashboardResponse( + dashboard=DashboardInfo( + id=updated_dashboard.id, + dashboard_title=updated_dashboard.dashboard_title, + published=updated_dashboard.published, + chart_count=len(all_chart_objects), + url=dashboard_url, + ), + dashboard_url=dashboard_url, + position=position_info, + error=None, ) - or updated_dashboard - ) # Convert to response format from superset.mcp_service.dashboard.schemas import ( @@ -511,6 +548,14 @@ def add_chart_to_existing_dashboard( ) except (CommandException, SQLAlchemyError, KeyError, ValueError) as e: + from superset import db + + try: + db.session.rollback() # pylint: disable=consider-using-transaction + except SQLAlchemyError: + logger.warning( + "Database rollback failed during error handling", exc_info=True + ) logger.error("Error adding chart to dashboard: %s", e) return AddChartToDashboardResponse( dashboard=None, diff --git a/superset/mcp_service/dashboard/tool/generate_dashboard.py b/superset/mcp_service/dashboard/tool/generate_dashboard.py index e4559d9f0f8..1b0f457771f 100644 --- a/superset/mcp_service/dashboard/tool/generate_dashboard.py +++ b/superset/mcp_service/dashboard/tool/generate_dashboard.py @@ -187,7 +187,7 @@ def _generate_title_from_charts(chart_objects: List[Any]) -> str: destructiveHint=False, ), ) -def generate_dashboard( +def generate_dashboard( # noqa: C901 request: GenerateDashboardRequest, ctx: Context ) -> GenerateDashboardResponse: """Create dashboard from chart IDs. @@ -323,9 +323,15 @@ def generate_dashboard( dashboard.slices = fresh_charts db.session.add(dashboard) - db.session.commit() + db.session.commit() # pylint: disable=consider-using-transaction except SQLAlchemyError as db_err: - db.session.rollback() + try: + db.session.rollback() # pylint: disable=consider-using-transaction + except SQLAlchemyError: + logger.warning( + "Database rollback failed during error handling", + exc_info=True, + ) logger.error( "Dashboard creation failed: %s", db_err, @@ -365,7 +371,13 @@ def generate_dashboard( dashboard.id, exc_info=True, ) - db.session.rollback() + try: + db.session.rollback() # pylint: disable=consider-using-transaction + except SQLAlchemyError: + logger.warning( + "Database rollback failed during dashboard re-fetch error handling", + exc_info=True, + ) dashboard_url = ( f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/" ) @@ -429,6 +441,14 @@ def generate_dashboard( ) except (SQLAlchemyError, ValueError, AttributeError, ValidationError) as e: + from superset import db + + try: + db.session.rollback() # pylint: disable=consider-using-transaction + except SQLAlchemyError: + logger.warning( + "Database rollback failed during error handling", exc_info=True + ) logger.error("Error creating dashboard: %s", e, exc_info=True) return GenerateDashboardResponse( dashboard=None, diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py index cccfbd0cccf..71e4b83b5ec 100644 --- a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py @@ -809,8 +809,6 @@ class TestAddChartToExistingDashboard: "DASHBOARD_VERSION_KEY": "v2", } ) - mock_find_dashboard.return_value = mock_dashboard - mock_chart = _mock_chart(id=25, slice_name="Tab Chart") mock_db_session.get.return_value = mock_chart @@ -818,6 +816,10 @@ class TestAddChartToExistingDashboard: updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=25)] mock_update_command.return_value.run.return_value = updated_dashboard + # side_effect: first call returns initial dashboard (validation), + # second call returns updated dashboard (re-fetch after update) + mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard] + request = {"dashboard_id": 3, "chart_id": 25} async with Client(mcp_server) as client: @@ -909,8 +911,6 @@ class TestAddChartToExistingDashboard: "DASHBOARD_VERSION_KEY": "v2", } ) - mock_find_dashboard.return_value = mock_dashboard - mock_chart = _mock_chart(id=30, slice_name="Customer Chart") mock_db_session.get.return_value = mock_chart @@ -918,6 +918,10 @@ class TestAddChartToExistingDashboard: updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=30)] mock_update_command.return_value.run.return_value = updated_dashboard + # side_effect: first call returns initial dashboard (validation), + # second call returns updated dashboard (re-fetch after update) + mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard] + request = {"dashboard_id": 3, "chart_id": 30, "target_tab": "Customers"} async with Client(mcp_server) as client: @@ -1404,56 +1408,101 @@ class TestGenerateTitleFromCharts: class TestDashboardSerializationEagerLoading: """Tests for eager loading fix in dashboard serialization paths.""" + @patch("superset.models.dashboard.Dashboard") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") - def test_generate_dashboard_refetches_via_dao(self, mock_find_by_id): + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_generate_dashboard_refetches_via_dao( + self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server + ): """generate_dashboard re-fetches dashboard via DashboardDAO.find_by_id with eager-loaded slice relationships before serialization.""" - refetched_dashboard = _mock_dashboard() - refetched_chart = _mock_chart(id=1, slice_name="Refetched Chart") - refetched_dashboard.slices = [refetched_chart] + charts = [_mock_chart(id=1, slice_name="Refetched Chart")] + refetched_dashboard = _mock_dashboard(id=10) + refetched_dashboard.slices = charts - mock_find_by_id.return_value = refetched_dashboard - - from superset.daos.dashboard import DashboardDAO - - result = ( - DashboardDAO.find_by_id(1, query_options=["dummy"]) or _mock_dashboard() + _setup_generate_dashboard_mocks( + mock_db_session, + mock_find_by_id, + mock_dashboard_cls, + charts, + refetched_dashboard, ) - assert result is refetched_dashboard - mock_find_by_id.assert_called_once_with(1, query_options=["dummy"]) + request = {"chart_ids": [1]} + async with Client(mcp_server) as client: + result = await client.call_tool("generate_dashboard", {"request": request}) + + assert result.structured_content["error"] is None + # Verify DashboardDAO.find_by_id was called for re-fetch + mock_find_by_id.assert_called() + + @patch("superset.commands.dashboard.update.UpdateDashboardCommand") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") - def test_add_chart_refetches_dashboard_via_dao(self, mock_find_by_id): + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_add_chart_refetches_dashboard_via_dao( + self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server + ): """add_chart_to_existing_dashboard re-fetches dashboard via DashboardDAO.find_by_id with eager-loaded slice relationships.""" - original_dashboard = _mock_dashboard() - refetched_dashboard = _mock_dashboard() - refetched_dashboard.slices = [_mock_chart()] + mock_dashboard = _mock_dashboard(id=1) + mock_dashboard.slices = [] + mock_dashboard.position_json = "{}" - mock_find_by_id.return_value = refetched_dashboard + mock_chart = _mock_chart(id=5, slice_name="New Chart") + mock_db_session.get.return_value = mock_chart - from superset.daos.dashboard import DashboardDAO + updated_dashboard = _mock_dashboard(id=1) + updated_dashboard.slices = [mock_chart] + mock_update_command.return_value.run.return_value = updated_dashboard - result = ( - DashboardDAO.find_by_id(original_dashboard.id, query_options=["dummy"]) - or original_dashboard - ) + # side_effect: first call returns initial dashboard (validation), + # second call returns updated dashboard (re-fetch with eager loading) + mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard] - assert result is refetched_dashboard + request = {"dashboard_id": 1, "chart_id": 5} + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", {"request": request} + ) + + assert result.structured_content["error"] is None + # DashboardDAO.find_by_id called twice: validation + re-fetch + assert mock_find_dashboard.call_count == 2 + + @patch("superset.commands.dashboard.update.UpdateDashboardCommand") @patch("superset.daos.dashboard.DashboardDAO.find_by_id") - def test_add_chart_falls_back_on_refetch_failure(self, mock_find_by_id): + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_add_chart_falls_back_on_refetch_failure( + self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server + ): """add_chart_to_existing_dashboard falls back to original dashboard - if DashboardDAO.find_by_id returns None.""" - original_dashboard = _mock_dashboard() - mock_find_by_id.return_value = None + if DashboardDAO.find_by_id returns None on re-fetch.""" + mock_dashboard = _mock_dashboard(id=1) + mock_dashboard.slices = [] + mock_dashboard.position_json = "{}" - from superset.daos.dashboard import DashboardDAO + mock_chart = _mock_chart(id=5, slice_name="New Chart") + mock_db_session.get.return_value = mock_chart - result = ( - DashboardDAO.find_by_id(original_dashboard.id, query_options=["dummy"]) - or original_dashboard - ) + updated_dashboard = _mock_dashboard(id=1) + updated_dashboard.slices = [mock_chart] + mock_update_command.return_value.run.return_value = updated_dashboard - assert result is original_dashboard + # side_effect: first call returns dashboard (validation), + # second call returns None (re-fetch fails, should fall back) + mock_find_dashboard.side_effect = [mock_dashboard, None] + + request = {"dashboard_id": 1, "chart_id": 5} + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", {"request": request} + ) + + # Tool should still succeed using fallback dashboard + assert result.structured_content["error"] is None