feat(mcp): get sample data tool

This commit is contained in:
alexandrusoare
2026-04-14 15:09:15 +03:00
parent c2a35e2eea
commit 41d8aeff5e
5 changed files with 440 additions and 1 deletions

View File

@@ -446,6 +446,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402
list_databases,
)
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
get_column_sample_data,
get_dataset_info,
list_datasets,
)

View File

@@ -306,6 +306,36 @@ class GetDatasetInfoRequest(MetadataCacheControl):
]
class GetColumnSampleDataRequest(BaseModel):
"""Request schema for get_column_sample_data."""
dataset_id: int = Field(..., description="The dataset ID to query")
column_name: str = Field(
..., description="The column name to get distinct values for"
)
limit: int = Field(
default=20,
ge=1,
le=100,
description="Maximum number of distinct values to return (default 20, max 100)",
)
class ColumnSampleDataResponse(BaseModel):
"""Response schema for get_column_sample_data."""
dataset_id: int = Field(..., description="The dataset ID queried")
column_name: str = Field(..., description="The column name queried")
values: List[str | int | float | bool | None] = Field(
..., description="Distinct values found in the column"
)
count: int = Field(..., description="Number of distinct values returned")
truncated: bool = Field(
False,
description="True if more distinct values exist beyond the limit",
)
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
"""Parse a field that may be stored as a JSON string into a dict."""
value = getattr(obj, field_name, None)

View File

@@ -15,10 +15,12 @@
# specific language governing permissions and limitations
# under the License.
from .get_column_sample_data import get_column_sample_data
from .get_dataset_info import get_dataset_info
from .list_datasets import list_datasets
__all__ = [
"list_datasets",
"get_column_sample_data",
"get_dataset_info",
"list_datasets",
]

View File

