fix(mcp): fix generate_dashboard cross-session SQLAlchemy error (#38827)

This commit is contained in:
Amin Ghadersohi
2026-03-24 11:39:37 -04:00
committed by GitHub
parent e2bb20121e
commit 09594b32f9
2 changed files with 185 additions and 136 deletions

View File

@@ -25,6 +25,7 @@ import logging
from typing import Any, Dict, List
from fastmcp import Context
from flask import g
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
@@ -200,10 +201,12 @@ def generate_dashboard(
Returns:
- Dashboard ID and URL
"""
from pydantic import ValidationError
from sqlalchemy.exc import SQLAlchemyError
try:
# Get chart objects from IDs (required for SQLAlchemy relationships)
from superset import db
from superset.commands.dashboard.create import CreateDashboardCommand
from superset.models.slice import Slice
with event_logger.log_context(action="mcp.generate_dashboard.chart_validation"):
@@ -235,65 +238,92 @@ def generate_dashboard(
else _generate_title_from_charts(chart_objects)
)
# Prepare dashboard data and create dashboard
# Create the dashboard directly with db.session instead of using
# CreateDashboardCommand. The command's @transaction decorator
# may operate in a different SQLAlchemy scoped-session than the
# one g.user and chart ORM objects are bound to in the MCP
# context, causing "Object is already attached to session X
# (this is Y)" errors. By re-querying all ORM objects in the
# tool's own db.session we keep everything in a single session.
from sqlalchemy.orm import subqueryload
from superset.models.dashboard import Dashboard
with event_logger.log_context(action="mcp.generate_dashboard.db_write"):
dashboard_data: Dict[str, Any] = {
"dashboard_title": dashboard_title,
"json_metadata": json.dumps(
{
"filter_scopes": {},
"expanded_slices": {},
"refresh_frequency": 0,
"timed_refresh_immune_slices": [],
"color_scheme": None,
"label_colors": {},
"shared_label_colors": {},
"color_scheme_domain": [],
"cross_filters_enabled": False,
"native_filter_configuration": [],
"global_chart_configuration": {
"scope": {
"rootPath": ["ROOT_ID"],
"excluded": [],
}
},
"chart_configuration": {},
}
),
"position_json": json.dumps(layout),
"published": request.published,
"slices": chart_objects, # Pass ORM objects, not IDs
}
json_metadata = json.dumps(
{
"filter_scopes": {},
"expanded_slices": {},
"refresh_frequency": 0,
"timed_refresh_immune_slices": [],
"color_scheme": None,
"label_colors": {},
"shared_label_colors": {},
"color_scheme_domain": [],
"cross_filters_enabled": False,
"native_filter_configuration": [],
"global_chart_configuration": {
"scope": {
"rootPath": ["ROOT_ID"],
"excluded": [],
}
},
"chart_configuration": {},
}
)
if request.description:
dashboard_data["description"] = request.description
# Create the dashboard using Superset's command pattern
try:
command = CreateDashboardCommand(dashboard_data)
dashboard = command.run()
except Exception as cmd_err:
# Surface the root cause from @transaction's error wrapping
root_cause = cmd_err.__cause__ or cmd_err
dashboard = Dashboard()
dashboard.dashboard_title = dashboard_title
dashboard.json_metadata = json_metadata
dashboard.position_json = json.dumps(layout)
dashboard.published = request.published
if request.description:
dashboard.description = request.description
# Re-query the current user and charts directly in the
# current db.session. g.user was loaded in a Flask
# app_context that has since been torn down (the
# middleware's ``with flask_app.app_context()`` exits
# before the tool function runs), so the User object
# is bound to a dead/different scoped session.
# Querying fresh avoids all cross-session errors.
from superset.extensions import security_manager
current_user = (
db.session.query(security_manager.user_model)
.filter_by(id=g.user.id)
.first()
)
if current_user:
dashboard.owners = [current_user]
fresh_charts = (
db.session.query(Slice)
.filter(Slice.id.in_(request.chart_ids))
.order_by(Slice.id)
.all()
)
dashboard.slices = fresh_charts
db.session.add(dashboard)
db.session.commit()
except SQLAlchemyError as db_err:
db.session.rollback()
logger.error(
"CreateDashboardCommand failed: %s (cause: %s)",
cmd_err,
root_cause,
"Dashboard creation failed: %s",
db_err,
exc_info=True,
)
return GenerateDashboardResponse(
dashboard=None,
dashboard_url=None,
error=f"Failed to create dashboard: {root_cause}",
error="Failed to create dashboard due to a database error.",
)
# Re-fetch the dashboard with eager-loaded relationships to avoid
# "Instance is not bound to a Session" errors when serializing
# chart .tags and .owners.
from sqlalchemy.orm import subqueryload
# Re-fetch with eager-loaded relationships for serialization
from superset.daos.dashboard import DashboardDAO
from superset.models.dashboard import Dashboard
dashboard = (
DashboardDAO.find_by_id(
@@ -355,8 +385,8 @@ def generate_dashboard(
dashboard=dashboard_info, dashboard_url=dashboard_url, error=None
)
except Exception as e:
logger.error("Error creating dashboard: %s", e)
except (SQLAlchemyError, ValueError, AttributeError, ValidationError) as e:
logger.error("Error creating dashboard: %s", e, exc_info=True)
return GenerateDashboardResponse(
dashboard=None,
dashboard_url=None,

View File

@@ -105,30 +105,59 @@ def _mock_dashboard(id: int = 1, title: str = "Test Dashboard") -> Mock:
return dashboard
def _setup_generate_dashboard_mocks(
mock_db_session,
mock_find_by_id,
mock_dashboard_cls,
charts,
dashboard,
):
"""Set up common mocks for generate_dashboard tests.
The tool creates dashboards directly via db.session (bypassing
CreateDashboardCommand) and re-queries user/charts in the tool's
own session. This helper wires up the mock chain for that path.
"""
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_user.first_name = "Admin"
mock_user.last_name = "User"
mock_user.email = "admin@example.com"
mock_user.active = True
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_query.filter_by.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = charts
mock_filter.first.return_value = mock_user
mock_db_session.query.return_value = mock_query
mock_dashboard_cls.return_value = dashboard
mock_find_by_id.return_value = dashboard
class TestGenerateDashboard:
"""Tests for generate_dashboard MCP tool."""
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_basic(
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
):
"""Test basic dashboard generation with valid charts."""
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = [
charts = [
_mock_chart(id=1, slice_name="Sales Chart"),
_mock_chart(id=2, slice_name="Revenue Chart"),
]
mock_db_session.query.return_value = mock_query
mock_dashboard = _mock_dashboard(id=10, title="Analytics Dashboard")
mock_create_command.return_value.run.return_value = mock_dashboard
mock_find_by_id.return_value = mock_dashboard
_setup_generate_dashboard_mocks(
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
)
request = {
"chart_ids": [1, 2],
@@ -173,24 +202,19 @@ class TestGenerateDashboard:
assert result.structured_content["dashboard"] is None
assert result.structured_content["dashboard_url"] is None
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_single_chart(
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
):
"""Test dashboard generation with a single chart."""
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = [_mock_chart(id=5, slice_name="Single Chart")]
mock_db_session.query.return_value = mock_query
charts = [_mock_chart(id=5, slice_name="Single Chart")]
mock_dashboard = _mock_dashboard(id=20, title="Single Chart Dashboard")
mock_create_command.return_value.run.return_value = mock_dashboard
mock_find_by_id.return_value = mock_dashboard
_setup_generate_dashboard_mocks(
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
)
request = {
"chart_ids": [5],
@@ -203,29 +227,22 @@ class TestGenerateDashboard:
assert result.structured_content["error"] is None
assert result.structured_content["dashboard"]["chart_count"] == 1
assert result.structured_content["dashboard"]["published"] is True
assert result.structured_content["dashboard"]["published"] is False
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_many_charts(
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
):
"""Test dashboard generation with many charts (grid layout)."""
chart_ids = list(range(1, 7))
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = [
_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids
]
mock_db_session.query.return_value = mock_query
charts = [_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids]
mock_dashboard = _mock_dashboard(id=30, title="Multi Chart Dashboard")
mock_create_command.return_value.run.return_value = mock_dashboard
mock_find_by_id.return_value = mock_dashboard
_setup_generate_dashboard_mocks(
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
)
request = {"chart_ids": chart_ids, "dashboard_title": "Multi Chart Dashboard"}
@@ -235,10 +252,14 @@ class TestGenerateDashboard:
assert result.structured_content["error"] is None
assert result.structured_content["dashboard"]["chart_count"] == 6
mock_create_command.assert_called_once()
call_args = mock_create_command.call_args[0][0]
# Verify db.session.add and commit were called
# (commit is called multiple times: once by tool + event_logger contexts)
mock_db_session.add.assert_called_once()
assert mock_db_session.commit.call_count >= 1
position_json = json.loads(call_args["position_json"])
# Verify layout was set on the dashboard object
created_dashboard = mock_dashboard_cls.return_value
position_json = json.loads(created_dashboard.position_json)
assert "ROOT_ID" in position_json
assert "GRID_ID" in position_json
assert "DASHBOARD_VERSION_KEY" in position_json
@@ -278,20 +299,33 @@ class TestGenerateDashboard:
assert row_data["type"] == "ROW"
assert column_key in row_data["children"]
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.models.dashboard.Dashboard")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_creation_failure(
self, mock_db_session, mock_create_command, mcp_server
self, mock_db_session, mock_dashboard_cls, mcp_server
):
"""Test error handling when dashboard creation fails."""
from sqlalchemy.exc import SQLAlchemyError
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_query.filter_by.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = [_mock_chart(id=1)]
mock_filter.first.return_value = Mock(
id=1,
username="admin",
first_name="Admin",
last_name="User",
email="admin@example.com",
active=True,
)
mock_db_session.query.return_value = mock_query
mock_create_command.return_value.run.side_effect = Exception("Creation failed")
mock_db_session.commit.side_effect = SQLAlchemyError("Creation failed")
mock_dashboard_cls.return_value = _mock_dashboard(id=99)
request = {"chart_ids": [1], "dashboard_title": "Failed Dashboard"}
@@ -301,25 +335,22 @@ class TestGenerateDashboard:
assert result.structured_content["error"] is not None
assert "Failed to create dashboard" in result.structured_content["error"]
assert result.structured_content["dashboard"] is None
# rollback called by tool + event_logger error handling
assert mock_db_session.rollback.call_count >= 1
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_minimal_request(
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
):
"""Test dashboard generation with minimal required parameters."""
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = [_mock_chart(id=3)]
mock_db_session.query.return_value = mock_query
charts = [_mock_chart(id=3)]
mock_dashboard = _mock_dashboard(id=40, title="Minimal Dashboard")
mock_create_command.return_value.run.return_value = mock_dashboard
mock_find_by_id.return_value = mock_dashboard
_setup_generate_dashboard_mocks(
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
)
request = {
"chart_ids": [3],
@@ -335,33 +366,26 @@ class TestGenerateDashboard:
== "Minimal Dashboard"
)
call_args = mock_create_command.call_args[0][0]
assert call_args["published"] is True
assert (
"description" not in call_args or call_args.get("description") is None
)
# Verify dashboard was created with default published=True
created = mock_dashboard_cls.return_value
assert created.published is True
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_auto_title_from_charts(
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
):
"""Test that omitting dashboard_title generates a title from chart names."""
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = [
charts = [
_mock_chart(id=1, slice_name="Sales Revenue"),
_mock_chart(id=2, slice_name="Customer Count"),
]
mock_db_session.query.return_value = mock_query
mock_dashboard = _mock_dashboard(id=50, title="Sales Revenue & Customer Count")
mock_create_command.return_value.run.return_value = mock_dashboard
mock_find_by_id.return_value = mock_dashboard
_setup_generate_dashboard_mocks(
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
)
# No dashboard_title provided
request = {"chart_ids": [1, 2]}
@@ -371,29 +395,23 @@ class TestGenerateDashboard:
assert result.structured_content["error"] is None
call_args = mock_create_command.call_args[0][0]
assert call_args["dashboard_title"] == "Sales Revenue & Customer Count"
# Verify auto-generated title was set on dashboard
created = mock_dashboard_cls.return_value
assert created.dashboard_title == "Sales Revenue & Customer Count"
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
@patch("superset.models.dashboard.Dashboard")
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@patch("superset.db.session")
@pytest.mark.asyncio
async def test_generate_dashboard_empty_string_title_preserved(
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
):
"""Test that an explicit empty-string title is NOT replaced by auto-gen."""
mock_query = Mock()
mock_filter = Mock()
mock_query.filter.return_value = mock_filter
mock_filter.order_by.return_value = mock_filter
mock_filter.all.return_value = [
_mock_chart(id=1, slice_name="Sales Revenue"),
]
mock_db_session.query.return_value = mock_query
charts = [_mock_chart(id=1, slice_name="Sales Revenue")]
mock_dashboard = _mock_dashboard(id=60, title="")
mock_create_command.return_value.run.return_value = mock_dashboard
mock_find_by_id.return_value = mock_dashboard
_setup_generate_dashboard_mocks(
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
)
# Explicit empty string title
request = {"chart_ids": [1], "dashboard_title": ""}
@@ -403,8 +421,9 @@ class TestGenerateDashboard:
assert result.structured_content["error"] is None
call_args = mock_create_command.call_args[0][0]
assert call_args["dashboard_title"] == ""
# Verify empty string title was preserved (not replaced by auto-gen)
created = mock_dashboard_cls.return_value
assert created.dashboard_title == ""
class TestAddChartToExistingDashboard: