feat(mcp): add query_dataset tool to query datasets using semantic layer (#39727)

This commit is contained in:
Amin Ghadersohi
2026-04-30 18:03:41 -04:00
committed by GitHub
parent 3f550f166f
commit f29d82b3b1
5 changed files with 1478 additions and 0 deletions

View File

@@ -62,6 +62,7 @@ Dataset Management:
- list_datasets: List datasets with advanced filters (1-based pagination)
- get_dataset_info: Get detailed dataset information by ID (includes columns/metrics)
- create_virtual_dataset: Save a SQL query as a virtual dataset for charting
- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart
Chart Management:
- list_charts: List charts with advanced filters (1-based pagination)
@@ -164,6 +165,17 @@ Use created_by_me for authorship, owned_by_me for edit ownership, or both
together for the union. All flags can be combined with 'filters' but not
with 'search'.
To query a dataset's semantic layer (metrics, dimensions):
1. list_datasets(request={{}}) -> find a dataset
2. get_dataset_info(request={{"identifier": <id>}}) -> examine columns AND metrics
3. query_dataset(request={{
"dataset_id": <id>,
"metrics": ["count", "avg_revenue"],
"columns": ["category"],
"time_range": "Last 7 days",
"row_limit": 100
}}) -> returns tabular data using saved metrics and dimensions
To explore data with SQL:
1. list_datasets(request={{}}) -> find a dataset and note its database_id
2. execute_sql(request={{"database_id": <id>, "sql": "SELECT ..."}})
@@ -520,6 +532,7 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
create_virtual_dataset,
get_dataset_info,
list_datasets,
query_dataset,
)
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
generate_explore_link,

View File

@@ -36,10 +36,13 @@ from pydantic import (
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
from superset.mcp_service.common.cache_schemas import (
CacheStatus,
CreatedByMeMixin,
MetadataCacheControl,
OwnedByMeMixin,
QueryCacheControl,
)
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.privacy import filter_user_directory_fields
@@ -393,6 +396,146 @@ class CreateVirtualDatasetResponse(BaseModel):
)
VALID_FILTER_OPS = Literal[
"==",
"!=",
">",
"<",
">=",
"<=",
"LIKE",
"NOT LIKE",
"ILIKE",
"NOT ILIKE",
"IN",
"NOT IN",
"IS NULL",
"IS NOT NULL",
"IS TRUE",
"IS FALSE",
"TEMPORAL_RANGE",
]
class QueryDatasetFilter(BaseModel):
"""A single filter condition for dataset queries."""
col: str = Field(..., description="Column name to filter on")
op: VALID_FILTER_OPS = Field(
...,
description=(
'Filter operator. Use "==" for equals, "!=" for not equals, '
'"IN" / "NOT IN" for membership, "IS NULL" / "IS NOT NULL", '
'"LIKE" for pattern matching, "TEMPORAL_RANGE" for time filters.'
),
)
val: Any = Field(
default=None,
description="Filter value (omit for IS NULL/IS NOT NULL)",
)
class QueryDatasetRequest(QueryCacheControl):
"""Request schema for query_dataset tool."""
dataset_id: int | str = Field(
...,
description="Dataset identifier — numeric ID or UUID string.",
)
metrics: List[str] = Field(
default_factory=list,
description=(
"Saved metric names to compute (e.g. ['count', 'avg_revenue']). "
"Use get_dataset_info to discover available metrics."
),
)
columns: List[str] = Field(
default_factory=list,
description=(
"Column/dimension names for GROUP BY or SELECT "
"(e.g. ['category', 'region']). "
"Use get_dataset_info to discover available columns."
),
)
filters: List[QueryDatasetFilter] = Field(
default_factory=list,
description=(
'Filter conditions (e.g. [{"col": "status", "op": "==", "val": "active"}]).'
),
)
time_range: str | None = Field(
default=None,
description=(
"Time range filter (e.g. 'Last 7 days', 'Last month', "
"'2024-01-01 : 2024-12-31'). Requires a temporal column "
"on the dataset."
),
)
time_column: str | None = Field(
default=None,
description=(
"Temporal column to apply time_range to. "
"Defaults to the dataset's main datetime column."
),
)
order_by: List[str] | None = Field(
default=None,
description="Column or metric names to sort results by.",
)
order_desc: bool = Field(
default=True,
description="Sort descending (True) or ascending (False).",
)
row_limit: int = Field(
default=1000,
ge=1,
le=50000,
description="Maximum number of rows to return (default 1000, max 50000).",
)
@model_validator(mode="after")
def validate_metrics_or_columns(self) -> "QueryDatasetRequest":
"""At least one of metrics or columns must be provided."""
if not self.metrics and not self.columns:
raise ValueError(
"At least one of 'metrics' or 'columns' must be provided. "
"Use get_dataset_info to discover available metrics and columns."
)
return self
class QueryDatasetResponse(BaseModel):
"""Response schema for query_dataset tool."""
model_config = ConfigDict(ser_json_timedelta="iso8601")
dataset_id: int = Field(..., description="Dataset ID")
dataset_name: str = Field(..., description="Dataset name")
columns: List[DataColumn] = Field(
default_factory=list, description="Column metadata for returned data"
)
data: List[Dict[str, Any]] = Field(
default_factory=list, description="Query result rows"
)
row_count: int = Field(0, description="Number of rows returned")
total_rows: int | None = Field(
None, description="Total row count from the query engine"
)
summary: str = Field("", description="Human-readable summary of the results")
performance: PerformanceMetadata | None = Field(
None, description="Query performance metadata"
)
cache_status: CacheStatus | None = Field(
None, description="Cache hit/miss information"
)
applied_filters: List[QueryDatasetFilter] = Field(
default_factory=list, description="Filters that were applied to the query"
)
warnings: List[str] = Field(
default_factory=list, description="Any warnings encountered during execution"
)
def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None:
"""Parse a field that may be stored as a JSON string into a dict."""
value = getattr(obj, field_name, None)

