mirror of
https://github.com/apache/superset.git
synced 2026-06-12 11:09:15 +00:00
Compare commits
2 Commits
fix/chart-
...
mcp-tags-9
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ba361ffe88 | ||
|
|
62b33b7421 |
@@ -60,6 +60,7 @@ class ColumnOperatorEnum(str, Enum):
|
||||
ne = "ne"
|
||||
sw = "sw"
|
||||
ew = "ew"
|
||||
ct = "ct"
|
||||
in_ = "in"
|
||||
nin = "nin"
|
||||
gt = "gt"
|
||||
@@ -84,11 +85,12 @@ operator_map: Dict[ColumnOperatorEnum, Any] = {
|
||||
ColumnOperatorEnum.ne: lambda col, val: col != val,
|
||||
ColumnOperatorEnum.sw: lambda col, val: col.like(f"{val}%"),
|
||||
ColumnOperatorEnum.ew: lambda col, val: col.like(f"%{val}"),
|
||||
ColumnOperatorEnum.ct: lambda col, val: col.ilike(f"%{val}%"),
|
||||
ColumnOperatorEnum.in_: lambda col, val: col.in_(
|
||||
val if isinstance(val, (list, tuple)) else [val]
|
||||
),
|
||||
ColumnOperatorEnum.nin: lambda col, val: ~col.in_(
|
||||
val if isinstance(val, (list, tuple)) else [val]
|
||||
ColumnOperatorEnum.nin: lambda col, val: (
|
||||
~col.in_(val if isinstance(val, (list, tuple)) else [val])
|
||||
),
|
||||
ColumnOperatorEnum.gt: lambda col, val: col > val,
|
||||
ColumnOperatorEnum.gte: lambda col, val: col >= val,
|
||||
@@ -107,6 +109,7 @@ TYPE_OPERATOR_MAP = {
|
||||
ColumnOperatorEnum.ne,
|
||||
ColumnOperatorEnum.sw,
|
||||
ColumnOperatorEnum.ew,
|
||||
ColumnOperatorEnum.ct,
|
||||
ColumnOperatorEnum.in_,
|
||||
ColumnOperatorEnum.nin,
|
||||
ColumnOperatorEnum.like,
|
||||
|
||||
@@ -652,6 +652,10 @@ from superset.mcp_service.system.tool import ( # noqa: F401, E402
|
||||
get_schema,
|
||||
health_check,
|
||||
)
|
||||
from superset.mcp_service.tag.tool import ( # noqa: F401, E402
|
||||
get_tag_info,
|
||||
list_tags,
|
||||
)
|
||||
|
||||
|
||||
def _remove_disabled_tools(disabled_tools: set[str]) -> None:
|
||||
|
||||
16
superset/mcp_service/tag/__init__.py
Normal file
16
superset/mcp_service/tag/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
238
superset/mcp_service/tag/schemas.py
Normal file
238
superset/mcp_service/tag/schemas.py
Normal file
@@ -0,0 +1,238 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for tag-related responses
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal
|
||||
|
||||
import humanize
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.system.schemas import (
|
||||
PaginationInfo,
|
||||
TagInfo as BaseTagInfo,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_model_list,
|
||||
)
|
||||
|
||||
|
||||
class TagFilter(ColumnOperator):
|
||||
"""
|
||||
Filter object for tag listing.
|
||||
col: The column to filter on. Must be one of the allowed filter fields.
|
||||
opr: The operator to use. Must be one of the supported operators.
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal["name", "type"] = Field(
|
||||
...,
|
||||
description="Column to filter on. Supported: 'name' (string match), "
|
||||
"'type' (tag type: custom, type, owner, favorited_by).",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
description="Operator to use. Common operators: 'eq' (equals), "
|
||||
"'ct' (contains), 'sw' (starts with), 'ew' (ends with).",
|
||||
)
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by (type depends on col and opr)"
|
||||
)
|
||||
|
||||
|
||||
class TagInfo(BaseTagInfo):
|
||||
"""Extends the shared BaseTagInfo with audit timestamps for MCP list/get tools."""
|
||||
|
||||
changed_on: str | datetime | None = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
changed_on_humanized: str | None = Field(
|
||||
None, description="Humanized modification time"
|
||||
)
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
ser_json_timedelta="iso8601",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
|
||||
"""Filter serialized fields to those requested via select_columns context."""
|
||||
data: Dict[str, Any] = serializer(self)
|
||||
if info.context and isinstance(info.context, dict):
|
||||
select_columns = info.context.get("select_columns")
|
||||
if select_columns:
|
||||
requested_fields = set(select_columns)
|
||||
return {k: v for k, v in data.items() if k in requested_fields}
|
||||
return data
|
||||
|
||||
|
||||
class TagList(BaseModel):
|
||||
tags: List[TagInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: List[str] = Field(default_factory=list)
|
||||
columns_loaded: List[str] = Field(default_factory=list)
|
||||
columns_available: List[str] = Field(default_factory=list)
|
||||
sortable_columns: List[str] = Field(default_factory=list)
|
||||
filters_applied: List[TagFilter] = Field(default_factory=list)
|
||||
pagination: PaginationInfo | None = None
|
||||
timestamp: datetime | None = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListTagsRequest(BaseModel):
|
||||
"""Request schema for list_tags."""
|
||||
|
||||
filters: Annotated[
|
||||
List[TagFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (column, operator, value). Each "
|
||||
"filter has 'col', 'opr', and 'value' properties. Cannot be used "
|
||||
"together with 'search'.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of columns to select. Defaults to common columns if not "
|
||||
"specified.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search string to match against tag name. Cannot be used "
|
||||
"together with 'filters'.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
str | None, Field(default=None, description="Column to order results by")
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(
|
||||
default="desc",
|
||||
description="Direction to order results ('asc' or 'desc')",
|
||||
),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number for pagination (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Number of items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[TagFilter]:
|
||||
return parse_json_or_model_list(v, TagFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v: Any) -> List[str]:
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListTagsRequest":
|
||||
if self.search and self.filters:
|
||||
raise ValueError(
|
||||
"Cannot use both 'search' and 'filters' parameters simultaneously. "
|
||||
"Use either 'search' for text-based searching or 'filters' for "
|
||||
"precise column-based filtering, but not both."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class TagError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: str | datetime | None = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
@classmethod
|
||||
def create(cls, error: str, error_type: str) -> "TagError":
|
||||
from datetime import timezone
|
||||
|
||||
return cls(
|
||||
error=error, error_type=error_type, timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
|
||||
|
||||
class GetTagInfoRequest(BaseModel):
|
||||
"""Request schema for get_tag_info with numeric ID."""
|
||||
|
||||
identifier: Annotated[
|
||||
int,
|
||||
Field(description="Tag identifier — numeric ID"),
|
||||
]
|
||||
|
||||
|
||||
def _humanize_timestamp(dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
return None
|
||||
now = datetime.now(dt.tzinfo) if dt.tzinfo else datetime.now()
|
||||
return humanize.naturaltime(now - dt)
|
||||
|
||||
|
||||
def serialize_tag_object(tag: Any) -> TagInfo | None:
|
||||
if not tag:
|
||||
return None
|
||||
|
||||
type_str: str | None = None
|
||||
if (raw_type := getattr(tag, "type", None)) is not None:
|
||||
type_str = raw_type.name if hasattr(raw_type, "name") else str(raw_type)
|
||||
|
||||
return TagInfo(
|
||||
id=getattr(tag, "id", None),
|
||||
name=getattr(tag, "name", None),
|
||||
type=type_str,
|
||||
description=getattr(tag, "description", None),
|
||||
changed_on=getattr(tag, "changed_on", None),
|
||||
changed_on_humanized=_humanize_timestamp(getattr(tag, "changed_on", None)),
|
||||
)
|
||||
24
superset/mcp_service/tag/tool/__init__.py
Normal file
24
superset/mcp_service/tag/tool/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_tag_info import get_tag_info
|
||||
from .list_tags import list_tags
|
||||
|
||||
__all__ = [
|
||||
"get_tag_info",
|
||||
"list_tags",
|
||||
]
|
||||
108
superset/mcp_service/tag/tool/get_tag_info.py
Normal file
108
superset/mcp_service/tag/tool/get_tag_info.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get tag info FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for getting detailed information
|
||||
about a specific tag.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.tag.schemas import (
|
||||
GetTagInfoRequest,
|
||||
serialize_tag_object,
|
||||
TagError,
|
||||
TagInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="Tag",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get tag info",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_tag_info(request: GetTagInfoRequest, ctx: Context) -> TagInfo | TagError:
|
||||
"""Get tag metadata by numeric ID.
|
||||
|
||||
Returns tag details including name, type, and description.
|
||||
|
||||
Tag types: custom (user-created), type (implicit by object type),
|
||||
owner (implicit by ownership), favorited_by (implicit by favorites).
|
||||
|
||||
To find a tag ID, use the list_tags tool first.
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{
|
||||
"identifier": 1
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info("Retrieving tag information: identifier=%s" % (request.identifier,))
|
||||
|
||||
try:
|
||||
from superset.daos.tag import TagDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_tag_info.lookup"):
|
||||
get_tool = ModelGetInfoCore(
|
||||
dao_class=TagDAO,
|
||||
output_schema=TagInfo,
|
||||
error_schema=TagError,
|
||||
serializer=serialize_tag_object,
|
||||
supports_slug=False,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
result = get_tool.run_tool(request.identifier)
|
||||
|
||||
if isinstance(result, TagInfo):
|
||||
await ctx.info(
|
||||
"Tag information retrieved successfully: tag_id=%s, name=%s, type=%s"
|
||||
% (result.id, result.name, result.type)
|
||||
)
|
||||
else:
|
||||
await ctx.warning(
|
||||
"Tag retrieval failed: error_type=%s, error=%s"
|
||||
% (result.error_type, result.error)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Tag information retrieval failed: identifier=%s, error=%s, error_type=%s"
|
||||
% (request.identifier, str(e), type(e).__name__)
|
||||
)
|
||||
return TagError(
|
||||
error=f"Failed to get tag info: {str(e)}",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
148
superset/mcp_service/tag/tool/list_tags.py
Normal file
148
superset/mcp_service/tag/tool/list_tags.py
Normal file
@@ -0,0 +1,148 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
List tags FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for listing tags with filtering,
|
||||
search, and pagination support.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.tag.schemas import (
|
||||
ListTagsRequest,
|
||||
serialize_tag_object,
|
||||
TagError,
|
||||
TagFilter,
|
||||
TagInfo,
|
||||
TagList,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_TAG_COLUMNS = ["id", "name", "type"]
|
||||
SORTABLE_TAG_COLUMNS = ["id", "name", "changed_on"]
|
||||
ALL_TAG_COLUMNS = ["id", "name", "type", "description", "changed_on"]
|
||||
|
||||
_DEFAULT_LIST_TAGS_REQUEST = ListTagsRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
class_permission_name="Tag",
|
||||
annotations=ToolAnnotations(
|
||||
title="List tags",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def list_tags(
|
||||
request: ListTagsRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> TagList | TagError:
|
||||
"""List tags with filtering and search.
|
||||
|
||||
Returns tag metadata including name, type, and description.
|
||||
|
||||
Tag types: custom (user-created), type (implicit by object type),
|
||||
owner (implicit by ownership), favorited_by (implicit by favorites).
|
||||
|
||||
Sortable columns for order_column: id, name, changed_on
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_tags")
|
||||
|
||||
request = request or _DEFAULT_LIST_TAGS_REQUEST.model_copy(deep=True)
|
||||
|
||||
await ctx.info(
|
||||
"Listing tags: page=%s, page_size=%s, search=%s"
|
||||
% (request.page, request.page_size, request.search)
|
||||
)
|
||||
await ctx.debug(
|
||||
"Tag listing parameters: filters=%s, order_column=%s, "
|
||||
"order_direction=%s, select_columns=%s"
|
||||
% (
|
||||
request.filters,
|
||||
request.order_column,
|
||||
request.order_direction,
|
||||
request.select_columns,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.tag import TagDAO
|
||||
|
||||
def _serialize_tag(obj: object, cols: list[str] | None) -> TagInfo | None:
|
||||
return serialize_tag_object(obj)
|
||||
|
||||
list_tool = ModelListCore(
|
||||
dao_class=TagDAO,
|
||||
output_schema=TagInfo,
|
||||
item_serializer=_serialize_tag,
|
||||
filter_type=TagFilter,
|
||||
default_columns=DEFAULT_TAG_COLUMNS,
|
||||
search_columns=["name"],
|
||||
list_field_name="tags",
|
||||
output_list_schema=TagList,
|
||||
all_columns=ALL_TAG_COLUMNS,
|
||||
sortable_columns=SORTABLE_TAG_COLUMNS,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
with event_logger.log_context(action="mcp.list_tags.query"):
|
||||
result = list_tool.run_tool(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=request.order_column,
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Tags listed successfully: count=%s, total_count=%s, total_pages=%s"
|
||||
% (
|
||||
len(result.tags) if hasattr(result, "tags") else 0,
|
||||
getattr(result, "total_count", None),
|
||||
getattr(result, "total_pages", None),
|
||||
)
|
||||
)
|
||||
|
||||
columns_to_filter = result.columns_requested
|
||||
await ctx.debug(
|
||||
"Applying field filtering via serialization context: columns=%s"
|
||||
% (columns_to_filter,)
|
||||
)
|
||||
with event_logger.log_context(action="mcp.list_tags.serialization"):
|
||||
return result.model_dump(
|
||||
mode="json",
|
||||
context={"select_columns": columns_to_filter},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Tag listing failed: page=%s, page_size=%s, error=%s, error_type=%s"
|
||||
% (request.page, request.page_size, str(e), type(e).__name__)
|
||||
)
|
||||
raise
|
||||
@@ -73,6 +73,12 @@ def test_column_operator_enum_apply_method() -> None: # noqa: C901
|
||||
(ColumnOperatorEnum.ne, TestModel.name, "test", "test_model.name != 'test'"),
|
||||
(ColumnOperatorEnum.sw, TestModel.name, "test", "test_model.name LIKE 'test%'"),
|
||||
(ColumnOperatorEnum.ew, TestModel.name, "test", "test_model.name LIKE '%test'"),
|
||||
(
|
||||
ColumnOperatorEnum.ct,
|
||||
TestModel.name,
|
||||
"test",
|
||||
"lower(test_model.name) LIKE lower('%test%')",
|
||||
),
|
||||
(ColumnOperatorEnum.in_, TestModel.id, [1, 2, 3], "test_model.id IN (1, 2, 3)"),
|
||||
(
|
||||
ColumnOperatorEnum.nin,
|
||||
@@ -130,6 +136,7 @@ def test_column_operator_enum_apply_method() -> None: # noqa: C901
|
||||
ColumnOperatorEnum.ne,
|
||||
ColumnOperatorEnum.sw,
|
||||
ColumnOperatorEnum.ew,
|
||||
ColumnOperatorEnum.ct,
|
||||
ColumnOperatorEnum.in_,
|
||||
ColumnOperatorEnum.nin,
|
||||
ColumnOperatorEnum.gt,
|
||||
|
||||
16
tests/unit_tests/mcp_service/tag/__init__.py
Normal file
16
tests/unit_tests/mcp_service/tag/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/tag/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/tag/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
253
tests/unit_tests/mcp_service/tag/tool/test_tag_tools.py
Normal file
253
tests/unit_tests/mcp_service/tag/tool/test_tag_tools.py
Normal file
@@ -0,0 +1,253 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.tag.schemas import ListTagsRequest, TagFilter
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestTagFilterSchema:
|
||||
"""Tests for TagFilter schema — filterable columns."""
|
||||
|
||||
def test_invalid_filter_column_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TagFilter(col="not_a_real_column", opr="eq", value="x")
|
||||
|
||||
def test_valid_name_filter(self):
|
||||
f = TagFilter(col="name", opr="ct", value="finance")
|
||||
assert f.col == "name"
|
||||
|
||||
def test_valid_type_filter(self):
|
||||
f = TagFilter(col="type", opr="eq", value="custom")
|
||||
assert f.col == "type"
|
||||
|
||||
|
||||
def create_mock_tag(
|
||||
tag_id: int = 1,
|
||||
name: str = "finance",
|
||||
type_name: str = "custom",
|
||||
description: str = "Finance related",
|
||||
) -> MagicMock:
|
||||
tag = MagicMock()
|
||||
tag.id = tag_id
|
||||
tag.name = name
|
||||
mock_type = MagicMock()
|
||||
mock_type.name = type_name
|
||||
tag.type = mock_type
|
||||
tag.description = description
|
||||
tag.changed_on = None
|
||||
tag.created_on = None
|
||||
tag.changed_by = None
|
||||
tag.created_by = None
|
||||
return tag
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
from unittest.mock import Mock
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tags_basic(mock_list, mcp_server):
|
||||
"""Test basic tag listing functionality."""
|
||||
tag = create_mock_tag()
|
||||
mock_list.return_value = ([tag], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListTagsRequest(page=1, page_size=10)
|
||||
result = await client.call_tool("list_tags", {"request": request.model_dump()})
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["tags"] is not None
|
||||
assert len(data["tags"]) == 1
|
||||
assert data["tags"][0]["id"] == 1
|
||||
assert data["tags"][0]["name"] == "finance"
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tags_without_request(mock_list, mcp_server):
|
||||
"""Test listing tags with no request payload uses defaults."""
|
||||
tag = create_mock_tag()
|
||||
mock_list.return_value = ([tag], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_tags", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["tags"] is not None
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tags_with_search(mock_list, mcp_server):
|
||||
"""Test tag listing with search functionality."""
|
||||
tag = create_mock_tag(name="sales")
|
||||
mock_list.return_value = ([tag], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListTagsRequest(page=1, page_size=10, search="sales")
|
||||
result = await client.call_tool("list_tags", {"request": request.model_dump()})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["tags"][0]["name"] == "sales"
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tags_with_filters(mock_list, mcp_server):
|
||||
"""Test tag listing with column filters."""
|
||||
tag = create_mock_tag(type_name="custom")
|
||||
mock_list.return_value = ([tag], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListTagsRequest(
|
||||
page=1,
|
||||
page_size=10,
|
||||
filters=[{"col": "type", "opr": "eq", "value": "custom"}],
|
||||
)
|
||||
result = await client.call_tool("list_tags", {"request": request.model_dump()})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert len(data["tags"]) == 1
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tags_empty_results(mock_list, mcp_server):
|
||||
"""Test tag listing with no results."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListTagsRequest(page=1, page_size=10)
|
||||
result = await client.call_tool("list_tags", {"request": request.model_dump()})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["tags"] == []
|
||||
assert data["total_count"] == 0
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tags_api_error(mock_list, mcp_server):
|
||||
"""Test error handling when DAO raises an exception."""
|
||||
mock_list.side_effect = ToolError("Tag DAO error")
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListTagsRequest(page=1, page_size=10)
|
||||
with pytest.raises(ToolError) as excinfo: # noqa: PT012
|
||||
await client.call_tool("list_tags", {"request": request.model_dump()})
|
||||
assert "Tag DAO error" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_list_tags_search_and_filters_mutually_exclusive():
|
||||
"""Test that search and filters cannot be used together."""
|
||||
with pytest.raises(ValidationError):
|
||||
ListTagsRequest(
|
||||
search="finance",
|
||||
filters=[{"col": "name", "opr": "eq", "value": "finance"}],
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tag_info_basic(mock_find, mcp_server):
|
||||
"""Test basic get tag info functionality."""
|
||||
tag = create_mock_tag()
|
||||
mock_find.return_value = tag
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_tag_info", {"request": {"identifier": 1}})
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 1
|
||||
assert data["name"] == "finance"
|
||||
assert data["type"] == "custom"
|
||||
assert data["description"] == "Finance related"
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tag_info_not_found(mock_find, mcp_server):
|
||||
"""Test get tag info when tag does not exist."""
|
||||
mock_find.return_value = None
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_tag_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_tag_info_serializes_type_name(mock_find, mcp_server):
|
||||
"""Test that the tag type enum is serialized as its name string."""
|
||||
tag = create_mock_tag(type_name="owner")
|
||||
mock_find.return_value = tag
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_tag_info", {"request": {"identifier": 2}})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["type"] == "owner"
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tags_select_columns_filters_response(mock_list, mcp_server):
|
||||
"""select_columns restricts the fields returned in each tag object."""
|
||||
tag = create_mock_tag()
|
||||
mock_list.return_value = ([tag], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListTagsRequest(page=1, page_size=10, select_columns=["id", "name"])
|
||||
result = await client.call_tool("list_tags", {"request": request.model_dump()})
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["columns_requested"] == ["id", "name"]
|
||||
tag_obj = data["tags"][0]
|
||||
assert set(tag_obj.keys()) == {"id", "name"}
|
||||
assert "type" not in tag_obj
|
||||
assert "description" not in tag_obj
|
||||
assert "changed_on" not in tag_obj
|
||||
|
||||
|
||||
@patch("superset.daos.tag.TagDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tags_default_columns_are_id_name_type(mock_list, mcp_server):
|
||||
"""Default response includes id, name, type but not description or timestamps."""
|
||||
tag = create_mock_tag()
|
||||
mock_list.return_value = ([tag], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_tags", {})
|
||||
data = json.loads(result.content[0].text)
|
||||
tag_obj = data["tags"][0]
|
||||
assert "id" in tag_obj
|
||||
assert "name" in tag_obj
|
||||
assert "type" in tag_obj
|
||||
assert "description" not in tag_obj
|
||||
assert "changed_on" not in tag_obj
|
||||
Reference in New Issue
Block a user