diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index aad00a047ca..205720c3c8d 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -62,6 +62,7 @@ Dataset Management: - list_datasets: List datasets with advanced filters (1-based pagination) - get_dataset_info: Get detailed dataset information by ID (includes columns/metrics) - create_virtual_dataset: Save a SQL query as a virtual dataset for charting +- query_dataset: Query a dataset using its semantic layer (saved metrics, dimensions, filters) without needing a saved chart Chart Management: - list_charts: List charts with advanced filters (1-based pagination) @@ -164,6 +165,17 @@ Use created_by_me for authorship, owned_by_me for edit ownership, or both together for the union. All flags can be combined with 'filters' but not with 'search'. +To query a dataset's semantic layer (metrics, dimensions): +1. list_datasets(request={{}}) -> find a dataset +2. get_dataset_info(request={{"identifier": }}) -> examine columns AND metrics +3. query_dataset(request={{ + "dataset_id": , + "metrics": ["count", "avg_revenue"], + "columns": ["category"], + "time_range": "Last 7 days", + "row_limit": 100 + }}) -> returns tabular data using saved metrics and dimensions + To explore data with SQL: 1. list_datasets(request={{}}) -> find a dataset and note its database_id 2. execute_sql(request={{"database_id": , "sql": "SELECT ..."}}) @@ -520,6 +532,7 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402 create_virtual_dataset, get_dataset_info, list_datasets, + query_dataset, ) from superset.mcp_service.explore.tool import ( # noqa: F401, E402 generate_explore_link, diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py index dfb7f8f9faa..d2b35b4f69e 100644 --- a/superset/mcp_service/dataset/schemas.py +++ b/superset/mcp_service/dataset/schemas.py @@ -36,10 +36,13 @@ from pydantic import ( ) from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata from superset.mcp_service.common.cache_schemas import ( + CacheStatus, CreatedByMeMixin, MetadataCacheControl, OwnedByMeMixin, + QueryCacheControl, ) from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE from superset.mcp_service.privacy import filter_user_directory_fields @@ -393,6 +396,146 @@ class CreateVirtualDatasetResponse(BaseModel): ) +VALID_FILTER_OPS = Literal[ + "==", + "!=", + ">", + "<", + ">=", + "<=", + "LIKE", + "NOT LIKE", + "ILIKE", + "NOT ILIKE", + "IN", + "NOT IN", + "IS NULL", + "IS NOT NULL", + "IS TRUE", + "IS FALSE", + "TEMPORAL_RANGE", +] + + +class QueryDatasetFilter(BaseModel): + """A single filter condition for dataset queries.""" + + col: str = Field(..., description="Column name to filter on") + op: VALID_FILTER_OPS = Field( + ..., + description=( + 'Filter operator. Use "==" for equals, "!=" for not equals, ' + '"IN" / "NOT IN" for membership, "IS NULL" / "IS NOT NULL", ' + '"LIKE" for pattern matching, "TEMPORAL_RANGE" for time filters.' + ), + ) + val: Any = Field( + default=None, + description="Filter value (omit for IS NULL/IS NOT NULL)", + ) + + +class QueryDatasetRequest(QueryCacheControl): + """Request schema for query_dataset tool.""" + + dataset_id: int | str = Field( + ..., + description="Dataset identifier — numeric ID or UUID string.", + ) + metrics: List[str] = Field( + default_factory=list, + description=( + "Saved metric names to compute (e.g. ['count', 'avg_revenue']). " + "Use get_dataset_info to discover available metrics." + ), + ) + columns: List[str] = Field( + default_factory=list, + description=( + "Column/dimension names for GROUP BY or SELECT " + "(e.g. ['category', 'region']). " + "Use get_dataset_info to discover available columns." + ), + ) + filters: List[QueryDatasetFilter] = Field( + default_factory=list, + description=( + 'Filter conditions (e.g. [{"col": "status", "op": "==", "val": "active"}]).' + ), + ) + time_range: str | None = Field( + default=None, + description=( + "Time range filter (e.g. 'Last 7 days', 'Last month', " + "'2024-01-01 : 2024-12-31'). Requires a temporal column " + "on the dataset." + ), + ) + time_column: str | None = Field( + default=None, + description=( + "Temporal column to apply time_range to. " + "Defaults to the dataset's main datetime column." + ), + ) + order_by: List[str] | None = Field( + default=None, + description="Column or metric names to sort results by.", + ) + order_desc: bool = Field( + default=True, + description="Sort descending (True) or ascending (False).", + ) + row_limit: int = Field( + default=1000, + ge=1, + le=50000, + description="Maximum number of rows to return (default 1000, max 50000).", + ) + + @model_validator(mode="after") + def validate_metrics_or_columns(self) -> "QueryDatasetRequest": + """At least one of metrics or columns must be provided.""" + if not self.metrics and not self.columns: + raise ValueError( + "At least one of 'metrics' or 'columns' must be provided. " + "Use get_dataset_info to discover available metrics and columns." + ) + return self + + +class QueryDatasetResponse(BaseModel): + """Response schema for query_dataset tool.""" + + model_config = ConfigDict(ser_json_timedelta="iso8601") + + dataset_id: int = Field(..., description="Dataset ID") + dataset_name: str = Field(..., description="Dataset name") + columns: List[DataColumn] = Field( + default_factory=list, description="Column metadata for returned data" + ) + data: List[Dict[str, Any]] = Field( + default_factory=list, description="Query result rows" + ) + row_count: int = Field(0, description="Number of rows returned") + total_rows: int | None = Field( + None, description="Total row count from the query engine" + ) + summary: str = Field("", description="Human-readable summary of the results") + performance: PerformanceMetadata | None = Field( + None, description="Query performance metadata" + ) + cache_status: CacheStatus | None = Field( + None, description="Cache hit/miss information" + ) + applied_filters: List[QueryDatasetFilter] = Field( + default_factory=list, description="Filters that were applied to the query" + ) + warnings: List[str] = Field( + default_factory=list, description="Any warnings encountered during execution" + ) + + def _parse_json_field(obj: Any, field_name: str) -> Dict[str, Any] | None: """Parse a field that may be stored as a JSON string into a dict.""" value = getattr(obj, field_name, None) diff --git a/superset/mcp_service/dataset/tool/__init__.py b/superset/mcp_service/dataset/tool/__init__.py index d396ce9e2d1..cad8d4ed569 100644 --- a/superset/mcp_service/dataset/tool/__init__.py +++ b/superset/mcp_service/dataset/tool/__init__.py @@ -18,9 +18,11 @@ from .create_virtual_dataset import create_virtual_dataset from .get_dataset_info import get_dataset_info from .list_datasets import list_datasets +from .query_dataset import query_dataset __all__ = [ "create_virtual_dataset", "list_datasets", "get_dataset_info", + "query_dataset", ] diff --git a/superset/mcp_service/dataset/tool/query_dataset.py b/superset/mcp_service/dataset/tool/query_dataset.py new file mode 100644 index 00000000000..d62c7fd9d2d --- /dev/null +++ b/superset/mcp_service/dataset/tool/query_dataset.py @@ -0,0 +1,489 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP tool: query_dataset + +Query a dataset using its semantic layer (saved metrics, calculated columns, +dimensions) without requiring a saved chart. +""" + +import difflib +import logging +import time +from typing import Any + +from fastmcp import Context +from sqlalchemy.exc import SQLAlchemyError +from sqlalchemy.orm import joinedload, subqueryload +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.commands.exceptions import CommandException +from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException +from superset.extensions import event_logger +from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata +from superset.mcp_service.dataset.schemas import ( + DatasetError, + QueryDatasetFilter, + QueryDatasetRequest, + QueryDatasetResponse, +) +from superset.mcp_service.privacy import ( + DATA_MODEL_METADATA_ERROR_TYPE, + requires_data_model_metadata_access, + user_can_view_data_model_metadata, +) +from superset.mcp_service.utils import _is_uuid +from superset.mcp_service.utils.cache_utils import get_cache_status_from_result +from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message + +logger = logging.getLogger(__name__) + + +def _resolve_dataset(identifier: int | str, eager_options: list[Any]) -> Any | None: + """Resolve a dataset by int ID or UUID string. + + Replicates the identifier resolution logic from ModelGetInfoCore._find_object(). + """ + from superset.daos.dataset import DatasetDAO + + opts = eager_options or None + + if isinstance(identifier, int): + return DatasetDAO.find_by_id(identifier, query_options=opts) + + # Try parsing as int + try: + id_val = int(identifier) + return DatasetDAO.find_by_id(id_val, query_options=opts) + except (ValueError, TypeError): + pass + + # Try UUID + if _is_uuid(str(identifier)): + return DatasetDAO.find_by_id(identifier, id_column="uuid", query_options=opts) + + return None + + +def _validate_names( + requested: list[str], + valid: set[str], + kind: str, +) -> list[str]: + """Return list of error messages for names not found in *valid*. + + Includes close-match suggestions when available. + """ + errors: list[str] = [] + for name in requested: + if name not in valid: + suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6) + msg = f"Unknown {kind}: '{name}'" + if suggestions: + msg += f". Did you mean: {', '.join(suggestions)}?" + errors.append(msg) + return errors + + +@requires_data_model_metadata_access +@tool( + tags=["data"], + class_permission_name="Dataset", + annotations=ToolAnnotations( + title="Query dataset", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def query_dataset( # noqa: C901 + request: QueryDatasetRequest, ctx: Context +) -> QueryDatasetResponse | DatasetError: + """Query a dataset using its semantic layer (saved metrics, dimensions, filters). + + Returns tabular data without requiring a saved chart. Use this when you want + to compute saved metrics, group by dimensions, or apply filters directly + against a dataset's curated semantic layer. + + Workflow: + 1. list_datasets -> find a dataset + 2. get_dataset_info -> discover available columns and metrics + 3. query_dataset -> query using metric names and column names + + Example: + ```json + { + "dataset_id": 123, + "metrics": ["count", "avg_revenue"], + "columns": ["product_category"], + "time_range": "Last 7 days", + "row_limit": 100 + } + ``` + """ + await ctx.info( + "Starting dataset query: dataset_id=%s, metrics=%s, columns=%s, " + "row_limit=%s" + % ( + request.dataset_id, + request.metrics, + request.columns, + request.row_limit, + ) + ) + + try: + from superset.commands.chart.data.get_data_command import ChartDataCommand + from superset.common.query_context_factory import QueryContextFactory + from superset.connectors.sqla.models import SqlaTable + + # ------------------------------------------------------------------ + # Step 1: Check data-model metadata access BEFORE the dataset lookup. + # Doing this first prevents leaking dataset existence — restricted + # users always receive DataModelMetadataRestricted, never NotFound. + # The decorator hides this tool from search; this check enforces + # direct calls that bypass tool discovery. + # ------------------------------------------------------------------ + if not user_can_view_data_model_metadata(): + await ctx.warning("Dataset metadata access blocked by privacy controls") + return DatasetError.create( + error=( + "You don't have permission to access dataset details for your role." + ), + error_type=DATA_MODEL_METADATA_ERROR_TYPE, + ) + + # ------------------------------------------------------------------ + # Step 2: Resolve dataset + # ------------------------------------------------------------------ + await ctx.report_progress(1, 5, "Looking up dataset") + eager_options = [ + subqueryload(SqlaTable.columns), + subqueryload(SqlaTable.metrics), + joinedload(SqlaTable.database), + ] + + with event_logger.log_context(action="mcp.query_dataset.lookup"): + dataset = _resolve_dataset(request.dataset_id, eager_options) + + if dataset is None: + await ctx.error("Dataset not found: identifier=%s" % (request.dataset_id,)) + return DatasetError.create( + error=f"No dataset found with identifier: {request.dataset_id}", + error_type="NotFound", + ) + + dataset_name = getattr(dataset, "table_name", None) or f"Dataset {dataset.id}" + await ctx.info( + "Dataset found: id=%s, name=%s, columns=%s, metrics=%s" + % ( + dataset.id, + dataset_name, + len(dataset.columns), + len(dataset.metrics), + ) + ) + + # ------------------------------------------------------------------ + # Step 2: Validate requested columns and metrics + # ------------------------------------------------------------------ + await ctx.report_progress(2, 5, "Validating columns and metrics") + valid_columns = {c.column_name for c in dataset.columns} + valid_metrics = {m.metric_name for m in dataset.metrics} + + validation_errors: list[str] = [] + validation_errors.extend( + _validate_names(request.columns, valid_columns, "column") + ) + validation_errors.extend( + _validate_names(request.metrics, valid_metrics, "metric") + ) + # Validate filter column names against dataset columns + filter_cols = [f.col for f in request.filters] + validation_errors.extend( + _validate_names(filter_cols, valid_columns, "filter column") + ) + # Validate order_by names against columns + metrics + if request.order_by: + valid_orderby = valid_columns | valid_metrics + validation_errors.extend( + _validate_names(request.order_by, valid_orderby, "order_by") + ) + + if validation_errors: + error_msg = "; ".join(validation_errors) + await ctx.error("Validation failed: %s" % (error_msg,)) + return DatasetError.create( + error=error_msg, + error_type="ValidationError", + ) + + # ------------------------------------------------------------------ + # Step 3: Build filters and time range + # ------------------------------------------------------------------ + warnings: list[str] = [] + query_filters: list[dict[str, Any]] = [ + {"col": f.col, "op": f.op, "val": f.val} for f in request.filters + ] + # Track all applied filters (including synthesized ones) for the response. + effective_filters: list[QueryDatasetFilter] = list(request.filters) + granularity: str | None = None + + if request.time_range: + temporal_col = request.time_column or getattr( + dataset, "main_dttm_col", None + ) + if not temporal_col: + await ctx.error("time_range provided but no temporal column available") + return DatasetError.create( + error=( + "time_range was provided but no temporal column is available. " + "Either set time_column explicitly or ensure the dataset has " + "a main datetime column configured." + ), + error_type="ValidationError", + ) + # Validate that the temporal column actually exists on the dataset + if temporal_col not in valid_columns: + await ctx.error("time_column '%s' not found on dataset" % temporal_col) + return DatasetError.create( + error=( + f"time_column '{temporal_col}' does not exist on this dataset." + ), + error_type="ValidationError", + ) + # Warn if the chosen temporal column isn't marked as datetime + dttm_cols = {c.column_name for c in dataset.columns if c.is_dttm} + if temporal_col not in dttm_cols: + warnings.append( + f"Column '{temporal_col}' is not marked as a datetime " + f"column on this dataset. Time filtering may not work " + f"as expected." + ) + + query_filters.append( + { + "col": temporal_col, + "op": "TEMPORAL_RANGE", + "val": request.time_range, + } + ) + effective_filters.append( + QueryDatasetFilter( + col=temporal_col, + op="TEMPORAL_RANGE", + val=request.time_range, + ) + ) + granularity = temporal_col + await ctx.debug( + "Time filter: column=%s, range=%s" % (temporal_col, request.time_range) + ) + + # ------------------------------------------------------------------ + # Step 4: Build query dict + # ------------------------------------------------------------------ + await ctx.report_progress(3, 5, "Building query") + query_dict: dict[str, Any] = { + "filters": query_filters, + "columns": request.columns, + "metrics": request.metrics, + "row_limit": request.row_limit, + "order_desc": request.order_desc, + } + if granularity: + query_dict["granularity"] = granularity + if request.order_by: + # OrderBy = tuple[Metric | Column, bool] where bool is ascending + query_dict["orderby"] = [ + (col, not request.order_desc) for col in request.order_by + ] + + await ctx.debug("Query dict keys: %s" % (sorted(query_dict.keys()),)) + + # ------------------------------------------------------------------ + # Step 5: Create QueryContext and execute + # ------------------------------------------------------------------ + await ctx.report_progress(4, 5, "Executing query") + start_time = time.time() + + with event_logger.log_context(action="mcp.query_dataset.execute"): + factory = QueryContextFactory() + # datasource_type is "table" because this tool queries SqlaTable + # datasets (Superset's built-in semantic layer). External semantic + # layers (dbt, Snowflake Cortex, etc.) use "semantic_view" and have + # a different query path — see SemanticView + mapper.py. + query_context = factory.create( + datasource={"id": dataset.id, "type": "table"}, + queries=[query_dict], + form_data={}, + force=not request.use_cache or request.force_refresh, + custom_cache_timeout=request.cache_timeout, + ) + + command = ChartDataCommand(query_context) + command.validate() + result = command.run() + + query_duration_ms = int((time.time() - start_time) * 1000) + + if not result or "queries" not in result or len(result["queries"]) == 0: + await ctx.warning("Query returned no results for dataset %s" % dataset.id) + return DatasetError.create( + error="Query returned no results.", + error_type="EmptyQuery", + ) + + # ------------------------------------------------------------------ + # Step 6: Format response + # ------------------------------------------------------------------ + await ctx.report_progress(5, 5, "Formatting results") + query_result = result["queries"][0] + data = query_result.get("data", []) + raw_columns = query_result.get("colnames", []) + + if not data: + return QueryDatasetResponse( + dataset_id=dataset.id, + dataset_name=dataset_name, + columns=[], + data=[], + row_count=0, + total_rows=0, + summary=f"Query on '{dataset_name}' returned no data.", + performance=PerformanceMetadata( + query_duration_ms=query_duration_ms, + cache_status="no_data", + ), + cache_status=get_cache_status_from_result( + query_result, force_refresh=request.force_refresh + ), + applied_filters=effective_filters, + warnings=warnings, + ) + + # Build column metadata in a single pass per column. + # Cap stats computation at STATS_SAMPLE rows to avoid O(rows*cols) + # overhead on large result sets (row_limit allows up to 50k). + stats_sample_size = 5000 + stats_rows = data[:stats_sample_size] + + columns_meta: list[DataColumn] = [] + for col_name in raw_columns: + sample_values = [ + row.get(col_name) for row in data[:3] if row.get(col_name) is not None + ] + data_type = "string" + if sample_values: + if all(isinstance(v, bool) for v in sample_values): + data_type = "boolean" + elif all(isinstance(v, (int, float)) for v in sample_values): + data_type = "numeric" + + # Compute null_count and unique non-null values in one pass + null_count = 0 + unique_vals: set[str] = set() + for row in stats_rows: + val = row.get(col_name) + if val is None: + null_count += 1 + else: + unique_vals.add(str(val)) + + columns_meta.append( + DataColumn( + name=col_name, + display_name=col_name.replace("_", " ").title(), + data_type=data_type, + sample_values=sample_values[:3], + null_count=null_count, + unique_count=len(unique_vals), + ) + ) + + cache_status = get_cache_status_from_result( + query_result, force_refresh=request.force_refresh + ) + + cache_label = "cached" if cache_status and cache_status.cache_hit else "fresh" + summary = ( + f"Dataset '{dataset_name}': {len(data)} rows, " + f"{len(raw_columns)} columns ({cache_label})." + ) + + await ctx.info( + "Query complete: rows=%s, columns=%s, duration=%sms" + % (len(data), len(raw_columns), query_duration_ms) + ) + + return QueryDatasetResponse( + dataset_id=dataset.id, + dataset_name=dataset_name, + columns=columns_meta, + data=data, + row_count=len(data), + total_rows=query_result.get("rowcount"), + summary=summary, + performance=PerformanceMetadata( + query_duration_ms=query_duration_ms, + cache_status=cache_label, + ), + cache_status=cache_status, + applied_filters=effective_filters, + warnings=warnings, + ) + + except OAuth2RedirectError as exc: + redirect_msg = build_oauth2_redirect_message(exc) + await ctx.error("OAuth2 redirect required: %s" % (redirect_msg,)) + return DatasetError.create( + error=redirect_msg, + error_type="OAuth2Redirect", + ) + + except OAuth2Error as exc: + await ctx.error("OAuth2 error: %s" % (str(exc),)) + return DatasetError.create( + error=f"OAuth2 authentication error: {exc}", + error_type="OAuth2Error", + ) + + except (CommandException, SupersetException) as exc: + await ctx.error("Query failed: %s" % (str(exc),)) + return DatasetError.create( + error=f"Query execution failed: {exc}", + error_type="QueryError", + ) + + except SQLAlchemyError as exc: + await ctx.error("Database error: %s" % (str(exc),)) + return DatasetError.create( + error=f"Database error: {exc}", + error_type="DatabaseError", + ) + + except Exception as exc: + logger.exception( + "Unexpected error while querying dataset: %s: %s", + type(exc).__name__, + str(exc), + ) + await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc))) + return DatasetError.create( + error="An unexpected error occurred while querying the dataset.", + error_type="UnexpectedError", + ) diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_query_dataset.py b/tests/unit_tests/mcp_service/dataset/tool/test_query_dataset.py new file mode 100644 index 00000000000..eb9e241d283 --- /dev/null +++ b/tests/unit_tests/mcp_service/dataset/tool/test_query_dataset.py @@ -0,0 +1,831 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for the query_dataset MCP tool.""" + +from __future__ import annotations + +import importlib +from collections.abc import Generator +from typing import Any +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client, FastMCP + +from superset.mcp_service.app import mcp +from superset.utils import json + +query_dataset_module = importlib.import_module( + "superset.mcp_service.dataset.tool.query_dataset" +) + + +@pytest.fixture +def mcp_server() -> FastMCP: + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth() -> Generator[MagicMock, None, None]: + """Mock authentication and metadata access for all tests.""" + with ( + patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user, + patch.object( + query_dataset_module, + "user_can_view_data_model_metadata", + return_value=True, + ), + ): + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +def _make_column(name: str, is_dttm: bool = False) -> MagicMock: + """Build a mock SqlaTable column with the given name and datetime flag.""" + col = MagicMock() + col.column_name = name + col.is_dttm = is_dttm + col.verbose_name = None + col.type = "VARCHAR" + col.groupby = True + col.filterable = True + col.description = None + return col + + +def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock: + """Build a mock SqlMetric with the given name and SQL expression.""" + metric = MagicMock() + metric.metric_name = name + metric.verbose_name = None + metric.expression = expression + metric.description = None + metric.d3format = None + return metric + + +def _make_dataset( + dataset_id: int = 1, + table_name: str = "orders", + columns: list[Any] | None = None, + metrics: list[Any] | None = None, + main_dttm_col: str | None = None, +) -> MagicMock: + """Build a mock SqlaTable dataset with default columns and metrics.""" + ds = MagicMock() + ds.id = dataset_id + ds.table_name = table_name + ds.uuid = f"test-uuid-{dataset_id}" + ds.main_dttm_col = main_dttm_col + ds.database = MagicMock() + ds.database.database_name = "examples" + ds.columns = columns or [ + _make_column("category"), + _make_column("region"), + _make_column("order_date", is_dttm=True), + ] + ds.metrics = metrics or [ + _make_metric("count", "COUNT(*)"), + _make_metric("total_revenue", "SUM(revenue)"), + ] + return ds + + +def _mock_command_result( + data: list[dict[str, Any]] | None = None, + colnames: list[str] | None = None, +) -> dict[str, Any]: + """Build the result dict that ChartDataCommand.run() returns.""" + data = data or [ + {"category": "Electronics", "count": 42}, + {"category": "Clothing", "count": 17}, + ] + colnames = colnames or ["category", "count"] + return { + "queries": [ + { + "data": data, + "colnames": colnames, + "rowcount": len(data), + "cache_key": "abc123", + "is_cached": False, + "cached_dttm": None, + "cache_timeout": 300, + } + ] + } + + +@pytest.mark.asyncio +async def test_query_dataset_success(mcp_server: FastMCP) -> None: + """Happy path: metrics + columns returns data.""" + dataset = _make_dataset() + result_data = _mock_command_result() + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.validate", + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.run", + return_value=result_data, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory.create", + return_value=MagicMock(), + ), + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + "columns": ["category"], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["dataset_id"] == 1 + assert data["dataset_name"] == "orders" + assert data["row_count"] == 2 + assert len(data["data"]) == 2 + assert data["data"][0]["category"] == "Electronics" + + +@pytest.mark.asyncio +async def test_query_dataset_not_found(mcp_server: FastMCP) -> None: + """Dataset ID that doesn't exist returns error.""" + with patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=None, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 999, + "metrics": ["count"], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "NotFound" + assert "999" in data["error"] + + +@pytest.mark.asyncio +async def test_query_dataset_invalid_metric(mcp_server: FastMCP) -> None: + """Unknown metric name returns validation error with suggestions.""" + dataset = _make_dataset() + + with patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["countt"], # typo + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "ValidationError" + assert "countt" in data["error"] + # Should suggest "count" as a close match + assert "count" in data["error"] + + +@pytest.mark.asyncio +async def test_query_dataset_invalid_column(mcp_server: FastMCP) -> None: + """Unknown column name returns validation error.""" + dataset = _make_dataset() + + with patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "columns": ["nonexistent_col"], + "metrics": ["count"], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "ValidationError" + assert "nonexistent_col" in data["error"] + + +@pytest.mark.asyncio +async def test_query_dataset_no_metrics_no_columns(mcp_server: FastMCP) -> None: + """Providing neither metrics nor columns raises validation error.""" + from fastmcp.exceptions import ToolError + + async with Client(mcp_server) as client: + with pytest.raises(ToolError, match="metrics.*columns"): + await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": [], + "columns": [], + } + }, + ) + + +@pytest.mark.asyncio +async def test_query_dataset_with_time_range(mcp_server: FastMCP) -> None: + """time_range is converted to TEMPORAL_RANGE filter + granularity.""" + dataset = _make_dataset(main_dttm_col="order_date") + result_data = _mock_command_result() + captured_queries: list[dict[str, Any]] = [] + + def capture_create(**kwargs): + captured_queries.extend(kwargs.get("queries", [])) + return MagicMock() + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.validate", + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.run", + return_value=result_data, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory.create", + side_effect=capture_create, + ), + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + "time_range": "Last 7 days", + } + }, + ) + + assert len(captured_queries) == 1 + query_dict = captured_queries[0] + # Should have TEMPORAL_RANGE filter + temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"] + assert len(temporal_filters) == 1 + assert temporal_filters[0]["col"] == "order_date" + assert temporal_filters[0]["val"] == "Last 7 days" + # Should set granularity + assert query_dict["granularity"] == "order_date" + # applied_filters in response must include the synthesized TEMPORAL_RANGE filter + data = json.loads(result.content[0].text) + resp_filters = data["applied_filters"] + temporal_resp = [f for f in resp_filters if f["op"] == "TEMPORAL_RANGE"] + assert len(temporal_resp) == 1 + assert temporal_resp[0]["col"] == "order_date" + assert temporal_resp[0]["val"] == "Last 7 days" + + +@pytest.mark.asyncio +async def test_query_dataset_time_range_no_temporal_column(mcp_server: FastMCP) -> None: + """time_range without a temporal column returns error.""" + dataset = _make_dataset(main_dttm_col=None) + + with patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + "time_range": "Last 7 days", + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "ValidationError" + assert "temporal column" in data["error"].lower() + + +@pytest.mark.asyncio +async def test_query_dataset_with_filters(mcp_server: FastMCP) -> None: + """User-provided filters are passed through to the query.""" + dataset = _make_dataset() + result_data = _mock_command_result() + captured_queries: list[dict[str, Any]] = [] + + def capture_create(**kwargs): + captured_queries.extend(kwargs.get("queries", [])) + return MagicMock() + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.validate", + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.run", + return_value=result_data, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory.create", + side_effect=capture_create, + ), + ): + async with Client(mcp_server) as client: + await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + "filters": [ + {"col": "category", "op": "==", "val": "Electronics"} + ], + } + }, + ) + + assert len(captured_queries) == 1 + filters = captured_queries[0]["filters"] + assert len(filters) == 1 + assert filters[0]["col"] == "category" + assert filters[0]["op"] == "==" + assert filters[0]["val"] == "Electronics" + + +@pytest.mark.asyncio +async def test_query_dataset_empty_results(mcp_server: FastMCP) -> None: + """Query that returns no data gives a response with row_count=0.""" + dataset = _make_dataset() + empty_result = { + "queries": [ + { + "data": [], + "colnames": [], + "rowcount": 0, + "is_cached": False, + "cached_dttm": None, + "cache_timeout": 300, + } + ] + } + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.validate", + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.run", + return_value=empty_result, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory.create", + return_value=MagicMock(), + ), + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["row_count"] == 0 + assert data["data"] == [] + assert "no data" in data["summary"].lower() + + +@pytest.mark.asyncio +async def test_query_dataset_by_uuid(mcp_server: FastMCP) -> None: + """UUID-based lookup works.""" + dataset = _make_dataset() + result_data = _mock_command_result() + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ) as mock_resolve, + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.validate", + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.run", + return_value=result_data, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory.create", + return_value=MagicMock(), + ), + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": "a1b2c3d4-5678-90ab-cdef-1234567890ab", + "metrics": ["count"], + } + }, + ) + + # Verify the resolve function was called with the UUID + mock_resolve.assert_called_once() + call_args = mock_resolve.call_args + assert call_args[0][0] == "a1b2c3d4-5678-90ab-cdef-1234567890ab" + + data = json.loads(result.content[0].text) + assert data["dataset_id"] == 1 + + +@pytest.mark.asyncio +async def test_query_dataset_permission_denied(mcp_server: FastMCP) -> None: + """Permission denied from ChartDataCommand.validate() returns error.""" + from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + from superset.exceptions import SupersetSecurityException + + dataset = _make_dataset() + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory.create", + return_value=MagicMock(), + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.validate", + side_effect=SupersetSecurityException( + SupersetError( + message="Access denied", + error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR, + level=ErrorLevel.WARNING, + ) + ), + ), + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "QueryError" + + +@pytest.mark.asyncio +async def test_query_dataset_order_by_valid(mcp_server: FastMCP) -> None: + """order_by with valid column/metric names passes through.""" + dataset = _make_dataset() + result_data = _mock_command_result() + captured_queries: list[dict[str, Any]] = [] + + def capture_create(**kwargs): + captured_queries.extend(kwargs.get("queries", [])) + return MagicMock() + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.validate", + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.run", + return_value=result_data, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory.create", + side_effect=capture_create, + ), + ): + async with Client(mcp_server) as client: + await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + "columns": ["category"], + "order_by": ["count"], + "order_desc": True, + } + }, + ) + + assert len(captured_queries) == 1 + orderby = captured_queries[0].get("orderby", []) + assert len(orderby) == 1 + assert orderby[0][0] == "count" + # order_desc=True -> ascending=False + assert orderby[0][1] is False + + +@pytest.mark.asyncio +async def test_query_dataset_order_by_invalid(mcp_server: FastMCP) -> None: + """order_by with an unknown name returns validation error.""" + dataset = _make_dataset() + + with patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + "order_by": ["nonexistent"], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "ValidationError" + assert "nonexistent" in data["error"] + + +@pytest.mark.asyncio +async def test_query_dataset_time_column_override(mcp_server: FastMCP) -> None: + """Explicit time_column overrides dataset main_dttm_col.""" + dataset = _make_dataset(main_dttm_col="order_date") + result_data = _mock_command_result() + captured_queries: list[dict[str, Any]] = [] + + def capture_create(**kwargs): + captured_queries.extend(kwargs.get("queries", [])) + return MagicMock() + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.validate", + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.run", + return_value=result_data, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory.create", + side_effect=capture_create, + ), + ): + async with Client(mcp_server) as client: + await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + "time_range": "Last 30 days", + "time_column": "order_date", + } + }, + ) + + assert len(captured_queries) == 1 + query_dict = captured_queries[0] + assert query_dict["granularity"] == "order_date" + temporal_filters = [f for f in query_dict["filters"] if f["op"] == "TEMPORAL_RANGE"] + assert temporal_filters[0]["col"] == "order_date" + + +@pytest.mark.asyncio +async def test_query_dataset_non_dttm_time_column_warns(mcp_server: FastMCP) -> None: + """Using a non-datetime column for time_range produces a warning.""" + dataset = _make_dataset(main_dttm_col=None) + result_data = _mock_command_result() + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.validate", + ), + patch( + "superset.commands.chart.data.get_data_command.ChartDataCommand.run", + return_value=result_data, + ), + patch( + "superset.common.query_context_factory.QueryContextFactory.create", + return_value=MagicMock(), + ), + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + "time_range": "Last 7 days", + "time_column": "category", + } + }, + ) + + data = json.loads(result.content[0].text) + assert len(data["warnings"]) > 0 + assert "not marked as a datetime" in data["warnings"][0] + + +@pytest.mark.asyncio +async def test_query_dataset_invalid_filter_column(mcp_server: FastMCP) -> None: + """Filter on a column that doesn't exist returns validation error.""" + dataset = _make_dataset() + + with patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + "metrics": ["count"], + "filters": [ + { + "col": "nonexistent", + "op": "==", + "val": "test", + } + ], + } + }, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "ValidationError" + assert "nonexistent" in data["error"] + + +@pytest.mark.asyncio +async def test_query_dataset_metadata_access_denied_no_suggestions( + mcp_server: FastMCP, +) -> None: + """Users without data-model metadata access cannot probe column/metric names. + + The privacy gate must fire before the validation step that returns close-match + suggestions, so restricted users cannot enumerate schema details via typos. + """ + dataset = _make_dataset() + + with ( + patch.object( + query_dataset_module, + "_resolve_dataset", + return_value=dataset, + ), + patch.object( + query_dataset_module, + "user_can_view_data_model_metadata", + return_value=False, + ), + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + "dataset_id": 1, + # Typo that would normally trigger close-match suggestions + "metrics": ["countt"], + } + }, + ) + + data = json.loads(result.content[0].text) + # Must be denied before returning any schema suggestions + assert data["error_type"] == "DataModelMetadataRestricted" + # Must NOT contain column/metric name suggestions + assert "countt" not in data.get("error", "") + assert "count" not in data.get("error", "") + + +@pytest.mark.asyncio +async def test_query_dataset_metadata_access_denied_nonexistent_dataset( + mcp_server: FastMCP, +) -> None: + """Metadata-restricted users must not be able to probe dataset existence. + + The privacy gate fires before the DAO lookup, so a restricted caller + always receives DataModelMetadataRestricted — never NotFound — regardless + of whether the requested dataset ID exists. + """ + with patch.object( + query_dataset_module, + "user_can_view_data_model_metadata", + return_value=False, + ): + async with Client(mcp_server) as client: + result = await client.call_tool( + "query_dataset", + { + "request": { + # Use a dataset_id that does not exist + "dataset_id": 999999, + "metrics": ["count"], + } + }, + ) + + data = json.loads(result.content[0].text) + # Must receive restricted error, not a NotFound that leaks existence + assert data["error_type"] == "DataModelMetadataRestricted" + assert data["error_type"] != "NotFound"