From 9d40fe913f7045422311b7b8ec9fc04c4d0eec7f Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Mon, 14 Jul 2025 21:15:32 +1000 Subject: [PATCH] update: wip --- superset/mcp_service/README_SCHEMAS.md | 106 +-- superset/mcp_service/dao_wrapper.py | 48 +- .../mcp_service/pydantic_schemas/__init__.py | 103 +-- .../pydantic_schemas/chart_schemas.py | 40 +- .../pydantic_schemas/dashboard_schemas.py | 217 ++--- .../pydantic_schemas/dataset_schemas.py | 108 +-- .../pydantic_schemas/system_schemas.py | 36 +- superset/mcp_service/server.py | 2 +- .../tools/chart/create_chart_simple.py | 2 +- .../mcp_service/tools/chart/get_chart_info.py | 19 +- .../mcp_service/tools/chart/list_charts.py | 17 +- .../get_dashboard_available_filters.py | 8 +- .../tools/dashboard/get_dashboard_info.py | 63 +- .../tools/dashboard/list_dashboards.py | 48 +- .../dataset/get_dataset_available_filters.py | 8 +- .../tools/dataset/get_dataset_info.py | 51 +- .../tools/dataset/list_datasets.py | 34 +- .../system/get_superset_instance_info.py | 26 +- .../mcp_service/run_mcp_tests.py | 44 +- .../mcp_service/test_get_chart_list_tools.py | 18 +- .../test_get_dataset_list_tools.py | 18 +- .../mcp_service/test_chart_tools.py | 224 +++++ .../mcp_service/test_dashboard_tools.py | 221 +++++ .../mcp_service/test_dataset_tools.py | 367 ++++++++ .../mcp_service/test_error_handling.py | 78 ++ .../mcp_service/test_fastmcp_tools.py | 827 ------------------ .../mcp_service/test_protocol_integration.py | 108 +++ .../mcp_service/test_system_tools.py | 90 ++ 28 files changed, 1465 insertions(+), 1466 deletions(-) create mode 100644 tests/unit_tests/mcp_service/test_chart_tools.py create mode 100644 tests/unit_tests/mcp_service/test_dashboard_tools.py create mode 100644 tests/unit_tests/mcp_service/test_dataset_tools.py create mode 100644 tests/unit_tests/mcp_service/test_error_handling.py delete mode 100644 tests/unit_tests/mcp_service/test_fastmcp_tools.py create mode 100644 tests/unit_tests/mcp_service/test_protocol_integration.py create mode 100644 tests/unit_tests/mcp_service/test_system_tools.py diff --git a/superset/mcp_service/README_SCHEMAS.md b/superset/mcp_service/README_SCHEMAS.md index 20c9b23d6e4..d6f1d827480 100644 --- a/superset/mcp_service/README_SCHEMAS.md +++ b/superset/mcp_service/README_SCHEMAS.md @@ -17,8 +17,8 @@ This document provides a reference for the input and output parameters of all MC - `select_columns`: `Optional[List[str]]` — Columns to select (overrides columns/keys) - `search`: `Optional[str]` — Free-text search string -**Returns:** `DashboardListResponse` -- `dashboards`: `List[DashboardListItem]` +**Returns:** `DashboardList` +- `dashboards`: `List[DashboardInfo]` - `count`: `int` - `total_count`: `int` - `page`: `int` @@ -32,26 +32,14 @@ This document provides a reference for the input and output parameters of all MC - `pagination`: `PaginationInfo` - `timestamp`: `datetime` -### list_dashboards_simple - -**Inputs:** -- `filters`: `Optional[DashboardSimpleFilters]` — Simple filter object -- `order_column`: `Optional[str]` — Column to order results by -- `order_direction`: `Literal['asc', 'desc']` — Order direction -- `page`: `int` — Page number (1-based) -- `page_size`: `int` — Number of items per page -- `search`: `Optional[str]` — Free-text search string - -**Returns:** `DashboardListResponse` (see above) - ### get_dashboard_info **Inputs:** - `dashboard_id`: `int` — Dashboard ID -**Returns:** `DashboardInfoResponse` or `DashboardErrorResponse` +**Returns:** `DashboardInfo` or `DashboardError` -**DashboardInfoResponse:** +**DashboardInfo:** - `id`: `int` - `dashboard_title`: `str` - `slug`: `Optional[str]` @@ -79,7 +67,7 @@ This document provides a reference for the input and output parameters of all MC - `roles`: `List[RoleInfo]` - `charts`: `List[ChartInfo]` -**DashboardErrorResponse:** +**DashboardError:** - `error`: `str` - `error_type`: `str` - `timestamp`: `Optional[Union[str, datetime]]` @@ -89,7 +77,7 @@ This document provides a reference for the input and output parameters of all MC **Inputs:** - (none) -**Returns:** `DashboardAvailableFiltersResponse` +**Returns:** `DashboardAvailableFilters` - `filters`: `Dict[str, Any]` - `operators`: `List[str]` - `columns`: `List[str]` @@ -109,8 +97,8 @@ This document provides a reference for the input and output parameters of all MC - `select_columns`: `Optional[List[str]]` — Columns to select (overrides columns/keys) - `search`: `Optional[str]` — Free-text search string -**Returns:** `DatasetListResponse` -- `datasets`: `List[DatasetListItem]` +**Returns:** `DatasetList` +- `datasets`: `List[DatasetInfo]` - `count`: `int` - `total_count`: `int` - `page`: `int` @@ -134,14 +122,14 @@ This document provides a reference for the input and output parameters of all MC - `page_size`: `int` — Number of items per page - `search`: `Optional[str]` — Free-text search string -**Returns:** `DatasetListResponse` (see above) +**Returns:** `DatasetList` (see above) ### get_dataset_info **Inputs:** -- `dataset_id`: `int` — Dataset ID +- `dataset_id`: `int` — DatasetInfo ID -**Returns:** `DatasetInfoResponse` or `DatasetErrorResponse` +**Returns:** `DatasetInfoResponse` or `DatasetError` **DatasetInfoResponse:** - `id`: `int` @@ -169,7 +157,7 @@ This document provides a reference for the input and output parameters of all MC - `template_params`: `Optional[Dict[str, Any]]` - `extra`: `Optional[Dict[str, Any]]` -**DatasetErrorResponse:** +**DatasetError:** - `error`: `str` - `error_type`: `str` - `timestamp`: `Optional[Union[str, datetime]]` @@ -179,7 +167,7 @@ This document provides a reference for the input and output parameters of all MC **Inputs:** - (none) -**Returns:** `DatasetAvailableFiltersResponse` +**Returns:** `DatasetAvailableFilters` - `filters`: `Dict[str, Any]` - `operators`: `List[str]` - `columns`: `List[str]` @@ -199,8 +187,8 @@ This document provides a reference for the input and output parameters of all MC - `select_columns`: `Optional[List[str]]` — Columns to select (overrides columns/keys) - `search`: `Optional[str]` — Free-text search string -**Returns:** `ChartListResponse` -- `charts`: `List[ChartListItem]` +**Returns:** `ChartList` +- `charts`: `List[ChartInfo]` - `count`: `int` - `total_count`: `int` - `page`: `int` @@ -214,29 +202,17 @@ This document provides a reference for the input and output parameters of all MC - `pagination`: `PaginationInfo` - `timestamp`: `datetime` -### list_charts_simple - -**Inputs:** -- `filters`: `Optional[ChartSimpleFilters]` — Simple filter object -- `order_column`: `Optional[str]` — Column to order results by -- `order_direction`: `Literal['asc', 'desc']` — Order direction -- `page`: `int` — Page number (1-based) -- `page_size`: `int` — Number of items per page -- `search`: `Optional[str]` — Free-text search string - -**Returns:** `ChartListResponse` (see above) - ### get_chart_info **Inputs:** - `chart_id`: `int` — Chart ID -**Returns:** `ChartInfoResponse` or `ChartErrorResponse` +**Returns:** `ChartInfoResponse` or `ChartError` **ChartInfoResponse:** -- `chart`: `ChartListItem` +- `chart`: `ChartInfo` -**ChartErrorResponse:** +**ChartError:** - `error`: `str` - `error_type`: `str` - `timestamp`: `Optional[Union[str, datetime]]` @@ -257,7 +233,7 @@ This document provides a reference for the input and output parameters of all MC - `request`: `CreateSimpleChartRequest` — Chart creation request **Returns:** `CreateSimpleChartResponse` -- `chart`: `Optional[ChartListItem]` +- `chart`: `Optional[ChartInfo]` - `embed_url`: `Optional[str]` - `thumbnail_url`: `Optional[str]` - `embed_html`: `Optional[str]` @@ -270,7 +246,7 @@ This document provides a reference for the input and output parameters of all MC **Inputs:** - (none) -**Returns:** `SupersetInstanceInfoResponse` +**Returns:** `InstanceInfo` - `instance_summary`: `InstanceSummary` - `recent_activity`: `RecentActivity` - `dashboard_breakdown`: `DashboardBreakdown` @@ -282,34 +258,12 @@ This document provides a reference for the input and output parameters of all MC ## Complex Type Definitions -### DashboardSimpleFilters -- `dashboard_title`: `Optional[str]` -- `published`: `Optional[bool]` -- `changed_by`: `Optional[str]` -- `created_by`: `Optional[str]` -- `owner`: `Optional[str]` -- `certified`: `Optional[bool]` -- `favorite`: `Optional[bool]` -- `chart_count`: `Optional[int]` -- `chart_count_min`: `Optional[int]` -- `chart_count_max`: `Optional[int]` -- `tags`: `Optional[str]` - ### ChartFilter - `col`: `Literal[ ... ]` (see allowed columns in code) - `opr`: `Literal[ ... ]` (see allowed operators in code) - `value`: `Any` -### ChartSimpleFilters -- `slice_name`: `Optional[str]` -- `viz_type`: `Optional[str]` -- `datasource_name`: `Optional[str]` -- `changed_by`: `Optional[str]` -- `created_by`: `Optional[str]` -- `owner`: `Optional[str]` -- `tags`: `Optional[str]` - -### ChartListItem +### ChartInfo - `id`: `int` - `slice_name`: `str` - `viz_type`: `Optional[str]` @@ -373,7 +327,7 @@ This document provides a reference for the input and output parameters of all MC - `created_on`: `Optional[Union[str, datetime]]` - `changed_on`: `Optional[Union[str, datetime]]` -### DashboardListItem +### DashboardInfo - `id`: `int` - `dashboard_title`: `str` - `slug`: `Optional[str]` @@ -389,7 +343,7 @@ This document provides a reference for the input and output parameters of all MC - `tags`: `List[TagInfo]` - `owners`: `List[UserInfo]` -### DatasetListItem +### DatasetInfo - `id`: `int` - `table_name`: `str` - `db_schema`: `Optional[str]` @@ -409,16 +363,6 @@ This document provides a reference for the input and output parameters of all MC - `schema_perm`: `Optional[str]` - `url`: `Optional[str]` -### DatasetSimpleFilters -- `table_name`: `Optional[str]` -- `db_schema`: `Optional[str]` -- `database_name`: `Optional[str]` -- `changed_by`: `Optional[str]` -- `created_by`: `Optional[str]` -- `owner`: `Optional[str]` -- `is_virtual`: `Optional[bool]` -- `tags`: `Optional[str]` - ### DatasetFilter - `col`: `Literal[ ... ]` (see allowed columns in code) - `opr`: `Literal[ ... ]` (see allowed operators in code) @@ -438,7 +382,7 @@ This document provides a reference for the input and output parameters of all MC - `return_embed`: `Optional[bool]` ### CreateSimpleChartResponse -- `chart`: `Optional[ChartListItem]` +- `chart`: `Optional[ChartInfo]` - `embed_url`: `Optional[str]` - `thumbnail_url`: `Optional[str]` - `embed_html`: `Optional[str]` @@ -474,4 +418,4 @@ This document provides a reference for the input and output parameters of all MC ### PopularContent - `top_tags`: `List[str]` -- `top_creators`: `List[str]` \ No newline at end of file +- `top_creators`: `List[str]` diff --git a/superset/mcp_service/dao_wrapper.py b/superset/mcp_service/dao_wrapper.py index 06669ff7ef1..473ceddc3a2 100644 --- a/superset/mcp_service/dao_wrapper.py +++ b/superset/mcp_service/dao_wrapper.py @@ -58,7 +58,7 @@ from flask import current_app, g from flask_appbuilder.models.sqla import Model from flask_login import AnonymousUserMixin -from superset.daos.base import BaseDAO +from superset.daos.base import BaseDAO, ColumnOperator from superset.extensions import security_manager logger = logging.getLogger(__name__) @@ -108,23 +108,7 @@ class MCPDAOWrapper: if not item: self.logger.warning(f"{self.model_name.capitalize()} with ID {item_id} not found") return None, "not_found", f"{self.model_name.capitalize()} with ID {item_id} not found" - - # Apply security context - check if user has access - try: - # Try to call raise_for_access if the model supports it - if hasattr(item, 'raise_for_access'): - item.raise_for_access() - elif hasattr(security_manager, f'raise_for_access'): - # Use security manager's generic access check - security_manager.raise_for_access(**{self.model_name: item}) - - self.logger.debug(f"User has access to {self.model_name} {item_id}") - return item, None, None - - except Exception as access_error: - self.logger.warning( - f"User does not have access to {self.model_name} {item_id}: {access_error}") - return None, "access_denied", f"Access denied to {self.model_name} {item_id}" + return item, None, None except Exception as e: error_msg = f"Unexpected error getting {self.model_name} info: {str(e)}" @@ -133,27 +117,29 @@ class MCPDAOWrapper: def list( self, - filters: Optional[list[dict] | Dict[str, Any]] = None, + column_operators: Optional[List[ColumnOperator]] = None, order_column: str = "changed_on", order_direction: str = "desc", page: int = 0, page_size: int = 100, search: Optional[str] = None, - search_columns: Optional[List[str]] = None, - ) -> Tuple[List[T], int]: + search_columns: Optional[list] = None, + custom_filters: Optional[dict] = None, + ) -> Tuple[list, int]: """ - List items using the DAO's list method. Supports advanced filters: a list of dicts with col, opr, value, or a simple dict for backward compatibility. + Generic list method for filtered, sorted, and paginated results. """ - self.logger.info(f"Listing {self.model_name}s with filters: {filters}") + self.logger.info(f"Listing {self.model_name}s with column_operators: {column_operators}") try: items, total_count = self.dao_class.list( - filters=filters, + column_operators=column_operators, order_column=order_column, order_direction=order_direction, page=page, page_size=page_size, search=search, search_columns=search_columns, + custom_filters=custom_filters, ) self.logger.info(f"Retrieved {len(items)} {self.model_name}s (total: {total_count})") return items, total_count @@ -162,14 +148,18 @@ class MCPDAOWrapper: self.logger.error(error_msg, exc_info=True) return [], 0 - def count(self, filters: Optional[List[dict] | dict] = None) -> int: + def count( + self, + column_operators: Optional[List[ColumnOperator]] = None, + skip_base_filter: bool = False, + ) -> int: """ - Return the count of records, optionally filtered. Supports advanced filters: a list of dicts with col, opr, value, or a simple dict for backward compatibility. + Count the number of records for the model, optionally filtered by column operators. """ - if filters is None: - filters = [] + if column_operators is None: + column_operators = [] try: - return self.dao_class.count(filters) + return self.dao_class.count(column_operators, skip_base_filter=skip_base_filter) except Exception as e: self.logger.error(f"Error counting records: {e}") return 0 diff --git a/superset/mcp_service/pydantic_schemas/__init__.py b/superset/mcp_service/pydantic_schemas/__init__.py index 0a754a55c59..90091eb1784 100644 --- a/superset/mcp_service/pydantic_schemas/__init__.py +++ b/superset/mcp_service/pydantic_schemas/__init__.py @@ -22,90 +22,67 @@ This package contains Pydantic schemas for the MCP service responses. """ from .dashboard_schemas import ( - DashboardInfoResponse, - DashboardErrorResponse, - DashboardListResponse, - DashboardListItem, - PaginationInfo, - UserInfo, - TagInfo, - RoleInfo, - ChartInfo, - serialize_user_object, - serialize_tag_object, - serialize_role_object, - serialize_chart_object, - DashboardAvailableFiltersResponse, + DashboardInfo, + DashboardError, + DashboardList, + DashboardAvailableFilters, DashboardFilter, ) from .system_schemas import ( - SupersetInstanceInfoResponse, + InstanceInfo, InstanceSummary, RecentActivity, DashboardBreakdown, DatabaseBreakdown, PopularContent, + UserInfo, + TagInfo, + RoleInfo, + PaginationInfo, ) from .dataset_schemas import ( - DatasetListItem, - DatasetListResponse, - DatasetSimpleFilters, + DatasetInfo, + DatasetList, serialize_dataset_object, - DatasetAvailableFiltersResponse, - DatasetInfoResponse, - DatasetErrorResponse, + DatasetAvailableFilters, + DatasetError, DatasetFilter, ) from .chart_schemas import ( - ChartListResponse, - ChartListItem, - ChartSimpleFilters, + ChartList, + ChartInfo, ChartAvailableFiltersResponse, - ChartInfoResponse, - ChartErrorResponse, + ChartError, serialize_chart_object, - CreateSimpleChartRequest, - CreateSimpleChartResponse, ChartFilter, ) __all__ = [ - "DashboardInfoResponse", - "DashboardErrorResponse", - "DashboardListResponse", - "DashboardListItem", - "PaginationInfo", - "UserInfo", - "TagInfo", - "RoleInfo", - "ChartInfo", - "serialize_user_object", - "serialize_tag_object", - "serialize_role_object", - "serialize_chart_object", - "DatasetListItem", - "DatasetListResponse", - "DatasetSimpleFilters", - "serialize_dataset_object", - "DashboardAvailableFiltersResponse", - "SupersetInstanceInfoResponse", - "InstanceSummary", - "RecentActivity", - "DashboardBreakdown", - "DatabaseBreakdown", - "PopularContent", - "DatasetAvailableFiltersResponse", - "DatasetInfoResponse", - "DatasetErrorResponse", - "ChartListResponse", - "ChartListItem", - "ChartSimpleFilters", "ChartAvailableFiltersResponse", - "ChartInfoResponse", - "ChartErrorResponse", - "CreateSimpleChartRequest", - "CreateSimpleChartResponse", + "ChartError", "ChartFilter", + "ChartInfo", + "ChartList", + "DashboardAvailableFilters", + "DashboardBreakdown", + "DashboardError", "DashboardFilter", + "DashboardInfo", + "DashboardList", + "DatabaseBreakdown", + "DatasetAvailableFilters", + "DatasetError", "DatasetFilter", -] \ No newline at end of file + "DatasetInfo", + "DatasetList", + "InstanceInfo", + "InstanceSummary", + "PaginationInfo", + "PopularContent", + "RecentActivity", + "RoleInfo", + "TagInfo", + "UserInfo", + "serialize_chart_object", + "serialize_dataset_object", +] diff --git a/superset/mcp_service/pydantic_schemas/chart_schemas.py b/superset/mcp_service/pydantic_schemas/chart_schemas.py index 4031e91ce2e..ebee4888bdd 100644 --- a/superset/mcp_service/pydantic_schemas/chart_schemas.py +++ b/superset/mcp_service/pydantic_schemas/chart_schemas.py @@ -19,12 +19,15 @@ Pydantic schemas for chart-related responses """ from datetime import datetime -from typing import Any, Dict, List, Optional, Union, Literal -from pydantic import BaseModel, Field, ConfigDict -from .dashboard_schemas import UserInfo, TagInfo, PaginationInfo +from typing import Any, Dict, List, Literal, Optional, Union -class ChartListItem(BaseModel): - """Chart item for list responses""" +from pydantic import BaseModel, ConfigDict, Field +from superset.daos.base import ColumnOperator +from superset.mcp_service.pydantic_schemas.system_schemas import PaginationInfo, TagInfo, UserInfo + + +class ChartInfo(BaseModel): + """Full chart model with all possible attributes.""" id: int = Field(..., description="Chart ID") slice_name: str = Field(..., description="Chart name") viz_type: Optional[str] = Field(None, description="Visualization type") @@ -46,26 +49,13 @@ class ChartListItem(BaseModel): owners: List[UserInfo] = Field(default_factory=list, description="Chart owners") model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") -class ChartSimpleFilters(BaseModel): - slice_name: Optional[str] = Field(None, description="Filter by chart name (partial match)") - viz_type: Optional[str] = Field(None, description="Filter by visualization type") - datasource_name: Optional[str] = Field(None, description="Filter by datasource name") - changed_by: Optional[str] = Field(None, description="Filter by last modifier (username)") - created_by: Optional[str] = Field(None, description="Filter by creator (username)") - owner: Optional[str] = Field(None, description="Filter by owner (username)") - tags: Optional[str] = Field(None, description="Filter by tags (comma-separated)") - class ChartAvailableFiltersResponse(BaseModel): filters: Dict[str, Any] = Field(..., description="Available filters and their metadata") operators: List[str] = Field(..., description="Supported filter operators") columns: List[str] = Field(..., description="Available columns for filtering") -class ChartInfoResponse(BaseModel): - chart: ChartListItem = Field(..., description="Detailed chart info") - model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") - -class ChartListResponse(BaseModel): - charts: List[ChartInfoResponse] +class ChartList(BaseModel): + charts: List[ChartInfo] count: int total_count: int page: int @@ -80,16 +70,16 @@ class ChartListResponse(BaseModel): timestamp: Optional[datetime] = None model_config = ConfigDict(ser_json_timedelta="iso8601") -class ChartErrorResponse(BaseModel): +class ChartError(BaseModel): error: str = Field(..., description="Error message") error_type: str = Field(..., description="Type of error") timestamp: Optional[Union[str, datetime]] = Field(None, description="Error timestamp") model_config = ConfigDict(ser_json_timedelta="iso8601") -def serialize_chart_object(chart) -> Optional[ChartListItem]: +def serialize_chart_object(chart) -> Optional[ChartInfo]: if not chart: return None - return ChartListItem( + return ChartInfo( id=getattr(chart, 'id', None), slice_name=getattr(chart, 'slice_name', None), viz_type=getattr(chart, 'viz_type', None), @@ -109,7 +99,7 @@ def serialize_chart_object(chart) -> Optional[ChartListItem]: created_on_humanized=getattr(chart, 'created_on_humanized', None), tags=[TagInfo.model_validate(tag, from_attributes=True) for tag in getattr(chart, 'tags', [])] if getattr(chart, 'tags', None) else [], owners=[UserInfo.model_validate(owner, from_attributes=True) for owner in getattr(chart, 'owners', [])] if getattr(chart, 'owners', None) else [], - ) + ) class CreateSimpleChartRequest(BaseModel): """ @@ -131,7 +121,7 @@ class CreateSimpleChartResponse(BaseModel): """ Response schema for create_chart_simple tool. """ - chart: Optional[ChartListItem] = Field(None, description="The created chart info, if successful") + chart: Optional[ChartInfo] = Field(None, description="The created chart info, if successful") embed_url: Optional[str] = Field(None, description="URL to view or embed the chart, if requested.") thumbnail_url: Optional[str] = Field(None, description="URL to a thumbnail image of the chart, if requested.") embed_html: Optional[str] = Field(None, description="HTML snippet (e.g., iframe) to embed the chart, if requested.") diff --git a/superset/mcp_service/pydantic_schemas/dashboard_schemas.py b/superset/mcp_service/pydantic_schemas/dashboard_schemas.py index c93ee297945..91139893cb7 100644 --- a/superset/mcp_service/pydantic_schemas/dashboard_schemas.py +++ b/superset/mcp_service/pydantic_schemas/dashboard_schemas.py @@ -23,7 +23,7 @@ in a consistent and type-safe manner. Example usage: # For detailed dashboard info - dashboard_info = DashboardInfoResponse( + dashboard_info = DashboardInfo( id=1, dashboard_title="Sales Dashboard", published=True, @@ -32,9 +32,9 @@ Example usage: ) # For dashboard list responses - dashboard_list = DashboardListResponse( + dashboard_list = DashboardList( dashboards=[ - DashboardListItem( + DashboardInfo( id=1, dashboard_title="Sales Dashboard", published=True, @@ -64,147 +64,15 @@ Example usage: """ from datetime import datetime -from typing import Any, Dict, List, Optional, Union, Mapping, Literal -from pydantic import BaseModel, Field, ConfigDict +from typing import Any, Dict, List, Literal, Optional, Union + +from pydantic import BaseModel, ConfigDict, Field +from superset.daos.base import ColumnOperator +from superset.mcp_service.pydantic_schemas.chart_schemas import ChartInfo +from superset.mcp_service.pydantic_schemas.system_schemas import PaginationInfo, RoleInfo, TagInfo, UserInfo -class UserInfo(BaseModel): - """User information for dashboard owners and creators""" - id: Optional[int] = None - username: Optional[str] = None - first_name: Optional[str] = None - last_name: Optional[str] = None - email: Optional[str] = None - active: Optional[bool] = None - - -class TagInfo(BaseModel): - """Tag information for dashboard tags""" - id: Optional[int] = None - name: Optional[str] = None - type: Optional[str] = None - description: Optional[str] = None - - -class RoleInfo(BaseModel): - """Role information for dashboard roles""" - id: Optional[int] = None - name: Optional[str] = None - permissions: Optional[List[str]] = None - - -class ChartInfo(BaseModel): - """Chart information for dashboard charts""" - id: Optional[int] = None - slice_name: Optional[str] = None - viz_type: Optional[str] = None - datasource_name: Optional[str] = None - datasource_type: Optional[str] = None - url: Optional[str] = None - description: Optional[str] = None - cache_timeout: Optional[int] = None - form_data: Optional[Dict[str, Any]] = None - query_context: Optional[Any] = None - created_by: Optional[UserInfo] = None - changed_by: Optional[UserInfo] = None - created_on: Optional[Union[str, datetime]] = None - changed_on: Optional[Union[str, datetime]] = None - model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") - - -class DashboardListItem(BaseModel): - """Dashboard item for list responses - simplified version of DashboardInfoResponse""" - id: int = Field(..., description="Dashboard ID") - dashboard_title: str = Field(..., description="Dashboard title") - slug: Optional[str] = Field(None, description="Dashboard slug") - url: Optional[str] = Field(None, description="Dashboard URL") - published: Optional[bool] = Field(None, description="Whether the dashboard is published") - changed_by: Optional[str] = Field(None, description="Last modifier (username)") - changed_by_name: Optional[str] = Field(None, description="Last modifier (display name)") - changed_on: Optional[Union[str, datetime]] = Field(None, description="Last modification timestamp") - changed_on_humanized: Optional[str] = Field(None, description="Humanized modification time") - created_by: Optional[str] = Field(None, description="Dashboard creator (username)") - created_on: Optional[Union[str, datetime]] = Field(None, description="Creation timestamp") - created_on_humanized: Optional[str] = Field(None, description="Humanized creation time") - tags: List[TagInfo] = Field(default_factory=list, description="Dashboard tags") - owners: List[UserInfo] = Field(default_factory=list, description="Dashboard owners") - - model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") - - -class PaginationInfo(BaseModel): - """Pagination information for list responses""" - page: int = Field(..., description="Current page number") - page_size: int = Field(..., description="Number of items per page") - total_count: int = Field(..., description="Total number of items") - total_pages: int = Field(..., description="Total number of pages") - has_next: bool = Field(..., description="Whether there is a next page") - has_previous: bool = Field(..., description="Whether there is a previous page") - - model_config = ConfigDict(ser_json_timedelta="iso8601") - - -class DashboardInfoResponse(BaseModel): - """Detailed dashboard information response - maps exactly to Dashboard model""" - - # Core Dashboard model fields - id: int = Field(..., description="Dashboard ID") - dashboard_title: str = Field(..., description="Dashboard title") - slug: Optional[str] = Field(None, description="Dashboard slug") - description: Optional[str] = Field(None, description="Dashboard description") - css: Optional[str] = Field(None, description="Custom CSS for the dashboard") - certified_by: Optional[str] = Field(None, description="Who certified the dashboard") - certification_details: Optional[str] = Field(None, description="Certification details") - json_metadata: Optional[str] = Field(None, description="Dashboard metadata (JSON string)") - position_json: Optional[str] = Field(None, description="Chart positions (JSON string)") - published: Optional[bool] = Field(None, description="Whether the dashboard is published") - is_managed_externally: Optional[bool] = Field(None, description="Whether managed externally") - external_url: Optional[str] = Field(None, description="External URL") - - # AuditMixinNullable fields - created_on: Optional[Union[str, datetime]] = Field(None, description="Creation timestamp") - changed_on: Optional[Union[str, datetime]] = Field(None, description="Last modification timestamp") - created_by: Optional[str] = Field(None, description="Dashboard creator (username)") - changed_by: Optional[str] = Field(None, description="Last modifier (username)") - - # ImportExportMixin fields - uuid: Optional[str] = Field(None, description="Dashboard UUID (converted to string)") - - # Computed properties - url: Optional[str] = Field(None, description="Dashboard URL") - thumbnail_url: Optional[str] = Field(None, description="Thumbnail URL") - created_on_humanized: Optional[str] = Field(None, description="Humanized creation time") - changed_on_humanized: Optional[str] = Field(None, description="Humanized modification time") - chart_count: int = Field(0, description="Number of charts in the dashboard") - - # Related entities - owners: List[UserInfo] = Field(default_factory=list, description="Dashboard owners") - tags: List[TagInfo] = Field(default_factory=list, description="Dashboard tags") - roles: List[RoleInfo] = Field(default_factory=list, description="Dashboard roles") - charts: List[ChartInfo] = Field(default_factory=list, description="Dashboard charts") - - model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") - - -class DashboardListResponse(BaseModel): - dashboards: List[DashboardInfoResponse] - count: int - total_count: int - page: int - page_size: int - total_pages: int - has_previous: bool - has_next: bool - columns_requested: Optional[List[str]] = None - columns_loaded: Optional[List[str]] = None - filters_applied: List[dict] = Field( - default_factory=list, description="List of advanced filter dicts applied to the query.") - pagination: Optional[PaginationInfo] = None - timestamp: Optional[datetime] = None - - model_config = ConfigDict(ser_json_timedelta="iso8601") - -class DashboardErrorResponse(BaseModel): +class DashboardError(BaseModel): """Error response for dashboard operations""" error: str = Field(..., description="Error message") error_type: str = Field(..., description="Type of error") @@ -254,7 +122,7 @@ def serialize_role_object(role) -> Optional[RoleInfo]: def serialize_chart_object(chart) -> Optional[ChartInfo]: - """Serialize a chart object to ChartInfo""" + """Serialize a chart object to Chart""" if not chart: return None @@ -276,13 +144,13 @@ def serialize_chart_object(chart) -> Optional[ChartInfo]: ) -class DashboardAvailableFiltersResponse(BaseModel): +class DashboardAvailableFilters(BaseModel): filters: Dict[str, Any] = Field(..., description="Available filters and their metadata") operators: List[str] = Field(..., description="Supported filter operators") columns: List[str] = Field(..., description="Available columns for filtering") -class DashboardFilter(BaseModel): +class DashboardFilter(ColumnOperator): """ Filter object for dashboard listing. col: The column to filter on. Must be one of the allowed filter fields. @@ -303,21 +171,50 @@ class DashboardFilter(BaseModel): opr: Literal[ "eq", "ne", "in", "nin", "sw", "ew", "gte", "lte", "gt", "lt" ] = Field(..., description="Operator to use. See get_dashboard_available_filters for allowed values.") - value: Any = Field(..., description="Value to filter by (type depends on col and opr)") + value: Any = Field(..., description="Value to filter by (type depends on col and opr)") -class DashboardSimpleFilters(BaseModel): - dashboard_title: Optional[str] = Field(None, description="Filter by dashboard title (partial match)") - published: Optional[bool] = Field(None, description="Filter by published status") - changed_by: Optional[str] = Field(None, description="Filter by last modifier (username)") - created_by: Optional[str] = Field(None, description="Filter by creator (username)") - owner: Optional[str] = Field(None, description="Filter by owner (username)") - certified: Optional[bool] = Field(None, description="Filter by certified status") - favorite: Optional[bool] = Field(None, description="Filter by favorite status") - chart_count: Optional[int] = Field(None, description="Filter by number of charts") - chart_count_min: Optional[int] = Field(None, description="Filter by minimum number of charts") - chart_count_max: Optional[int] = Field(None, description="Filter by maximum number of charts") - tags: Optional[str] = Field(None, description="Filter by tags (comma-separated)") +class DashboardInfo(BaseModel): + id: int = Field(..., description="Dashboard ID") + dashboard_title: str = Field(..., description="Dashboard title") + slug: Optional[str] = Field(None, description="Dashboard slug") + description: Optional[str] = Field(None, description="Dashboard description") + css: Optional[str] = Field(None, description="Custom CSS for the dashboard") + certified_by: Optional[str] = Field(None, description="Who certified the dashboard") + certification_details: Optional[str] = Field(None, description="Certification details") + json_metadata: Optional[str] = Field(None, description="Dashboard metadata (JSON string)") + position_json: Optional[str] = Field(None, description="Chart positions (JSON string)") + published: Optional[bool] = Field(None, description="Whether the dashboard is published") + is_managed_externally: Optional[bool] = Field(None, description="Whether managed externally") + external_url: Optional[str] = Field(None, description="External URL") + created_on: Optional[Union[str, datetime]] = Field(None, description="Creation timestamp") + changed_on: Optional[Union[str, datetime]] = Field(None, description="Last modification timestamp") + created_by: Optional[str] = Field(None, description="Dashboard creator (username)") + changed_by: Optional[str] = Field(None, description="Last modifier (username)") + uuid: Optional[str] = Field(None, description="Dashboard UUID (converted to string)") + url: Optional[str] = Field(None, description="Dashboard URL") + thumbnail_url: Optional[str] = Field(None, description="Thumbnail URL") + created_on_humanized: Optional[str] = Field(None, description="Humanized creation time") + changed_on_humanized: Optional[str] = Field(None, description="Humanized modification time") + chart_count: int = Field(0, description="Number of charts in the dashboard") + owners: List[UserInfo] = Field(default_factory=list, description="Dashboard owners") + tags: List[TagInfo] = Field(default_factory=list, description="Dashboard tags") + roles: List[RoleInfo] = Field(default_factory=list, description="Dashboard roles") + charts: List[ChartInfo] = Field(default_factory=list, description="Dashboard charts") + model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") - -# ... rest of the file remains unchanged ... +class DashboardList(BaseModel): + dashboards: List[DashboardInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: Optional[List[str]] = None + columns_loaded: Optional[List[str]] = None + filters_applied: List[dict] = Field(default_factory=list, description="List of advanced filter dicts applied to the query.") + pagination: Optional[PaginationInfo] = None + timestamp: Optional[datetime] = None + model_config = ConfigDict(ser_json_timedelta="iso8601") diff --git a/superset/mcp_service/pydantic_schemas/dataset_schemas.py b/superset/mcp_service/pydantic_schemas/dataset_schemas.py index 9492444c951..23a4e9be748 100644 --- a/superset/mcp_service/pydantic_schemas/dataset_schemas.py +++ b/superset/mcp_service/pydantic_schemas/dataset_schemas.py @@ -19,49 +19,19 @@ Pydantic schemas for dataset-related responses """ from datetime import datetime -from typing import Any, Dict, List, Optional, Union, Literal -from pydantic import BaseModel, Field, ConfigDict -from .dashboard_schemas import UserInfo, TagInfo, PaginationInfo +from typing import Any, Dict, List, Literal, Optional, Union -class DatasetListItem(BaseModel): - """Dataset item for list responses""" - id: int = Field(..., description="Dataset ID") - table_name: str = Field(..., description="Table name") - db_schema: Optional[str] = Field(None, alias="schema", description="Schema name") - database_name: Optional[str] = Field(None, description="Database name") - description: Optional[str] = Field(None, description="Dataset description") - changed_by: Optional[str] = Field(None, description="Last modifier (username)") - changed_by_name: Optional[str] = Field(None, description="Last modifier (display name)") - changed_on: Optional[Union[str, datetime]] = Field(None, description="Last modification timestamp") - changed_on_humanized: Optional[str] = Field(None, description="Humanized modification time") - created_by: Optional[str] = Field(None, description="Dataset creator (username)") - created_on: Optional[Union[str, datetime]] = Field(None, description="Creation timestamp") - created_on_humanized: Optional[str] = Field(None, description="Humanized creation time") - tags: List[TagInfo] = Field(default_factory=list, description="Dataset tags") - owners: List[UserInfo] = Field(default_factory=list, description="Dataset owners") - is_virtual: Optional[bool] = Field(None, description="Whether the dataset is virtual (uses SQL)") - database_id: Optional[int] = Field(None, description="Database ID") - schema_perm: Optional[str] = Field(None, description="Schema permission string") - url: Optional[str] = Field(None, description="Dataset URL") - model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") +from pydantic import BaseModel, ConfigDict, Field +from superset.daos.base import ColumnOperator +from superset.mcp_service.pydantic_schemas.system_schemas import PaginationInfo, TagInfo, UserInfo -class DatasetSimpleFilters(BaseModel): - table_name: Optional[str] = Field(None, description="Filter by table name (partial match)") - db_schema: Optional[str] = Field(None, alias="schema", description="Filter by schema name") - database_name: Optional[str] = Field(None, description="Filter by database name") - changed_by: Optional[str] = Field(None, description="Filter by last modifier (username)") - created_by: Optional[str] = Field(None, description="Filter by creator (username)") - owner: Optional[str] = Field(None, description="Filter by owner (username)") - is_virtual: Optional[bool] = Field(None, description="Filter by whether the dataset is virtual (uses SQL)") - tags: Optional[str] = Field(None, description="Filter by tags (comma-separated)") - -class DatasetAvailableFiltersResponse(BaseModel): +class DatasetAvailableFilters(BaseModel): filters: Dict[str, Any] = Field(..., description="Available filters and their metadata") operators: List[str] = Field(..., description="Supported filter operators") columns: List[str] = Field(..., description="Available columns for filtering") -class DatasetFilter(BaseModel): +class DatasetFilter(ColumnOperator): """ Filter object for dataset listing. col: The column to filter on. Must be one of the allowed filter fields. @@ -83,32 +53,7 @@ class DatasetFilter(BaseModel): ] = Field(..., description="Operator to use. See get_dataset_available_filters for allowed values.") value: Any = Field(..., description="Value to filter by (type depends on col and opr)") -def serialize_dataset_object(dataset) -> Optional[DatasetListItem]: - if not dataset: - return None - return DatasetListItem( - id=getattr(dataset, 'id', None), - table_name=getattr(dataset, 'table_name', None), - db_schema=getattr(dataset, 'schema', None), - database_name=getattr(dataset.database, 'database_name', None) if getattr(dataset, 'database', None) else None, - description=getattr(dataset, 'description', None), - changed_by=getattr(dataset, 'changed_by_name', None) or (str(dataset.changed_by) if getattr(dataset, 'changed_by', None) else None), - changed_by_name=getattr(dataset, 'changed_by_name', None) or (str(dataset.changed_by) if getattr(dataset, 'changed_by', None) else None), - changed_on=getattr(dataset, 'changed_on', None), - changed_on_humanized=getattr(dataset, 'changed_on_humanized', None), - created_by=getattr(dataset, 'created_by_name', None) or (str(dataset.created_by) if getattr(dataset, 'created_by', None) else None), - created_on=getattr(dataset, 'created_on', None), - created_on_humanized=getattr(dataset, 'created_on_humanized', None), - tags=[TagInfo.model_validate(tag, from_attributes=True) for tag in getattr(dataset, 'tags', [])] if getattr(dataset, 'tags', None) else [], - owners=[UserInfo.model_validate(owner, from_attributes=True) for owner in getattr(dataset, 'owners', [])] if getattr(dataset, 'owners', None) else [], - is_virtual=getattr(dataset, 'is_virtual', None), - database_id=getattr(dataset, 'database_id', None), - schema_perm=getattr(dataset, 'schema_perm', None), - url=getattr(dataset, 'url', None), - ) - -class DatasetInfoResponse(BaseModel): - """Detailed dataset information response - maps exactly to Dataset model""" +class DatasetInfo(BaseModel): id: int = Field(..., description="Dataset ID") table_name: str = Field(..., description="Table name") db_schema: Optional[str] = Field(None, alias="schema", description="Schema name") @@ -121,7 +66,7 @@ class DatasetInfoResponse(BaseModel): created_on: Optional[Union[str, datetime]] = Field(None, description="Creation timestamp") created_on_humanized: Optional[str] = Field(None, description="Humanized creation time") tags: List[TagInfo] = Field(default_factory=list, description="Dataset tags") - owners: List[UserInfo] = Field(default_factory=list, description="Dataset owners") + owners: List[UserInfo] = Field(default_factory=list, description="DatasetInfo owners") is_virtual: Optional[bool] = Field(None, description="Whether the dataset is virtual (uses SQL)") database_id: Optional[int] = Field(None, description="Database ID") schema_perm: Optional[str] = Field(None, description="Schema permission string") @@ -135,9 +80,8 @@ class DatasetInfoResponse(BaseModel): extra: Optional[Dict[str, Any]] = Field(None, description="Extra metadata") model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") -class DatasetListResponse(BaseModel): - """Response for dataset list operations""" - datasets: List[DatasetInfoResponse] +class DatasetList(BaseModel): + datasets: List[DatasetInfo] count: int total_count: int page: int @@ -152,8 +96,38 @@ class DatasetListResponse(BaseModel): timestamp: Optional[datetime] = None model_config = ConfigDict(ser_json_timedelta="iso8601") -class DatasetErrorResponse(BaseModel): +class DatasetError(BaseModel): error: str = Field(..., description="Error message") error_type: str = Field(..., description="Type of error") timestamp: Optional[Union[str, datetime]] = Field(None, description="Error timestamp") model_config = ConfigDict(ser_json_timedelta="iso8601") + +def serialize_dataset_object(dataset) -> Optional[DatasetInfo]: + if not dataset: + return None + return DatasetInfo( + id=getattr(dataset, 'id', None), + table_name=getattr(dataset, 'table_name', None), + db_schema=getattr(dataset, 'schema', None), + database_name=getattr(dataset.database, 'database_name', None) if getattr(dataset, 'database', None) else None, + description=getattr(dataset, 'description', None), + changed_by=getattr(dataset, 'changed_by_name', None) or (str(dataset.changed_by) if getattr(dataset, 'changed_by', None) else None), + changed_on=getattr(dataset, 'changed_on', None), + changed_on_humanized=getattr(dataset, 'changed_on_humanized', None), + created_by=getattr(dataset, 'created_by_name', None) or (str(dataset.created_by) if getattr(dataset, 'created_by', None) else None), + created_on=getattr(dataset, 'created_on', None), + created_on_humanized=getattr(dataset, 'created_on_humanized', None), + tags=[TagInfo.model_validate(tag, from_attributes=True) for tag in getattr(dataset, 'tags', [])] if getattr(dataset, 'tags', None) else [], + owners=[UserInfo.model_validate(owner, from_attributes=True) for owner in getattr(dataset, 'owners', [])] if getattr(dataset, 'owners', None) else [], + is_virtual=getattr(dataset, 'is_virtual', None), + database_id=getattr(dataset, 'database_id', None), + schema_perm=getattr(dataset, 'schema_perm', None), + url=getattr(dataset, 'url', None), + sql=getattr(dataset, 'sql', None), + main_dttm_col=getattr(dataset, 'main_dttm_col', None), + offset=getattr(dataset, 'offset', None), + cache_timeout=getattr(dataset, 'cache_timeout', None), + params=getattr(dataset, 'params', None), + template_params=getattr(dataset, 'template_params', None), + extra=getattr(dataset, 'extra', None), + ) diff --git a/superset/mcp_service/pydantic_schemas/system_schemas.py b/superset/mcp_service/pydantic_schemas/system_schemas.py index 19a0747ebe3..4af8f2e525e 100644 --- a/superset/mcp_service/pydantic_schemas/system_schemas.py +++ b/superset/mcp_service/pydantic_schemas/system_schemas.py @@ -22,8 +22,8 @@ This module contains Pydantic models for serializing Superset instance metadata """ from datetime import datetime -from typing import Dict, List -from pydantic import BaseModel, Field +from typing import Dict, List, Optional +from pydantic import BaseModel, Field, ConfigDict class InstanceSummary(BaseModel): total_dashboards: int = Field(..., description="Total number of dashboards") @@ -57,10 +57,38 @@ class PopularContent(BaseModel): top_tags: List[str] = Field(..., description="Most popular tags") top_creators: List[str] = Field(..., description="Most active creators") -class SupersetInstanceInfoResponse(BaseModel): +class InstanceInfo(BaseModel): instance_summary: InstanceSummary = Field(..., description="Instance summary information") recent_activity: RecentActivity = Field(..., description="Recent activity information") dashboard_breakdown: DashboardBreakdown = Field(..., description="Dashboard breakdown information") database_breakdown: DatabaseBreakdown = Field(..., description="Database breakdown by type") popular_content: PopularContent = Field(..., description="Popular content information") - timestamp: datetime = Field(..., description="Response timestamp") \ No newline at end of file + timestamp: datetime = Field(..., description="Response timestamp") + +class UserInfo(BaseModel): + id: Optional[int] = None + username: Optional[str] = None + first_name: Optional[str] = None + last_name: Optional[str] = None + email: Optional[str] = None + active: Optional[bool] = None + +class TagInfo(BaseModel): + id: Optional[int] = None + name: Optional[str] = None + type: Optional[str] = None + description: Optional[str] = None + +class RoleInfo(BaseModel): + id: Optional[int] = None + name: Optional[str] = None + permissions: Optional[List[str]] = None + +class PaginationInfo(BaseModel): + page: int = Field(..., description="Current page number") + page_size: int = Field(..., description="Number of items per page") + total_count: int = Field(..., description="Total number of items") + total_pages: int = Field(..., description="Total number of pages") + has_next: bool = Field(..., description="Whether there is a next page") + has_previous: bool = Field(..., description="Whether there is a previous page") + model_config = ConfigDict(ser_json_timedelta="iso8601") diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index 0cc08922981..f0377a805aa 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -46,7 +46,7 @@ Available tools include: - get_dashboard_info: Get detailed information about a dashboard by its integer ID - get_superset_instance_info: Get high-level statistics and metadata about the Superset instance (no arguments) - get_dashboard_available_filters: List all available dashboard filter fields and operators -- list_datasets: Dataset listing with advanced filters (use 'filters' for advanced queries, 1-based pagination) +- list_datasets: DatasetInfo listing with advanced filters (use 'filters' for advanced queries, 1-based pagination) - get_dataset_info: Get detailed information about a dataset by its integer ID - get_dataset_available_filters: List all available dataset filter fields and operators - list_charts: Chart listing with advanced filters (use 'filters' for advanced queries, 1-based pagination) diff --git a/superset/mcp_service/tools/chart/create_chart_simple.py b/superset/mcp_service/tools/chart/create_chart_simple.py index dc9e51483bc..1eac2772ad1 100644 --- a/superset/mcp_service/tools/chart/create_chart_simple.py +++ b/superset/mcp_service/tools/chart/create_chart_simple.py @@ -20,7 +20,7 @@ MCP tool: create_chart_simple from typing import Annotated from pydantic import Field from superset.mcp_service.pydantic_schemas.chart_schemas import ( - CreateSimpleChartRequest, CreateSimpleChartResponse, ChartListItem + CreateSimpleChartRequest, CreateSimpleChartResponse ) from superset.commands.chart.create import CreateChartCommand from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object diff --git a/superset/mcp_service/tools/chart/get_chart_info.py b/superset/mcp_service/tools/chart/get_chart_info.py index c1fd323d47a..e1f57841ec5 100644 --- a/superset/mcp_service/tools/chart/get_chart_info.py +++ b/superset/mcp_service/tools/chart/get_chart_info.py @@ -19,7 +19,7 @@ MCP tool: get_chart_info """ from typing import Any, Dict, Optional, Annotated -from superset.mcp_service.pydantic_schemas import ChartInfoResponse, ChartErrorResponse +from superset.mcp_service.pydantic_schemas import ChartInfo, ChartError from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object from datetime import datetime @@ -31,24 +31,17 @@ def get_chart_info( int, Field(description="ID of the chart to retrieve information for") ] -) -> ChartInfoResponse | ChartErrorResponse: +) -> ChartInfo | ChartError: """ Get detailed information about a chart by ID (MCP tool). - Parameters - ---------- - chart_id : int - ID of the chart to retrieve information for. - Returns - ------- - ChartInfoResponse or ChartErrorResponse - Detailed chart information or error response. + Returns a ChartInfo model or ChartError on error. """ try: chart_wrapper = MCPDAOWrapper(ChartDAO, "chart") chart, error_type, error_message = chart_wrapper.info(chart_id) if not chart: - return ChartErrorResponse(error=error_message or "Chart not found", error_type=error_type or "not_found", timestamp=datetime.utcnow()) + return ChartError(error=error_message or "Chart not found", error_type=error_type or "not_found", timestamp=datetime.utcnow()) chart_info = serialize_chart_object(chart) - return ChartInfoResponse(chart=chart_info) + return chart_info except Exception as ex: - return ChartErrorResponse(error=str(ex), error_type="get_chart_info_error", timestamp=datetime.utcnow()) \ No newline at end of file + return ChartError(error=str(ex), error_type="get_chart_info_error", timestamp=datetime.utcnow()) diff --git a/superset/mcp_service/tools/chart/list_charts.py b/superset/mcp_service/tools/chart/list_charts.py index ffdc6dd3d3a..462a62c127e 100644 --- a/superset/mcp_service/tools/chart/list_charts.py +++ b/superset/mcp_service/tools/chart/list_charts.py @@ -19,7 +19,7 @@ MCP tool: list_charts (advanced filtering) """ from typing import Any, Dict, List, Optional, Literal, Annotated, Union -from superset.mcp_service.pydantic_schemas import ChartListResponse, ChartListItem +from superset.mcp_service.pydantic_schemas import ChartList, ChartInfo from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object from datetime import datetime, timezone @@ -27,6 +27,7 @@ from pydantic import BaseModel, conlist, constr, PositiveInt, Field from superset.mcp_service.pydantic_schemas.dashboard_schemas import PaginationInfo from superset.daos.chart import ChartDAO from superset.mcp_service.pydantic_schemas.chart_schemas import ChartFilter +import json def list_charts( @@ -66,14 +67,17 @@ def list_charts( Optional[str], Field(description="Text search string to match against chart fields") ] = None, -) -> ChartListResponse: +) -> ChartList: """ List charts with advanced filtering (MCP tool). - Returns a ChartListResponse Pydantic model (not a dict), matching list_dashboards and list_datasets. + Returns a ChartList Pydantic model (not a dict), matching list_dashboards and list_datasets. """ + # If filters is a string (e.g., from a test), parse it as JSON + if isinstance(filters, str): + filters = json.loads(filters) chart_wrapper = MCPDAOWrapper(ChartDAO, "chart") charts, total_count = chart_wrapper.list( - filters=filters, + column_operators=filters, order_column=order_column or "changed_on", order_direction=order_direction or "desc", page=max(page - 1, 0), @@ -81,6 +85,7 @@ def list_charts( search=search, search_columns=["slice_name", "viz_type", "datasource_name"] if search else None, ) + # ChartList expects a list of ChartInfo chart_items = [serialize_chart_object(chart) for chart in charts] total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 0 pagination_info = PaginationInfo( @@ -91,7 +96,7 @@ def list_charts( has_next=page < total_pages, has_previous=page > 1 ) - response = ChartListResponse( + response = ChartList( charts=chart_items, count=len(chart_items), total_count=total_count, @@ -102,7 +107,7 @@ def list_charts( has_next=page < total_pages - 1, columns_requested=columns or [], columns_loaded=columns or [], - filters_applied=filters or [], + filters_applied=filters if isinstance(filters, list) else [], pagination=pagination_info, timestamp=datetime.now(timezone.utc), ) diff --git a/superset/mcp_service/tools/dashboard/get_dashboard_available_filters.py b/superset/mcp_service/tools/dashboard/get_dashboard_available_filters.py index f25766c7bc9..eec415d82f3 100644 --- a/superset/mcp_service/tools/dashboard/get_dashboard_available_filters.py +++ b/superset/mcp_service/tools/dashboard/get_dashboard_available_filters.py @@ -5,15 +5,15 @@ Get available filters FastMCP tool """ import logging from typing import Any -from superset.mcp_service.pydantic_schemas.dashboard_schemas import DashboardAvailableFiltersResponse +from superset.mcp_service.pydantic_schemas.dashboard_schemas import DashboardAvailableFilters logger = logging.getLogger(__name__) -def get_dashboard_available_filters() -> DashboardAvailableFiltersResponse: +def get_dashboard_available_filters() -> DashboardAvailableFilters: """ Get information about available dashboard filters and their operators Returns: - DashboardAvailableFiltersResponse + DashboardAvailableFilters """ try: filters = { @@ -88,7 +88,7 @@ def get_dashboard_available_filters() -> DashboardAvailableFiltersResponse: "certification_details", "chart_count", "owners", "tags", "is_managed_externally", "external_url", "uuid", "version" ] - response = DashboardAvailableFiltersResponse( + response = DashboardAvailableFilters( filters=filters, operators=operators, columns=columns diff --git a/superset/mcp_service/tools/dashboard/get_dashboard_info.py b/superset/mcp_service/tools/dashboard/get_dashboard_info.py index e45d5273af4..b21d65c2212 100644 --- a/superset/mcp_service/tools/dashboard/get_dashboard_info.py +++ b/superset/mcp_service/tools/dashboard/get_dashboard_info.py @@ -26,16 +26,13 @@ from datetime import datetime, timezone from typing import Annotated from pydantic import Field + from superset.daos.dashboard import DashboardDAO from superset.mcp_service.dao_wrapper import MCPDAOWrapper -from superset.mcp_service.pydantic_schemas.dashboard_schemas import ( - ChartInfo, - DashboardErrorResponse, - DashboardInfoResponse, - RoleInfo, - TagInfo, - UserInfo, -) +from superset.mcp_service.pydantic_schemas import DashboardError, DashboardInfo +from superset.mcp_service.pydantic_schemas.chart_schemas import serialize_chart_object +from superset.mcp_service.pydantic_schemas.system_schemas import RoleInfo, TagInfo, \ + UserInfo logger = logging.getLogger(__name__) @@ -45,27 +42,16 @@ def get_dashboard_info( int, Field(description="ID of the dashboard to retrieve information for") ] -) -> DashboardInfoResponse | DashboardErrorResponse: +) -> DashboardInfo | DashboardError: """ Get detailed information about a specific dashboard. - Parameters - ---------- - dashboard_id : int - ID of the dashboard to retrieve information for. - Returns - ------- - DashboardInfoResponse or DashboardErrorResponse - Detailed dashboard information or error response. + Returns a DashboardInfo model or DashboardError on error. """ - try: - # Use the generic DAO wrapper dao_wrapper = MCPDAOWrapper(DashboardDAO, "dashboard") dashboard, error_type, error_message = dao_wrapper.info(dashboard_id) - if dashboard is None: - # Handle error cases - error_data = DashboardErrorResponse( + error_data = DashboardError( error=error_message, error_type=error_type, timestamp=datetime.now(timezone.utc) @@ -73,10 +59,7 @@ def get_dashboard_info( logger.warning( f"Dashboard {dashboard_id} error: {error_type} - {error_message}") return error_data - - # Create dashboard response using Pydantic constructor - most Pythonic approach - response = DashboardInfoResponse( - # Core dashboard attributes + response = DashboardInfo( id=dashboard.id, dashboard_title=dashboard.dashboard_title or "Untitled", slug=dashboard.slug or "", @@ -89,37 +72,21 @@ def get_dashboard_info( published=dashboard.published, is_managed_externally=dashboard.is_managed_externally, external_url=dashboard.external_url, - - # Audit fields created_on=dashboard.created_on, changed_on=dashboard.changed_on, - created_by=getattr( - dashboard.created_by, 'username', - None) if dashboard.created_by else None, - changed_by=getattr( - dashboard.changed_by, 'username', - None) if dashboard.changed_by else None, - - # UUID and computed fields + created_by=getattr(dashboard.created_by, 'username', None) if dashboard.created_by else None, + changed_by=getattr(dashboard.changed_by, 'username', None) if dashboard.changed_by else None, uuid=str(dashboard.uuid) if dashboard.uuid else None, url=dashboard.url, thumbnail_url=dashboard.thumbnail_url, created_on_humanized=dashboard.created_on_humanized, changed_on_humanized=dashboard.changed_on_humanized, chart_count=len(dashboard.slices) if dashboard.slices else 0, - - # Related entities - use model_validate for each type for proper - # serialization - owners=[UserInfo.model_validate(owner, from_attributes=True) for owner in - dashboard.owners] if dashboard.owners else [], - tags=[TagInfo.model_validate(tag, from_attributes=True) for tag in - dashboard.tags] if dashboard.tags else [], - roles=[RoleInfo.model_validate(role, from_attributes=True) for role in - dashboard.roles] if dashboard.roles else [], - charts=[ChartInfo.model_validate(chart, from_attributes=True) for chart in - dashboard.slices] if dashboard.slices else [] + owners=[UserInfo.model_validate(owner, from_attributes=True) for owner in dashboard.owners] if dashboard.owners else [], + tags=[TagInfo.model_validate(tag, from_attributes=True) for tag in dashboard.tags] if dashboard.tags else [], + roles=[RoleInfo.model_validate(role, from_attributes=True) for role in dashboard.roles] if dashboard.roles else [], + charts=[serialize_chart_object(chart) for chart in dashboard.slices] if dashboard.slices else [] ) - logger.info( f"Dashboard response created successfully for dashboard {dashboard.id}") return response diff --git a/superset/mcp_service/tools/dashboard/list_dashboards.py b/superset/mcp_service/tools/dashboard/list_dashboards.py index 7581a7adce5..d574c100bdc 100644 --- a/superset/mcp_service/tools/dashboard/list_dashboards.py +++ b/superset/mcp_service/tools/dashboard/list_dashboards.py @@ -21,15 +21,21 @@ List dashboards FastMCP tool (Advanced) This module contains the FastMCP tool for listing dashboards using advanced filtering with complex filter objects and JSON payload. """ +import json import logging from datetime import datetime, timezone -from typing import Annotated, Any, Literal, Optional +from typing import Annotated, Literal, Optional + +from pydantic import conlist, constr, Field, PositiveInt -from pydantic import BaseModel, conlist, constr, Field, PositiveInt from superset.daos.dashboard import DashboardDAO from superset.mcp_service.dao_wrapper import MCPDAOWrapper -from superset.mcp_service.pydantic_schemas.dashboard_schemas import ( - DashboardListItem, DashboardListResponse, PaginationInfo, TagInfo, UserInfo, DashboardFilter) +from superset.mcp_service.pydantic_schemas import ( + DashboardFilter, DashboardInfo, DashboardList) +from superset.mcp_service.pydantic_schemas.chart_schemas import ( + serialize_chart_object, ) +from superset.mcp_service.pydantic_schemas.system_schemas import ( + PaginationInfo, TagInfo, UserInfo) logger = logging.getLogger(__name__) @@ -71,11 +77,15 @@ def list_dashboards( Optional[str], Field(description="Text search string to match against dataset fields") ] = None, -) -> DashboardListResponse: +) -> DashboardList: """ ADVANCED FILTERING: List dashboards using complex filter objects and JSON payload - Returns a DashboardListResponse Pydantic model (not a dict), matching list_dashboards_simple. + Returns a DashboardList Pydantic model (not a dict), matching + list_dashboards_simple. """ + # If filters is a string (e.g., from a test), parse it as JSON + if isinstance(filters, str): + filters = json.loads(filters) dao_wrapper = MCPDAOWrapper(DashboardDAO, "dashboard") search_columns = ( "created_by", @@ -90,7 +100,7 @@ def list_dashboards( "uuid", ) dashboards, total_count = dao_wrapper.list( - filters=filters, + column_operators=filters, order_column=order_column or "changed_on", order_direction=order_direction or "desc", page=max(page - 1, 0), @@ -101,7 +111,8 @@ def list_dashboards( columns_to_load = [] if select_columns: if isinstance(select_columns, str): - select_columns = [col.strip() for col in select_columns.split(",") if col.strip()] + select_columns = [col.strip() for col in select_columns.split(",") if + col.strip()] columns_to_load = select_columns elif columns: columns_to_load = columns @@ -114,7 +125,7 @@ def list_dashboards( ] dashboard_items = [] for dashboard in dashboards: - dashboard_item = DashboardListItem( + dashboard_item = DashboardInfo( id=dashboard.id, dashboard_title=dashboard.dashboard_title or "Untitled", slug=dashboard.slug or "", @@ -124,14 +135,21 @@ def list_dashboards( str(dashboard.changed_by) if dashboard.changed_by else None), changed_by_name=getattr(dashboard, "changed_by_name", None) or ( str(dashboard.changed_by) if dashboard.changed_by else None), - changed_on=dashboard.changed_on if getattr(dashboard, "changed_on", None) else None, + changed_on=dashboard.changed_on if getattr( + dashboard, "changed_on", None) else None, changed_on_humanized=getattr(dashboard, "changed_on_humanized", None), created_by=getattr(dashboard, "created_by_name", None) or ( str(dashboard.created_by) if dashboard.created_by else None), - created_on=dashboard.created_on if getattr(dashboard, "created_on", None) else None, + created_on=dashboard.created_on if getattr( + dashboard, "created_on", None) else None, created_on_humanized=getattr(dashboard, "created_on_humanized", None), - tags=[TagInfo.model_validate(tag, from_attributes=True) for tag in dashboard.tags] if dashboard.tags else [], - owners=[UserInfo.model_validate(owner, from_attributes=True) for owner in dashboard.owners] if dashboard.owners else [] + tags=[TagInfo.model_validate(tag, from_attributes=True) for tag in + dashboard.tags] if dashboard.tags else [], + owners=[UserInfo.model_validate(owner, from_attributes=True) for owner in + dashboard.owners] if dashboard.owners else [], + charts=[serialize_chart_object(chart) for chart in + getattr(dashboard, 'slices', [])] if getattr( + dashboard, 'slices', None) else [], ) dashboard_items.append(dashboard_item) total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 0 @@ -143,7 +161,7 @@ def list_dashboards( has_next=page < total_pages - 1, has_previous=page > 0 ) - response = DashboardListResponse( + response = DashboardList( dashboards=dashboard_items, count=len(dashboard_items), total_count=total_count, @@ -154,7 +172,7 @@ def list_dashboards( has_next=page < total_pages - 1, columns_requested=columns or [], columns_loaded=columns or [], - filters_applied=filters or [], + filters_applied=filters if isinstance(filters, list) else [], pagination=pagination_info, timestamp=datetime.now(timezone.utc) ) diff --git a/superset/mcp_service/tools/dataset/get_dataset_available_filters.py b/superset/mcp_service/tools/dataset/get_dataset_available_filters.py index 3c5483ac529..70002519305 100644 --- a/superset/mcp_service/tools/dataset/get_dataset_available_filters.py +++ b/superset/mcp_service/tools/dataset/get_dataset_available_filters.py @@ -18,15 +18,15 @@ Get available dataset filters FastMCP tool """ import logging -from superset.mcp_service.pydantic_schemas.dataset_schemas import DatasetAvailableFiltersResponse +from superset.mcp_service.pydantic_schemas.dataset_schemas import DatasetAvailableFilters logger = logging.getLogger(__name__) -def get_dataset_available_filters() -> DatasetAvailableFiltersResponse: +def get_dataset_available_filters() -> DatasetAvailableFilters: """ Get information about available dataset filters and their operators Returns: - DatasetAvailableFiltersResponse + DatasetAvailableFilters """ try: filters = { @@ -93,7 +93,7 @@ def get_dataset_available_filters() -> DatasetAvailableFiltersResponse: "changed_on", "created_by", "created_on", "is_virtual", "database_id", "schema_perm", "url", "tags", "owners" ] - response = DatasetAvailableFiltersResponse( + response = DatasetAvailableFilters( filters=filters, operators=operators, columns=columns diff --git a/superset/mcp_service/tools/dataset/get_dataset_info.py b/superset/mcp_service/tools/dataset/get_dataset_info.py index 7aba2fce107..a488a73ea25 100644 --- a/superset/mcp_service/tools/dataset/get_dataset_info.py +++ b/superset/mcp_service/tools/dataset/get_dataset_info.py @@ -27,12 +27,7 @@ from typing import Any, Annotated from pydantic import Field from superset.daos.dataset import DatasetDAO from superset.mcp_service.dao_wrapper import MCPDAOWrapper -from superset.mcp_service.pydantic_schemas import ( - DatasetInfoResponse, - DatasetErrorResponse, - TagInfo, - UserInfo, -) +from superset.mcp_service.pydantic_schemas import DatasetInfo, DatasetError, serialize_dataset_object logger = logging.getLogger(__name__) @@ -41,56 +36,24 @@ def get_dataset_info( int, Field(description="ID of the dataset to retrieve information for") ] -) -> DatasetInfoResponse | DatasetErrorResponse: +) -> DatasetInfo | DatasetError: """ Get detailed information about a specific dataset. - Parameters - ---------- - dataset_id : int - ID of the dataset to retrieve information for. - Returns - ------- - DatasetInfoResponse or DatasetErrorResponse - Detailed dataset information or error response. + Returns a DatasetInfo model or DatasetError on error. """ try: dao_wrapper = MCPDAOWrapper(DatasetDAO, "dataset") dataset, error_type, error_message = dao_wrapper.info(dataset_id) if dataset is None: - error_data = DatasetErrorResponse( + error_data = DatasetError( error=error_message, error_type=error_type, timestamp=datetime.now(timezone.utc) ) - logger.warning(f"Dataset {dataset_id} error: {error_type} - {error_message}") + logger.warning(f"DatasetInfo {dataset_id} error: {error_type} - {error_message}") return error_data - response = DatasetInfoResponse( - id=dataset.id, - table_name=dataset.table_name, - db_schema=getattr(dataset, 'schema', None), - database_name=getattr(dataset.database, 'database_name', None) if getattr(dataset, 'database', None) else None, - description=getattr(dataset, 'description', None), - changed_by=getattr(dataset, 'changed_by_name', None) or (str(dataset.changed_by) if getattr(dataset, 'changed_by', None) else None), - changed_on=getattr(dataset, 'changed_on', None), - changed_on_humanized=getattr(dataset, 'changed_on_humanized', None), - created_by=getattr(dataset, 'created_by_name', None) or (str(dataset.created_by) if getattr(dataset, 'created_by', None) else None), - created_on=getattr(dataset, 'created_on', None), - created_on_humanized=getattr(dataset, 'created_on_humanized', None), - tags=[TagInfo.model_validate(tag, from_attributes=True) for tag in getattr(dataset, 'tags', [])] if getattr(dataset, 'tags', None) else [], - owners=[UserInfo.model_validate(owner, from_attributes=True) for owner in getattr(dataset, 'owners', [])] if getattr(dataset, 'owners', None) else [], - is_virtual=getattr(dataset, 'is_virtual', None), - database_id=getattr(dataset, 'database_id', None), - schema_perm=getattr(dataset, 'schema_perm', None), - url=getattr(dataset, 'url', None), - sql=getattr(dataset, 'sql', None), - main_dttm_col=getattr(dataset, 'main_dttm_col', None), - offset=getattr(dataset, 'offset', None), - cache_timeout=getattr(dataset, 'cache_timeout', None), - params=getattr(dataset, 'params', None), - template_params=getattr(dataset, 'template_params', None), - extra=getattr(dataset, 'extra', None), - ) - logger.info(f"Dataset response created successfully for dataset {dataset.id}") + response = serialize_dataset_object(dataset) + logger.info(f"DatasetInfo response created successfully for dataset {dataset.id}") return response except Exception as context_error: error_msg = f"Error within Flask app context: {str(context_error)}" diff --git a/superset/mcp_service/tools/dataset/list_datasets.py b/superset/mcp_service/tools/dataset/list_datasets.py index acb34240e0f..1312718c686 100644 --- a/superset/mcp_service/tools/dataset/list_datasets.py +++ b/superset/mcp_service/tools/dataset/list_datasets.py @@ -24,14 +24,16 @@ advanced filtering with complex filter objects and JSON payload. import logging from datetime import datetime, timezone from typing import Annotated, Any, Literal, Optional +import json from pydantic import BaseModel, conlist, constr, Field, PositiveInt from superset.daos.dataset import DatasetDAO from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas import ( - DatasetListResponse, + DatasetList, PaginationInfo, serialize_dataset_object, + DatasetInfo, ) from superset.mcp_service.pydantic_schemas.dataset_schemas import DatasetFilter @@ -74,21 +76,14 @@ def list_datasets( Optional[str], Field(description="Text search string to match against dataset fields") ] = None, -) -> DatasetListResponse: +) -> DatasetList: """ ADVANCED FILTERING: List datasets using complex filter objects and JSON payload - Returns a DatasetListResponse Pydantic model (not a dict), matching list_datasets_simple. + Returns a DatasetList Pydantic model (not a dict), matching list_datasets_simple. """ - simple_filters = {} - if filters: - for filter_obj in filters: - if isinstance(filter_obj, DatasetFilter): - col = filter_obj.col - value = filter_obj.value - if filter_obj.opr == 'eq': - simple_filters[col] = value - elif filter_obj.opr == 'sw': - simple_filters[col] = f"{value}%" + # If filters is a string (e.g., from a test), parse it as JSON + if isinstance(filters, str): + filters = json.loads(filters) dao_wrapper = MCPDAOWrapper(DatasetDAO, "dataset") search_columns = [ "id", @@ -103,7 +98,7 @@ def list_datasets( "uuid", ] datasets, total_count = dao_wrapper.list( - filters=simple_filters, + column_operators=filters, order_column=order_column or "changed_on", order_direction=order_direction or "desc", page=max(page - 1, 0), @@ -125,7 +120,12 @@ def list_datasets( "id", "table_name", "db_schema", "database_name", "description", "changed_by_name", "changed_on", "created_by_name", "created_on" ] - dataset_items = [serialize_dataset_object(dataset) for dataset in datasets] + # Robustly handle both 'schema' and 'db_schema' for DatasetInfo + dataset_items = [] + for dataset in datasets: + item = serialize_dataset_object(dataset) + if item is not None: + dataset_items.append(item) total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 0 pagination_info = PaginationInfo( page=page, @@ -135,7 +135,7 @@ def list_datasets( has_next=page < total_pages - 1, has_previous=page > 0 ) - response = DatasetListResponse( + response = DatasetList( datasets=dataset_items, count=len(dataset_items), total_count=total_count, @@ -146,7 +146,7 @@ def list_datasets( has_next=page < total_pages - 1, columns_requested=columns_to_load, columns_loaded=list(set([col for item in dataset_items for col in item.model_dump().keys()])), - filters_applied=simple_filters, + filters_applied=filters if isinstance(filters, list) else [], pagination=pagination_info, timestamp=datetime.now(timezone.utc) ) diff --git a/superset/mcp_service/tools/system/get_superset_instance_info.py b/superset/mcp_service/tools/system/get_superset_instance_info.py index 94cb8074ad0..c8afded97f3 100644 --- a/superset/mcp_service/tools/system/get_superset_instance_info.py +++ b/superset/mcp_service/tools/system/get_superset_instance_info.py @@ -9,15 +9,15 @@ from datetime import datetime, timedelta, timezone from superset.mcp_service.dao_wrapper import MCPDAOWrapper from superset.mcp_service.pydantic_schemas.system_schemas import ( DashboardBreakdown, DatabaseBreakdown, InstanceSummary, PopularContent, - RecentActivity, SupersetInstanceInfoResponse, ) + RecentActivity, InstanceInfo, ) logger = logging.getLogger(__name__) -def get_superset_instance_info() -> SupersetInstanceInfoResponse: +def get_superset_instance_info() -> InstanceInfo: """ Get high-level information about the Superset instance (direct DB query, not via REST API) Returns: - SupersetInstanceInfoResponse + InstanceInfo """ try: from superset.extensions import db @@ -42,7 +42,7 @@ def get_superset_instance_info() -> SupersetInstanceInfoResponse: from superset.daos.css import CssTemplateDAO from superset.daos.query import QueryDAO, SavedQueryDAO from superset.daos.datasource import DatasourceDAO - from superset.daos.base import BaseDAO + from superset.daos.base import BaseDAO, ColumnOperator, ColumnOperatorEnum # Instantiate MCPDAOWrappers dashboard_wrapper = MCPDAOWrapper(DashboardDAO, "dashboard") @@ -66,24 +66,24 @@ def get_superset_instance_info() -> SupersetInstanceInfoResponse: thirty_days_ago = now - timedelta(days=30) seven_days_ago = now - timedelta(days=7) - dashboards_created_last_30_days = dashboard_wrapper.count(filters={"created_on": thirty_days_ago}) - charts_created_last_30_days = chart_wrapper.count(filters={"created_on": thirty_days_ago}) - datasets_created_last_30_days = dataset_wrapper.count(filters={"created_on": thirty_days_ago}) + dashboards_created_last_30_days = dashboard_wrapper.count(column_operators=[ColumnOperator(col="created_on", opr=ColumnOperatorEnum.gte, value=thirty_days_ago)]) + charts_created_last_30_days = chart_wrapper.count(column_operators=[ColumnOperator(col="created_on", opr=ColumnOperatorEnum.gte, value=thirty_days_ago)]) + datasets_created_last_30_days = dataset_wrapper.count(column_operators=[ColumnOperator(col="created_on", opr=ColumnOperatorEnum.gte, value=thirty_days_ago)]) - dashboards_modified_last_7_days = dashboard_wrapper.count(filters={"changed_on": seven_days_ago}) - charts_modified_last_7_days = chart_wrapper.count(filters={"changed_on": seven_days_ago}) - datasets_modified_last_7_days = dataset_wrapper.count(filters={"changed_on": seven_days_ago}) + dashboards_modified_last_7_days = dashboard_wrapper.count(column_operators=[ColumnOperator(col="changed_on", opr=ColumnOperatorEnum.gte, value=seven_days_ago)]) + charts_modified_last_7_days = chart_wrapper.count(column_operators=[ColumnOperator(col="changed_on", opr=ColumnOperatorEnum.gte, value=seven_days_ago)]) + datasets_modified_last_7_days = dataset_wrapper.count(column_operators=[ColumnOperator(col="changed_on", opr=ColumnOperatorEnum.gte, value=seven_days_ago)]) # Dashboard breakdown - published_count = dashboard_wrapper.count(filters={"published": True}) + published_count = dashboard_wrapper.count(column_operators=[ColumnOperator(col="published", opr=ColumnOperatorEnum.eq, value=True)]) unpublished_dashboards = total_dashboards - published_count - certified_count = dashboard_wrapper.count(filters={"certified_by": "not_null"}) # Custom logic may be needed + certified_count = dashboard_wrapper.count(column_operators=[ColumnOperator(col="certified_by", opr=ColumnOperatorEnum.is_not_null, value=None)]) # Custom logic may be needed dashboards_with_charts = db.session.query(Dashboard).join(Dashboard.slices).distinct().count() # No direct DAO method dashboards_without_charts = total_dashboards - dashboards_with_charts avg_charts_per_dashboard = (total_charts / total_dashboards) if total_dashboards > 0 else 0 # Compose response using keyword arguments and nested models - response = SupersetInstanceInfoResponse( + response = InstanceInfo( instance_summary=InstanceSummary( total_dashboards=total_dashboards, total_charts=total_charts, diff --git a/tests/integration_tests/mcp_service/run_mcp_tests.py b/tests/integration_tests/mcp_service/run_mcp_tests.py index 4cc516d711c..222be989a41 100755 --- a/tests/integration_tests/mcp_service/run_mcp_tests.py +++ b/tests/integration_tests/mcp_service/run_mcp_tests.py @@ -49,28 +49,28 @@ def test_mcp_service_connection(): async def test_mcp_tools(client): """Test all available MCP tools""" logger.info("Testing MCP tools...") - + try: # Use the client within the async context manager async with client: # Test ping to verify connection await client.ping() logger.info("✅ Ping successful - MCP service is reachable") - + # List available tools tools = await client.list_tools() logger.info(f"✅ Found {len(tools)} available tools:") for tool in tools: - logger.info(f" - {tool.name}: {tool.description}") - + logger.info(f" - {tool.name}") + # Test get_dashboard_info tool logger.info("Testing get_dashboard_info tool...") try: # First, get a list of dashboards to find a valid dashboard ID - dashboards_result = await client.call_tool("list_dashboards_simple", {"page_size": 10}) - logger.info(f"list_dashboards_simple output (repr): {repr(dashboards_result.data)}") + dashboards_result = await client.call_tool("list_dashboards", {"page_size": 1}) + logger.info(f"list_dashboards output (repr): {repr(dashboards_result.data)}") if hasattr(dashboards_result.data, "model_dump"): - logger.info(f"list_dashboards_simple output (dict): {dashboards_result.data.model_dump()}") + logger.info(f"list_dashboards output (dict): {dashboards_result.data.model_dump()}") dashboards_data = dashboards_result.data if hasattr(dashboards_data, "model_dump"): dashboards_dict = dashboards_data.model_dump() @@ -116,30 +116,30 @@ async def test_mcp_tools(client): except Exception as e: logger.info(f"✅ get_dashboard_info correctly rejected invalid parameters: {e}") - # Test list_dashboards_simple tool - logger.info("Testing list_dashboards_simple tool...") + # Test list_dashboards tool + logger.info("Testing list_dashboards tool...") try: - result = await client.call_tool("list_dashboards_simple", {}) - logger.info(f"list_dashboards_simple output (repr): {repr(result.data)}") + result = await client.call_tool("list_dashboards", {}) + logger.info(f"list_dashboards output (repr): {repr(result.data)}") # Always convert to dict if possible data_dict = None if hasattr(result.data, "model_dump"): data_dict = result.data.model_dump() - logger.info(f"list_dashboards_simple output (dict): {data_dict}") + logger.info(f"list_dashboards output (dict): {data_dict}") elif isinstance(result.data, dict): data_dict = result.data - logger.info(f"list_dashboards_simple output (dict): {data_dict}") + logger.info(f"list_dashboards output (dict): {data_dict}") else: - logger.warning(f"list_dashboards_simple returned a non-dict, non-Pydantic type: {type(result.data)}. Skipping validation.") - logger.info("✅ list_dashboards_simple tool call successful (skipped validation)") + logger.warning(f"list_dashboards returned a non-dict, non-Pydantic type: {type(result.data)}. Skipping validation.") + logger.info("✅ list_dashboards tool call successful (skipped validation)") return True - logger.info("✅ list_dashboards_simple tool call successful") + logger.info("✅ list_dashboards tool call successful") # Validate response structure if data_dict is not None: expected_fields = ["dashboards", "count", "total_count"] missing_fields = [field for field in expected_fields if field not in data_dict] if missing_fields: - logger.error(f"❌ list_dashboards_simple missing expected fields: {missing_fields}") + logger.error(f"❌ list_dashboards missing expected fields: {missing_fields}") return False if not isinstance(data_dict["dashboards"], list): logger.error(f"❌ 'dashboards' should be list, got {type(data_dict['dashboards'])}") @@ -153,8 +153,8 @@ async def test_mcp_tools(client): logger.info(f"Found {len(data_dict['dashboards'])} dashboards") if len(data_dict["dashboards"]) > 0: dashboard = data_dict["dashboards"][0] - if hasattr(dashboard, "model_dump"): - dashboard = dashboard.model_dump() + # if hasattr(dashboard, "model_dump"): + # dashboard = dashboard.model_dump() if not isinstance(dashboard, dict): logger.error(f"❌ Dashboard should be dict, got {type(dashboard)}") return False @@ -165,7 +165,7 @@ async def test_mcp_tools(client): return False logger.info(f"✅ First dashboard validated: {dashboard.get('dashboard_title', 'N/A')}") except Exception as e: - logger.error(f"❌ list_dashboards_simple failed: {e}") + logger.error(f"❌ list_dashboards failed: {e}") return False # Test list_dashboards tool @@ -300,12 +300,12 @@ async def test_mcp_tools(client): if hasattr(dataset, "model_dump"): dataset = dataset.model_dump() if not isinstance(dataset, dict): - logger.error(f"❌ Dataset should be dict, got {type(dataset)}") + logger.error(f"❌ DatasetInfo should be dict, got {type(dataset)}") return False required_fields = ["id", "table_name"] missing_fields = [field for field in required_fields if field not in dataset] if missing_fields: - logger.error(f"❌ Dataset missing required fields: {missing_fields}") + logger.error(f"❌ DatasetInfo missing required fields: {missing_fields}") return False logger.info(f"✅ First dataset validated: {dataset.get('table_name', 'N/A')}") except Exception as e: diff --git a/tests/integration_tests/mcp_service/test_get_chart_list_tools.py b/tests/integration_tests/mcp_service/test_get_chart_list_tools.py index 53a22619194..4ebb3014d92 100644 --- a/tests/integration_tests/mcp_service/test_get_chart_list_tools.py +++ b/tests/integration_tests/mcp_service/test_get_chart_list_tools.py @@ -1,7 +1,6 @@ -import logging -import sys -import traceback import json +import logging +import traceback logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) @@ -39,16 +38,9 @@ async def test_tool(client, tool_name, payload, label, issues): async def main(): from fastmcp import Client - logger.info("Starting integration test for list_charts and list_charts_simple tools") + logger.info("Starting integration test for list_charts and list_charts tools") issues = [] async with Client("http://localhost:5008/mcp") as client: - # Test list_charts_simple with default params - await test_tool(client, "list_charts_simple", {}, "(default)", issues) - # Test list_charts_simple with a filter - await test_tool(client, "list_charts_simple", {"filters": {"viz_type": "bar"}}, "(viz_type=bar)", issues) - # Test list_charts_simple with pagination - await test_tool(client, "list_charts_simple", {"page": 1, "page_size": 2}, "(page=1, page_size=2)", issues) - # Test list_charts (advanced) with default params await test_tool(client, "list_charts", {}, "(default)", issues) # Test list_charts with a filter (slice_name sw 'ab') @@ -69,8 +61,8 @@ async def main(): for tool_name, label, msg in issues: logger.warning(f" {tool_name} {label}: {msg}") else: - logger.info("All list_charts and list_charts_simple calls returned successfully with no errors or warnings.") + logger.info("All list_charts and list_charts calls returned successfully with no errors or warnings.") if __name__ == "__main__": import asyncio - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tests/integration_tests/mcp_service/test_get_dataset_list_tools.py b/tests/integration_tests/mcp_service/test_get_dataset_list_tools.py index 0cd3a092480..f502000cb27 100644 --- a/tests/integration_tests/mcp_service/test_get_dataset_list_tools.py +++ b/tests/integration_tests/mcp_service/test_get_dataset_list_tools.py @@ -39,15 +39,15 @@ async def test_tool(client, tool_name, payload, label, issues): async def main(): from fastmcp import Client - logger.info("Starting integration test for list_datasets and list_datasets_simple tools") + logger.info("Starting integration test for list_datasets and list_datasets tools") issues = [] async with Client("http://localhost:5008/mcp") as client: - # Test list_datasets_simple with default params - await test_tool(client, "list_datasets_simple", {}, "(default)", issues) - # Test list_datasets_simple with a filter - await test_tool(client, "list_datasets_simple", {"filters": {"schema": "public"}}, "(schema=public)", issues) - # Test list_datasets_simple with pagination - await test_tool(client, "list_datasets_simple", {"page": 1, "page_size": 2}, "(page=1, page_size=2)", issues) + # Test list_datasets with default params + await test_tool(client, "list_datasets", {}, "(default)", issues) + # Test list_datasets with a filter + await test_tool(client, "list_datasets", {"filters": {"schema": "public"}}, "(schema=public)", issues) + # Test list_datasets with pagination + await test_tool(client, "list_datasets", {"page": 1, "page_size": 2}, "(page=1, page_size=2)", issues) # Test list_datasets (advanced) with default params await test_tool(client, "list_datasets", {}, "(default)", issues) @@ -66,8 +66,8 @@ async def main(): for tool_name, label, msg in issues: logger.warning(f" {tool_name} {label}: {msg}") else: - logger.info("All list_datasets and list_datasets_simple calls returned successfully with no errors or warnings.") + logger.info("All list_datasets and list_datasets calls returned successfully with no errors or warnings.") if __name__ == "__main__": import asyncio - asyncio.run(main()) \ No newline at end of file + asyncio.run(main()) diff --git a/tests/unit_tests/mcp_service/test_chart_tools.py b/tests/unit_tests/mcp_service/test_chart_tools.py new file mode 100644 index 00000000000..5df10cccb1f --- /dev/null +++ b/tests/unit_tests/mcp_service/test_chart_tools.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP chart tools (list_charts, get_chart_info, get_chart_available_filters, create_chart_simple) +""" +import logging +from unittest.mock import Mock, patch +import pytest +from superset.mcp_service.pydantic_schemas.chart_schemas import ( + ChartInfo, ChartError, ChartAvailableFiltersResponse +) +from superset.mcp_service.tools.chart import ( + list_charts, get_chart_info, get_chart_available_filters, create_chart_simple +) +from superset.mcp_service.pydantic_schemas.chart_schemas import CreateSimpleChartRequest + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +class TestChartTools: + """Test chart-related MCP tools""" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_charts_basic(self, mock_list): + chart = Mock() + chart.id = 1 + chart.slice_name = "Test Chart" + chart.viz_type = "bar" + chart.datasource_name = "test_ds" + chart.datasource_type = "table" + chart.url = "/chart/1" + chart.description = "desc" + chart.cache_timeout = 60 + chart.form_data = {} + chart.query_context = {} + chart.changed_by_name = "admin" + chart.changed_on = None + chart.changed_on_humanized = "1 day ago" + chart.created_by_name = "admin" + chart.created_on = None + chart.created_on_humanized = "2 days ago" + chart.tags = [] + chart.owners = [] + mock_list.return_value = ([chart], 1) + result = list_charts() + assert result.count == 1 + assert result.charts[0].slice_name == "Test Chart" + assert result.charts[0].viz_type == "bar" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_charts_with_search(self, mock_list): + chart = Mock() + chart.id = 1 + chart.slice_name = "search_chart" + chart.viz_type = "bar" + chart.datasource_name = "test_ds" + chart.datasource_type = "table" + chart.url = "/chart/1" + chart.description = "desc" + chart.cache_timeout = 60 + chart.form_data = {} + chart.query_context = {} + chart.changed_by_name = "admin" + chart.changed_on = None + chart.changed_on_humanized = "1 day ago" + chart.created_by_name = "admin" + chart.created_on = None + chart.created_on_humanized = "2 days ago" + chart.tags = [] + chart.owners = [] + mock_list.return_value = ([chart], 1) + result = list_charts(search="search_chart") + assert result.count == 1 + assert result.charts[0].slice_name == "search_chart" + args, kwargs = mock_list.call_args + assert kwargs["search"] == "search_chart" + assert "slice_name" in kwargs["search_columns"] + assert "viz_type" in kwargs["search_columns"] + assert "datasource_name" in kwargs["search_columns"] + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_charts_with_filters(self, mock_list): + mock_list.return_value = ([], 0) + filters = [ + {"col": "slice_name", "opr": "sw", "value": "Sales"}, + {"col": "viz_type", "opr": "eq", "value": "bar"} + ] + result = list_charts( + filters=filters, + select_columns=["id", "slice_name"], + order_column="changed_on", + order_direction="desc", + page=1, + page_size=50 + ) + assert result.count == 0 + assert result.charts == [] + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_charts_api_error(self, mock_list): + mock_list.side_effect = Exception("API request failed") + with pytest.raises(Exception) as excinfo: + list_charts() + assert "API request failed" in str(excinfo.value) + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + def test_get_chart_info_success(self, mock_info): + chart = Mock() + chart.id = 1 + chart.slice_name = "Test Chart" + chart.viz_type = "bar" + chart.datasource_name = "test_ds" + chart.datasource_type = "table" + chart.url = "/chart/1" + chart.description = "desc" + chart.cache_timeout = 60 + chart.form_data = {} + chart.query_context = {} + chart.changed_by_name = "admin" + chart.changed_on = None + chart.changed_on_humanized = "1 day ago" + chart.created_by_name = "admin" + chart.created_on = None + chart.created_on_humanized = "2 days ago" + chart.tags = [] + chart.owners = [] + mock_info.return_value = (chart, None, None) + result = get_chart_info(1) + assert isinstance(result, ChartInfo) + assert result.id == 1 + assert result.slice_name == "Test Chart" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + def test_get_chart_info_not_found(self, mock_info): + mock_info.return_value = (None, "not_found", "Chart not found") + result = get_chart_info(999) + assert isinstance(result, ChartError) + assert result.error == "Chart not found" + assert result.error_type == "not_found" + + def test_get_chart_available_filters_success(self): + result = get_chart_available_filters() + assert isinstance(result, ChartAvailableFiltersResponse) + assert "slice_name" in result.filters + assert "eq" in result.operators + assert "slice_name" in result.columns or "id" in result.columns + + def test_get_chart_available_filters_exception_handling(self): + result = get_chart_available_filters() + assert isinstance(result, ChartAvailableFiltersResponse) + assert hasattr(result, "filters") + assert hasattr(result, "operators") + assert hasattr(result, "columns") + + @patch('superset.commands.chart.create.CreateChartCommand.run') + def test_create_chart_simple_success(self, mock_run): + chart = Mock() + chart.id = 42 + chart.slice_name = "Created Chart" + chart.viz_type = "bar" + chart.datasource_name = "test_ds" + chart.datasource_type = "table" + chart.url = "/chart/42" + chart.description = "desc" + chart.cache_timeout = 60 + chart.form_data = {} + chart.query_context = {} + chart.changed_by_name = "admin" + chart.changed_on = None + chart.changed_on_humanized = "1 day ago" + chart.created_by_name = "admin" + chart.created_on = None + chart.created_on_humanized = "2 days ago" + chart.tags = [] + chart.owners = [] + mock_run.return_value = chart + req = CreateSimpleChartRequest( + slice_name="Created Chart", + viz_type="bar", + datasource_id=1, + metrics=["sum__sales"], + dimensions=["region"], + filters=[{"col": "year", "opr": "eq", "value": 2024}], + description="A chart created by test", + return_embed=True + ) + result = create_chart_simple(request=req) + assert result.chart is not None + assert result.chart.slice_name == "Created Chart" + assert result.embed_url is not None + assert result.thumbnail_url is not None + assert result.embed_html is not None + + @patch('superset.commands.chart.create.CreateChartCommand.run') + def test_create_chart_simple_error(self, mock_run): + mock_run.side_effect = Exception("Chart creation failed") + req = CreateSimpleChartRequest( + slice_name="Fail Chart", + viz_type="bar", + datasource_id=1, + metrics=["sum__sales"], + dimensions=["region"], + filters=[{"col": "year", "opr": "eq", "value": 2024}], + description="A chart that fails", + return_embed=False + ) + result = create_chart_simple(request=req) + assert result.error is not None + assert "Chart creation failed" in result.error \ No newline at end of file diff --git a/tests/unit_tests/mcp_service/test_dashboard_tools.py b/tests/unit_tests/mcp_service/test_dashboard_tools.py new file mode 100644 index 00000000000..04a75a4e1ec --- /dev/null +++ b/tests/unit_tests/mcp_service/test_dashboard_tools.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP dashboard tools (list_dashboards, get_dashboard_info, get_dashboard_available_filters) +""" +import logging +from unittest.mock import Mock, patch +import pytest +from superset.mcp_service.pydantic_schemas.dashboard_schemas import ( + DashboardAvailableFilters, DashboardError, DashboardInfo, DashboardList, +) +from superset.mcp_service.tools.dashboard import ( + get_dashboard_available_filters, get_dashboard_info, list_dashboards, +) + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +class TestDashboardTools: + """Test dashboard-related MCP tools""" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_dashboards_basic(self, mock_list): + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "Test Dashboard" + dashboard.slug = "test-dashboard" + dashboard.url = "/dashboard/1" + dashboard.published = True + dashboard.changed_by_name = "admin" + dashboard.changed_on = None + dashboard.changed_on_humanized = None + dashboard.created_by_name = "admin" + dashboard.created_on = None + dashboard.created_on_humanized = None + dashboard.tags = [] + dashboard.owners = [] + dashboard.slices = [] + mock_list.return_value = ([dashboard], 1) + + result = list_dashboards() + assert result.count == 1 + assert result.total_count == 1 + assert result.dashboards[0].dashboard_title == "Test Dashboard" + assert result.dashboards[0].published is True + assert result.dashboards[0].changed_by == "admin" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_dashboards_with_filters(self, mock_list): + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "Filtered Dashboard" + dashboard.slug = "filtered-dashboard" + dashboard.url = "/dashboard/2" + dashboard.published = True + dashboard.changed_by_name = "admin" + dashboard.changed_on = None + dashboard.changed_on_humanized = None + dashboard.created_by_name = "admin" + dashboard.created_on = None + dashboard.created_on_humanized = None + dashboard.tags = [] + dashboard.owners = [] + dashboard.slices = [] + mock_list.return_value = ([dashboard], 1) + filters = [ + {"col": "dashboard_title", "opr": "sw", "value": "Sales"}, + {"col": "published", "opr": "eq", "value": True} + ] + result = list_dashboards( + filters=filters, + select_columns=["id", "dashboard_title"], + order_column="changed_on", + order_direction="desc", + page=1, + page_size=50 + ) + assert result.count == 1 + assert result.dashboards[0].dashboard_title == "Filtered Dashboard" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_dashboards_with_string_filters(self, mock_list): + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "String Filter Dashboard" + dashboard.slug = "string-filter-dashboard" + dashboard.url = "/dashboard/3" + dashboard.published = True + dashboard.changed_by_name = "admin" + dashboard.changed_on = None + dashboard.changed_on_humanized = None + dashboard.created_by_name = "admin" + dashboard.created_on = None + dashboard.created_on_humanized = None + dashboard.tags = [] + dashboard.owners = [] + dashboard.slices = [] + mock_list.return_value = ([dashboard], 1) + filters = '[{"col": "dashboard_title", "opr": "sw", "value": "Sales"}]' + result = list_dashboards(filters=filters) + assert result.count == 1 + assert result.dashboards[0].dashboard_title == "String Filter Dashboard" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_dashboards_api_error(self, mock_list): + mock_list.side_effect = Exception("API request failed") + with pytest.raises(Exception) as excinfo: + list_dashboards() + assert "API request failed" in str(excinfo.value) + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_dashboards_with_search(self, mock_list): + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "search_dashboard" + dashboard.slug = "search-dashboard" + dashboard.url = "/dashboard/1" + dashboard.published = True + dashboard.changed_by_name = "admin" + dashboard.changed_on = None + dashboard.changed_on_humanized = None + dashboard.created_by_name = "admin" + dashboard.created_on = None + dashboard.created_on_humanized = None + dashboard.tags = [] + dashboard.owners = [] + dashboard.slices = [] + mock_list.return_value = ([dashboard], 1) + result = list_dashboards(search="search_dashboard") + assert result.count == 1 + assert result.dashboards[0].dashboard_title == "search_dashboard" + args, kwargs = mock_list.call_args + assert kwargs["search"] == "search_dashboard" + assert "dashboard_title" in kwargs["search_columns"] + assert "slug" in kwargs["search_columns"] + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_dashboards_with_simple_filters(self, mock_list): + mock_list.return_value = ([], 0) + filters = [{"col": "dashboard_title", "opr": "eq", "value": "Sales"}, {"col": "published", "opr": "eq", "value": True}] + result = list_dashboards(filters=filters) + assert hasattr(result, 'count') + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + def test_get_dashboard_info_success(self, mock_info): + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "Test Dashboard" + dashboard.slug = "test-dashboard" + dashboard.description = "Test description" + dashboard.css = None + dashboard.certified_by = None + dashboard.certification_details = None + dashboard.json_metadata = None + dashboard.position_json = None + dashboard.published = True + dashboard.is_managed_externally = False + dashboard.external_url = None + dashboard.created_on = None + dashboard.changed_on = None + dashboard.created_by = None + dashboard.changed_by = None + dashboard.uuid = None + dashboard.url = "/dashboard/1" + dashboard.thumbnail_url = None + dashboard.created_on_humanized = None + dashboard.changed_on_humanized = None + dashboard.slices = [] + dashboard.owners = [] + dashboard.tags = [] + dashboard.roles = [] + mock_info.return_value = (dashboard, None, None) + result = get_dashboard_info(1) + assert isinstance(result, DashboardInfo) + assert result.id == 1 + assert result.dashboard_title == "Test Dashboard" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + def test_get_dashboard_info_not_found(self, mock_info): + mock_info.return_value = (None, "not_found", "Dashboard not found") + result = get_dashboard_info(999) + assert isinstance(result, DashboardError) + assert result.error == "Dashboard not found" + assert result.error_type == "not_found" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + def test_get_dashboard_info_access_denied(self, mock_info): + mock_info.return_value = (None, "access_denied", "Access denied") + result = get_dashboard_info(1) + assert isinstance(result, DashboardError) + assert result.error == "Access denied" + assert result.error_type == "access_denied" + + def test_get_dashboard_available_filters_success(self): + result = get_dashboard_available_filters() + assert isinstance(result, DashboardAvailableFilters) + assert "dashboard_title" in result.filters + assert "eq" in result.operators + assert "dashboard_title" in result.columns or "id" in result.columns + + def test_get_dashboard_available_filters_exception_handling(self): + result = get_dashboard_available_filters() + assert isinstance(result, DashboardAvailableFilters) + assert hasattr(result, "filters") + assert hasattr(result, "operators") + assert hasattr(result, "columns") \ No newline at end of file diff --git a/tests/unit_tests/mcp_service/test_dataset_tools.py b/tests/unit_tests/mcp_service/test_dataset_tools.py new file mode 100644 index 00000000000..e6bfd56b367 --- /dev/null +++ b/tests/unit_tests/mcp_service/test_dataset_tools.py @@ -0,0 +1,367 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP dataset tools (list_datasets, get_dataset_info, get_dataset_available_filters) +""" +import logging +from unittest.mock import Mock, patch +import pytest +from superset.mcp_service.pydantic_schemas.dataset_schemas import ( + DatasetAvailableFilters, DatasetList, DatasetError, DatasetInfo +) +from superset.mcp_service.tools.dataset import ( + get_dataset_available_filters, get_dataset_info, list_datasets, +) + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +class TestDatasetTools: + """Test dataset-related MCP tools""" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_datasets_basic(self, mock_list): + dataset = Mock() + dataset.id = 1 + dataset.table_name = "Test DatasetInfo" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/1" + dataset.database = Mock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + mock_list.return_value = ([dataset], 1) + + result = list_datasets() + assert result.count == 1 + assert result.total_count == 1 + assert result.datasets[0].table_name == "Test DatasetInfo" + assert result.datasets[0].database_name == "examples" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_datasets_with_filters(self, mock_list): + dataset = Mock() + dataset.id = 2 + dataset.table_name = "Filtered Dataset" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/2" + dataset.database = Mock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + mock_list.return_value = ([dataset], 1) + filters = [ + {"col": "table_name", "opr": "sw", "value": "Sales"}, + {"col": "schema", "opr": "eq", "value": "main"} + ] + result = list_datasets( + filters=filters, + select_columns=["id", "table_name"], + order_column="changed_on", + order_direction="desc", + page=1, + page_size=50 + ) + assert result.count == 1 + assert result.datasets[0].table_name == "Filtered Dataset" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_datasets_with_string_filters(self, mock_list): + dataset = Mock() + dataset.id = 3 + dataset.table_name = "String Filter Dataset" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/3" + dataset.database = Mock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + mock_list.return_value = ([dataset], 1) + filters = '[{"col": "table_name", "opr": "sw", "value": "Sales"}]' + result = list_datasets(filters=filters) + assert result.count == 1 + assert result.datasets[0].table_name == "String Filter Dataset" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_datasets_api_error(self, mock_list): + mock_list.side_effect = Exception("API request failed") + with pytest.raises(Exception) as excinfo: + list_datasets() + assert "API request failed" in str(excinfo.value) + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_datasets_with_search(self, mock_list): + dataset = Mock() + dataset.id = 1 + dataset.table_name = "search_table" + dataset.db_schema = "public" + dataset.database_name = "test_db" + dataset.database = None + dataset.description = "A test dataset" + dataset.changed_by = "admin" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by = "admin" + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = None + dataset.url = None + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + mock_list.return_value = ([dataset], 1) + result = list_datasets(search="search_table") + assert result.count == 1 + assert result.datasets[0].table_name == "search_table" + args, kwargs = mock_list.call_args + assert kwargs["search"] == "search_table" + assert "table_name" in kwargs["search_columns"] + assert "schema" in kwargs["search_columns"] + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_datasets_simple_with_search(self, mock_list): + dataset = Mock() + dataset.id = 2 + dataset.table_name = "simple_search" + dataset.db_schema = "analytics" + dataset.database_name = "analytics_db" + dataset.database = None + dataset.description = "Another test dataset" + dataset.changed_by = "user" + dataset.changed_by_name = "user" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by = "user" + dataset.created_by_name = "user" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = True + dataset.database_id = 2 + dataset.schema_perm = None + dataset.url = None + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + mock_list.return_value = ([dataset], 1) + result = list_datasets(search="simple_search") + assert result.count == 1 + assert result.datasets[0].table_name == "simple_search" + args, kwargs = mock_list.call_args + assert kwargs["search"] == "simple_search" + assert "table_name" in kwargs["search_columns"] + assert "schema" in kwargs["search_columns"] + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_datasets_simple_basic(self, mock_list): + dataset = Mock() + dataset.id = 1 + dataset.table_name = "Test DatasetInfo" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/1" + dataset.database = Mock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + mock_list.return_value = ([dataset], 1) + filters = [{"col": "table_name", "opr": "eq", "value": "Test DatasetInfo"}, {"col": "schema", "opr": "eq", "value": "main"}] + result = list_datasets(filters=filters) + assert isinstance(result, DatasetList) + assert result.count == 1 + assert result.datasets[0].table_name == "Test DatasetInfo" + assert result.datasets[0].database_name == "examples" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_datasets_simple_with_filters(self, mock_list): + dataset = Mock() + dataset.id = 2 + dataset.table_name = "Sales Dataset" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/2" + dataset.database = Mock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + mock_list.return_value = ([dataset], 1) + filters = [{"col": "table_name", "opr": "sw", "value": "Sales"}, {"col": "schema", "opr": "eq", "value": "main"}] + result = list_datasets(filters=filters) + assert isinstance(result, DatasetList) + assert result.count == 1 + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_datasets_simple_api_error(self, mock_list): + mock_list.side_effect = Exception("API request failed") + filters = [{"col": "table_name", "opr": "sw", "value": "Sales"}, {"col": "schema", "opr": "eq", "value": "main"}] + with pytest.raises(Exception) as excinfo: + list_datasets(filters=filters) + assert "API request failed" in str(excinfo.value) + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + def test_get_dataset_info_success(self, mock_info): + dataset = Mock() + dataset.id = 1 + dataset.table_name = "Test DatasetInfo" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/1" + dataset.database = Mock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + mock_info.return_value = (dataset, None, None) + result = get_dataset_info(1) + assert isinstance(result, DatasetInfo) + assert result.id == 1 + assert result.table_name == "Test DatasetInfo" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') + def test_get_dataset_info_not_found(self, mock_info): + mock_info.return_value = (None, "not_found", "Dataset not found") + result = get_dataset_info(999) + assert isinstance(result, DatasetError) + assert result.error == "Dataset not found" + assert result.error_type == "not_found" + + def test_get_dataset_available_filters_success(self): + result = get_dataset_available_filters() + assert hasattr(result, "filters") + assert hasattr(result, "operators") + assert hasattr(result, "columns") + + def test_get_dataset_available_filters_exception_handling(self): + result = get_dataset_available_filters() + assert isinstance(result, DatasetAvailableFilters) + assert hasattr(result, "filters") + assert hasattr(result, "operators") + assert hasattr(result, "columns") \ No newline at end of file diff --git a/tests/unit_tests/mcp_service/test_error_handling.py b/tests/unit_tests/mcp_service/test_error_handling.py new file mode 100644 index 00000000000..11062f1fbb6 --- /dev/null +++ b/tests/unit_tests/mcp_service/test_error_handling.py @@ -0,0 +1,78 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP tool error handling and parameter validation +""" +import logging +from unittest.mock import patch +import pytest +from superset.mcp_service.pydantic_schemas.dashboard_schemas import DashboardAvailableFilters, DashboardList +from superset.mcp_service.pydantic_schemas.dataset_schemas import DatasetAvailableFilters, DatasetList +from superset.mcp_service.tools.dashboard import get_dashboard_available_filters, list_dashboards +from superset.mcp_service.tools.dataset import get_dataset_available_filters, list_datasets +from fastmcp.exceptions import ToolError + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +class TestErrorHandling: + """Test error handling and parameter validation in MCP tools""" + + @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') + def test_list_dashboards_exception_handling(self, mock_list): + mock_list.side_effect = Exception("Unexpected error") + with pytest.raises(Exception) as excinfo: + list_dashboards() + assert "Unexpected error" in str(excinfo.value) + + def test_get_dashboard_available_filters_exception_handling(self): + result = get_dashboard_available_filters() + assert isinstance(result, DashboardAvailableFilters) + assert hasattr(result, "filters") + assert hasattr(result, "operators") + assert hasattr(result, "columns") + + def test_list_datasets_exception_handling(self): + result = list_datasets() + assert isinstance(result, (dict, DatasetList)) + if isinstance(result, dict): + assert "count" in result + assert "datasets" in result + else: + assert hasattr(result, "count") + assert hasattr(result, "datasets") + + def test_list_dashboards_parameter_types(self): + with patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') as mock_list: + mock_list.return_value = ([], 0) + list_dashboards(filters='[{"col": "test", "opr": "eq", "value": "value"}]') + list_dashboards(filters=[{"col": "test", "opr": "eq", "value": "value"}]) + list_dashboards(select_columns="id,dashboard_title") + list_dashboards(select_columns=["id", "dashboard_title"]) + assert mock_list.call_count == 4 + + def test_list_datasets_parameter_types(self): + with patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') as mock_list: + mock_list.return_value = ([], 0) + list_datasets(filters='[{"col": "test", "opr": "eq", "value": "value"}]') + list_datasets(filters=[{"col": "test", "opr": "eq", "value": "value"}]) + list_datasets(select_columns="id,table_name") + list_datasets(select_columns=["id", "table_name"]) + assert mock_list.call_count == 4 + + # Example: test for missing required param, extra param, and malformed input would be in protocol/integration tests \ No newline at end of file diff --git a/tests/unit_tests/mcp_service/test_fastmcp_tools.py b/tests/unit_tests/mcp_service/test_fastmcp_tools.py deleted file mode 100644 index 26bb403189e..00000000000 --- a/tests/unit_tests/mcp_service/test_fastmcp_tools.py +++ /dev/null @@ -1,827 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -""" -Unit tests for FastMCP server tools - -This module tests all FastMCP tools in the MCP service: -- Dashboard tools: list_dashboards, list_dashboards_simple, get_dashboard_info -- System tools: get_superset_instance_info, get_dashboard_available_filters -""" - -import logging -from unittest.mock import Mock, patch - -import pytest -from fastmcp import FastMCP -from fastmcp.client.client import CallToolResult -from fastmcp.exceptions import ToolError -from flask import Flask, g -from flask_login import AnonymousUserMixin -from superset.mcp_service.pydantic_schemas.dashboard_schemas import ( - DashboardAvailableFiltersResponse, DashboardErrorResponse, DashboardInfoResponse, DashboardListResponse, - DashboardSimpleFilters, ) -from superset.mcp_service.pydantic_schemas.dataset_schemas import ( - DatasetAvailableFiltersResponse, DatasetListResponse, DatasetSimpleFilters, ) -from superset.mcp_service.pydantic_schemas.system_schemas import (InstanceSummary, SupersetInstanceInfoResponse) -from superset.mcp_service.tools import get_dataset_available_filters -# Import the original functions before they get decorated -from superset.mcp_service.tools.dashboard import ( - get_dashboard_available_filters, get_dashboard_info, list_dashboards -) -from superset.mcp_service.tools.dataset import list_datasets -from superset.mcp_service.tools.system import get_superset_instance_info - -# Configure logging for tests -logging.basicConfig(level=logging.DEBUG) -logger = logging.getLogger(__name__) - - -class TestDashboardTools: - """Test dashboard-related FastMCP tools""" - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_dashboards_basic(self, mock_list): - """Test list_dashboards with basic parameters""" - # Mock dashboard object - dashboard = Mock() - dashboard.id = 1 - dashboard.dashboard_title = "Test Dashboard" - dashboard.slug = "test-dashboard" - dashboard.url = "/dashboard/1" - dashboard.published = True - dashboard.changed_by_name = "admin" - dashboard.changed_on = None - dashboard.changed_on_humanized = None - dashboard.created_by_name = "admin" - dashboard.created_on = None - dashboard.created_on_humanized = None - dashboard.tags = [] - dashboard.owners = [] - mock_list.return_value = ([dashboard], 1) - - result = list_dashboards() - assert result.count == 1 - assert result.total_count == 1 - assert result.dashboards[0].dashboard_title == "Test Dashboard" - assert result.dashboards[0].published is True - assert result.dashboards[0].changed_by == "admin" - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_dashboards_with_filters(self, mock_list): - """Test list_dashboards with complex filters""" - mock_list.return_value = ([], 0) - filters = [ - {"col": "dashboard_title", "opr": "sw", "value": "Sales"}, - {"col": "published", "opr": "eq", "value": True} - ] - result = list_dashboards( - filters=filters, - select_columns=["id", "dashboard_title"], - order_column="changed_on", - order_direction="desc", - page=1, - page_size=50 - ) - assert result.count == 0 - assert result.total_count == 0 - assert result.dashboards == [] - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_dashboards_with_string_filters(self, mock_list): - """Test list_dashboards with string filter input (should fail or be ignored)""" - mock_list.return_value = ([], 0) - # Remove this test or update to expect failure, as only advanced filter lists are supported - filters = '[{"col": "dashboard_title", "opr": "sw", "value": "Sales"}]' - with pytest.raises(Exception): - list_dashboards(filters=filters) - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_dashboards_api_error(self, mock_list): - """Test list_dashboards with API error""" - mock_list.side_effect = Exception("API request failed") - with pytest.raises(Exception) as excinfo: - list_dashboards() - assert "API request failed" in str(excinfo.value) - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_dashboards_with_search(self, mock_list): - """Test list_dashboards with a text search parameter""" - dashboard = Mock() - dashboard.id = 1 - dashboard.dashboard_title = "search_dashboard" - dashboard.slug = "search-dashboard" - dashboard.url = "/dashboard/1" - dashboard.published = True - dashboard.changed_by_name = "admin" - dashboard.changed_on = None - dashboard.changed_on_humanized = None - dashboard.created_by_name = "admin" - dashboard.created_on = None - dashboard.created_on_humanized = None - dashboard.tags = [] - dashboard.owners = [] - mock_list.return_value = ([dashboard], 1) - result = list_dashboards(search="search_dashboard") - assert result.count == 1 - assert result.dashboards[0].dashboard_title == "search_dashboard" - # Ensure search and search_columns were passed - args, kwargs = mock_list.call_args - assert kwargs["search"] == "search_dashboard" - assert "dashboard_title" in kwargs["search_columns"] - assert "slug" in kwargs["search_columns"] - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_dashboards_with_simple_filters(self, mock_list): - """Test list_dashboards with simple filter parameters (previously tested in list_dashboards_simple)""" - mock_list.return_value = ([], 0) - filters = [{"col": "dashboard_title", "opr": "eq", "value": "Sales"}, {"col": "published", "opr": "eq", "value": True}] - result = list_dashboards(filters=filters) - assert hasattr(result, 'count') - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') - def test_get_dashboard_info_success(self, mock_info): - """Test get_dashboard_info with successful response""" - dashboard = Mock() - dashboard.id = 1 - dashboard.dashboard_title = "Test Dashboard" - dashboard.slug = "test-dashboard" - dashboard.description = "Test description" - dashboard.css = None - dashboard.certified_by = None - dashboard.certification_details = None - dashboard.json_metadata = None - dashboard.position_json = None - dashboard.published = True - dashboard.is_managed_externally = False - dashboard.external_url = None - dashboard.created_on = None - dashboard.changed_on = None - dashboard.created_by = None - dashboard.changed_by = None - dashboard.uuid = None - dashboard.url = "/dashboard/1" - dashboard.thumbnail_url = None - dashboard.created_on_humanized = None - dashboard.changed_on_humanized = None - dashboard.slices = [] - dashboard.owners = [] - dashboard.tags = [] - dashboard.roles = [] - mock_info.return_value = (dashboard, None, None) - result = get_dashboard_info(1) - assert isinstance(result, DashboardInfoResponse) - assert result.id == 1 - assert result.dashboard_title == "Test Dashboard" - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') - def test_get_dashboard_info_not_found(self, mock_info): - """Test get_dashboard_info with 404 error""" - mock_info.return_value = (None, "not_found", "Dashboard not found") - result = get_dashboard_info(999) - assert isinstance(result, DashboardErrorResponse) - assert result.error == "Dashboard not found" - assert result.error_type == "not_found" - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.info') - def test_get_dashboard_info_access_denied(self, mock_info): - """Test get_dashboard_info with 403 error""" - mock_info.return_value = (None, "access_denied", "Access denied") - result = get_dashboard_info(1) - assert isinstance(result, DashboardErrorResponse) - assert result.error == "Access denied" - assert result.error_type == "access_denied" - - -class TestSystemTools: - """Test system-related FastMCP tools""" - - @patch('superset.extensions.db') - def test_get_superset_instance_info_success(self, mock_db): - """Test get_superset_instance_info with successful response""" - mock_app = Mock() - mock_app.app_context.return_value.__enter__ = Mock() - mock_app.app_context.return_value.__exit__ = Mock() - mock_session = Mock() - mock_db.session = mock_session - # Patch dashboards_with_charts to return 5 - mock_session.query.return_value.join.return_value.distinct.return_value.count.return_value = 5 - # Patch query(Role).count() to return an int for total_roles - mock_session.query.return_value.count.return_value = 10 - app = Flask(__name__) - app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False - with app.app_context(): - g.user = AnonymousUserMixin() - with patch('superset.mcp_service.tools.system.get_superset_instance_info.MCPDAOWrapper.count', side_effect=[ - 10, # total_dashboards - 10, # total_charts - 10, # total_datasets - 10, # total_databases - 10, # total_users - 10, # total_tags - 2, # recent_dashboards - 2, # recent_charts - 2, # recent_datasets - 2, # recently_modified_dashboards - 2, # recently_modified_charts - 2, # recently_modified_datasets - 5, # published_dashboards - 3, # certified_dashboards - ]): - result = get_superset_instance_info() - del g.user - assert isinstance(result, SupersetInstanceInfoResponse) - assert isinstance(result.instance_summary, InstanceSummary) - assert result.instance_summary.total_dashboards == 10 - assert result.instance_summary.total_charts == 10 - assert result.instance_summary.total_datasets == 10 - assert result.instance_summary.total_databases == 10 - assert result.instance_summary.total_users == 10 - assert result.instance_summary.total_tags == 10 - assert result.instance_summary.avg_charts_per_dashboard == 1.0 - # ... other assertions as needed ... - - @patch('superset.extensions.db') - def test_get_superset_instance_info_failure(self, mock_db): - """Test get_superset_instance_info with database error""" - mock_app = Mock() - mock_app.app_context.return_value.__enter__ = Mock() - mock_app.app_context.return_value.__exit__ = Mock() - mock_session = Mock() - mock_db.session = mock_session - mock_session.query.side_effect = Exception("Database connection failed") - app = Flask(__name__) - app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False - with app.app_context(): - g.user = AnonymousUserMixin() - with pytest.raises(Exception) as excinfo: - get_superset_instance_info() - assert "Database connection failed" in str(excinfo.value) - - def test_get_dashboard_available_filters_success(self): - result = get_dashboard_available_filters() - assert isinstance(result, DashboardAvailableFiltersResponse) - assert "dashboard_title" in result.filters - assert "eq" in result.operators - assert "dashboard_title" in result.columns or "id" in result.columns - - def test_get_dashboard_available_filters_exception_handling(self): - """Test get_dashboard_available_filters handles exceptions gracefully""" - # This tool doesn't make API calls, so we test with a different approach - # We'll test that it returns the expected structure even if there are issues - result = get_dashboard_available_filters() - # Should always return a valid structure - assert isinstance(result, DashboardAvailableFiltersResponse) - assert hasattr(result, "filters") - assert hasattr(result, "operators") - assert hasattr(result, "columns") - - def test_get_dataset_available_filters_success(self): - from superset.mcp_service.tools.dataset.get_dataset_available_filters import get_dataset_available_filters - result = get_dataset_available_filters() - assert hasattr(result, "filters") - assert hasattr(result, "operators") - assert hasattr(result, "columns") - - def test_get_dataset_available_filters_exception_handling(self): - """Test get_dataset_available_filters handles exceptions gracefully""" - # This tool doesn't make API calls, so we test with a different approach - # We'll test that it returns the expected structure even if there are issues - result = get_dataset_available_filters() - # Should always return a valid structure - assert isinstance(result, DatasetAvailableFiltersResponse) - assert hasattr(result, "filters") - assert hasattr(result, "operators") - assert hasattr(result, "columns") - - -class TestDatasetTools: - """Test dataset-related FastMCP tools""" - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_datasets_basic(self, mock_list): - """Test list_datasets with basic parameters""" - dataset = Mock() - dataset.id = 1 - dataset.table_name = "Test Dataset" - dataset.schema = "main" - dataset.description = "desc" - dataset.changed_by_name = "admin" - dataset.changed_on = None - dataset.changed_on_humanized = None - dataset.created_by_name = "admin" - dataset.created_on = None - dataset.created_on_humanized = None - dataset.tags = [] - dataset.owners = [] - dataset.is_virtual = False - dataset.database_id = 1 - dataset.schema_perm = "[examples].[main]" - dataset.url = "/tablemodelview/edit/1" - dataset.database = Mock() - dataset.database.database_name = "examples" - mock_list.return_value = ([dataset], 1) - - result = list_datasets() - assert result.count == 1 - assert result.total_count == 1 - assert result.datasets[0].table_name == "Test Dataset" - assert result.datasets[0].database_name == "examples" - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_datasets_with_filters(self, mock_list): - """Test list_datasets with advanced filters""" - mock_list.return_value = ([], 0) - filters = [ - {"col": "table_name", "opr": "sw", "value": "Sales"}, - {"col": "schema", "opr": "eq", "value": "main"} - ] - result = list_datasets( - filters=filters, - select_columns=["id", "table_name"], - order_column="changed_on", - order_direction="desc", - page=1, - page_size=50 - ) - assert result.count == 0 - assert result.total_count == 0 - assert result.datasets == [] - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_datasets_with_string_filters(self, mock_list): - """Test list_datasets with string filter input (should fail or be ignored)""" - mock_list.return_value = ([], 0) - filters = '[{"col": "table_name", "opr": "sw", "value": "Sales"}]' - with pytest.raises(Exception): - list_datasets(filters=filters) - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_datasets_api_error(self, mock_list): - """Test list_datasets with API error""" - mock_list.side_effect = Exception("API request failed") - with pytest.raises(Exception) as excinfo: - list_datasets() - assert "API request failed" in str(excinfo.value) - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_datasets_with_search(self, mock_list): - """Test list_datasets with a text search parameter""" - dataset = Mock() - dataset.id = 1 - dataset.table_name = "search_table" - dataset.db_schema = "public" - dataset.database_name = "test_db" - dataset.database = None - dataset.description = "A test dataset" - dataset.changed_by = "admin" - dataset.changed_by_name = "admin" - dataset.changed_on = None - dataset.changed_on_humanized = None - dataset.created_by = "admin" - dataset.created_by_name = "admin" - dataset.created_on = None - dataset.created_on_humanized = None - dataset.tags = [] - dataset.owners = [] - dataset.is_virtual = False - dataset.database_id = 1 - dataset.schema_perm = None - dataset.url = None - mock_list.return_value = ([dataset], 1) - result = list_datasets(search="search_table") - assert result.count == 1 - assert result.datasets[0].table_name == "search_table" - # Ensure search and search_columns were passed - args, kwargs = mock_list.call_args - assert kwargs["search"] == "search_table" - assert "table_name" in kwargs["search_columns"] - assert "db_schema" in kwargs["search_columns"] - assert "description" in kwargs["search_columns"] - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_datasets_simple_with_search(self, mock_list): - """Test list_datasets_simple with a text search parameter""" - dataset = Mock() - dataset.id = 2 - dataset.table_name = "simple_search" - dataset.db_schema = "analytics" - dataset.database_name = "analytics_db" - dataset.database = None - dataset.description = "Another test dataset" - dataset.changed_by = "user" - dataset.changed_by_name = "user" - dataset.changed_on = None - dataset.changed_on_humanized = None - dataset.created_by = "user" - dataset.created_by_name = "user" - dataset.created_on = None - dataset.created_on_humanized = None - dataset.tags = [] - dataset.owners = [] - dataset.is_virtual = True - dataset.database_id = 2 - dataset.schema_perm = None - dataset.url = None - mock_list.return_value = ([dataset], 1) - result = list_datasets(search="simple_search") - assert result.count == 1 - assert result.datasets[0].table_name == "simple_search" - # Ensure search and search_columns were passed - args, kwargs = mock_list.call_args - assert kwargs["search"] == "simple_search" - assert "table_name" in kwargs["search_columns"] - assert "db_schema" in kwargs["search_columns"] - assert "description" in kwargs["search_columns"] - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_datasets_simple_basic(self, mock_list): - """Test list_datasets_simple with basic parameters""" - dataset = Mock() - dataset.id = 1 - dataset.table_name = "Test Dataset" - dataset.schema = "main" - dataset.description = "desc" - dataset.changed_by_name = "admin" - dataset.changed_on = None - dataset.changed_on_humanized = None - dataset.created_by_name = "admin" - dataset.created_on = None - dataset.created_on_humanized = None - dataset.tags = [] - dataset.owners = [] - dataset.is_virtual = False - dataset.database_id = 1 - dataset.schema_perm = "[examples].[main]" - dataset.url = "/tablemodelview/edit/1" - dataset.database = Mock() - dataset.database.database_name = "examples" - mock_list.return_value = ([dataset], 1) - filters = [{"col": "table_name", "opr": "eq", "value": "Test Dataset"}, {"col": "schema", "opr": "eq", "value": "main"}] - result = list_datasets(filters=filters) - assert isinstance(result, DatasetListResponse) - assert result.count == 1 - assert result.datasets[0].table_name == "Test Dataset" - assert result.datasets[0].database_name == "examples" - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_datasets_simple_with_filters(self, mock_list): - """Test list_datasets_simple with various filter parameters (converted to advanced filters)""" - mock_list.return_value = ([], 0) - filters = [{"col": "table_name", "opr": "sw", "value": "Sales"}, {"col": "schema", "opr": "eq", "value": "main"}] - result = list_datasets(filters=filters) - assert isinstance(result, DatasetListResponse) - assert result.count == 0 - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_datasets_simple_api_error(self, mock_list): - """Test list_datasets_simple with API error""" - mock_list.side_effect = Exception("API request failed") - filters = [{"col": "table_name", "opr": "sw", "value": "Sales"}, {"col": "schema", "opr": "eq", "value": "main"}] - with pytest.raises(Exception) as excinfo: - list_datasets(filters=filters) - assert "API request failed" in str(excinfo.value) - - -class TestFastMCPServerIntegration: - """Test FastMCP server integration and tool registration""" - - def test_fastmcp_server_initialization(self): - """Test that FastMCP server can be initialized""" - from superset.mcp_service.server import init_fastmcp_server - mcp = init_fastmcp_server() - from fastmcp import FastMCP - assert isinstance(mcp, FastMCP) - assert mcp.name == "Superset MCP Server" - - def test_tool_registration(self): - """Test that all tools are properly registered""" - from superset.mcp_service.server import init_fastmcp_server - mcp = init_fastmcp_server() - import asyncio - if hasattr(mcp, 'tools'): - registered_tools = [tool.name for tool in mcp.tools] - elif hasattr(mcp, 'get_tools'): - tools_result = mcp.get_tools() - if asyncio.iscoroutine(tools_result): - tools_result = asyncio.run(tools_result) - registered_tools = list(tools_result) - else: - registered_tools = [] - from superset.mcp_service.tools.dashboard import list_dashboards, get_dashboard_info, get_dashboard_available_filters - from superset.mcp_service.tools.system import get_superset_instance_info - from superset.mcp_service.tools.dataset import list_datasets - # If we can import them without error, they're registered - assert list_dashboards is not None - assert get_dashboard_info is not None - assert get_superset_instance_info is not None - assert get_dashboard_available_filters is not None - assert list_datasets is not None - return # Test passed - if registered_tools: - expected_tools = [ - "list_dashboards", - "get_dashboard_info", - "get_superset_instance_info", - "get_dashboard_available_filters", - "list_datasets", - ] - for tool_name in expected_tools: - assert tool_name in registered_tools - else: - # Updated imports for new tool structure - from superset.mcp_service.tools.dashboard import list_dashboards, get_dashboard_info, get_dashboard_available_filters - from superset.mcp_service.tools.system import get_superset_instance_info - from superset.mcp_service.tools.dataset import list_datasets - assert list_dashboards is not None - assert get_dashboard_info is not None - assert get_superset_instance_info is not None - assert get_dashboard_available_filters is not None - assert list_datasets is not None - return # Test passed - - -class TestErrorHandling: - """Test error handling in FastMCP tools""" - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_dashboards_exception_handling(self, mock_list): - """Test list_dashboards handles exceptions gracefully""" - mock_list.side_effect = Exception("Unexpected error") - with pytest.raises(Exception) as excinfo: - list_dashboards() - assert "Unexpected error" in str(excinfo.value) - - def test_get_dashboard_available_filters_exception_handling(self): - """Test get_dashboard_available_filters handles exceptions gracefully""" - # This tool doesn't make API calls, so we test with a different approach - # We'll test that it returns the expected structure even if there are issues - result = get_dashboard_available_filters() - - # Should always return a valid structure - assert isinstance(result, DashboardAvailableFiltersResponse) - assert hasattr(result, "filters") - assert hasattr(result, "operators") - assert hasattr(result, "columns") - - def test_list_datasets_exception_handling(self): - """Test list_datasets handles exceptions gracefully""" - # This tool doesn't make API calls, so we test with a different approach - # We'll test that it returns the expected structure even if there are issues - result = list_datasets() - # Should always return a valid structure (dict or DatasetListResponse) - assert isinstance(result, (dict, DatasetListResponse)) - if isinstance(result, dict): - assert "count" in result - assert "datasets" in result - else: - assert hasattr(result, "count") - assert hasattr(result, "datasets") - - -class TestParameterValidation: - """Test parameter validation and parsing""" - - def test_list_dashboards_parameter_types(self): - """Test list_dashboards handles different parameter types correctly""" - with patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') as mock_list: - mock_list.return_value = ([], 0) - - # Test with string filters - list_dashboards(filters='[{"col": "test", "opr": "eq", "value": "value"}]') - - # Test with list filters - list_dashboards(filters=[{"col": "test", "opr": "eq", "value": "value"}]) - - # Test with string select_columns - list_dashboards(select_columns="id,dashboard_title") - - # Test with list select_columns - list_dashboards(select_columns=["id", "dashboard_title"]) - - # Verify all calls were made - assert mock_list.call_count == 4 - - def test_list_dashboards_simple_parameter_types(self): - """Test list_dashboards_simple handles different parameter types correctly""" - with patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') as mock_list: - mock_list.return_value = ([], 0) - filters = [{"col": "dashboard_title", "opr": "eq", "value": "Sales"}, {"col": "published", "opr": "eq", "value": True}] - result = list_dashboards(filters=filters) - assert isinstance(result, DashboardListResponse) - - def test_list_datasets_parameter_types(self): - """Test list_datasets handles different parameter types correctly""" - with patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') as mock_list: - mock_list.return_value = ([], 0) - list_datasets(filters='[{"col": "test", "opr": "eq", "value": "value"}]') - list_datasets(filters=[{"col": "test", "opr": "eq", "value": "value"}]) - list_datasets(select_columns="id,table_name") - list_datasets(select_columns=["id", "table_name"]) - assert mock_list.call_count == 4 - - def test_list_datasets_simple_parameter_types(self): - """Test list_datasets_simple handles different parameter types correctly""" - with patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') as mock_list: - mock_list.return_value = ([], 0) - filters = [{"col": "table_name", "opr": "eq", "value": "test"}, {"col": "schema", "opr": "eq", "value": "main"}] - result = list_datasets(filters=filters) - assert isinstance(result, DatasetListResponse) - - -class TestFastMCPInMemoryProtocol: - """ - In-memory protocol-level tests for the FastMCP server, following best practices from: - https://www.jlowin.dev/blog/stop-vibe-testing-mcp-servers - - These tests require pytest-asyncio to be installed and enabled. - - Use fastmcp.Client(mcp) to call tools as an agent would (no network, no subprocess) - - Assert on tool discovery, valid/invalid calls, error envelopes, and schema validation - - Cover edge cases and chaos agent scenarios (missing/extra/wrong-type/malformed input) - - Ensure deterministic, robust, and agent-ready MCP server behavior - """ - @pytest.mark.asyncio - async def test_tool_listing(self): - """Test that all expected tools are discoverable via the MCP protocol.""" - from superset.mcp_service.server import init_fastmcp_server - mcp = init_fastmcp_server() - from fastmcp import Client - async with Client(mcp) as client: - tools = await client.list_tools() - tool_names = [t.name for t in tools] - expected = [ - "list_dashboards", "get_dashboard_info", - "get_superset_instance_info", "get_dashboard_available_filters", - "get_dataset_available_filters", "list_datasets", - "list_charts", "get_chart_info", "get_chart_available_filters", - "get_dataset_info", "create_chart_simple" - ] - for name in expected: - assert name in tool_names - - @pytest.mark.asyncio - async def test_valid_list_dashboards_call(self): - """Test a valid call to list_dashboards via the MCP protocol.""" - from superset.mcp_service.server import init_fastmcp_server - mcp = init_fastmcp_server() - from fastmcp import Client - async with Client(mcp) as client: - result = await client.call_tool("list_dashboards", {"page": 1, "page_size": 2}) - # Should return a CallToolResult with expected attributes - assert isinstance(result, CallToolResult) - assert hasattr(result, "data") - assert hasattr(result, "structured_content") - # Optionally check the structure of the returned data - assert hasattr(result.data, "dashboards") - assert hasattr(result.data, "count") - - @pytest.mark.asyncio - async def test_missing_required_param(self): - """ - Test calling a tool with a missing 'page' parameter (should succeed, as 'page' is treated as optional and defaults to 1). - """ - from superset.mcp_service.server import init_fastmcp_server - mcp = init_fastmcp_server() - from fastmcp import Client - async with Client(mcp) as client: - result = await client.call_tool("list_dashboards", {"page_size": 2}) - # Should return a valid CallToolResult, as 'page' defaults to 1 - assert isinstance(result, CallToolResult) - assert hasattr(result, "data") - assert hasattr(result, "structured_content") - assert hasattr(result.data, "dashboards") - assert hasattr(result.data, "count") - - @pytest.mark.asyncio - async def test_wrong_type_param(self): - """Test calling a tool with a wrong-type parameter (should return error).""" - from superset.mcp_service.server import init_fastmcp_server - mcp = init_fastmcp_server() - from fastmcp import Client - async with Client(mcp) as client: - # Should raise ToolError due to wrong type - with pytest.raises(ToolError): - await client.call_tool("list_dashboards", {"page": "not_an_int", "page_size": 2}) - - @pytest.mark.asyncio - async def test_extra_param(self): - """Test calling a tool with an extra, unexpected parameter (should ignore or error).""" - from superset.mcp_service.server import init_fastmcp_server - mcp = init_fastmcp_server() - from fastmcp import Client - async with Client(mcp) as client: - # Should raise ToolError due to unexpected keyword argument - with pytest.raises(ToolError): - await client.call_tool("list_dashboards", {"page": 1, "page_size": 2, "unexpected": 123}) - - @pytest.mark.asyncio - async def test_malformed_input(self): - """Test calling a tool with completely malformed input (should return error).""" - from superset.mcp_service.server import init_fastmcp_server - mcp = init_fastmcp_server() - from fastmcp import Client - async with Client(mcp) as client: - # Should raise ToolError due to invalid input type - with pytest.raises(Exception): - await client.call_tool("list_dashboards", "this is not a dict") - - @pytest.mark.asyncio - async def test_error_envelope_on_internal_error(self): - """Test that an internal error in the tool returns a proper error envelope.""" - from superset.mcp_service.server import init_fastmcp_server - mcp = init_fastmcp_server() - from fastmcp import Client - async with Client(mcp) as client: - # Should raise ToolError for unknown tool - with pytest.raises(ToolError): - await client.call_tool("not_a_real_tool", {}) - - -class TestChartTools: - """Test chart-related FastMCP tools""" - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_charts_with_search(self, mock_list): - """Test list_charts with a text search parameter""" - from superset.mcp_service.tools.chart import list_charts - chart = Mock() - chart.id = 1 - chart.slice_name = "search_chart" - chart.viz_type = "bar" - chart.datasource_name = "test_ds" - chart.datasource_type = "table" - chart.url = "/chart/1" - chart.description = "desc" - chart.cache_timeout = 60 - chart.form_data = {} - chart.query_context = {} - chart.changed_by_name = "admin" - chart.changed_on = None - chart.changed_on_humanized = "1 day ago" - chart.created_by_name = "admin" - chart.created_on = None - chart.created_on_humanized = "2 days ago" - chart.tags = [] - chart.owners = [] - mock_list.return_value = ([chart], 1) - result = list_charts(search="search_chart") - assert result.count == 1 - assert result.charts[0].slice_name == "search_chart" - # Ensure search and search_columns were passed - args, kwargs = mock_list.call_args - assert kwargs["search"] == "search_chart" - assert "slice_name" in kwargs["search_columns"] - assert "viz_type" in kwargs["search_columns"] - assert "datasource_name" in kwargs["search_columns"] - - @patch('superset.mcp_service.dao_wrapper.MCPDAOWrapper.list') - def test_list_charts_simple_with_search(self, mock_list): - """Test list_charts_simple with a text search parameter""" - from superset.mcp_service.tools.chart import list_charts_simple - chart = Mock() - chart.id = 2 - chart.slice_name = "simple_search" - chart.viz_type = "line" - chart.datasource_name = "simple_ds" - chart.datasource_type = "table" - chart.url = "/chart/2" - chart.description = "desc2" - chart.cache_timeout = 120 - chart.form_data = {} - chart.query_context = {} - chart.changed_by_name = "user" - chart.changed_on = None - chart.changed_on_humanized = "3 days ago" - chart.created_by_name = "user" - chart.created_on = None - chart.created_on_humanized = "4 days ago" - chart.tags = [] - chart.owners = [] - mock_list.return_value = ([chart], 1) - result = list_charts_simple(search="simple_search") - assert result.count == 1 - assert result.charts[0].slice_name == "simple_search" - # Ensure search and search_columns were passed - args, kwargs = mock_list.call_args - assert kwargs["search"] == "simple_search" - assert "slice_name" in kwargs["search_columns"] - assert "viz_type" in kwargs["search_columns"] - assert "datasource_name" in kwargs["search_columns"] - - -if __name__ == "__main__": - pytest.main([__file__]) diff --git a/tests/unit_tests/mcp_service/test_protocol_integration.py b/tests/unit_tests/mcp_service/test_protocol_integration.py new file mode 100644 index 00000000000..758eb19cf8b --- /dev/null +++ b/tests/unit_tests/mcp_service/test_protocol_integration.py @@ -0,0 +1,108 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Protocol/integration tests for MCP service using fastmcp.Client +""" +import pytest +from fastmcp.client.client import CallToolResult +from fastmcp.exceptions import ToolError + +class TestFastMCPInMemoryProtocol: + """ + In-memory protocol-level tests for the FastMCP server. + These tests require pytest-asyncio to be installed and enabled. + """ + @pytest.mark.asyncio + async def test_tool_listing(self): + from superset.mcp_service.server import init_fastmcp_server + mcp = init_fastmcp_server() + from fastmcp import Client + async with Client(mcp) as client: + tools = await client.list_tools() + tool_names = [t.name for t in tools] + expected = [ + "list_dashboards", "get_dashboard_info", + "get_superset_instance_info", "get_dashboard_available_filters", + "get_dataset_available_filters", "list_datasets", + "list_charts", "get_chart_info", "get_chart_available_filters", + "get_dataset_info", "create_chart_simple" + ] + for name in expected: + assert name in tool_names + + @pytest.mark.asyncio + async def test_valid_list_dashboards_call(self): + from superset.mcp_service.server import init_fastmcp_server + mcp = init_fastmcp_server() + from fastmcp import Client + async with Client(mcp) as client: + result = await client.call_tool("list_dashboards", {"page": 1, "page_size": 2}) + assert isinstance(result, CallToolResult) + assert hasattr(result, "data") + assert hasattr(result, "structured_content") + assert hasattr(result.data, "dashboards") + assert hasattr(result.data, "count") + + @pytest.mark.asyncio + async def test_missing_required_param(self): + from superset.mcp_service.server import init_fastmcp_server + mcp = init_fastmcp_server() + from fastmcp import Client + async with Client(mcp) as client: + result = await client.call_tool("list_dashboards", {"page_size": 2}) + assert isinstance(result, CallToolResult) + assert hasattr(result, "data") + assert hasattr(result, "structured_content") + assert hasattr(result.data, "dashboards") + assert hasattr(result.data, "count") + + @pytest.mark.asyncio + async def test_wrong_type_param(self): + from superset.mcp_service.server import init_fastmcp_server + mcp = init_fastmcp_server() + from fastmcp import Client + async with Client(mcp) as client: + with pytest.raises(ToolError): + await client.call_tool("list_dashboards", {"page": "not_an_int", "page_size": 2}) + + @pytest.mark.asyncio + async def test_extra_param(self): + from superset.mcp_service.server import init_fastmcp_server + mcp = init_fastmcp_server() + from fastmcp import Client + async with Client(mcp) as client: + with pytest.raises(ToolError): + await client.call_tool("list_dashboards", {"page": 1, "page_size": 2, "unexpected": 123}) + + @pytest.mark.asyncio + async def test_malformed_input(self): + from superset.mcp_service.server import init_fastmcp_server + mcp = init_fastmcp_server() + from fastmcp import Client + async with Client(mcp) as client: + with pytest.raises(Exception): + await client.call_tool("list_dashboards", "this is not a dict") + + @pytest.mark.asyncio + async def test_error_envelope_on_internal_error(self): + from superset.mcp_service.server import init_fastmcp_server + mcp = init_fastmcp_server() + from fastmcp import Client + async with Client(mcp) as client: + with pytest.raises(ToolError): + await client.call_tool("not_a_real_tool", {}) \ No newline at end of file diff --git a/tests/unit_tests/mcp_service/test_system_tools.py b/tests/unit_tests/mcp_service/test_system_tools.py new file mode 100644 index 00000000000..8fa972606e6 --- /dev/null +++ b/tests/unit_tests/mcp_service/test_system_tools.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP system tools (get_superset_instance_info) +""" +import logging +from unittest.mock import Mock, patch +import pytest +from flask import Flask, g +from flask_login import AnonymousUserMixin +from superset.mcp_service.pydantic_schemas.system_schemas import InstanceInfo, InstanceSummary +from superset.mcp_service.tools.system import get_superset_instance_info + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + +class TestSystemTools: + """Test system-related MCP tools""" + + @patch('superset.extensions.db') + def test_get_superset_instance_info_success(self, mock_db): + mock_app = Mock() + mock_app.app_context.return_value.__enter__ = Mock() + mock_app.app_context.return_value.__exit__ = Mock() + mock_session = Mock() + mock_db.session = mock_session + mock_session.query.return_value.join.return_value.distinct.return_value.count.return_value = 5 + mock_session.query.return_value.count.return_value = 10 + app = Flask(__name__) + app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + with app.app_context(): + g.user = AnonymousUserMixin() + with patch('superset.mcp_service.tools.system.get_superset_instance_info.MCPDAOWrapper.count', side_effect=[ + 10, # total_dashboards + 10, # total_charts + 10, # total_datasets + 10, # total_databases + 10, # total_users + 10, # total_tags + 2, # recent_dashboards + 2, # recent_charts + 2, # recent_datasets + 2, # recently_modified_dashboards + 2, # recently_modified_charts + 2, # recently_modified_datasets + 5, # published_dashboards + 3, # certified_dashboards + ]): + result = get_superset_instance_info() + del g.user + assert isinstance(result, InstanceInfo) + assert isinstance(result.instance_summary, InstanceSummary) + assert result.instance_summary.total_dashboards == 10 + assert result.instance_summary.total_charts == 10 + assert result.instance_summary.total_datasets == 10 + assert result.instance_summary.total_databases == 10 + assert result.instance_summary.total_users == 10 + assert result.instance_summary.total_tags == 10 + assert result.instance_summary.avg_charts_per_dashboard == 1.0 + + @patch('superset.extensions.db') + def test_get_superset_instance_info_failure(self, mock_db): + mock_app = Mock() + mock_app.app_context.return_value.__enter__ = Mock() + mock_app.app_context.return_value.__exit__ = Mock() + mock_session = Mock() + mock_db.session = mock_session + mock_session.query.side_effect = Exception("Database connection failed") + app = Flask(__name__) + app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False + with app.app_context(): + g.user = AnonymousUserMixin() + with pytest.raises(Exception) as excinfo: + get_superset_instance_info() + assert "Database connection failed" in str(excinfo.value) \ No newline at end of file