mirror of
https://github.com/apache/superset.git
synced 2026-05-07 08:54:23 +00:00
feat(mcp): get sample data tool
This commit is contained in:
@@ -446,6 +446,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402
|
||||
list_databases,
|
||||
)
|
||||
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
get_column_sample_data,
|
||||
get_dataset_info,
|
||||
list_datasets,
|
||||
)
|
||||
|
||||
@@ -306,6 +306,36 @@ class GetDatasetInfoRequest(MetadataCacheControl):
|
||||
]
|
||||
|
||||
|
||||
class GetColumnSampleDataRequest(BaseModel):
|
||||
"""Request schema for get_column_sample_data."""
|
||||
|
||||
dataset_id: int = Field(..., description="The dataset ID to query")
|
||||
column_name: str = Field(
|
||||
..., description="The column name to get distinct values for"
|
||||
)
|
||||
limit: int = Field(
|
||||
default=20,
|
||||
ge=1,
|
||||
le=100,
|
||||
description="Maximum number of distinct values to return (default 20, max 100)",
|
||||
)
|
||||
|
||||
|
||||
class ColumnSampleDataResponse(BaseModel):
|
||||
"""Response schema for get_column_sample_data."""
|
||||
|
||||
dataset_id: int = Field(..., description="The dataset ID queried")
|
||||
column_name: str = Field(..., description="The column name queried")
|
||||
values: List[str | int | float | bool | None] = Field(
|
||||
..., description="Distinct values found in the column"
|
||||
)
|
||||
count: int = Field(..., description="Number of distinct values returned")
|
||||
truncated: bool = Field(
|
||||
False,
|
||||
description="True if more distinct values exist beyond the limit",
|
||||
)
|
||||
|
||||
|
||||
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
|
||||
"""Parse a field that may be stored as a JSON string into a dict."""
|
||||
value = getattr(obj, field_name, None)
|
||||
|
||||
@@ -15,10 +15,12 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_column_sample_data import get_column_sample_data
|
||||
from .get_dataset_info import get_dataset_info
|
||||
from .list_datasets import list_datasets
|
||||
|
||||
__all__ = [
|
||||
"list_datasets",
|
||||
"get_column_sample_data",
|
||||
"get_dataset_info",
|
||||
"list_datasets",
|
||||
]
|
||||
|
||||
158
superset/mcp_service/dataset/tool/get_column_sample_data.py
Normal file
158
superset/mcp_service/dataset/tool/get_column_sample_data.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get column sample data FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for retrieving distinct values
|
||||
from a dataset column, useful for building filters in charts.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.dataset.schemas import (
|
||||
ColumnSampleDataResponse,
|
||||
DatasetError,
|
||||
GetColumnSampleDataRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="Dataset",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get column sample data",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_column_sample_data(
|
||||
request: GetColumnSampleDataRequest, ctx: Context
|
||||
) -> ColumnSampleDataResponse | DatasetError:
|
||||
"""Get distinct values for a dataset column.
|
||||
|
||||
Returns up to `limit` distinct values from the specified column.
|
||||
Useful for discovering valid filter values when building charts.
|
||||
Respects row-level security and dataset fetch_values_predicate.
|
||||
|
||||
IMPORTANT FOR LLM CLIENTS:
|
||||
- Use this tool BEFORE creating charts with filters to discover actual
|
||||
column values instead of guessing
|
||||
- Use get_dataset_info first to find column names and types
|
||||
- Low-cardinality columns (gender, status, category) work best
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{
|
||||
"dataset_id": 123,
|
||||
"column_name": "gender",
|
||||
"limit": 20
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Retrieving column sample data: dataset_id=%s, column=%s, limit=%s"
|
||||
% (request.dataset_id, request.column_name, request.limit)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_column_sample_data.lookup"):
|
||||
dataset = DatasetDAO.find_by_id(request.dataset_id)
|
||||
|
||||
if not dataset:
|
||||
await ctx.warning(
|
||||
"Dataset not found: dataset_id=%s" % (request.dataset_id,)
|
||||
)
|
||||
return DatasetError(
|
||||
error=f"Dataset with ID {request.dataset_id} not found",
|
||||
error_type="NotFound",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
try:
|
||||
dataset.raise_for_access()
|
||||
except SupersetSecurityException as ex:
|
||||
await ctx.warning(
|
||||
"Permission denied for dataset_id=%s: %s"
|
||||
% (request.dataset_id, str(ex))
|
||||
)
|
||||
return DatasetError(
|
||||
error=f"Permission denied for dataset {request.dataset_id}",
|
||||
error_type="PermissionDenied",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
# Fetch one extra value to detect truncation without a COUNT query
|
||||
fetch_limit = request.limit + 1
|
||||
denormalize_column = not dataset.normalize_columns
|
||||
|
||||
with event_logger.log_context(action="mcp.get_column_sample_data.query"):
|
||||
raw_values = dataset.values_for_column(
|
||||
column_name=request.column_name,
|
||||
limit=fetch_limit,
|
||||
denormalize_column=denormalize_column,
|
||||
)
|
||||
|
||||
truncated = len(raw_values) > request.limit
|
||||
values = raw_values[: request.limit]
|
||||
|
||||
await ctx.info(
|
||||
"Column sample data retrieved: dataset_id=%s, column=%s, "
|
||||
"count=%s, truncated=%s"
|
||||
% (request.dataset_id, request.column_name, len(values), truncated)
|
||||
)
|
||||
|
||||
return ColumnSampleDataResponse(
|
||||
dataset_id=request.dataset_id,
|
||||
column_name=request.column_name,
|
||||
values=values,
|
||||
count=len(values),
|
||||
truncated=truncated,
|
||||
)
|
||||
|
||||
except KeyError:
|
||||
await ctx.warning(
|
||||
"Column not found: column=%s in dataset_id=%s"
|
||||
% (request.column_name, request.dataset_id)
|
||||
)
|
||||
return DatasetError(
|
||||
error=f"Column '{request.column_name}' does not exist "
|
||||
f"in dataset {request.dataset_id}",
|
||||
error_type="ColumnNotFound",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Column sample data retrieval failed: dataset_id=%s, column=%s, "
|
||||
"error=%s, error_type=%s"
|
||||
% (request.dataset_id, request.column_name, str(e), type(e).__name__)
|
||||
)
|
||||
return DatasetError(
|
||||
error=f"Failed to get column sample data: {str(e)}",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
@@ -0,0 +1,248 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.dataset.schemas import GetColumnSampleDataRequest
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _create_mock_dataset(dataset_id=1, normalize_columns=True):
|
||||
"""Create a mock dataset with values_for_column support."""
|
||||
dataset = MagicMock()
|
||||
dataset.id = dataset_id
|
||||
dataset.normalize_columns = normalize_columns
|
||||
dataset.raise_for_access = MagicMock()
|
||||
dataset.values_for_column = MagicMock()
|
||||
return dataset
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_success(mock_find, mcp_server):
|
||||
"""Test successful retrieval of column sample data."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.values_for_column.return_value = ["Male", "Female", "Other"]
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "gender", "limit": 20}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["dataset_id"] == 1
|
||||
assert data["column_name"] == "gender"
|
||||
assert data["values"] == ["Male", "Female", "Other"]
|
||||
assert data["count"] == 3
|
||||
assert data["truncated"] is False
|
||||
|
||||
dataset.values_for_column.assert_called_once_with(
|
||||
column_name="gender",
|
||||
limit=21, # limit + 1 for truncation detection
|
||||
denormalize_column=False, # normalize_columns=True -> denormalize=False
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_truncated(mock_find, mcp_server):
|
||||
"""Test truncation detection when more values exist than the limit."""
|
||||
dataset = _create_mock_dataset()
|
||||
# Return 4 values when limit is 3 (tool fetches limit+1=4)
|
||||
dataset.values_for_column.return_value = ["A", "B", "C", "D"]
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "status", "limit": 3}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["values"] == ["A", "B", "C"]
|
||||
assert data["count"] == 3
|
||||
assert data["truncated"] is True
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_dataset_not_found(mock_find, mcp_server):
|
||||
"""Test error when dataset does not exist."""
|
||||
mock_find.return_value = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 999, "column_name": "gender"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "NotFound"
|
||||
assert "999" in data["error"]
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_permission_denied(mock_find, mcp_server):
|
||||
"""Test error when user lacks permission to access the dataset."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.raise_for_access.side_effect = SupersetSecurityException(
|
||||
SupersetError(
|
||||
message="Access denied",
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "gender"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "PermissionDenied"
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_column_not_found(mock_find, mcp_server):
|
||||
"""Test error when column does not exist in the dataset."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.values_for_column.side_effect = KeyError("nonexistent_col")
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "nonexistent_col"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "ColumnNotFound"
|
||||
assert "nonexistent_col" in data["error"]
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_default_limit(mock_find, mcp_server):
|
||||
"""Test that omitting limit defaults to 20."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.values_for_column.return_value = list(range(10))
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "category"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 10
|
||||
|
||||
# Default limit=20, so fetch_limit should be 21
|
||||
dataset.values_for_column.assert_called_once_with(
|
||||
column_name="category",
|
||||
limit=21,
|
||||
denormalize_column=False,
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_denormalize_column(mock_find, mcp_server):
|
||||
"""Test that denormalize_column is set based on dataset.normalize_columns."""
|
||||
dataset = _create_mock_dataset(normalize_columns=False)
|
||||
dataset.values_for_column.return_value = ["val1"]
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "col"}},
|
||||
)
|
||||
|
||||
dataset.values_for_column.assert_called_once_with(
|
||||
column_name="col",
|
||||
limit=21,
|
||||
denormalize_column=True, # normalize_columns=False -> denormalize=True
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_column_sample_data_with_none_values(mock_find, mcp_server):
|
||||
"""Test that None values (from NULL columns) are handled correctly."""
|
||||
dataset = _create_mock_dataset()
|
||||
dataset.values_for_column.return_value = ["Active", None, "Inactive"]
|
||||
mock_find.return_value = dataset
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_column_sample_data",
|
||||
{"request": {"dataset_id": 1, "column_name": "status"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["values"] == ["Active", None, "Inactive"]
|
||||
assert data["count"] == 3
|
||||
|
||||
|
||||
def test_get_column_sample_data_request_limit_validation():
|
||||
"""Test that Pydantic rejects invalid limit values."""
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=0)
|
||||
|
||||
with pytest.raises(ValidationError, match="less than or equal to 100"):
|
||||
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=101)
|
||||
|
||||
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=-1)
|
||||
|
||||
# Valid limits should work
|
||||
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=1)
|
||||
assert req.limit == 1
|
||||
|
||||
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=100)
|
||||
assert req.limit == 100
|
||||
|
||||
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col")
|
||||
assert req.limit == 20
|
||||
Reference in New Issue
Block a user