View File

@@ -18,9 +18,11 @@
from .create_virtual_dataset import create_virtual_dataset
from .get_dataset_info import get_dataset_info
from .list_datasets import list_datasets
from .query_dataset import query_dataset
__all__ = [
"create_virtual_dataset",
"list_datasets",
"get_dataset_info",
"query_dataset",
]

View File

@@ -0,0 +1,489 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
MCP tool: query_dataset
Query a dataset using its semantic layer (saved metrics, calculated columns,
dimensions) without requiring a saved chart.
"""
import difflib
import logging
import time
from typing import Any
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import joinedload, subqueryload
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
from superset.mcp_service.dataset.schemas import (
DatasetError,
QueryDatasetFilter,
QueryDatasetRequest,
QueryDatasetResponse,
)
from superset.mcp_service.privacy import (
DATA_MODEL_METADATA_ERROR_TYPE,
requires_data_model_metadata_access,
user_can_view_data_model_metadata,
)
from superset.mcp_service.utils import _is_uuid
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message
logger = logging.getLogger(__name__)
def _resolve_dataset(identifier: int | str, eager_options: list[Any]) -> Any | None:
"""Resolve a dataset by int ID or UUID string.
Replicates the identifier resolution logic from ModelGetInfoCore._find_object().
"""
from superset.daos.dataset import DatasetDAO
opts = eager_options or None
if isinstance(identifier, int):
return DatasetDAO.find_by_id(identifier, query_options=opts)
# Try parsing as int
try:
id_val = int(identifier)
return DatasetDAO.find_by_id(id_val, query_options=opts)
except (ValueError, TypeError):
pass
# Try UUID
if _is_uuid(str(identifier)):
return DatasetDAO.find_by_id(identifier, id_column="uuid", query_options=opts)
return None
def _validate_names(
requested: list[str],
valid: set[str],
kind: str,
) -> list[str]:
"""Return list of error messages for names not found in *valid*.
Includes close-match suggestions when available.
"""
errors: list[str] = []
for name in requested:
if name not in valid:
suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6)
msg = f"Unknown {kind}: '{name}'"
if suggestions:
msg += f". Did you mean: {', '.join(suggestions)}?"
errors.append(msg)
return errors
@requires_data_model_metadata_access
@tool(
tags=["data"],
class_permission_name="Dataset",
annotations=ToolAnnotations(
title="Query dataset",
readOnlyHint=True,
destructiveHint=False,
),
)
async def query_dataset( # noqa: C901
request: QueryDatasetRequest, ctx: Context
) -> QueryDatasetResponse | DatasetError:
"""Query a dataset using its semantic layer (saved metrics, dimensions, filters).
Returns tabular data without requiring a saved chart. Use this when you want
to compute saved metrics, group by dimensions, or apply filters directly
against a dataset's curated semantic layer.
Workflow:
1. list_datasets -> find a dataset
2. get_dataset_info -> discover available columns and metrics
3. query_dataset -> query using metric names and column names
Example:
```json
{
"dataset_id": 123,
"metrics": ["count", "avg_revenue"],
"columns": ["product_category"],
"time_range": "Last 7 days",
"row_limit": 100
}
```
"""
await ctx.info(
"Starting dataset query: dataset_id=%s, metrics=%s, columns=%s, "
"row_limit=%s"
% (
request.dataset_id,
request.metrics,
request.columns,
request.row_limit,
)
)
try:
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.common.query_context_factory import QueryContextFactory
from superset.connectors.sqla.models import SqlaTable
# ------------------------------------------------------------------
# Step 1: Check data-model metadata access BEFORE the dataset lookup.
# Doing this first prevents leaking dataset existence — restricted
# users always receive DataModelMetadataRestricted, never NotFound.
# The decorator hides this tool from search; this check enforces
# direct calls that bypass tool discovery.
# ------------------------------------------------------------------
if not user_can_view_data_model_metadata():
await ctx.warning("Dataset metadata access blocked by privacy controls")
return DatasetError.create(
error=(
"You don't have permission to access dataset details for your role."
),
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
)
# ------------------------------------------------------------------
# Step 2: Resolve dataset
# ------------------------------------------------------------------
await ctx.report_progress(1, 5, "Looking up dataset")
eager_options = [
subqueryload(SqlaTable.columns),
subqueryload(SqlaTable.metrics),
joinedload(SqlaTable.database),
]
with event_logger.log_context(action="mcp.query_dataset.lookup"):
dataset = _resolve_dataset(request.dataset_id, eager_options)
if dataset is None:
await ctx.error("Dataset not found: identifier=%s" % (request.dataset_id,))
return DatasetError.create(
error=f"No dataset found with identifier: {request.dataset_id}",
error_type="NotFound",
)
dataset_name = getattr(dataset, "table_name", None) or f"Dataset {dataset.id}"
await ctx.info(
"Dataset found: id=%s, name=%s, columns=%s, metrics=%s"
% (
dataset.id,
dataset_name,
len(dataset.columns),
len(dataset.metrics),
)
)
# ------------------------------------------------------------------
# Step 2: Validate requested columns and metrics
# ------------------------------------------------------------------
await ctx.report_progress(2, 5, "Validating columns and metrics")
valid_columns = {c.column_name for c in dataset.columns}
valid_metrics = {m.metric_name for m in dataset.metrics}
validation_errors: list[str] = []
validation_errors.extend(
_validate_names(request.columns, valid_columns, "column")
)
validation_errors.extend(
_validate_names(request.metrics, valid_metrics, "metric")
)
# Validate filter column names against dataset columns
filter_cols = [f.col for f in request.filters]
validation_errors.extend(
_validate_names(filter_cols, valid_columns, "filter column")
)
# Validate order_by names against columns + metrics
if request.order_by:
valid_orderby = valid_columns | valid_metrics
validation_errors.extend(
_validate_names(request.order_by, valid_orderby, "order_by")
)
if validation_errors:
error_msg = "; ".join(validation_errors)
await ctx.error("Validation failed: %s" % (error_msg,))
return DatasetError.create(
error=error_msg,
error_type="ValidationError",
)
# ------------------------------------------------------------------
# Step 3: Build filters and time range
# ------------------------------------------------------------------
warnings: list[str] = []
query_filters: list[dict[str, Any]] = [
{"col": f.col, "op": f.op, "val": f.val} for f in request.filters
]
# Track all applied filters (including synthesized ones) for the response.
effective_filters: list[QueryDatasetFilter] = list(request.filters)
granularity: str | None = None
if request.time_range:
temporal_col = request.time_column or getattr(
dataset, "main_dttm_col", None
)
if not temporal_col:
await ctx.error("time_range provided but no temporal column available")
return DatasetError.create(
error=(
"time_range was provided but no temporal column is available. "
"Either set time_column explicitly or ensure the dataset has "
"a main datetime column configured."
),
error_type="ValidationError",
)
# Validate that the temporal column actually exists on the dataset
if temporal_col not in valid_columns:
await ctx.error("time_column '%s' not found on dataset" % temporal_col)
return DatasetError.create(
error=(
f"time_column '{temporal_col}' does not exist on this dataset."
),
error_type="ValidationError",
)
# Warn if the chosen temporal column isn't marked as datetime
dttm_cols = {c.column_name for c in dataset.columns if c.is_dttm}
if temporal_col not in dttm_cols:
warnings.append(
f"Column '{temporal_col}' is not marked as a datetime "
f"column on this dataset. Time filtering may not work "
f"as expected."
)
query_filters.append(
{
"col": temporal_col,
"op": "TEMPORAL_RANGE",
"val": request.time_range,
}
)
effective_filters.append(
QueryDatasetFilter(
col=temporal_col,
op="TEMPORAL_RANGE",
val=request.time_range,
)
)
granularity = temporal_col
await ctx.debug(
"Time filter: column=%s, range=%s" % (temporal_col, request.time_range)
)
# ------------------------------------------------------------------
# Step 4: Build query dict
# ------------------------------------------------------------------
await ctx.report_progress(3, 5, "Building query")
query_dict: dict[str, Any] = {
"filters": query_filters,
"columns": request.columns,
"metrics": request.metrics,
"row_limit": request.row_limit,
"order_desc": request.order_desc,
}
if granularity:
query_dict["granularity"] = granularity
if request.order_by:
# OrderBy = tuple[Metric | Column, bool] where bool is ascending
query_dict["orderby"] = [
(col, not request.order_desc) for col in request.order_by
]
await ctx.debug("Query dict keys: %s" % (sorted(query_dict.keys()),))
# ------------------------------------------------------------------
# Step 5: Create QueryContext and execute
# ------------------------------------------------------------------
await ctx.report_progress(4, 5, "Executing query")
start_time = time.time()
with event_logger.log_context(action="mcp.query_dataset.execute"):
factory = QueryContextFactory()
# datasource_type is "table" because this tool queries SqlaTable
# datasets (Superset's built-in semantic layer). External semantic
# layers (dbt, Snowflake Cortex, etc.) use "semantic_view" and have
# a different query path — see SemanticView + mapper.py.
query_context = factory.create(
datasource={"id": dataset.id, "type": "table"},
queries=[query_dict],
form_data={},
force=not request.use_cache or request.force_refresh,
custom_cache_timeout=request.cache_timeout,
)
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
query_duration_ms = int((time.time() - start_time) * 1000)
if not result or "queries" not in result or len(result["queries"]) == 0:
await ctx.warning("Query returned no results for dataset %s" % dataset.id)
return DatasetError.create(
error="Query returned no results.",
error_type="EmptyQuery",
)
# ------------------------------------------------------------------
# Step 6: Format response
# ------------------------------------------------------------------
await ctx.report_progress(5, 5, "Formatting results")
query_result = result["queries"][0]
data = query_result.get("data", [])
raw_columns = query_result.get("colnames", [])
if not data:
return QueryDatasetResponse(
dataset_id=dataset.id,
dataset_name=dataset_name,
columns=[],
data=[],
row_count=0,
total_rows=0,
summary=f"Query on '{dataset_name}' returned no data.",
performance=PerformanceMetadata(
query_duration_ms=query_duration_ms,
cache_status="no_data",
),
cache_status=get_cache_status_from_result(
query_result, force_refresh=request.force_refresh
),
applied_filters=effective_filters,
warnings=warnings,
)
# Build column metadata in a single pass per column.
# Cap stats computation at STATS_SAMPLE rows to avoid O(rows*cols)
# overhead on large result sets (row_limit allows up to 50k).
stats_sample_size = 5000
stats_rows = data[:stats_sample_size]
columns_meta: list[DataColumn] = []
for col_name in raw_columns:
sample_values = [
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
]
data_type = "string"
if sample_values:
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
# Compute null_count and unique non-null values in one pass
null_count = 0
unique_vals: set[str] = set()
for row in stats_rows:
val = row.get(col_name)
if val is None:
null_count += 1
else:
unique_vals.add(str(val))
columns_meta.append(
DataColumn(
name=col_name,
display_name=col_name.replace("_", " ").title(),
data_type=data_type,
sample_values=sample_values[:3],
null_count=null_count,
unique_count=len(unique_vals),
)
)
cache_status = get_cache_status_from_result(
query_result, force_refresh=request.force_refresh
)
cache_label = "cached" if cache_status and cache_status.cache_hit else "fresh"
summary = (
f"Dataset '{dataset_name}': {len(data)} rows, "
f"{len(raw_columns)} columns ({cache_label})."
)
await ctx.info(
"Query complete: rows=%s, columns=%s, duration=%sms"
% (len(data), len(raw_columns), query_duration_ms)
)
return QueryDatasetResponse(
dataset_id=dataset.id,
dataset_name=dataset_name,
columns=columns_meta,
data=data,
row_count=len(data),
total_rows=query_result.get("rowcount"),
summary=summary,
performance=PerformanceMetadata(
query_duration_ms=query_duration_ms,
cache_status=cache_label,
),
cache_status=cache_status,
applied_filters=effective_filters,
warnings=warnings,
)
except OAuth2RedirectError as exc:
redirect_msg = build_oauth2_redirect_message(exc)
await ctx.error("OAuth2 redirect required: %s" % (redirect_msg,))
return DatasetError.create(
error=redirect_msg,
error_type="OAuth2Redirect",
)
except OAuth2Error as exc:
await ctx.error("OAuth2 error: %s" % (str(exc),))
return DatasetError.create(
error=f"OAuth2 authentication error: {exc}",
error_type="OAuth2Error",
)
except (CommandException, SupersetException) as exc:
await ctx.error("Query failed: %s" % (str(exc),))
return DatasetError.create(
error=f"Query execution failed: {exc}",
error_type="QueryError",
)
except SQLAlchemyError as exc:
await ctx.error("Database error: %s" % (str(exc),))
return DatasetError.create(
error=f"Database error: {exc}",
error_type="DatabaseError",
)
except Exception as exc:
logger.exception(
"Unexpected error while querying dataset: %s: %s",
type(exc).__name__,
str(exc),
)
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
return DatasetError.create(
error="An unexpected error occurred while querying the dataset.",
error_type="UnexpectedError",
)

View File

@@ -0,0 +1,831 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the query_dataset MCP tool."""
from __future__ import annotations
import importlib
from collections.abc import Generator
from typing import Any
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client, FastMCP
from superset.mcp_service.app import mcp
from superset.utils import json
query_dataset_module = importlib.import_module(
"superset.mcp_service.dataset.tool.query_dataset"
)
@pytest.fixture
def mcp_server() -> FastMCP:
return mcp
@pytest.fixture(autouse=True)
def mock_auth() -> Generator[MagicMock, None, None]:
"""Mock authentication and metadata access for all tests."""
with (
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
patch.object(
query_dataset_module,
"user_can_view_data_model_metadata",
return_value=True,
),
):
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _make_column(name: str, is_dttm: bool = False) -> MagicMock:
"""Build a mock SqlaTable column with the given name and datetime flag."""
col = MagicMock()
col.column_name = name
col.is_dttm = is_dttm
col.verbose_name = None
col.type = "VARCHAR"
col.groupby = True
col.filterable = True
col.description = None
return col
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
"""Build a mock SqlMetric with the given name and SQL expression."""
metric = MagicMock()
metric.metric_name = name
metric.verbose_name = None
metric.expression = expression
metric.description = None
metric.d3format = None
return metric
def _make_dataset(
dataset_id: int = 1,
table_name: str = "orders",
columns: list[Any] | None = None,
metrics: list[Any] | None = None,
main_dttm_col: str | None = None,
) -> MagicMock:
"""Build a mock SqlaTable dataset with default columns and metrics."""
ds = MagicMock()
ds.id = dataset_id
ds.table_name = table_name
ds.uuid = f"test-uuid-{dataset_id}"
ds.main_dttm_col = main_dttm_col
ds.database = MagicMock()
ds.database.database_name = "examples"
ds.columns = columns or [
_make_column("category"),
_make_column("region"),
_make_column("order_date", is_dttm=True),
]
ds.metrics = metrics or [
_make_metric("count", "COUNT(*)"),
_make_metric("total_revenue", "SUM(revenue)"),
]
return ds
def _mock_command_result(
data: list[dict[str, Any]] | None = None,
colnames: list[str] | None = None,
) -> dict[str, Any]:
"""Build the result dict that ChartDataCommand.run() returns."""
data = data or [
{"category": "Electronics", "count": 42},
{"category": "Clothing", "count": 17},
]
colnames = colnames or ["category", "count"]
return {
"queries": [
{
"data": data,
"colnames": colnames,
"rowcount": len(data),
"cache_key": "abc123",
"is_cached": False,
"cached_dttm": None,
"cache_timeout": 300,
}
]
}
@pytest.mark.asyncio
async def test_query_dataset_success(mcp_server: FastMCP) -> None:
"""Happy path: metrics + columns returns data."""
dataset = _make_dataset()
result_data = _mock_command_result()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"columns": ["category"],
}
},
)
data = json.loads(result.content[0].text)
assert data["dataset_id"] == 1
assert data["dataset_name"] == "orders"
assert data["row_count"] == 2
assert len(data["data"]) == 2
assert data["data"][0]["category"] == "Electronics"
@pytest.mark.asyncio
async def test_query_dataset_not_found(mcp_server: FastMCP) -> None:
"""Dataset ID that doesn't exist returns error."""
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=None,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 999,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "NotFound"
assert "999" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_invalid_metric(mcp_server: FastMCP) -> None:
"""Unknown metric name returns validation error with suggestions."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["countt"], # typo
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "countt" in data["error"]
# Should suggest "count" as a close match
assert "count" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_invalid_column(mcp_server: FastMCP) -> None:
"""Unknown column name returns validation error."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"columns": ["nonexistent_col"],
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "nonexistent_col" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_no_metrics_no_columns(mcp_server: FastMCP) -> None:
"""Providing neither metrics nor columns raises validation error."""
from fastmcp.exceptions import ToolError
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="metrics.*columns"):
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": [],
"columns": [],
}
},
)
@pytest.mark.asyncio
async def test_query_dataset_with_time_range(mcp_server: FastMCP) -> None:
"""time_range is converted to TEMPORAL_RANGE filter + granularity."""
dataset = _make_dataset(main_dttm_col="order_date")
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 7 days",
}
},
)
assert len(captured_queries) == 1
query_dict = captured_queries[0]
# Should have TEMPORAL_RANGE filter
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
assert len(temporal_filters) == 1
assert temporal_filters[0]["col"] == "order_date"
assert temporal_filters[0]["val"] == "Last 7 days"
# Should set granularity
assert query_dict["granularity"] == "order_date"
# applied_filters in response must include the synthesized TEMPORAL_RANGE filter
data = json.loads(result.content[0].text)
resp_filters = data["applied_filters"]
temporal_resp = [f for f in resp_filters if f["op"] == "TEMPORAL_RANGE"]
assert len(temporal_resp) == 1
assert temporal_resp[0]["col"] == "order_date"
assert temporal_resp[0]["val"] == "Last 7 days"
@pytest.mark.asyncio
async def test_query_dataset_time_range_no_temporal_column(mcp_server: FastMCP) -> None:
"""time_range without a temporal column returns error."""
dataset = _make_dataset(main_dttm_col=None)
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 7 days",
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "temporal column" in data["error"].lower()
@pytest.mark.asyncio
async def test_query_dataset_with_filters(mcp_server: FastMCP) -> None:
"""User-provided filters are passed through to the query."""
dataset = _make_dataset()
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"filters": [
{"col": "category", "op": "==", "val": "Electronics"}
],
}
},
)
assert len(captured_queries) == 1
filters = captured_queries[0]["filters"]
assert len(filters) == 1
assert filters[0]["col"] == "category"
assert filters[0]["op"] == "=="
assert filters[0]["val"] == "Electronics"
@pytest.mark.asyncio
async def test_query_dataset_empty_results(mcp_server: FastMCP) -> None:
"""Query that returns no data gives a response with row_count=0."""
dataset = _make_dataset()
empty_result = {
"queries": [
{
"data": [],
"colnames": [],
"rowcount": 0,
"is_cached": False,
"cached_dttm": None,
"cache_timeout": 300,
}
]
}
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=empty_result,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["row_count"] == 0
assert data["data"] == []
assert "no data" in data["summary"].lower()
@pytest.mark.asyncio
async def test_query_dataset_by_uuid(mcp_server: FastMCP) -> None:
"""UUID-based lookup works."""
dataset = _make_dataset()
result_data = _mock_command_result()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
) as mock_resolve,
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": "a1b2c3d4-5678-90ab-cdef-1234567890ab",
"metrics": ["count"],
}
},
)
# Verify the resolve function was called with the UUID
mock_resolve.assert_called_once()
call_args = mock_resolve.call_args
assert call_args[0][0] == "a1b2c3d4-5678-90ab-cdef-1234567890ab"
data = json.loads(result.content[0].text)
assert data["dataset_id"] == 1
@pytest.mark.asyncio
async def test_query_dataset_permission_denied(mcp_server: FastMCP) -> None:
"""Permission denied from ChartDataCommand.validate() returns error."""
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
dataset = _make_dataset()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
side_effect=SupersetSecurityException(
SupersetError(
message="Access denied",
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.WARNING,
)
),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "QueryError"
@pytest.mark.asyncio
async def test_query_dataset_order_by_valid(mcp_server: FastMCP) -> None:
"""order_by with valid column/metric names passes through."""
dataset = _make_dataset()
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"columns": ["category"],
"order_by": ["count"],
"order_desc": True,
}
},
)
assert len(captured_queries) == 1
orderby = captured_queries[0].get("orderby", [])
assert len(orderby) == 1
assert orderby[0][0] == "count"
# order_desc=True -> ascending=False
assert orderby[0][1] is False
@pytest.mark.asyncio
async def test_query_dataset_order_by_invalid(mcp_server: FastMCP) -> None:
"""order_by with an unknown name returns validation error."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"order_by": ["nonexistent"],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "nonexistent" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_time_column_override(mcp_server: FastMCP) -> None:
"""Explicit time_column overrides dataset main_dttm_col."""
dataset = _make_dataset(main_dttm_col="order_date")
result_data = _mock_command_result()
captured_queries: list[dict[str, Any]] = []
def capture_create(**kwargs):
captured_queries.extend(kwargs.get("queries", []))
return MagicMock()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
side_effect=capture_create,
),
):
async with Client(mcp_server) as client:
await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 30 days",
"time_column": "order_date",
}
},
)
assert len(captured_queries) == 1
query_dict = captured_queries[0]
assert query_dict["granularity"] == "order_date"
temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"]
assert temporal_filters[0]["col"] == "order_date"
@pytest.mark.asyncio
async def test_query_dataset_non_dttm_time_column_warns(mcp_server: FastMCP) -> None:
"""Using a non-datetime column for time_range produces a warning."""
dataset = _make_dataset(main_dttm_col=None)
result_data = _mock_command_result()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.validate",
),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand.run",
return_value=result_data,
),
patch(
"superset.common.query_context_factory.QueryContextFactory.create",
return_value=MagicMock(),
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"time_range": "Last 7 days",
"time_column": "category",
}
},
)
data = json.loads(result.content[0].text)
assert len(data["warnings"]) > 0
assert "not marked as a datetime" in data["warnings"][0]
@pytest.mark.asyncio
async def test_query_dataset_invalid_filter_column(mcp_server: FastMCP) -> None:
"""Filter on a column that doesn't exist returns validation error."""
dataset = _make_dataset()
with patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
"metrics": ["count"],
"filters": [
{
"col": "nonexistent",
"op": "==",
"val": "test",
}
],
}
},
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "ValidationError"
assert "nonexistent" in data["error"]
@pytest.mark.asyncio
async def test_query_dataset_metadata_access_denied_no_suggestions(
mcp_server: FastMCP,
) -> None:
"""Users without data-model metadata access cannot probe column/metric names.
The privacy gate must fire before the validation step that returns close-match
suggestions, so restricted users cannot enumerate schema details via typos.
"""
dataset = _make_dataset()
with (
patch.object(
query_dataset_module,
"_resolve_dataset",
return_value=dataset,
),
patch.object(
query_dataset_module,
"user_can_view_data_model_metadata",
return_value=False,
),
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
"dataset_id": 1,
# Typo that would normally trigger close-match suggestions
"metrics": ["countt"],
}
},
)
data = json.loads(result.content[0].text)
# Must be denied before returning any schema suggestions
assert data["error_type"] == "DataModelMetadataRestricted"
# Must NOT contain column/metric name suggestions
assert "countt" not in data.get("error", "")
assert "count" not in data.get("error", "")
@pytest.mark.asyncio
async def test_query_dataset_metadata_access_denied_nonexistent_dataset(
mcp_server: FastMCP,
) -> None:
"""Metadata-restricted users must not be able to probe dataset existence.
The privacy gate fires before the DAO lookup, so a restricted caller
always receives DataModelMetadataRestricted — never NotFound — regardless
of whether the requested dataset ID exists.
"""
with patch.object(
query_dataset_module,
"user_can_view_data_model_metadata",
return_value=False,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"query_dataset",
{
"request": {
# Use a dataset_id that does not exist
"dataset_id": 999999,
"metrics": ["count"],
}
},
)
data = json.loads(result.content[0].text)
# Must receive restricted error, not a NotFound that leaks existence
assert data["error_type"] == "DataModelMetadataRestricted"
assert data["error_type"] != "NotFound"