mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
fix(mcp): fix generate_dashboard cross-session SQLAlchemy error (#38827)
This commit is contained in:
@@ -105,30 +105,59 @@ def _mock_dashboard(id: int = 1, title: str = "Test Dashboard") -> Mock:
|
||||
return dashboard
|
||||
|
||||
|
||||
def _setup_generate_dashboard_mocks(
|
||||
mock_db_session,
|
||||
mock_find_by_id,
|
||||
mock_dashboard_cls,
|
||||
charts,
|
||||
dashboard,
|
||||
):
|
||||
"""Set up common mocks for generate_dashboard tests.
|
||||
|
||||
The tool creates dashboards directly via db.session (bypassing
|
||||
CreateDashboardCommand) and re-queries user/charts in the tool's
|
||||
own session. This helper wires up the mock chain for that path.
|
||||
"""
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_user.first_name = "Admin"
|
||||
mock_user.last_name = "User"
|
||||
mock_user.email = "admin@example.com"
|
||||
mock_user.active = True
|
||||
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_query.filter_by.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = charts
|
||||
mock_filter.first.return_value = mock_user
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard_cls.return_value = dashboard
|
||||
mock_find_by_id.return_value = dashboard
|
||||
|
||||
|
||||
class TestGenerateDashboard:
|
||||
"""Tests for generate_dashboard MCP tool."""
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_basic(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test basic dashboard generation with valid charts."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
charts = [
|
||||
_mock_chart(id=1, slice_name="Sales Chart"),
|
||||
_mock_chart(id=2, slice_name="Revenue Chart"),
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=10, title="Analytics Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
request = {
|
||||
"chart_ids": [1, 2],
|
||||
@@ -173,24 +202,19 @@ class TestGenerateDashboard:
|
||||
assert result.structured_content["dashboard"] is None
|
||||
assert result.structured_content["dashboard_url"] is None
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_single_chart(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with a single chart."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=5, slice_name="Single Chart")]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
charts = [_mock_chart(id=5, slice_name="Single Chart")]
|
||||
mock_dashboard = _mock_dashboard(id=20, title="Single Chart Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
request = {
|
||||
"chart_ids": [5],
|
||||
@@ -203,29 +227,22 @@ class TestGenerateDashboard:
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
assert result.structured_content["dashboard"]["chart_count"] == 1
|
||||
assert result.structured_content["dashboard"]["published"] is True
|
||||
assert result.structured_content["dashboard"]["published"] is False
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_many_charts(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with many charts (grid layout)."""
|
||||
chart_ids = list(range(1, 7))
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
charts = [_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids]
|
||||
mock_dashboard = _mock_dashboard(id=30, title="Multi Chart Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
request = {"chart_ids": chart_ids, "dashboard_title": "Multi Chart Dashboard"}
|
||||
|
||||
@@ -235,10 +252,14 @@ class TestGenerateDashboard:
|
||||
assert result.structured_content["error"] is None
|
||||
assert result.structured_content["dashboard"]["chart_count"] == 6
|
||||
|
||||
mock_create_command.assert_called_once()
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
# Verify db.session.add and commit were called
|
||||
# (commit is called multiple times: once by tool + event_logger contexts)
|
||||
mock_db_session.add.assert_called_once()
|
||||
assert mock_db_session.commit.call_count >= 1
|
||||
|
||||
position_json = json.loads(call_args["position_json"])
|
||||
# Verify layout was set on the dashboard object
|
||||
created_dashboard = mock_dashboard_cls.return_value
|
||||
position_json = json.loads(created_dashboard.position_json)
|
||||
assert "ROOT_ID" in position_json
|
||||
assert "GRID_ID" in position_json
|
||||
assert "DASHBOARD_VERSION_KEY" in position_json
|
||||
@@ -278,20 +299,33 @@ class TestGenerateDashboard:
|
||||
assert row_data["type"] == "ROW"
|
||||
assert column_key in row_data["children"]
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_creation_failure(
|
||||
self, mock_db_session, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test error handling when dashboard creation fails."""
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_query.filter_by.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=1)]
|
||||
mock_filter.first.return_value = Mock(
|
||||
id=1,
|
||||
username="admin",
|
||||
first_name="Admin",
|
||||
last_name="User",
|
||||
email="admin@example.com",
|
||||
active=True,
|
||||
)
|
||||
mock_db_session.query.return_value = mock_query
|
||||
mock_create_command.return_value.run.side_effect = Exception("Creation failed")
|
||||
mock_db_session.commit.side_effect = SQLAlchemyError("Creation failed")
|
||||
|
||||
mock_dashboard_cls.return_value = _mock_dashboard(id=99)
|
||||
|
||||
request = {"chart_ids": [1], "dashboard_title": "Failed Dashboard"}
|
||||
|
||||
@@ -301,25 +335,22 @@ class TestGenerateDashboard:
|
||||
assert result.structured_content["error"] is not None
|
||||
assert "Failed to create dashboard" in result.structured_content["error"]
|
||||
assert result.structured_content["dashboard"] is None
|
||||
# rollback called by tool + event_logger error handling
|
||||
assert mock_db_session.rollback.call_count >= 1
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_minimal_request(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with minimal required parameters."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [_mock_chart(id=3)]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
charts = [_mock_chart(id=3)]
|
||||
mock_dashboard = _mock_dashboard(id=40, title="Minimal Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
request = {
|
||||
"chart_ids": [3],
|
||||
@@ -335,33 +366,26 @@ class TestGenerateDashboard:
|
||||
== "Minimal Dashboard"
|
||||
)
|
||||
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
assert call_args["published"] is True
|
||||
assert (
|
||||
"description" not in call_args or call_args.get("description") is None
|
||||
)
|
||||
# Verify dashboard was created with default published=True
|
||||
created = mock_dashboard_cls.return_value
|
||||
assert created.published is True
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_auto_title_from_charts(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test that omitting dashboard_title generates a title from chart names."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
charts = [
|
||||
_mock_chart(id=1, slice_name="Sales Revenue"),
|
||||
_mock_chart(id=2, slice_name="Customer Count"),
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=50, title="Sales Revenue & Customer Count")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
# No dashboard_title provided
|
||||
request = {"chart_ids": [1, 2]}
|
||||
@@ -371,29 +395,23 @@ class TestGenerateDashboard:
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
assert call_args["dashboard_title"] == "Sales Revenue & Customer Count"
|
||||
# Verify auto-generated title was set on dashboard
|
||||
created = mock_dashboard_cls.return_value
|
||||
assert created.dashboard_title == "Sales Revenue & Customer Count"
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.models.dashboard.Dashboard")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.db.session")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_empty_string_title_preserved(
|
||||
self, mock_db_session, mock_find_by_id, mock_create_command, mcp_server
|
||||
self, mock_db_session, mock_find_by_id, mock_dashboard_cls, mcp_server
|
||||
):
|
||||
"""Test that an explicit empty-string title is NOT replaced by auto-gen."""
|
||||
mock_query = Mock()
|
||||
mock_filter = Mock()
|
||||
mock_query.filter.return_value = mock_filter
|
||||
mock_filter.order_by.return_value = mock_filter
|
||||
mock_filter.all.return_value = [
|
||||
_mock_chart(id=1, slice_name="Sales Revenue"),
|
||||
]
|
||||
mock_db_session.query.return_value = mock_query
|
||||
|
||||
charts = [_mock_chart(id=1, slice_name="Sales Revenue")]
|
||||
mock_dashboard = _mock_dashboard(id=60, title="")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
mock_find_by_id.return_value = mock_dashboard
|
||||
_setup_generate_dashboard_mocks(
|
||||
mock_db_session, mock_find_by_id, mock_dashboard_cls, charts, mock_dashboard
|
||||
)
|
||||
|
||||
# Explicit empty string title
|
||||
request = {"chart_ids": [1], "dashboard_title": ""}
|
||||
@@ -403,8 +421,9 @@ class TestGenerateDashboard:
|
||||
|
||||
assert result.structured_content["error"] is None
|
||||
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
assert call_args["dashboard_title"] == ""
|
||||
# Verify empty string title was preserved (not replaced by auto-gen)
|
||||
created = mock_dashboard_cls.return_value
|
||||
assert created.dashboard_title == ""
|
||||
|
||||
|
||||
class TestAddChartToExistingDashboard:
|
||||
|
||||
Reference in New Issue
Block a user