mirror of
https://github.com/apache/superset.git
synced 2026-05-12 19:35:17 +00:00
feat(mcp): get sample data tool
This commit is contained in:
@@ -446,6 +446,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402
|
|||||||
list_databases,
|
list_databases,
|
||||||
)
|
)
|
||||||
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||||
|
get_column_sample_data,
|
||||||
get_dataset_info,
|
get_dataset_info,
|
||||||
list_datasets,
|
list_datasets,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -306,6 +306,36 @@ class GetDatasetInfoRequest(MetadataCacheControl):
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class GetColumnSampleDataRequest(BaseModel):
|
||||||
|
"""Request schema for get_column_sample_data."""
|
||||||
|
|
||||||
|
dataset_id: int = Field(..., description="The dataset ID to query")
|
||||||
|
column_name: str = Field(
|
||||||
|
..., description="The column name to get distinct values for"
|
||||||
|
)
|
||||||
|
limit: int = Field(
|
||||||
|
default=20,
|
||||||
|
ge=1,
|
||||||
|
le=100,
|
||||||
|
description="Maximum number of distinct values to return (default 20, max 100)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnSampleDataResponse(BaseModel):
|
||||||
|
"""Response schema for get_column_sample_data."""
|
||||||
|
|
||||||
|
dataset_id: int = Field(..., description="The dataset ID queried")
|
||||||
|
column_name: str = Field(..., description="The column name queried")
|
||||||
|
values: List[str | int | float | bool | None] = Field(
|
||||||
|
..., description="Distinct values found in the column"
|
||||||
|
)
|
||||||
|
count: int = Field(..., description="Number of distinct values returned")
|
||||||
|
truncated: bool = Field(
|
||||||
|
False,
|
||||||
|
description="True if more distinct values exist beyond the limit",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
|
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
|
||||||
"""Parse a field that may be stored as a JSON string into a dict."""
|
"""Parse a field that may be stored as a JSON string into a dict."""
|
||||||
value = getattr(obj, field_name, None)
|
value = getattr(obj, field_name, None)
|
||||||
|
|||||||
@@ -15,10 +15,12 @@
|
|||||||
# specific language governing permissions and limitations
|
# specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
|
from .get_column_sample_data import get_column_sample_data
|
||||||
from .get_dataset_info import get_dataset_info
|
from .get_dataset_info import get_dataset_info
|
||||||
from .list_datasets import list_datasets
|
from .list_datasets import list_datasets
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"list_datasets",
|
"get_column_sample_data",
|
||||||
"get_dataset_info",
|
"get_dataset_info",
|
||||||
|
"list_datasets",
|
||||||
]
|
]
|
||||||
|
|||||||
158
superset/mcp_service/dataset/tool/get_column_sample_data.py
Normal file
158
superset/mcp_service/dataset/tool/get_column_sample_data.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""
|
||||||
|
Get column sample data FastMCP tool
|
||||||
|
|
||||||
|
This module contains the FastMCP tool for retrieving distinct values
|
||||||
|
from a dataset column, useful for building filters in charts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastmcp import Context
|
||||||
|
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||||
|
|
||||||
|
from superset.exceptions import SupersetSecurityException
|
||||||
|
from superset.extensions import event_logger
|
||||||
|
from superset.mcp_service.dataset.schemas import (
|
||||||
|
ColumnSampleDataResponse,
|
||||||
|
DatasetError,
|
||||||
|
GetColumnSampleDataRequest,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
tags=["discovery"],
|
||||||
|
class_permission_name="Dataset",
|
||||||
|
annotations=ToolAnnotations(
|
||||||
|
title="Get column sample data",
|
||||||
|
readOnlyHint=True,
|
||||||
|
destructiveHint=False,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
async def get_column_sample_data(
|
||||||
|
request: GetColumnSampleDataRequest, ctx: Context
|
||||||
|
) -> ColumnSampleDataResponse | DatasetError:
|
||||||
|
"""Get distinct values for a dataset column.
|
||||||
|
|
||||||
|
Returns up to `limit` distinct values from the specified column.
|
||||||
|
Useful for discovering valid filter values when building charts.
|
||||||
|
Respects row-level security and dataset fetch_values_predicate.
|
||||||
|
|
||||||
|
IMPORTANT FOR LLM CLIENTS:
|
||||||
|
- Use this tool BEFORE creating charts with filters to discover actual
|
||||||
|
column values instead of guessing
|
||||||
|
- Use get_dataset_info first to find column names and types
|
||||||
|
- Low-cardinality columns (gender, status, category) work best
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"dataset_id": 123,
|
||||||
|
"column_name": "gender",
|
||||||
|
"limit": 20
|
||||||
|
}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
await ctx.info(
|
||||||
|
"Retrieving column sample data: dataset_id=%s, column=%s, limit=%s"
|
||||||
|
% (request.dataset_id, request.column_name, request.limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
from superset.daos.dataset import DatasetDAO
|
||||||
|
|
||||||
|
with event_logger.log_context(action="mcp.get_column_sample_data.lookup"):
|
||||||
|
dataset = DatasetDAO.find_by_id(request.dataset_id)
|
||||||
|
|
||||||
|
if not dataset:
|
||||||
|
await ctx.warning(
|
||||||
|
"Dataset not found: dataset_id=%s" % (request.dataset_id,)
|
||||||
|
)
|
||||||
|
return DatasetError(
|
||||||
|
error=f"Dataset with ID {request.dataset_id} not found",
|
||||||
|
error_type="NotFound",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
dataset.raise_for_access()
|
||||||
|
except SupersetSecurityException as ex:
|
||||||
|
await ctx.warning(
|
||||||
|
"Permission denied for dataset_id=%s: %s"
|
||||||
|
% (request.dataset_id, str(ex))
|
||||||
|
)
|
||||||
|
return DatasetError(
|
||||||
|
error=f"Permission denied for dataset {request.dataset_id}",
|
||||||
|
error_type="PermissionDenied",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fetch one extra value to detect truncation without a COUNT query
|
||||||
|
fetch_limit = request.limit + 1
|
||||||
|
denormalize_column = not dataset.normalize_columns
|
||||||
|
|
||||||
|
with event_logger.log_context(action="mcp.get_column_sample_data.query"):
|
||||||
|
raw_values = dataset.values_for_column(
|
||||||
|
column_name=request.column_name,
|
||||||
|
limit=fetch_limit,
|
||||||
|
denormalize_column=denormalize_column,
|
||||||
|
)
|
||||||
|
|
||||||
|
truncated = len(raw_values) > request.limit
|
||||||
|
values = raw_values[: request.limit]
|
||||||
|
|
||||||
|
await ctx.info(
|
||||||
|
"Column sample data retrieved: dataset_id=%s, column=%s, "
|
||||||
|
"count=%s, truncated=%s"
|
||||||
|
% (request.dataset_id, request.column_name, len(values), truncated)
|
||||||
|
)
|
||||||
|
|
||||||
|
return ColumnSampleDataResponse(
|
||||||
|
dataset_id=request.dataset_id,
|
||||||
|
column_name=request.column_name,
|
||||||
|
values=values,
|
||||||
|
count=len(values),
|
||||||
|
truncated=truncated,
|
||||||
|
)
|
||||||
|
|
||||||
|
except KeyError:
|
||||||
|
await ctx.warning(
|
||||||
|
"Column not found: column=%s in dataset_id=%s"
|
||||||
|
% (request.column_name, request.dataset_id)
|
||||||
|
)
|
||||||
|
return DatasetError(
|
||||||
|
error=f"Column '{request.column_name}' does not exist "
|
||||||
|
f"in dataset {request.dataset_id}",
|
||||||
|
error_type="ColumnNotFound",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
await ctx.error(
|
||||||
|
"Column sample data retrieval failed: dataset_id=%s, column=%s, "
|
||||||
|
"error=%s, error_type=%s"
|
||||||
|
% (request.dataset_id, request.column_name, str(e), type(e).__name__)
|
||||||
|
)
|
||||||
|
return DatasetError(
|
||||||
|
error=f"Failed to get column sample data: {str(e)}",
|
||||||
|
error_type="InternalError",
|
||||||
|
timestamp=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
@@ -0,0 +1,248 @@
|
|||||||
|
# Licensed to the Apache Software Foundation (ASF) under one
|
||||||
|
# or more contributor license agreements. See the NOTICE file
|
||||||
|
# distributed with this work for additional information
|
||||||
|
# regarding copyright ownership. The ASF licenses this file
|
||||||
|
# to you under the Apache License, Version 2.0 (the
|
||||||
|
# "License"); you may not use this file except in compliance
|
||||||
|
# with the License. You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing,
|
||||||
|
# software distributed under the License is distributed on an
|
||||||
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||||
|
# KIND, either express or implied. See the License for the
|
||||||
|
# specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastmcp import Client
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||||
|
from superset.exceptions import SupersetSecurityException
|
||||||
|
from superset.mcp_service.app import mcp
|
||||||
|
from superset.mcp_service.dataset.schemas import GetColumnSampleDataRequest
|
||||||
|
from superset.utils import json
|
||||||
|
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mcp_server():
|
||||||
|
return mcp
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_auth():
|
||||||
|
"""Mock authentication for all tests."""
|
||||||
|
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||||
|
mock_user = Mock()
|
||||||
|
mock_user.id = 1
|
||||||
|
mock_user.username = "admin"
|
||||||
|
mock_get_user.return_value = mock_user
|
||||||
|
yield mock_get_user
|
||||||
|
|
||||||
|
|
||||||
|
def _create_mock_dataset(dataset_id=1, normalize_columns=True):
|
||||||
|
"""Create a mock dataset with values_for_column support."""
|
||||||
|
dataset = MagicMock()
|
||||||
|
dataset.id = dataset_id
|
||||||
|
dataset.normalize_columns = normalize_columns
|
||||||
|
dataset.raise_for_access = MagicMock()
|
||||||
|
dataset.values_for_column = MagicMock()
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_column_sample_data_success(mock_find, mcp_server):
|
||||||
|
"""Test successful retrieval of column sample data."""
|
||||||
|
dataset = _create_mock_dataset()
|
||||||
|
dataset.values_for_column.return_value = ["Male", "Female", "Other"]
|
||||||
|
mock_find.return_value = dataset
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"get_column_sample_data",
|
||||||
|
{"request": {"dataset_id": 1, "column_name": "gender", "limit": 20}},
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["dataset_id"] == 1
|
||||||
|
assert data["column_name"] == "gender"
|
||||||
|
assert data["values"] == ["Male", "Female", "Other"]
|
||||||
|
assert data["count"] == 3
|
||||||
|
assert data["truncated"] is False
|
||||||
|
|
||||||
|
dataset.values_for_column.assert_called_once_with(
|
||||||
|
column_name="gender",
|
||||||
|
limit=21, # limit + 1 for truncation detection
|
||||||
|
denormalize_column=False, # normalize_columns=True -> denormalize=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_column_sample_data_truncated(mock_find, mcp_server):
|
||||||
|
"""Test truncation detection when more values exist than the limit."""
|
||||||
|
dataset = _create_mock_dataset()
|
||||||
|
# Return 4 values when limit is 3 (tool fetches limit+1=4)
|
||||||
|
dataset.values_for_column.return_value = ["A", "B", "C", "D"]
|
||||||
|
mock_find.return_value = dataset
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"get_column_sample_data",
|
||||||
|
{"request": {"dataset_id": 1, "column_name": "status", "limit": 3}},
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["values"] == ["A", "B", "C"]
|
||||||
|
assert data["count"] == 3
|
||||||
|
assert data["truncated"] is True
|
||||||
|
|
||||||
|
|
||||||
|
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_column_sample_data_dataset_not_found(mock_find, mcp_server):
|
||||||
|
"""Test error when dataset does not exist."""
|
||||||
|
mock_find.return_value = None
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"get_column_sample_data",
|
||||||
|
{"request": {"dataset_id": 999, "column_name": "gender"}},
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["error_type"] == "NotFound"
|
||||||
|
assert "999" in data["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_column_sample_data_permission_denied(mock_find, mcp_server):
|
||||||
|
"""Test error when user lacks permission to access the dataset."""
|
||||||
|
dataset = _create_mock_dataset()
|
||||||
|
dataset.raise_for_access.side_effect = SupersetSecurityException(
|
||||||
|
SupersetError(
|
||||||
|
message="Access denied",
|
||||||
|
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||||
|
level=ErrorLevel.ERROR,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
mock_find.return_value = dataset
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"get_column_sample_data",
|
||||||
|
{"request": {"dataset_id": 1, "column_name": "gender"}},
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["error_type"] == "PermissionDenied"
|
||||||
|
|
||||||
|
|
||||||
|
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_column_sample_data_column_not_found(mock_find, mcp_server):
|
||||||
|
"""Test error when column does not exist in the dataset."""
|
||||||
|
dataset = _create_mock_dataset()
|
||||||
|
dataset.values_for_column.side_effect = KeyError("nonexistent_col")
|
||||||
|
mock_find.return_value = dataset
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"get_column_sample_data",
|
||||||
|
{"request": {"dataset_id": 1, "column_name": "nonexistent_col"}},
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["error_type"] == "ColumnNotFound"
|
||||||
|
assert "nonexistent_col" in data["error"]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_column_sample_data_default_limit(mock_find, mcp_server):
|
||||||
|
"""Test that omitting limit defaults to 20."""
|
||||||
|
dataset = _create_mock_dataset()
|
||||||
|
dataset.values_for_column.return_value = list(range(10))
|
||||||
|
mock_find.return_value = dataset
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"get_column_sample_data",
|
||||||
|
{"request": {"dataset_id": 1, "column_name": "category"}},
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["count"] == 10
|
||||||
|
|
||||||
|
# Default limit=20, so fetch_limit should be 21
|
||||||
|
dataset.values_for_column.assert_called_once_with(
|
||||||
|
column_name="category",
|
||||||
|
limit=21,
|
||||||
|
denormalize_column=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_column_sample_data_denormalize_column(mock_find, mcp_server):
|
||||||
|
"""Test that denormalize_column is set based on dataset.normalize_columns."""
|
||||||
|
dataset = _create_mock_dataset(normalize_columns=False)
|
||||||
|
dataset.values_for_column.return_value = ["val1"]
|
||||||
|
mock_find.return_value = dataset
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
await client.call_tool(
|
||||||
|
"get_column_sample_data",
|
||||||
|
{"request": {"dataset_id": 1, "column_name": "col"}},
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset.values_for_column.assert_called_once_with(
|
||||||
|
column_name="col",
|
||||||
|
limit=21,
|
||||||
|
denormalize_column=True, # normalize_columns=False -> denormalize=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_column_sample_data_with_none_values(mock_find, mcp_server):
|
||||||
|
"""Test that None values (from NULL columns) are handled correctly."""
|
||||||
|
dataset = _create_mock_dataset()
|
||||||
|
dataset.values_for_column.return_value = ["Active", None, "Inactive"]
|
||||||
|
mock_find.return_value = dataset
|
||||||
|
|
||||||
|
async with Client(mcp_server) as client:
|
||||||
|
result = await client.call_tool(
|
||||||
|
"get_column_sample_data",
|
||||||
|
{"request": {"dataset_id": 1, "column_name": "status"}},
|
||||||
|
)
|
||||||
|
data = json.loads(result.content[0].text)
|
||||||
|
assert data["values"] == ["Active", None, "Inactive"]
|
||||||
|
assert data["count"] == 3
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_column_sample_data_request_limit_validation():
|
||||||
|
"""Test that Pydantic rejects invalid limit values."""
|
||||||
|
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||||
|
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="less than or equal to 100"):
|
||||||
|
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=101)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError, match="greater than or equal to 1"):
|
||||||
|
GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=-1)
|
||||||
|
|
||||||
|
# Valid limits should work
|
||||||
|
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=1)
|
||||||
|
assert req.limit == 1
|
||||||
|
|
||||||
|
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=100)
|
||||||
|
assert req.limit == 100
|
||||||
|
|
||||||
|
req = GetColumnSampleDataRequest(dataset_id=1, column_name="col")
|
||||||
|
assert req.limit == 20
|
||||||
Reference in New Issue
Block a user