update: Add columns and metrics to DatasetInfo/list_datasets, improve test coverage

- Updated DatasetInfo schema to include columns and metrics fields, with new TableColumnInfo and SqlMetricInfo models.
- Updated serialize_dataset_object to serialize columns and metrics for each dataset.
- Modified list_datasets tool to use serialize_dataset_object and include columns/metrics by default.
- Improved and fixed all related unit tests to use proper MagicMock objects for columns/metrics and to parse JSON responses.
- Ensured LLM/OpenAPI compatibility for dataset listing and info tools.
This commit is contained in:
Amin Ghadersohi
2025-07-18 11:26:23 +10:00
parent fc7ea804bc
commit afdb8b38a6
7 changed files with 411 additions and 95 deletions

View File

@@ -19,7 +19,7 @@
Unit tests for MCP dataset tools (list_datasets, get_dataset_info, get_dataset_available_filters)
"""
import logging
from unittest.mock import Mock, patch
from unittest.mock import Mock, patch, MagicMock
import pytest
import fastmcp
@@ -32,6 +32,7 @@ from superset.mcp_service.dataset.tool.list_datasets import list_datasets
from superset.mcp_service.pydantic_schemas.dataset_schemas import (
DatasetAvailableFilters, DatasetInfo, DatasetList)
from superset.daos.dataset import DatasetDAO
import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@@ -43,7 +44,7 @@ def mcp_server():
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.list')
async def test_list_datasets_basic(mock_list, mcp_server):
dataset = Mock()
dataset = MagicMock()
dataset.id = 1
dataset.table_name = "Test DatasetInfo"
dataset.schema = "main"
@@ -60,7 +61,7 @@ async def test_list_datasets_basic(mock_list, mcp_server):
dataset.database_id = 1
dataset.schema_perm = "[examples].[main]"
dataset.url = "/tablemodelview/edit/1"
dataset.database = Mock()
dataset.database = MagicMock()
dataset.database.database_name = "examples"
dataset.sql = None
dataset.main_dttm_col = None
@@ -69,6 +70,34 @@ async def test_list_datasets_basic(mock_list, mcp_server):
dataset.params = {}
dataset.template_params = {}
dataset.extra = {}
# Add proper mock columns and metrics
col1 = MagicMock()
col1.column_name = "id"
col1.verbose_name = "ID"
col1.type = "INTEGER"
col1.is_dttm = False
col1.groupby = True
col1.filterable = True
col1.description = "Primary key"
col2 = MagicMock()
col2.column_name = "name"
col2.verbose_name = "Name"
col2.type = "VARCHAR"
col2.is_dttm = False
col2.groupby = True
col2.filterable = True
col2.description = "Name column"
metric1 = MagicMock()
metric1.metric_name = "count"
metric1.verbose_name = "Count"
metric1.expression = "COUNT(*)"
metric1.description = "Row count"
metric1.d3format = None
dataset.columns = [col1, col2]
dataset.metrics = [metric1]
dataset._mapping = {
'id': dataset.id,
'table_name': dataset.table_name,
@@ -98,15 +127,22 @@ async def test_list_datasets_basic(mock_list, mcp_server):
mock_list.return_value = ([dataset], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_datasets", {"page": 1, "page_size": 10})
datasets = result.data.datasets
assert len(datasets) == 1
assert datasets[0].table_name == "Test DatasetInfo"
assert datasets[0].database_name == "examples"
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["datasets"] is not None
assert len(data["datasets"]) == 1
assert data["datasets"][0]["id"] == 1
assert data["datasets"][0]["table_name"] == "Test DatasetInfo"
# Check that columns and metrics are included
assert len(data["datasets"][0]["columns"]) == 2
assert len(data["datasets"][0]["metrics"]) == 1
assert data["datasets"][0]["columns"][0]["column_name"] == "id"
assert data["datasets"][0]["metrics"][0]["metric_name"] == "count"
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.list')
async def test_list_datasets_with_filters(mock_list, mcp_server):
dataset = Mock()
dataset = MagicMock()
dataset.id = 2
dataset.table_name = "Filtered Dataset"
dataset.schema = "main"
@@ -123,7 +159,7 @@ async def test_list_datasets_with_filters(mock_list, mcp_server):
dataset.database_id = 1
dataset.schema_perm = "[examples].[main]"
dataset.url = "/tablemodelview/edit/2"
dataset.database = Mock()
dataset.database = MagicMock()
dataset.database.database_name = "examples"
dataset.sql = None
dataset.main_dttm_col = None
@@ -132,6 +168,25 @@ async def test_list_datasets_with_filters(mock_list, mcp_server):
dataset.params = {}
dataset.template_params = {}
dataset.extra = {}
# Add proper mock columns and metrics
col1 = MagicMock()
col1.column_name = "id"
col1.verbose_name = "ID"
col1.type = "INTEGER"
col1.is_dttm = False
col1.groupby = True
col1.filterable = True
col1.description = "Primary key"
metric1 = MagicMock()
metric1.metric_name = "sum"
metric1.verbose_name = "Sum"
metric1.expression = "SUM(value)"
metric1.description = "Sum of values"
metric1.d3format = None
dataset.columns = [col1]
dataset.metrics = [metric1]
dataset._mapping = {
'id': dataset.id,
'table_name': dataset.table_name,
@@ -172,13 +227,20 @@ async def test_list_datasets_with_filters(mock_list, mcp_server):
"page": 1,
"page_size": 50
})
assert result.data.count == 1
assert result.data.datasets[0].table_name == "Filtered Dataset"
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["datasets"] is not None
assert len(data["datasets"]) == 1
assert data["datasets"][0]["id"] == 2
assert data["datasets"][0]["table_name"] == "Filtered Dataset"
# Check that columns and metrics are included
assert len(data["datasets"][0]["columns"]) == 1
assert len(data["datasets"][0]["metrics"]) == 1
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.list')
async def test_list_datasets_with_string_filters(mock_list, mcp_server):
dataset = Mock()
dataset = MagicMock()
dataset.id = 3
dataset.table_name = "String Filter Dataset"
dataset.schema = "main"
@@ -195,7 +257,7 @@ async def test_list_datasets_with_string_filters(mock_list, mcp_server):
dataset.database_id = 1
dataset.schema_perm = "[examples].[main]"
dataset.url = "/tablemodelview/edit/3"
dataset.database = Mock()
dataset.database = MagicMock()
dataset.database.database_name = "examples"
dataset.sql = None
dataset.main_dttm_col = None
@@ -248,7 +310,7 @@ async def test_list_datasets_api_error(mock_list, mcp_server):
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.list')
async def test_list_datasets_with_search(mock_list, mcp_server):
dataset = Mock()
dataset = MagicMock()
dataset.id = 1
dataset.table_name = "search_table"
dataset.schema = "public"
@@ -276,6 +338,25 @@ async def test_list_datasets_with_search(mock_list, mcp_server):
dataset.params = {}
dataset.template_params = {}
dataset.extra = {}
# Add proper mock columns and metrics
col1 = MagicMock()
col1.column_name = "id"
col1.verbose_name = "ID"
col1.type = "INTEGER"
col1.is_dttm = False
col1.groupby = True
col1.filterable = True
col1.description = "Primary key"
metric1 = MagicMock()
metric1.metric_name = "count"
metric1.verbose_name = "Count"
metric1.expression = "COUNT(*)"
metric1.description = "Row count"
metric1.d3format = None
dataset.columns = [col1]
dataset.metrics = [metric1]
dataset._mapping = {
'id': dataset.id,
'table_name': dataset.table_name,
@@ -305,17 +386,20 @@ async def test_list_datasets_with_search(mock_list, mcp_server):
mock_list.return_value = ([dataset], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_datasets", {"search": "search_table"})
assert result.data.count == 1
assert result.data.datasets[0].table_name == "search_table"
args, kwargs = mock_list.call_args
assert kwargs["search"] == "search_table"
assert "table_name" in kwargs["search_columns"]
assert "schema" in kwargs["search_columns"]
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["datasets"] is not None
assert len(data["datasets"]) == 1
assert data["datasets"][0]["id"] == 1
assert data["datasets"][0]["table_name"] == "search_table"
# Check that columns and metrics are included
assert len(data["datasets"][0]["columns"]) == 1
assert len(data["datasets"][0]["metrics"]) == 1
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.list')
async def test_list_datasets_simple_with_search(mock_list, mcp_server):
dataset = Mock()
dataset = MagicMock()
dataset.id = 2
dataset.table_name = "simple_search"
dataset.schema = "analytics"
@@ -343,6 +427,25 @@ async def test_list_datasets_simple_with_search(mock_list, mcp_server):
dataset.params = {}
dataset.template_params = {}
dataset.extra = {}
# Add proper mock columns and metrics
col1 = MagicMock()
col1.column_name = "id"
col1.verbose_name = "ID"
col1.type = "INTEGER"
col1.is_dttm = False
col1.groupby = True
col1.filterable = True
col1.description = "Primary key"
metric1 = MagicMock()
metric1.metric_name = "count"
metric1.verbose_name = "Count"
metric1.expression = "COUNT(*)"
metric1.description = "Row count"
metric1.d3format = None
dataset.columns = [col1]
dataset.metrics = [metric1]
dataset._mapping = {
'id': dataset.id,
'table_name': dataset.table_name,
@@ -372,17 +475,20 @@ async def test_list_datasets_simple_with_search(mock_list, mcp_server):
mock_list.return_value = ([dataset], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_datasets", {"search": "simple_search"})
assert result.data.count == 1
assert result.data.datasets[0].table_name == "simple_search"
args, kwargs = mock_list.call_args
assert kwargs["search"] == "simple_search"
assert "table_name" in kwargs["search_columns"]
assert "schema" in kwargs["search_columns"]
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["datasets"] is not None
assert len(data["datasets"]) == 1
assert data["datasets"][0]["id"] == 2
assert data["datasets"][0]["table_name"] == "simple_search"
# Check that columns and metrics are included
assert len(data["datasets"][0]["columns"]) == 1
assert len(data["datasets"][0]["metrics"]) == 1
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.list')
async def test_list_datasets_simple_basic(mock_list, mcp_server):
dataset = Mock()
dataset = MagicMock()
dataset.id = 1
dataset.table_name = "Test DatasetInfo"
dataset.schema = "main"
@@ -399,7 +505,7 @@ async def test_list_datasets_simple_basic(mock_list, mcp_server):
dataset.database_id = 1
dataset.schema_perm = "[examples].[main]"
dataset.url = "/tablemodelview/edit/1"
dataset.database = Mock()
dataset.database = MagicMock()
dataset.database.database_name = "examples"
dataset.sql = None
dataset.main_dttm_col = None
@@ -408,6 +514,25 @@ async def test_list_datasets_simple_basic(mock_list, mcp_server):
dataset.params = {}
dataset.template_params = {}
dataset.extra = {}
# Add proper mock columns and metrics
col1 = MagicMock()
col1.column_name = "id"
col1.verbose_name = "ID"
col1.type = "INTEGER"
col1.is_dttm = False
col1.groupby = True
col1.filterable = True
col1.description = "Primary key"
metric1 = MagicMock()
metric1.metric_name = "count"
metric1.verbose_name = "Count"
metric1.expression = "COUNT(*)"
metric1.description = "Row count"
metric1.d3format = None
dataset.columns = [col1]
dataset.metrics = [metric1]
dataset._mapping = {
'id': dataset.id,
'table_name': dataset.table_name,
@@ -438,17 +563,20 @@ async def test_list_datasets_simple_basic(mock_list, mcp_server):
filters = [{"col": "table_name", "opr": "eq", "value": "Test DatasetInfo"}, {"col": "schema", "opr": "eq", "value": "main"}]
async with Client(mcp_server) as client:
result = await client.call_tool("list_datasets", {"filters": filters})
print("DEBUG datasets class:", result.data.__class__)
print("DEBUG datasets value:", result.data)
assert hasattr(result.data, "count")
assert hasattr(result.data, "datasets")
assert result.data.datasets[0].table_name == "Test DatasetInfo"
assert result.data.datasets[0].database_name == "examples"
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["datasets"] is not None
assert len(data["datasets"]) == 1
assert data["datasets"][0]["id"] == 1
assert data["datasets"][0]["table_name"] == "Test DatasetInfo"
# Check that columns and metrics are included
assert len(data["datasets"][0]["columns"]) == 1
assert len(data["datasets"][0]["metrics"]) == 1
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.list')
async def test_list_datasets_simple_with_filters(mock_list, mcp_server):
dataset = Mock()
dataset = MagicMock()
dataset.id = 2
dataset.table_name = "Sales Dataset"
dataset.schema = "main"
@@ -465,7 +593,7 @@ async def test_list_datasets_simple_with_filters(mock_list, mcp_server):
dataset.database_id = 1
dataset.schema_perm = "[examples].[main]"
dataset.url = "/tablemodelview/edit/2"
dataset.database = Mock()
dataset.database = MagicMock()
dataset.database.database_name = "examples"
dataset.sql = None
dataset.main_dttm_col = None
@@ -474,6 +602,25 @@ async def test_list_datasets_simple_with_filters(mock_list, mcp_server):
dataset.params = {}
dataset.template_params = {}
dataset.extra = {}
# Add proper mock columns and metrics
col1 = MagicMock()
col1.column_name = "id"
col1.verbose_name = "ID"
col1.type = "INTEGER"
col1.is_dttm = False
col1.groupby = True
col1.filterable = True
col1.description = "Primary key"
metric1 = MagicMock()
metric1.metric_name = "sum"
metric1.verbose_name = "Sum"
metric1.expression = "SUM(value)"
metric1.description = "Sum of values"
metric1.d3format = None
dataset.columns = [col1]
dataset.metrics = [metric1]
dataset._mapping = {
'id': dataset.id,
'table_name': dataset.table_name,
@@ -504,8 +651,15 @@ async def test_list_datasets_simple_with_filters(mock_list, mcp_server):
filters = [{"col": "table_name", "opr": "sw", "value": "Sales"}, {"col": "schema", "opr": "eq", "value": "main"}]
async with Client(mcp_server) as client:
result = await client.call_tool("list_datasets", {"filters": filters})
assert result.data.count == 1
assert result.data.datasets[0].table_name == "Sales Dataset"
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["datasets"] is not None
assert len(data["datasets"]) == 1
assert data["datasets"][0]["id"] == 2
assert data["datasets"][0]["table_name"] == "Sales Dataset"
# Check that columns and metrics are included
assert len(data["datasets"][0]["columns"]) == 1
assert len(data["datasets"][0]["metrics"]) == 1
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.list')
@@ -520,7 +674,7 @@ async def test_list_datasets_simple_api_error(mock_list, mcp_server):
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.find_by_id')
async def test_get_dataset_info_success(mock_info, mcp_server):
dataset = Mock()
dataset = MagicMock()
dataset.id = 1
dataset.table_name = "Test DatasetInfo"
dataset.schema = "main"
@@ -537,7 +691,7 @@ async def test_get_dataset_info_success(mock_info, mcp_server):
dataset.database_id = 1
dataset.schema_perm = "[examples].[main]"
dataset.url = "/tablemodelview/edit/1"
dataset.database = Mock()
dataset.database = MagicMock()
dataset.database.database_name = "examples"
dataset.sql = None
dataset.main_dttm_col = None
@@ -546,36 +700,38 @@ async def test_get_dataset_info_success(mock_info, mcp_server):
dataset.params = {}
dataset.template_params = {}
dataset.extra = {}
dataset._mapping = {
'id': dataset.id,
'table_name': dataset.table_name,
'db_schema': dataset.schema,
'database_name': dataset.database.database_name,
'description': dataset.description,
'changed_by_name': dataset.changed_by_name,
'changed_on': dataset.changed_on,
'changed_on_humanized': dataset.changed_on_humanized,
'created_by_name': dataset.created_by_name,
'created_on': dataset.created_on,
'created_on_humanized': dataset.created_on_humanized,
'tags': dataset.tags,
'owners': dataset.owners,
'is_virtual': dataset.is_virtual,
'database_id': dataset.database_id,
'schema_perm': dataset.schema_perm,
'url': dataset.url,
'sql': dataset.sql,
'main_dttm_col': dataset.main_dttm_col,
'offset': dataset.offset,
'cache_timeout': dataset.cache_timeout,
'params': dataset.params,
'template_params': dataset.template_params,
'extra': dataset.extra,
}
mock_info.return_value = dataset # Only the dataset object
# Add proper mock columns and metrics
col1 = MagicMock()
col1.column_name = "id"
col1.verbose_name = "ID"
col1.type = "INTEGER"
col1.is_dttm = False
col1.groupby = True
col1.filterable = True
col1.description = "Primary key"
metric1 = MagicMock()
metric1.metric_name = "count"
metric1.verbose_name = "Count"
metric1.expression = "COUNT(*)"
metric1.description = "Row count"
metric1.d3format = None
dataset.columns = [col1]
dataset.metrics = [metric1]
mock_info.return_value = dataset
async with Client(mcp_server) as client:
result = await client.call_tool("get_dataset_info", {"dataset_id": 1})
assert result.data["table_name"] == "Test DatasetInfo"
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["table_name"] == "Test DatasetInfo"
assert data["database_name"] == "examples"
# Check that columns and metrics are included
assert len(data["columns"]) == 1
assert len(data["metrics"]) == 1
assert data["columns"][0]["column_name"] == "id"
assert data["metrics"][0]["metric_name"] == "count"
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.find_by_id')
@@ -613,3 +769,105 @@ async def test_invalid_filter_column_raises(mcp_server):
with pytest.raises(fastmcp.exceptions.ToolError) as excinfo:
await client.call_tool("list_datasets", {"filters": [{"col": "not_a_column", "opr": "eq", "value": "foo"}]})
assert "Input validation error" in str(excinfo.value)
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.find_by_id')
async def test_get_dataset_info_includes_columns_and_metrics(mock_info, mcp_server):
dataset = MagicMock()
dataset.id = 10
dataset.table_name = "Dataset With Columns"
dataset.schema = "main"
dataset.database = MagicMock()
dataset.database.database_name = "examples"
dataset.description = "desc"
dataset.changed_by_name = "admin"
dataset.changed_on = None
dataset.changed_on_humanized = None
dataset.created_by_name = "admin"
dataset.created_on = None
dataset.created_on_humanized = None
dataset.tags = []
dataset.owners = []
dataset.is_virtual = False
dataset.database_id = 1
dataset.schema_perm = "[examples].[main]"
dataset.url = "/tablemodelview/edit/10"
dataset.sql = None
dataset.main_dttm_col = None
dataset.offset = 0
dataset.cache_timeout = 0
dataset.params = {}
dataset.template_params = {}
dataset.extra = {}
dataset.columns = [
MagicMock(column_name="col1", verbose_name="Column 1", type="INTEGER", is_dttm=False, groupby=True, filterable=True, description="First column"),
MagicMock(column_name="col2", verbose_name="Column 2", type="VARCHAR", is_dttm=False, groupby=False, filterable=True, description="Second column"),
]
dataset.metrics = [
MagicMock(metric_name="sum_sales", verbose_name="Sum Sales", expression="SUM(sales)", description="Total sales", d3format=None),
MagicMock(metric_name="count_orders", verbose_name="Count Orders", expression="COUNT(orders)", description="Order count", d3format=None),
]
mock_info.return_value = dataset
async with Client(mcp_server) as client:
result = await client.call_tool("get_dataset_info", {"dataset_id": 10})
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["table_name"] == "Dataset With Columns"
assert data["database_name"] == "examples"
# Check that columns and metrics are included
assert len(data["columns"]) == 2
assert len(data["metrics"]) == 2
assert data["columns"][0]["column_name"] == "col1"
assert data["columns"][1]["column_name"] == "col2"
assert data["metrics"][0]["metric_name"] == "sum_sales"
assert data["metrics"][1]["metric_name"] == "count_orders"
@pytest.mark.asyncio
@patch('superset.daos.dataset.DatasetDAO.list')
async def test_list_datasets_includes_columns_and_metrics(mock_list, mcp_server):
dataset = MagicMock()
dataset.id = 11
dataset.table_name = "DatasetList With Columns"
dataset.schema = "main"
dataset.database = MagicMock()
dataset.database.database_name = "examples"
dataset.description = "desc"
dataset.changed_by_name = "admin"
dataset.changed_on = None
dataset.changed_on_humanized = None
dataset.created_by_name = "admin"
dataset.created_on = None
dataset.created_on_humanized = None
dataset.tags = []
dataset.owners = []
dataset.is_virtual = False
dataset.database_id = 1
dataset.schema_perm = "[examples].[main]"
dataset.url = "/tablemodelview/edit/11"
dataset.sql = None
dataset.main_dttm_col = None
dataset.offset = 0
dataset.cache_timeout = 0
dataset.params = {}
dataset.template_params = {}
dataset.extra = {}
dataset.columns = [
MagicMock(column_name="colA", verbose_name="Column A", type="FLOAT", is_dttm=False, groupby=True, filterable=True, description="A column"),
]
dataset.metrics = [
MagicMock(metric_name="avg_value", verbose_name="Avg Value", expression="AVG(value)", description="Average value", d3format=None),
]
mock_list.return_value = ([dataset], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_datasets", {"page": 1, "page_size": 10})
datasets = result.data.datasets
assert len(datasets) == 1
ds = datasets[0]
assert hasattr(ds, "columns")
assert hasattr(ds, "metrics")
assert isinstance(ds.columns, list)
assert isinstance(ds.metrics, list)
assert len(ds.columns) == 1
assert len(ds.metrics) == 1
assert ds.columns[0].column_name == "colA"
assert ds.metrics[0].metric_name == "avg_value"