Compare commits

..

1 Commits

Author SHA1 Message Date
Amin Ghadersohi
21fa0148bf feat(mcp): add get_dashboard_datasets tool 2026-06-11 00:10:27 +00:00
8 changed files with 721 additions and 138 deletions

View File

@@ -103,19 +103,6 @@ class DatasourceTypeUpdateRequiredValidationError(ValidationError):
)
class ChartQueryContextDatasourceMismatchValidationError(ValidationError):
"""
Raised when a query-context-only update carries a datasource that does not
match the chart's own datasource.
"""
def __init__(self) -> None:
super().__init__(
_("The query context datasource does not match the chart datasource"),
field_name="query_context",
)
class ChartNotFoundError(CommandException):
message = "Chart not found."

View File

@@ -29,7 +29,6 @@ from superset.commands.chart.exceptions import (
ChartForbiddenError,
ChartInvalidError,
ChartNotFoundError,
ChartQueryContextDatasourceMismatchValidationError,
ChartUpdateFailedError,
DashboardsForbiddenError,
DashboardsNotFoundValidationError,
@@ -42,7 +41,6 @@ from superset.exceptions import SupersetSecurityException
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.tags.models import ObjectType
from superset.utils import json
from superset.utils.decorators import on_error, transaction
logger = logging.getLogger(__name__)
@@ -103,51 +101,6 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
if not security_manager.is_owner(dash):
raise DashboardsForbiddenError()
def _validate_query_context_datasource(
self, exceptions: list[ValidationError]
) -> None:
"""
Ensure a query-context-only update keeps the chart's own datasource.
The submitted query context is only verified when it carries a parseable
``datasource`` object; a payload that references a different datasource than
the chart's persisted one is rejected. Payloads without a datasource fall
back to the chart's datasource at execution time and need no check.
"""
if not self._model:
return
raw_query_context = self._properties.get("query_context")
if not raw_query_context:
return
try:
query_context = json.loads(raw_query_context)
except (TypeError, ValueError):
# An unparseable payload cannot be verified or replayed; leave it for
# downstream handling rather than guessing at its intent.
return
datasource = (
query_context.get("datasource") if isinstance(query_context, dict) else None
)
if not isinstance(datasource, dict):
return
try:
ids_match = int(datasource["id"]) == self._model.datasource_id
except (KeyError, TypeError, ValueError):
ids_match = False
datasource_type = datasource.get("type")
types_match = (
datasource_type is None
or str(datasource_type) == self._model.datasource_type
)
if not ids_match or not types_match:
exceptions.append(ChartQueryContextDatasourceMismatchValidationError())
def validate(self) -> None: # noqa: C901
exceptions: list[ValidationError] = []
dashboard_ids = self._properties.get("dashboards")
@@ -181,12 +134,6 @@ class UpdateChartCommand(UpdateMixin, BaseCommand):
raise ChartForbiddenError() from ex
except ValidationError as ex:
exceptions.append(ex)
else:
# The query-context-only path skips the ownership check so report and
# alert workers can refresh a chart's cached payload. Keep that payload
# bound to the chart's own datasource so it cannot be repointed at an
# unrelated one.
self._validate_query_context_datasource(exceptions)
# validate tags
try:

View File

@@ -128,6 +128,7 @@ Dashboard Management:
- list_dashboards: List dashboards with advanced filters (1-based pagination)
- get_dashboard_info: Get detailed dashboard information by ID
- get_dashboard_layout: Get parsed tabs and chart positions for a dashboard (companion to get_dashboard_info when its omitted_fields hint flags position_json)
- get_dashboard_datasets: List the datasets used by a dashboard's charts, with columns and metrics (context for configuring native filters)
- generate_dashboard: Create a dashboard from chart IDs (requires write access)
- add_chart_to_existing_dashboard: Add a chart to an existing dashboard (requires write access)
@@ -680,6 +681,7 @@ from superset.mcp_service.chart.tool import ( # noqa: F401, E402
from superset.mcp_service.dashboard.tool import ( # noqa: F401, E402
add_chart_to_existing_dashboard,
generate_dashboard,
get_dashboard_datasets,
get_dashboard_info,
get_dashboard_layout,
list_dashboards,

View File

@@ -374,6 +374,17 @@ class GetDashboardLayoutRequest(BaseModel):
]
class GetDashboardDatasetsRequest(BaseModel):
"""Request schema for get_dashboard_datasets."""
identifier: Annotated[
int | str,
Field(
description="Dashboard identifier - can be numeric ID, UUID string, or slug"
),
]
logger = logging.getLogger(__name__)
@@ -1298,3 +1309,225 @@ def dashboard_layout_serializer(dashboard: "Dashboard") -> DashboardLayout:
has_layout=bool(position_json_str),
)
)
# Per-dataset caps keep responses small enough for LLM context: wide
# datasets can have hundreds of columns, which would dwarf the fields an
# agent actually needs to configure native filters.
MAX_DASHBOARD_DATASET_COLUMNS = 100
MAX_DASHBOARD_DATASET_METRICS = 50
class DashboardDatasetColumn(BaseModel):
"""Lean column representation for dashboard dataset context."""
column_name: str = Field(..., description="Column name")
verbose_name: str | None = Field(None, description="Verbose (display) name")
type: str | None = Field(None, description="Column data type")
is_dttm: bool | None = Field(None, description="Is datetime column")
class DashboardDatasetMetric(BaseModel):
"""Lean metric representation for dashboard dataset context."""
metric_name: str = Field(..., description="Saved metric name")
verbose_name: str | None = Field(None, description="Verbose (display) name")
expression: str | None = Field(None, description="SQL expression")
class DashboardDatasetDatabaseInfo(BaseModel):
"""Database connection summary for a dashboard dataset."""
id: int | None = Field(None, description="Database ID")
name: str | None = Field(None, description="Database name")
backend: str | None = Field(None, description="Database backend (engine)")
class DashboardDatasetSummary(BaseModel):
"""A dataset used by a dashboard's charts, with columns and metrics."""
id: int | None = Field(None, description="Dataset ID")
uuid: str | None = Field(None, description="Dataset UUID")
table_name: str | None = Field(None, description="Table name")
schema_name: str | None = Field(None, description="Schema name")
database: DashboardDatasetDatabaseInfo | None = Field(
None, description="Database the dataset belongs to"
)
chart_count: int = Field(
0, description="Number of charts on the dashboard using this dataset"
)
columns: List[DashboardDatasetColumn] = Field(
default_factory=list, description="Dataset columns"
)
metrics: List[DashboardDatasetMetric] = Field(
default_factory=list, description="Dataset metrics"
)
total_column_count: int = Field(
0, description="Total number of columns on the dataset"
)
total_metric_count: int = Field(
0, description="Total number of metrics on the dataset"
)
columns_truncated: bool = Field(
False,
description=(
"True when the columns list was truncated to keep the response small"
),
)
metrics_truncated: bool = Field(
False,
description=(
"True when the metrics list was truncated to keep the response small"
),
)
@model_serializer(mode="wrap")
def _rename_schema_field(self, serializer: Any, info: Any) -> Dict[str, Any]:
"""Serialize 'schema_name' as 'schema' to match API conventions."""
data = serializer(self)
if "schema_name" in data:
data["schema"] = data.pop("schema_name")
return data
class DashboardDatasets(BaseModel):
"""Response schema for get_dashboard_datasets."""
id: int | None = Field(None, description="Dashboard ID")
dashboard_title: str | None = Field(None, description="Dashboard title")
uuid: str | None = Field(None, description="Dashboard UUID")
dataset_count: int = Field(
0, description="Number of accessible datasets used by the dashboard"
)
inaccessible_dataset_count: int = Field(
0,
description=(
"Number of datasets used by the dashboard that the current user "
"cannot access (excluded from 'datasets')"
),
)
datasets: List[DashboardDatasetSummary] = Field(
default_factory=list,
description="Datasets used by the dashboard's charts",
)
def _serialize_dashboard_dataset(
datasource: Any, chart_count: int
) -> DashboardDatasetSummary:
"""Serialize a datasource to a lean, LLM-safe dataset summary."""
all_columns = list(getattr(datasource, "columns", None) or [])
all_metrics = list(getattr(datasource, "metrics", None) or [])
columns = [
DashboardDatasetColumn(
column_name=escape_llm_context_delimiters(
getattr(column, "column_name", None) or ""
),
verbose_name=sanitize_for_llm_context(
getattr(column, "verbose_name", None),
field_path=("columns", str(index), "verbose_name"),
),
type=getattr(column, "type", None),
is_dttm=getattr(column, "is_dttm", None),
)
for index, column in enumerate(all_columns[:MAX_DASHBOARD_DATASET_COLUMNS])
]
metrics = [
DashboardDatasetMetric(
metric_name=escape_llm_context_delimiters(
getattr(metric, "metric_name", None) or ""
),
verbose_name=sanitize_for_llm_context(
getattr(metric, "verbose_name", None),
field_path=("metrics", str(index), "verbose_name"),
),
expression=sanitize_for_llm_context(
getattr(metric, "expression", None),
field_path=("metrics", str(index), "expression"),
),
)
for index, metric in enumerate(all_metrics[:MAX_DASHBOARD_DATASET_METRICS])
]
database = getattr(datasource, "database", None)
database_info = (
DashboardDatasetDatabaseInfo(
id=getattr(database, "id", None),
name=escape_llm_context_delimiters(
getattr(database, "database_name", None)
),
backend=getattr(database, "backend", None),
)
if database is not None
else None
)
dataset_uuid = getattr(datasource, "uuid", None)
return DashboardDatasetSummary(
id=getattr(datasource, "id", None),
uuid=str(dataset_uuid) if dataset_uuid else None,
table_name=escape_llm_context_delimiters(
getattr(datasource, "table_name", None)
),
schema_name=escape_llm_context_delimiters(getattr(datasource, "schema", None)),
database=database_info,
chart_count=chart_count,
columns=columns,
metrics=metrics,
total_column_count=len(all_columns),
total_metric_count=len(all_metrics),
columns_truncated=len(all_columns) > MAX_DASHBOARD_DATASET_COLUMNS,
metrics_truncated=len(all_metrics) > MAX_DASHBOARD_DATASET_METRICS,
)
def dashboard_datasets_serializer(dashboard: "Dashboard") -> DashboardDatasets:
"""Serialize a Dashboard model to the datasets used by its charts.
Groups the dashboard's charts by datasource (mirroring
``Dashboard.datasets_trimmed_for_slices``) but keeps the full column and
metric lists (capped) since native-filter configuration regularly needs
columns that no chart references. Datasets the current user cannot
access are excluded and only counted.
"""
from superset.mcp_service.auth import has_dataset_access
slices_by_datasource: Dict[int, List[Any]] = {}
for slc in getattr(dashboard, "slices", None) or []:
datasource_id = getattr(slc, "datasource_id", None)
if datasource_id is None:
continue
slices_by_datasource.setdefault(datasource_id, []).append(slc)
datasets: List[DashboardDatasetSummary] = []
inaccessible_count = 0
for slices in slices_by_datasource.values():
datasource = next(
(
getattr(slc, "datasource", None)
for slc in slices
if getattr(slc, "datasource", None) is not None
),
None,
)
if datasource is None:
continue
if not has_dataset_access(datasource):
inaccessible_count += 1
continue
datasets.append(_serialize_dashboard_dataset(datasource, len(slices)))
datasets.sort(key=lambda dataset: dataset.id or 0)
return DashboardDatasets(
id=dashboard.id,
dashboard_title=sanitize_for_llm_context(
dashboard.dashboard_title or "Untitled",
field_path=("dashboard_title",),
),
uuid=str(dashboard.uuid) if dashboard.uuid else None,
dataset_count=len(datasets),
inaccessible_dataset_count=inaccessible_count,
datasets=datasets,
)

