Compare commits

...

7 Commits

Author SHA1 Message Date
Amin Ghadersohi
ec88c32af0 fix(mcp): revert accidental chmod +x on pre-existing mcp_service files
An earlier commit on this branch flipped the executable bit (644 -> 755)
on ~111 unrelated files across superset/mcp_service/ as a side effect,
inflating the PR diff to 128 files. Restore their mode to 644 to match
master; no content changes.
2026-07-01 18:33:05 +00:00
Amin Ghadersohi
290ab5d882 fix(mcp): validate time_column datetime type, add missing tool tests
- get_table: reject a time_column that doesn't exist or isn't marked as
  a datetime column, instead of passing it through to the query engine
  and surfacing a cryptic downstream error
- add unit tests for get_table, get_compatible_dimensions, and
  get_compatible_metrics (previously only list_metrics had coverage)
2026-07-01 18:15:40 +00:00
Amin Ghadersohi
1b9de8f4a6 fix(mcp): add ASF license headers to __init__.py files; fix test patch paths
- Add Apache license headers to semantic_layer test __init__.py files
  (license check CI was failing due to missing headers)
- Promote DatasetDAO, SemanticViewDAO, and db to module-level imports in
  list_metrics.py so unit test patch("...list_metrics.DatasetDAO") resolves
  correctly (unit-tests CI was failing with AttributeError because these
  were lazy-imported inside function bodies)
2026-07-01 17:54:14 +00:00
Amin Ghadersohi
d80343c20a fix(mcp): address Codex review findings on semantic layer tools
- BLOCKER: add view.raise_for_access() on all external SemanticView
  paths; use SemanticViewDAO.find_accessible() for list-all to filter
  at SQL level, mirroring DatasourceDAO.build_semantic_view_query()
- MEDIUM: validate requested metrics/dimensions/filters/order_by against
  the resolved datasource in get_table, matching query_dataset behavior;
  extract shared validate_names() to mcp_service/utils/query_utils.py
- MEDIUM: fix N+1 in list_metrics find_all path by using _apply_base_filter
  with subqueryload eager options instead of lazy-loaded find_all()
- NIT: extract _format_columns to format_data_columns() in response_utils.py;
  query_dataset now shares the same implementation
- NIT: SemanticLayerError now extends MCPBaseError for a consistent error
  shape across all MCP tools
- NIT: fix file permissions (644) on new semantic layer Python files
- add mutual exclusion validation (dataset_id XOR view_id) to list_metrics
- return typed SemanticLayerError instead of re-raising in final except blocks
- fix circular import: use TYPE_CHECKING guard for DataColumn in response_utils
- add unit tests for list_metrics (happy path, mutual exclusion, privacy, search)
2026-07-01 17:21:12 +00:00
Amin Ghadersohi
48a73b5d59 fix(mcp): address Codex review findings on semantic layer tools
- BLOCKER: add view.raise_for_access() on all external SemanticView
  paths; use SemanticViewDAO.find_accessible() for list-all to filter
  at SQL level, mirroring DatasourceDAO.build_semantic_view_query()
- MEDIUM: validate requested metrics/dimensions/filters/order_by against
  the resolved datasource in get_table, matching query_dataset behavior;
  extract shared validate_names() to mcp_service/utils/query_utils.py
- MEDIUM: fix N+1 in list_metrics find_all path by using _apply_base_filter
  with subqueryload eager options instead of lazy-loaded find_all()
- NIT: extract _format_columns to format_data_columns() in response_utils.py;
  query_dataset now shares the same implementation
- NIT: SemanticLayerError now extends MCPBaseError for a consistent error
  shape across all MCP tools
- NIT: fix file permissions (644) on new semantic layer Python files
2026-07-01 17:13:05 +00:00
Amin Ghadersohi
89a52f99be feat(mcp): add semantic layer MCP tools (list_metrics, get_table, get_compatible_dimensions, get_compatible_metrics)
Implements Phase 1 of SC-98803. Adds 4 new MCP tools that surface Superset's
semantic layer to LLM clients, spanning both built-in SqlaTable datasets and
external SemanticView implementations:

- list_metrics: unified metric discovery across all data sources; includes
  compatible_dimensions inline per metric for progressive query building
- get_table: routes queries to ChartDataCommand (built-in) or the semantic
  view's query path (external) based on dataset_id vs view_id
- get_compatible_dimensions: returns valid dimensions for a given metric
  selection; delegates to SemanticView.get_compatible_dimensions for external
- get_compatible_metrics: returns valid metrics for a given dimension
  selection; delegates to SemanticView.get_compatible_metrics for external

External semantic view calls degrade gracefully when the registry is empty
(OSS default). All tools follow existing patterns: @tool decorator, Pydantic
schemas, DAO-based lookups, event_logger, ctx logging, and ASF license headers.
2026-07-01 17:13:05 +00:00
Amin Ghadersohi
8f2a01e294 feat(mcp): add semantic layer MCP tools (list_metrics, get_table, get_compatible_dimensions, get_compatible_metrics)
Implements Phase 1 of SC-98803. Adds 4 new MCP tools that surface Superset's
semantic layer to LLM clients, spanning both built-in SqlaTable datasets and
external SemanticView implementations:

- list_metrics: unified metric discovery across all data sources; includes
  compatible_dimensions inline per metric for progressive query building
- get_table: routes queries to ChartDataCommand (built-in) or the semantic
  view's query path (external) based on dataset_id vs view_id
- get_compatible_dimensions: returns valid dimensions for a given metric
  selection; delegates to SemanticView.get_compatible_dimensions for external
- get_compatible_metrics: returns valid metrics for a given dimension
  selection; delegates to SemanticView.get_compatible_metrics for external

External semantic view calls degrade gracefully when the registry is empty
(OSS default). All tools follow existing patterns: @tool decorator, Pydantic
schemas, DAO-based lookups, event_logger, ctx logging, and ASF license headers.
2026-07-01 17:13:05 +00:00
18 changed files with 2612 additions and 65 deletions

View File

