fix(mcp): prevent PendingRollbackError from poisoned sessions after SSL drops (#38934)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
(cherry picked from commit d331a043a3)
This commit is contained in:
Amin Ghadersohi
2026-03-30 16:30:15 +02:00
committed by Michael S. Molina
parent f1f757b5c5
commit bda02a3fdc
5 changed files with 199 additions and 56 deletions

View File

@@ -742,7 +742,13 @@ async def generate_chart( # noqa: C901
chart.id,
exc_info=True,
)
db.session.rollback()
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during chart re-fetch error handling",
exc_info=True,
)
chart_data = {
"id": chart.id,
"slice_name": chart.slice_name,
@@ -805,6 +811,14 @@ async def generate_chart( # noqa: C901
return GenerateChartResponse.model_validate(result)
except (CommandException, SQLAlchemyError, KeyError, ValueError) as e:
from superset import db
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during error handling", exc_info=True
)
await ctx.error(
"Chart generation failed: error=%s, execution_time_ms=%s"
% (

View File

@@ -23,6 +23,7 @@ import logging
import time
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
@@ -268,7 +269,21 @@ async def update_chart(
}
return GenerateChartResponse.model_validate(result)
except (CommandException, ValueError, KeyError, AttributeError) as e:
except (
CommandException,
SQLAlchemyError,
ValueError,
KeyError,
AttributeError,
) as e:
from superset import db
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during error handling", exc_info=True
)
execution_time = int((time.time() - start_time) * 1000)
return GenerateChartResponse.model_validate(
{

View File

@@ -435,25 +435,62 @@ def add_chart_to_existing_dashboard(
# Re-fetch the dashboard with eager-loaded relationships to avoid
# "Instance is not bound to a Session" errors when serializing
# chart .tags and .owners.
# chart .tags and .owners. The preceding command.run() commit may
# invalidate the session in multi-tenant environments; on failure,
# return a minimal response using only scalar attributes that are
# already loaded — relationship fields (owners, tags, slices) would
# trigger lazy-loading on the same dead session.
from sqlalchemy.orm import subqueryload
from superset.daos.dashboard import DashboardDAO
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
updated_dashboard = (
DashboardDAO.find_by_id(
updated_dashboard.id,
query_options=[
subqueryload(Dashboard.slices).subqueryload(Slice.owners),
subqueryload(Dashboard.slices).subqueryload(Slice.tags),
subqueryload(Dashboard.owners),
subqueryload(Dashboard.tags),
],
try:
updated_dashboard = (
DashboardDAO.find_by_id(
updated_dashboard.id,
query_options=[
subqueryload(Dashboard.slices).subqueryload(Slice.owners),
subqueryload(Dashboard.slices).subqueryload(Slice.tags),
subqueryload(Dashboard.owners),
subqueryload(Dashboard.tags),
],
)
or updated_dashboard
)
except SQLAlchemyError:
logger.warning(
"Re-fetch of dashboard %s failed; returning minimal response",
updated_dashboard.id,
exc_info=True,
)
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during dashboard re-fetch error handling",
exc_info=True,
)
dashboard_url = (
f"{get_superset_base_url()}/superset/dashboard/{updated_dashboard.id}/"
)
position_info = {
"row": row_key,
"chart_key": chart_key,
"row_key": row_key,
}
return AddChartToDashboardResponse(
dashboard=DashboardInfo(
id=updated_dashboard.id,
dashboard_title=updated_dashboard.dashboard_title,
published=updated_dashboard.published,
chart_count=len(all_chart_objects),
url=dashboard_url,
),
dashboard_url=dashboard_url,
position=position_info,
error=None,
)
or updated_dashboard
)
# Convert to response format
from superset.mcp_service.dashboard.schemas import (
@@ -511,6 +548,14 @@ def add_chart_to_existing_dashboard(
)
except (CommandException, SQLAlchemyError, KeyError, ValueError) as e:
from superset import db
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during error handling", exc_info=True
)
logger.error("Error adding chart to dashboard: %s", e)
return AddChartToDashboardResponse(
dashboard=None,

View File

@@ -187,7 +187,7 @@ def _generate_title_from_charts(chart_objects: List[Any]) -> str:
destructiveHint=False,
),
)
def generate_dashboard(
def generate_dashboard( # noqa: C901
request: GenerateDashboardRequest, ctx: Context
) -> GenerateDashboardResponse:
"""Create dashboard from chart IDs.
@@ -323,9 +323,15 @@ def generate_dashboard(
dashboard.slices = fresh_charts
db.session.add(dashboard)
db.session.commit()
db.session.commit() # pylint: disable=consider-using-transaction
except SQLAlchemyError as db_err:
db.session.rollback()
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during error handling",
exc_info=True,
)
logger.error(
"Dashboard creation failed: %s",
db_err,
@@ -365,7 +371,13 @@ def generate_dashboard(
dashboard.id,
exc_info=True,
)
db.session.rollback()
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during dashboard re-fetch error handling",
exc_info=True,
)
dashboard_url = (
f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/"
)
@@ -429,6 +441,14 @@ def generate_dashboard(
)
except (SQLAlchemyError, ValueError, AttributeError, ValidationError) as e:
from superset import db
try:
db.session.rollback() # pylint: disable=consider-using-transaction
except SQLAlchemyError:
logger.warning(
"Database rollback failed during error handling", exc_info=True
)
logger.error("Error creating dashboard: %s", e, exc_info=True)
return GenerateDashboardResponse(
dashboard=None,

View File

@@ -809,8 +809,6 @@ class TestAddChartToExistingDashboard:
"DASHBOARD_VERSION_KEY": "v2",
}
)
mock_find_dashboard.return_value = mock_dashboard
mock_chart = _mock_chart(id=25, slice_name="Tab Chart")
mock_db_session.get.return_value = mock_chart
@@ -818,6 +816,10 @@ class TestAddChartToExistingDashboard:
updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=25)]
mock_update_command.return_value.run.return_value = updated_dashboard
# side_effect: first call returns initial dashboard (validation),
# second call returns updated dashboard (re-fetch after update)
mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard]
request = {"dashboard_id": 3, "chart_id": 25}
async with Client(mcp_server) as client:
@@ -909,8 +911,6 @@ class TestAddChartToExistingDashboard:
"DASHBOARD_VERSION_KEY": "v2",
}
)
mock_find_dashboard.return_value = mock_dashboard
mock_chart = _mock_chart(id=30, slice_name="Customer Chart")
mock_db_session.get.return_value = mock_chart
@@ -918,6 +918,10 @@ class TestAddChartToExistingDashboard:
updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=30)]
mock_update_command.return_value.run.return_value = updated_dashboard
# side_effect: first call returns initial dashboard (validation),
# second call returns updated dashboard (re-fetch after update)
mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard]
request = {"dashboard_id": 3, "chart_id": 30, "target_tab": "Customers"}
async with Client(mcp_server) as client:
@@ -1404,56 +1408,101 @@ class TestGenerateTitleFromCharts:
class TestDashboardSerializationEagerLoading:
"""Tests for eager loading fix in dashboard serialization paths."""
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
def test_generate_dashboard_refetches_via_dao(self, mock_find_by_id):
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_refetches_via_dao(
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
):
"""generate_dashboard re-fetches dashboard via DashboardDAO.find_by_id
with eager-loaded slice relationships before serialization."""
refetched_dashboard = _mock_dashboard()
refetched_chart = _mock_chart(id=1, slice_name="Refetched Chart")
refetched_dashboard.slices = [refetched_chart]
charts = [_mock_chart(id=1, slice_name="Refetched Chart")]
refetched_dashboard = _mock_dashboard(id=10)
refetched_dashboard.slices = charts
mock_find_by_id.return_value = refetched_dashboard
from superset.daos.dashboard import DashboardDAO
result = (
DashboardDAO.find_by_id(1, query_options=["dummy"]) or _mock_dashboard()
_setup_generate_dashboard_mocks(
mock_db_session,
mock_find_by_id,
mock_dashboard_cls,
charts,
refetched_dashboard,
)
assert result is refetched_dashboard
mock_find_by_id.assert_called_once_with(1, query_options=["dummy"])
request = {"chart_ids": [1]}
async with Client(mcp_server) as client:
result = await client.call_tool("generate_dashboard", {"request": request})
assert result.structured_content["error"] is None
# Verify DashboardDAO.find_by_id was called for re-fetch
mock_find_by_id.assert_called()
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
def test_add_chart_refetches_dashboard_via_dao(self, mock_find_by_id):
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_add_chart_refetches_dashboard_via_dao(
self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server
):
"""add_chart_to_existing_dashboard re-fetches dashboard via
DashboardDAO.find_by_id with eager-loaded slice relationships."""
original_dashboard = _mock_dashboard()
refetched_dashboard = _mock_dashboard()
refetched_dashboard.slices = [_mock_chart()]
mock_dashboard = _mock_dashboard(id=1)
mock_dashboard.slices = []
mock_dashboard.position_json = "{}"
mock_find_by_id.return_value = refetched_dashboard
mock_chart = _mock_chart(id=5, slice_name="New Chart")
mock_db_session.get.return_value = mock_chart
from superset.daos.dashboard import DashboardDAO
updated_dashboard = _mock_dashboard(id=1)
updated_dashboard.slices = [mock_chart]
mock_update_command.return_value.run.return_value = updated_dashboard
result = (
DashboardDAO.find_by_id(original_dashboard.id, query_options=["dummy"])
or original_dashboard
)
# side_effect: first call returns initial dashboard (validation),
# second call returns updated dashboard (re-fetch with eager loading)
mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard]
assert result is refetched_dashboard
request = {"dashboard_id": 1, "chart_id": 5}
async with Client(mcp_server) as client:
result = await client.call_tool(
"add_chart_to_existing_dashboard", {"request": request}
)
assert result.structured_content["error"] is None
# DashboardDAO.find_by_id called twice: validation + re-fetch
assert mock_find_dashboard.call_count == 2
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
def test_add_chart_falls_back_on_refetch_failure(self, mock_find_by_id):
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_add_chart_falls_back_on_refetch_failure(
self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server
):
"""add_chart_to_existing_dashboard falls back to original dashboard
if DashboardDAO.find_by_id returns None."""
original_dashboard = _mock_dashboard()
mock_find_by_id.return_value = None
if DashboardDAO.find_by_id returns None on re-fetch."""
mock_dashboard = _mock_dashboard(id=1)
mock_dashboard.slices = []
mock_dashboard.position_json = "{}"
from superset.daos.dashboard import DashboardDAO
mock_chart = _mock_chart(id=5, slice_name="New Chart")
mock_db_session.get.return_value = mock_chart
result = (
DashboardDAO.find_by_id(original_dashboard.id, query_options=["dummy"])
or original_dashboard
)
updated_dashboard = _mock_dashboard(id=1)
updated_dashboard.slices = [mock_chart]
mock_update_command.return_value.run.return_value = updated_dashboard
assert result is original_dashboard
# side_effect: first call returns dashboard (validation),
# second call returns None (re-fetch fails, should fall back)
mock_find_dashboard.side_effect = [mock_dashboard, None]
request = {"dashboard_id": 1, "chart_id": 5}
async with Client(mcp_server) as client:
result = await client.call_tool(
"add_chart_to_existing_dashboard", {"request": request}
)
# Tool should still succeed using fallback dashboard
assert result.structured_content["error"] is None