View File

@@ -17,12 +17,14 @@
from .add_chart_to_existing_dashboard import add_chart_to_existing_dashboard
from .generate_dashboard import generate_dashboard
from .get_dashboard_datasets import get_dashboard_datasets
from .get_dashboard_info import get_dashboard_info
from .get_dashboard_layout import get_dashboard_layout
from .list_dashboards import list_dashboards
__all__ = [
"list_dashboards",
"get_dashboard_datasets",
"get_dashboard_info",
"get_dashboard_layout",
"generate_dashboard",

View File

@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Get dashboard datasets FastMCP tool
Returns the datasets used by a dashboard's charts, including columns and
metrics. This is the prerequisite context an agent needs before configuring
native filters on a dashboard (e.g. picking filter target columns).
"""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from sqlalchemy.orm import subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.dashboard.schemas import (
dashboard_datasets_serializer,
DashboardDatasets,
DashboardError,
GetDashboardDatasetsRequest,
)
from superset.mcp_service.mcp_core import ModelGetInfoCore
logger = logging.getLogger(__name__)
@tool(
tags=["core"],
class_permission_name="Dashboard",
annotations=ToolAnnotations(
title="Get dashboard datasets",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_dashboard_datasets(
request: GetDashboardDatasetsRequest, ctx: Context
) -> DashboardDatasets | DashboardError:
"""
List the datasets used by a dashboard's charts, by ID, UUID, or slug.
Each dataset includes its table name, schema, database connection
(id, name, backend), columns (name, type, is_dttm, verbose_name) and
metrics (name, expression, verbose_name). Use this to understand which
columns and metrics are available before configuring native filters or
analyzing a dashboard's data model.
Datasets the current user cannot access are excluded from the response
and reported via inaccessible_dataset_count. Column and metric lists are
capped per dataset; when truncated, columns_truncated/metrics_truncated
are set and total counts are reported.
Example usage:
```json
{
"identifier": 123
}
```
"""
await ctx.info(
"Retrieving dashboard datasets: identifier=%s" % (request.identifier,)
)
try:
from superset.daos.dashboard import DashboardDAO
from superset.models.dashboard import Dashboard
# Eager load slices to avoid N+1 queries when grouping by datasource.
eager_options = [subqueryload(Dashboard.slices)]
with event_logger.log_context(action="mcp.get_dashboard_datasets.lookup"):
core = ModelGetInfoCore(
dao_class=DashboardDAO,
output_schema=DashboardDatasets,
error_schema=DashboardError,
serializer=dashboard_datasets_serializer,
supports_slug=True,
logger=logger,
query_options=eager_options,
)
result = core.run_tool(request.identifier)
if isinstance(result, DashboardDatasets):
await ctx.info(
"Dashboard datasets retrieved: id=%s, dataset_count=%s, "
"inaccessible_dataset_count=%s"
% (
result.id,
result.dataset_count,
result.inaccessible_dataset_count,
)
)
else:
await ctx.warning(
"Dashboard datasets retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result
except Exception as e:
await ctx.error(
"Dashboard datasets retrieval failed: identifier=%s, error=%s, "
"error_type=%s" % (request.identifier, str(e), type(e).__name__)
)
return DashboardError(
error=f"Failed to get dashboard datasets: {str(e)}",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -17,11 +17,10 @@
import pytest
from pytest_mock import MockerFixture
from superset.commands.chart.exceptions import ChartForbiddenError, ChartInvalidError
from superset.commands.chart.exceptions import ChartForbiddenError
from superset.commands.chart.update import UpdateChartCommand
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from superset.utils import json
def _ownership_exc() -> SupersetSecurityException:
@@ -92,73 +91,3 @@ def test_update_chart_owner_can_perform_regular_update(
find_by_id.assert_called_once_with(1)
raise_for_ownership.assert_called_once()
def _query_context_payload(datasource: object) -> dict[str, object]:
return {
"query_context": json.dumps({"datasource": datasource, "queries": []}),
"query_context_generation": True,
}
def test_update_chart_query_context_matching_datasource_is_allowed(
mocker: MockerFixture,
) -> None:
"""A query context that targets the chart's own datasource is accepted."""
find_by_id = mocker.patch("superset.commands.chart.update.ChartDAO.find_by_id")
find_by_id.return_value = mocker.MagicMock(
id=1, tags=[], dashboards=[], datasource_id=42, datasource_type="table"
)
mocker.patch("superset.commands.chart.update.security_manager.raise_for_ownership")
UpdateChartCommand(
1, _query_context_payload({"id": 42, "type": "table"})
).validate()
@pytest.mark.parametrize(
"datasource",
[
{"id": 99, "type": "table"}, # different id
{"id": 42, "type": "query"}, # different type
{"id": "99", "type": "table"}, # different id as string
],
)
def test_update_chart_query_context_mismatched_datasource_is_rejected(
mocker: MockerFixture,
datasource: dict[str, object],
) -> None:
"""A query context pointing at a different datasource is rejected with a 4xx."""
find_by_id = mocker.patch("superset.commands.chart.update.ChartDAO.find_by_id")
find_by_id.return_value = mocker.MagicMock(
id=1, tags=[], dashboards=[], datasource_id=42, datasource_type="table"
)
mocker.patch("superset.commands.chart.update.security_manager.raise_for_ownership")
with pytest.raises(ChartInvalidError):
UpdateChartCommand(1, _query_context_payload(datasource)).validate()
@pytest.mark.parametrize(
"query_context",
[
"{}", # no datasource key
'{"datasource": null}', # null datasource
"not-json", # unparseable payload
],
)
def test_update_chart_query_context_without_datasource_is_allowed(
mocker: MockerFixture,
query_context: str,
) -> None:
"""Payloads with no verifiable datasource fall back to the chart's own."""
find_by_id = mocker.patch("superset.commands.chart.update.ChartDAO.find_by_id")
find_by_id.return_value = mocker.MagicMock(
id=1, tags=[], dashboards=[], datasource_id=42, datasource_type="table"
)
mocker.patch("superset.commands.chart.update.security_manager.raise_for_ownership")
UpdateChartCommand(
1,
{"query_context": query_context, "query_context_generation": True},
).validate()

View File

@@ -0,0 +1,355 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for the MCP get_dashboard_datasets tool."""
from unittest.mock import Mock, patch
import pytest
from fastmcp import Client
from superset.mcp_service.app import mcp
from superset.mcp_service.utils.sanitization import (
LLM_CONTEXT_CLOSE_DELIMITER,
LLM_CONTEXT_OPEN_DELIMITER,
)
from superset.utils import json
def _wrapped(value: str) -> str:
return f"{LLM_CONTEXT_OPEN_DELIMITER}\n{value}\n{LLM_CONTEXT_CLOSE_DELIMITER}"
def _build_column_mock(
name: str,
*,
verbose_name: str | None = None,
type_: str | None = "VARCHAR",
is_dttm: bool = False,
) -> Mock:
column = Mock()
column.column_name = name
column.verbose_name = verbose_name
column.type = type_
column.is_dttm = is_dttm
return column
def _build_metric_mock(
name: str,
*,
verbose_name: str | None = None,
expression: str | None = None,
) -> Mock:
metric = Mock()
metric.metric_name = name
metric.verbose_name = verbose_name
metric.expression = expression
return metric
def _build_database_mock(
*, database_id: int = 7, name: str = "examples", backend: str = "postgresql"
) -> Mock:
database = Mock()
database.id = database_id
database.database_name = name
database.backend = backend
return database
def _build_datasource_mock(
*,
dataset_id: int,
uuid: str | None = None,
table_name: str = "my_table",
schema: str | None = "public",
database: Mock | None = None,
columns: list[Mock] | None = None,
metrics: list[Mock] | None = None,
) -> Mock:
datasource = Mock()
datasource.id = dataset_id
datasource.uuid = uuid
datasource.table_name = table_name
datasource.schema = schema
datasource.database = database
datasource.columns = columns or []
datasource.metrics = metrics or []
return datasource
def _build_slice_mock(datasource: Mock) -> Mock:
slc = Mock()
slc.datasource_id = datasource.id
slc.datasource = datasource
return slc
def _build_dashboard_mock(
*,
dashboard_id: int = 1,
title: str = "Test Dashboard",
uuid: str | None = "dashboard-uuid-1",
slices: list[Mock] | None = None,
) -> Mock:
dashboard = Mock()
dashboard.id = dashboard_id
dashboard.dashboard_title = title
dashboard.uuid = uuid
dashboard.slices = slices or []
return dashboard
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
@pytest.fixture(autouse=True)
def mock_dataset_access():
with patch(
"superset.mcp_service.auth.has_dataset_access", return_value=True
) as mock_access:
yield mock_access
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_dashboard_datasets_multiple_datasets(mock_find, mcp_server):
sales = _build_datasource_mock(
dataset_id=10,
uuid="dataset-uuid-10",
table_name="sales",
schema="public",
database=_build_database_mock(),
columns=[
_build_column_mock("region", verbose_name="Region"),
_build_column_mock("order_date", type_="TIMESTAMP", is_dttm=True),
],
metrics=[
_build_metric_mock(
"total_revenue",
verbose_name="Total Revenue",
expression="SUM(revenue)",
)
],
)
customers = _build_datasource_mock(
dataset_id=20,
uuid="dataset-uuid-20",
table_name="customers",
schema="crm",
database=_build_database_mock(database_id=8, name="crm_db", backend="mysql"),
columns=[_build_column_mock("customer_name")],
metrics=[],
)
mock_find.return_value = _build_dashboard_mock(
slices=[
_build_slice_mock(sales),
_build_slice_mock(sales),
_build_slice_mock(customers),
]
)
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_datasets", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["dashboard_title"] == _wrapped("Test Dashboard")
assert data["uuid"] == "dashboard-uuid-1"
assert data["dataset_count"] == 2
assert data["inaccessible_dataset_count"] == 0
assert len(data["datasets"]) == 2
datasets_by_id = {d["id"]: d for d in data["datasets"]}
sales_data = datasets_by_id[10]
assert sales_data["uuid"] == "dataset-uuid-10"
assert sales_data["table_name"] == "sales"
assert sales_data["schema"] == "public"
assert sales_data["database"] == {
"id": 7,
"name": "examples",
"backend": "postgresql",
}
assert sales_data["chart_count"] == 2
assert sales_data["columns"] == [
{
"column_name": "region",
"verbose_name": _wrapped("Region"),
"type": "VARCHAR",
"is_dttm": False,
},
{
"column_name": "order_date",
"verbose_name": None,
"type": "TIMESTAMP",
"is_dttm": True,
},
]
assert sales_data["metrics"] == [
{
"metric_name": "total_revenue",
"verbose_name": _wrapped("Total Revenue"),
"expression": _wrapped("SUM(revenue)"),
}
]
assert sales_data["total_column_count"] == 2
assert sales_data["total_metric_count"] == 1
assert sales_data["columns_truncated"] is False
assert sales_data["metrics_truncated"] is False
customers_data = datasets_by_id[20]
assert customers_data["table_name"] == "customers"
assert customers_data["schema"] == "crm"
assert customers_data["chart_count"] == 1
assert customers_data["metrics"] == []
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_dashboard_datasets_by_slug(mock_find, mcp_server):
datasource = _build_datasource_mock(
dataset_id=10,
table_name="sales",
database=_build_database_mock(),
columns=[_build_column_mock("region")],
)
dashboard = _build_dashboard_mock(slices=[_build_slice_mock(datasource)])
def find_by_id(identifier, id_column=None, query_options=None):
if id_column == "slug" and identifier == "sales-dash":
return dashboard
return None
mock_find.side_effect = find_by_id
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_datasets", {"request": {"identifier": "sales-dash"}}
)
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["dataset_count"] == 1
assert data["datasets"][0]["table_name"] == "sales"
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_dashboard_datasets_not_found(mock_find, mcp_server):
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_datasets", {"request": {"identifier": 999}}
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "not_found"
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_dashboard_datasets_empty_dashboard(mock_find, mcp_server):
mock_find.return_value = _build_dashboard_mock(slices=[])
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_datasets", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["dataset_count"] == 0
assert data["inaccessible_dataset_count"] == 0
assert data["datasets"] == []
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_dashboard_datasets_excludes_inaccessible(
mock_find, mcp_server, mock_dataset_access
):
allowed = _build_datasource_mock(dataset_id=10, table_name="sales")
denied = _build_datasource_mock(dataset_id=20, table_name="secrets")
mock_find.return_value = _build_dashboard_mock(
slices=[_build_slice_mock(allowed), _build_slice_mock(denied)]
)
mock_dataset_access.side_effect = lambda datasource: datasource.id != 20
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_datasets", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["dataset_count"] == 1
assert data["inaccessible_dataset_count"] == 1
assert [d["id"] for d in data["datasets"]] == [10]
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_dashboard_datasets_truncates_wide_datasets(mock_find, mcp_server):
from superset.mcp_service.dashboard.schemas import (
MAX_DASHBOARD_DATASET_COLUMNS,
MAX_DASHBOARD_DATASET_METRICS,
)
datasource = _build_datasource_mock(
dataset_id=10,
table_name="wide_table",
columns=[
_build_column_mock(f"col_{i}")
for i in range(MAX_DASHBOARD_DATASET_COLUMNS + 5)
],
metrics=[
_build_metric_mock(f"metric_{i}")
for i in range(MAX_DASHBOARD_DATASET_METRICS + 3)
],
)
mock_find.return_value = _build_dashboard_mock(
slices=[_build_slice_mock(datasource)]
)
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_dashboard_datasets", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
dataset = data["datasets"][0]
assert len(dataset["columns"]) == MAX_DASHBOARD_DATASET_COLUMNS
assert len(dataset["metrics"]) == MAX_DASHBOARD_DATASET_METRICS
assert dataset["columns_truncated"] is True
assert dataset["metrics_truncated"] is True
assert dataset["total_column_count"] == MAX_DASHBOARD_DATASET_COLUMNS + 5
assert dataset["total_metric_count"] == MAX_DASHBOARD_DATASET_METRICS + 3