@@ -0,0 +1,158 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Get column sample data FastMCP tool
This module contains the FastMCP tool for retrieving distinct values
from a dataset column, useful for building filters in charts.
"""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.exceptions import SupersetSecurityException
from superset.extensions import event_logger
from superset.mcp_service.dataset.schemas import (
ColumnSampleDataResponse,
DatasetError,
GetColumnSampleDataRequest,
)
logger = logging.getLogger(__name__)
@tool(
tags=["discovery"],
class_permission_name="Dataset",
annotations=ToolAnnotations(
title="Get column sample data",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_column_sample_data(
request: GetColumnSampleDataRequest, ctx: Context
) -> ColumnSampleDataResponse | DatasetError:
"""Get distinct values for a dataset column.
Returns up to `limit` distinct values from the specified column.
Useful for discovering valid filter values when building charts.
Respects row-level security and dataset fetch_values_predicate.
IMPORTANT FOR LLM CLIENTS:
- Use this tool BEFORE creating charts with filters to discover actual
column values instead of guessing
- Use get_dataset_info first to find column names and types
- Low-cardinality columns (gender, status, category) work best
Example usage:
```json
{
"dataset_id": 123,
"column_name": "gender",
"limit": 20
}
```
"""
await ctx.info(
"Retrieving column sample data: dataset_id=%s, column=%s, limit=%s"
% (request.dataset_id, request.column_name, request.limit)
)
try:
from superset.daos.dataset import DatasetDAO
with event_logger.log_context(action="mcp.get_column_sample_data.lookup"):
dataset = DatasetDAO.find_by_id(request.dataset_id)
if not dataset:
await ctx.warning(
"Dataset not found: dataset_id=%s" % (request.dataset_id,)
)
return DatasetError(
error=f"Dataset with ID {request.dataset_id} not found",
error_type="NotFound",
timestamp=datetime.now(timezone.utc),
)
try:
dataset.raise_for_access()
except SupersetSecurityException as ex:
await ctx.warning(
"Permission denied for dataset_id=%s: %s"
% (request.dataset_id, str(ex))
)
return DatasetError(
error=f"Permission denied for dataset {request.dataset_id}",
error_type="PermissionDenied",
timestamp=datetime.now(timezone.utc),
)
# Fetch one extra value to detect truncation without a COUNT query
fetch_limit = request.limit + 1
denormalize_column = not dataset.normalize_columns
with event_logger.log_context(action="mcp.get_column_sample_data.query"):
raw_values = dataset.values_for_column(
column_name=request.column_name,
limit=fetch_limit,
denormalize_column=denormalize_column,
)
truncated = len(raw_values) > request.limit
values = raw_values[: request.limit]
await ctx.info(
"Column sample data retrieved: dataset_id=%s, column=%s, "
"count=%s, truncated=%s"
% (request.dataset_id, request.column_name, len(values), truncated)
)
return ColumnSampleDataResponse(
dataset_id=request.dataset_id,
column_name=request.column_name,
values=values,
count=len(values),
truncated=truncated,
)
except KeyError:
await ctx.warning(
"Column not found: column=%s in dataset_id=%s"
% (request.column_name, request.dataset_id)
)
return DatasetError(
error=f"Column '{request.column_name}' does not exist "
f"in dataset {request.dataset_id}",
error_type="ColumnNotFound",
timestamp=datetime.now(timezone.utc),
)
except Exception as e:
await ctx.error(
"Column sample data retrieval failed: dataset_id=%s, column=%s, "
"error=%s, error_type=%s"
% (request.dataset_id, request.column_name, str(e), type(e).__name__)
)
return DatasetError(
error=f"Failed to get column sample data: {str(e)}",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,248 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client
from pydantic import ValidationError
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from superset.mcp_service.app import mcp
from superset.mcp_service.dataset.schemas import GetColumnSampleDataRequest
from superset.utils import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _create_mock_dataset(dataset_id=1, normalize_columns=True):
"""Create a mock dataset with values_for_column support."""
dataset = MagicMock()
dataset.id = dataset_id
dataset.normalize_columns = normalize_columns
dataset.raise_for_access = MagicMock()
dataset.values_for_column = MagicMock()
return dataset
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_column_sample_data_success(mock_find, mcp_server):
"""Test successful retrieval of column sample data."""
dataset = _create_mock_dataset()
dataset.values_for_column.return_value = ["Male", "Female", "Other"]
mock_find.return_value = dataset
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_column_sample_data",
{"request": {"dataset_id": 1, "column_name": "gender", "limit": 20}},
)
data = json.loads(result.content[0].text)
assert data["dataset_id"] == 1
assert data["column_name"] == "gender"
assert data["values"] == ["Male", "Female", "Other"]
assert data["count"] == 3
assert data["truncated"] is False
dataset.values_for_column.assert_called_once_with(
column_name="gender",
limit=21, # limit + 1 for truncation detection
denormalize_column=False, # normalize_columns=True -> denormalize=False
)
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_column_sample_data_truncated(mock_find, mcp_server):
"""Test truncation detection when more values exist than the limit."""
dataset = _create_mock_dataset()
# Return 4 values when limit is 3 (tool fetches limit+1=4)
dataset.values_for_column.return_value = ["A", "B", "C", "D"]
mock_find.return_value = dataset
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_column_sample_data",
{"request": {"dataset_id": 1, "column_name": "status", "limit": 3}},
)
data = json.loads(result.content[0].text)
assert data["values"] == ["A", "B", "C"]
assert data["count"] == 3
assert data["truncated"] is True
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_column_sample_data_dataset_not_found(mock_find, mcp_server):
"""Test error when dataset does not exist."""
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_column_sample_data",
{"request": {"dataset_id": 999, "column_name": "gender"}},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "NotFound"
assert "999" in data["error"]
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_column_sample_data_permission_denied(mock_find, mcp_server):
"""Test error when user lacks permission to access the dataset."""
dataset = _create_mock_dataset()
dataset.raise_for_access.side_effect = SupersetSecurityException(
SupersetError(
message="Access denied",
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.ERROR,
)
)
mock_find.return_value = dataset
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_column_sample_data",
{"request": {"dataset_id": 1, "column_name": "gender"}},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "PermissionDenied"
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_column_sample_data_column_not_found(mock_find, mcp_server):
"""Test error when column does not exist in the dataset."""
dataset = _create_mock_dataset()
dataset.values_for_column.side_effect = KeyError("nonexistent_col")
mock_find.return_value = dataset
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_column_sample_data",
{"request": {"dataset_id": 1, "column_name": "nonexistent_col"}},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ColumnNotFound"
assert "nonexistent_col" in data["error"]
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_column_sample_data_default_limit(mock_find, mcp_server):
"""Test that omitting limit defaults to 20."""
dataset = _create_mock_dataset()
dataset.values_for_column.return_value = list(range(10))
mock_find.return_value = dataset
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_column_sample_data",
{"request": {"dataset_id": 1, "column_name": "category"}},
)
data = json.loads(result.content[0].text)
assert data["count"] == 10
# Default limit=20, so fetch_limit should be 21
dataset.values_for_column.assert_called_once_with(
column_name="category",
limit=21,
denormalize_column=False,
)
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_column_sample_data_denormalize_column(mock_find, mcp_server):
"""Test that denormalize_column is set based on dataset.normalize_columns."""
dataset = _create_mock_dataset(normalize_columns=False)
dataset.values_for_column.return_value = ["val1"]
mock_find.return_value = dataset
async with Client(mcp_server) as client:
await client.call_tool(
"get_column_sample_data",
{"request": {"dataset_id": 1, "column_name": "col"}},
)
dataset.values_for_column.assert_called_once_with(
column_name="col",
limit=21,
denormalize_column=True, # normalize_columns=False -> denormalize=True
)
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_column_sample_data_with_none_values(mock_find, mcp_server):
"""Test that None values (from NULL columns) are handled correctly."""
dataset = _create_mock_dataset()
dataset.values_for_column.return_value = ["Active", None, "Inactive"]
mock_find.return_value = dataset
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_column_sample_data",
{"request": {"dataset_id": 1, "column_name": "status"}},
)
data = json.loads(result.content[0].text)
assert data["values"] == ["Active", None, "Inactive"]
assert data["count"] == 3
def test_get_column_sample_data_request_limit_validation():
"""Test that Pydantic rejects invalid limit values."""
with pytest.raises(ValidationError, match="greater than or equal to 1"):
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=0)
with pytest.raises(ValidationError, match="less than or equal to 100"):
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=101)
with pytest.raises(ValidationError, match="greater than or equal to 1"):
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=-1)
# Valid limits should work
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=1)
assert req.limit == 1
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=100)
assert req.limit == 100
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col")
assert req.limit == 20