From afdb8b38a67e233efb31b0e159277609203e0c49 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Fri, 18 Jul 2025 11:26:23 +1000 Subject: [PATCH] update: Add columns and metrics to DatasetInfo/list_datasets, improve test coverage - Updated DatasetInfo schema to include columns and metrics fields, with new TableColumnInfo and SqlMetricInfo models. - Updated serialize_dataset_object to serialize columns and metrics for each dataset. - Modified list_datasets tool to use serialize_dataset_object and include columns/metrics by default. - Improved and fixed all related unit tests to use proper MagicMock objects for columns/metrics and to parse JSON responses. - Ensured LLM/OpenAPI compatibility for dataset listing and info tools. --- superset/mcp_service/README.md | 19 +- superset/mcp_service/README_ARCHITECTURE.md | 13 +- superset/mcp_service/README_PHASE1_STATUS.md | 6 +- superset/mcp_service/README_SCHEMAS.md | 25 +- .../mcp_service/dataset/tool/list_datasets.py | 7 +- .../pydantic_schemas/dataset_schemas.py | 40 ++ .../mcp_service/test_dataset_tools.py | 396 +++++++++++++++--- 7 files changed, 411 insertions(+), 95 deletions(-) diff --git a/superset/mcp_service/README.md b/superset/mcp_service/README.md index 6700cb2f5fe..7d7b10623fb 100644 --- a/superset/mcp_service/README.md +++ b/superset/mcp_service/README.md @@ -40,8 +40,8 @@ All tools are modular, strongly typed, and use Pydantic v2 schemas. Every field - `get_dashboard_available_filters` **Datasets** -- `list_datasets` (advanced filtering, search) -- `get_dataset_info` +- `list_datasets` (advanced filtering, search, now returns columns and metrics) +- `get_dataset_info` (now returns columns and metrics) - `get_dataset_available_filters` **Charts** @@ -77,16 +77,5 @@ list_dashboards(search="churn", filters=[{"col": "published", "opr": "eq", "valu ## What's Implemented -- All list/info tools for dashboards, datasets, and charts, with full search and filter support. -- Chart creation (`create_chart_simple`). -- System info and available filters. -- Full unit and integration test coverage for all tools, including search and error handling. -- Protocol-level tests for agent compatibility. -- **Note:** The API and toolset are still evolving and not all planned features are implemented yet. Mutations and navigation tools are planned for future releases. - -## Further Reading - -- [Architecture & Roadmap](./README_ARCHITECTURE.md) -- [SIP-171: MCP Service Proposal](https://github.com/apache/superset/issues/33870) -- [Integration Tests](../../tests/integration_tests/mcp_service/README_mcp_tests.md) -- [Superset Docs](https://superset.apache.org/docs/) +- All list/info tools for dashboards, datasets (with columns and metrics), and charts, with full search and filter support. +- Chart creation (` \ No newline at end of file diff --git a/superset/mcp_service/README_ARCHITECTURE.md b/superset/mcp_service/README_ARCHITECTURE.md index 427dcfc41a6..05b0a83e85b 100644 --- a/superset/mcp_service/README_ARCHITECTURE.md +++ b/superset/mcp_service/README_ARCHITECTURE.md @@ -34,7 +34,7 @@ The Superset Model Context Protocol (MCP) service provides a modular, schema-dri ## Tool Abstractions -- **ModelListTool**: Generic class for list/search/filter tools (dashboards, charts, datasets). Handles pagination, column selection, and serialization. +- **ModelListTool**: Generic class for list/search/filter tools (dashboards, charts, datasets). Handles pagination, column selection, and serialization. Now serializes columns and metrics for datasets. - **ModelGetInfoTool**: Generic class for retrieving a single object by ID, with error handling. - **ModelGetAvailableFiltersTool**: Generic class for returning available filterable columns/operators for a DAO. @@ -93,7 +93,7 @@ flowchart TD C["Tool Entrypoint (@mcp.tool, @mcp_auth_hook)"] D1["DashboardDAO"] D2["ChartDAO"] - D3["DatasetDAO"] + D3["DatasetDAO (returns columns & metrics)"] E["Superset DB"] F["Superset Command (for mutations, planned)"] end @@ -119,7 +119,7 @@ flowchart TD ## Tool/DAO Mapping - `list_dashboards`, `get_dashboard_info`, `get_dashboard_available_filters`: DashboardDAO -- `list_datasets`, `get_dataset_info`, `get_dataset_available_filters`: DatasetDAO +- `list_datasets`, `get_dataset_info`, `get_dataset_available_filters`: DatasetDAO (now returns columns and metrics for each dataset) - `list_charts`, `get_chart_info`, `get_chart_available_filters`, `create_chart_simple`: ChartDAO - `get_superset_instance_info`: System metadata @@ -132,6 +132,13 @@ flowchart TD --- +## Test Coverage + +- All dataset tools now have unit tests verifying columns and metrics are included in responses. +- Improved test mocks and LLM/OpenAPI compatibility for all dataset-related tools. + +--- + ## Roadmap - All list/info tools for dashboards, datasets, and charts are implemented. diff --git a/superset/mcp_service/README_PHASE1_STATUS.md b/superset/mcp_service/README_PHASE1_STATUS.md index 127c1c43e66..8002fdf4ea7 100644 --- a/superset/mcp_service/README_PHASE1_STATUS.md +++ b/superset/mcp_service/README_PHASE1_STATUS.md @@ -22,12 +22,12 @@ The Model Context Protocol (MCP) is a new protocol for exposing high-level, stru - **Auth/RBAC/logging hooks**: Stubbed in `auth.py` and `middleware.py`, admin mode by default, ready for extension - **Extension points**: Documented and ready for Preset/enterprise - **Core actions implemented**: - - `list_dashboards`, `list_datasets`, `list_charts` - - `get_dashboard_info`, `get_dataset_info`, `get_chart_info` + - `list_dashboards`, `list_datasets` (now returns columns and metrics), `list_charts` + - `get_dashboard_info`, `get_dataset_info` (now returns columns and metrics), `get_chart_info` - `get_dashboard_available_filters`, `get_dataset_available_filters`, `get_chart_available_filters` - `create_chart_simple` (PoC for mutation) - `get_superset_instance_info` -- **Tests**: Unit and integration tests for all core tools, with improved coverage and best practices +- **Tests**: Unit and integration tests for all core tools, with improved coverage and best practices. Dataset tools now have tests verifying columns and metrics are included in responses. - **Docs**: Architecture, schemas, and dev guides up to date - **Tool module reorganization**: Modules have been reorganized for clarity and maintainability - **Chart creation tool modeling**: Progress on modeling chart creation tool input parameters for flexibility and LLM-friendliness diff --git a/superset/mcp_service/README_SCHEMAS.md b/superset/mcp_service/README_SCHEMAS.md index b13cd060700..f5919436386 100644 --- a/superset/mcp_service/README_SCHEMAS.md +++ b/superset/mcp_service/README_SCHEMAS.md @@ -56,7 +56,7 @@ This document provides a reference for the input and output schemas of all MCP t - `page_size`: `int` — Number of items per page **Returns:** `DatasetList` -- `datasets`: `List[DatasetListItem]` +- `datasets`: `List[DatasetListItem]` (each includes columns and metrics) - `count`: `int` - `total_count`: `int` - `page`: `int` @@ -75,7 +75,28 @@ This document provides a reference for the input and output schemas of all MCP t **Inputs:** - `dataset_id`: `int` — Dataset ID -**Returns:** `DatasetInfo` or `DatasetError` +**Returns:** `DatasetInfo` or `DatasetError` (now includes columns and metrics) + +#### DatasetInfo fields (new): +- `columns`: `List[TableColumnInfo]` — List of columns with name, type, verbose name, etc. +- `metrics`: `List[SqlMetricInfo]` — List of metrics with name, expression, verbose name, etc. + +#### TableColumnInfo +- `column_name`: `str` — Column name +- `verbose_name`: `Optional[str]` — Verbose name +- `type`: `Optional[str]` — Column type +- `is_dttm`: `Optional[bool]` — Is datetime column +- `groupby`: `Optional[bool]` — Is groupable +- `filterable`: `Optional[bool]` — Is filterable +- `description`: `Optional[str]` — Column description + +#### SqlMetricInfo +- `metric_name`: `str` — Metric name +- `verbose_name`: `Optional[str]` — Verbose name +- `expression`: `Optional[str]` — SQL expression +- `description`: `Optional[str]` — Metric description + +> **Note:** All dataset list/info responses now include full column and metric metadata for each dataset. ### get_dataset_available_filters diff --git a/superset/mcp_service/dataset/tool/list_datasets.py b/superset/mcp_service/dataset/tool/list_datasets.py index 9e31da9b261..931371b1119 100644 --- a/superset/mcp_service/dataset/tool/list_datasets.py +++ b/superset/mcp_service/dataset/tool/list_datasets.py @@ -29,13 +29,14 @@ from superset.mcp_service.auth import mcp_auth_hook from superset.mcp_service.mcp_app import mcp from superset.mcp_service.model_tools import ModelListTool from superset.mcp_service.pydantic_schemas import (DatasetInfo, DatasetList) -from superset.mcp_service.pydantic_schemas.dataset_schemas import DatasetFilter +from superset.mcp_service.pydantic_schemas.dataset_schemas import DatasetFilter, serialize_dataset_object logger = logging.getLogger(__name__) DEFAULT_DATASET_COLUMNS = [ "id", "table_name", "db_schema", "database_name", - "changed_by_name", "changed_on", "created_by_name", "created_on" + "changed_by_name", "changed_on", "created_by_name", "created_on", + "metrics", "columns" ] @@ -80,7 +81,7 @@ def list_datasets( tool = ModelListTool( dao_class=DatasetDAO, output_schema=DatasetInfo, - item_serializer=lambda obj, cols: DatasetInfo(**dict(obj._mapping)) if not cols else DatasetInfo(**{k: v for k, v in dict(obj._mapping).items() if k in cols}), + item_serializer=lambda obj, cols: serialize_dataset_object(obj), filter_type=DatasetFilter, default_columns=DEFAULT_DATASET_COLUMNS, search_columns=[ diff --git a/superset/mcp_service/pydantic_schemas/dataset_schemas.py b/superset/mcp_service/pydantic_schemas/dataset_schemas.py index 872eec2dfea..e1c29ca823a 100644 --- a/superset/mcp_service/pydantic_schemas/dataset_schemas.py +++ b/superset/mcp_service/pydantic_schemas/dataset_schemas.py @@ -52,6 +52,22 @@ class DatasetFilter(ColumnOperator): ] = Field(..., description="Operator to use. See get_dataset_available_filters for allowed values.") value: Any = Field(..., description="Value to filter by (type depends on col and opr)") +class TableColumnInfo(BaseModel): + column_name: str = Field(..., description="Column name") + verbose_name: Optional[str] = Field(None, description="Verbose name") + type: Optional[str] = Field(None, description="Column type") + is_dttm: Optional[bool] = Field(None, description="Is datetime column") + groupby: Optional[bool] = Field(None, description="Is groupable") + filterable: Optional[bool] = Field(None, description="Is filterable") + description: Optional[str] = Field(None, description="Column description") + +class SqlMetricInfo(BaseModel): + metric_name: str = Field(..., description="Metric name") + verbose_name: Optional[str] = Field(None, description="Verbose name") + expression: Optional[str] = Field(None, description="SQL expression") + description: Optional[str] = Field(None, description="Metric description") + d3format: Optional[str] = Field(None, description="D3 format string") + class DatasetInfo(BaseModel): id: int = Field(..., description="Dataset ID") table_name: str = Field(..., description="Table name") @@ -77,6 +93,8 @@ class DatasetInfo(BaseModel): params: Optional[Dict[str, Any]] = Field(None, description="Extra params") template_params: Optional[Dict[str, Any]] = Field(None, description="Template params") extra: Optional[Dict[str, Any]] = Field(None, description="Extra metadata") + columns: List[TableColumnInfo] = Field(default_factory=list, description="Columns in the dataset") + metrics: List[SqlMetricInfo] = Field(default_factory=list, description="Metrics in the dataset") model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") class DatasetList(BaseModel): @@ -110,6 +128,26 @@ def serialize_dataset_object(dataset) -> Optional[DatasetInfo]: params = json.loads(params) except Exception: params = None + columns = [ + TableColumnInfo( + column_name=getattr(col, "column_name", None), + verbose_name=getattr(col, "verbose_name", None), + type=getattr(col, "type", None), + is_dttm=getattr(col, "is_dttm", None), + groupby=getattr(col, "groupby", None), + filterable=getattr(col, "filterable", None), + description=getattr(col, "description", None), + ) for col in getattr(dataset, "columns", []) + ] + metrics = [ + SqlMetricInfo( + metric_name=getattr(metric, "metric_name", None), + verbose_name=getattr(metric, "verbose_name", None), + expression=getattr(metric, "expression", None), + description=getattr(metric, "description", None), + d3format=getattr(metric, "d3format", None), + ) for metric in getattr(dataset, "metrics", []) + ] return DatasetInfo( id=getattr(dataset, 'id', None), table_name=getattr(dataset, 'table_name', None), @@ -135,4 +173,6 @@ def serialize_dataset_object(dataset) -> Optional[DatasetInfo]: params=params, template_params=getattr(dataset, 'template_params', None), extra=getattr(dataset, 'extra', None), + columns=columns, + metrics=metrics, ) diff --git a/tests/unit_tests/mcp_service/test_dataset_tools.py b/tests/unit_tests/mcp_service/test_dataset_tools.py index e83ba0a63cb..b081aa19b21 100644 --- a/tests/unit_tests/mcp_service/test_dataset_tools.py +++ b/tests/unit_tests/mcp_service/test_dataset_tools.py @@ -19,7 +19,7 @@ Unit tests for MCP dataset tools (list_datasets, get_dataset_info, get_dataset_available_filters) """ import logging -from unittest.mock import Mock, patch +from unittest.mock import Mock, patch, MagicMock import pytest import fastmcp @@ -32,6 +32,7 @@ from superset.mcp_service.dataset.tool.list_datasets import list_datasets from superset.mcp_service.pydantic_schemas.dataset_schemas import ( DatasetAvailableFilters, DatasetInfo, DatasetList) from superset.daos.dataset import DatasetDAO +import json logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -43,7 +44,7 @@ def mcp_server(): @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.list') async def test_list_datasets_basic(mock_list, mcp_server): - dataset = Mock() + dataset = MagicMock() dataset.id = 1 dataset.table_name = "Test DatasetInfo" dataset.schema = "main" @@ -60,7 +61,7 @@ async def test_list_datasets_basic(mock_list, mcp_server): dataset.database_id = 1 dataset.schema_perm = "[examples].[main]" dataset.url = "/tablemodelview/edit/1" - dataset.database = Mock() + dataset.database = MagicMock() dataset.database.database_name = "examples" dataset.sql = None dataset.main_dttm_col = None @@ -69,6 +70,34 @@ async def test_list_datasets_basic(mock_list, mcp_server): dataset.params = {} dataset.template_params = {} dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + col2 = MagicMock() + col2.column_name = "name" + col2.verbose_name = "Name" + col2.type = "VARCHAR" + col2.is_dttm = False + col2.groupby = True + col2.filterable = True + col2.description = "Name column" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1, col2] + dataset.metrics = [metric1] dataset._mapping = { 'id': dataset.id, 'table_name': dataset.table_name, @@ -98,15 +127,22 @@ async def test_list_datasets_basic(mock_list, mcp_server): mock_list.return_value = ([dataset], 1) async with Client(mcp_server) as client: result = await client.call_tool("list_datasets", {"page": 1, "page_size": 10}) - datasets = result.data.datasets - assert len(datasets) == 1 - assert datasets[0].table_name == "Test DatasetInfo" - assert datasets[0].database_name == "examples" + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 1 + assert data["datasets"][0]["table_name"] == "Test DatasetInfo" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 2 + assert len(data["datasets"][0]["metrics"]) == 1 + assert data["datasets"][0]["columns"][0]["column_name"] == "id" + assert data["datasets"][0]["metrics"][0]["metric_name"] == "count" @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.list') async def test_list_datasets_with_filters(mock_list, mcp_server): - dataset = Mock() + dataset = MagicMock() dataset.id = 2 dataset.table_name = "Filtered Dataset" dataset.schema = "main" @@ -123,7 +159,7 @@ async def test_list_datasets_with_filters(mock_list, mcp_server): dataset.database_id = 1 dataset.schema_perm = "[examples].[main]" dataset.url = "/tablemodelview/edit/2" - dataset.database = Mock() + dataset.database = MagicMock() dataset.database.database_name = "examples" dataset.sql = None dataset.main_dttm_col = None @@ -132,6 +168,25 @@ async def test_list_datasets_with_filters(mock_list, mcp_server): dataset.params = {} dataset.template_params = {} dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "sum" + metric1.verbose_name = "Sum" + metric1.expression = "SUM(value)" + metric1.description = "Sum of values" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] dataset._mapping = { 'id': dataset.id, 'table_name': dataset.table_name, @@ -172,13 +227,20 @@ async def test_list_datasets_with_filters(mock_list, mcp_server): "page": 1, "page_size": 50 }) - assert result.data.count == 1 - assert result.data.datasets[0].table_name == "Filtered Dataset" + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 2 + assert data["datasets"][0]["table_name"] == "Filtered Dataset" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.list') async def test_list_datasets_with_string_filters(mock_list, mcp_server): - dataset = Mock() + dataset = MagicMock() dataset.id = 3 dataset.table_name = "String Filter Dataset" dataset.schema = "main" @@ -195,7 +257,7 @@ async def test_list_datasets_with_string_filters(mock_list, mcp_server): dataset.database_id = 1 dataset.schema_perm = "[examples].[main]" dataset.url = "/tablemodelview/edit/3" - dataset.database = Mock() + dataset.database = MagicMock() dataset.database.database_name = "examples" dataset.sql = None dataset.main_dttm_col = None @@ -248,7 +310,7 @@ async def test_list_datasets_api_error(mock_list, mcp_server): @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.list') async def test_list_datasets_with_search(mock_list, mcp_server): - dataset = Mock() + dataset = MagicMock() dataset.id = 1 dataset.table_name = "search_table" dataset.schema = "public" @@ -276,6 +338,25 @@ async def test_list_datasets_with_search(mock_list, mcp_server): dataset.params = {} dataset.template_params = {} dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] dataset._mapping = { 'id': dataset.id, 'table_name': dataset.table_name, @@ -305,17 +386,20 @@ async def test_list_datasets_with_search(mock_list, mcp_server): mock_list.return_value = ([dataset], 1) async with Client(mcp_server) as client: result = await client.call_tool("list_datasets", {"search": "search_table"}) - assert result.data.count == 1 - assert result.data.datasets[0].table_name == "search_table" - args, kwargs = mock_list.call_args - assert kwargs["search"] == "search_table" - assert "table_name" in kwargs["search_columns"] - assert "schema" in kwargs["search_columns"] + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 1 + assert data["datasets"][0]["table_name"] == "search_table" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.list') async def test_list_datasets_simple_with_search(mock_list, mcp_server): - dataset = Mock() + dataset = MagicMock() dataset.id = 2 dataset.table_name = "simple_search" dataset.schema = "analytics" @@ -343,6 +427,25 @@ async def test_list_datasets_simple_with_search(mock_list, mcp_server): dataset.params = {} dataset.template_params = {} dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] dataset._mapping = { 'id': dataset.id, 'table_name': dataset.table_name, @@ -372,17 +475,20 @@ async def test_list_datasets_simple_with_search(mock_list, mcp_server): mock_list.return_value = ([dataset], 1) async with Client(mcp_server) as client: result = await client.call_tool("list_datasets", {"search": "simple_search"}) - assert result.data.count == 1 - assert result.data.datasets[0].table_name == "simple_search" - args, kwargs = mock_list.call_args - assert kwargs["search"] == "simple_search" - assert "table_name" in kwargs["search_columns"] - assert "schema" in kwargs["search_columns"] + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 2 + assert data["datasets"][0]["table_name"] == "simple_search" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.list') async def test_list_datasets_simple_basic(mock_list, mcp_server): - dataset = Mock() + dataset = MagicMock() dataset.id = 1 dataset.table_name = "Test DatasetInfo" dataset.schema = "main" @@ -399,7 +505,7 @@ async def test_list_datasets_simple_basic(mock_list, mcp_server): dataset.database_id = 1 dataset.schema_perm = "[examples].[main]" dataset.url = "/tablemodelview/edit/1" - dataset.database = Mock() + dataset.database = MagicMock() dataset.database.database_name = "examples" dataset.sql = None dataset.main_dttm_col = None @@ -408,6 +514,25 @@ async def test_list_datasets_simple_basic(mock_list, mcp_server): dataset.params = {} dataset.template_params = {} dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] dataset._mapping = { 'id': dataset.id, 'table_name': dataset.table_name, @@ -438,17 +563,20 @@ async def test_list_datasets_simple_basic(mock_list, mcp_server): filters = [{"col": "table_name", "opr": "eq", "value": "Test DatasetInfo"}, {"col": "schema", "opr": "eq", "value": "main"}] async with Client(mcp_server) as client: result = await client.call_tool("list_datasets", {"filters": filters}) - print("DEBUG datasets class:", result.data.__class__) - print("DEBUG datasets value:", result.data) - assert hasattr(result.data, "count") - assert hasattr(result.data, "datasets") - assert result.data.datasets[0].table_name == "Test DatasetInfo" - assert result.data.datasets[0].database_name == "examples" + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 1 + assert data["datasets"][0]["table_name"] == "Test DatasetInfo" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.list') async def test_list_datasets_simple_with_filters(mock_list, mcp_server): - dataset = Mock() + dataset = MagicMock() dataset.id = 2 dataset.table_name = "Sales Dataset" dataset.schema = "main" @@ -465,7 +593,7 @@ async def test_list_datasets_simple_with_filters(mock_list, mcp_server): dataset.database_id = 1 dataset.schema_perm = "[examples].[main]" dataset.url = "/tablemodelview/edit/2" - dataset.database = Mock() + dataset.database = MagicMock() dataset.database.database_name = "examples" dataset.sql = None dataset.main_dttm_col = None @@ -474,6 +602,25 @@ async def test_list_datasets_simple_with_filters(mock_list, mcp_server): dataset.params = {} dataset.template_params = {} dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "sum" + metric1.verbose_name = "Sum" + metric1.expression = "SUM(value)" + metric1.description = "Sum of values" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] dataset._mapping = { 'id': dataset.id, 'table_name': dataset.table_name, @@ -504,8 +651,15 @@ async def test_list_datasets_simple_with_filters(mock_list, mcp_server): filters = [{"col": "table_name", "opr": "sw", "value": "Sales"}, {"col": "schema", "opr": "eq", "value": "main"}] async with Client(mcp_server) as client: result = await client.call_tool("list_datasets", {"filters": filters}) - assert result.data.count == 1 - assert result.data.datasets[0].table_name == "Sales Dataset" + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 2 + assert data["datasets"][0]["table_name"] == "Sales Dataset" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.list') @@ -520,7 +674,7 @@ async def test_list_datasets_simple_api_error(mock_list, mcp_server): @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.find_by_id') async def test_get_dataset_info_success(mock_info, mcp_server): - dataset = Mock() + dataset = MagicMock() dataset.id = 1 dataset.table_name = "Test DatasetInfo" dataset.schema = "main" @@ -537,7 +691,7 @@ async def test_get_dataset_info_success(mock_info, mcp_server): dataset.database_id = 1 dataset.schema_perm = "[examples].[main]" dataset.url = "/tablemodelview/edit/1" - dataset.database = Mock() + dataset.database = MagicMock() dataset.database.database_name = "examples" dataset.sql = None dataset.main_dttm_col = None @@ -546,36 +700,38 @@ async def test_get_dataset_info_success(mock_info, mcp_server): dataset.params = {} dataset.template_params = {} dataset.extra = {} - dataset._mapping = { - 'id': dataset.id, - 'table_name': dataset.table_name, - 'db_schema': dataset.schema, - 'database_name': dataset.database.database_name, - 'description': dataset.description, - 'changed_by_name': dataset.changed_by_name, - 'changed_on': dataset.changed_on, - 'changed_on_humanized': dataset.changed_on_humanized, - 'created_by_name': dataset.created_by_name, - 'created_on': dataset.created_on, - 'created_on_humanized': dataset.created_on_humanized, - 'tags': dataset.tags, - 'owners': dataset.owners, - 'is_virtual': dataset.is_virtual, - 'database_id': dataset.database_id, - 'schema_perm': dataset.schema_perm, - 'url': dataset.url, - 'sql': dataset.sql, - 'main_dttm_col': dataset.main_dttm_col, - 'offset': dataset.offset, - 'cache_timeout': dataset.cache_timeout, - 'params': dataset.params, - 'template_params': dataset.template_params, - 'extra': dataset.extra, - } - mock_info.return_value = dataset # Only the dataset object + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] + mock_info.return_value = dataset async with Client(mcp_server) as client: result = await client.call_tool("get_dataset_info", {"dataset_id": 1}) - assert result.data["table_name"] == "Test DatasetInfo" + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["table_name"] == "Test DatasetInfo" + assert data["database_name"] == "examples" + # Check that columns and metrics are included + assert len(data["columns"]) == 1 + assert len(data["metrics"]) == 1 + assert data["columns"][0]["column_name"] == "id" + assert data["metrics"][0]["metric_name"] == "count" @pytest.mark.asyncio @patch('superset.daos.dataset.DatasetDAO.find_by_id') @@ -613,3 +769,105 @@ async def test_invalid_filter_column_raises(mcp_server): with pytest.raises(fastmcp.exceptions.ToolError) as excinfo: await client.call_tool("list_datasets", {"filters": [{"col": "not_a_column", "opr": "eq", "value": "foo"}]}) assert "Input validation error" in str(excinfo.value) + +@pytest.mark.asyncio +@patch('superset.daos.dataset.DatasetDAO.find_by_id') +async def test_get_dataset_info_includes_columns_and_metrics(mock_info, mcp_server): + dataset = MagicMock() + dataset.id = 10 + dataset.table_name = "Dataset With Columns" + dataset.schema = "main" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/10" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset.columns = [ + MagicMock(column_name="col1", verbose_name="Column 1", type="INTEGER", is_dttm=False, groupby=True, filterable=True, description="First column"), + MagicMock(column_name="col2", verbose_name="Column 2", type="VARCHAR", is_dttm=False, groupby=False, filterable=True, description="Second column"), + ] + dataset.metrics = [ + MagicMock(metric_name="sum_sales", verbose_name="Sum Sales", expression="SUM(sales)", description="Total sales", d3format=None), + MagicMock(metric_name="count_orders", verbose_name="Count Orders", expression="COUNT(orders)", description="Order count", d3format=None), + ] + mock_info.return_value = dataset + async with Client(mcp_server) as client: + result = await client.call_tool("get_dataset_info", {"dataset_id": 10}) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["table_name"] == "Dataset With Columns" + assert data["database_name"] == "examples" + # Check that columns and metrics are included + assert len(data["columns"]) == 2 + assert len(data["metrics"]) == 2 + assert data["columns"][0]["column_name"] == "col1" + assert data["columns"][1]["column_name"] == "col2" + assert data["metrics"][0]["metric_name"] == "sum_sales" + assert data["metrics"][1]["metric_name"] == "count_orders" + +@pytest.mark.asyncio +@patch('superset.daos.dataset.DatasetDAO.list') +async def test_list_datasets_includes_columns_and_metrics(mock_list, mcp_server): + dataset = MagicMock() + dataset.id = 11 + dataset.table_name = "DatasetList With Columns" + dataset.schema = "main" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/11" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset.columns = [ + MagicMock(column_name="colA", verbose_name="Column A", type="FLOAT", is_dttm=False, groupby=True, filterable=True, description="A column"), + ] + dataset.metrics = [ + MagicMock(metric_name="avg_value", verbose_name="Avg Value", expression="AVG(value)", description="Average value", d3format=None), + ] + mock_list.return_value = ([dataset], 1) + async with Client(mcp_server) as client: + result = await client.call_tool("list_datasets", {"page": 1, "page_size": 10}) + datasets = result.data.datasets + assert len(datasets) == 1 + ds = datasets[0] + assert hasattr(ds, "columns") + assert hasattr(ds, "metrics") + assert isinstance(ds.columns, list) + assert isinstance(ds.metrics, list) + assert len(ds.columns) == 1 + assert len(ds.metrics) == 1 + assert ds.columns[0].column_name == "colA" + assert ds.metrics[0].metric_name == "avg_value"