mirror of
https://github.com/apache/superset.git
synced 2026-07-02 21:05:36 +00:00
Compare commits
7 Commits
chore/ci-c
...
aminghader
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ec88c32af0 | ||
|
|
290ab5d882 | ||
|
|
1b9de8f4a6 | ||
|
|
d80343c20a | ||
|
|
48a73b5d59 | ||
|
|
89a52f99be | ||
|
|
8f2a01e294 |
@@ -201,6 +201,29 @@ class SemanticViewDAO(BaseDAO[SemanticView], AbstractSemanticViewDAO):
|
||||
for c in candidates
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def find_accessible(cls) -> list[SemanticView]:
|
||||
"""Return all views the current user can access, filtered at SQL level.
|
||||
|
||||
Mirrors the permission filter in ``DatasourceDAO.build_semantic_view_query``
|
||||
to avoid per-row Python-level access checks when listing all views.
|
||||
"""
|
||||
from sqlalchemy import or_
|
||||
|
||||
query = db.session.query(SemanticView)
|
||||
if not security_manager.can_access_all_datasources():
|
||||
perms = security_manager.user_view_menu_names("datasource_access")
|
||||
query = query.outerjoin(
|
||||
SemanticLayer,
|
||||
SemanticLayer.uuid == SemanticView.semantic_layer_uuid,
|
||||
).filter(
|
||||
or_(
|
||||
SemanticView.perm.in_(perms),
|
||||
SemanticLayer.perm.in_(perms),
|
||||
)
|
||||
)
|
||||
return query.all()
|
||||
|
||||
@classmethod
|
||||
def find_by_name(cls, name: str, layer_uuid: str) -> SemanticView | None:
|
||||
"""
|
||||
|
||||
@@ -291,6 +291,31 @@ Use created_by_me for authorship, owned_by_me for edit ownership, or both
|
||||
together for the union. All flags can be combined with 'filters' but not
|
||||
with 'search'.
|
||||
|
||||
To explore metrics across all data sources (built-in datasets + external semantic views):
|
||||
1. list_metrics(request={{"search": "<keyword>"}})
|
||||
-> returns metrics with dataset_id/view_id and compatible_dimensions inline
|
||||
2. get_table(request={{
|
||||
"dataset_id": <id>, # OR "view_id": <id> for external semantic views
|
||||
"metrics": ["revenue"],
|
||||
"dimensions": ["region"],
|
||||
"time_range": "Last 30 days",
|
||||
"row_limit": 500
|
||||
}}) -> returns tabular results
|
||||
- Use "dataset_id" when list_metrics returned source="builtin"
|
||||
- Use "view_id" when list_metrics returned source="external"
|
||||
|
||||
To progressively refine a query (compatible dimensions/metrics):
|
||||
- get_compatible_dimensions(request={{
|
||||
"selected_metrics": ["revenue"],
|
||||
"selected_dimensions": [],
|
||||
"dataset_id": <id> # or "view_id": <id>
|
||||
}}) -> dimensions valid to add to the current selection
|
||||
- get_compatible_metrics(request={{
|
||||
"selected_metrics": [],
|
||||
"selected_dimensions": ["region"],
|
||||
"view_id": <id> # useful for external semantic layers with constraints
|
||||
}}) -> metrics valid to add to the current selection
|
||||
|
||||
To query a dataset's semantic layer (metrics, dimensions):
|
||||
1. list_datasets(request={{}}) -> find a dataset
|
||||
2. get_dataset_info(request={{"identifier": <id>}}) -> examine columns AND metrics
|
||||
@@ -741,6 +766,12 @@ from superset.mcp_service.saved_query.tool import ( # noqa: F401, E402
|
||||
get_saved_query_info,
|
||||
list_saved_queries,
|
||||
)
|
||||
from superset.mcp_service.semantic_layer.tool import ( # noqa: F401, E402
|
||||
get_compatible_dimensions,
|
||||
get_compatible_metrics,
|
||||
get_table,
|
||||
list_metrics,
|
||||
)
|
||||
from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402
|
||||
execute_sql,
|
||||
open_sql_lab_with_context,
|
||||
|
||||
@@ -22,7 +22,6 @@ Query a dataset using its semantic layer (saved metrics, calculated columns,
|
||||
dimensions) without requiring a saved chart.
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
@@ -35,7 +34,7 @@ from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
|
||||
from superset.mcp_service.chart.schemas import PerformanceMetadata
|
||||
from superset.mcp_service.dataset.schemas import (
|
||||
DatasetError,
|
||||
QueryDatasetFilter,
|
||||
@@ -50,6 +49,8 @@ from superset.mcp_service.privacy import (
|
||||
from superset.mcp_service.utils import _is_uuid
|
||||
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
|
||||
from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message
|
||||
from superset.mcp_service.utils.query_utils import validate_names
|
||||
from superset.mcp_service.utils.response_utils import format_data_columns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -80,26 +81,6 @@ def _resolve_dataset(identifier: int | str, eager_options: list[Any]) -> Any | N
|
||||
return None
|
||||
|
||||
|
||||
def _validate_names(
|
||||
requested: list[str],
|
||||
valid: set[str],
|
||||
kind: str,
|
||||
) -> list[str]:
|
||||
"""Return list of error messages for names not found in *valid*.
|
||||
|
||||
Includes close-match suggestions when available.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
for name in requested:
|
||||
if name not in valid:
|
||||
suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6)
|
||||
msg = f"Unknown {kind}: '{name}'"
|
||||
if suggestions:
|
||||
msg += f". Did you mean: {', '.join(suggestions)}?"
|
||||
errors.append(msg)
|
||||
return errors
|
||||
|
||||
|
||||
@requires_data_model_metadata_access
|
||||
@tool(
|
||||
tags=["data"],
|
||||
@@ -210,21 +191,21 @@ async def query_dataset( # noqa: C901
|
||||
|
||||
validation_errors: list[str] = []
|
||||
validation_errors.extend(
|
||||
_validate_names(request.columns, valid_columns, "column")
|
||||
validate_names(request.columns, valid_columns, "column")
|
||||
)
|
||||
validation_errors.extend(
|
||||
_validate_names(request.metrics, valid_metrics, "metric")
|
||||
validate_names(request.metrics, valid_metrics, "metric")
|
||||
)
|
||||
# Validate filter column names against dataset columns
|
||||
filter_cols = [f.col for f in request.filters]
|
||||
validation_errors.extend(
|
||||
_validate_names(filter_cols, valid_columns, "filter column")
|
||||
validate_names(filter_cols, valid_columns, "filter column")
|
||||
)
|
||||
# Validate order_by names against columns + metrics
|
||||
if request.order_by:
|
||||
valid_orderby = valid_columns | valid_metrics
|
||||
validation_errors.extend(
|
||||
_validate_names(request.order_by, valid_orderby, "order_by")
|
||||
validate_names(request.order_by, valid_orderby, "order_by")
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
@@ -379,44 +360,7 @@ async def query_dataset( # noqa: C901
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
# Build column metadata in a single pass per column.
|
||||
# Cap stats computation at STATS_SAMPLE rows to avoid O(rows*cols)
|
||||
# overhead on large result sets (row_limit allows up to 50k).
|
||||
stats_sample_size = 5000
|
||||
stats_rows = data[:stats_sample_size]
|
||||
|
||||
columns_meta: list[DataColumn] = []
|
||||
for col_name in raw_columns:
|
||||
sample_values = [
|
||||
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
|
||||
]
|
||||
data_type = "string"
|
||||
if sample_values:
|
||||
if all(isinstance(v, bool) for v in sample_values):
|
||||
data_type = "boolean"
|
||||
elif all(isinstance(v, (int, float)) for v in sample_values):
|
||||
data_type = "numeric"
|
||||
|
||||
# Compute null_count and unique non-null values in one pass
|
||||
null_count = 0
|
||||
unique_vals: set[str] = set()
|
||||
for row in stats_rows:
|
||||
val = row.get(col_name)
|
||||
if val is None:
|
||||
null_count += 1
|
||||
else:
|
||||
unique_vals.add(str(val))
|
||||
|
||||
columns_meta.append(
|
||||
DataColumn(
|
||||
name=col_name,
|
||||
display_name=col_name.replace("_", " ").title(),
|
||||
data_type=data_type,
|
||||
sample_values=sample_values[:3],
|
||||
null_count=null_count,
|
||||
unique_count=len(unique_vals),
|
||||
)
|
||||
)
|
||||
columns_meta = format_data_columns(data, raw_columns)
|
||||
|
||||
cache_status = get_cache_status_from_result(
|
||||
query_result, force_refresh=request.force_refresh
|
||||
|
||||
16
superset/mcp_service/semantic_layer/__init__.py
Normal file
16
superset/mcp_service/semantic_layer/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
302
superset/mcp_service/semantic_layer/schemas.py
Normal file
302
superset/mcp_service/semantic_layer/schemas.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Pydantic schemas for semantic layer MCP tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from superset.mcp_service.chart.schemas import DataColumn, PerformanceMetadata
|
||||
from superset.mcp_service.common.cache_schemas import CacheStatus
|
||||
from superset.mcp_service.common.error_schemas import MCPBaseError
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared error schema
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SemanticLayerError(MCPBaseError):
|
||||
"""Error response returned by semantic layer tools."""
|
||||
|
||||
success: Literal[False] = False
|
||||
|
||||
@classmethod
|
||||
def create(cls, *, error: str, error_type: str) -> "SemanticLayerError":
|
||||
return cls(error=error, error_type=error_type)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dimension info (returned inside MetricInfo.compatible_dimensions)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DimensionInfo(BaseModel):
|
||||
"""Metadata for a single dimension / column."""
|
||||
|
||||
name: str
|
||||
verbose_name: str | None = None
|
||||
description: str | None = None
|
||||
type: str | None = None
|
||||
is_dttm: bool = False
|
||||
groupby: bool = True
|
||||
filterable: bool = True
|
||||
source: Literal["builtin", "external"] = "builtin"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Metric info
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MetricInfo(BaseModel):
|
||||
"""Metadata for a single metric, including compatible dimensions."""
|
||||
|
||||
name: str
|
||||
verbose_name: str | None = None
|
||||
description: str | None = None
|
||||
expression: str | None = None
|
||||
d3format: str | None = None
|
||||
warning_text: str | None = None
|
||||
source: Literal["builtin", "external"] = "builtin"
|
||||
dataset_id: int | None = None
|
||||
dataset_name: str | None = None
|
||||
view_id: int | None = None
|
||||
view_name: str | None = None
|
||||
compatible_dimensions: list[DimensionInfo] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# list_metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ListMetricsRequest(BaseModel):
|
||||
"""Request schema for list_metrics."""
|
||||
|
||||
search: str | None = Field(
|
||||
default=None,
|
||||
description="Optional search string to filter metrics by name or description.",
|
||||
)
|
||||
dataset_id: int | None = Field(
|
||||
default=None,
|
||||
description="Filter to metrics from a specific built-in dataset.",
|
||||
)
|
||||
view_id: int | None = Field(
|
||||
default=None,
|
||||
description="Filter to metrics from a specific semantic view.",
|
||||
)
|
||||
include_compatible_dimensions: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"When True, each metric includes its list of compatible dimensions. "
|
||||
"Set to False to reduce response size when dimensions aren't needed."
|
||||
),
|
||||
)
|
||||
page: int = Field(default=1, ge=1, description="1-based page number.")
|
||||
page_size: int = Field(
|
||||
default=50, ge=1, le=500, description="Number of metrics per page."
|
||||
)
|
||||
|
||||
|
||||
class MetricList(BaseModel):
|
||||
"""Response schema for list_metrics."""
|
||||
|
||||
metrics: list[MetricInfo]
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
success: Literal[True] = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_table
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GetTableFilter(BaseModel):
|
||||
"""A single filter clause for get_table."""
|
||||
|
||||
col: str = Field(..., description="Column or dimension name to filter on.")
|
||||
op: str = Field(
|
||||
default="==",
|
||||
description=(
|
||||
"Filter operator. Common values: '==', '!=', '>', '<', '>=', '<=', "
|
||||
"'IN', 'NOT IN', 'LIKE', 'ILIKE', 'TEMPORAL_RANGE'."
|
||||
),
|
||||
)
|
||||
val: Any = Field(
|
||||
default=None,
|
||||
description="Filter value. Use a list for 'IN'/'NOT IN' operators.",
|
||||
)
|
||||
|
||||
|
||||
class GetTableRequest(BaseModel):
|
||||
"""Request schema for get_table."""
|
||||
|
||||
dataset_id: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Built-in dataset ID to query. Obtained from list_metrics response "
|
||||
"when source='builtin'. Provide either this or view_id."
|
||||
),
|
||||
)
|
||||
view_id: int | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"External semantic view ID to query. Obtained from list_metrics "
|
||||
"response when source='external'. Provide either this or dataset_id."
|
||||
),
|
||||
)
|
||||
metrics: list[str] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Metric names to compute. All metrics must come from the same "
|
||||
"data source (dataset or semantic view)."
|
||||
),
|
||||
)
|
||||
dimensions: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Dimension or column names to group by.",
|
||||
)
|
||||
filters: list[GetTableFilter] = Field(
|
||||
default_factory=list,
|
||||
description="Optional filters to apply.",
|
||||
)
|
||||
time_range: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional time range string, e.g. 'Last 7 days', 'Last 30 days', "
|
||||
"'2024-01-01 : 2024-12-31'. Requires a datetime dimension."
|
||||
),
|
||||
)
|
||||
time_column: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Name of the datetime column/dimension to apply time_range to. "
|
||||
"Inferred from the dataset's main_dttm_col when omitted."
|
||||
),
|
||||
)
|
||||
row_limit: int = Field(
|
||||
default=1000,
|
||||
ge=1,
|
||||
le=50000,
|
||||
description="Maximum number of rows to return.",
|
||||
)
|
||||
order_by: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Column/metric names to sort by.",
|
||||
)
|
||||
order_desc: bool = Field(
|
||||
default=True,
|
||||
description="Sort descending when True (default).",
|
||||
)
|
||||
use_cache: bool = Field(default=True, description="Use query cache when available.")
|
||||
force_refresh: bool = Field(
|
||||
default=False,
|
||||
description="Force a cache refresh even when cached results exist.",
|
||||
)
|
||||
|
||||
|
||||
class GetTableResponse(BaseModel):
|
||||
"""Response schema for get_table."""
|
||||
|
||||
columns: list[DataColumn]
|
||||
data: list[dict[str, Any]]
|
||||
row_count: int
|
||||
total_rows: int | None = None
|
||||
summary: str
|
||||
source: Literal["builtin", "external"]
|
||||
dataset_id: int | None = None
|
||||
dataset_name: str | None = None
|
||||
view_id: int | None = None
|
||||
view_name: str | None = None
|
||||
performance: PerformanceMetadata | None = None
|
||||
cache_status: CacheStatus | None = None
|
||||
warnings: list[str] = Field(default_factory=list)
|
||||
success: Literal[True] = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_compatible_dimensions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GetCompatibleDimensionsRequest(BaseModel):
|
||||
"""Request schema for get_compatible_dimensions."""
|
||||
|
||||
selected_metrics: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Metric names already selected.",
|
||||
)
|
||||
selected_dimensions: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Dimension names already selected.",
|
||||
)
|
||||
dataset_id: int | None = Field(
|
||||
default=None,
|
||||
description="Built-in dataset ID to query. Provide either this or view_id.",
|
||||
)
|
||||
view_id: int | None = Field(
|
||||
default=None,
|
||||
description="Semantic view ID to query. Provide either this or dataset_id.",
|
||||
)
|
||||
|
||||
|
||||
class CompatibleDimensionsResponse(BaseModel):
|
||||
"""Response schema for get_compatible_dimensions."""
|
||||
|
||||
compatible_dimensions: list[DimensionInfo]
|
||||
source: Literal["builtin", "external"]
|
||||
success: Literal[True] = True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_compatible_metrics
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class GetCompatibleMetricsRequest(BaseModel):
|
||||
"""Request schema for get_compatible_metrics."""
|
||||
|
||||
selected_metrics: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Metric names already selected.",
|
||||
)
|
||||
selected_dimensions: list[str] = Field(
|
||||
default_factory=list,
|
||||
description="Dimension names already selected.",
|
||||
)
|
||||
dataset_id: int | None = Field(
|
||||
default=None,
|
||||
description="Built-in dataset ID to query. Provide either this or view_id.",
|
||||
)
|
||||
view_id: int | None = Field(
|
||||
default=None,
|
||||
description="Semantic view ID to query. Provide either this or dataset_id.",
|
||||
)
|
||||
|
||||
|
||||
class CompatibleMetricsResponse(BaseModel):
|
||||
"""Response schema for get_compatible_metrics."""
|
||||
|
||||
compatible_metrics: list[MetricInfo]
|
||||
source: Literal["builtin", "external"]
|
||||
success: Literal[True] = True
|
||||
27
superset/mcp_service/semantic_layer/tool/__init__.py
Normal file
27
superset/mcp_service/semantic_layer/tool/__init__.py
Normal file
@@ -0,0 +1,27 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from superset.mcp_service.semantic_layer.tool.get_compatible_dimensions import ( # noqa: F401
|
||||
get_compatible_dimensions,
|
||||
)
|
||||
from superset.mcp_service.semantic_layer.tool.get_compatible_metrics import ( # noqa: F401
|
||||
get_compatible_metrics,
|
||||
)
|
||||
from superset.mcp_service.semantic_layer.tool.get_table import get_table # noqa: F401
|
||||
from superset.mcp_service.semantic_layer.tool.list_metrics import ( # noqa: F401
|
||||
list_metrics,
|
||||
)
|
||||
@@ -0,0 +1,227 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""MCP tool: get_compatible_dimensions
|
||||
|
||||
Returns dimensions compatible with the current metric/dimension selection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.privacy import (
|
||||
DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
requires_data_model_metadata_access,
|
||||
user_can_view_data_model_metadata,
|
||||
)
|
||||
from superset.mcp_service.semantic_layer.schemas import (
|
||||
CompatibleDimensionsResponse,
|
||||
DimensionInfo,
|
||||
GetCompatibleDimensionsRequest,
|
||||
SemanticLayerError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["data", "semantic"],
|
||||
class_permission_name="Dataset",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get compatible dimensions",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
@requires_data_model_metadata_access
|
||||
async def get_compatible_dimensions(
|
||||
request: GetCompatibleDimensionsRequest,
|
||||
ctx: Context,
|
||||
) -> CompatibleDimensionsResponse | SemanticLayerError:
|
||||
"""Return dimensions compatible with the current metric/dimension selection.
|
||||
|
||||
Used to drive progressive disclosure in query builders: after the user
|
||||
selects one or more metrics (and optionally some dimensions), this tool
|
||||
returns the dimensions that can validly be added without breaking the
|
||||
underlying query.
|
||||
|
||||
Provide exactly one of ``dataset_id`` (built-in) or ``view_id`` (external).
|
||||
|
||||
For built-in datasets, returns all groupby-enabled columns from the dataset.
|
||||
SQL datasets have no semantic compatibility constraints between metrics and
|
||||
dimensions, so all groupby columns are always returned regardless of the
|
||||
selected metrics.
|
||||
|
||||
For external semantic views, delegates to the view's
|
||||
``get_compatible_dimensions`` implementation.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"selected_metrics": ["revenue"],
|
||||
"selected_dimensions": [],
|
||||
"view_id": 5
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Getting compatible dimensions: dataset_id=%s, view_id=%s, "
|
||||
"metrics=%s, dims=%s"
|
||||
% (
|
||||
request.dataset_id,
|
||||
request.view_id,
|
||||
request.selected_metrics,
|
||||
request.selected_dimensions,
|
||||
)
|
||||
)
|
||||
|
||||
if not user_can_view_data_model_metadata():
|
||||
return SemanticLayerError.create(
|
||||
error="You don't have permission to access dataset details for your role.",
|
||||
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
)
|
||||
|
||||
if request.dataset_id is None and request.view_id is None:
|
||||
return SemanticLayerError.create(
|
||||
error="Provide either dataset_id (built-in) or view_id (external).",
|
||||
error_type="ValidationError",
|
||||
)
|
||||
if request.dataset_id is not None and request.view_id is not None:
|
||||
return SemanticLayerError.create(
|
||||
error="Provide only one of dataset_id or view_id, not both.",
|
||||
error_type="ValidationError",
|
||||
)
|
||||
|
||||
try:
|
||||
# ------------------------------------------------------------------
|
||||
# Built-in dataset path
|
||||
# ------------------------------------------------------------------
|
||||
if request.dataset_id is not None:
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
dataset_id: int = request.dataset_id
|
||||
with event_logger.log_context(
|
||||
action="mcp.get_compatible_dimensions.builtin"
|
||||
):
|
||||
dataset = DatasetDAO.find_by_id(
|
||||
dataset_id,
|
||||
query_options=[
|
||||
subqueryload(SqlaTable.columns),
|
||||
subqueryload(SqlaTable.metrics),
|
||||
],
|
||||
)
|
||||
|
||||
if dataset is None:
|
||||
return SemanticLayerError.create(
|
||||
error=f"No dataset found with id: {request.dataset_id}.",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
# For built-in datasets all groupby columns are always compatible;
|
||||
# there's no per-metric compatibility constraint at the SQL level.
|
||||
dims = [
|
||||
DimensionInfo(
|
||||
name=col.column_name,
|
||||
verbose_name=col.verbose_name or None,
|
||||
description=col.description or None,
|
||||
type=col.type or None,
|
||||
is_dttm=bool(col.is_dttm),
|
||||
groupby=bool(col.groupby),
|
||||
filterable=bool(col.filterable),
|
||||
source="builtin",
|
||||
)
|
||||
for col in dataset.columns
|
||||
if col.groupby
|
||||
]
|
||||
|
||||
await ctx.info("Compatible dimensions (builtin): count=%d" % len(dims))
|
||||
return CompatibleDimensionsResponse(
|
||||
compatible_dimensions=dims,
|
||||
source="builtin",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# External semantic view path
|
||||
# ------------------------------------------------------------------
|
||||
from superset.daos.semantic_layer import SemanticViewDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
view_id: int = request.view_id # type: ignore[assignment]
|
||||
with event_logger.log_context(action="mcp.get_compatible_dimensions.external"):
|
||||
view = SemanticViewDAO.find_by_id(view_id)
|
||||
|
||||
if view is None:
|
||||
return SemanticLayerError.create(
|
||||
error=f"No semantic view found with id: {view_id}.",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
try:
|
||||
view.raise_for_access()
|
||||
except SupersetSecurityException as ex:
|
||||
return SemanticLayerError.create(
|
||||
error=str(ex.error.message),
|
||||
error_type="AccessDenied",
|
||||
)
|
||||
|
||||
compatible_names = view.get_compatible_dimensions(
|
||||
request.selected_metrics,
|
||||
request.selected_dimensions,
|
||||
)
|
||||
|
||||
# Enrich with full column metadata
|
||||
all_cols = {col.column_name: col for col in view.columns}
|
||||
dims = [
|
||||
DimensionInfo(
|
||||
name=name,
|
||||
verbose_name=all_cols[name].verbose_name if name in all_cols else None,
|
||||
description=all_cols[name].description if name in all_cols else None,
|
||||
type=all_cols[name].type if name in all_cols else None,
|
||||
is_dttm=all_cols[name].is_dttm if name in all_cols else False,
|
||||
groupby=all_cols[name].groupby if name in all_cols else True,
|
||||
filterable=all_cols[name].filterable if name in all_cols else True,
|
||||
source="external",
|
||||
)
|
||||
for name in compatible_names
|
||||
]
|
||||
|
||||
await ctx.info(
|
||||
"Compatible dimensions (external view id=%d): count=%d"
|
||||
% (view.id, len(dims))
|
||||
)
|
||||
return CompatibleDimensionsResponse(
|
||||
compatible_dimensions=dims,
|
||||
source="external",
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Unexpected error in get_compatible_dimensions: %s: %s",
|
||||
type(exc).__name__,
|
||||
str(exc),
|
||||
)
|
||||
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
|
||||
return SemanticLayerError.create(
|
||||
error=f"Internal error in get_compatible_dimensions: {exc}",
|
||||
error_type="InternalError",
|
||||
)
|
||||
@@ -0,0 +1,226 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""MCP tool: get_compatible_metrics
|
||||
|
||||
Returns metrics compatible with the current dimension/metric selection.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.privacy import (
|
||||
DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
requires_data_model_metadata_access,
|
||||
user_can_view_data_model_metadata,
|
||||
)
|
||||
from superset.mcp_service.semantic_layer.schemas import (
|
||||
CompatibleMetricsResponse,
|
||||
GetCompatibleMetricsRequest,
|
||||
MetricInfo,
|
||||
SemanticLayerError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["data", "semantic"],
|
||||
class_permission_name="Dataset",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get compatible metrics",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
@requires_data_model_metadata_access
|
||||
async def get_compatible_metrics(
|
||||
request: GetCompatibleMetricsRequest,
|
||||
ctx: Context,
|
||||
) -> CompatibleMetricsResponse | SemanticLayerError:
|
||||
"""Return metrics compatible with the current dimension/metric selection.
|
||||
|
||||
Used to progressively refine a query: given a set of already-selected
|
||||
metrics and dimensions, returns the additional metrics that can be
|
||||
combined without breaking the underlying semantic constraints.
|
||||
|
||||
Provide exactly one of ``dataset_id`` (built-in) or ``view_id`` (external).
|
||||
|
||||
For built-in datasets, all metrics from the dataset are considered
|
||||
compatible (SQL GROUP BY imposes no metric-level constraints).
|
||||
|
||||
For external semantic views, delegates to the view's
|
||||
``get_compatible_metrics`` implementation.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{
|
||||
"selected_metrics": [],
|
||||
"selected_dimensions": ["region"],
|
||||
"view_id": 5
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Getting compatible metrics: dataset_id=%s, view_id=%s, "
|
||||
"metrics=%s, dims=%s"
|
||||
% (
|
||||
request.dataset_id,
|
||||
request.view_id,
|
||||
request.selected_metrics,
|
||||
request.selected_dimensions,
|
||||
)
|
||||
)
|
||||
|
||||
if not user_can_view_data_model_metadata():
|
||||
return SemanticLayerError.create(
|
||||
error="You don't have permission to access dataset details for your role.",
|
||||
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
)
|
||||
|
||||
if request.dataset_id is None and request.view_id is None:
|
||||
return SemanticLayerError.create(
|
||||
error="Provide either dataset_id (built-in) or view_id (external).",
|
||||
error_type="ValidationError",
|
||||
)
|
||||
if request.dataset_id is not None and request.view_id is not None:
|
||||
return SemanticLayerError.create(
|
||||
error="Provide only one of dataset_id or view_id, not both.",
|
||||
error_type="ValidationError",
|
||||
)
|
||||
|
||||
try:
|
||||
# ------------------------------------------------------------------
|
||||
# Built-in dataset path
|
||||
# ------------------------------------------------------------------
|
||||
if request.dataset_id is not None:
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_compatible_metrics.builtin"):
|
||||
dataset = DatasetDAO.find_by_id(
|
||||
request.dataset_id,
|
||||
query_options=[
|
||||
subqueryload(SqlaTable.columns),
|
||||
subqueryload(SqlaTable.metrics),
|
||||
],
|
||||
)
|
||||
|
||||
if dataset is None:
|
||||
return SemanticLayerError.create(
|
||||
error=f"No dataset found with id: {request.dataset_id}.",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
# All metrics on a SQL dataset are always mutually compatible.
|
||||
compatible = [
|
||||
MetricInfo(
|
||||
name=m.metric_name,
|
||||
verbose_name=m.verbose_name or None,
|
||||
description=m.description or None,
|
||||
expression=m.expression or None,
|
||||
d3format=m.d3format or None,
|
||||
warning_text=m.warning_text or None,
|
||||
source="builtin",
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.table_name,
|
||||
)
|
||||
for m in dataset.metrics
|
||||
]
|
||||
|
||||
await ctx.info("Compatible metrics (builtin): count=%d" % len(compatible))
|
||||
return CompatibleMetricsResponse(
|
||||
compatible_metrics=compatible,
|
||||
source="builtin",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# External semantic view path
|
||||
# ------------------------------------------------------------------
|
||||
from superset.daos.semantic_layer import SemanticViewDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
view_id: int = request.view_id # type: ignore[assignment]
|
||||
with event_logger.log_context(action="mcp.get_compatible_metrics.external"):
|
||||
view = SemanticViewDAO.find_by_id(view_id)
|
||||
|
||||
if view is None:
|
||||
return SemanticLayerError.create(
|
||||
error=f"No semantic view found with id: {view_id}.",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
try:
|
||||
view.raise_for_access()
|
||||
except SupersetSecurityException as ex:
|
||||
return SemanticLayerError.create(
|
||||
error=str(ex.error.message),
|
||||
error_type="AccessDenied",
|
||||
)
|
||||
|
||||
compatible_names = view.get_compatible_metrics(
|
||||
request.selected_metrics,
|
||||
request.selected_dimensions,
|
||||
)
|
||||
|
||||
# Enrich with full metric metadata
|
||||
all_metrics_map = {m.metric_name: m for m in view.metrics}
|
||||
compatible = [
|
||||
MetricInfo(
|
||||
name=name,
|
||||
description=(
|
||||
all_metrics_map[name].description
|
||||
if name in all_metrics_map
|
||||
else None
|
||||
),
|
||||
expression=(
|
||||
all_metrics_map[name].expression
|
||||
if name in all_metrics_map
|
||||
else None
|
||||
),
|
||||
source="external",
|
||||
view_id=view.id,
|
||||
view_name=view.name,
|
||||
)
|
||||
for name in compatible_names
|
||||
]
|
||||
|
||||
await ctx.info(
|
||||
"Compatible metrics (external view id=%d): count=%d"
|
||||
% (view.id, len(compatible))
|
||||
)
|
||||
return CompatibleMetricsResponse(
|
||||
compatible_metrics=compatible,
|
||||
source="external",
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Unexpected error in get_compatible_metrics: %s: %s",
|
||||
type(exc).__name__,
|
||||
str(exc),
|
||||
)
|
||||
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
|
||||
return SemanticLayerError.create(
|
||||
error=f"Internal error in get_compatible_metrics: {exc}",
|
||||
error_type="InternalError",
|
||||
)
|
||||
437
superset/mcp_service/semantic_layer/tool/get_table.py
Normal file
437
superset/mcp_service/semantic_layer/tool/get_table.py
Normal file
@@ -0,0 +1,437 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""MCP tool: get_table
|
||||
|
||||
Query a data source (built-in dataset or external semantic view) using
|
||||
metric and dimension names, returning tabular results.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.commands.exceptions import CommandException
|
||||
from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.chart.schemas import PerformanceMetadata
|
||||
from superset.mcp_service.privacy import (
|
||||
DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
requires_data_model_metadata_access,
|
||||
user_can_view_data_model_metadata,
|
||||
)
|
||||
from superset.mcp_service.semantic_layer.schemas import (
|
||||
GetTableRequest,
|
||||
GetTableResponse,
|
||||
SemanticLayerError,
|
||||
)
|
||||
from superset.mcp_service.utils.cache_utils import get_cache_status_from_result
|
||||
from superset.mcp_service.utils.oauth2_utils import build_oauth2_redirect_message
|
||||
from superset.mcp_service.utils.query_utils import validate_names
|
||||
from superset.mcp_service.utils.response_utils import format_data_columns
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_query_dict(
|
||||
request: GetTableRequest,
|
||||
time_col: str | None,
|
||||
) -> dict[str, Any]:
|
||||
"""Assemble the query dict for QueryContextFactory."""
|
||||
filters: list[dict[str, Any]] = [
|
||||
{"col": f.col, "op": f.op, "val": f.val} for f in request.filters
|
||||
]
|
||||
if request.time_range and time_col:
|
||||
filters.append(
|
||||
{"col": time_col, "op": "TEMPORAL_RANGE", "val": request.time_range}
|
||||
)
|
||||
|
||||
query_dict: dict[str, Any] = {
|
||||
"filters": filters,
|
||||
"columns": request.dimensions,
|
||||
"metrics": request.metrics,
|
||||
"row_limit": request.row_limit,
|
||||
"order_desc": request.order_desc,
|
||||
}
|
||||
if time_col:
|
||||
query_dict["granularity"] = time_col
|
||||
if request.order_by:
|
||||
query_dict["orderby"] = [
|
||||
(col, not request.order_desc) for col in request.order_by
|
||||
]
|
||||
return query_dict
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["data", "semantic"],
|
||||
class_permission_name="Dataset",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get table",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
@requires_data_model_metadata_access
|
||||
async def get_table( # noqa: C901
|
||||
request: GetTableRequest,
|
||||
ctx: Context,
|
||||
) -> GetTableResponse | SemanticLayerError:
|
||||
"""Query a data source using metrics and dimensions, returning tabular results.
|
||||
|
||||
Works with both built-in datasets and external semantic views. The
|
||||
``dataset_id`` or ``view_id`` comes from the ``list_metrics`` response.
|
||||
|
||||
Workflow:
|
||||
1. list_metrics -> discover metrics and their compatible_dimensions
|
||||
2. get_table -> query with chosen metrics and dimensions
|
||||
|
||||
Example (built-in):
|
||||
```json
|
||||
{
|
||||
"dataset_id": 42,
|
||||
"metrics": ["revenue"],
|
||||
"dimensions": ["region", "product_category"],
|
||||
"time_range": "Last 30 days",
|
||||
"row_limit": 500
|
||||
}
|
||||
```
|
||||
|
||||
Example (external):
|
||||
```json
|
||||
{
|
||||
"view_id": 5,
|
||||
"metrics": ["bookings"],
|
||||
"dimensions": ["listing__country_name"],
|
||||
"row_limit": 100
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Starting get_table: dataset_id=%s, view_id=%s, metrics=%s, "
|
||||
"dimensions=%s, row_limit=%s"
|
||||
% (
|
||||
request.dataset_id,
|
||||
request.view_id,
|
||||
request.metrics,
|
||||
request.dimensions,
|
||||
request.row_limit,
|
||||
)
|
||||
)
|
||||
|
||||
if not user_can_view_data_model_metadata():
|
||||
return SemanticLayerError.create(
|
||||
error="You don't have permission to access dataset details for your role.",
|
||||
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
)
|
||||
|
||||
if request.dataset_id is None and request.view_id is None:
|
||||
return SemanticLayerError.create(
|
||||
error=(
|
||||
"Provide either dataset_id (built-in dataset) or view_id "
|
||||
"(external semantic view). Both are in the list_metrics response."
|
||||
),
|
||||
error_type="ValidationError",
|
||||
)
|
||||
if request.dataset_id is not None and request.view_id is not None:
|
||||
return SemanticLayerError.create(
|
||||
error="Provide only one of dataset_id or view_id, not both.",
|
||||
error_type="ValidationError",
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
|
||||
is_builtin = request.dataset_id is not None
|
||||
datasource_type = "table" if is_builtin else "semantic_view"
|
||||
if is_builtin:
|
||||
assert request.dataset_id is not None
|
||||
datasource_id = request.dataset_id
|
||||
else:
|
||||
assert request.view_id is not None
|
||||
datasource_id = request.view_id
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Resolve datasource for metadata (time column, display name)
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(1, 5, "Resolving data source")
|
||||
display_name: str = f"{'Dataset' if is_builtin else 'View'} {datasource_id}"
|
||||
time_col: str | None = request.time_column
|
||||
warnings: list[str] = []
|
||||
valid_columns: set[str] = set()
|
||||
valid_metrics: set[str] = set()
|
||||
|
||||
if is_builtin:
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_table.resolve_dataset"):
|
||||
dataset = DatasetDAO.find_by_id(
|
||||
datasource_id,
|
||||
query_options=[
|
||||
subqueryload(SqlaTable.columns),
|
||||
subqueryload(SqlaTable.metrics),
|
||||
],
|
||||
)
|
||||
if dataset is None:
|
||||
return SemanticLayerError.create(
|
||||
error=f"No dataset found with id: {request.dataset_id}.",
|
||||
error_type="NotFound",
|
||||
)
|
||||
display_name = dataset.table_name
|
||||
valid_columns = {c.column_name for c in dataset.columns}
|
||||
valid_dttm_columns = {c.column_name for c in dataset.columns if c.is_dttm}
|
||||
valid_metrics = {m.metric_name for m in dataset.metrics}
|
||||
if time_col is None and request.time_range:
|
||||
time_col = getattr(dataset, "main_dttm_col", None)
|
||||
if not time_col:
|
||||
return SemanticLayerError.create(
|
||||
error=(
|
||||
"time_range was provided but no temporal column is "
|
||||
"configured. Set time_column explicitly."
|
||||
),
|
||||
error_type="ValidationError",
|
||||
)
|
||||
if time_col is not None and time_col not in valid_dttm_columns:
|
||||
if time_col in valid_columns:
|
||||
error_msg = (
|
||||
f"time_column '{time_col}' on dataset '{display_name}' is "
|
||||
"not marked as a datetime column."
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
f"Unknown time_column: '{time_col}' on dataset "
|
||||
f"'{display_name}'."
|
||||
)
|
||||
return SemanticLayerError.create(
|
||||
error=error_msg,
|
||||
error_type="ValidationError",
|
||||
)
|
||||
else:
|
||||
from superset.daos.semantic_layer import SemanticViewDAO
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
with event_logger.log_context(action="mcp.get_table.resolve_view"):
|
||||
view = SemanticViewDAO.find_by_id(datasource_id)
|
||||
if view is None:
|
||||
return SemanticLayerError.create(
|
||||
error=f"No semantic view found with id: {request.view_id}.",
|
||||
error_type="NotFound",
|
||||
)
|
||||
try:
|
||||
view.raise_for_access()
|
||||
except SupersetSecurityException as ex:
|
||||
return SemanticLayerError.create(
|
||||
error=str(ex.error.message),
|
||||
error_type="AccessDenied",
|
||||
)
|
||||
display_name = view.name
|
||||
valid_columns = {c.column_name for c in view.columns}
|
||||
valid_dttm_columns = {c.column_name for c in view.columns if c.is_dttm}
|
||||
valid_metrics = {m.metric_name for m in view.metrics}
|
||||
if time_col is None and request.time_range:
|
||||
# Use first datetime dimension as the time column
|
||||
dttm_cols = [c for c in view.columns if c.is_dttm]
|
||||
if dttm_cols:
|
||||
time_col = dttm_cols[0].column_name
|
||||
else:
|
||||
warnings.append(
|
||||
"time_range provided but no datetime dimension found; "
|
||||
"time filter will not be applied."
|
||||
)
|
||||
time_col = None
|
||||
if time_col is not None and time_col not in valid_dttm_columns:
|
||||
if time_col in valid_columns:
|
||||
error_msg = (
|
||||
f"time_column '{time_col}' on view '{display_name}' is "
|
||||
"not marked as a datetime column."
|
||||
)
|
||||
else:
|
||||
error_msg = (
|
||||
f"Unknown time_column: '{time_col}' on view '{display_name}'."
|
||||
)
|
||||
return SemanticLayerError.create(
|
||||
error=error_msg,
|
||||
error_type="ValidationError",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Validate requested metrics and dimensions against the datasource
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(2, 5, "Validating metrics and dimensions")
|
||||
validation_errors: list[str] = []
|
||||
validation_errors.extend(
|
||||
validate_names(request.dimensions, valid_columns, "dimension")
|
||||
)
|
||||
validation_errors.extend(
|
||||
validate_names(request.metrics, valid_metrics, "metric")
|
||||
)
|
||||
filter_cols = [f.col for f in request.filters]
|
||||
validation_errors.extend(
|
||||
validate_names(filter_cols, valid_columns, "filter column")
|
||||
)
|
||||
if request.order_by:
|
||||
valid_orderby = valid_columns | valid_metrics
|
||||
validation_errors.extend(
|
||||
validate_names(request.order_by, valid_orderby, "order_by")
|
||||
)
|
||||
if validation_errors:
|
||||
error_msg = "; ".join(validation_errors)
|
||||
await ctx.error("Validation failed: %s" % (error_msg,))
|
||||
return SemanticLayerError.create(
|
||||
error=error_msg,
|
||||
error_type="ValidationError",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Build and execute query
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(3, 5, "Building query")
|
||||
query_dict = _build_query_dict(request, time_col)
|
||||
|
||||
await ctx.debug("Query dict: %s" % (sorted(query_dict.keys()),))
|
||||
await ctx.report_progress(4, 5, "Executing query")
|
||||
start_time = time.time()
|
||||
|
||||
with event_logger.log_context(action="mcp.get_table.execute"):
|
||||
factory = QueryContextFactory()
|
||||
query_context = factory.create(
|
||||
datasource={"id": datasource_id, "type": datasource_type},
|
||||
queries=[query_dict],
|
||||
form_data={},
|
||||
force=not request.use_cache or request.force_refresh,
|
||||
)
|
||||
command = ChartDataCommand(query_context)
|
||||
command.validate()
|
||||
result = command.run()
|
||||
|
||||
query_duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
if not result or "queries" not in result or not result["queries"]:
|
||||
return SemanticLayerError.create(
|
||||
error="Query returned no results.",
|
||||
error_type="EmptyQuery",
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Format response
|
||||
# ------------------------------------------------------------------
|
||||
await ctx.report_progress(5, 5, "Formatting results")
|
||||
query_result = result["queries"][0]
|
||||
data = query_result.get("data", [])
|
||||
raw_columns = query_result.get("colnames", [])
|
||||
|
||||
if not data:
|
||||
return GetTableResponse(
|
||||
columns=[],
|
||||
data=[],
|
||||
row_count=0,
|
||||
total_rows=0,
|
||||
summary=f"'{display_name}': query returned no data.",
|
||||
source="builtin" if is_builtin else "external",
|
||||
dataset_id=request.dataset_id,
|
||||
dataset_name=display_name if is_builtin else None,
|
||||
view_id=request.view_id,
|
||||
view_name=display_name if not is_builtin else None,
|
||||
performance=PerformanceMetadata(
|
||||
query_duration_ms=query_duration_ms,
|
||||
cache_status="no_data",
|
||||
),
|
||||
cache_status=get_cache_status_from_result(
|
||||
query_result, force_refresh=request.force_refresh
|
||||
),
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
columns_meta = format_data_columns(data, raw_columns)
|
||||
cache_status = get_cache_status_from_result(
|
||||
query_result, force_refresh=request.force_refresh
|
||||
)
|
||||
cache_label = "cached" if cache_status and cache_status.cache_hit else "fresh"
|
||||
summary = (
|
||||
f"'{display_name}': {len(data)} rows, "
|
||||
f"{len(raw_columns)} columns ({cache_label})."
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"get_table complete: rows=%d, columns=%d, duration=%dms"
|
||||
% (len(data), len(raw_columns), query_duration_ms)
|
||||
)
|
||||
|
||||
return GetTableResponse(
|
||||
columns=columns_meta,
|
||||
data=data,
|
||||
row_count=len(data),
|
||||
total_rows=query_result.get("rowcount"),
|
||||
summary=summary,
|
||||
source="builtin" if is_builtin else "external",
|
||||
dataset_id=request.dataset_id,
|
||||
dataset_name=display_name if is_builtin else None,
|
||||
view_id=request.view_id,
|
||||
view_name=display_name if not is_builtin else None,
|
||||
performance=PerformanceMetadata(
|
||||
query_duration_ms=query_duration_ms,
|
||||
cache_status=cache_label,
|
||||
),
|
||||
cache_status=cache_status,
|
||||
warnings=warnings,
|
||||
)
|
||||
|
||||
except OAuth2RedirectError as exc:
|
||||
redirect_msg = build_oauth2_redirect_message(exc)
|
||||
await ctx.error("OAuth2 redirect required: %s" % redirect_msg)
|
||||
return SemanticLayerError.create(
|
||||
error=redirect_msg,
|
||||
error_type="OAuth2Redirect",
|
||||
)
|
||||
|
||||
except OAuth2Error as exc:
|
||||
await ctx.error("OAuth2 error: %s" % str(exc))
|
||||
return SemanticLayerError.create(
|
||||
error=f"OAuth2 authentication error: {exc}",
|
||||
error_type="OAuth2Error",
|
||||
)
|
||||
|
||||
except (CommandException, SupersetException) as exc:
|
||||
await ctx.error("Query failed: %s" % str(exc))
|
||||
return SemanticLayerError.create(
|
||||
error=f"Query execution failed: {exc}",
|
||||
error_type="QueryError",
|
||||
)
|
||||
|
||||
except SQLAlchemyError as exc:
|
||||
await ctx.error("Database error: %s" % str(exc))
|
||||
return SemanticLayerError.create(
|
||||
error=f"Database error: {exc}",
|
||||
error_type="DatabaseError",
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Unexpected error in get_table: %s: %s", type(exc).__name__, str(exc)
|
||||
)
|
||||
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
|
||||
return SemanticLayerError.create(
|
||||
error=f"Internal error executing get_table: {exc}",
|
||||
error_type="InternalError",
|
||||
)
|
||||
294
superset/mcp_service/semantic_layer/tool/list_metrics.py
Normal file
294
superset/mcp_service/semantic_layer/tool/list_metrics.py
Normal file
@@ -0,0 +1,294 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""MCP tool: list_metrics
|
||||
|
||||
Unified metric discovery across built-in datasets and external semantic views.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.daos.semantic_layer import SemanticViewDAO
|
||||
from superset.extensions import db, event_logger
|
||||
from superset.mcp_service.privacy import (
|
||||
DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
requires_data_model_metadata_access,
|
||||
user_can_view_data_model_metadata,
|
||||
)
|
||||
from superset.mcp_service.semantic_layer.schemas import (
|
||||
DimensionInfo,
|
||||
ListMetricsRequest,
|
||||
MetricInfo,
|
||||
MetricList,
|
||||
SemanticLayerError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _matches_search(text: str | None, search: str) -> bool:
|
||||
if not text:
|
||||
return False
|
||||
return search.lower() in text.lower()
|
||||
|
||||
|
||||
def _builtin_compatible_dims(dataset: Any) -> list[DimensionInfo]:
|
||||
"""Return groupby-enabled columns as compatible dimensions for a built-in metric."""
|
||||
return [
|
||||
DimensionInfo(
|
||||
name=col.column_name,
|
||||
verbose_name=col.verbose_name or None,
|
||||
description=col.description or None,
|
||||
type=col.type or None,
|
||||
is_dttm=bool(col.is_dttm),
|
||||
groupby=bool(col.groupby),
|
||||
filterable=bool(col.filterable),
|
||||
source="builtin",
|
||||
)
|
||||
for col in dataset.columns
|
||||
if col.groupby
|
||||
]
|
||||
|
||||
|
||||
def _collect_builtin_metrics(request: ListMetricsRequest) -> list[MetricInfo]:
|
||||
"""Collect metrics from built-in SqlaTable datasets."""
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
with event_logger.log_context(action="mcp.list_metrics.builtin_query"):
|
||||
if request.dataset_id is not None:
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
dataset = DatasetDAO.find_by_id(
|
||||
request.dataset_id,
|
||||
query_options=[
|
||||
subqueryload(SqlaTable.columns),
|
||||
subqueryload(SqlaTable.metrics),
|
||||
],
|
||||
)
|
||||
datasets = [dataset] if dataset else []
|
||||
else:
|
||||
from sqlalchemy.orm import subqueryload
|
||||
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
|
||||
# Use _apply_base_filter with explicit eager loading to avoid
|
||||
# N+1 queries when iterating dataset.metrics / dataset.columns.
|
||||
query = db.session.query(SqlaTable).options(
|
||||
subqueryload(SqlaTable.columns),
|
||||
subqueryload(SqlaTable.metrics),
|
||||
)
|
||||
datasets = DatasetDAO._apply_base_filter(query).all()
|
||||
|
||||
results: list[MetricInfo] = []
|
||||
for dataset in datasets:
|
||||
compat_dims = (
|
||||
_builtin_compatible_dims(dataset)
|
||||
if request.include_compatible_dimensions
|
||||
else []
|
||||
)
|
||||
for metric in dataset.metrics:
|
||||
name = metric.metric_name or ""
|
||||
desc = metric.description or ""
|
||||
if request.search and not (
|
||||
_matches_search(name, request.search)
|
||||
or _matches_search(desc, request.search)
|
||||
):
|
||||
continue
|
||||
results.append(
|
||||
MetricInfo(
|
||||
name=name,
|
||||
verbose_name=metric.verbose_name or None,
|
||||
description=desc or None,
|
||||
expression=metric.expression or None,
|
||||
d3format=metric.d3format or None,
|
||||
warning_text=metric.warning_text or None,
|
||||
source="builtin",
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.table_name,
|
||||
compatible_dimensions=compat_dims,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
async def _collect_external_metrics(
|
||||
request: ListMetricsRequest,
|
||||
ctx: Context,
|
||||
) -> list[MetricInfo]:
|
||||
"""Collect metrics from external SemanticView models."""
|
||||
with event_logger.log_context(action="mcp.list_metrics.external_query"):
|
||||
if request.view_id is not None:
|
||||
view = SemanticViewDAO.find_by_id(request.view_id)
|
||||
views = [view] if view else []
|
||||
else:
|
||||
# find_accessible filters at SQL level, avoiding a per-row
|
||||
# Python permission check and the audit noise of raise_for_access.
|
||||
views = SemanticViewDAO.find_accessible()
|
||||
|
||||
await ctx.debug("Found %d semantic views to scan for metrics" % len(views))
|
||||
|
||||
results: list[MetricInfo] = []
|
||||
for view in views:
|
||||
# raise_for_access must be called outside the broad except block below
|
||||
# so that auth errors are never silently swallowed.
|
||||
view.raise_for_access()
|
||||
try:
|
||||
raw_metrics = view.metrics
|
||||
raw_cols = view.columns if request.include_compatible_dimensions else []
|
||||
compat_dims = [
|
||||
DimensionInfo(
|
||||
name=col.column_name,
|
||||
verbose_name=col.verbose_name,
|
||||
description=col.description,
|
||||
type=col.type,
|
||||
is_dttm=col.is_dttm,
|
||||
groupby=col.groupby,
|
||||
filterable=col.filterable,
|
||||
source="external",
|
||||
)
|
||||
for col in raw_cols
|
||||
]
|
||||
for metric in raw_metrics:
|
||||
name = metric.metric_name or ""
|
||||
desc = metric.description or ""
|
||||
if request.search and not (
|
||||
_matches_search(name, request.search)
|
||||
or _matches_search(desc, request.search)
|
||||
):
|
||||
continue
|
||||
results.append(
|
||||
MetricInfo(
|
||||
name=name,
|
||||
description=desc or None,
|
||||
expression=metric.expression or None,
|
||||
source="external",
|
||||
view_id=view.id,
|
||||
view_name=view.name,
|
||||
compatible_dimensions=compat_dims,
|
||||
)
|
||||
)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
# External registry may be empty in OSS — degrade gracefully
|
||||
await ctx.warning(
|
||||
"Could not load metrics for view id=%s: %s" % (view.id, str(exc))
|
||||
)
|
||||
return results
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["data", "semantic"],
|
||||
class_permission_name="Dataset",
|
||||
annotations=ToolAnnotations(
|
||||
title="List metrics",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
@requires_data_model_metadata_access
|
||||
async def list_metrics(
|
||||
request: ListMetricsRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> MetricList | SemanticLayerError:
|
||||
"""List available metrics across built-in datasets and external semantic views.
|
||||
|
||||
This is the primary entry point for semantic layer exploration. Returns a
|
||||
unified list of metrics from all data sources the current user can access,
|
||||
with compatible dimensions included inline.
|
||||
|
||||
Workflow:
|
||||
1. list_metrics -> discover available metrics and their compatible dimensions
|
||||
2. get_table -> query data using chosen metrics and dimensions
|
||||
|
||||
Use ``search`` to filter by metric name or description. Use ``dataset_id``
|
||||
or ``view_id`` to scope to a specific data source.
|
||||
|
||||
Example:
|
||||
```json
|
||||
{"search": "revenue", "include_compatible_dimensions": true, "page": 1}
|
||||
```
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_metrics")
|
||||
|
||||
request = request or ListMetricsRequest()
|
||||
|
||||
await ctx.info(
|
||||
"Listing metrics: search=%s, dataset_id=%s, view_id=%s, page=%s"
|
||||
% (request.search, request.dataset_id, request.view_id, request.page)
|
||||
)
|
||||
|
||||
if not user_can_view_data_model_metadata():
|
||||
await ctx.warning("Metric listing blocked by data-model privacy controls")
|
||||
return SemanticLayerError.create(
|
||||
error="You don't have permission to access dataset details for your role.",
|
||||
error_type=DATA_MODEL_METADATA_ERROR_TYPE,
|
||||
)
|
||||
|
||||
if request.dataset_id is not None and request.view_id is not None:
|
||||
return SemanticLayerError.create(
|
||||
error="Provide only one of dataset_id or view_id to scope the search, not both.",
|
||||
error_type="ValidationError",
|
||||
)
|
||||
|
||||
try:
|
||||
all_metrics: list[MetricInfo] = []
|
||||
|
||||
if request.view_id is None:
|
||||
all_metrics.extend(_collect_builtin_metrics(request))
|
||||
await ctx.debug("Collected %d built-in metrics" % len(all_metrics))
|
||||
|
||||
if request.dataset_id is None:
|
||||
external = await _collect_external_metrics(request, ctx)
|
||||
all_metrics.extend(external)
|
||||
await ctx.debug("Collected %d external metrics" % len(external))
|
||||
|
||||
total_count = len(all_metrics)
|
||||
total_pages = max(1, (total_count + request.page_size - 1) // request.page_size)
|
||||
start = (request.page - 1) * request.page_size
|
||||
page_metrics = all_metrics[start : start + request.page_size]
|
||||
|
||||
await ctx.info(
|
||||
"Metrics listed: total=%d, page=%d/%d, returned=%d"
|
||||
% (total_count, request.page, total_pages, len(page_metrics))
|
||||
)
|
||||
|
||||
return MetricList(
|
||||
metrics=page_metrics,
|
||||
total_count=total_count,
|
||||
page=request.page,
|
||||
page_size=request.page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(
|
||||
"Unexpected error in list_metrics: %s: %s", type(exc).__name__, str(exc)
|
||||
)
|
||||
await ctx.error("Unexpected error: %s: %s" % (type(exc).__name__, str(exc)))
|
||||
return SemanticLayerError.create(
|
||||
error=f"Internal error listing metrics: {exc}",
|
||||
error_type="InternalError",
|
||||
)
|
||||
40
superset/mcp_service/utils/query_utils.py
Normal file
40
superset/mcp_service/utils/query_utils.py
Normal file
@@ -0,0 +1,40 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Shared query validation utilities for MCP tools."""
|
||||
|
||||
import difflib
|
||||
|
||||
|
||||
def validate_names(
|
||||
requested: list[str],
|
||||
valid: set[str],
|
||||
kind: str,
|
||||
) -> list[str]:
|
||||
"""Return list of error messages for names not found in *valid*.
|
||||
|
||||
Includes close-match suggestions when available.
|
||||
"""
|
||||
errors: list[str] = []
|
||||
for name in requested:
|
||||
if name not in valid:
|
||||
suggestions = difflib.get_close_matches(name, valid, n=3, cutoff=0.6)
|
||||
msg = f"Unknown {kind}: '{name}'"
|
||||
if suggestions:
|
||||
msg += f". Did you mean: {', '.join(suggestions)}?"
|
||||
errors.append(msg)
|
||||
return errors
|
||||
@@ -57,7 +57,10 @@ Usage example::
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict
|
||||
from typing import Any, Dict, TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.mcp_service.chart.schemas import DataColumn
|
||||
|
||||
import humanize
|
||||
|
||||
@@ -161,3 +164,49 @@ class OmittedFieldsBuilder:
|
||||
def build(self) -> Dict[str, str]:
|
||||
"""Return the omission metadata dict."""
|
||||
return dict(self._fields)
|
||||
|
||||
|
||||
def format_data_columns(
|
||||
data: list[dict[str, Any]], raw_columns: list[str]
|
||||
) -> list[DataColumn]:
|
||||
"""Build column metadata from query result data.
|
||||
|
||||
Caps statistics at 5000 rows to avoid O(rows*cols) overhead on large
|
||||
result sets.
|
||||
"""
|
||||
# Local import breaks the chart.schemas ↔ response_utils circular dependency.
|
||||
from superset.mcp_service.chart.schemas import DataColumn # noqa: PLC0415
|
||||
|
||||
stats_rows = data[:5000]
|
||||
columns_meta: list[DataColumn] = []
|
||||
for col_name in raw_columns:
|
||||
sample_values = [
|
||||
row.get(col_name) for row in data[:3] if row.get(col_name) is not None
|
||||
]
|
||||
data_type = "string"
|
||||
if sample_values:
|
||||
if all(isinstance(v, bool) for v in sample_values):
|
||||
data_type = "boolean"
|
||||
elif all(isinstance(v, (int, float)) for v in sample_values):
|
||||
data_type = "numeric"
|
||||
|
||||
null_count = 0
|
||||
unique_vals: set[str] = set()
|
||||
for row in stats_rows:
|
||||
val = row.get(col_name)
|
||||
if val is None:
|
||||
null_count += 1
|
||||
else:
|
||||
unique_vals.add(str(val))
|
||||
|
||||
columns_meta.append(
|
||||
DataColumn(
|
||||
name=col_name,
|
||||
display_name=col_name.replace("_", " ").title(),
|
||||
data_type=data_type,
|
||||
sample_values=sample_values[:3],
|
||||
null_count=null_count,
|
||||
unique_count=len(unique_vals),
|
||||
)
|
||||
)
|
||||
return columns_meta
|
||||
|
||||
16
tests/unit_tests/mcp_service/semantic_layer/__init__.py
Normal file
16
tests/unit_tests/mcp_service/semantic_layer/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/semantic_layer/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/semantic_layer/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,231 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for the get_compatible_dimensions MCP tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client, FastMCP
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
get_compatible_dimensions_module = importlib.import_module(
|
||||
"superset.mcp_service.semantic_layer.tool.get_compatible_dimensions"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server() -> FastMCP:
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth() -> Generator[MagicMock, None, None]:
|
||||
with (
|
||||
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
|
||||
patch.object(
|
||||
get_compatible_dimensions_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _make_column(name: str, groupby: bool = True) -> MagicMock:
|
||||
col = MagicMock()
|
||||
col.column_name = name
|
||||
col.verbose_name = None
|
||||
col.description = None
|
||||
col.type = "VARCHAR"
|
||||
col.is_dttm = False
|
||||
col.groupby = groupby
|
||||
col.filterable = True
|
||||
return col
|
||||
|
||||
|
||||
def _make_dataset(dataset_id: int = 42) -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.id = dataset_id
|
||||
ds.table_name = f"table_{dataset_id}"
|
||||
ds.metrics = []
|
||||
ds.columns = [
|
||||
_make_column("region"),
|
||||
_make_column("category"),
|
||||
_make_column("internal_only", groupby=False),
|
||||
]
|
||||
return ds
|
||||
|
||||
|
||||
def _make_view(view_id: int = 5) -> MagicMock:
|
||||
view = MagicMock()
|
||||
view.id = view_id
|
||||
view.name = f"view_{view_id}"
|
||||
view.raise_for_access = MagicMock(return_value=None)
|
||||
view.columns = [_make_column("country_name")]
|
||||
view.get_compatible_dimensions = MagicMock(return_value=["country_name"])
|
||||
return view
|
||||
|
||||
|
||||
def _access_denied_exc(message: str = "Access denied") -> SupersetSecurityException:
|
||||
return SupersetSecurityException(
|
||||
SupersetError(
|
||||
message=message,
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_dimensions_builtin_happy_path(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Builtin datasets return all groupby-enabled columns, ignoring selection."""
|
||||
mock_ds = _make_dataset(42)
|
||||
|
||||
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_dimensions",
|
||||
{"request": {"dataset_id": 42, "selected_metrics": ["revenue"]}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["source"] == "builtin"
|
||||
names = {d["name"] for d in data["compatible_dimensions"]}
|
||||
assert names == {"region", "category"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_dimensions_external_happy_path(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""External views delegate to view.get_compatible_dimensions()."""
|
||||
mock_view = _make_view(5)
|
||||
|
||||
with patch(
|
||||
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
|
||||
return_value=mock_view,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_dimensions",
|
||||
{"request": {"view_id": 5, "selected_metrics": ["bookings"]}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["source"] == "external"
|
||||
assert [d["name"] for d in data["compatible_dimensions"]] == ["country_name"]
|
||||
mock_view.get_compatible_dimensions.assert_called_once_with(["bookings"], [])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_dimensions_mutual_exclusion_validation(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Errors when both dataset_id and view_id are provided."""
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_dimensions",
|
||||
{"request": {"dataset_id": 1, "view_id": 2}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "ValidationError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_dimensions_requires_one_source(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Errors when neither dataset_id nor view_id is provided."""
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_compatible_dimensions", {"request": {}})
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "ValidationError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_dimensions_privacy_check(mcp_server: FastMCP) -> None:
|
||||
"""Errors when the user lacks data-model metadata access."""
|
||||
with patch.object(
|
||||
get_compatible_dimensions_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=False,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_dimensions", {"request": {"dataset_id": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "DataModelMetadataRestricted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_dimensions_external_access_denied(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Returns AccessDenied when raise_for_access rejects the view."""
|
||||
mock_view = _make_view(5)
|
||||
mock_view.raise_for_access.side_effect = _access_denied_exc()
|
||||
|
||||
with patch(
|
||||
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
|
||||
return_value=mock_view,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_dimensions", {"request": {"view_id": 5}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "AccessDenied"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_dimensions_not_found(mcp_server: FastMCP) -> None:
|
||||
"""Returns NotFound when the dataset doesn't exist."""
|
||||
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_dimensions", {"request": {"dataset_id": 999}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "NotFound"
|
||||
@@ -0,0 +1,222 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for the get_compatible_metrics MCP tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client, FastMCP
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
get_compatible_metrics_module = importlib.import_module(
|
||||
"superset.mcp_service.semantic_layer.tool.get_compatible_metrics"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server() -> FastMCP:
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth() -> Generator[MagicMock, None, None]:
|
||||
with (
|
||||
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
|
||||
patch.object(
|
||||
get_compatible_metrics_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
|
||||
m = MagicMock()
|
||||
m.metric_name = name
|
||||
m.verbose_name = None
|
||||
m.expression = expression
|
||||
m.description = None
|
||||
m.d3format = None
|
||||
m.warning_text = None
|
||||
return m
|
||||
|
||||
|
||||
def _make_dataset(dataset_id: int = 42) -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.id = dataset_id
|
||||
ds.table_name = f"table_{dataset_id}"
|
||||
ds.columns = []
|
||||
ds.metrics = [_make_metric("count"), _make_metric("revenue", "SUM(revenue)")]
|
||||
return ds
|
||||
|
||||
|
||||
def _make_view(view_id: int = 5) -> MagicMock:
|
||||
view = MagicMock()
|
||||
view.id = view_id
|
||||
view.name = f"view_{view_id}"
|
||||
view.raise_for_access = MagicMock(return_value=None)
|
||||
view.metrics = [_make_metric("bookings")]
|
||||
view.get_compatible_metrics = MagicMock(return_value=["bookings"])
|
||||
return view
|
||||
|
||||
|
||||
def _access_denied_exc(message: str = "Access denied") -> SupersetSecurityException:
|
||||
return SupersetSecurityException(
|
||||
SupersetError(
|
||||
message=message,
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_metrics_builtin_happy_path(mcp_server: FastMCP) -> None:
|
||||
"""Builtin datasets return all metrics, ignoring the current selection."""
|
||||
mock_ds = _make_dataset(42)
|
||||
|
||||
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_metrics",
|
||||
{"request": {"dataset_id": 42, "selected_dimensions": ["region"]}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["source"] == "builtin"
|
||||
names = {m["name"] for m in data["compatible_metrics"]}
|
||||
assert names == {"count", "revenue"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_metrics_external_happy_path(mcp_server: FastMCP) -> None:
|
||||
"""External views delegate to view.get_compatible_metrics()."""
|
||||
mock_view = _make_view(5)
|
||||
|
||||
with patch(
|
||||
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
|
||||
return_value=mock_view,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_metrics",
|
||||
{"request": {"view_id": 5, "selected_dimensions": ["country_name"]}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["source"] == "external"
|
||||
assert [m["name"] for m in data["compatible_metrics"]] == ["bookings"]
|
||||
mock_view.get_compatible_metrics.assert_called_once_with([], ["country_name"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_metrics_mutual_exclusion_validation(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Errors when both dataset_id and view_id are provided."""
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_metrics",
|
||||
{"request": {"dataset_id": 1, "view_id": 2}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "ValidationError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_metrics_requires_one_source(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Errors when neither dataset_id nor view_id is provided."""
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_compatible_metrics", {"request": {}})
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "ValidationError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_metrics_privacy_check(mcp_server: FastMCP) -> None:
|
||||
"""Errors when the user lacks data-model metadata access."""
|
||||
with patch.object(
|
||||
get_compatible_metrics_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=False,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_metrics", {"request": {"dataset_id": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "DataModelMetadataRestricted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_metrics_external_access_denied(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""Returns AccessDenied when raise_for_access rejects the view."""
|
||||
mock_view = _make_view(5)
|
||||
mock_view.raise_for_access.side_effect = _access_denied_exc()
|
||||
|
||||
with patch(
|
||||
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
|
||||
return_value=mock_view,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_metrics", {"request": {"view_id": 5}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "AccessDenied"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_compatible_metrics_not_found(mcp_server: FastMCP) -> None:
|
||||
"""Returns NotFound when the dataset doesn't exist."""
|
||||
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_compatible_metrics", {"request": {"dataset_id": 999}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "NotFound"
|
||||
@@ -0,0 +1,263 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for the get_table MCP tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client, FastMCP
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
get_table_module = importlib.import_module(
|
||||
"superset.mcp_service.semantic_layer.tool.get_table"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server() -> FastMCP:
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth() -> Generator[MagicMock, None, None]:
|
||||
with (
|
||||
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
|
||||
patch.object(
|
||||
get_table_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
|
||||
m = MagicMock()
|
||||
m.metric_name = name
|
||||
m.verbose_name = None
|
||||
m.expression = expression
|
||||
m.description = None
|
||||
m.d3format = None
|
||||
m.warning_text = None
|
||||
return m
|
||||
|
||||
|
||||
def _make_column(name: str, is_dttm: bool = False) -> MagicMock:
|
||||
col = MagicMock()
|
||||
col.column_name = name
|
||||
col.verbose_name = None
|
||||
col.description = None
|
||||
col.type = "VARCHAR"
|
||||
col.is_dttm = is_dttm
|
||||
col.groupby = True
|
||||
col.filterable = True
|
||||
return col
|
||||
|
||||
|
||||
def _make_dataset(dataset_id: int = 42) -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.id = dataset_id
|
||||
ds.table_name = f"table_{dataset_id}"
|
||||
ds.main_dttm_col = "created_at"
|
||||
ds.metrics = [_make_metric("revenue", "SUM(revenue)")]
|
||||
ds.columns = [
|
||||
_make_column("region"),
|
||||
_make_column("created_at", is_dttm=True),
|
||||
]
|
||||
return ds
|
||||
|
||||
|
||||
def _make_view(view_id: int = 5) -> MagicMock:
|
||||
view = MagicMock()
|
||||
view.id = view_id
|
||||
view.name = f"view_{view_id}"
|
||||
view.raise_for_access = MagicMock(return_value=None)
|
||||
view.metrics = [_make_metric("bookings")]
|
||||
view.columns = [_make_column("country_name")]
|
||||
return view
|
||||
|
||||
|
||||
def _access_denied_exc(message: str = "Access denied") -> SupersetSecurityException:
|
||||
return SupersetSecurityException(
|
||||
SupersetError(
|
||||
message=message,
|
||||
error_type=SupersetErrorType.DATASOURCE_SECURITY_ACCESS_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_builtin_happy_path(mcp_server: FastMCP) -> None:
|
||||
"""get_table returns tabular data for a built-in dataset."""
|
||||
mock_ds = _make_dataset(42)
|
||||
query_result = {
|
||||
"queries": [
|
||||
{
|
||||
"data": [{"region": "west", "revenue": 100}],
|
||||
"colnames": ["region", "revenue"],
|
||||
"rowcount": 1,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
with (
|
||||
patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds),
|
||||
patch(
|
||||
"superset.commands.chart.data.get_data_command.ChartDataCommand"
|
||||
) as mock_command_cls,
|
||||
patch(
|
||||
"superset.common.query_context_factory.QueryContextFactory"
|
||||
) as mock_factory_cls,
|
||||
):
|
||||
mock_command_cls.return_value.run.return_value = query_result
|
||||
mock_factory_cls.return_value.create.return_value = MagicMock()
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_table",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 42,
|
||||
"metrics": ["revenue"],
|
||||
"dimensions": ["region"],
|
||||
}
|
||||
},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["row_count"] == 1
|
||||
assert data["source"] == "builtin"
|
||||
assert data["dataset_id"] == 42
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_requires_one_source(mcp_server: FastMCP) -> None:
|
||||
"""get_table errors when neither dataset_id nor view_id is provided."""
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_table", {"request": {}})
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "ValidationError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_mutual_exclusion_validation(mcp_server: FastMCP) -> None:
|
||||
"""get_table errors when both dataset_id and view_id are provided."""
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_table", {"request": {"dataset_id": 1, "view_id": 2}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "ValidationError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_privacy_check(mcp_server: FastMCP) -> None:
|
||||
"""get_table errors when the user lacks data-model metadata access."""
|
||||
with patch.object(
|
||||
get_table_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=False,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_table", {"request": {"dataset_id": 1}})
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "DataModelMetadataRestricted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_unknown_metric_validation_error(mcp_server: FastMCP) -> None:
|
||||
"""get_table errors when a requested metric doesn't exist on the dataset."""
|
||||
mock_ds = _make_dataset(42)
|
||||
|
||||
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_table",
|
||||
{"request": {"dataset_id": 42, "metrics": ["does_not_exist"]}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "ValidationError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_time_column_not_dttm_validation_error(
|
||||
mcp_server: FastMCP,
|
||||
) -> None:
|
||||
"""get_table rejects a time_column that isn't marked as a datetime column."""
|
||||
mock_ds = _make_dataset(42)
|
||||
|
||||
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_ds):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_table",
|
||||
{
|
||||
"request": {
|
||||
"dataset_id": 42,
|
||||
"metrics": ["revenue"],
|
||||
"time_column": "region",
|
||||
}
|
||||
},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "ValidationError"
|
||||
assert "not marked as a datetime column" in data["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_table_external_view_access_denied(mcp_server: FastMCP) -> None:
|
||||
"""get_table returns AccessDenied when raise_for_access rejects the view."""
|
||||
mock_view = _make_view(5)
|
||||
mock_view.raise_for_access.side_effect = _access_denied_exc()
|
||||
|
||||
with patch(
|
||||
"superset.daos.semantic_layer.SemanticViewDAO.find_by_id",
|
||||
return_value=mock_view,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_table",
|
||||
{"request": {"view_id": 5, "metrics": ["bookings"]}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "AccessDenied"
|
||||
@@ -0,0 +1,183 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for the list_metrics MCP tool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client, FastMCP
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
list_metrics_module = importlib.import_module(
|
||||
"superset.mcp_service.semantic_layer.tool.list_metrics"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server() -> FastMCP:
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth() -> Generator[MagicMock, None, None]:
|
||||
with (
|
||||
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
|
||||
patch.object(
|
||||
list_metrics_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=True,
|
||||
),
|
||||
):
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _make_metric(name: str, expression: str = "COUNT(*)") -> MagicMock:
|
||||
m = MagicMock()
|
||||
m.metric_name = name
|
||||
m.verbose_name = None
|
||||
m.expression = expression
|
||||
m.description = None
|
||||
m.d3format = None
|
||||
m.warning_text = None
|
||||
return m
|
||||
|
||||
|
||||
def _make_column(name: str) -> MagicMock:
|
||||
col = MagicMock()
|
||||
col.column_name = name
|
||||
col.verbose_name = None
|
||||
col.description = None
|
||||
col.type = "VARCHAR"
|
||||
col.is_dttm = False
|
||||
col.groupby = True
|
||||
col.filterable = True
|
||||
return col
|
||||
|
||||
|
||||
def _make_dataset(dataset_id: int = 1) -> MagicMock:
|
||||
ds = MagicMock()
|
||||
ds.id = dataset_id
|
||||
ds.table_name = f"table_{dataset_id}"
|
||||
ds.metrics = [_make_metric("count"), _make_metric("revenue", "SUM(revenue)")]
|
||||
ds.columns = [_make_column("region"), _make_column("category")]
|
||||
return ds
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_metrics_builtin_happy_path(mcp_server: FastMCP) -> None:
|
||||
"""list_metrics returns builtin metrics when only datasets exist."""
|
||||
mock_ds = _make_dataset(42)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.semantic_layer.tool.list_metrics.DatasetDAO"
|
||||
) as mock_dao,
|
||||
patch(
|
||||
"superset.mcp_service.semantic_layer.tool.list_metrics.SemanticViewDAO"
|
||||
) as mock_view_dao,
|
||||
):
|
||||
mock_dao.find_by_id.return_value = mock_ds
|
||||
mock_view_dao.find_accessible.return_value = []
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"list_metrics",
|
||||
{"request": {"dataset_id": 42, "include_compatible_dimensions": False}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is True
|
||||
assert data["total_count"] == 2
|
||||
metrics = data["metrics"]
|
||||
assert {m["name"] for m in metrics} == {"count", "revenue"}
|
||||
assert all(m["source"] == "builtin" for m in metrics)
|
||||
assert all(m["dataset_id"] == 42 for m in metrics)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_metrics_mutual_exclusion_validation(mcp_server: FastMCP) -> None:
|
||||
"""list_metrics returns a validation error when dataset_id and view_id coexist."""
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"list_metrics",
|
||||
{"request": {"dataset_id": 1, "view_id": 2}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "ValidationError"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_metrics_privacy_check(mcp_server: FastMCP) -> None:
|
||||
"""list_metrics returns an error when the user lacks data-model metadata access."""
|
||||
with patch.object(
|
||||
list_metrics_module,
|
||||
"user_can_view_data_model_metadata",
|
||||
return_value=False,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_metrics", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is False
|
||||
assert data["error_type"] == "DataModelMetadataRestricted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_metrics_search_filter(mcp_server: FastMCP) -> None:
|
||||
"""list_metrics filters metrics by search term."""
|
||||
mock_ds = _make_dataset(1)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"superset.mcp_service.semantic_layer.tool.list_metrics.DatasetDAO"
|
||||
) as mock_dao,
|
||||
patch(
|
||||
"superset.mcp_service.semantic_layer.tool.list_metrics.SemanticViewDAO"
|
||||
) as mock_view_dao,
|
||||
patch("superset.mcp_service.semantic_layer.tool.list_metrics.db") as mock_db,
|
||||
):
|
||||
mock_view_dao.find_accessible.return_value = []
|
||||
mock_query = MagicMock()
|
||||
mock_db.session.query.return_value.options.return_value = mock_query
|
||||
mock_dao._apply_base_filter.return_value = mock_query
|
||||
mock_query.all.return_value = [mock_ds]
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"list_metrics",
|
||||
{"request": {"search": "revenue"}},
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["success"] is True
|
||||
# Only the "revenue" metric should match the search
|
||||
metrics = data["metrics"]
|
||||
assert len(metrics) == 1
|
||||
assert metrics[0]["name"] == "revenue"
|
||||
Reference in New Issue
Block a user