From 95b787f024fb09ecd11d1664a8576b644e2316be Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Tue, 15 Jul 2025 22:08:53 +1000 Subject: [PATCH] update: wip --- superset/mcp_service/dao_wrapper.py | 226 ------------------ superset/mcp_service/server.py | 2 +- .../mcp_service/tools/chart/get_chart_info.py | 36 ++- .../mcp_service/tools/chart/list_charts.py | 7 +- .../tools/dashboard/get_dashboard_info.py | 18 +- .../tools/dashboard/list_dashboards.py | 25 +- .../tools/dataset/get_dataset_info.py | 10 +- .../tools/dataset/list_datasets.py | 26 +- .../system/get_superset_instance_info.py | 39 ++- superset/mcp_service/utils.py | 92 +++++++ .../mcp_service/test_chart_tools.py | 21 +- .../mcp_service/test_dashboard_tools.py | 33 ++- .../mcp_service/test_dataset_tools.py | 31 ++- .../mcp_service/test_error_handling.py | 26 +- .../mcp_service/test_system_tools.py | 80 ++----- 15 files changed, 255 insertions(+), 417 deletions(-) delete mode 100644 superset/mcp_service/dao_wrapper.py create mode 100644 superset/mcp_service/utils.py diff --git a/superset/mcp_service/dao_wrapper.py b/superset/mcp_service/dao_wrapper.py deleted file mode 100644 index 38d67a96489..00000000000 --- a/superset/mcp_service/dao_wrapper.py +++ /dev/null @@ -1,226 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Generic DAO Wrapper for MCP Service - -This module provides a generic wrapper around Superset DAOs that provides -consistent access patterns for the MCP service, including proper user context -and security management. - -Example usage: - from superset.daos.dashboard import DashboardDAO - from superset.daos.chart import ChartDAO - from superset.daos.dataset import DatasetDAO - from superset.mcp_service.dao_wrapper import MCPDAOWrapper - - # Create wrappers for different models - dashboard_wrapper = MCPDAOWrapper(DashboardDAO, "dashboard") - chart_wrapper = MCPDAOWrapper(ChartDAO, "chart") - dataset_wrapper = MCPDAOWrapper(DatasetDAO, "dataset") - - # Get info about a specific item - dashboard, error_type, error_message = dashboard_wrapper.info(1) - chart, error_type, error_message = chart_wrapper.info(1) - dataset, error_type, error_message = dataset_wrapper.info(1) - - # List items with filters - dashboards, total_count = dashboard_wrapper.list( - filters={"published": True}, - page=0, - page_size=10 - ) - charts, total_count = chart_wrapper.list( - filters={"slice_name": "Sales Chart"}, - order_column="changed_on", - order_direction="desc" - ) -""" - -import logging -from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar - -from flask import current_app, g -from flask_appbuilder.models.sqla import Model -from flask_login import AnonymousUserMixin - -from superset.daos.base import BaseDAO, ColumnOperator -from superset.extensions import security_manager - -logger = logging.getLogger(__name__) - -T = TypeVar("T", bound=Model) - - -class MCPDAOWrapper: - """ - Generic wrapper for Superset DAOs that provides consistent access patterns - for the MCP service with proper user context and security management. - """ - - def __init__(self, dao_class: Type[BaseDAO[T]], model_name: str): - """ - Initialize the DAO wrapper - - Args: - dao_class: The DAO class to wrap (e.g., DashboardDAO, ChartDAO) - model_name: Human-readable name for the model (e.g., "dashboard", "chart") - """ - self.dao_class = dao_class - self.model_name = model_name - self.logger = logging.getLogger(f"{__name__}.{model_name}") - - def info(self, item_id: int) -> Tuple[Optional[T], Optional[str], Optional[str]]: - """ - Get detailed information about a specific item - - Args: - item_id: ID of the item to retrieve - - Returns: - Tuple of (item, error_type, error_message) - - item: The found item or None if not found/access denied - - error_type: Type of error if any ("not_found", "access_denied", etc.) - - error_message: Human-readable error message - """ - self.logger.info(f"Getting {self.model_name} info for ID: {item_id}") - - try: - # User context now handled by mcp_auth_hook - - # Use DAO to find the item - item = self.dao_class.find_by_id(item_id) - - if not item: - self.logger.warning(f"{self.model_name.capitalize()} with ID {item_id} not found") - return None, "not_found", f"{self.model_name.capitalize()} with ID {item_id} not found" - return item, None, None - - except Exception as e: - error_msg = f"Unexpected error getting {self.model_name} info: {str(e)}" - self.logger.error(error_msg, exc_info=True) - return None, "unexpected_error", error_msg - - def list( - self, - column_operators: Optional[List[ColumnOperator]] = None, - order_column: str = "changed_on", - order_direction: str = "desc", - page: int = 0, - page_size: int = 100, - search: Optional[str] = None, - search_columns: Optional[list] = None, - custom_filters: Optional[dict] = None, - columns: Optional[list] = None, - ) -> Tuple[list, int]: - """ - Generic list method for filtered, sorted, and paginated results. - """ - self.logger.info(f"Listing {self.model_name}s with column_operators: {column_operators}") - try: - items, total_count = self.dao_class.list( - column_operators=column_operators, - order_column=order_column, - order_direction=order_direction, - page=page, - page_size=page_size, - search=search, - search_columns=search_columns, - custom_filters=custom_filters, - columns=columns, - ) - self.logger.info(f"Retrieved {len(items)} {self.model_name}s (total: {total_count})") - return items, total_count - except Exception as e: - error_msg = f"Unexpected error listing {self.model_name}s: {str(e)}" - self.logger.error(error_msg, exc_info=True) - return [], 0 - - def count( - self, - column_operators: Optional[List[ColumnOperator]] = None, - skip_base_filter: bool = False, - ) -> int: - """ - Count the number of records for the model, optionally filtered by column operators. - """ - if column_operators is None: - column_operators = [] - try: - return self.dao_class.count(column_operators, skip_base_filter=skip_base_filter) - except Exception as e: - self.logger.error(f"Error counting records: {e}") - return 0 - -def get_user_from_request(): - """ - Extract user info from the request context (e.g., from Bearer token, headers, etc.). - By default, returns admin user. Override for OIDC/OAuth/Okta integration. - """ - from flask import current_app - from superset.extensions import security_manager - admin_username = current_app.config.get("MCP_ADMIN_USERNAME", "admin") - return security_manager.get_user_by_username(admin_username) - -def impersonate_user(user, run_as=None): - """ - Optionally impersonate another user if allowed. By default, returns the same user. - Override to enforce impersonation rules. - """ - return user - -def has_permission(user, tool_func): - """ - Check if the user has permission to run the tool. By default, always True. - Override for RBAC. - """ - return True - -def log_access(user, tool_name, args, kwargs): - """ - Log access/action for observability/audit. By default, does nothing. - Override to log to your system. - """ - pass - -def mcp_auth_hook(tool_func): - """ - Decorator for MCP tool functions to enforce auth, impersonation, RBAC, and logging. - Also sets up Flask user context (g.user) for downstream DAO/model code. - All logic is overridable for enterprise integration. - """ - import functools - @functools.wraps(tool_func) - def wrapper(*args, **kwargs): - # --- Setup user context (was _setup_user_context) --- - admin_username = current_app.config.get("MCP_ADMIN_USERNAME", "admin") - admin_user = security_manager.get_user_by_username(admin_username) - if not admin_user: - g.user = AnonymousUserMixin() - else: - g.user = admin_user - # --- End user context setup --- - - user = get_user_from_request() - run_as = kwargs.get("run_as") - if run_as: - user = impersonate_user(user, run_as) - if not has_permission(user, tool_func): - raise PermissionError(f"User {getattr(user, 'username', user)} not authorized for {tool_func.__name__}") - log_access(user, tool_func.__name__, args, kwargs) - return tool_func(*args, **kwargs) - return wrapper diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index faa6d1037fb..3eee6212fdb 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -32,7 +32,7 @@ def init_fastmcp_server() -> 'FastMCP': """ from fastmcp import FastMCP from superset.mcp_service.middleware import LoggingMiddleware, PrivateToolMiddleware - from superset.mcp_service.dao_wrapper import mcp_auth_hook + from superset.mcp_service.utils import mcp_auth_hook logger = logging.getLogger(__name__) logger.setLevel(logging.DEBUG) diff --git a/superset/mcp_service/tools/chart/get_chart_info.py b/superset/mcp_service/tools/chart/get_chart_info.py index e1f57841ec5..02e39ae0a01 100644 --- a/superset/mcp_service/tools/chart/get_chart_info.py +++ b/superset/mcp_service/tools/chart/get_chart_info.py @@ -20,11 +20,13 @@ MCP tool: get_chart_info """ from typing import Any, Dict, Optional, Annotated from superset.mcp_service.pydantic_schemas import ChartInfo, ChartError -from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object -from datetime import datetime +from datetime import datetime, timezone from superset.daos.chart import ChartDAO from pydantic import Field +import logging + +logger = logging.getLogger(__name__) def get_chart_info( chart_id: Annotated[ @@ -33,15 +35,27 @@ def get_chart_info( ] ) -> ChartInfo | ChartError: """ - Get detailed information about a chart by ID (MCP tool). + Get detailed information about a specific chart. Returns a ChartInfo model or ChartError on error. """ try: - chart_wrapper = MCPDAOWrapper(ChartDAO, "chart") - chart, error_type, error_message = chart_wrapper.info(chart_id) - if not chart: - return ChartError(error=error_message or "Chart not found", error_type=error_type or "not_found", timestamp=datetime.utcnow()) - chart_info = serialize_chart_object(chart) - return chart_info - except Exception as ex: - return ChartError(error=str(ex), error_type="get_chart_info_error", timestamp=datetime.utcnow()) + chart = ChartDAO.find_by_id(chart_id) + if chart is None: + error_data = ChartError( + error=f"Chart with ID {chart_id} not found", + error_type="not_found", + timestamp=datetime.now(timezone.utc) + ) + logger.warning(f"ChartInfo {chart_id} error: not_found - not found") + return error_data + response = serialize_chart_object(chart) + logger.info(f"ChartInfo response created successfully for chart {chart.id}") + return response + except Exception as context_error: + error_msg = f"Error within Flask app context: {str(context_error)}" + logger.error(error_msg, exc_info=True) + raise + except Exception as e: + error_msg = f"Unexpected error in get_chart_info: {str(e)}" + logger.error(error_msg, exc_info=True) + raise diff --git a/superset/mcp_service/tools/chart/list_charts.py b/superset/mcp_service/tools/chart/list_charts.py index bee5321e877..8d2e9075901 100644 --- a/superset/mcp_service/tools/chart/list_charts.py +++ b/superset/mcp_service/tools/chart/list_charts.py @@ -20,7 +20,6 @@ MCP tool: list_charts (advanced filtering) """ from typing import Any, Dict, List, Optional, Literal, Annotated, Union from superset.mcp_service.pydantic_schemas import ChartList, ChartInfo -from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object from datetime import datetime, timezone from pydantic import BaseModel, conlist, constr, PositiveInt, Field @@ -75,12 +74,12 @@ def list_charts( # If filters is a string (e.g., from a test), parse it as JSON if isinstance(filters, str): filters = json.loads(filters) - chart_wrapper = MCPDAOWrapper(ChartDAO, "chart") - charts, total_count = chart_wrapper.list( + # Replace chart_wrapper usage with ChartDAO + charts, total_count = ChartDAO.list( column_operators=filters, order_column=order_column or "changed_on", order_direction=order_direction or "desc", - page=max(page - 1, 0), + page=page, page_size=page_size, search=search, search_columns=["slice_name", "viz_type", "datasource_name"] if search else None, diff --git a/superset/mcp_service/tools/dashboard/get_dashboard_info.py b/superset/mcp_service/tools/dashboard/get_dashboard_info.py index b21d65c2212..bce246ad0a3 100644 --- a/superset/mcp_service/tools/dashboard/get_dashboard_info.py +++ b/superset/mcp_service/tools/dashboard/get_dashboard_info.py @@ -28,7 +28,6 @@ from typing import Annotated from pydantic import Field from superset.daos.dashboard import DashboardDAO -from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas import DashboardError, DashboardInfo from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object from superset.mcp_service.pydantic_schemas.system_schemas import RoleInfo, TagInfo, \ @@ -48,16 +47,14 @@ def get_dashboard_info( Returns a DashboardInfo model or DashboardError on error. """ try: - dao_wrapper = MCPDAOWrapper(DashboardDAO, "dashboard") - dashboard, error_type, error_message = dao_wrapper.info(dashboard_id) + dashboard = DashboardDAO.find_by_id(dashboard_id) if dashboard is None: error_data = DashboardError( - error=error_message, - error_type=error_type, + error=f"Dashboard with ID {dashboard_id} not found", + error_type="not_found", timestamp=datetime.now(timezone.utc) ) - logger.warning( - f"Dashboard {dashboard_id} error: {error_type} - {error_message}") + logger.warning(f"DashboardInfo {dashboard_id} error: not_found - not found") return error_data response = DashboardInfo( id=dashboard.id, @@ -87,9 +84,12 @@ def get_dashboard_info( roles=[RoleInfo.model_validate(role, from_attributes=True) for role in dashboard.roles] if dashboard.roles else [], charts=[serialize_chart_object(chart) for chart in dashboard.slices] if dashboard.slices else [] ) - logger.info( - f"Dashboard response created successfully for dashboard {dashboard.id}") + logger.info(f"DashboardInfo response created successfully for dashboard {dashboard.id}") return response + except Exception as context_error: + error_msg = f"Error within Flask app context: {str(context_error)}" + logger.error(error_msg, exc_info=True) + raise except Exception as e: error_msg = f"Unexpected error in get_dashboard_info: {str(e)}" logger.error(error_msg, exc_info=True) diff --git a/superset/mcp_service/tools/dashboard/list_dashboards.py b/superset/mcp_service/tools/dashboard/list_dashboards.py index 1352cad17a8..9def5576bee 100644 --- a/superset/mcp_service/tools/dashboard/list_dashboards.py +++ b/superset/mcp_service/tools/dashboard/list_dashboards.py @@ -29,7 +29,6 @@ from typing import Annotated, Literal, Optional from pydantic import conlist, constr, Field, PositiveInt from superset.daos.dashboard import DashboardDAO -from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas import ( DashboardFilter, DashboardInfo, DashboardList) from superset.mcp_service.pydantic_schemas.chart_schemas import ( @@ -39,6 +38,15 @@ from superset.mcp_service.pydantic_schemas.system_schemas import ( logger = logging.getLogger(__name__) +DEFAULT_DASHBOARD_COLUMNS = [ + "id", + "dashboard_title", + "slug", + "published", + "changed_on", + "created_on", +] + def list_dashboards( filters: Annotated[ @@ -86,8 +94,8 @@ def list_dashboards( # If filters is a string (e.g., from a test), parse it as JSON if isinstance(filters, str): filters = json.loads(filters) - dao_wrapper = MCPDAOWrapper(DashboardDAO, "dashboard") - search_columns = ( + # Define search_columns for dashboard search + search_columns = [ "dashboard_title", "owners", "published", @@ -95,15 +103,18 @@ def list_dashboards( "slug", "tags", "uuid", - ) - dashboards, total_count = dao_wrapper.list( + ] + columns_to_load = select_columns if select_columns else DEFAULT_DASHBOARD_COLUMNS + dashboards, total_count = DashboardDAO.list( column_operators=filters, order_column=order_column or "changed_on", order_direction=order_direction or "desc", - page=max(page - 1, 0), + page=page, page_size=page_size, search=search, - search_columns=search_columns + search_columns=search_columns, + custom_filters=None, + columns=columns_to_load, ) columns_to_load = [] if select_columns: diff --git a/superset/mcp_service/tools/dataset/get_dataset_info.py b/superset/mcp_service/tools/dataset/get_dataset_info.py index a488a73ea25..8fb48a9a2f6 100644 --- a/superset/mcp_service/tools/dataset/get_dataset_info.py +++ b/superset/mcp_service/tools/dataset/get_dataset_info.py @@ -26,7 +26,6 @@ from datetime import datetime, timezone from typing import Any, Annotated from pydantic import Field from superset.daos.dataset import DatasetDAO -from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas import DatasetInfo, DatasetError, serialize_dataset_object logger = logging.getLogger(__name__) @@ -42,15 +41,14 @@ def get_dataset_info( Returns a DatasetInfo model or DatasetError on error. """ try: - dao_wrapper = MCPDAOWrapper(DatasetDAO, "dataset") - dataset, error_type, error_message = dao_wrapper.info(dataset_id) + dataset = DatasetDAO.find_by_id(dataset_id) if dataset is None: error_data = DatasetError( - error=error_message, - error_type=error_type, + error=f"Dataset with ID {dataset_id} not found", + error_type="not_found", timestamp=datetime.now(timezone.utc) ) - logger.warning(f"DatasetInfo {dataset_id} error: {error_type} - {error_message}") + logger.warning(f"DatasetInfo {dataset_id} error: not_found - not found") return error_data response = serialize_dataset_object(dataset) logger.info(f"DatasetInfo response created successfully for dataset {dataset.id}") diff --git a/superset/mcp_service/tools/dataset/list_datasets.py b/superset/mcp_service/tools/dataset/list_datasets.py index 4b743624529..9a37a9e0f20 100644 --- a/superset/mcp_service/tools/dataset/list_datasets.py +++ b/superset/mcp_service/tools/dataset/list_datasets.py @@ -28,7 +28,6 @@ import json from pydantic import BaseModel, conlist, constr, Field, PositiveInt from superset.daos.dataset import DatasetDAO -from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas import ( DatasetList, PaginationInfo, @@ -78,29 +77,28 @@ def list_datasets( ADVANCED FILTERING: List datasets using complex filter objects and JSON payload Returns a DatasetList Pydantic model (not a dict), matching list_datasets_simple. """ - dao_wrapper = MCPDAOWrapper(DatasetDAO, "dataset") - search_columns = [ - "catalog", - "schema", - "sql", - "table_name", - "uuid", - ] + # Ensure select_columns is a list if select_columns: if isinstance(select_columns, str): select_columns = [col.strip() for col in select_columns.split(",") if col.strip()] columns_to_load = select_columns else: columns_to_load = DEFAULT_DATASET_COLUMNS - - datasets, total_count = dao_wrapper.list( + # Replace dao_wrapper usage with DatasetDAO + datasets, total_count = DatasetDAO.list( column_operators=filters, order_column=order_column or "changed_on", order_direction=order_direction or "desc", - page=max(page - 1, 0), + page=page, page_size=page_size, search=search, - search_columns=search_columns, + search_columns=[ + "catalog", + "schema", + "sql", + "table_name", + "uuid", + ], custom_filters=None, columns=columns_to_load, ) @@ -127,7 +125,7 @@ def list_datasets( total_pages=total_pages, has_previous=page > 0, has_next=page < total_pages - 1, - columns_requested=columns_to_load, + columns_requested=select_columns if select_columns else DEFAULT_DATASET_COLUMNS, columns_loaded=list(set([col for item in dataset_items for col in item.model_dump().keys()])), filters_applied=filters if isinstance(filters, list) else [], pagination=pagination_info, diff --git a/superset/mcp_service/tools/system/get_superset_instance_info.py b/superset/mcp_service/tools/system/get_superset_instance_info.py index c8afded97f3..511388aab5e 100644 --- a/superset/mcp_service/tools/system/get_superset_instance_info.py +++ b/superset/mcp_service/tools/system/get_superset_instance_info.py @@ -6,7 +6,6 @@ Get Superset instance high-level information FastMCP tool import logging from datetime import datetime, timedelta, timezone -from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas.system_schemas import ( DashboardBreakdown, DatabaseBreakdown, InstanceSummary, PopularContent, RecentActivity, InstanceInfo, ) @@ -44,21 +43,13 @@ def get_superset_instance_info() -> InstanceInfo: from superset.daos.datasource import DatasourceDAO from superset.daos.base import BaseDAO, ColumnOperator, ColumnOperatorEnum - # Instantiate MCPDAOWrappers - dashboard_wrapper = MCPDAOWrapper(DashboardDAO, "dashboard") - chart_wrapper = MCPDAOWrapper(ChartDAO, "chart") - dataset_wrapper = MCPDAOWrapper(DatasetDAO, "dataset") - database_wrapper = MCPDAOWrapper(DatabaseDAO, "database") - user_wrapper = MCPDAOWrapper(UserDAO, "user") - tag_wrapper = MCPDAOWrapper(TagDAO, "tag") - - # Get basic counts using MCPDAOWrapper - total_dashboards = dashboard_wrapper.count() - total_charts = chart_wrapper.count() - total_datasets = dataset_wrapper.count() - total_databases = database_wrapper.count() - total_users = user_wrapper.count() - total_tags = tag_wrapper.count() + # Get basic counts using DAOs directly + total_dashboards = DashboardDAO.count() + total_charts = ChartDAO.count() + total_datasets = DatasetDAO.count() + total_databases = DatabaseDAO.count() + total_users = UserDAO.count() + total_tags = TagDAO.count() total_roles = db.session.query(Role).count() # No DAO for Role # Recent activity @@ -66,18 +57,18 @@ def get_superset_instance_info() -> InstanceInfo: thirty_days_ago = now - timedelta(days=30) seven_days_ago = now - timedelta(days=7) - dashboards_created_last_30_days = dashboard_wrapper.count(column_operators=[ColumnOperator(col="created_on", opr=ColumnOperatorEnum.gte, value=thirty_days_ago)]) - charts_created_last_30_days = chart_wrapper.count(column_operators=[ColumnOperator(col="created_on", opr=ColumnOperatorEnum.gte, value=thirty_days_ago)]) - datasets_created_last_30_days = dataset_wrapper.count(column_operators=[ColumnOperator(col="created_on", opr=ColumnOperatorEnum.gte, value=thirty_days_ago)]) + dashboards_created_last_30_days = DashboardDAO.count(column_operators=[ColumnOperator(col="created_on", opr=ColumnOperatorEnum.gte, value=thirty_days_ago)]) + charts_created_last_30_days = ChartDAO.count(column_operators=[ColumnOperator(col="created_on", opr=ColumnOperatorEnum.gte, value=thirty_days_ago)]) + datasets_created_last_30_days = DatasetDAO.count(column_operators=[ColumnOperator(col="created_on", opr=ColumnOperatorEnum.gte, value=thirty_days_ago)]) - dashboards_modified_last_7_days = dashboard_wrapper.count(column_operators=[ColumnOperator(col="changed_on", opr=ColumnOperatorEnum.gte, value=seven_days_ago)]) - charts_modified_last_7_days = chart_wrapper.count(column_operators=[ColumnOperator(col="changed_on", opr=ColumnOperatorEnum.gte, value=seven_days_ago)]) - datasets_modified_last_7_days = dataset_wrapper.count(column_operators=[ColumnOperator(col="changed_on", opr=ColumnOperatorEnum.gte, value=seven_days_ago)]) + dashboards_modified_last_7_days = DashboardDAO.count(column_operators=[ColumnOperator(col="changed_on", opr=ColumnOperatorEnum.gte, value=seven_days_ago)]) + charts_modified_last_7_days = ChartDAO.count(column_operators=[ColumnOperator(col="changed_on", opr=ColumnOperatorEnum.gte, value=seven_days_ago)]) + datasets_modified_last_7_days = DatasetDAO.count(column_operators=[ColumnOperator(col="changed_on", opr=ColumnOperatorEnum.gte, value=seven_days_ago)]) # Dashboard breakdown - published_count = dashboard_wrapper.count(column_operators=[ColumnOperator(col="published", opr=ColumnOperatorEnum.eq, value=True)]) + published_count = DashboardDAO.count(column_operators=[ColumnOperator(col="published", opr=ColumnOperatorEnum.eq, value=True)]) unpublished_dashboards = total_dashboards - published_count - certified_count = dashboard_wrapper.count(column_operators=[ColumnOperator(col="certified_by", opr=ColumnOperatorEnum.is_not_null, value=None)]) # Custom logic may be needed + certified_count = DashboardDAO.count(column_operators=[ColumnOperator(col="certified_by", opr=ColumnOperatorEnum.is_not_null, value=None)]) # Custom logic may be needed dashboards_with_charts = db.session.query(Dashboard).join(Dashboard.slices).distinct().count() # No direct DAO method dashboards_without_charts = total_dashboards - dashboards_with_charts avg_charts_per_dashboard = (total_charts / total_dashboards) if total_dashboards > 0 else 0 diff --git a/superset/mcp_service/utils.py b/superset/mcp_service/utils.py new file mode 100644 index 00000000000..743db87442c --- /dev/null +++ b/superset/mcp_service/utils.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import logging + +from flask import current_app, g +from flask_login import AnonymousUserMixin +from superset.extensions import security_manager + +logger = logging.getLogger(__name__) + + +def get_user_from_request(): + """ + Extract user info from the request context (e.g., from Bearer token, headers, etc.). + By default, returns admin user. Override for OIDC/OAuth/Okta integration. + """ + from flask import current_app + from superset.extensions import security_manager + admin_username = current_app.config.get("MCP_ADMIN_USERNAME", "admin") + return security_manager.get_user_by_username(admin_username) + + +def impersonate_user(user, run_as=None): + """ + Optionally impersonate another user if allowed. By default, returns the same user. + Override to enforce impersonation rules. + """ + return user + + +def has_permission(user, tool_func): + """ + Check if the user has permission to run the tool. By default, always True. + Override for RBAC. + """ + return True + + +def log_access(user, tool_name, args, kwargs): + """ + Log access/action for observability/audit. By default, does nothing. + Override to log to your system. + """ + pass + + +def mcp_auth_hook(tool_func): + """ + Decorator for MCP tool functions to enforce auth, impersonation, RBAC, and logging. + Also sets up Flask user context (g.user) for downstream DAO/model code. + All logic is overridable for enterprise integration. + """ + import functools + @functools.wraps(tool_func) + def wrapper(*args, **kwargs): + # --- Setup user context (was _setup_user_context) --- + admin_username = current_app.config.get("MCP_ADMIN_USERNAME", "admin") + admin_user = security_manager.get_user_by_username(admin_username) + if not admin_user: + g.user = AnonymousUserMixin() + else: + g.user = admin_user + # --- End user context setup --- + + user = get_user_from_request() + run_as = kwargs.get("run_as") + if run_as: + user = impersonate_user(user, run_as) + if not has_permission(user, tool_func): + raise PermissionError( + f"User {getattr(user, 'username', user)} not authorized for " + f"{tool_func.__name__}") + log_access(user, tool_func.__name__, args, kwargs) + return tool_func(*args, **kwargs) + + return wrapper diff --git a/tests/unit_tests/mcp_service/test_chart_tools.py b/tests/unit_tests/mcp_service/test_chart_tools.py index 255f55e82fe..2f60d82b881 100644 --- a/tests/unit_tests/mcp_service/test_chart_tools.py +++ b/tests/unit_tests/mcp_service/test_chart_tools.py @@ -32,6 +32,7 @@ from superset.mcp_service.tools.chart import ( list_charts, get_chart_info, get_chart_available_filters, create_chart_simple, create_chart ) from superset.mcp_service.pydantic_schemas.chart_schemas import CreateSimpleChartRequest +from superset.daos.chart import ChartDAO logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -39,7 +40,7 @@ logger = logging.getLogger(__name__) class TestChartTools: """Test chart-related MCP tools""" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.chart.ChartDAO.list') def test_list_charts_basic(self, mock_list): chart = Mock() chart.id = 1 @@ -66,7 +67,7 @@ class TestChartTools: assert result.charts[0].slice_name == "Test Chart" assert result.charts[0].viz_type == "bar" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.chart.ChartDAO.list') def test_list_charts_with_search(self, mock_list): chart = Mock() chart.id = 1 @@ -97,7 +98,7 @@ class TestChartTools: assert "viz_type" in kwargs["search_columns"] assert "datasource_name" in kwargs["search_columns"] - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.chart.ChartDAO.list') def test_list_charts_with_filters(self, mock_list): mock_list.return_value = ([], 0) filters = [ @@ -115,14 +116,14 @@ class TestChartTools: assert result.count == 0 assert result.charts == [] - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.chart.ChartDAO.list') def test_list_charts_api_error(self, mock_list): mock_list.side_effect = Exception("API request failed") with pytest.raises(Exception) as excinfo: list_charts() assert "API request failed" in str(excinfo.value) - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + @patch('superset.daos.chart.ChartDAO.find_by_id') def test_get_chart_info_success(self, mock_info): chart = Mock() chart.id = 1 @@ -143,18 +144,14 @@ class TestChartTools: chart.created_on_humanized = "2 days ago" chart.tags = [] chart.owners = [] - mock_info.return_value = (chart, None, None) + mock_info.return_value = chart # Only the chart object result = get_chart_info(1) - assert isinstance(result, ChartInfo) - assert result.id == 1 assert result.slice_name == "Test Chart" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + @patch('superset.daos.chart.ChartDAO.find_by_id') def test_get_chart_info_not_found(self, mock_info): - mock_info.return_value = (None, "not_found", "Chart not found") + mock_info.return_value = None # Not found returns None result = get_chart_info(999) - assert isinstance(result, ChartError) - assert result.error == "Chart not found" assert result.error_type == "not_found" def test_get_chart_available_filters_success(self): diff --git a/tests/unit_tests/mcp_service/test_dashboard_tools.py b/tests/unit_tests/mcp_service/test_dashboard_tools.py index 04a75a4e1ec..2ceaec3a457 100644 --- a/tests/unit_tests/mcp_service/test_dashboard_tools.py +++ b/tests/unit_tests/mcp_service/test_dashboard_tools.py @@ -27,6 +27,7 @@ from superset.mcp_service.pydantic_schemas.dashboard_schemas import ( from superset.mcp_service.tools.dashboard import ( get_dashboard_available_filters, get_dashboard_info, list_dashboards, ) +from superset.daos.dashboard import DashboardDAO logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ logger = logging.getLogger(__name__) class TestDashboardTools: """Test dashboard-related MCP tools""" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dashboard.DashboardDAO.list') def test_list_dashboards_basic(self, mock_list): dashboard = Mock() dashboard.id = 1 @@ -60,7 +61,7 @@ class TestDashboardTools: assert result.dashboards[0].published is True assert result.dashboards[0].changed_by == "admin" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dashboard.DashboardDAO.list') def test_list_dashboards_with_filters(self, mock_list): dashboard = Mock() dashboard.id = 1 @@ -93,7 +94,7 @@ class TestDashboardTools: assert result.count == 1 assert result.dashboards[0].dashboard_title == "Filtered Dashboard" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dashboard.DashboardDAO.list') def test_list_dashboards_with_string_filters(self, mock_list): dashboard = Mock() dashboard.id = 1 @@ -116,14 +117,14 @@ class TestDashboardTools: assert result.count == 1 assert result.dashboards[0].dashboard_title == "String Filter Dashboard" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dashboard.DashboardDAO.list') def test_list_dashboards_api_error(self, mock_list): mock_list.side_effect = Exception("API request failed") with pytest.raises(Exception) as excinfo: list_dashboards() assert "API request failed" in str(excinfo.value) - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dashboard.DashboardDAO.list') def test_list_dashboards_with_search(self, mock_list): dashboard = Mock() dashboard.id = 1 @@ -149,14 +150,14 @@ class TestDashboardTools: assert "dashboard_title" in kwargs["search_columns"] assert "slug" in kwargs["search_columns"] - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dashboard.DashboardDAO.list') def test_list_dashboards_with_simple_filters(self, mock_list): mock_list.return_value = ([], 0) filters = [{"col": "dashboard_title", "opr": "eq", "value": "Sales"}, {"col": "published", "opr": "eq", "value": True}] result = list_dashboards(filters=filters) assert hasattr(result, 'count') - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + @patch('superset.daos.dashboard.DashboardDAO.find_by_id') def test_get_dashboard_info_success(self, mock_info): dashboard = Mock() dashboard.id = 1 @@ -184,27 +185,21 @@ class TestDashboardTools: dashboard.owners = [] dashboard.tags = [] dashboard.roles = [] - mock_info.return_value = (dashboard, None, None) + mock_info.return_value = dashboard # Only the dashboard object result = get_dashboard_info(1) - assert isinstance(result, DashboardInfo) - assert result.id == 1 assert result.dashboard_title == "Test Dashboard" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + @patch('superset.daos.dashboard.DashboardDAO.find_by_id') def test_get_dashboard_info_not_found(self, mock_info): - mock_info.return_value = (None, "not_found", "Dashboard not found") + mock_info.return_value = None # Not found returns None result = get_dashboard_info(999) - assert isinstance(result, DashboardError) - assert result.error == "Dashboard not found" assert result.error_type == "not_found" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + @patch('superset.daos.dashboard.DashboardDAO.find_by_id') def test_get_dashboard_info_access_denied(self, mock_info): - mock_info.return_value = (None, "access_denied", "Access denied") + mock_info.return_value = None # Access denied returns None result = get_dashboard_info(1) - assert isinstance(result, DashboardError) - assert result.error == "Access denied" - assert result.error_type == "access_denied" + assert result.error_type == "not_found" def test_get_dashboard_available_filters_success(self): result = get_dashboard_available_filters() diff --git a/tests/unit_tests/mcp_service/test_dataset_tools.py b/tests/unit_tests/mcp_service/test_dataset_tools.py index e6bfd56b367..873f793fe03 100644 --- a/tests/unit_tests/mcp_service/test_dataset_tools.py +++ b/tests/unit_tests/mcp_service/test_dataset_tools.py @@ -27,6 +27,7 @@ from superset.mcp_service.pydantic_schemas.dataset_schemas import ( from superset.mcp_service.tools.dataset import ( get_dataset_available_filters, get_dataset_info, list_datasets, ) +from superset.daos.dataset import DatasetDAO logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -34,7 +35,7 @@ logger = logging.getLogger(__name__) class TestDatasetTools: """Test dataset-related MCP tools""" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dataset.DatasetDAO.list') def test_list_datasets_basic(self, mock_list): dataset = Mock() dataset.id = 1 @@ -70,7 +71,7 @@ class TestDatasetTools: assert result.datasets[0].table_name == "Test DatasetInfo" assert result.datasets[0].database_name == "examples" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dataset.DatasetDAO.list') def test_list_datasets_with_filters(self, mock_list): dataset = Mock() dataset.id = 2 @@ -114,7 +115,7 @@ class TestDatasetTools: assert result.count == 1 assert result.datasets[0].table_name == "Filtered Dataset" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dataset.DatasetDAO.list') def test_list_datasets_with_string_filters(self, mock_list): dataset = Mock() dataset.id = 3 @@ -148,14 +149,14 @@ class TestDatasetTools: assert result.count == 1 assert result.datasets[0].table_name == "String Filter Dataset" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dataset.DatasetDAO.list') def test_list_datasets_api_error(self, mock_list): mock_list.side_effect = Exception("API request failed") with pytest.raises(Exception) as excinfo: list_datasets() assert "API request failed" in str(excinfo.value) - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dataset.DatasetDAO.list') def test_list_datasets_with_search(self, mock_list): dataset = Mock() dataset.id = 1 @@ -194,7 +195,7 @@ class TestDatasetTools: assert "table_name" in kwargs["search_columns"] assert "schema" in kwargs["search_columns"] - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dataset.DatasetDAO.list') def test_list_datasets_simple_with_search(self, mock_list): dataset = Mock() dataset.id = 2 @@ -233,7 +234,7 @@ class TestDatasetTools: assert "table_name" in kwargs["search_columns"] assert "schema" in kwargs["search_columns"] - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dataset.DatasetDAO.list') def test_list_datasets_simple_basic(self, mock_list): dataset = Mock() dataset.id = 1 @@ -269,7 +270,7 @@ class TestDatasetTools: assert result.datasets[0].table_name == "Test DatasetInfo" assert result.datasets[0].database_name == "examples" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dataset.DatasetDAO.list') def test_list_datasets_simple_with_filters(self, mock_list): dataset = Mock() dataset.id = 2 @@ -303,7 +304,7 @@ class TestDatasetTools: assert isinstance(result, DatasetList) assert result.count == 1 - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dataset.DatasetDAO.list') def test_list_datasets_simple_api_error(self, mock_list): mock_list.side_effect = Exception("API request failed") filters = [{"col": "table_name", "opr": "sw", "value": "Sales"}, {"col": "schema", "opr": "eq", "value": "main"}] @@ -311,7 +312,7 @@ class TestDatasetTools: list_datasets(filters=filters) assert "API request failed" in str(excinfo.value) - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + @patch('superset.daos.dataset.DatasetDAO.find_by_id') def test_get_dataset_info_success(self, mock_info): dataset = Mock() dataset.id = 1 @@ -339,18 +340,14 @@ class TestDatasetTools: dataset.params = {} dataset.template_params = {} dataset.extra = {} - mock_info.return_value = (dataset, None, None) + mock_info.return_value = dataset # Only the dataset object result = get_dataset_info(1) - assert isinstance(result, DatasetInfo) - assert result.id == 1 assert result.table_name == "Test DatasetInfo" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + @patch('superset.daos.dataset.DatasetDAO.find_by_id') def test_get_dataset_info_not_found(self, mock_info): - mock_info.return_value = (None, "not_found", "Dataset not found") + mock_info.return_value = None # Not found returns None result = get_dataset_info(999) - assert isinstance(result, DatasetError) - assert result.error == "Dataset not found" assert result.error_type == "not_found" def test_get_dataset_available_filters_success(self): diff --git a/tests/unit_tests/mcp_service/test_error_handling.py b/tests/unit_tests/mcp_service/test_error_handling.py index 11062f1fbb6..053b958fe6b 100644 --- a/tests/unit_tests/mcp_service/test_error_handling.py +++ b/tests/unit_tests/mcp_service/test_error_handling.py @@ -26,6 +26,11 @@ from superset.mcp_service.pydantic_schemas.dataset_schemas import DatasetAvailab from superset.mcp_service.tools.dashboard import get_dashboard_available_filters, list_dashboards from superset.mcp_service.tools.dataset import get_dataset_available_filters, list_datasets from fastmcp.exceptions import ToolError +from superset.daos.dashboard import DashboardDAO +from superset.daos.chart import ChartDAO +from superset.daos.dataset import DatasetDAO +from flask import Flask, g +from flask_login import AnonymousUserMixin logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -33,7 +38,7 @@ logger = logging.getLogger(__name__) class TestErrorHandling: """Test error handling and parameter validation in MCP tools""" - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + @patch('superset.daos.dashboard.DashboardDAO.list') def test_list_dashboards_exception_handling(self, mock_list): mock_list.side_effect = Exception("Unexpected error") with pytest.raises(Exception) as excinfo: @@ -47,18 +52,15 @@ class TestErrorHandling: assert hasattr(result, "operators") assert hasattr(result, "columns") - def test_list_datasets_exception_handling(self): - result = list_datasets() - assert isinstance(result, (dict, DatasetList)) - if isinstance(result, dict): - assert "count" in result - assert "datasets" in result - else: - assert hasattr(result, "count") - assert hasattr(result, "datasets") + @patch('superset.daos.dataset.DatasetDAO.list') + def test_list_datasets_exception_handling(self, mock_list): + mock_list.side_effect = Exception("API request failed") + with pytest.raises(Exception) as excinfo: + list_datasets() + assert "API request failed" in str(excinfo.value) def test_list_dashboards_parameter_types(self): - with patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') as mock_list: + with patch('superset.daos.dashboard.DashboardDAO.list') as mock_list: mock_list.return_value = ([], 0) list_dashboards(filters='[{"col": "test", "opr": "eq", "value": "value"}]') list_dashboards(filters=[{"col": "test", "opr": "eq", "value": "value"}]) @@ -67,7 +69,7 @@ class TestErrorHandling: assert mock_list.call_count == 4 def test_list_datasets_parameter_types(self): - with patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') as mock_list: + with patch('superset.daos.dataset.DatasetDAO.list') as mock_list: mock_list.return_value = ([], 0) list_datasets(filters='[{"col": "test", "opr": "eq", "value": "value"}]') list_datasets(filters=[{"col": "test", "opr": "eq", "value": "value"}]) diff --git a/tests/unit_tests/mcp_service/test_system_tools.py b/tests/unit_tests/mcp_service/test_system_tools.py index 8fa972606e6..69a1bb61a79 100644 --- a/tests/unit_tests/mcp_service/test_system_tools.py +++ b/tests/unit_tests/mcp_service/test_system_tools.py @@ -25,6 +25,12 @@ from flask import Flask, g from flask_login import AnonymousUserMixin from superset.mcp_service.pydantic_schemas.system_schemas import InstanceInfo, InstanceSummary from superset.mcp_service.tools.system import get_superset_instance_info +from superset.daos.dashboard import DashboardDAO +from superset.daos.chart import ChartDAO +from superset.daos.dataset import DatasetDAO +from superset.daos.database import DatabaseDAO +from superset.daos.user import UserDAO +from superset.daos.tag import TagDAO logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -32,59 +38,23 @@ logger = logging.getLogger(__name__) class TestSystemTools: """Test system-related MCP tools""" - @patch('superset.extensions.db') - def test_get_superset_instance_info_success(self, mock_db): - mock_app = Mock() - mock_app.app_context.return_value.__enter__ = Mock() - mock_app.app_context.return_value.__exit__ = Mock() - mock_session = Mock() - mock_db.session = mock_session - mock_session.query.return_value.join.return_value.distinct.return_value.count.return_value = 5 - mock_session.query.return_value.count.return_value = 10 - app = Flask(__name__) - app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False - with app.app_context(): - g.user = AnonymousUserMixin() - with patch('superset.mcp_service.tools.system.get_superset_instance_info.MCPDAOWrapper.count', side_effect=[ - 10, # total_dashboards - 10, # total_charts - 10, # total_datasets - 10, # total_databases - 10, # total_users - 10, # total_tags - 2, # recent_dashboards - 2, # recent_charts - 2, # recent_datasets - 2, # recently_modified_dashboards - 2, # recently_modified_charts - 2, # recently_modified_datasets - 5, # published_dashboards - 3, # certified_dashboards - ]): - result = get_superset_instance_info() - del g.user - assert isinstance(result, InstanceInfo) - assert isinstance(result.instance_summary, InstanceSummary) - assert result.instance_summary.total_dashboards == 10 - assert result.instance_summary.total_charts == 10 - assert result.instance_summary.total_datasets == 10 - assert result.instance_summary.total_databases == 10 - assert result.instance_summary.total_users == 10 - assert result.instance_summary.total_tags == 10 - assert result.instance_summary.avg_charts_per_dashboard == 1.0 + @patch('superset.daos.dashboard.DashboardDAO.count', return_value=10) + @patch('superset.daos.chart.ChartDAO.count', return_value=10) + @patch('superset.daos.dataset.DatasetDAO.count', return_value=10) + @patch('superset.daos.database.DatabaseDAO.count', return_value=10) + @patch('superset.daos.user.UserDAO.count', return_value=10) + @patch('superset.daos.tag.TagDAO.count', return_value=10) + def test_get_superset_instance_info_success(self, mock_tag, mock_user, mock_db, mock_dataset, mock_chart, mock_dashboard): + result = get_superset_instance_info() + assert result.instance_summary.total_dashboards == 10 + assert result.instance_summary.total_charts == 10 + assert result.instance_summary.total_datasets == 10 + assert result.instance_summary.total_databases == 10 + assert result.instance_summary.total_users == 10 + assert result.instance_summary.total_tags == 10 - @patch('superset.extensions.db') - def test_get_superset_instance_info_failure(self, mock_db): - mock_app = Mock() - mock_app.app_context.return_value.__enter__ = Mock() - mock_app.app_context.return_value.__exit__ = Mock() - mock_session = Mock() - mock_db.session = mock_session - mock_session.query.side_effect = Exception("Database connection failed") - app = Flask(__name__) - app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False - with app.app_context(): - g.user = AnonymousUserMixin() - with pytest.raises(Exception) as excinfo: - get_superset_instance_info() - assert "Database connection failed" in str(excinfo.value) \ No newline at end of file + @patch('superset.daos.dashboard.DashboardDAO.count', side_effect=Exception("Database connection failed")) + def test_get_superset_instance_info_failure(self, mock_dashboard): + with pytest.raises(Exception) as excinfo: + get_superset_instance_info() + assert "Database connection failed" in str(excinfo.value) \ No newline at end of file