mirror of
https://github.com/apache/superset.git
synced 2026-05-15 12:55:08 +00:00
Compare commits
5 Commits
fix/mcp-ex
...
fix-mcp-pe
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a2dfaa873e | ||
|
|
f13e24a6df | ||
|
|
534b5aa799 | ||
|
|
6275159eb4 | ||
|
|
c54f77a0b2 |
@@ -84,6 +84,7 @@ Schema Discovery:
|
||||
|
||||
System Information:
|
||||
- get_instance_info: Get instance-wide statistics, metadata, and current user identity
|
||||
- find_users: Resolve a person's name to user IDs for use as a filter value
|
||||
- health_check: Simple health check tool (takes NO parameters, call without arguments)
|
||||
|
||||
Available Resources:
|
||||
@@ -123,6 +124,16 @@ Some tools do not use a request wrapper, so follow each tool's schema
|
||||
|
||||
Recommended Workflows:
|
||||
|
||||
To filter dashboards/charts/datasets by a person ("show me what <name> is working on"):
|
||||
1. find_users(request={{"query": "<name>"}}) -> resolve to user IDs
|
||||
2. Pick the matching user.id from the response
|
||||
3. list_dashboards(request={{"filters": [
|
||||
{{"col": "created_by_fk", "opr": "eq", "value": <id>}}
|
||||
]}}) — same shape for list_charts / list_datasets.
|
||||
(use changed_by_fk for "last modified by", or "in" with a list of IDs for
|
||||
multiple matches). Do NOT pass the person's name as the search parameter —
|
||||
search matches titles, not people.
|
||||
|
||||
To add a chart to an existing dashboard:
|
||||
1. add_chart_to_existing_dashboard(dashboard_id, chart_id) -> updates dashboard directly
|
||||
- If permission_denied=True is returned: inform the user they lack edit rights,
|
||||
@@ -263,6 +274,11 @@ Permission Awareness:
|
||||
contact details, roles, admin status, ownership, or access-list information.
|
||||
- Do NOT infer access-list answers from dashboard metadata such as published status,
|
||||
role restrictions, empty owner lists, or schema fields.
|
||||
- find_users is sanctioned ONLY for resolving a name the user supplied into a
|
||||
user ID for filtering (e.g., "what is <name> working on" -> filter
|
||||
list_dashboards by created_by_fk). Do NOT use find_users to answer "who owns
|
||||
X", "who can access X", "is <name> an admin", or to enumerate the directory.
|
||||
Never return find_users output to the user verbatim.
|
||||
- Do NOT use execute_sql to query user, role, owner, or access-list tables for this
|
||||
information.
|
||||
- You may reference the current user's own identity details when appropriate, such
|
||||
@@ -509,6 +525,7 @@ from superset.mcp_service.system import ( # noqa: F401, E402
|
||||
resources as system_resources,
|
||||
)
|
||||
from superset.mcp_service.system.tool import ( # noqa: F401, E402
|
||||
find_users,
|
||||
get_instance_info,
|
||||
get_schema,
|
||||
health_check,
|
||||
|
||||
@@ -23,7 +23,7 @@ from __future__ import annotations
|
||||
|
||||
import difflib
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated, Any, Dict, List, Literal, Protocol
|
||||
from typing import Annotated, Any, cast, Dict, List, Literal, Protocol
|
||||
|
||||
import humanize
|
||||
from pydantic import (
|
||||
@@ -141,7 +141,7 @@ class ChartInfo(BaseModel):
|
||||
),
|
||||
)
|
||||
form_data_key: str | None = Field(
|
||||
None,
|
||||
default=None,
|
||||
description=(
|
||||
"Cache key used to retrieve unsaved form_data. When present, indicates "
|
||||
"the form_data came from cache (unsaved edits) rather than the saved chart."
|
||||
@@ -435,14 +435,18 @@ class ChartFilter(ColumnOperator):
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
"slice_name",
|
||||
"viz_type",
|
||||
"datasource_name",
|
||||
"created_by_fk",
|
||||
"changed_by_fk",
|
||||
] = Field(
|
||||
...,
|
||||
description="Column to filter on. Use get_schema(model_type='chart') for "
|
||||
"available filter columns.",
|
||||
"available filter columns. To filter by a person, first call find_users "
|
||||
"to resolve a name to a user ID, then filter by created_by_fk or "
|
||||
"changed_by_fk with that integer ID.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
@@ -1351,7 +1355,10 @@ class ListChartsRequest(MetadataCacheControl):
|
||||
"""
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
|
||||
|
||||
return parse_json_or_model_list(v, ChartFilter, "filters")
|
||||
return cast(
|
||||
List[ChartFilter],
|
||||
parse_json_or_model_list(v, ChartFilter, "filters"),
|
||||
)
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -90,6 +90,10 @@ async def list_charts(
|
||||
|
||||
Sortable columns for order_column: id, slice_name, viz_type, description,
|
||||
changed_on, created_on
|
||||
|
||||
To filter by a person, call find_users to resolve the name to a user ID,
|
||||
then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq",
|
||||
"value": <id>}] (or "changed_by_fk"). Do not pass the name as search.
|
||||
"""
|
||||
await ctx.info(
|
||||
"Listing charts: page=%s, page_size=%s, search=%s"
|
||||
|
||||
@@ -37,10 +37,12 @@ class ColumnMetadata(BaseModel):
|
||||
"""Metadata for a selectable column."""
|
||||
|
||||
name: str = Field(..., description="Column name to use in select_columns")
|
||||
description: str | None = Field(None, description="Column description")
|
||||
type: str | None = Field(None, description="Data type (str, int, datetime, etc.)")
|
||||
description: str | None = Field(default=None, description="Column description")
|
||||
type: str | None = Field(
|
||||
default=None, description="Data type (str, int, datetime, etc.)"
|
||||
)
|
||||
is_default: bool = Field(
|
||||
False, description="Whether this column is included by default"
|
||||
default=False, description="Whether this column is included by default"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, TYPE_CHECKING
|
||||
from typing import Annotated, Any, cast, Dict, List, Literal, TYPE_CHECKING
|
||||
|
||||
import humanize
|
||||
from pydantic import (
|
||||
@@ -155,16 +155,20 @@ class DashboardFilter(ColumnOperator):
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
"dashboard_title",
|
||||
"published",
|
||||
"favorite",
|
||||
"created_by_fk",
|
||||
"changed_by_fk",
|
||||
] = Field(
|
||||
...,
|
||||
description=(
|
||||
"Column to filter on. Use "
|
||||
"get_schema(model_type='dashboard') for available "
|
||||
"filter columns."
|
||||
"filter columns. To filter by a person, first call find_users to "
|
||||
"resolve a name to a user ID, then filter by created_by_fk or "
|
||||
"changed_by_fk with that integer ID."
|
||||
),
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
@@ -209,7 +213,10 @@ class ListDashboardsRequest(MetadataCacheControl):
|
||||
"""
|
||||
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
|
||||
|
||||
return parse_json_or_model_list(v, DashboardFilter, "filters")
|
||||
return cast(
|
||||
List[DashboardFilter],
|
||||
parse_json_or_model_list(v, DashboardFilter, "filters"),
|
||||
)
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
@@ -379,14 +386,14 @@ class DashboardInfo(BaseModel):
|
||||
|
||||
# Fields for permalink/filter state support
|
||||
permalink_key: str | None = Field(
|
||||
None,
|
||||
default=None,
|
||||
description=(
|
||||
"Permalink key used to retrieve filter state. When present, indicates "
|
||||
"the filter_state came from a permalink rather than the default dashboard."
|
||||
),
|
||||
)
|
||||
filter_state: Dict[str, Any] | None = Field(
|
||||
None,
|
||||
default=None,
|
||||
description=(
|
||||
"Filter state from permalink. Contains dataMask (native filter values), "
|
||||
"activeTabs, anchor, and urlParams. When present, represents the actual "
|
||||
|
||||
@@ -84,6 +84,12 @@ async def list_dashboards(
|
||||
|
||||
Sortable columns for order_column: id, dashboard_title, slug, published,
|
||||
changed_on, created_on
|
||||
|
||||
To filter by a person (e.g. "dashboards Maxime is working on"), do NOT pass
|
||||
the name as the search parameter — search matches titles and slugs only.
|
||||
Instead, call find_users to resolve the name to a user ID, then pass it as
|
||||
a filter: filters=[{"col": "created_by_fk", "opr": "eq", "value": <id>}]
|
||||
(or "changed_by_fk" for "last modified by").
|
||||
"""
|
||||
await ctx.info(
|
||||
"Listing dashboards: page=%s, page_size=%s, search=%s"
|
||||
|
||||
@@ -22,7 +22,7 @@ Pydantic schemas for database-related responses
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal
|
||||
from typing import Annotated, Any, cast, Dict, List, Literal
|
||||
|
||||
import humanize
|
||||
from pydantic import (
|
||||
@@ -55,7 +55,7 @@ class DatabaseFilter(ColumnOperator):
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
"database_name",
|
||||
"expose_in_sqllab",
|
||||
"allow_file_upload",
|
||||
@@ -239,7 +239,10 @@ class ListDatabasesRequest(MetadataCacheControl):
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[DatabaseFilter]:
|
||||
"""Accept both JSON string and list of objects."""
|
||||
return parse_json_or_model_list(v, DatabaseFilter, "filters")
|
||||
return cast(
|
||||
List[DatabaseFilter],
|
||||
parse_json_or_model_list(v, DatabaseFilter, "filters"),
|
||||
)
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -54,14 +54,18 @@ class DatasetFilter(ColumnOperator):
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
col: Literal[ # pyright: ignore[reportIncompatibleVariableOverride]
|
||||
"table_name",
|
||||
"schema",
|
||||
"database_name",
|
||||
"created_by_fk",
|
||||
"changed_by_fk",
|
||||
] = Field(
|
||||
...,
|
||||
description="Column to filter on. Use get_schema(model_type='dataset') for "
|
||||
"available filter columns.",
|
||||
"available filter columns. To filter by a person, first call find_users "
|
||||
"to resolve a name to a user ID, then filter by created_by_fk or "
|
||||
"changed_by_fk with that integer ID.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
@@ -415,7 +419,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
|
||||
params = None
|
||||
columns = [
|
||||
TableColumnInfo(
|
||||
column_name=getattr(col, "column_name", None),
|
||||
column_name=getattr(col, "column_name", None) or "",
|
||||
verbose_name=getattr(col, "verbose_name", None),
|
||||
type=getattr(col, "type", None),
|
||||
is_dttm=getattr(col, "is_dttm", None),
|
||||
@@ -427,7 +431,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
|
||||
]
|
||||
metrics = [
|
||||
SqlMetricInfo(
|
||||
metric_name=getattr(metric, "metric_name", None),
|
||||
metric_name=getattr(metric, "metric_name", None) or "",
|
||||
verbose_name=getattr(metric, "verbose_name", None),
|
||||
expression=getattr(metric, "expression", None),
|
||||
description=getattr(metric, "description", None),
|
||||
@@ -438,7 +442,7 @@ def serialize_dataset_object(dataset: Any) -> DatasetInfo | None:
|
||||
return DatasetInfo(
|
||||
id=getattr(dataset, "id", None),
|
||||
table_name=getattr(dataset, "table_name", None),
|
||||
schema_name=getattr(dataset, "schema", None),
|
||||
schema=getattr(dataset, "schema", None),
|
||||
database_name=getattr(dataset.database, "database_name", None)
|
||||
if getattr(dataset, "database", None)
|
||||
else None,
|
||||
|
||||
@@ -98,6 +98,10 @@ async def list_datasets(
|
||||
|
||||
Sortable columns for order_column: id, table_name, schema, changed_on,
|
||||
created_on
|
||||
|
||||
To filter by a person, call find_users to resolve the name to a user ID,
|
||||
then pass it as a filter: filters=[{"col": "created_by_fk", "opr": "eq",
|
||||
"value": <id>}] (or "changed_by_fk"). Do not pass the name as search.
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_datasets")
|
||||
|
||||
@@ -29,6 +29,7 @@ from superset.mcp_service.constants import ModelType
|
||||
from superset.mcp_service.privacy import (
|
||||
filter_user_directory_columns,
|
||||
USER_DIRECTORY_FIELDS,
|
||||
USER_FILTER_FIELDS,
|
||||
)
|
||||
from superset.mcp_service.utils import _is_uuid
|
||||
|
||||
@@ -245,14 +246,6 @@ class ModelListCore(BaseCore, Generic[L]):
|
||||
has_previous=page > 0,
|
||||
)
|
||||
|
||||
# Build response
|
||||
def get_keys(obj: BaseModel | dict[str, Any] | Any) -> List[str]:
|
||||
if hasattr(obj, "model_dump"):
|
||||
return list(obj.model_dump().keys())
|
||||
elif isinstance(obj, dict):
|
||||
return list(obj.keys())
|
||||
return []
|
||||
|
||||
response_kwargs = {
|
||||
self.list_field_name: item_objs,
|
||||
"count": len(item_objs),
|
||||
@@ -425,7 +418,7 @@ class InstanceInfoCore(BaseCore):
|
||||
return counts
|
||||
|
||||
def _calculate_time_based_metrics(
|
||||
self, base_counts: Dict[str, int]
|
||||
self, _base_counts: Dict[str, int]
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
"""Calculate time-based metrics for recent activity."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
@@ -612,7 +605,9 @@ class ModelGetSchemaCore(BaseCore, Generic[S]):
|
||||
self.default_sort = default_sort
|
||||
self.default_sort_direction = default_sort_direction
|
||||
self.exclude_filter_columns = set(exclude_filter_columns or set())
|
||||
self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS)
|
||||
# Hide user-directory columns from filter discovery, except the small
|
||||
# set callers may legitimately filter by ID (resolved via find_users).
|
||||
self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS - USER_FILTER_FIELDS)
|
||||
|
||||
def _get_filter_columns(self) -> Dict[str, List[str]]:
|
||||
"""Get filterable columns and operators from the DAO."""
|
||||
|
||||
@@ -44,6 +44,12 @@ USER_DIRECTORY_FIELDS = frozenset(
|
||||
}
|
||||
)
|
||||
|
||||
# User-directory columns that may be used as filter values (an integer user ID).
|
||||
# These remain stripped from select_columns, sort, search, and tool responses
|
||||
# (so the directory itself is never exposed), but list tools may filter rows by
|
||||
# them when the caller already has an ID — typically resolved via find_users.
|
||||
USER_FILTER_FIELDS = frozenset({"created_by_fk", "changed_by_fk"})
|
||||
|
||||
DATA_MODEL_METADATA_ACCESS_ATTR = "_requires_data_model_metadata_access"
|
||||
DATA_MODEL_METADATA_ERROR_TYPE = "DataModelMetadataRestricted"
|
||||
DATA_MODEL_METADATA_PRIVACY_SCOPE = "data_model"
|
||||
|
||||
@@ -25,9 +25,11 @@ system-level info.
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List
|
||||
from typing import Annotated, Any, Dict, List
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
|
||||
|
||||
class HealthCheckResponse(BaseModel):
|
||||
@@ -170,6 +172,84 @@ def serialize_user_object(user: Any) -> UserInfo | None:
|
||||
)
|
||||
|
||||
|
||||
class FindUsersRequest(BaseModel):
|
||||
"""Request schema for find_users tool.
|
||||
|
||||
Resolves a person's name (or partial name, username, or email) to user IDs
|
||||
so they can be passed to listing tools as filter values for created_by_fk
|
||||
or changed_by_fk. This is the only sanctioned path for "show me what
|
||||
<person> is working on" queries.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
query: Annotated[
|
||||
str,
|
||||
Field(
|
||||
min_length=1,
|
||||
max_length=200,
|
||||
description=(
|
||||
"Substring to match (case-insensitive) against username, "
|
||||
"first_name, last_name, and email. Required and non-empty: "
|
||||
"this tool does not enumerate the full user directory."
|
||||
),
|
||||
),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Maximum number of matches to return (max {MAX_PAGE_SIZE}).",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("query")
|
||||
@classmethod
|
||||
def _reject_blank_query(cls, value: str) -> str:
|
||||
# min_length=1 alone admits whitespace-only strings, which strip to "" and
|
||||
# produce a "%%" LIKE pattern that matches every user. Strip and require
|
||||
# at least one non-space character.
|
||||
stripped = value.strip()
|
||||
if not stripped:
|
||||
raise ValueError("query must contain at least one non-whitespace character")
|
||||
return stripped
|
||||
|
||||
|
||||
class UserMatch(BaseModel):
|
||||
"""Minimal user projection returned by find_users.
|
||||
|
||||
Intentionally narrower than UserInfo: only the fields needed to disambiguate
|
||||
matches and pass an id to created_by_fk / changed_by_fk filters. Email,
|
||||
active flag, and roles are deliberately excluded to limit identity
|
||||
exposure through this directory-resolution path.
|
||||
"""
|
||||
|
||||
id: int | None = None
|
||||
username: str | None = None
|
||||
first_name: str | None = None
|
||||
last_name: str | None = None
|
||||
|
||||
|
||||
class FindUsersResponse(BaseModel):
|
||||
"""Response schema for find_users tool."""
|
||||
|
||||
users: List[UserMatch] = Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"Matching users. Pass user.id as the value for created_by_fk or "
|
||||
"changed_by_fk filters on list_dashboards, list_charts, and "
|
||||
"list_datasets."
|
||||
),
|
||||
)
|
||||
count: int = Field(..., description="Number of users returned in this response.")
|
||||
truncated: bool = Field(
|
||||
default=False,
|
||||
description="True when the query matched more rows than page_size allows.",
|
||||
)
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
id: int | None = None
|
||||
name: str | None = None
|
||||
|
||||
@@ -17,11 +17,13 @@
|
||||
|
||||
"""System tools for MCP service."""
|
||||
|
||||
from .find_users import find_users
|
||||
from .get_instance_info import get_instance_info
|
||||
from .get_schema import get_schema
|
||||
from .health_check import health_check
|
||||
|
||||
__all__ = [
|
||||
"find_users",
|
||||
"health_check",
|
||||
"get_instance_info",
|
||||
"get_schema",
|
||||
|
||||
101
superset/mcp_service/system/tool/find_users.py
Normal file
101
superset/mcp_service/system/tool/find_users.py
Normal file
@@ -0,0 +1,101 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""find_users MCP tool: resolve a person's name to user IDs for filtering."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy import or_
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import db, event_logger, security_manager
|
||||
from superset.mcp_service.system.schemas import (
|
||||
FindUsersRequest,
|
||||
FindUsersResponse,
|
||||
UserMatch,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
annotations=ToolAnnotations(
|
||||
title="Find users",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def find_users(request: FindUsersRequest, ctx: Context) -> FindUsersResponse:
|
||||
"""Resolve a person's name to user IDs so they can be used as filter values.
|
||||
|
||||
Use this when the caller asks "show me <person>'s dashboards/charts/datasets"
|
||||
or "what is <person> working on". Take the matching user.id and pass it as
|
||||
the value for a created_by_fk or changed_by_fk filter on list_dashboards,
|
||||
list_charts, or list_datasets.
|
||||
|
||||
Matches case-insensitively against username, first_name, last_name, and
|
||||
email. The query is required and non-empty; this tool does not enumerate
|
||||
the full user directory.
|
||||
|
||||
Privacy: returning a user's identity here is sanctioned only for resolving
|
||||
filter values. Do not use the response to answer "who owns X", "who can
|
||||
access X", or any access-list question — those remain off-limits per the
|
||||
server instructions.
|
||||
"""
|
||||
await ctx.info(
|
||||
"Resolving user query: query=%s, page_size=%s"
|
||||
% (request.query, request.page_size)
|
||||
)
|
||||
|
||||
user_model = security_manager.user_model
|
||||
needle = f"%{request.query.strip()}%"
|
||||
|
||||
with event_logger.log_context(action="mcp.find_users.query"):
|
||||
query = (
|
||||
db.session.query(user_model)
|
||||
.filter(
|
||||
or_(
|
||||
user_model.username.ilike(needle),
|
||||
user_model.first_name.ilike(needle),
|
||||
user_model.last_name.ilike(needle),
|
||||
user_model.email.ilike(needle),
|
||||
)
|
||||
)
|
||||
.order_by(user_model.username.asc())
|
||||
)
|
||||
# Fetch one extra row to detect truncation without a separate count query.
|
||||
rows = query.limit(request.page_size + 1).all()
|
||||
|
||||
truncated = len(rows) > request.page_size
|
||||
rows = rows[: request.page_size]
|
||||
|
||||
users: list[UserMatch] = [
|
||||
UserMatch(
|
||||
id=getattr(row, "id", None),
|
||||
username=getattr(row, "username", None),
|
||||
first_name=getattr(row, "first_name", None),
|
||||
last_name=getattr(row, "last_name", None),
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
|
||||
await ctx.info(
|
||||
"Resolved user query: matches=%s, truncated=%s" % (len(users), truncated)
|
||||
)
|
||||
return FindUsersResponse(users=users, count=len(users), truncated=truncated)
|
||||
257
tests/unit_tests/mcp_service/system/tool/test_find_users.py
Normal file
257
tests/unit_tests/mcp_service/system/tool/test_find_users.py
Normal file
@@ -0,0 +1,257 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for find_users MCP tool and its filter contract."""
|
||||
|
||||
import importlib
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.system.schemas import FindUsersRequest, FindUsersResponse
|
||||
from superset.utils import json
|
||||
|
||||
# Import the submodule directly so ``patch.object`` targets the module (not the
|
||||
# ``find_users`` function that ``tool/__init__.py`` re-exports onto the
|
||||
# package). The package attribute is the function, so dotted-string patches
|
||||
# like ``superset.mcp_service.system.tool.find_users.db`` can resolve to the
|
||||
# function in some import orderings and fail with AttributeError.
|
||||
find_users_module = importlib.import_module(
|
||||
"superset.mcp_service.system.tool.find_users"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _make_user(id_, username, first=None, last=None, email=None, active=True):
|
||||
"""Build a Mock user with the attributes serialize_user_object reads."""
|
||||
user = Mock(
|
||||
spec=["id", "username", "first_name", "last_name", "email", "active", "roles"]
|
||||
)
|
||||
user.id = id_
|
||||
user.username = username
|
||||
user.first_name = first
|
||||
user.last_name = last
|
||||
user.email = email
|
||||
user.active = active
|
||||
user.roles = []
|
||||
return user
|
||||
|
||||
|
||||
def _patch_user_query(rows):
|
||||
"""Patch the SQLAlchemy chain used by find_users to return a fixed result set."""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
chain.order_by.return_value = chain
|
||||
chain.limit.return_value = chain
|
||||
chain.all.return_value = rows
|
||||
session = MagicMock()
|
||||
session.query.return_value = chain
|
||||
return session, chain
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_find_users_request_rejects_empty_query():
|
||||
with pytest.raises(ValidationError):
|
||||
FindUsersRequest(query="")
|
||||
|
||||
|
||||
def test_find_users_request_rejects_extra_fields():
|
||||
with pytest.raises(ValidationError):
|
||||
FindUsersRequest(query="maxime", random_field="x")
|
||||
|
||||
|
||||
def test_find_users_response_default_truncated_false():
|
||||
resp = FindUsersResponse(users=[], count=0)
|
||||
assert resp.truncated is False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tool-level tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_users_returns_matches(mcp_server):
|
||||
rows = [
|
||||
_make_user(
|
||||
7, "maxime", first="Maxime", last="Beauchemin", email="m@example.com"
|
||||
)
|
||||
]
|
||||
session, _ = _patch_user_query(rows)
|
||||
|
||||
with (
|
||||
patch.object(find_users_module, "db") as mock_db,
|
||||
patch.object(find_users_module, "security_manager") as mock_sm,
|
||||
patch.object(find_users_module, "or_") as mock_or,
|
||||
):
|
||||
mock_db.session = session
|
||||
mock_sm.user_model = MagicMock()
|
||||
mock_or.return_value = MagicMock()
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"find_users", {"request": {"query": "maxime"}}
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 1
|
||||
assert data["truncated"] is False
|
||||
assert data["users"][0]["id"] == 7
|
||||
assert data["users"][0]["username"] == "maxime"
|
||||
assert data["users"][0]["first_name"] == "Maxime"
|
||||
assert data["users"][0]["last_name"] == "Beauchemin"
|
||||
# Privacy: minimal projection excludes identity attributes that aren't
|
||||
# required for filter resolution. Catch regressions on the response shape.
|
||||
for forbidden in ("email", "active", "roles"):
|
||||
assert forbidden not in data["users"][0]
|
||||
# or_ should have been built across the four matched columns
|
||||
assert mock_or.called
|
||||
assert len(mock_or.call_args.args) == 4
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_users_truncates_when_more_rows_than_page_size(mcp_server):
|
||||
# page_size=2 with 3 returned rows -> truncated, response trimmed to 2
|
||||
rows = [
|
||||
_make_user(1, "a"),
|
||||
_make_user(2, "b"),
|
||||
_make_user(3, "c"),
|
||||
]
|
||||
session, chain = _patch_user_query(rows)
|
||||
|
||||
with (
|
||||
patch.object(find_users_module, "db") as mock_db,
|
||||
patch.object(find_users_module, "security_manager") as mock_sm,
|
||||
patch.object(find_users_module, "or_") as mock_or,
|
||||
):
|
||||
mock_db.session = session
|
||||
mock_sm.user_model = MagicMock()
|
||||
mock_or.return_value = MagicMock()
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"find_users", {"request": {"query": "a", "page_size": 2}}
|
||||
)
|
||||
|
||||
# Tool requested page_size+1 rows for truncation detection
|
||||
chain.limit.assert_called_with(3)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 2
|
||||
assert data["truncated"] is True
|
||||
assert [u["id"] for u in data["users"]] == [1, 2]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_find_users_rejects_empty_query_via_client(mcp_server):
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError):
|
||||
await client.call_tool("find_users", {"request": {"query": ""}})
|
||||
|
||||
|
||||
@pytest.mark.parametrize("blank", [" ", " ", "\t", "\n \t"])
|
||||
def test_find_users_request_rejects_whitespace_only_query(blank):
|
||||
# Whitespace-only queries would strip to "" and produce a LIKE "%%" pattern
|
||||
# that enumerates the entire user directory. The validator must reject them.
|
||||
with pytest.raises(ValidationError):
|
||||
FindUsersRequest(query=blank)
|
||||
|
||||
|
||||
def test_find_users_request_strips_query_whitespace():
|
||||
# Validator should normalize the stored query so downstream LIKE patterns
|
||||
# don't carry leading/trailing whitespace.
|
||||
request = FindUsersRequest(query=" maxime ")
|
||||
assert request.query == "maxime"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Filter contract: created_by_fk / changed_by_fk filtering on list tools
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_passes_created_by_fk_filter_to_dao(
|
||||
mock_list, mcp_server
|
||||
):
|
||||
"""list_dashboards should accept created_by_fk filter and forward it."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"list_dashboards",
|
||||
{
|
||||
"request": {
|
||||
"filters": [{"col": "created_by_fk", "opr": "eq", "value": 7}],
|
||||
"page": 1,
|
||||
"page_size": 10,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_list.called
|
||||
forwarded_filters = mock_list.call_args.kwargs.get("column_operators")
|
||||
assert forwarded_filters is not None
|
||||
assert any(
|
||||
getattr(f, "col", None) == "created_by_fk" and getattr(f, "value", None) == 7
|
||||
for f in forwarded_filters
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_charts_passes_changed_by_fk_filter_to_dao(mock_list, mcp_server):
|
||||
"""list_charts should accept changed_by_fk filter and forward it."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool(
|
||||
"list_charts",
|
||||
{
|
||||
"request": {
|
||||
"filters": [{"col": "changed_by_fk", "opr": "eq", "value": 7}],
|
||||
"page": 1,
|
||||
"page_size": 10,
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
assert mock_list.called
|
||||
forwarded_filters = mock_list.call_args.kwargs.get("column_operators")
|
||||
assert forwarded_filters is not None
|
||||
assert any(getattr(f, "col", None) == "changed_by_fk" for f in forwarded_filters)
|
||||
@@ -22,6 +22,8 @@ from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.client.client import CallToolResult
|
||||
from mcp.types import TextContent
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
@@ -45,6 +47,14 @@ get_schema_module = importlib.import_module(
|
||||
"superset.mcp_service.system.tool.get_schema"
|
||||
)
|
||||
|
||||
|
||||
def _result_text(result: CallToolResult) -> str:
|
||||
"""Return the text payload from the first content block of a tool result."""
|
||||
block = result.content[0]
|
||||
assert isinstance(block, TextContent)
|
||||
return block.text
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -198,7 +208,7 @@ async def test_get_schema_returns_structured_privacy_error_for_dataset(mcp_serve
|
||||
{"request": {"model_type": "dataset"}},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
assert data["error_type"] == DATA_MODEL_METADATA_ERROR_TYPE
|
||||
assert data["privacy_scope"] == "data_model"
|
||||
|
||||
@@ -241,7 +251,7 @@ async def test_get_schema_redacts_chart_data_model_fields(mcp_server):
|
||||
{"request": {"model_type": "chart"}},
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
schema_info = data["schema_info"]
|
||||
assert all(
|
||||
column["name"] not in CHART_DATA_MODEL_COLUMNS
|
||||
@@ -389,7 +399,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_instance_info", {"request": {}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
assert "current_user" in data
|
||||
cu = data["current_user"]
|
||||
assert cu["id"] == 5
|
||||
@@ -418,7 +428,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_instance_info", {"request": {}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
assert data["current_user"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -444,7 +454,7 @@ class TestGetInstanceInfoCurrentUserViaMCP:
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_instance_info", {"request": {}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
data = json.loads(_result_text(result))
|
||||
cu = data["current_user"]
|
||||
assert cu["id"] == 99
|
||||
assert cu["username"] == "bot"
|
||||
@@ -460,28 +470,50 @@ class TestGetInstanceInfoCurrentUserViaMCP:
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_chart_filter_rejects_created_by_fk() -> None:
|
||||
"""Test that ChartFilter rejects user-directory columns."""
|
||||
with pytest.raises(ValidationError):
|
||||
ChartFilter(col="created_by_fk", opr="eq", value=42)
|
||||
def test_chart_filter_rejects_user_directory_columns_other_than_fk() -> None:
|
||||
"""ChartFilter still rejects user-directory columns that expose names."""
|
||||
for col in ("created_by_name", "owners", "changed_by"):
|
||||
with pytest.raises(ValidationError):
|
||||
ChartFilter.model_validate({"col": col, "opr": "eq", "value": "anything"})
|
||||
|
||||
|
||||
def test_chart_filter_accepts_created_and_changed_by_fk() -> None:
|
||||
"""ChartFilter allows filtering by created_by_fk / changed_by_fk (user IDs)."""
|
||||
for col in ("created_by_fk", "changed_by_fk"):
|
||||
f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": 42})
|
||||
assert f.col == col
|
||||
|
||||
|
||||
def test_chart_filter_rejects_invalid_column():
|
||||
"""Test that ChartFilter rejects invalid column names."""
|
||||
with pytest.raises(ValidationError):
|
||||
ChartFilter(col="nonexistent_column", opr="eq", value=42)
|
||||
ChartFilter.model_validate(
|
||||
{"col": "nonexistent_column", "opr": "eq", "value": 42}
|
||||
)
|
||||
|
||||
|
||||
def test_dashboard_filter_rejects_created_by_fk():
|
||||
"""Test that DashboardFilter rejects user-directory columns."""
|
||||
with pytest.raises(ValidationError):
|
||||
DashboardFilter(col="created_by_fk", opr="eq", value=42)
|
||||
def test_dashboard_filter_rejects_user_directory_columns_other_than_fk() -> None:
|
||||
"""DashboardFilter still rejects user-directory columns that expose names."""
|
||||
for col in ("created_by_name", "owners", "changed_by"):
|
||||
with pytest.raises(ValidationError):
|
||||
DashboardFilter.model_validate(
|
||||
{"col": col, "opr": "eq", "value": "anything"}
|
||||
)
|
||||
|
||||
|
||||
def test_dashboard_filter_accepts_created_and_changed_by_fk() -> None:
|
||||
"""DashboardFilter allows filtering by created_by_fk / changed_by_fk."""
|
||||
for col in ("created_by_fk", "changed_by_fk"):
|
||||
f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": 42})
|
||||
assert f.col == col
|
||||
|
||||
|
||||
def test_dashboard_filter_rejects_invalid_column():
|
||||
"""Test that DashboardFilter rejects invalid column names."""
|
||||
with pytest.raises(ValidationError):
|
||||
DashboardFilter(col="nonexistent_column", opr="eq", value=42)
|
||||
DashboardFilter.model_validate(
|
||||
{"col": "nonexistent_column", "opr": "eq", "value": 42}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -492,12 +524,12 @@ def test_dashboard_filter_rejects_invalid_column():
|
||||
def test_chart_filter_existing_columns_still_work():
|
||||
"""Test that pre-existing chart filter columns are not broken."""
|
||||
for col in ("slice_name", "viz_type", "datasource_name"):
|
||||
f = ChartFilter(col=col, opr="eq", value="test")
|
||||
f = ChartFilter.model_validate({"col": col, "opr": "eq", "value": "test"})
|
||||
assert f.col == col
|
||||
|
||||
|
||||
def test_dashboard_filter_existing_columns_still_work():
|
||||
"""Test that pre-existing dashboard filter columns are not broken."""
|
||||
for col in ("dashboard_title", "published", "favorite"):
|
||||
f = DashboardFilter(col=col, opr="eq", value="test")
|
||||
f = DashboardFilter.model_validate({"col": col, "opr": "eq", "value": "test"})
|
||||
assert f.col == col
|
||||
|
||||
@@ -326,11 +326,19 @@ class TestGetSchemaToolViaClient:
|
||||
async def test_get_schema_omits_user_directory_columns(
|
||||
self, mock_filters, mcp_server
|
||||
):
|
||||
"""Test that schema discovery does not advertise user/access fields."""
|
||||
"""Test that schema discovery does not advertise user/access fields.
|
||||
|
||||
created_by_fk and changed_by_fk are intentionally allowed in
|
||||
filter_columns so callers can filter by user ID resolved via find_users,
|
||||
but they remain hidden from select_columns and sortable_columns so the
|
||||
directory itself is never exposed.
|
||||
"""
|
||||
mock_filters.return_value = {
|
||||
"dashboard_title": ["eq", "ilike"],
|
||||
"owner": ["rel_m_m"],
|
||||
"published": ["eq"],
|
||||
"created_by_fk": ["eq", "in"],
|
||||
"changed_by_fk": ["eq", "in"],
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
@@ -352,9 +360,16 @@ class TestGetSchemaToolViaClient:
|
||||
"owner",
|
||||
):
|
||||
assert field not in select_column_names
|
||||
assert field not in info["filter_columns"]
|
||||
assert field not in info["sortable_columns"]
|
||||
|
||||
# User-name and relationship fields stay out of filter_columns
|
||||
for field in ("owners", "roles", "created_by", "changed_by", "owner"):
|
||||
assert field not in info["filter_columns"]
|
||||
|
||||
# ID-only filter columns are advertised so callers can filter via find_users
|
||||
assert "created_by_fk" in info["filter_columns"]
|
||||
assert "changed_by_fk" in info["filter_columns"]
|
||||
|
||||
|
||||
class TestGetSchemaEdgeCases:
|
||||
"""Test edge cases for get_schema tool."""
|
||||
|
||||
Reference in New Issue
Block a user