Compare commits

...

2 Commits

Author SHA1 Message Date
Amin Ghadersohi
ba361ffe88 fix: add ct (contains) operator to ColumnOperatorEnum for TagFilter support
TagFilter documents 'ct' as a valid operator but ColumnOperatorEnum was
missing it, causing ValidationError when constructing TagFilter with opr="ct".
Added ct = "ct" with ilike("%val%") mapping, updated TYPE_OPERATOR_MAP for
string columns, and updated base_dao_test to cover the new operator.
2026-05-21 21:14:40 +00:00
Amin Ghadersohi
62b33b7421 feat(mcp): add list and get tools for tags
Adds list_tags and get_tag_info MCP tools in a new superset/mcp_service/tag/
domain, following the existing database/dataset/dashboard patterns.
Registers both tools in app.py and covers them with unit tests.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 21:14:40 +00:00
11 changed files with 835 additions and 2 deletions

View File

@@ -60,6 +60,7 @@ class ColumnOperatorEnum(str, Enum):
ne = "ne"
sw = "sw"
ew = "ew"
ct = "ct"
in_ = "in"
nin = "nin"
gt = "gt"
@@ -84,11 +85,12 @@ operator_map: Dict[ColumnOperatorEnum, Any] = {
ColumnOperatorEnum.ne: lambda col, val: col != val,
ColumnOperatorEnum.sw: lambda col, val: col.like(f"{val}%"),
ColumnOperatorEnum.ew: lambda col, val: col.like(f"%{val}"),
ColumnOperatorEnum.ct: lambda col, val: col.ilike(f"%{val}%"),
ColumnOperatorEnum.in_: lambda col, val: col.in_(
val if isinstance(val, (list, tuple)) else [val]
),
ColumnOperatorEnum.nin: lambda col, val: ~col.in_(
val if isinstance(val, (list, tuple)) else [val]
ColumnOperatorEnum.nin: lambda col, val: (
~col.in_(val if isinstance(val, (list, tuple)) else [val])
),
ColumnOperatorEnum.gt: lambda col, val: col > val,
ColumnOperatorEnum.gte: lambda col, val: col >= val,
@@ -107,6 +109,7 @@ TYPE_OPERATOR_MAP = {
ColumnOperatorEnum.ne,
ColumnOperatorEnum.sw,
ColumnOperatorEnum.ew,
ColumnOperatorEnum.ct,
ColumnOperatorEnum.in_,
ColumnOperatorEnum.nin,
ColumnOperatorEnum.like,

View File

@@ -652,6 +652,10 @@ from superset.mcp_service.system.tool import ( # noqa: F401, E402
get_schema,
health_check,
)
from superset.mcp_service.tag.tool import ( # noqa: F401, E402
get_tag_info,
list_tags,
)
def _remove_disabled_tools(disabled_tools: set[str]) -> None:

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,238 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Pydantic schemas for tag-related responses
"""
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal
import humanize
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_serializer,
model_validator,
PositiveInt,
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.system.schemas import (
PaginationInfo,
TagInfo as BaseTagInfo,
)
from superset.mcp_service.utils.schema_utils import (
parse_json_or_list,
parse_json_or_model_list,
)
class TagFilter(ColumnOperator):
"""
Filter object for tag listing.
col: The column to filter on. Must be one of the allowed filter fields.
opr: The operator to use. Must be one of the supported operators.
value: The value to filter by (type depends on col and opr).
"""
col: Literal["name", "type"] = Field(
...,
description="Column to filter on. Supported: 'name' (string match), "
"'type' (tag type: custom, type, owner, favorited_by).",
)
opr: ColumnOperatorEnum = Field(
...,
description="Operator to use. Common operators: 'eq' (equals), "
"'ct' (contains), 'sw' (starts with), 'ew' (ends with).",
)
value: str | int | float | bool | List[str | int | float | bool] = Field(
..., description="Value to filter by (type depends on col and opr)"
)
class TagInfo(BaseTagInfo):
"""Extends the shared BaseTagInfo with audit timestamps for MCP list/get tools."""
changed_on: str | datetime | None = Field(
None, description="Last modification timestamp"
)
changed_on_humanized: str | None = Field(
None, description="Humanized modification time"
)
model_config = ConfigDict(
from_attributes=True,
ser_json_timedelta="iso8601",
populate_by_name=True,
)
@model_serializer(mode="wrap")
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
"""Filter serialized fields to those requested via select_columns context."""
data: Dict[str, Any] = serializer(self)
if info.context and isinstance(info.context, dict):
select_columns = info.context.get("select_columns")
if select_columns:
requested_fields = set(select_columns)
return {k: v for k, v in data.items() if k in requested_fields}
return data
class TagList(BaseModel):
tags: List[TagInfo]
count: int
total_count: int
page: int
page_size: int
total_pages: int
has_previous: bool
has_next: bool
columns_requested: List[str] = Field(default_factory=list)
columns_loaded: List[str] = Field(default_factory=list)
columns_available: List[str] = Field(default_factory=list)
sortable_columns: List[str] = Field(default_factory=list)
filters_applied: List[TagFilter] = Field(default_factory=list)
pagination: PaginationInfo | None = None
timestamp: datetime | None = None
model_config = ConfigDict(ser_json_timedelta="iso8601")
class ListTagsRequest(BaseModel):
"""Request schema for list_tags."""
filters: Annotated[
List[TagFilter],
Field(
default_factory=list,
description="List of filter objects (column, operator, value). Each "
"filter has 'col', 'opr', and 'value' properties. Cannot be used "
"together with 'search'.",
),
]
select_columns: Annotated[
List[str],
Field(
default_factory=list,
description="List of columns to select. Defaults to common columns if not "
"specified.",
),
]
search: Annotated[
str | None,
Field(
default=None,
description="Text search string to match against tag name. Cannot be used "
"together with 'filters'.",
),
]
order_column: Annotated[
str | None, Field(default=None, description="Column to order results by")
]
order_direction: Annotated[
Literal["asc", "desc"],
Field(
default="desc",
description="Direction to order results ('asc' or 'desc')",
),
]
page: Annotated[
PositiveInt,
Field(default=1, description="Page number for pagination (1-based)"),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Number of items per page (max {MAX_PAGE_SIZE})",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> List[TagFilter]:
return parse_json_or_model_list(v, TagFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_columns(cls, v: Any) -> List[str]:
return parse_json_or_list(v, "select_columns")
@model_validator(mode="after")
def validate_search_and_filters(self) -> "ListTagsRequest":
if self.search and self.filters:
raise ValueError(
"Cannot use both 'search' and 'filters' parameters simultaneously. "
"Use either 'search' for text-based searching or 'filters' for "
"precise column-based filtering, but not both."
)
return self
class TagError(BaseModel):
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Type of error")
timestamp: str | datetime | None = Field(None, description="Error timestamp")
model_config = ConfigDict(ser_json_timedelta="iso8601")
@classmethod
def create(cls, error: str, error_type: str) -> "TagError":
from datetime import timezone
return cls(
error=error, error_type=error_type, timestamp=datetime.now(timezone.utc)
)
class GetTagInfoRequest(BaseModel):
"""Request schema for get_tag_info with numeric ID."""
identifier: Annotated[
int,
Field(description="Tag identifier — numeric ID"),
]
def _humanize_timestamp(dt: datetime | None) -> str | None:
if dt is None:
return None
now = datetime.now(dt.tzinfo) if dt.tzinfo else datetime.now()
return humanize.naturaltime(now - dt)
def serialize_tag_object(tag: Any) -> TagInfo | None:
if not tag:
return None
type_str: str | None = None
if (raw_type := getattr(tag, "type", None)) is not None:
type_str = raw_type.name if hasattr(raw_type, "name") else str(raw_type)
return TagInfo(
id=getattr(tag, "id", None),
name=getattr(tag, "name", None),
type=type_str,
description=getattr(tag, "description", None),
changed_on=getattr(tag, "changed_on", None),
changed_on_humanized=_humanize_timestamp(getattr(tag, "changed_on", None)),
)

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .get_tag_info import get_tag_info
from .list_tags import list_tags
__all__ = [
"get_tag_info",
"list_tags",
]

View File

@@ -0,0 +1,108 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Get tag info FastMCP tool
This module contains the FastMCP tool for getting detailed information
about a specific tag.
"""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.tag.schemas import (
GetTagInfoRequest,
serialize_tag_object,
TagError,
TagInfo,
)
logger = logging.getLogger(__name__)
@tool(
tags=["discovery"],
class_permission_name="Tag",
annotations=ToolAnnotations(
title="Get tag info",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_tag_info(request: GetTagInfoRequest, ctx: Context) -> TagInfo | TagError:
"""Get tag metadata by numeric ID.
Returns tag details including name, type, and description.
Tag types: custom (user-created), type (implicit by object type),
owner (implicit by ownership), favorited_by (implicit by favorites).
To find a tag ID, use the list_tags tool first.
Example usage:
```json
{
"identifier": 1
}
```
"""
await ctx.info("Retrieving tag information: identifier=%s" % (request.identifier,))
try:
from superset.daos.tag import TagDAO
with event_logger.log_context(action="mcp.get_tag_info.lookup"):
get_tool = ModelGetInfoCore(
dao_class=TagDAO,
output_schema=TagInfo,
error_schema=TagError,
serializer=serialize_tag_object,
supports_slug=False,
logger=logger,
)
result = get_tool.run_tool(request.identifier)
if isinstance(result, TagInfo):
await ctx.info(
"Tag information retrieved successfully: tag_id=%s, name=%s, type=%s"
% (result.id, result.name, result.type)
)
else:
await ctx.warning(
"Tag retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result
except Exception as e:
await ctx.error(
"Tag information retrieval failed: identifier=%s, error=%s, error_type=%s"
% (request.identifier, str(e), type(e).__name__)
)
return TagError(
error=f"Failed to get tag info: {str(e)}",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,148 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
List tags FastMCP tool
This module contains the FastMCP tool for listing tags with filtering,
search, and pagination support.
"""
import logging
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.tag.schemas import (
ListTagsRequest,
serialize_tag_object,
TagError,
TagFilter,
TagInfo,
TagList,
)
logger = logging.getLogger(__name__)
DEFAULT_TAG_COLUMNS = ["id", "name", "type"]
SORTABLE_TAG_COLUMNS = ["id", "name", "changed_on"]
ALL_TAG_COLUMNS = ["id", "name", "type", "description", "changed_on"]
_DEFAULT_LIST_TAGS_REQUEST = ListTagsRequest()
@tool(
tags=["core"],
class_permission_name="Tag",
annotations=ToolAnnotations(
title="List tags",
readOnlyHint=True,
destructiveHint=False,
),
)
async def list_tags(
request: ListTagsRequest | None = None,
ctx: Context | None = None,
) -> TagList | TagError:
"""List tags with filtering and search.
Returns tag metadata including name, type, and description.
Tag types: custom (user-created), type (implicit by object type),
owner (implicit by ownership), favorited_by (implicit by favorites).
Sortable columns for order_column: id, name, changed_on
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_tags")
request = request or _DEFAULT_LIST_TAGS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing tags: page=%s, page_size=%s, search=%s"
% (request.page, request.page_size, request.search)
)
await ctx.debug(
"Tag listing parameters: filters=%s, order_column=%s, "
"order_direction=%s, select_columns=%s"
% (
request.filters,
request.order_column,
request.order_direction,
request.select_columns,
)
)
try:
from superset.daos.tag import TagDAO
def _serialize_tag(obj: object, cols: list[str] | None) -> TagInfo | None:
return serialize_tag_object(obj)
list_tool = ModelListCore(
dao_class=TagDAO,
output_schema=TagInfo,
item_serializer=_serialize_tag,
filter_type=TagFilter,
default_columns=DEFAULT_TAG_COLUMNS,
search_columns=["name"],
list_field_name="tags",
output_list_schema=TagList,
all_columns=ALL_TAG_COLUMNS,
sortable_columns=SORTABLE_TAG_COLUMNS,
logger=logger,
)
with event_logger.log_context(action="mcp.list_tags.query"):
result = list_tool.run_tool(
filters=request.filters,
search=request.search,
select_columns=request.select_columns,
order_column=request.order_column,
order_direction=request.order_direction,
page=max(request.page - 1, 0),
page_size=request.page_size,
)
await ctx.info(
"Tags listed successfully: count=%s, total_count=%s, total_pages=%s"
% (
len(result.tags) if hasattr(result, "tags") else 0,
getattr(result, "total_count", None),
getattr(result, "total_pages", None),
)
)
columns_to_filter = result.columns_requested
await ctx.debug(
"Applying field filtering via serialization context: columns=%s"
% (columns_to_filter,)
)
with event_logger.log_context(action="mcp.list_tags.serialization"):
return result.model_dump(
mode="json",
context={"select_columns": columns_to_filter},
)
except Exception as e:
await ctx.error(
"Tag listing failed: page=%s, page_size=%s, error=%s, error_type=%s"
% (request.page, request.page_size, str(e), type(e).__name__)
)
raise

View File

@@ -73,6 +73,12 @@ def test_column_operator_enum_apply_method() -> None: # noqa: C901
(ColumnOperatorEnum.ne, TestModel.name, "test", "test_model.name != 'test'"),
(ColumnOperatorEnum.sw, TestModel.name, "test", "test_model.name LIKE 'test%'"),
(ColumnOperatorEnum.ew, TestModel.name, "test", "test_model.name LIKE '%test'"),
(
ColumnOperatorEnum.ct,
TestModel.name,
"test",
"lower(test_model.name) LIKE lower('%test%')",
),
(ColumnOperatorEnum.in_, TestModel.id, [1, 2, 3], "test_model.id IN (1, 2, 3)"),
(
ColumnOperatorEnum.nin,
@@ -130,6 +136,7 @@ def test_column_operator_enum_apply_method() -> None: # noqa: C901
ColumnOperatorEnum.ne,
ColumnOperatorEnum.sw,
ColumnOperatorEnum.ew,
ColumnOperatorEnum.ct,
ColumnOperatorEnum.in_,
ColumnOperatorEnum.nin,
ColumnOperatorEnum.gt,

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,253 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from unittest.mock import MagicMock, patch
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from pydantic import ValidationError
from superset.mcp_service.app import mcp
from superset.mcp_service.tag.schemas import ListTagsRequest, TagFilter
from superset.utils import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class TestTagFilterSchema:
"""Tests for TagFilter schema — filterable columns."""
def test_invalid_filter_column_rejected(self):
with pytest.raises(ValidationError):
TagFilter(col="not_a_real_column", opr="eq", value="x")
def test_valid_name_filter(self):
f = TagFilter(col="name", opr="ct", value="finance")
assert f.col == "name"
def test_valid_type_filter(self):
f = TagFilter(col="type", opr="eq", value="custom")
assert f.col == "type"
def create_mock_tag(
tag_id: int = 1,
name: str = "finance",
type_name: str = "custom",
description: str = "Finance related",
) -> MagicMock:
tag = MagicMock()
tag.id = tag_id
tag.name = name
mock_type = MagicMock()
mock_type.name = type_name
tag.type = mock_type
tag.description = description
tag.changed_on = None
tag.created_on = None
tag.changed_by = None
tag.created_by = None
return tag
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
from unittest.mock import Mock
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
@patch("superset.daos.tag.TagDAO.list")
@pytest.mark.asyncio
async def test_list_tags_basic(mock_list, mcp_server):
"""Test basic tag listing functionality."""
tag = create_mock_tag()
mock_list.return_value = ([tag], 1)
async with Client(mcp_server) as client:
request = ListTagsRequest(page=1, page_size=10)
result = await client.call_tool("list_tags", {"request": request.model_dump()})
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["tags"] is not None
assert len(data["tags"]) == 1
assert data["tags"][0]["id"] == 1
assert data["tags"][0]["name"] == "finance"
@patch("superset.daos.tag.TagDAO.list")
@pytest.mark.asyncio
async def test_list_tags_without_request(mock_list, mcp_server):
"""Test listing tags with no request payload uses defaults."""
tag = create_mock_tag()
mock_list.return_value = ([tag], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_tags", {})
data = json.loads(result.content[0].text)
assert data["tags"] is not None
@patch("superset.daos.tag.TagDAO.list")
@pytest.mark.asyncio
async def test_list_tags_with_search(mock_list, mcp_server):
"""Test tag listing with search functionality."""
tag = create_mock_tag(name="sales")
mock_list.return_value = ([tag], 1)
async with Client(mcp_server) as client:
request = ListTagsRequest(page=1, page_size=10, search="sales")
result = await client.call_tool("list_tags", {"request": request.model_dump()})
data = json.loads(result.content[0].text)
assert data["tags"][0]["name"] == "sales"
@patch("superset.daos.tag.TagDAO.list")
@pytest.mark.asyncio
async def test_list_tags_with_filters(mock_list, mcp_server):
"""Test tag listing with column filters."""
tag = create_mock_tag(type_name="custom")
mock_list.return_value = ([tag], 1)
async with Client(mcp_server) as client:
request = ListTagsRequest(
page=1,
page_size=10,
filters=[{"col": "type", "opr": "eq", "value": "custom"}],
)
result = await client.call_tool("list_tags", {"request": request.model_dump()})
data = json.loads(result.content[0].text)
assert len(data["tags"]) == 1
@patch("superset.daos.tag.TagDAO.list")
@pytest.mark.asyncio
async def test_list_tags_empty_results(mock_list, mcp_server):
"""Test tag listing with no results."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
request = ListTagsRequest(page=1, page_size=10)
result = await client.call_tool("list_tags", {"request": request.model_dump()})
data = json.loads(result.content[0].text)
assert data["tags"] == []
assert data["total_count"] == 0
@patch("superset.daos.tag.TagDAO.list")
@pytest.mark.asyncio
async def test_list_tags_api_error(mock_list, mcp_server):
"""Test error handling when DAO raises an exception."""
mock_list.side_effect = ToolError("Tag DAO error")
async with Client(mcp_server) as client:
request = ListTagsRequest(page=1, page_size=10)
with pytest.raises(ToolError) as excinfo: # noqa: PT012
await client.call_tool("list_tags", {"request": request.model_dump()})
assert "Tag DAO error" in str(excinfo.value)
def test_list_tags_search_and_filters_mutually_exclusive():
"""Test that search and filters cannot be used together."""
with pytest.raises(ValidationError):
ListTagsRequest(
search="finance",
filters=[{"col": "name", "opr": "eq", "value": "finance"}],
)
@patch("superset.daos.tag.TagDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_tag_info_basic(mock_find, mcp_server):
"""Test basic get tag info functionality."""
tag = create_mock_tag()
mock_find.return_value = tag
async with Client(mcp_server) as client:
result = await client.call_tool("get_tag_info", {"request": {"identifier": 1}})
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["name"] == "finance"
assert data["type"] == "custom"
assert data["description"] == "Finance related"
@patch("superset.daos.tag.TagDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_tag_info_not_found(mock_find, mcp_server):
"""Test get tag info when tag does not exist."""
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_tag_info", {"request": {"identifier": 999}}
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "not_found"
@patch("superset.daos.tag.TagDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_tag_info_serializes_type_name(mock_find, mcp_server):
"""Test that the tag type enum is serialized as its name string."""
tag = create_mock_tag(type_name="owner")
mock_find.return_value = tag
async with Client(mcp_server) as client:
result = await client.call_tool("get_tag_info", {"request": {"identifier": 2}})
data = json.loads(result.content[0].text)
assert data["type"] == "owner"
@patch("superset.daos.tag.TagDAO.list")
@pytest.mark.asyncio
async def test_list_tags_select_columns_filters_response(mock_list, mcp_server):
"""select_columns restricts the fields returned in each tag object."""
tag = create_mock_tag()
mock_list.return_value = ([tag], 1)
async with Client(mcp_server) as client:
request = ListTagsRequest(page=1, page_size=10, select_columns=["id", "name"])
result = await client.call_tool("list_tags", {"request": request.model_dump()})
data = json.loads(result.content[0].text)
assert data["columns_requested"] == ["id", "name"]
tag_obj = data["tags"][0]
assert set(tag_obj.keys()) == {"id", "name"}
assert "type" not in tag_obj
assert "description" not in tag_obj
assert "changed_on" not in tag_obj
@patch("superset.daos.tag.TagDAO.list")
@pytest.mark.asyncio
async def test_list_tags_default_columns_are_id_name_type(mock_list, mcp_server):
"""Default response includes id, name, type but not description or timestamps."""
tag = create_mock_tag()
mock_list.return_value = ([tag], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_tags", {})
data = json.loads(result.content[0].text)
tag_obj = data["tags"][0]
assert "id" in tag_obj
assert "name" in tag_obj
assert "type" in tag_obj
assert "description" not in tag_obj
assert "changed_on" not in tag_obj