diff --git a/superset/daos/base.py b/superset/daos/base.py index 4897f796cf8..136184304b5 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -248,6 +248,7 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]): column_name: str, value: str | int, skip_base_filter: bool = False, + query_options: list[Any] | None = None, ) -> T | None: """ Private method to find a model by any column value. @@ -256,6 +257,8 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]): column_name: Name of the column to search by value: Value to search for skip_base_filter: Whether to skip base filtering + query_options: SQLAlchemy query options (e.g., joinedload, + subqueryload) to apply to the query for eager loading Returns: Model instance or None if not found @@ -263,6 +266,9 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]): query = db.session.query(cls.model_cls) query = cls._apply_base_filter(query, skip_base_filter) + if query_options: + query = query.options(*query_options) + if not hasattr(cls.model_cls, column_name): return None @@ -283,6 +289,7 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]): model_id: str | int, skip_base_filter: bool = False, id_column: str | None = None, + query_options: list[Any] | None = None, ) -> T | None: """ Find a model by ID using specified or default ID column. @@ -291,12 +298,14 @@ class BaseDAO(CoreBaseDAO[T], Generic[T]): model_id: ID value to search for skip_base_filter: Whether to skip base filtering id_column: Column name to use (defaults to cls.id_column_name) + query_options: SQLAlchemy query options (e.g., joinedload, + subqueryload) to apply to the query for eager loading Returns: Model instance or None if not found """ column = id_column or cls.id_column_name - return cls._find_by_column(column, model_id, skip_base_filter) + return cls._find_by_column(column, model_id, skip_base_filter, query_options) @classmethod def find_by_ids( diff --git a/superset/daos/database.py b/superset/daos/database.py index cd1bc3d51b3..5b1b33b3839 100644 --- a/superset/daos/database.py +++ b/superset/daos/database.py @@ -70,11 +70,15 @@ class DatabaseDAO(BaseDAO[Database]): model_id: str | int, skip_base_filter: bool = False, id_column: str | None = None, + query_options: list[Any] | None = None, ) -> Database | None: """ Find a database by id, eagerly loading the SSH tunnel relationship. """ - query = db.session.query(cls.model_cls).options(joinedload(Database.ssh_tunnel)) + all_options = [joinedload(Database.ssh_tunnel)] + if query_options: + all_options.extend(query_options) + query = db.session.query(cls.model_cls).options(*all_options) query = cls._apply_base_filter(query, skip_base_filter) column_name = id_column or cls.id_column_name diff --git a/superset/mcp_service/chart/tool/get_chart_info.py b/superset/mcp_service/chart/tool/get_chart_info.py index de9dc314e6f..d3d92967408 100644 --- a/superset/mcp_service/chart/tool/get_chart_info.py +++ b/superset/mcp_service/chart/tool/get_chart_info.py @@ -22,6 +22,7 @@ MCP tool: get_chart_info import logging from fastmcp import Context +from sqlalchemy.orm import subqueryload from superset_core.mcp import tool from superset.commands.exceptions import CommandException @@ -93,6 +94,7 @@ async def get_chart_info( Returns chart details including name, type, and URL. """ from superset.daos.chart import ChartDAO + from superset.models.slice import Slice from superset.utils import json as utils_json await ctx.info( @@ -100,6 +102,12 @@ async def get_chart_info( % (request.identifier, request.form_data_key) ) + # Eager load owners and tags to avoid N+1 queries during serialization + eager_options = [ + subqueryload(Slice.owners), + subqueryload(Slice.tags), + ] + with event_logger.log_context(action="mcp.get_chart_info.lookup"): tool = ModelGetInfoCore( dao_class=ChartDAO, @@ -108,6 +116,7 @@ async def get_chart_info( serializer=serialize_chart_object, supports_slug=False, # Charts don't have slugs logger=logger, + query_options=eager_options, ) result = tool.run_tool(request.identifier) diff --git a/superset/mcp_service/dashboard/tool/get_dashboard_info.py b/superset/mcp_service/dashboard/tool/get_dashboard_info.py index ebca60a7bb7..31646db3753 100644 --- a/superset/mcp_service/dashboard/tool/get_dashboard_info.py +++ b/superset/mcp_service/dashboard/tool/get_dashboard_info.py @@ -26,6 +26,7 @@ import logging from datetime import datetime, timezone from fastmcp import Context +from sqlalchemy.orm import subqueryload from superset_core.mcp import tool from superset.dashboards.permalink.exceptions import DashboardPermalinkGetFailedError @@ -98,6 +99,19 @@ async def get_dashboard_info( try: from superset.daos.dashboard import DashboardDAO + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + + # Eager load slices (charts), owners, tags, and roles to avoid N+1 + # queries. Also eager load owners/tags on each slice since the + # dashboard serializer calls serialize_chart_object for every chart. + eager_options = [ + subqueryload(Dashboard.slices).subqueryload(Slice.owners), + subqueryload(Dashboard.slices).subqueryload(Slice.tags), + subqueryload(Dashboard.owners), + subqueryload(Dashboard.tags), + subqueryload(Dashboard.roles), + ] with event_logger.log_context(action="mcp.get_dashboard_info.lookup"): tool = ModelGetInfoCore( @@ -107,6 +121,7 @@ async def get_dashboard_info( serializer=dashboard_serializer, supports_slug=True, # Dashboards support slugs logger=logger, + query_options=eager_options, ) result = tool.run_tool(request.identifier) diff --git a/superset/mcp_service/dataset/tool/get_dataset_info.py b/superset/mcp_service/dataset/tool/get_dataset_info.py index e9e8817d2d9..35c963eb2bd 100644 --- a/superset/mcp_service/dataset/tool/get_dataset_info.py +++ b/superset/mcp_service/dataset/tool/get_dataset_info.py @@ -26,6 +26,7 @@ import logging from datetime import datetime, timezone from fastmcp import Context +from sqlalchemy.orm import joinedload, subqueryload from superset_core.mcp import tool from superset.extensions import event_logger @@ -82,8 +83,18 @@ async def get_dataset_info( ) try: + from superset.connectors.sqla.models import SqlaTable from superset.daos.dataset import DatasetDAO + # Eager load columns, metrics, and database to avoid N+1 queries. + # Without this, serialize_dataset_object triggers lazy loads for each + # relationship, which can time out on datasets with many columns. + eager_options = [ + subqueryload(SqlaTable.columns), + subqueryload(SqlaTable.metrics), + joinedload(SqlaTable.database), + ] + with event_logger.log_context(action="mcp.get_dataset_info.lookup"): tool = ModelGetInfoCore( dao_class=DatasetDAO, @@ -92,6 +103,7 @@ async def get_dataset_info( serializer=serialize_dataset_object, supports_slug=False, # Datasets don't have slugs logger=logger, + query_options=eager_options, ) result = tool.run_tool(request.identifier) diff --git a/superset/mcp_service/mcp_core.py b/superset/mcp_service/mcp_core.py index ef0612553c1..3ac0f008ca4 100644 --- a/superset/mcp_service/mcp_core.py +++ b/superset/mcp_service/mcp_core.py @@ -240,6 +240,7 @@ class ModelGetInfoCore(BaseCore): serializer: Callable[[T], BaseModel], supports_slug: bool = False, logger: logging.Logger | None = None, + query_options: list[Any] | None = None, ) -> None: super().__init__(logger) self.dao_class = dao_class @@ -247,29 +248,35 @@ class ModelGetInfoCore(BaseCore): self.error_schema = error_schema self.serializer = serializer self.supports_slug = supports_slug + self.query_options = query_options or [] def _find_object(self, identifier: int | str) -> Any: """Find object by identifier using appropriate method.""" + opts = self.query_options or None # If it's an integer or string that can be converted to int, use find_by_id if isinstance(identifier, int): - return self.dao_class.find_by_id(identifier) + return self.dao_class.find_by_id(identifier, query_options=opts) try: # Try to convert string to int id_val = int(identifier) - return self.dao_class.find_by_id(id_val) + return self.dao_class.find_by_id(id_val, query_options=opts) except ValueError: pass # Check if it's a UUID if _is_uuid(identifier): # Use the new flexible find_by_id with uuid column - return self.dao_class.find_by_id(identifier, id_column="uuid") + return self.dao_class.find_by_id( + identifier, id_column="uuid", query_options=opts + ) # For dashboards, also check slug if self.supports_slug: # Try to find by slug using the new flexible method - result = self.dao_class.find_by_id(identifier, id_column="slug") + result = self.dao_class.find_by_id( + identifier, id_column="slug", query_options=opts + ) if result: return result @@ -278,11 +285,10 @@ class ModelGetInfoCore(BaseCore): from superset.models.dashboard import id_or_slug_filter model_class = self.dao_class.model_cls - return ( - db.session.query(model_class) - .filter(id_or_slug_filter(identifier)) - .one_or_none() - ) + query = db.session.query(model_class).filter(id_or_slug_filter(identifier)) + if opts: + query = query.options(*opts) + return query.one_or_none() # If we get here, it's an invalid identifier return None diff --git a/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py b/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py index ad3ece3d498..aefdd8c8433 100644 --- a/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py +++ b/tests/unit_tests/mcp_service/system/tool/test_mcp_core.py @@ -69,7 +69,7 @@ class DummyDAO: return [SimpleNamespace(id=1, name="foo"), SimpleNamespace(id=2, name="bar")], 2 @classmethod - def find_by_id(cls, id): + def find_by_id(cls, id, **kwargs): if id == 1: return SimpleNamespace(id=1, name="foo") return None @@ -196,7 +196,7 @@ def test_model_get_info_tool_not_found(): def test_model_get_info_tool_exception(): class FailingDAO: @classmethod - def find_by_id(cls, id): + def find_by_id(cls, id, **kwargs): raise Exception("fail") tool = ModelGetInfoCore(