diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 9a014915a7d..21f10a13b25 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -446,6 +446,7 @@ from superset.mcp_service.database.tool import ( # noqa: F401, E402 list_databases, ) from superset.mcp_service.dataset.tool import ( # noqa: F401, E402 + get_column_sample_data, get_dataset_info, list_datasets, ) diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index f6083520d5b..fdfd95ce8a9 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -306,6 +306,36 @@ class GetDatasetInfoRequest(MetadataCacheControl): ] +class GetColumnSampleDataRequest(BaseModel): + """Request schema for get_column_sample_data.""" + + dataset_id: int = Field(..., description="The dataset ID to query") + column_name: str = Field( + ..., description="The column name to get distinct values for" + ) + limit: int = Field( + default=20, + ge=1, + le=100, + description="Maximum number of distinct values to return (default 20, max 100)", + ) + + +class ColumnSampleDataResponse(BaseModel): + """Response schema for get_column_sample_data.""" + + dataset_id: int = Field(..., description="The dataset ID queried") + column_name: str = Field(..., description="The column name queried") + values: List[str | int | float | bool | None] = Field( + ..., description="Distinct values found in the column" + ) + count: int = Field(..., description="Number of distinct values returned") + truncated: bool = Field( + False, + description="True if more distinct values exist beyond the limit", + ) + + def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None: """Parse a field that may be stored as a JSON string into a dict.""" value = getattr(obj, field_name, None) diff --git a/superset/mcp_service/dataset/tool/__init__.py b/superset/mcp_service/dataset/tool/__init__.py index 4446d2173ff..075debf643b 100644 --- a/superset/mcp_service/dataset/tool/__init__.py +++ b/superset/mcp_service/dataset/tool/__init__.py @@ -15,10 +15,12 @@ # specific language governing permissions and limitations # under the License. +from .get_column_sample_data import get_column_sample_data from .get_dataset_info import get_dataset_info from .list_datasets import list_datasets __all__ = [ - "list_datasets", + "get_column_sample_data", "get_dataset_info", + "list_datasets", ] diff --git a/superset/mcp_service/dataset/tool/get_column_sample_data.py b/superset/mcp_service/dataset/tool/get_column_sample_data.py new file mode 100644 index 00000000000..f7091b6711b --- /dev/null +++ b/superset/mcp_service/dataset/tool/get_column_sample_data.py @@ -0,0 +1,158 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get column sample data FastMCP tool + +This module contains the FastMCP tool for retrieving distinct values +from a dataset column, useful for building filters in charts. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.exceptions import SupersetSecurityException +from superset.extensions import event_logger +from superset.mcp_service.dataset.schemas import ( + ColumnSampleDataResponse, + DatasetError, + GetColumnSampleDataRequest, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Dataset", + annotations=ToolAnnotations( + title="Get column sample data", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_column_sample_data( + request: GetColumnSampleDataRequest, ctx: Context +) -> ColumnSampleDataResponse | DatasetError: + """Get distinct values for a dataset column. + + Returns up to `limit` distinct values from the specified column. + Useful for discovering valid filter values when building charts. + Respects row-level security and dataset fetch_values_predicate. + + IMPORTANT FOR LLM CLIENTS: + - Use this tool BEFORE creating charts with filters to discover actual + column values instead of guessing + - Use get_dataset_info first to find column names and types + - Low-cardinality columns (gender, status, category) work best + + Example usage: + ```json + { + "dataset_id": 123, + "column_name": "gender", + "limit": 20 + } + ``` + """ + await ctx.info( + "Retrieving column sample data: dataset_id=%s, column=%s, limit=%s" + % (request.dataset_id, request.column_name, request.limit) + ) + + try: + from superset.daos.dataset import DatasetDAO + + with event_logger.log_context(action="mcp.get_column_sample_data.lookup"): + dataset = DatasetDAO.find_by_id(request.dataset_id) + + if not dataset: + await ctx.warning( + "Dataset not found: dataset_id=%s" % (request.dataset_id,) + ) + return DatasetError( + error=f"Dataset with ID {request.dataset_id} not found", + error_type="NotFound", + timestamp=datetime.now(timezone.utc), + ) + + try: + dataset.raise_for_access() + except SupersetSecurityException as ex: + await ctx.warning( + "Permission denied for dataset_id=%s: %s" + % (request.dataset_id, str(ex)) + ) + return DatasetError( + error=f"Permission denied for dataset {request.dataset_id}", + error_type="PermissionDenied", + timestamp=datetime.now(timezone.utc), + ) + + # Fetch one extra value to detect truncation without a COUNT query + fetch_limit = request.limit + 1 + denormalize_column = not dataset.normalize_columns + + with event_logger.log_context(action="mcp.get_column_sample_data.query"): + raw_values = dataset.values_for_column( + column_name=request.column_name, + limit=fetch_limit, + denormalize_column=denormalize_column, + ) + + truncated = len(raw_values) > request.limit + values = raw_values[: request.limit] + + await ctx.info( + "Column sample data retrieved: dataset_id=%s, column=%s, " + "count=%s, truncated=%s" + % (request.dataset_id, request.column_name, len(values), truncated) + ) + + return ColumnSampleDataResponse( + dataset_id=request.dataset_id, + column_name=request.column_name, + values=values, + count=len(values), + truncated=truncated, + ) + + except KeyError: + await ctx.warning( + "Column not found: column=%s in dataset_id=%s" + % (request.column_name, request.dataset_id) + ) + return DatasetError( + error=f"Column '{request.column_name}' does not exist " + f"in dataset {request.dataset_id}", + error_type="ColumnNotFound", + timestamp=datetime.now(timezone.utc), + ) + except Exception as e: + await ctx.error( + "Column sample data retrieval failed: dataset_id=%s, column=%s, " + "error=%s, error_type=%s" + % (request.dataset_id, request.column_name, str(e), type(e).__name__) + ) + return DatasetError( + error=f"Failed to get column sample data: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_get_column_sample_data.py b/tests/unit_tests/mcp_service/dataset/tool/test_get_column_sample_data.py new file mode 100644 index 00000000000..80c7fb061e1 --- /dev/null +++ b/tests/unit_tests/mcp_service/dataset/tool/test_get_column_sample_data.py @@ -0,0 +1,248 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import logging +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetSecurityException +from superset.mcp_service.app import mcp +from superset.mcp_service.dataset.schemas import GetColumnSampleDataRequest +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +def _create_mock_dataset(dataset_id=1, normalize_columns=True): + """Create a mock dataset with values_for_column support.""" + dataset = MagicMock() + dataset.id = dataset_id + dataset.normalize_columns = normalize_columns + dataset.raise_for_access = MagicMock() + dataset.values_for_column = MagicMock() + return dataset + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_column_sample_data_success(mock_find, mcp_server): + """Test successful retrieval of column sample data.""" + dataset = _create_mock_dataset() + dataset.values_for_column.return_value = ["Male", "Female", "Other"] + mock_find.return_value = dataset + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_column_sample_data", + {"request": {"dataset_id": 1, "column_name": "gender", "limit": 20}}, + ) + data = json.loads(result.content[0].text) + assert data["dataset_id"] == 1 + assert data["column_name"] == "gender" + assert data["values"] == ["Male", "Female", "Other"] + assert data["count"] == 3 + assert data["truncated"] is False + + dataset.values_for_column.assert_called_once_with( + column_name="gender", + limit=21, # limit + 1 for truncation detection + denormalize_column=False, # normalize_columns=True -> denormalize=False + ) + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_column_sample_data_truncated(mock_find, mcp_server): + """Test truncation detection when more values exist than the limit.""" + dataset = _create_mock_dataset() + # Return 4 values when limit is 3 (tool fetches limit+1=4) + dataset.values_for_column.return_value = ["A", "B", "C", "D"] + mock_find.return_value = dataset + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_column_sample_data", + {"request": {"dataset_id": 1, "column_name": "status", "limit": 3}}, + ) + data = json.loads(result.content[0].text) + assert data["values"] == ["A", "B", "C"] + assert data["count"] == 3 + assert data["truncated"] is True + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_column_sample_data_dataset_not_found(mock_find, mcp_server): + """Test error when dataset does not exist.""" + mock_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_column_sample_data", + {"request": {"dataset_id": 999, "column_name": "gender"}}, + ) + data = json.loads(result.content[0].text) + assert data["error_type"] == "NotFound" + assert "999" in data["error"] + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_column_sample_data_permission_denied(mock_find, mcp_server): + """Test error when user lacks permission to access the dataset.""" + dataset = _create_mock_dataset() + dataset.raise_for_access.side_effect = SupersetSecurityException( + SupersetError( + message="Access denied", + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + mock_find.return_value = dataset + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_column_sample_data", + {"request": {"dataset_id": 1, "column_name": "gender"}}, + ) + data = json.loads(result.content[0].text) + assert data["error_type"] == "PermissionDenied" + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_column_sample_data_column_not_found(mock_find, mcp_server): + """Test error when column does not exist in the dataset.""" + dataset = _create_mock_dataset() + dataset.values_for_column.side_effect = KeyError("nonexistent_col") + mock_find.return_value = dataset + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_column_sample_data", + {"request": {"dataset_id": 1, "column_name": "nonexistent_col"}}, + ) + data = json.loads(result.content[0].text) + assert data["error_type"] == "ColumnNotFound" + assert "nonexistent_col" in data["error"] + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_column_sample_data_default_limit(mock_find, mcp_server): + """Test that omitting limit defaults to 20.""" + dataset = _create_mock_dataset() + dataset.values_for_column.return_value = list(range(10)) + mock_find.return_value = dataset + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_column_sample_data", + {"request": {"dataset_id": 1, "column_name": "category"}}, + ) + data = json.loads(result.content[0].text) + assert data["count"] == 10 + + # Default limit=20, so fetch_limit should be 21 + dataset.values_for_column.assert_called_once_with( + column_name="category", + limit=21, + denormalize_column=False, + ) + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_column_sample_data_denormalize_column(mock_find, mcp_server): + """Test that denormalize_column is set based on dataset.normalize_columns.""" + dataset = _create_mock_dataset(normalize_columns=False) + dataset.values_for_column.return_value = ["val1"] + mock_find.return_value = dataset + + async with Client(mcp_server) as client: + await client.call_tool( + "get_column_sample_data", + {"request": {"dataset_id": 1, "column_name": "col"}}, + ) + + dataset.values_for_column.assert_called_once_with( + column_name="col", + limit=21, + denormalize_column=True, # normalize_columns=False -> denormalize=True + ) + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_column_sample_data_with_none_values(mock_find, mcp_server): + """Test that None values (from NULL columns) are handled correctly.""" + dataset = _create_mock_dataset() + dataset.values_for_column.return_value = ["Active", None, "Inactive"] + mock_find.return_value = dataset + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_column_sample_data", + {"request": {"dataset_id": 1, "column_name": "status"}}, + ) + data = json.loads(result.content[0].text) + assert data["values"] == ["Active", None, "Inactive"] + assert data["count"] == 3 + + +def test_get_column_sample_data_request_limit_validation(): + """Test that Pydantic rejects invalid limit values.""" + with pytest.raises(ValidationError, match="greater than or equal to 1"): + GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=0) + + with pytest.raises(ValidationError, match="less than or equal to 100"): + GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=101) + + with pytest.raises(ValidationError, match="greater than or equal to 1"): + GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=-1) + + # Valid limits should work + req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=1) + assert req.limit == 1 + + req = GetColumnSampleDataRequest(dataset_id=1, column_name="col", limit=100) + assert req.limit == 100 + + req = GetColumnSampleDataRequest(dataset_id=1, column_name="col") + assert req.limit == 20