@@ -201,6 +201,29 @@ class SemanticViewDAO(BaseDAO[SemanticView], AbstractSemanticViewDAO):
for c in candidates
)
@classmethod
def find_accessible(cls) -> list[SemanticView]:
"""Return all views the current user can access, filtered at SQL level.
Mirrors the permission filter in ``DatasourceDAO.build_semantic_view_query``
to avoid per-row Python-level access checks when listing all views.
"""
from sqlalchemy import or_
query = db.session.query(SemanticView)
if not security_manager.can_access_all_datasources():
perms = security_manager.user_view_menu_names("datasource_access")
query = query.outerjoin(
SemanticLayer,
SemanticLayer.uuid == SemanticView.semantic_layer_uuid,
).filter(
or_(
SemanticView.perm.in_(perms),
SemanticLayer.perm.in_(perms),
)
)
return query.all()
@classmethod
def find_by_name(cls, name: str, layer_uuid: str) -> SemanticView | None:
"""

View File

@@ -291,6 +291,31 @@ Use created_by_me for authorship, owned_by_me for edit ownership, or both
together for the union. All flags can be combined with 'filters' but not
with 'search'.
To explore metrics across all data sources (built-in datasets + external semantic views):
1. list_metrics(request={{"search": "<keyword>"}})
-> returns metrics with dataset_id/view_id and compatible_dimensions inline
2. get_table(request={{
"dataset_id": <id>, # OR "view_id": <id> for external semantic views
"metrics": ["revenue"],
"dimensions": ["region"],
"time_range": "Last 30 days",
"row_limit": 500
}}) -> returns tabular results
- Use "dataset_id" when list_metrics returned source="builtin"
- Use "view_id" when list_metrics returned source="external"
To progressively refine a query (compatible dimensions/metrics):
- get_compatible_dimensions(request={{
"selected_metrics": ["revenue"],
"selected_dimensions": [],
"dataset_id": <id> # or "view_id": <id>
}}) -> dimensions valid to add to the current selection
- get_compatible_metrics(request={{
"selected_metrics": [],
"selected_dimensions": ["region"],
"view_id": <id> # useful for external semantic layers with constraints
}}) -> metrics valid to add to the current selection
To query a dataset's semantic layer (metrics, dimensions):
1. list_datasets(request={{}}) -> find a dataset
2. get_dataset_info(request={{"identifier": <id>}}) -> examine columns AND metrics
@@ -741,6 +766,12 @@ from superset.mcp_service.saved_query.tool import ( # noqa: F401, E402
get_saved_query_info,
list_saved_queries,
)
from superset.mcp_service.semantic_layer.tool import ( # noqa: F401, E402
get_compatible_dimensions,
get_compatible_metrics,
get_table,
list_metrics,
)
from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402
execute_sql,
open_sql_lab_with_context,

View File

@@ -22,7 +22,6 @@ Query a dataset using its semantic layer (saved metrics, calculated columns,
dimensions) without requiring a saved chart.
"""
import difflib
import logging
import time
from typing import Any
@@ -35,7 +34,7 @@ from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
from superset.mcp_service.chart.schemas import PerformanceMetadata
from superset.mcp_service.dataset.schemas import (
DatasetError,
QueryDatasetFilter,
@@ -50,6 +49,8 @@ from superset.mcp_service.privacy import (
from superset.mcp_service.utils import _is_uuid
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message
from superset.mcp_service.utils.query_utils import validate_names
from superset.mcp_service.utils.response_utils import format_data_columns
logger = logging.getLogger(__name__)
@@ -80,26 +81,6 @@ def _resolve_dataset(identifier: int | str, eager_options: list[Any]) -> Any | N
return None
def _validate_names(
requested: list[str],
valid: set[str],
kind: str,
) -> list[str]:
"""Return list of error messages for names not found in *valid*.
Includes close-match suggestions when available.
"""
errors: list[str] = []
for name in requested:
if name not in valid:
suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6)
msg = f"Unknown {kind}: '{name}'"
if suggestions:
msg += f". Did you mean: {', '.join(suggestions)}?"
errors.append(msg)
return errors
@requires_data_model_metadata_access
@tool(
tags=["data"],
@@ -210,21 +191,21 @@ async def query_dataset( # noqa: C901
validation_errors: list[str] = []
validation_errors.extend(
_validate_names(request.columns, valid_columns, "column")
validate_names(request.columns, valid_columns, "column")
)
validation_errors.extend(
_validate_names(request.metrics, valid_metrics, "metric")
validate_names(request.metrics, valid_metrics, "metric")
)
# Validate filter column names against dataset columns
filter_cols = [f.col for f in request.filters]
validation_errors.extend(
_validate_names(filter_cols, valid_columns, "filter column")
validate_names(filter_cols, valid_columns, "filter column")
)
# Validate order_by names against columns + metrics
if request.order_by:
valid_orderby = valid_columns | valid_metrics
validation_errors.extend(
_validate_names(request.order_by, valid_orderby, "order_by")
validate_names(request.order_by, valid_orderby, "order_by")
)
if validation_errors:
@@ -379,44 +360,7 @@ async def query_dataset( # noqa: C901
warnings=warnings,
)
# Build column metadata in a single pass per column.
# Cap stats computation at STATS_SAMPLE rows to avoid O(rows*cols)
# overhead on large result sets (row_limit allows up to 50k).
stats_sample_size = 5000
stats_rows = data[:stats_sample_size]
columns_meta: list[DataColumn] = []
for col_name in raw_columns:
sample_values = [
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
]
data_type = "string"
if sample_values:
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
# Compute null_count and unique non-null values in one pass
null_count = 0
unique_vals: set[str] = set()
for row in stats_rows:
val = row.get(col_name)
if val is None:
null_count += 1
else:
unique_vals.add(str(val))
columns_meta.append(
DataColumn(
name=col_name,
display_name=col_name.replace("_", " ").title(),
data_type=data_type,
sample_values=sample_values[:3],
null_count=null_count,
unique_count=len(unique_vals),
)
)
columns_meta = format_data_columns(data, raw_columns)
cache_status = get_cache_status_from_result(
query_result, force_refresh=request.force_refresh

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,302 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pydantic schemas for semantic layer MCP tools."""
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, Field
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
from superset.mcp_service.common.cache_schemas import CacheStatus
from superset.mcp_service.common.error_schemas import MCPBaseError
# ---------------------------------------------------------------------------
# Shared error schema
# ---------------------------------------------------------------------------
class SemanticLayerError(MCPBaseError):
"""Error response returned by semantic layer tools."""
success: Literal[False] = False
@classmethod
def create(cls, *, error: str, error_type: str) -> "SemanticLayerError":
return cls(error=error, error_type=error_type)
# ---------------------------------------------------------------------------
# Dimension info (returned inside MetricInfo.compatible_dimensions)
# ---------------------------------------------------------------------------
class DimensionInfo(BaseModel):
"""Metadata for a single dimension / column."""
name: str
verbose_name: str | None = None
description: str | None = None
type: str | None = None
is_dttm: bool = False
groupby: bool = True
filterable: bool = True
source: Literal["builtin", "external"] = "builtin"
# ---------------------------------------------------------------------------
# Metric info
# ---------------------------------------------------------------------------
class MetricInfo(BaseModel):
"""Metadata for a single metric, including compatible dimensions."""
name: str
verbose_name: str | None = None
description: str | None = None
expression: str | None = None
d3format: str | None = None
warning_text: str | None = None
source: Literal["builtin", "external"] = "builtin"
dataset_id: int | None = None
dataset_name: str | None = None
view_id: int | None = None
view_name: str | None = None
compatible_dimensions: list[DimensionInfo] = Field(default_factory=list)
# ---------------------------------------------------------------------------
# list_metrics
# ---------------------------------------------------------------------------
class ListMetricsRequest(BaseModel):
"""Request schema for list_metrics."""
search: str | None = Field(
default=None,
description="Optional search string to filter metrics by name or description.",
)
dataset_id: int | None = Field(
default=None,
description="Filter to metrics from a specific built-in dataset.",
)
view_id: int | None = Field(
default=None,
description="Filter to metrics from a specific semantic view.",
)
include_compatible_dimensions: bool = Field(
default=True,
description=(
"When True, each metric includes its list of compatible dimensions. "
"Set to False to reduce response size when dimensions aren't needed."
),
)
page: int = Field(default=1, ge=1, description="1-based page number.")
page_size: int = Field(
default=50, ge=1, le=500, description="Number of metrics per page."
)
class MetricList(BaseModel):
"""Response schema for list_metrics."""
metrics: list[MetricInfo]
total_count: int
page: int
page_size: int
total_pages: int
success: Literal[True] = True
# ---------------------------------------------------------------------------
# get_table
# ---------------------------------------------------------------------------
class GetTableFilter(BaseModel):
"""A single filter clause for get_table."""
col: str = Field(..., description="Column or dimension name to filter on.")
op: str = Field(
default="==",
description=(
"Filter operator. Common values: '==', '!=', '>', '<', '>=', '<=', "
"'IN', 'NOT IN', 'LIKE', 'ILIKE', 'TEMPORAL_RANGE'."
),
)
val: Any = Field(
default=None,
description="Filter value. Use a list for 'IN'/'NOT IN' operators.",
)
class GetTableRequest(BaseModel):
"""Request schema for get_table."""
dataset_id: int | None = Field(
default=None,
description=(
"Built-in dataset ID to query. Obtained from list_metrics response "
"when source='builtin'. Provide either this or view_id."
),
)
view_id: int | None = Field(
default=None,
description=(
"External semantic view ID to query. Obtained from list_metrics "
"response when source='external'. Provide either this or dataset_id."
),
)
metrics: list[str] = Field(
default_factory=list,
description=(
"Metric names to compute. All metrics must come from the same "
"data source (dataset or semantic view)."
),
)
dimensions: list[str] = Field(
default_factory=list,
description="Dimension or column names to group by.",
)
filters: list[GetTableFilter] = Field(
default_factory=list,
description="Optional filters to apply.",
)
time_range: str | None = Field(
default=None,
description=(
"Optional time range string, e.g. 'Last 7 days', 'Last 30 days', "
"'2024-01-01 : 2024-12-31'. Requires a datetime dimension."
),
)
time_column: str | None = Field(
default=None,
description=(
"Name of the datetime column/dimension to apply time_range to. "
"Inferred from the dataset's main_dttm_col when omitted."
),
)
row_limit: int = Field(
default=1000,
ge=1,
le=50000,
description="Maximum number of rows to return.",
)
order_by: list[str] = Field(
default_factory=list,
description="Column/metric names to sort by.",
)
order_desc: bool = Field(
default=True,
description="Sort descending when True (default).",
)
use_cache: bool = Field(default=True, description="Use query cache when available.")
force_refresh: bool = Field(
default=False,
description="Force a cache refresh even when cached results exist.",
)
class GetTableResponse(BaseModel):
"""Response schema for get_table."""
columns: list[DataColumn]
data: list[dict[str, Any]]
row_count: int
total_rows: int | None = None
summary: str
source: Literal["builtin", "external"]
dataset_id: int | None = None
dataset_name: str | None = None
view_id: int | None = None
view_name: str | None = None
performance: PerformanceMetadata | None = None
cache_status: CacheStatus | None = None
warnings: list[str] = Field(default_factory=list)
success: Literal[True] = True
# ---------------------------------------------------------------------------
# get_compatible_dimensions
# ---------------------------------------------------------------------------
class GetCompatibleDimensionsRequest(BaseModel):
"""Request schema for get_compatible_dimensions."""
selected_metrics: list[str] = Field(
default_factory=list,
description="Metric names already selected.",
)
selected_dimensions: list[str] = Field(
default_factory=list,
description="Dimension names already selected.",
)
dataset_id: int | None = Field(
default=None,
description="Built-in dataset ID to query. Provide either this or view_id.",
)
view_id: int | None = Field(
default=None,
description="Semantic view ID to query. Provide either this or dataset_id.",
)
class CompatibleDimensionsResponse(BaseModel):
"""Response schema for get_compatible_dimensions."""
compatible_dimensions: list[DimensionInfo]
source: Literal["builtin", "external"]
success: Literal[True] = True
# ---------------------------------------------------------------------------
# get_compatible_metrics
# ---------------------------------------------------------------------------
class GetCompatibleMetricsRequest(BaseModel):
"""Request schema for get_compatible_metrics."""
selected_metrics: list[str] = Field(
default_factory=list,
description="Metric names already selected.",
)
selected_dimensions: list[str] = Field(
default_factory=list,
description="Dimension names already selected.",
)
dataset_id: int | None = Field(
default=None,
description="Built-in dataset ID to query. Provide either this or view_id.",
)
view_id: int | None = Field(
default=None,
description="Semantic view ID to query. Provide either this or dataset_id.",
)
class CompatibleMetricsResponse(BaseModel):
"""Response schema for get_compatible_metrics."""
compatible_metrics: list[MetricInfo]
source: Literal["builtin", "external"]
success: Literal[True] = True

View File

@@ -0,0 +1,27 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from superset.mcp_service.semantic_layer.tool.get_compatible_dimensions import ( # noqa: F401
get_compatible_dimensions,
)
from superset.mcp_service.semantic_layer.tool.get_compatible_metrics import ( # noqa: F401
get_compatible_metrics,
)
from superset.mcp_service.semantic_layer.tool.get_table import get_table # noqa: F401
from superset.mcp_service.semantic_layer.tool.list_metrics import ( # noqa: F401
list_metrics,
)

View File

@@ -0,0 +1,227 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""MCP tool: get_compatible_dimensions
Returns dimensions compatible with the current metric/dimension selection.
"""
import logging
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.privacy import (
DATA_MODEL_METADATA_ERROR_TYPE,
requires_data_model_metadata_access,
user_can_view_data_model_metadata,
)
from superset.mcp_service.semantic_layer.schemas import (
CompatibleDimensionsResponse,
DimensionInfo,
GetCompatibleDimensionsRequest,
SemanticLayerError,
)
logger = logging.getLogger(__name__)
@tool(
tags=["data", "semantic"],
class_permission_name="Dataset",
annotations=ToolAnnotations(
title="Get compatible dimensions",
readOnlyHint=True,
destructiveHint=False,
),
)
@requires_data_model_metadata_access
async def get_compatible_dimensions(
request: GetCompatibleDimensionsRequest,
ctx: Context,
) -> CompatibleDimensionsResponse | SemanticLayerError:
"""Return dimensions compatible with the current metric/dimension selection.
Used to drive progressive disclosure in query builders: after the user
selects one or more metrics (and optionally some dimensions), this tool
returns the dimensions that can validly be added without breaking the
underlying query.
Provide exactly one of ``dataset_id`` (built-in) or ``view_id`` (external).
For built-in datasets, returns all groupby-enabled columns from the dataset.
SQL datasets have no semantic compatibility constraints between metrics and
dimensions, so all groupby columns are always returned regardless of the
selected metrics.
For external semantic views, delegates to the view's
``get_compatible_dimensions`` implementation.
Example:
```json
{
"selected_metrics": ["revenue"],
"selected_dimensions": [],
"view_id": 5
}
```
"""
await ctx.info(
"Getting compatible dimensions: dataset_id=%s, view_id=%s, "
"metrics=%s, dims=%s"
% (
request.dataset_id,
request.view_id,
request.selected_metrics,
request.selected_dimensions,
)
)
if not user_can_view_data_model_metadata():
return SemanticLayerError.create(
error="You don't have permission to access dataset details for your role.",
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
)
if request.dataset_id is None and request.view_id is None:
return SemanticLayerError.create(
error="Provide either dataset_id (built-in) or view_id (external).",
error_type="ValidationError",
)
if request.dataset_id is not None and request.view_id is not None:
return SemanticLayerError.create(
error="Provide only one of dataset_id or view_id, not both.",
error_type="ValidationError",
)
try:
# ------------------------------------------------------------------
# Built-in dataset path
# ------------------------------------------------------------------
if request.dataset_id is not None:
from sqlalchemy.orm import subqueryload
from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO
dataset_id: int = request.dataset_id
with event_logger.log_context(
action="mcp.get_compatible_dimensions.builtin"
):
dataset = DatasetDAO.find_by_id(
dataset_id,
query_options=[
subqueryload(SqlaTable.columns),
subqueryload(SqlaTable.metrics),
],
)
if dataset is None:
return SemanticLayerError.create(
error=f"No dataset found with id: {request.dataset_id}.",
error_type="NotFound",
)
# For built-in datasets all groupby columns are always compatible;
# there's no per-metric compatibility constraint at the SQL level.
dims = [
DimensionInfo(
name=col.column_name,
verbose_name=col.verbose_name or None,
description=col.description or None,
type=col.type or None,
is_dttm=bool(col.is_dttm),
groupby=bool(col.groupby),
filterable=bool(col.filterable),
source="builtin",
)
for col in dataset.columns
if col.groupby
]
await ctx.info("Compatible dimensions (builtin): count=%d" % len(dims))
return CompatibleDimensionsResponse(
compatible_dimensions=dims,
source="builtin",
)
# ------------------------------------------------------------------
# External semantic view path
# ------------------------------------------------------------------
from superset.daos.semantic_layer import SemanticViewDAO
from superset.exceptions import SupersetSecurityException
view_id: int = request.view_id # type: ignore[assignment]
with event_logger.log_context(action="mcp.get_compatible_dimensions.external"):
view = SemanticViewDAO.find_by_id(view_id)
if view is None:
return SemanticLayerError.create(
error=f"No semantic view found with id: {view_id}.",
error_type="NotFound",
)
try:
view.raise_for_access()
except SupersetSecurityException as ex:
return SemanticLayerError.create(
error=str(ex.error.message),
error_type="AccessDenied",
)
compatible_names = view.get_compatible_dimensions(
request.selected_metrics,
request.selected_dimensions,
)
# Enrich with full column metadata
all_cols = {col.column_name: col for col in view.columns}
dims = [
DimensionInfo(
name=name,
verbose_name=all_cols[name].verbose_name if name in all_cols else None,
description=all_cols[name].description if name in all_cols else None,
type=all_cols[name].type if name in all_cols else None,
is_dttm=all_cols[name].is_dttm if name in all_cols else False,
groupby=all_cols[name].groupby if name in all_cols else True,
filterable=all_cols[name].filterable if name in all_cols else True,
source="external",
)
for name in compatible_names
]
await ctx.info(
"Compatible dimensions (external view id=%d): count=%d"
% (view.id, len(dims))
)
return CompatibleDimensionsResponse(
compatible_dimensions=dims,
source="external",
)
except Exception as exc:
logger.exception(
"Unexpected error in get_compatible_dimensions: %s: %s",
type(exc).__name__,
str(exc),
)
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
return SemanticLayerError.create(
error=f"Internal error in get_compatible_dimensions: {exc}",
error_type="InternalError",
)

View File

@@ -0,0 +1,226 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""MCP tool: get_compatible_metrics
Returns metrics compatible with the current dimension/metric selection.
"""
import logging
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.privacy import (
DATA_MODEL_METADATA_ERROR_TYPE,
requires_data_model_metadata_access,
user_can_view_data_model_metadata,
)
from superset.mcp_service.semantic_layer.schemas import (
CompatibleMetricsResponse,
GetCompatibleMetricsRequest,
MetricInfo,
SemanticLayerError,
)
logger = logging.getLogger(__name__)
@tool(
tags=["data", "semantic"],
class_permission_name="Dataset",
annotations=ToolAnnotations(
title="Get compatible metrics",
readOnlyHint=True,
destructiveHint=False,
),
)
@requires_data_model_metadata_access
async def get_compatible_metrics(
request: GetCompatibleMetricsRequest,
ctx: Context,
) -> CompatibleMetricsResponse | SemanticLayerError:
"""Return metrics compatible with the current dimension/metric selection.
Used to progressively refine a query: given a set of already-selected
metrics and dimensions, returns the additional metrics that can be
combined without breaking the underlying semantic constraints.
Provide exactly one of ``dataset_id`` (built-in) or ``view_id`` (external).
For built-in datasets, all metrics from the dataset are considered
compatible (SQL GROUP BY imposes no metric-level constraints).
For external semantic views, delegates to the view's
``get_compatible_metrics`` implementation.
Example:
```json
{
"selected_metrics": [],
"selected_dimensions": ["region"],
"view_id": 5
}
```
"""
await ctx.info(
"Getting compatible metrics: dataset_id=%s, view_id=%s, "
"metrics=%s, dims=%s"
% (
request.dataset_id,
request.view_id,
request.selected_metrics,
request.selected_dimensions,
)
)
if not user_can_view_data_model_metadata():
return SemanticLayerError.create(
error="You don't have permission to access dataset details for your role.",
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
)
if request.dataset_id is None and request.view_id is None:
return SemanticLayerError.create(
error="Provide either dataset_id (built-in) or view_id (external).",
error_type="ValidationError",
)
if request.dataset_id is not None and request.view_id is not None:
return SemanticLayerError.create(
error="Provide only one of dataset_id or view_id, not both.",
error_type="ValidationError",
)
try:
# ------------------------------------------------------------------
# Built-in dataset path
# ------------------------------------------------------------------
if request.dataset_id is not None:
from sqlalchemy.orm import subqueryload
from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO
with event_logger.log_context(action="mcp.get_compatible_metrics.builtin"):
dataset = DatasetDAO.find_by_id(
request.dataset_id,
query_options=[
subqueryload(SqlaTable.columns),
subqueryload(SqlaTable.metrics),
],
)
if dataset is None:
return SemanticLayerError.create(
error=f"No dataset found with id: {request.dataset_id}.",
error_type="NotFound",
)
# All metrics on a SQL dataset are always mutually compatible.
compatible = [
MetricInfo(
name=m.metric_name,
verbose_name=m.verbose_name or None,
description=m.description or None,
expression=m.expression or None,
d3format=m.d3format or None,
warning_text=m.warning_text or None,
source="builtin",
dataset_id=dataset.id,
dataset_name=dataset.table_name,
)
for m in dataset.metrics
]
await ctx.info("Compatible metrics (builtin): count=%d" % len(compatible))
return CompatibleMetricsResponse(
compatible_metrics=compatible,
source="builtin",
)
# ------------------------------------------------------------------
# External semantic view path
# ------------------------------------------------------------------
from superset.daos.semantic_layer import SemanticViewDAO
from superset.exceptions import SupersetSecurityException
view_id: int = request.view_id # type: ignore[assignment]
with event_logger.log_context(action="mcp.get_compatible_metrics.external"):
view = SemanticViewDAO.find_by_id(view_id)
if view is None:
return SemanticLayerError.create(
error=f"No semantic view found with id: {view_id}.",
error_type="NotFound",
)
try:
view.raise_for_access()
except SupersetSecurityException as ex:
return SemanticLayerError.create(
error=str(ex.error.message),
error_type="AccessDenied",
)
compatible_names = view.get_compatible_metrics(
request.selected_metrics,
request.selected_dimensions,
)
# Enrich with full metric metadata
all_metrics_map = {m.metric_name: m for m in view.metrics}
compatible = [
MetricInfo(
name=name,
description=(
all_metrics_map[name].description
if name in all_metrics_map
else None
),
expression=(
all_metrics_map[name].expression
if name in all_metrics_map
else None
),
source="external",
view_id=view.id,
view_name=view.name,
)
for name in compatible_names
]
await ctx.info(
"Compatible metrics (external view id=%d): count=%d"
% (view.id, len(compatible))
)
return CompatibleMetricsResponse(
compatible_metrics=compatible,
source="external",
)
except Exception as exc:
logger.exception(
"Unexpected error in get_compatible_metrics: %s: %s",
type(exc).__name__,
str(exc),
)
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
return SemanticLayerError.create(
error=f"Internal error in get_compatible_metrics: {exc}",
error_type="InternalError",
)

View File

@@ -0,0 +1,437 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""MCP tool: get_table
Query a data source (built-in dataset or external semantic view) using
metric and dimension names, returning tabular results.
"""
import logging
import time
from typing import Any
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.commands.exceptions import CommandException
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
from superset.extensions import event_logger
from superset.mcp_service.chart.schemas import PerformanceMetadata
from superset.mcp_service.privacy import (
DATA_MODEL_METADATA_ERROR_TYPE,
requires_data_model_metadata_access,
user_can_view_data_model_metadata,
)
from superset.mcp_service.semantic_layer.schemas import (
GetTableRequest,
GetTableResponse,
SemanticLayerError,
)
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message
from superset.mcp_service.utils.query_utils import validate_names
from superset.mcp_service.utils.response_utils import format_data_columns
logger = logging.getLogger(__name__)
def _build_query_dict(
request: GetTableRequest,
time_col: str | None,
) -> dict[str, Any]:
"""Assemble the query dict for QueryContextFactory."""
filters: list[dict[str, Any]] = [
{"col": f.col, "op": f.op, "val": f.val} for f in request.filters
]
if request.time_range and time_col:
filters.append(
{"col": time_col, "op": "TEMPORAL_RANGE", "val": request.time_range}
)
query_dict: dict[str, Any] = {
"filters": filters,
"columns": request.dimensions,
"metrics": request.metrics,
"row_limit": request.row_limit,
"order_desc": request.order_desc,
}
if time_col:
query_dict["granularity"] = time_col
if request.order_by:
query_dict["orderby"] = [
(col, not request.order_desc) for col in request.order_by
]
return query_dict
@tool(
tags=["data", "semantic"],
class_permission_name="Dataset",
annotations=ToolAnnotations(
title="Get table",
readOnlyHint=True,
destructiveHint=False,
),
)
@requires_data_model_metadata_access
async def get_table( # noqa: C901
request: GetTableRequest,
ctx: Context,
) -> GetTableResponse | SemanticLayerError:
"""Query a data source using metrics and dimensions, returning tabular results.
Works with both built-in datasets and external semantic views. The
``dataset_id`` or ``view_id`` comes from the ``list_metrics`` response.
Workflow:
1. list_metrics -> discover metrics and their compatible_dimensions
2. get_table -> query with chosen metrics and dimensions
Example (built-in):
```json
{
"dataset_id": 42,
"metrics": ["revenue"],
"dimensions": ["region", "product_category"],
"time_range": "Last 30 days",
"row_limit": 500
}
```
Example (external):
```json
{
"view_id": 5,
"metrics": ["bookings"],
"dimensions": ["listing__country_name"],
"row_limit": 100
}
```
"""
await ctx.info(
"Starting get_table: dataset_id=%s, view_id=%s, metrics=%s, "
"dimensions=%s, row_limit=%s"
% (
request.dataset_id,
request.view_id,
request.metrics,
request.dimensions,
request.row_limit,
)
)
if not user_can_view_data_model_metadata():
return SemanticLayerError.create(
error="You don't have permission to access dataset details for your role.",
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
)
if request.dataset_id is None and request.view_id is None:
return SemanticLayerError.create(
error=(
"Provide either dataset_id (built-in dataset) or view_id "
"(external semantic view). Both are in the list_metrics response."
),
error_type="ValidationError",
)
if request.dataset_id is not None and request.view_id is not None:
return SemanticLayerError.create(
error="Provide only one of dataset_id or view_id, not both.",
error_type="ValidationError",
)
try:
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.common.query_context_factory import QueryContextFactory
is_builtin = request.dataset_id is not None
datasource_type = "table" if is_builtin else "semantic_view"
if is_builtin:
assert request.dataset_id is not None
datasource_id = request.dataset_id
else:
assert request.view_id is not None
datasource_id = request.view_id
# ------------------------------------------------------------------
# Resolve datasource for metadata (time column, display name)
# ------------------------------------------------------------------
await ctx.report_progress(1, 5, "Resolving data source")
display_name: str = f"{'Dataset' if is_builtin else 'View'} {datasource_id}"
time_col: str | None = request.time_column
warnings: list[str] = []
valid_columns: set[str] = set()
valid_metrics: set[str] = set()
if is_builtin:
from sqlalchemy.orm import subqueryload
from superset.connectors.sqla.models import SqlaTable
from superset.daos.dataset import DatasetDAO
with event_logger.log_context(action="mcp.get_table.resolve_dataset"):
dataset = DatasetDAO.find_by_id(
datasource_id,
query_options=[
subqueryload(SqlaTable.columns),
subqueryload(SqlaTable.metrics),
],
)
if dataset is None:
return SemanticLayerError.create(
error=f"No dataset found with id: {request.dataset_id}.",
error_type="NotFound",
)
display_name = dataset.table_name
valid_columns = {c.column_name for c in dataset.columns}
valid_dttm_columns = {c.column_name for c in dataset.columns if c.is_dttm}
valid_metrics = {m.metric_name for m in dataset.metrics}
if time_col is None and request.time_range:
time_col = getattr(dataset, "main_dttm_col", None)
if not time_col:
return SemanticLayerError.create(
error=(
"time_range was provided but no temporal column is "
"configured. Set time_column explicitly."
),
error_type="ValidationError",
)
if time_col is not None and time_col not in valid_dttm_columns:
if time_col in valid_columns:
error_msg = (
f"time_column '{time_col}' on dataset '{display_name}' is "
"not marked as a datetime column."
)
else:
error_msg = (
f"Unknown time_column: '{time_col}' on dataset "
f"'{display_name}'."
)
return SemanticLayerError.create(
error=error_msg,
error_type="ValidationError",
)
else:
from superset.daos.semantic_layer import SemanticViewDAO
from superset.exceptions import SupersetSecurityException
with event_logger.log_context(action="mcp.get_table.resolve_view"):
view = SemanticViewDAO.find_by_id(datasource_id)
if view is None:
return SemanticLayerError.create(
error=f"No semantic view found with id: {request.view_id}.",
error_type="NotFound",
)
try:
view.raise_for_access()
except SupersetSecurityException as ex:
return SemanticLayerError.create(
error=str(ex.error.message),
error_type="AccessDenied",
)
display_name = view.name
valid_columns = {c.column_name for c in view.columns}
valid_dttm_columns = {c.column_name for c in view.columns if c.is_dttm}
valid_metrics = {m.metric_name for m in view.metrics}
if time_col is None and request.time_range:
# Use first datetime dimension as the time column
dttm_cols = [c for c in view.columns if c.is_dttm]
if dttm_cols:
time_col = dttm_cols[0].column_name
else:
warnings.append(
"time_range provided but no datetime dimension found; "
"time filter will not be applied."
)
time_col = None
if time_col is not None and time_col not in valid_dttm_columns:
if time_col in valid_columns:
error_msg = (
f"time_column '{time_col}' on view '{display_name}' is "
"not marked as a datetime column."
)
else:
error_msg = (
f"Unknown time_column: '{time_col}' on view '{display_name}'."
)
return SemanticLayerError.create(
error=error_msg,
error_type="ValidationError",
)
# ------------------------------------------------------------------
# Validate requested metrics and dimensions against the datasource
# ------------------------------------------------------------------
await ctx.report_progress(2, 5, "Validating metrics and dimensions")
validation_errors: list[str] = []
validation_errors.extend(
validate_names(request.dimensions, valid_columns, "dimension")
)
validation_errors.extend(
validate_names(request.metrics, valid_metrics, "metric")
)
filter_cols = [f.col for f in request.filters]
validation_errors.extend(
validate_names(filter_cols, valid_columns, "filter column")
)
if request.order_by:
valid_orderby = valid_columns | valid_metrics
validation_errors.extend(
validate_names(request.order_by, valid_orderby, "order_by")
)
if validation_errors:
error_msg = "; ".join(validation_errors)
await ctx.error("Validation failed: %s" % (error_msg,))
return SemanticLayerError.create(
error=error_msg,
error_type="ValidationError",
)
# ------------------------------------------------------------------
# Build and execute query
# ------------------------------------------------------------------
await ctx.report_progress(3, 5, "Building query")
query_dict = _build_query_dict(request, time_col)
await ctx.debug("Query dict: %s" % (sorted(query_dict.keys()),))
await ctx.report_progress(4, 5, "Executing query")
start_time = time.time()
with event_logger.log_context(action="mcp.get_table.execute"):
factory = QueryContextFactory()
query_context = factory.create(
datasource={"id": datasource_id, "type": datasource_type},
queries=[query_dict],
form_data={},
force=not request.use_cache or request.force_refresh,
)
command = ChartDataCommand(query_context)
command.validate()
result = command.run()
query_duration_ms = int((time.time() - start_time) * 1000)
if not result or "queries" not in result or not result["queries"]:
return SemanticLayerError.create(
error="Query returned no results.",
error_type="EmptyQuery",
)
# ------------------------------------------------------------------
# Format response
# ------------------------------------------------------------------
await ctx.report_progress(5, 5, "Formatting results")
query_result = result["queries"][0]
data = query_result.get("data", [])
raw_columns = query_result.get("colnames", [])
if not data:
return GetTableResponse(
columns=[],
data=[],
row_count=0,
total_rows=0,
summary=f"'{display_name}': query returned no data.",
source="builtin" if is_builtin else "external",
dataset_id=request.dataset_id,
dataset_name=display_name if is_builtin else None,
view_id=request.view_id,
view_name=display_name if not is_builtin else None,
performance=PerformanceMetadata(
query_duration_ms=query_duration_ms,
cache_status="no_data",
),
cache_status=get_cache_status_from_result(
query_result, force_refresh=request.force_refresh
),
warnings=warnings,
)
columns_meta = format_data_columns(data, raw_columns)
cache_status = get_cache_status_from_result(
query_result, force_refresh=request.force_refresh
)
cache_label = "cached" if cache_status and cache_status.cache_hit else "fresh"
summary = (
f"'{display_name}': {len(data)} rows, "
f"{len(raw_columns)} columns ({cache_label})."
)
await ctx.info(
"get_table complete: rows=%d, columns=%d, duration=%dms"
% (len(data), len(raw_columns), query_duration_ms)
)
return GetTableResponse(
columns=columns_meta,
data=data,
row_count=len(data),
total_rows=query_result.get("rowcount"),
summary=summary,
source="builtin" if is_builtin else "external",
dataset_id=request.dataset_id,
dataset_name=display_name if is_builtin else None,
view_id=request.view_id,
view_name=display_name if not is_builtin else None,
performance=PerformanceMetadata(
query_duration_ms=query_duration_ms,
cache_status=cache_label,
),
cache_status=cache_status,
warnings=warnings,
)
except OAuth2RedirectError as exc:
redirect_msg = build_oauth2_redirect_message(exc)
await ctx.error("OAuth2 redirect required: %s" % redirect_msg)
return SemanticLayerError.create(
error=redirect_msg,
error_type="OAuth2Redirect",
)
except OAuth2Error as exc:
await ctx.error("OAuth2 error: %s" % str(exc))
return SemanticLayerError.create(
error=f"OAuth2 authentication error: {exc}",
error_type="OAuth2Error",
)
except (CommandException, SupersetException) as exc:
await ctx.error("Query failed: %s" % str(exc))
return SemanticLayerError.create(
error=f"Query execution failed: {exc}",
error_type="QueryError",
)
except SQLAlchemyError as exc:
await ctx.error("Database error: %s" % str(exc))
return SemanticLayerError.create(
error=f"Database error: {exc}",
error_type="DatabaseError",
)
except Exception as exc:
logger.exception(
"Unexpected error in get_table: %s: %s", type(exc).__name__, str(exc)
)
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
return SemanticLayerError.create(
error=f"Internal error executing get_table: {exc}",
error_type="InternalError",
)

View File

@@ -0,0 +1,294 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""MCP tool: list_metrics
Unified metric discovery across built-in datasets and external semantic views.
"""
import logging
from typing import Any
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.daos.dataset import DatasetDAO
from superset.daos.semantic_layer import SemanticViewDAO
from superset.extensions import db, event_logger
from superset.mcp_service.privacy import (
DATA_MODEL_METADATA_ERROR_TYPE,
requires_data_model_metadata_access,
user_can_view_data_model_metadata,
)
from superset.mcp_service.semantic_layer.schemas import (
DimensionInfo,
ListMetricsRequest,
MetricInfo,
MetricList,
SemanticLayerError,
)
logger = logging.getLogger(__name__)
def _matches_search(text: str | None, search: str) -> bool:
if not text:
return False
return search.lower() in text.lower()
def _builtin_compatible_dims(dataset: Any) -> list[DimensionInfo]:
"""Return groupby-enabled columns as compatible dimensions for a built-in metric."""
return [
DimensionInfo(
name=col.column_name,
verbose_name=col.verbose_name or None,
description=col.description or None,
type=col.type or None,
is_dttm=bool(col.is_dttm),
groupby=bool(col.groupby),
filterable=bool(col.filterable),
source="builtin",
)
for col in dataset.columns
if col.groupby
]
def _collect_builtin_metrics(request: ListMetricsRequest) -> list[MetricInfo]:
"""Collect metrics from built-in SqlaTable datasets."""
from sqlalchemy.orm import subqueryload
from superset.connectors.sqla.models import SqlaTable
with event_logger.log_context(action="mcp.list_metrics.builtin_query"):
if request.dataset_id is not None:
from sqlalchemy.orm import subqueryload
from superset.connectors.sqla.models import SqlaTable
dataset = DatasetDAO.find_by_id(
request.dataset_id,
query_options=[
subqueryload(SqlaTable.columns),
subqueryload(SqlaTable.metrics),
],
)
datasets = [dataset] if dataset else []
else:
from sqlalchemy.orm import subqueryload
from superset.connectors.sqla.models import SqlaTable
# Use _apply_base_filter with explicit eager loading to avoid
# N+1 queries when iterating dataset.metrics / dataset.columns.
query = db.session.query(SqlaTable).options(
subqueryload(SqlaTable.columns),
subqueryload(SqlaTable.metrics),
)
datasets = DatasetDAO._apply_base_filter(query).all()
results: list[MetricInfo] = []
for dataset in datasets:
compat_dims = (
_builtin_compatible_dims(dataset)
if request.include_compatible_dimensions
else []
)
for metric in dataset.metrics:
name = metric.metric_name or ""
desc = metric.description or ""
if request.search and not (
_matches_search(name, request.search)
or _matches_search(desc, request.search)
):
continue
results.append(
MetricInfo(
name=name,
verbose_name=metric.verbose_name or None,
description=desc or None,
expression=metric.expression or None,
d3format=metric.d3format or None,
warning_text=metric.warning_text or None,
source="builtin",
dataset_id=dataset.id,
dataset_name=dataset.table_name,
compatible_dimensions=compat_dims,
)
)
return results
async def _collect_external_metrics(
request: ListMetricsRequest,
ctx: Context,
) -> list[MetricInfo]:
"""Collect metrics from external SemanticView models."""
with event_logger.log_context(action="mcp.list_metrics.external_query"):
if request.view_id is not None:
view = SemanticViewDAO.find_by_id(request.view_id)
views = [view] if view else []
else:
# find_accessible filters at SQL level, avoiding a per-row
# Python permission check and the audit noise of raise_for_access.
views = SemanticViewDAO.find_accessible()
await ctx.debug("Found %d semantic views to scan for metrics" % len(views))
results: list[MetricInfo] = []
for view in views:
# raise_for_access must be called outside the broad except block below
# so that auth errors are never silently swallowed.
view.raise_for_access()
try:
raw_metrics = view.metrics
raw_cols = view.columns if request.include_compatible_dimensions else []
compat_dims = [
DimensionInfo(
name=col.column_name,
verbose_name=col.verbose_name,
description=col.description,
type=col.type,
is_dttm=col.is_dttm,
groupby=col.groupby,
filterable=col.filterable,
source="external",
)
for col in raw_cols
]
for metric in raw_metrics:
name = metric.metric_name or ""
desc = metric.description or ""
if request.search and not (
_matches_search(name, request.search)
or _matches_search(desc, request.search)
):
continue
results.append(
MetricInfo(
name=name,
description=desc or None,
expression=metric.expression or None,
source="external",
view_id=view.id,
view_name=view.name,
compatible_dimensions=compat_dims,
)
)
except Exception as exc: # noqa: BLE001
# External registry may be empty in OSS — degrade gracefully
await ctx.warning(
"Could not load metrics for view id=%s: %s" % (view.id, str(exc))
)
return results
@tool(
tags=["data", "semantic"],
class_permission_name="Dataset",
annotations=ToolAnnotations(
title="List metrics",
readOnlyHint=True,
destructiveHint=False,
),
)
@requires_data_model_metadata_access
async def list_metrics(
request: ListMetricsRequest | None = None,
ctx: Context | None = None,
) -> MetricList | SemanticLayerError:
"""List available metrics across built-in datasets and external semantic views.
This is the primary entry point for semantic layer exploration. Returns a
unified list of metrics from all data sources the current user can access,
with compatible dimensions included inline.
Workflow:
1. list_metrics -> discover available metrics and their compatible dimensions
2. get_table -> query data using chosen metrics and dimensions
Use ``search`` to filter by metric name or description. Use ``dataset_id``
or ``view_id`` to scope to a specific data source.
Example:
```json
{"search": "revenue", "include_compatible_dimensions": true, "page": 1}
```
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_metrics")
request = request or ListMetricsRequest()
await ctx.info(
"Listing metrics: search=%s, dataset_id=%s, view_id=%s, page=%s"
% (request.search, request.dataset_id, request.view_id, request.page)
)
if not user_can_view_data_model_metadata():
await ctx.warning("Metric listing blocked by data-model privacy controls")
return SemanticLayerError.create(
error="You don't have permission to access dataset details for your role.",
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
)
if request.dataset_id is not None and request.view_id is not None:
return SemanticLayerError.create(
error="Provide only one of dataset_id or view_id to scope the search, not both.",
error_type="ValidationError",
)
try:
all_metrics: list[MetricInfo] = []
if request.view_id is None:
all_metrics.extend(_collect_builtin_metrics(request))
await ctx.debug("Collected %d built-in metrics" % len(all_metrics))
if request.dataset_id is None:
external = await _collect_external_metrics(request, ctx)
all_metrics.extend(external)
await ctx.debug("Collected %d external metrics" % len(external))
total_count = len(all_metrics)
total_pages = max(1, (total_count + request.page_size - 1) // request.page_size)
start = (request.page - 1) * request.page_size
page_metrics = all_metrics[start : start + request.page_size]
await ctx.info(
"Metrics listed: total=%d, page=%d/%d, returned=%d"
% (total_count, request.page, total_pages, len(page_metrics))
)
return MetricList(
metrics=page_metrics,
total_count=total_count,
page=request.page,
page_size=request.page_size,
total_pages=total_pages,
)
except Exception as exc:
logger.exception(
"Unexpected error in list_metrics: %s: %s", type(exc).__name__, str(exc)
)
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
return SemanticLayerError.create(
error=f"Internal error listing metrics: {exc}",
error_type="InternalError",
)

View File

@@ -0,0 +1,40 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Shared query validation utilities for MCP tools."""
import difflib
def validate_names(
requested: list[str],
valid: set[str],
kind: str,
) -> list[str]:
"""Return list of error messages for names not found in *valid*.
Includes close-match suggestions when available.
"""
errors: list[str] = []
for name in requested:
if name not in valid:
suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6)
msg = f"Unknown {kind}: '{name}'"
if suggestions:
msg += f". Did you mean: {', '.join(suggestions)}?"
errors.append(msg)
return errors

View File

@@ -57,7 +57,10 @@ Usage example::
from __future__ import annotations
from datetime import datetime
from typing import Dict
from typing import Any, Dict, TYPE_CHECKING
if TYPE_CHECKING:
from superset.mcp_service.chart.schemas import DataColumn
import humanize
@@ -161,3 +164,49 @@ class OmittedFieldsBuilder:
def build(self) -> Dict[str, str]:
"""Return the omission metadata dict."""
return dict(self._fields)
def format_data_columns(
data: list[dict[str, Any]], raw_columns: list[str]
) -> list[DataColumn]:
"""Build column metadata from query result data.
Caps statistics at 5000 rows to avoid O(rows*cols) overhead on large
result sets.
"""
# Local import breaks the chart.schemas ↔ response_utils circular dependency.
from superset.mcp_service.chart.schemas import DataColumn # noqa: PLC0415
stats_rows = data[:5000]
columns_meta: list[DataColumn] = []
for col_name in raw_columns:
sample_values = [
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
]
data_type = "string"
if sample_values:
if all(isinstance(v, bool) for v in sample_values):
data_type = "boolean"
elif all(isinstance(v, (int, float)) for v in sample_values):
data_type = "numeric"
null_count = 0
unique_vals: set[str] = set()
for row in stats_rows:
val = row.get(col_name)
if val is None:
null_count += 1
else:
unique_vals.add(str(val))
columns_meta.append(
DataColumn(
name=col_name,
display_name=col_name.replace("_", " ").title(),
data_type=data_type,
sample_values=sample_values[:3],
null_count=null_count,
unique_count=len(unique_vals),
)
)
return columns_meta

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,231 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for the get_compatible_dimensions MCP tool."""
from __future__ import annotations
import importlib
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client, FastMCP
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from superset.mcp_service.app import mcp
from superset.utils import json
get_compatible_dimensions_module = importlib.import_module(
"superset.mcp_service.semantic_layer.tool.get_compatible_dimensions"
)
@pytest.fixture
def mcp_server() -> FastMCP:
return mcp
@pytest.fixture(autouse=True)
def mock_auth() -> Generator[MagicMock, None, None]:
with (
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
patch.object(
get_compatible_dimensions_module,
"user_can_view_data_model_metadata",
return_value=True,
),
):
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _make_column(name: str, groupby: bool = True) -> MagicMock:
col = MagicMock()
col.column_name = name
col.verbose_name = None
col.description = None
col.type = "VARCHAR"
col.is_dttm = False
col.groupby = groupby
col.filterable = True
return col
def _make_dataset(dataset_id: int = 42) -> MagicMock:
ds = MagicMock()
ds.id = dataset_id
ds.table_name = f"table_{dataset_id}"
ds.metrics = []
ds.columns = [
_make_column("region"),
_make_column("category"),
_make_column("internal_only", groupby=False),
]
return ds
def _make_view(view_id: int = 5) -> MagicMock:
view = MagicMock()
view.id = view_id
view.name = f"view_{view_id}"
view.raise_for_access = MagicMock(return_value=None)
view.columns = [_make_column("country_name")]
view.get_compatible_dimensions = MagicMock(return_value=["country_name"])
return view
def _access_denied_exc(message: str = "Access denied") -> SupersetSecurityException:
return SupersetSecurityException(
SupersetError(
message=message,
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.ERROR,
)
)
@pytest.mark.asyncio
async def test_get_compatible_dimensions_builtin_happy_path(
mcp_server: FastMCP,
) -> None:
"""Builtin datasets return all groupby-enabled columns, ignoring selection."""
mock_ds = _make_dataset(42)
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_dimensions",
{"request": {"dataset_id": 42, "selected_metrics": ["revenue"]}},
)
data = json.loads(result.content[0].text)
assert data["success"] is True
assert data["source"] == "builtin"
names = {d["name"] for d in data["compatible_dimensions"]}
assert names == {"region", "category"}
@pytest.mark.asyncio
async def test_get_compatible_dimensions_external_happy_path(
mcp_server: FastMCP,
) -> None:
"""External views delegate to view.get_compatible_dimensions()."""
mock_view = _make_view(5)
with patch(
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
return_value=mock_view,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_dimensions",
{"request": {"view_id": 5, "selected_metrics": ["bookings"]}},
)
data = json.loads(result.content[0].text)
assert data["success"] is True
assert data["source"] == "external"
assert [d["name"] for d in data["compatible_dimensions"]] == ["country_name"]
mock_view.get_compatible_dimensions.assert_called_once_with(["bookings"], [])
@pytest.mark.asyncio
async def test_get_compatible_dimensions_mutual_exclusion_validation(
mcp_server: FastMCP,
) -> None:
"""Errors when both dataset_id and view_id are provided."""
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_dimensions",
{"request": {"dataset_id": 1, "view_id": 2}},
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "ValidationError"
@pytest.mark.asyncio
async def test_get_compatible_dimensions_requires_one_source(
mcp_server: FastMCP,
) -> None:
"""Errors when neither dataset_id nor view_id is provided."""
async with Client(mcp_server) as client:
result = await client.call_tool("get_compatible_dimensions", {"request": {}})
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "ValidationError"
@pytest.mark.asyncio
async def test_get_compatible_dimensions_privacy_check(mcp_server: FastMCP) -> None:
"""Errors when the user lacks data-model metadata access."""
with patch.object(
get_compatible_dimensions_module,
"user_can_view_data_model_metadata",
return_value=False,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_dimensions", {"request": {"dataset_id": 1}}
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "DataModelMetadataRestricted"
@pytest.mark.asyncio
async def test_get_compatible_dimensions_external_access_denied(
mcp_server: FastMCP,
) -> None:
"""Returns AccessDenied when raise_for_access rejects the view."""
mock_view = _make_view(5)
mock_view.raise_for_access.side_effect = _access_denied_exc()
with patch(
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
return_value=mock_view,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_dimensions", {"request": {"view_id": 5}}
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "AccessDenied"
@pytest.mark.asyncio
async def test_get_compatible_dimensions_not_found(mcp_server: FastMCP) -> None:
"""Returns NotFound when the dataset doesn't exist."""
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_dimensions", {"request": {"dataset_id": 999}}
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "NotFound"

View File

@@ -0,0 +1,222 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for the get_compatible_metrics MCP tool."""
from __future__ import annotations
import importlib
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client, FastMCP
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from superset.mcp_service.app import mcp
from superset.utils import json
get_compatible_metrics_module = importlib.import_module(
"superset.mcp_service.semantic_layer.tool.get_compatible_metrics"
)
@pytest.fixture
def mcp_server() -> FastMCP:
return mcp
@pytest.fixture(autouse=True)
def mock_auth() -> Generator[MagicMock, None, None]:
with (
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
patch.object(
get_compatible_metrics_module,
"user_can_view_data_model_metadata",
return_value=True,
),
):
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
m = MagicMock()
m.metric_name = name
m.verbose_name = None
m.expression = expression
m.description = None
m.d3format = None
m.warning_text = None
return m
def _make_dataset(dataset_id: int = 42) -> MagicMock:
ds = MagicMock()
ds.id = dataset_id
ds.table_name = f"table_{dataset_id}"
ds.columns = []
ds.metrics = [_make_metric("count"), _make_metric("revenue", "SUM(revenue)")]
return ds
def _make_view(view_id: int = 5) -> MagicMock:
view = MagicMock()
view.id = view_id
view.name = f"view_{view_id}"
view.raise_for_access = MagicMock(return_value=None)
view.metrics = [_make_metric("bookings")]
view.get_compatible_metrics = MagicMock(return_value=["bookings"])
return view
def _access_denied_exc(message: str = "Access denied") -> SupersetSecurityException:
return SupersetSecurityException(
SupersetError(
message=message,
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.ERROR,
)
)
@pytest.mark.asyncio
async def test_get_compatible_metrics_builtin_happy_path(mcp_server: FastMCP) -> None:
"""Builtin datasets return all metrics, ignoring the current selection."""
mock_ds = _make_dataset(42)
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_metrics",
{"request": {"dataset_id": 42, "selected_dimensions": ["region"]}},
)
data = json.loads(result.content[0].text)
assert data["success"] is True
assert data["source"] == "builtin"
names = {m["name"] for m in data["compatible_metrics"]}
assert names == {"count", "revenue"}
@pytest.mark.asyncio
async def test_get_compatible_metrics_external_happy_path(mcp_server: FastMCP) -> None:
"""External views delegate to view.get_compatible_metrics()."""
mock_view = _make_view(5)
with patch(
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
return_value=mock_view,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_metrics",
{"request": {"view_id": 5, "selected_dimensions": ["country_name"]}},
)
data = json.loads(result.content[0].text)
assert data["success"] is True
assert data["source"] == "external"
assert [m["name"] for m in data["compatible_metrics"]] == ["bookings"]
mock_view.get_compatible_metrics.assert_called_once_with([], ["country_name"])
@pytest.mark.asyncio
async def test_get_compatible_metrics_mutual_exclusion_validation(
mcp_server: FastMCP,
) -> None:
"""Errors when both dataset_id and view_id are provided."""
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_metrics",
{"request": {"dataset_id": 1, "view_id": 2}},
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "ValidationError"
@pytest.mark.asyncio
async def test_get_compatible_metrics_requires_one_source(
mcp_server: FastMCP,
) -> None:
"""Errors when neither dataset_id nor view_id is provided."""
async with Client(mcp_server) as client:
result = await client.call_tool("get_compatible_metrics", {"request": {}})
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "ValidationError"
@pytest.mark.asyncio
async def test_get_compatible_metrics_privacy_check(mcp_server: FastMCP) -> None:
"""Errors when the user lacks data-model metadata access."""
with patch.object(
get_compatible_metrics_module,
"user_can_view_data_model_metadata",
return_value=False,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_metrics", {"request": {"dataset_id": 1}}
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "DataModelMetadataRestricted"
@pytest.mark.asyncio
async def test_get_compatible_metrics_external_access_denied(
mcp_server: FastMCP,
) -> None:
"""Returns AccessDenied when raise_for_access rejects the view."""
mock_view = _make_view(5)
mock_view.raise_for_access.side_effect = _access_denied_exc()
with patch(
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
return_value=mock_view,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_metrics", {"request": {"view_id": 5}}
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "AccessDenied"
@pytest.mark.asyncio
async def test_get_compatible_metrics_not_found(mcp_server: FastMCP) -> None:
"""Returns NotFound when the dataset doesn't exist."""
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_compatible_metrics", {"request": {"dataset_id": 999}}
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "NotFound"

View File

@@ -0,0 +1,263 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for the get_table MCP tool."""
from __future__ import annotations
import importlib
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client, FastMCP
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetSecurityException
from superset.mcp_service.app import mcp
from superset.utils import json
get_table_module = importlib.import_module(
"superset.mcp_service.semantic_layer.tool.get_table"
)
@pytest.fixture
def mcp_server() -> FastMCP:
return mcp
@pytest.fixture(autouse=True)
def mock_auth() -> Generator[MagicMock, None, None]:
with (
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
patch.object(
get_table_module,
"user_can_view_data_model_metadata",
return_value=True,
),
):
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
m = MagicMock()
m.metric_name = name
m.verbose_name = None
m.expression = expression
m.description = None
m.d3format = None
m.warning_text = None
return m
def _make_column(name: str, is_dttm: bool = False) -> MagicMock:
col = MagicMock()
col.column_name = name
col.verbose_name = None
col.description = None
col.type = "VARCHAR"
col.is_dttm = is_dttm
col.groupby = True
col.filterable = True
return col
def _make_dataset(dataset_id: int = 42) -> MagicMock:
ds = MagicMock()
ds.id = dataset_id
ds.table_name = f"table_{dataset_id}"
ds.main_dttm_col = "created_at"
ds.metrics = [_make_metric("revenue", "SUM(revenue)")]
ds.columns = [
_make_column("region"),
_make_column("created_at", is_dttm=True),
]
return ds
def _make_view(view_id: int = 5) -> MagicMock:
view = MagicMock()
view.id = view_id
view.name = f"view_{view_id}"
view.raise_for_access = MagicMock(return_value=None)
view.metrics = [_make_metric("bookings")]
view.columns = [_make_column("country_name")]
return view
def _access_denied_exc(message: str = "Access denied") -> SupersetSecurityException:
return SupersetSecurityException(
SupersetError(
message=message,
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.ERROR,
)
)
@pytest.mark.asyncio
async def test_get_table_builtin_happy_path(mcp_server: FastMCP) -> None:
"""get_table returns tabular data for a built-in dataset."""
mock_ds = _make_dataset(42)
query_result = {
"queries": [
{
"data": [{"region": "west", "revenue": 100}],
"colnames": ["region", "revenue"],
"rowcount": 1,
}
]
}
with (
patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds),
patch(
"superset.commands.chart.data.get_data_command.ChartDataCommand"
) as mock_command_cls,
patch(
"superset.common.query_context_factory.QueryContextFactory"
) as mock_factory_cls,
):
mock_command_cls.return_value.run.return_value = query_result
mock_factory_cls.return_value.create.return_value = MagicMock()
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_table",
{
"request": {
"dataset_id": 42,
"metrics": ["revenue"],
"dimensions": ["region"],
}
},
)
data = json.loads(result.content[0].text)
assert data["success"] is True
assert data["row_count"] == 1
assert data["source"] == "builtin"
assert data["dataset_id"] == 42
@pytest.mark.asyncio
async def test_get_table_requires_one_source(mcp_server: FastMCP) -> None:
"""get_table errors when neither dataset_id nor view_id is provided."""
async with Client(mcp_server) as client:
result = await client.call_tool("get_table", {"request": {}})
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "ValidationError"
@pytest.mark.asyncio
async def test_get_table_mutual_exclusion_validation(mcp_server: FastMCP) -> None:
"""get_table errors when both dataset_id and view_id are provided."""
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_table", {"request": {"dataset_id": 1, "view_id": 2}}
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "ValidationError"
@pytest.mark.asyncio
async def test_get_table_privacy_check(mcp_server: FastMCP) -> None:
"""get_table errors when the user lacks data-model metadata access."""
with patch.object(
get_table_module,
"user_can_view_data_model_metadata",
return_value=False,
):
async with Client(mcp_server) as client:
result = await client.call_tool("get_table", {"request": {"dataset_id": 1}})
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "DataModelMetadataRestricted"
@pytest.mark.asyncio
async def test_get_table_unknown_metric_validation_error(mcp_server: FastMCP) -> None:
"""get_table errors when a requested metric doesn't exist on the dataset."""
mock_ds = _make_dataset(42)
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_table",
{"request": {"dataset_id": 42, "metrics": ["does_not_exist"]}},
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "ValidationError"
@pytest.mark.asyncio
async def test_get_table_time_column_not_dttm_validation_error(
mcp_server: FastMCP,
) -> None:
"""get_table rejects a time_column that isn't marked as a datetime column."""
mock_ds = _make_dataset(42)
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_table",
{
"request": {
"dataset_id": 42,
"metrics": ["revenue"],
"time_column": "region",
}
},
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "ValidationError"
assert "not marked as a datetime column" in data["message"]
@pytest.mark.asyncio
async def test_get_table_external_view_access_denied(mcp_server: FastMCP) -> None:
"""get_table returns AccessDenied when raise_for_access rejects the view."""
mock_view = _make_view(5)
mock_view.raise_for_access.side_effect = _access_denied_exc()
with patch(
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
return_value=mock_view,
):
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_table",
{"request": {"view_id": 5, "metrics": ["bookings"]}},
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "AccessDenied"

View File

@@ -0,0 +1,183 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for the list_metrics MCP tool."""
from __future__ import annotations
import importlib
from collections.abc import Generator
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client, FastMCP
from superset.mcp_service.app import mcp
from superset.utils import json
list_metrics_module = importlib.import_module(
"superset.mcp_service.semantic_layer.tool.list_metrics"
)
@pytest.fixture
def mcp_server() -> FastMCP:
return mcp
@pytest.fixture(autouse=True)
def mock_auth() -> Generator[MagicMock, None, None]:
with (
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
patch.object(
list_metrics_module,
"user_can_view_data_model_metadata",
return_value=True,
),
):
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
m = MagicMock()
m.metric_name = name
m.verbose_name = None
m.expression = expression
m.description = None
m.d3format = None
m.warning_text = None
return m
def _make_column(name: str) -> MagicMock:
col = MagicMock()
col.column_name = name
col.verbose_name = None
col.description = None
col.type = "VARCHAR"
col.is_dttm = False
col.groupby = True
col.filterable = True
return col
def _make_dataset(dataset_id: int = 1) -> MagicMock:
ds = MagicMock()
ds.id = dataset_id
ds.table_name = f"table_{dataset_id}"
ds.metrics = [_make_metric("count"), _make_metric("revenue", "SUM(revenue)")]
ds.columns = [_make_column("region"), _make_column("category")]
return ds
@pytest.mark.asyncio
async def test_list_metrics_builtin_happy_path(mcp_server: FastMCP) -> None:
"""list_metrics returns builtin metrics when only datasets exist."""
mock_ds = _make_dataset(42)
with (
patch(
"superset.mcp_service.semantic_layer.tool.list_metrics.DatasetDAO"
) as mock_dao,
patch(
"superset.mcp_service.semantic_layer.tool.list_metrics.SemanticViewDAO"
) as mock_view_dao,
):
mock_dao.find_by_id.return_value = mock_ds
mock_view_dao.find_accessible.return_value = []
async with Client(mcp_server) as client:
result = await client.call_tool(
"list_metrics",
{"request": {"dataset_id": 42, "include_compatible_dimensions": False}},
)
data = json.loads(result.content[0].text)
assert data["success"] is True
assert data["total_count"] == 2
metrics = data["metrics"]
assert {m["name"] for m in metrics} == {"count", "revenue"}
assert all(m["source"] == "builtin" for m in metrics)
assert all(m["dataset_id"] == 42 for m in metrics)
@pytest.mark.asyncio
async def test_list_metrics_mutual_exclusion_validation(mcp_server: FastMCP) -> None:
"""list_metrics returns a validation error when dataset_id and view_id coexist."""
async with Client(mcp_server) as client:
result = await client.call_tool(
"list_metrics",
{"request": {"dataset_id": 1, "view_id": 2}},
)
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "ValidationError"
@pytest.mark.asyncio
async def test_list_metrics_privacy_check(mcp_server: FastMCP) -> None:
"""list_metrics returns an error when the user lacks data-model metadata access."""
with patch.object(
list_metrics_module,
"user_can_view_data_model_metadata",
return_value=False,
):
async with Client(mcp_server) as client:
result = await client.call_tool("list_metrics", {})
data = json.loads(result.content[0].text)
assert data["success"] is False
assert data["error_type"] == "DataModelMetadataRestricted"
@pytest.mark.asyncio
async def test_list_metrics_search_filter(mcp_server: FastMCP) -> None:
"""list_metrics filters metrics by search term."""
mock_ds = _make_dataset(1)
with (
patch(
"superset.mcp_service.semantic_layer.tool.list_metrics.DatasetDAO"
) as mock_dao,
patch(
"superset.mcp_service.semantic_layer.tool.list_metrics.SemanticViewDAO"
) as mock_view_dao,
patch("superset.mcp_service.semantic_layer.tool.list_metrics.db") as mock_db,
):
mock_view_dao.find_accessible.return_value = []
mock_query = MagicMock()
mock_db.session.query.return_value.options.return_value = mock_query
mock_dao._apply_base_filter.return_value = mock_query
mock_query.all.return_value = [mock_ds]
async with Client(mcp_server) as client:
result = await client.call_tool(
"list_metrics",
{"request": {"search": "revenue"}},
)
data = json.loads(result.content[0].text)
assert data["success"] is True
# Only the "revenue" metric should match the search
metrics = data["metrics"]
assert len(metrics) == 1
assert metrics[0]["name"] == "revenue"