mirror of
https://github.com/apache/superset.git
synced 2026-04-20 08:34:37 +00:00
fix(mcp): add try/except around DAO re-fetch to handle session errors in multi-tenant (#38859)
This commit is contained in:
@@ -410,38 +410,68 @@ class TestChartSerializationEagerLoading:
|
||||
with pytest.raises(DetachedInstanceError):
|
||||
serialize_chart_object(chart)
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
def test_generate_chart_refetches_via_dao(self, mock_find_by_id):
|
||||
"""The serialization path re-fetches the chart via ChartDAO.find_by_id
|
||||
with joinedload query_options for owners and tags."""
|
||||
def test_generate_chart_refetches_via_dao(self):
|
||||
"""The serialization path re-fetches the chart via
|
||||
ChartDAO.find_by_id() with query_options for owners and tags."""
|
||||
refetched_chart = _make_mock_chart()
|
||||
refetched_chart.tags = [Mock(id=1, name="tag1", type="custom")]
|
||||
refetched_chart.tags[0].description = ""
|
||||
|
||||
mock_find_by_id.return_value = refetched_chart
|
||||
|
||||
from superset.daos.chart import ChartDAO
|
||||
mock_dao = MagicMock()
|
||||
mock_dao.find_by_id.return_value = refetched_chart
|
||||
|
||||
chart = (
|
||||
ChartDAO.find_by_id(42, query_options=["dummy_option"])
|
||||
mock_dao.find_by_id(42, query_options=[Mock(), Mock()])
|
||||
or _make_mock_chart()
|
||||
)
|
||||
|
||||
assert chart is refetched_chart
|
||||
mock_find_by_id.assert_called_once_with(42, query_options=["dummy_option"])
|
||||
mock_dao.find_by_id.assert_called()
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
def test_generate_chart_falls_back_to_original_on_refetch_failure(
|
||||
self, mock_find_by_id
|
||||
):
|
||||
"""Falls back to original chart if ChartDAO.find_by_id returns None."""
|
||||
def test_generate_chart_falls_back_to_original_on_dao_none(self):
|
||||
"""Falls back to original chart if ChartDAO.find_by_id()
|
||||
returns None."""
|
||||
original_chart = _make_mock_chart()
|
||||
mock_find_by_id.return_value = None
|
||||
|
||||
from superset.daos.chart import ChartDAO
|
||||
mock_dao = MagicMock()
|
||||
mock_dao.find_by_id.return_value = None
|
||||
|
||||
chart = (
|
||||
ChartDAO.find_by_id(original_chart.id, query_options=[]) or original_chart
|
||||
)
|
||||
chart = mock_dao.find_by_id(42, query_options=[Mock()]) or original_chart
|
||||
|
||||
assert chart is original_chart
|
||||
|
||||
def test_generate_chart_refetch_sqlalchemy_error_rollback(self):
|
||||
"""When the DAO re-fetch raises SQLAlchemyError, the session is
|
||||
rolled back and a minimal chart_data dict is built from scalar
|
||||
attributes instead of calling serialize_chart_object (which would
|
||||
trigger lazy-loading on the same dead session)."""
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
original_chart = _make_mock_chart()
|
||||
mock_dao = MagicMock()
|
||||
mock_dao.find_by_id.side_effect = SQLAlchemyError("session error")
|
||||
mock_session = MagicMock()
|
||||
explore_url = "http://example.com/explore/?slice_id=42"
|
||||
|
||||
chart_data = None
|
||||
try:
|
||||
mock_dao.find_by_id(42, query_options=[Mock()])
|
||||
except SQLAlchemyError:
|
||||
mock_session.rollback()
|
||||
chart_data = {
|
||||
"id": original_chart.id,
|
||||
"slice_name": original_chart.slice_name,
|
||||
"viz_type": original_chart.viz_type,
|
||||
"url": explore_url,
|
||||
"uuid": str(original_chart.uuid) if original_chart.uuid else None,
|
||||
}
|
||||
|
||||
mock_session.rollback.assert_called()
|
||||
# Minimal chart_data should contain scalar fields only
|
||||
assert chart_data is not None
|
||||
assert chart_data["id"] == original_chart.id
|
||||
assert chart_data["slice_name"] == original_chart.slice_name
|
||||
assert chart_data["url"] == explore_url
|
||||
# No tags/owners keys — those would require relationship access
|
||||
assert "tags" not in chart_data
|
||||
assert "owners" not in chart_data
|
||||
|
||||
@@ -20,7 +20,7 @@ Unit tests for dashboard generation MCP tools
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
@@ -133,7 +133,8 @@ def _setup_generate_dashboard_mocks(
|
||||
|
||||
The tool creates dashboards directly via db.session (bypassing
|
||||
CreateDashboardCommand) and re-queries user/charts in the tool's
|
||||
own session. This helper wires up the mock chain for that path.
|
||||
own session. The re-fetch uses DashboardDAO.find_by_id() with
|
||||
query_options for eager loading of slice relationships.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
@@ -143,8 +144,8 @@ def _setup_generate_dashboard_mocks(
|
||||
mock_user.email = "admin@example.com"
|
||||
mock_user.active = True
|
||||
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query = MagicMock()
|
||||
mock_filter = MagicMock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_query.filter_by.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
@@ -153,6 +154,7 @@ def _setup_generate_dashboard_mocks(
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard_cls.return_value = dashboard
|
||||
# DashboardDAO.find_by_id is used for the re-fetch with eager loading
|
||||
mock_find_by_id.return_value = dashboard
|
||||
|
||||
|
||||
@@ -556,15 +558,15 @@ class TestAddChartToExistingDashboard:
|
||||
_mock_chart(id=20),
|
||||
_mock_chart(id=30),
|
||||
]
|
||||
# First call: initial validation returns original dashboard
|
||||
# Second call: re-fetch after update returns updated dashboard
|
||||
mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard]
|
||||
|
||||
mock_chart = _mock_chart(id=30, slice_name="New Chart")
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
|
||||
mock_update_command.return_value.run.return_value = updated_dashboard
|
||||
|
||||
# First DAO call returns initial dashboard (validation),
|
||||
# second DAO call returns updated dashboard (re-fetch with eager loading)
|
||||
mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard]
|
||||
|
||||
request = {"dashboard_id": 1, "chart_id": 30}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
@@ -700,8 +702,6 @@ class TestAddChartToExistingDashboard:
|
||||
mock_dashboard = _mock_dashboard(id=2)
|
||||
mock_dashboard.slices = []
|
||||
mock_dashboard.position_json = "{}"
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
mock_chart = _mock_chart(id=15)
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
|
||||
@@ -709,6 +709,10 @@ class TestAddChartToExistingDashboard:
|
||||
updated_dashboard.slices = [_mock_chart(id=15)]
|
||||
mock_update_command.return_value.run.return_value = updated_dashboard
|
||||
|
||||
# First DAO call returns initial dashboard (validation),
|
||||
# second returns updated dashboard (re-fetch)
|
||||
mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard]
|
||||
|
||||
request = {"dashboard_id": 2, "chart_id": 15}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
@@ -809,8 +813,6 @@ class TestAddChartToExistingDashboard:
|
||||
"DASHBOARD_VERSION_KEY": "v2",
|
||||
}
|
||||
)
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
mock_chart = _mock_chart(id=25, slice_name="Tab Chart")
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
|
||||
@@ -818,6 +820,10 @@ class TestAddChartToExistingDashboard:
|
||||
updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=25)]
|
||||
mock_update_command.return_value.run.return_value = updated_dashboard
|
||||
|
||||
# First DAO call returns initial dashboard (validation),
|
||||
# second returns updated dashboard (re-fetch)
|
||||
mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard]
|
||||
|
||||
request = {"dashboard_id": 3, "chart_id": 25}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
@@ -909,8 +915,6 @@ class TestAddChartToExistingDashboard:
|
||||
"DASHBOARD_VERSION_KEY": "v2",
|
||||
}
|
||||
)
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
mock_chart = _mock_chart(id=30, slice_name="Customer Chart")
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
|
||||
@@ -918,6 +922,10 @@ class TestAddChartToExistingDashboard:
|
||||
updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=30)]
|
||||
mock_update_command.return_value.run.return_value = updated_dashboard
|
||||
|
||||
# First DAO call returns initial dashboard (validation),
|
||||
# second returns updated dashboard (re-fetch)
|
||||
mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard]
|
||||
|
||||
request = {"dashboard_id": 3, "chart_id": 30, "target_tab": "Customers"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
@@ -984,8 +992,6 @@ class TestAddChartToExistingDashboard:
|
||||
"DASHBOARD_VERSION_KEY": "v2",
|
||||
}
|
||||
)
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
mock_chart = _mock_chart(id=50, slice_name="New Nanoid Chart")
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
|
||||
@@ -993,6 +999,10 @@ class TestAddChartToExistingDashboard:
|
||||
updated_dashboard.slices = [_mock_chart(id=10), _mock_chart(id=50)]
|
||||
mock_update_command.return_value.run.return_value = updated_dashboard
|
||||
|
||||
# First DAO call returns initial dashboard (validation),
|
||||
# second returns updated dashboard (re-fetch)
|
||||
mock_find_dashboard.side_effect = [mock_dashboard, updated_dashboard]
|
||||
|
||||
request = {"dashboard_id": 4, "chart_id": 50}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
@@ -1218,58 +1228,122 @@ class TestGenerateTitleFromCharts:
|
||||
|
||||
|
||||
class TestDashboardSerializationEagerLoading:
|
||||
"""Tests for eager loading fix in dashboard serialization paths."""
|
||||
"""Tests for eager loading fix in dashboard serialization paths.
|
||||
|
||||
The re-fetch uses DashboardDAO.find_by_id() with query_options for
|
||||
eager loading. A try/except around the DAO call handles "Can't
|
||||
reconnect until invalid transaction is rolled back" errors in
|
||||
multi-tenant environments by rolling back and falling back to the
|
||||
original dashboard object.
|
||||
"""
|
||||
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
def test_generate_dashboard_refetches_via_dao(self, mock_find_by_id):
|
||||
"""generate_dashboard re-fetches dashboard via DashboardDAO.find_by_id
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_refetches_via_dao(
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""generate_dashboard re-fetches dashboard via DashboardDAO.find_by_id()
|
||||
with eager-loaded slice relationships before serialization."""
|
||||
refetched_dashboard = _mock_dashboard()
|
||||
refetched_chart = _mock_chart(id=1, slice_name="Refetched Chart")
|
||||
refetched_dashboard.slices = [refetched_chart]
|
||||
|
||||
mock_find_by_id.return_value = refetched_dashboard
|
||||
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
result = (
|
||||
DashboardDAO.find_by_id(1, query_options=["dummy"]) or _mock_dashboard()
|
||||
charts = [_mock_chart(id=1, slice_name="Chart 1")]
|
||||
dashboard = _mock_dashboard(id=10, title="Refetch Test")
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, dashboard
|
||||
)
|
||||
|
||||
assert result is refetched_dashboard
|
||||
mock_find_by_id.assert_called_once_with(1, query_options=["dummy"])
|
||||
request = {"chart_ids": [1], "dashboard_title": "Refetch Test"}
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
# Verify DashboardDAO.find_by_id was called for re-fetch
|
||||
mock_find_by_id.assert_called()
|
||||
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
def test_add_chart_refetches_dashboard_via_dao(self, mock_find_by_id):
|
||||
"""add_chart_to_existing_dashboard re-fetches dashboard via
|
||||
DashboardDAO.find_by_id with eager-loaded slice relationships."""
|
||||
original_dashboard = _mock_dashboard()
|
||||
refetched_dashboard = _mock_dashboard()
|
||||
refetched_dashboard.slices = [_mock_chart()]
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_refetch_sqlalchemy_error_rollback(
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""When the DAO re-fetch raises SQLAlchemyError, the session is
|
||||
rolled back and a minimal response is returned with only scalar
|
||||
attributes (no owners/tags/charts that would trigger lazy-loading)."""
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
mock_find_by_id.return_value = refetched_dashboard
|
||||
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
result = (
|
||||
DashboardDAO.find_by_id(original_dashboard.id, query_options=["dummy"])
|
||||
or original_dashboard
|
||||
charts = [_mock_chart(id=1, slice_name="Chart 1")]
|
||||
dashboard = _mock_dashboard(id=10, title="Rollback Test")
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, dashboard
|
||||
)
|
||||
# Make the DAO re-fetch raise SQLAlchemyError
|
||||
mock_find_by_id.side_effect = SQLAlchemyError("session error")
|
||||
|
||||
assert result is refetched_dashboard
|
||||
request = {"chart_ids": [1], "dashboard_title": "Rollback Test"}
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
data = result.structured_content
|
||||
assert data["error"] is None
|
||||
mock_db_session.rollback.assert_called()
|
||||
# Minimal response should have scalar fields
|
||||
dash = data["dashboard"]
|
||||
assert dash["id"] == 10
|
||||
assert dash["dashboard_title"] == "Rollback Test"
|
||||
assert "/superset/dashboard/10/" in data["dashboard_url"]
|
||||
# Relationship fields should be empty (defaults)
|
||||
assert dash["owners"] == []
|
||||
assert dash["tags"] == []
|
||||
assert dash["charts"] == []
|
||||
|
||||
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
def test_add_chart_falls_back_on_refetch_failure(self, mock_find_by_id):
|
||||
"""add_chart_to_existing_dashboard falls back to original dashboard
|
||||
if DashboardDAO.find_by_id returns None."""
|
||||
original_dashboard = _mock_dashboard()
|
||||
mock_find_by_id.return_value = None
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_refetch_sqlalchemy_error_rollback(
|
||||
self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server
|
||||
):
|
||||
"""When the DAO re-fetch raises SQLAlchemyError after adding a chart,
|
||||
the session is rolled back and a minimal response is returned with
|
||||
only scalar attributes and position info."""
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
mock_dashboard = _mock_dashboard(id=1, title="Dashboard")
|
||||
mock_dashboard.slices = []
|
||||
mock_dashboard.position_json = "{}"
|
||||
|
||||
result = (
|
||||
DashboardDAO.find_by_id(original_dashboard.id, query_options=["dummy"])
|
||||
or original_dashboard
|
||||
)
|
||||
mock_chart = _mock_chart(id=15)
|
||||
mock_db_session.get.return_value = mock_chart
|
||||
|
||||
assert result is original_dashboard
|
||||
updated = _mock_dashboard(id=1, title="Dashboard")
|
||||
updated.slices = [_mock_chart(id=15)]
|
||||
mock_update_command.return_value.run.return_value = updated
|
||||
|
||||
# First call returns dashboard (validation), second raises (re-fetch)
|
||||
mock_find_dashboard.side_effect = [
|
||||
mock_dashboard,
|
||||
SQLAlchemyError("session error"),
|
||||
]
|
||||
|
||||
request = {"dashboard_id": 1, "chart_id": 15}
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
data = result.structured_content
|
||||
assert data["error"] is None
|
||||
mock_db_session.rollback.assert_called()
|
||||
# Minimal response should have scalar fields
|
||||
dash = data["dashboard"]
|
||||
assert dash["id"] == 1
|
||||
assert dash["dashboard_title"] == "Dashboard"
|
||||
assert "/superset/dashboard/1/" in data["dashboard_url"]
|
||||
# Position info should still be returned
|
||||
assert data["position"] is not None
|
||||
assert "chart_key" in data["position"]
|
||||
# Relationship fields should be empty (defaults)
|
||||
assert dash["owners"] == []
|
||||
assert dash["tags"] == []
|
||||
assert dash["charts"] == []
|
||||
|
||||
Reference in New Issue
Block a user