mirror of
https://github.com/apache/superset.git
synced 2026-04-18 23:55:00 +00:00
fix(mcp): fix generate_dashboard cross-session SQLAlchemy error (#38827)
This commit is contained in:
@@ -25,6 +25,7 @@ import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from fastmcp import Context
|
||||
from flask import g
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
@@ -200,10 +201,12 @@ def generate_dashboard(
|
||||
Returns:
|
||||
- Dashboard ID and URL
|
||||
"""
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
try:
|
||||
# Get chart objects from IDs (required for SQLAlchemy relationships)
|
||||
from superset import db
|
||||
from superset.commands.dashboard.create import CreateDashboardCommand
|
||||
from superset.models.slice import Slice
|
||||
|
||||
with event_logger.log_context(action="mcp.generate_dashboard.chart_validation"):
|
||||
@@ -235,65 +238,92 @@ def generate_dashboard(
|
||||
else _generate_title_from_charts(chart_objects)
|
||||
)
|
||||
|
||||
# Prepare dashboard data and create dashboard
|
||||
# Create the dashboard directly with db.session instead of using
|
||||
# CreateDashboardCommand. The command's @transaction decorator
|
||||
# may operate in a different SQLAlchemy scoped-session than the
|
||||
# one g.user and chart ORM objects are bound to in the MCP
|
||||
# context, causing "Object is already attached to session X
|
||||
# (this is Y)" errors. By re-querying all ORM objects in the
|
||||
# tool's own db.session we keep everything in a single session.
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
||||
with event_logger.log_context(action="mcp.generate_dashboard.db_write"):
|
||||
dashboard_data: Dict[str, Any] = {
|
||||
"dashboard_title": dashboard_title,
|
||||
"json_metadata": json.dumps(
|
||||
{
|
||||
"filter_scopes": {},
|
||||
"expanded_slices": {},
|
||||
"refresh_frequency": 0,
|
||||
"timed_refresh_immune_slices": [],
|
||||
"color_scheme": None,
|
||||
"label_colors": {},
|
||||
"shared_label_colors": {},
|
||||
"color_scheme_domain": [],
|
||||
"cross_filters_enabled": False,
|
||||
"native_filter_configuration": [],
|
||||
"global_chart_configuration": {
|
||||
"scope": {
|
||||
"rootPath": ["ROOT_ID"],
|
||||
"excluded": [],
|
||||
}
|
||||
},
|
||||
"chart_configuration": {},
|
||||
}
|
||||
),
|
||||
"position_json": json.dumps(layout),
|
||||
"published": request.published,
|
||||
"slices": chart_objects, # Pass ORM objects, not IDs
|
||||
}
|
||||
json_metadata = json.dumps(
|
||||
{
|
||||
"filter_scopes": {},
|
||||
"expanded_slices": {},
|
||||
"refresh_frequency": 0,
|
||||
"timed_refresh_immune_slices": [],
|
||||
"color_scheme": None,
|
||||
"label_colors": {},
|
||||
"shared_label_colors": {},
|
||||
"color_scheme_domain": [],
|
||||
"cross_filters_enabled": False,
|
||||
"native_filter_configuration": [],
|
||||
"global_chart_configuration": {
|
||||
"scope": {
|
||||
"rootPath": ["ROOT_ID"],
|
||||
"excluded": [],
|
||||
}
|
||||
},
|
||||
"chart_configuration": {},
|
||||
}
|
||||
)
|
||||
|
||||
if request.description:
|
||||
dashboard_data["description"] = request.description
|
||||
|
||||
# Create the dashboard using Superset's command pattern
|
||||
try:
|
||||
command = CreateDashboardCommand(dashboard_data)
|
||||
dashboard = command.run()
|
||||
except Exception as cmd_err:
|
||||
# Surface the root cause from @transaction's error wrapping
|
||||
root_cause = cmd_err.__cause__ or cmd_err
|
||||
dashboard = Dashboard()
|
||||
dashboard.dashboard_title = dashboard_title
|
||||
dashboard.json_metadata = json_metadata
|
||||
dashboard.position_json = json.dumps(layout)
|
||||
dashboard.published = request.published
|
||||
|
||||
if request.description:
|
||||
dashboard.description = request.description
|
||||
|
||||
# Re-query the current user and charts directly in the
|
||||
# current db.session. g.user was loaded in a Flask
|
||||
# app_context that has since been torn down (the
|
||||
# middleware's ``with flask_app.app_context()`` exits
|
||||
# before the tool function runs), so the User object
|
||||
# is bound to a dead/different scoped session.
|
||||
# Querying fresh avoids all cross-session errors.
|
||||
from superset.extensions import security_manager
|
||||
|
||||
current_user = (
|
||||
db.session.query(security_manager.user_model)
|
||||
.filter_by(id=g.user.id)
|
||||
.first()
|
||||
)
|
||||
if current_user:
|
||||
dashboard.owners = [current_user]
|
||||
|
||||
fresh_charts = (
|
||||
db.session.query(Slice)
|
||||
.filter(Slice.id.in_(request.chart_ids))
|
||||
.order_by(Slice.id)
|
||||
.all()
|
||||
)
|
||||
dashboard.slices = fresh_charts
|
||||
|
||||
db.session.add(dashboard)
|
||||
db.session.commit()
|
||||
except SQLAlchemyError as db_err:
|
||||
db.session.rollback()
|
||||
logger.error(
|
||||
"CreateDashboardCommand failed: %s (cause: %s)",
|
||||
cmd_err,
|
||||
root_cause,
|
||||
"Dashboard creation failed: %s",
|
||||
db_err,
|
||||
exc_info=True,
|
||||
)
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
error=f"Failed to create dashboard: {root_cause}",
|
||||
error="Failed to create dashboard due to a database error.",
|
||||
)
|
||||
|
||||
# Re-fetch the dashboard with eager-loaded relationships to avoid
|
||||
# "Instance is not bound to a Session" errors when serializing
|
||||
# chart .tags and .owners.
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
# Re-fetch with eager-loaded relationships for serialization
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
||||
dashboard = (
|
||||
DashboardDAO.find_by_id(
|
||||
@@ -355,8 +385,8 @@ def generate_dashboard(
|
||||
dashboard=dashboard_info, dashboard_url=dashboard_url, error=None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error("Error creating dashboard: %s", e)
|
||||
except (SQLAlchemyError, ValueError, AttributeError, ValidationError) as e:
|
||||
logger.error("Error creating dashboard: %s", e, exc_info=True)
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
|
||||
@@ -105,30 +105,59 @@ def _mock_dashboard(id: int = 1, title: str = "Test Dashboard") -> Mock:
|
||||
return dashboard
|
||||
|
||||
|
||||
def _setup_generate_dashboard_mocks(
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_dashboard_cls,
|
||||
charts,
|
||||
dashboard,
|
||||
):
|
||||
"""Set up common mocks for generate_dashboard tests.
|
||||
|
||||
The tool creates dashboards directly via db.session (bypassing
|
||||
CreateDashboardCommand) and re-queries user/charts in the tool's
|
||||
own session. This helper wires up the mock chain for that path.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_user.first_name = "Admin"
|
||||
mock_user.last_name = "User"
|
||||
mock_user.email = "admin@example.com"
|
||||
mock_user.active = True
|
||||
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_query.filter_by.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = charts
|
||||
mock_filter.first.return_value = mock_user
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard_cls.return_value = dashboard
|
||||
mock_find_by_id.return_value = dashboard
|
||||
|
||||
|
||||
class TestGenerateDashboard:
|
||||
"""Tests for generate_dashboard MCP tool."""
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_basic(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test basic dashboard generation with valid charts."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
charts = [
|
||||
_mock_chart(id=1, slice_name="Sales Chart"),
|
||||
_mock_chart(id=2, slice_name="Revenue Chart"),
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=10, title="Analytics Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
request = {
|
||||
"chart_ids": [1, 2],
|
||||
@@ -173,24 +202,19 @@ class TestGenerateDashboard:
|
||||
assert result.structured_content["dashboard"] is None
|
||||
assert result.structured_content["dashboard_url"] is None
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_single_chart(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with a single chart."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=5, slice_name="Single Chart")]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
charts = [_mock_chart(id=5, slice_name="Single Chart")]
|
||||
mock_dashboard = _mock_dashboard(id=20, title="Single Chart Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
request = {
|
||||
"chart_ids": [5],
|
||||
@@ -203,29 +227,22 @@ class TestGenerateDashboard:
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
assert result.structured_content["dashboard"]["chart_count"] == 1
|
||||
assert result.structured_content["dashboard"]["published"] is True
|
||||
assert result.structured_content["dashboard"]["published"] is False
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_many_charts(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with many charts (grid layout)."""
|
||||
chart_ids = list(range(1, 7))
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
charts = [_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids]
|
||||
mock_dashboard = _mock_dashboard(id=30, title="Multi Chart Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
request = {"chart_ids": chart_ids, "dashboard_title": "Multi Chart Dashboard"}
|
||||
|
||||
@@ -235,10 +252,14 @@ class TestGenerateDashboard:
|
||||
assert result.structured_content["error"] is None
|
||||
assert result.structured_content["dashboard"]["chart_count"] == 6
|
||||
|
||||
mock_create_command.assert_called_once()
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
# Verify db.session.add and commit were called
|
||||
# (commit is called multiple times: once by tool + event_logger contexts)
|
||||
mock_db_session.add.assert_called_once()
|
||||
assert mock_db_session.commit.call_count >= 1
|
||||
|
||||
position_json = json.loads(call_args["position_json"])
|
||||
# Verify layout was set on the dashboard object
|
||||
created_dashboard = mock_dashboard_cls.return_value
|
||||
position_json = json.loads(created_dashboard.position_json)
|
||||
assert "ROOT_ID" in position_json
|
||||
assert "GRID_ID" in position_json
|
||||
assert "DASHBOARD_VERSION_KEY" in position_json
|
||||
@@ -278,20 +299,33 @@ class TestGenerateDashboard:
|
||||
assert row_data["type"] == "ROW"
|
||||
assert column_key in row_data["children"]
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_creation_failure(
|
||||
self, mock_db_session, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test error handling when dashboard creation fails."""
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_query.filter_by.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=1)]
|
||||
mock_filter.first.return_value = Mock(
|
||||
id=1,
|
||||
username="admin",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
email="admin@example.com",
|
||||
active=True,
|
||||
)
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_create_command.return_value.run.side_effect = Exception("Creation failed")
|
||||
mock_db_session.commit.side_effect = SQLAlchemyError("Creation failed")
|
||||
|
||||
mock_dashboard_cls.return_value = _mock_dashboard(id=99)
|
||||
|
||||
request = {"chart_ids": [1], "dashboard_title": "Failed Dashboard"}
|
||||
|
||||
@@ -301,25 +335,22 @@ class TestGenerateDashboard:
|
||||
assert result.structured_content["error"] is not None
|
||||
assert "Failed to create dashboard" in result.structured_content["error"]
|
||||
assert result.structured_content["dashboard"] is None
|
||||
# rollback called by tool + event_logger error handling
|
||||
assert mock_db_session.rollback.call_count >= 1
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_minimal_request(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with minimal required parameters."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=3)]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
charts = [_mock_chart(id=3)]
|
||||
mock_dashboard = _mock_dashboard(id=40, title="Minimal Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
request = {
|
||||
"chart_ids": [3],
|
||||
@@ -335,33 +366,26 @@ class TestGenerateDashboard:
|
||||
== "Minimal Dashboard"
|
||||
)
|
||||
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
assert call_args["published"] is True
|
||||
assert (
|
||||
"description" not in call_args or call_args.get("description") is None
|
||||
)
|
||||
# Verify dashboard was created with default published=True
|
||||
created = mock_dashboard_cls.return_value
|
||||
assert created.published is True
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_auto_title_from_charts(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test that omitting dashboard_title generates a title from chart names."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
charts = [
|
||||
_mock_chart(id=1, slice_name="Sales Revenue"),
|
||||
_mock_chart(id=2, slice_name="Customer Count"),
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=50, title="Sales Revenue & Customer Count")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
# No dashboard_title provided
|
||||
request = {"chart_ids": [1, 2]}
|
||||
@@ -371,29 +395,23 @@ class TestGenerateDashboard:
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
assert call_args["dashboard_title"] == "Sales Revenue & Customer Count"
|
||||
# Verify auto-generated title was set on dashboard
|
||||
created = mock_dashboard_cls.return_value
|
||||
assert created.dashboard_title == "Sales Revenue & Customer Count"
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_empty_string_title_preserved(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test that an explicit empty-string title is NOT replaced by auto-gen."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
_mock_chart(id=1, slice_name="Sales Revenue"),
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
charts = [_mock_chart(id=1, slice_name="Sales Revenue")]
|
||||
mock_dashboard = _mock_dashboard(id=60, title="")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
# Explicit empty string title
|
||||
request = {"chart_ids": [1], "dashboard_title": ""}
|
||||
@@ -403,8 +421,9 @@ class TestGenerateDashboard:
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
assert call_args["dashboard_title"] == ""
|
||||
# Verify empty string title was preserved (not replaced by auto-gen)
|
||||
created = mock_dashboard_cls.return_value
|
||||
assert created.dashboard_title == ""
|
||||
|
||||
|
||||
class TestAddChartToExistingDashboard:
|
||||
|
||||
Reference in New Issue
Block a user