mirror of
https://github.com/apache/superset.git
synced 2026-06-11 18:49:15 +00:00
Compare commits
6 Commits
tanstack-r
...
amin/mcp-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b491bd167b | ||
|
|
40767a4397 | ||
|
|
d85f8e7cf2 | ||
|
|
56a844970e | ||
|
|
46ae254444 | ||
|
|
e0f01f3a23 |
@@ -365,3 +365,164 @@ def serialize_annotation(obj: Any) -> AnnotationInfo | None:
|
||||
layer_id=getattr(obj, "layer_id", None),
|
||||
)
|
||||
)
|
||||
class CreateLayerAnnotationRequest(BaseModel):
|
||||
"""Request schema for create_layer_annotation."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
layer_id: int = Field(
|
||||
...,
|
||||
description="ID of the annotation layer to add the annotation to.",
|
||||
)
|
||||
short_descr: str = Field(
|
||||
...,
|
||||
min_length=1,
|
||||
max_length=500,
|
||||
description="Short description / title of the annotation. "
|
||||
"Must be unique within the annotation layer.",
|
||||
)
|
||||
start_dttm: datetime = Field(
|
||||
...,
|
||||
description="Annotation start time in ISO 8601 format "
|
||||
"(e.g. '2024-01-15T08:00:00').",
|
||||
)
|
||||
end_dttm: datetime = Field(
|
||||
...,
|
||||
description="Annotation end time in ISO 8601 format "
|
||||
"(e.g. '2024-01-15T09:00:00'). Must be >= start_dttm.",
|
||||
)
|
||||
long_descr: str | None = Field(
|
||||
None,
|
||||
description="Detailed description of the annotation (optional).",
|
||||
)
|
||||
json_metadata: str | None = Field(
|
||||
None,
|
||||
description="Optional JSON metadata string for the annotation.",
|
||||
)
|
||||
|
||||
@field_validator("json_metadata")
|
||||
@classmethod
|
||||
def validate_json_metadata(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
try:
|
||||
json_utils.loads(v)
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise ValueError("json_metadata must be valid JSON") from exc
|
||||
return v
|
||||
|
||||
|
||||
class CreateLayerAnnotationResponse(BaseModel):
|
||||
"""Response schema for create_layer_annotation."""
|
||||
|
||||
id: int | None = Field(
|
||||
None,
|
||||
description="ID of the created annotation. None if creation failed.",
|
||||
)
|
||||
layer_id: int = Field(
|
||||
...,
|
||||
description="ID of the annotation layer the annotation was added to.",
|
||||
)
|
||||
short_descr: str = Field(
|
||||
...,
|
||||
description="Short description / title of the annotation.",
|
||||
)
|
||||
start_dttm: datetime | None = Field(
|
||||
None,
|
||||
description="Annotation start time.",
|
||||
)
|
||||
end_dttm: datetime | None = Field(
|
||||
None,
|
||||
description="Annotation end time.",
|
||||
)
|
||||
long_descr: str | None = Field(
|
||||
None,
|
||||
description="Detailed description of the annotation.",
|
||||
)
|
||||
error: str | None = Field(
|
||||
None,
|
||||
description="Error message if creation failed, otherwise null.",
|
||||
)
|
||||
|
||||
|
||||
class UpdateLayerAnnotationRequest(BaseModel):
|
||||
"""Request schema for update_layer_annotation."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
layer_id: int = Field(
|
||||
...,
|
||||
description="ID of the annotation layer the annotation belongs to.",
|
||||
)
|
||||
annotation_id: int = Field(
|
||||
...,
|
||||
description="ID of the annotation to update.",
|
||||
)
|
||||
short_descr: str | None = Field(
|
||||
None,
|
||||
min_length=1,
|
||||
max_length=500,
|
||||
description="New short description / title. "
|
||||
"Must be unique within the annotation layer.",
|
||||
)
|
||||
start_dttm: datetime | None = Field(
|
||||
None,
|
||||
description="New annotation start time in ISO 8601 format.",
|
||||
)
|
||||
end_dttm: datetime | None = Field(
|
||||
None,
|
||||
description="New annotation end time in ISO 8601 format. "
|
||||
"Must be >= start_dttm.",
|
||||
)
|
||||
long_descr: str | None = Field(
|
||||
None,
|
||||
description="New detailed description (optional).",
|
||||
)
|
||||
json_metadata: str | None = Field(
|
||||
None,
|
||||
description="New JSON metadata string (optional).",
|
||||
)
|
||||
|
||||
@field_validator("json_metadata")
|
||||
@classmethod
|
||||
def validate_json_metadata(cls, v: str | None) -> str | None:
|
||||
if v is None:
|
||||
return v
|
||||
try:
|
||||
json.loads(v)
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise ValueError("json_metadata must be valid JSON") from exc
|
||||
return v
|
||||
|
||||
|
||||
class UpdateLayerAnnotationResponse(BaseModel):
|
||||
"""Response schema for update_layer_annotation."""
|
||||
|
||||
id: int | None = Field(
|
||||
None,
|
||||
description="ID of the updated annotation. None if update failed.",
|
||||
)
|
||||
layer_id: int = Field(
|
||||
...,
|
||||
description="ID of the annotation layer.",
|
||||
)
|
||||
short_descr: str | None = Field(
|
||||
None,
|
||||
description="Short description / title of the annotation.",
|
||||
)
|
||||
start_dttm: datetime | None = Field(
|
||||
None,
|
||||
description="Annotation start time.",
|
||||
)
|
||||
end_dttm: datetime | None = Field(
|
||||
None,
|
||||
description="Annotation end time.",
|
||||
)
|
||||
long_descr: str | None = Field(
|
||||
None,
|
||||
description="Detailed description of the annotation.",
|
||||
)
|
||||
error: str | None = Field(
|
||||
None,
|
||||
description="Error message if update failed, otherwise null.",
|
||||
)
|
||||
|
||||
@@ -15,14 +15,18 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .create_layer_annotation import create_layer_annotation
|
||||
from .get_annotation_layer_info import get_annotation_layer_info
|
||||
from .get_layer_annotation_info import get_layer_annotation_info
|
||||
from .list_annotation_layers import list_annotation_layers
|
||||
from .list_layer_annotations import list_layer_annotations
|
||||
from .update_layer_annotation import update_layer_annotation
|
||||
|
||||
__all__ = [
|
||||
"list_annotation_layers",
|
||||
"get_annotation_layer_info",
|
||||
"list_layer_annotations",
|
||||
"get_layer_annotation_info",
|
||||
"create_layer_annotation",
|
||||
"update_layer_annotation",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.annotation_layer.schemas import (
|
||||
CreateLayerAnnotationRequest,
|
||||
CreateLayerAnnotationResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["mutate"],
|
||||
class_permission_name="Annotation",
|
||||
method_permission_name="write",
|
||||
annotations=ToolAnnotations(
|
||||
title="Add annotation to an annotation layer",
|
||||
readOnlyHint=False,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def create_layer_annotation(
|
||||
request: CreateLayerAnnotationRequest, ctx: Context
|
||||
) -> CreateLayerAnnotationResponse:
|
||||
"""Add a new annotation to an existing annotation layer.
|
||||
|
||||
Use this tool when a user wants to mark a specific time range on charts
|
||||
with a label or note — for example, to flag a deployment, an outage, or
|
||||
a campaign period.
|
||||
|
||||
Workflow:
|
||||
1. Identify the target annotation layer (layer_id). Use list tools if needed.
|
||||
2. Call this tool with the layer_id, a short description, and the time range.
|
||||
3. The annotation will appear on charts that reference that layer.
|
||||
"""
|
||||
await ctx.info(
|
||||
"Creating annotation: layer_id=%s, short_descr=%r"
|
||||
% (request.layer_id, request.short_descr)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.commands.annotation_layer.annotation.create import (
|
||||
CreateAnnotationCommand,
|
||||
)
|
||||
from superset.commands.annotation_layer.annotation.exceptions import (
|
||||
AnnotationCreateFailedError,
|
||||
AnnotationInvalidError,
|
||||
)
|
||||
from superset.commands.annotation_layer.exceptions import (
|
||||
AnnotationLayerNotFoundError,
|
||||
)
|
||||
|
||||
properties: dict[str, Any] = {
|
||||
"layer": request.layer_id,
|
||||
"short_descr": request.short_descr,
|
||||
"start_dttm": request.start_dttm,
|
||||
"end_dttm": request.end_dttm,
|
||||
}
|
||||
if request.long_descr is not None:
|
||||
properties["long_descr"] = request.long_descr
|
||||
if request.json_metadata is not None:
|
||||
properties["json_metadata"] = request.json_metadata
|
||||
|
||||
with event_logger.log_context(action="mcp.create_layer_annotation.create"):
|
||||
annotation = CreateAnnotationCommand(properties).run()
|
||||
|
||||
await ctx.info(
|
||||
"Annotation created: id=%s, layer_id=%s, short_descr=%r"
|
||||
% (annotation.id, request.layer_id, request.short_descr)
|
||||
)
|
||||
|
||||
return CreateLayerAnnotationResponse(
|
||||
id=annotation.id,
|
||||
layer_id=request.layer_id,
|
||||
short_descr=annotation.short_descr,
|
||||
start_dttm=annotation.start_dttm,
|
||||
end_dttm=annotation.end_dttm,
|
||||
long_descr=getattr(annotation, "long_descr", None),
|
||||
)
|
||||
|
||||
except AnnotationLayerNotFoundError as exc:
|
||||
await ctx.warning(
|
||||
"Annotation layer not found: layer_id=%s" % (request.layer_id,)
|
||||
)
|
||||
return CreateLayerAnnotationResponse(
|
||||
id=None,
|
||||
layer_id=request.layer_id,
|
||||
short_descr=request.short_descr,
|
||||
start_dttm=request.start_dttm,
|
||||
end_dttm=request.end_dttm,
|
||||
error=f"Annotation layer {request.layer_id} not found: {exc}",
|
||||
)
|
||||
except AnnotationInvalidError as exc:
|
||||
messages = exc.normalized_messages()
|
||||
await ctx.warning("Annotation validation failed: %s" % (messages,))
|
||||
return CreateLayerAnnotationResponse(
|
||||
id=None,
|
||||
layer_id=request.layer_id,
|
||||
short_descr=request.short_descr,
|
||||
start_dttm=request.start_dttm,
|
||||
end_dttm=request.end_dttm,
|
||||
error=str(messages),
|
||||
)
|
||||
except AnnotationCreateFailedError as exc:
|
||||
await ctx.error("Annotation creation failed: %s" % (str(exc),))
|
||||
return CreateLayerAnnotationResponse(
|
||||
id=None,
|
||||
layer_id=request.layer_id,
|
||||
short_descr=request.short_descr,
|
||||
start_dttm=request.start_dttm,
|
||||
end_dttm=request.end_dttm,
|
||||
error=f"Failed to create annotation: {exc}",
|
||||
)
|
||||
except Exception as exc:
|
||||
await ctx.error(
|
||||
"Unexpected error creating annotation: %s: %s"
|
||||
% (type(exc).__name__, str(exc))
|
||||
)
|
||||
raise
|
||||
@@ -0,0 +1,153 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.annotation_layer.schemas import (
|
||||
UpdateLayerAnnotationRequest,
|
||||
UpdateLayerAnnotationResponse,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_update_properties(request: UpdateLayerAnnotationRequest) -> dict[str, Any]:
|
||||
"""Build the properties dict for UpdateAnnotationCommand from the request."""
|
||||
properties: dict[str, Any] = {"layer": request.layer_id}
|
||||
if request.short_descr is not None:
|
||||
properties["short_descr"] = request.short_descr
|
||||
if request.start_dttm is not None:
|
||||
properties["start_dttm"] = request.start_dttm
|
||||
if request.end_dttm is not None:
|
||||
properties["end_dttm"] = request.end_dttm
|
||||
if request.long_descr is not None:
|
||||
properties["long_descr"] = request.long_descr
|
||||
if request.json_metadata is not None:
|
||||
properties["json_metadata"] = request.json_metadata
|
||||
return properties
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["mutate"],
|
||||
class_permission_name="Annotation",
|
||||
method_permission_name="write",
|
||||
annotations=ToolAnnotations(
|
||||
title="Update an annotation in an annotation layer",
|
||||
readOnlyHint=False,
|
||||
destructiveHint=True,
|
||||
),
|
||||
)
|
||||
async def update_layer_annotation(
|
||||
request: UpdateLayerAnnotationRequest, ctx: Context
|
||||
) -> UpdateLayerAnnotationResponse:
|
||||
"""Update an existing annotation in an annotation layer.
|
||||
|
||||
Use this tool to change the time range, description, or metadata of an
|
||||
existing annotation — for example to correct a deployment window or extend
|
||||
an outage marker.
|
||||
|
||||
All fields except layer_id and annotation_id are optional; only the
|
||||
fields provided will be updated.
|
||||
|
||||
Workflow:
|
||||
1. Identify the annotation layer (layer_id) and annotation (annotation_id).
|
||||
2. Call this tool with the fields you want to change.
|
||||
3. The annotation will be updated in place on all charts that reference
|
||||
that layer.
|
||||
"""
|
||||
await ctx.info(
|
||||
"Updating annotation: layer_id=%s, annotation_id=%s"
|
||||
% (request.layer_id, request.annotation_id)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.commands.annotation_layer.annotation.exceptions import (
|
||||
AnnotationInvalidError,
|
||||
AnnotationNotFoundError,
|
||||
AnnotationUpdateFailedError,
|
||||
)
|
||||
from superset.commands.annotation_layer.annotation.update import (
|
||||
UpdateAnnotationCommand,
|
||||
)
|
||||
from superset.commands.annotation_layer.exceptions import (
|
||||
AnnotationLayerNotFoundError,
|
||||
)
|
||||
|
||||
properties = _build_update_properties(request)
|
||||
|
||||
with event_logger.log_context(action="mcp.update_layer_annotation.update"):
|
||||
annotation = UpdateAnnotationCommand(
|
||||
request.annotation_id, properties
|
||||
).run()
|
||||
|
||||
await ctx.info(
|
||||
"Annotation updated: id=%s, layer_id=%s" % (annotation.id, request.layer_id)
|
||||
)
|
||||
|
||||
return UpdateLayerAnnotationResponse(
|
||||
id=annotation.id,
|
||||
layer_id=request.layer_id,
|
||||
short_descr=annotation.short_descr,
|
||||
start_dttm=annotation.start_dttm,
|
||||
end_dttm=annotation.end_dttm,
|
||||
long_descr=getattr(annotation, "long_descr", None),
|
||||
)
|
||||
|
||||
except AnnotationNotFoundError as exc:
|
||||
await ctx.warning(
|
||||
"Annotation not found: annotation_id=%s" % (request.annotation_id,)
|
||||
)
|
||||
return UpdateLayerAnnotationResponse(
|
||||
id=None,
|
||||
layer_id=request.layer_id,
|
||||
error=f"Annotation {request.annotation_id} not found: {exc}",
|
||||
)
|
||||
except AnnotationLayerNotFoundError as exc:
|
||||
await ctx.warning(
|
||||
"Annotation layer not found: layer_id=%s" % (request.layer_id,)
|
||||
)
|
||||
return UpdateLayerAnnotationResponse(
|
||||
id=None,
|
||||
layer_id=request.layer_id,
|
||||
error=f"Annotation layer {request.layer_id} not found: {exc}",
|
||||
)
|
||||
except AnnotationInvalidError as exc:
|
||||
messages = exc.normalized_messages()
|
||||
await ctx.warning("Annotation validation failed: %s" % (messages,))
|
||||
return UpdateLayerAnnotationResponse(
|
||||
id=None,
|
||||
layer_id=request.layer_id,
|
||||
error=str(messages),
|
||||
)
|
||||
except AnnotationUpdateFailedError as exc:
|
||||
await ctx.error("Annotation update failed: %s" % (str(exc),))
|
||||
return UpdateLayerAnnotationResponse(
|
||||
id=None,
|
||||
layer_id=request.layer_id,
|
||||
error=f"Failed to update annotation: {exc}",
|
||||
)
|
||||
except Exception as exc:
|
||||
await ctx.error(
|
||||
"Unexpected error updating annotation: %s: %s"
|
||||
% (type(exc).__name__, str(exc))
|
||||
)
|
||||
raise
|
||||
@@ -673,10 +673,12 @@ from superset.mcp_service.action_log.tool import ( # noqa: F401, E402
|
||||
list_action_logs,
|
||||
)
|
||||
from superset.mcp_service.annotation_layer.tool import ( # noqa: F401, E402
|
||||
create_layer_annotation,
|
||||
get_annotation_layer_info,
|
||||
get_layer_annotation_info,
|
||||
list_annotation_layers,
|
||||
list_layer_annotations,
|
||||
update_layer_annotation,
|
||||
)
|
||||
from superset.mcp_service.chart import ( # noqa: F401, E402
|
||||
prompts as chart_prompts,
|
||||
|
||||
@@ -0,0 +1,467 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.annotation_layer.schemas import (
|
||||
CreateLayerAnnotationRequest,
|
||||
)
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
from unittest.mock import Mock
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
def _make_request(**kwargs) -> CreateLayerAnnotationRequest:
|
||||
defaults = {
|
||||
"layer_id": 1,
|
||||
"short_descr": "Deploy v2.0",
|
||||
"start_dttm": datetime(2024, 1, 15, 8, 0, tzinfo=timezone.utc),
|
||||
"end_dttm": datetime(2024, 1, 15, 9, 0, tzinfo=timezone.utc),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return CreateLayerAnnotationRequest(**defaults)
|
||||
|
||||
|
||||
def _make_mock_annotation(
|
||||
id: int = 42,
|
||||
short_descr: str = "Deploy v2.0",
|
||||
long_descr: str | None = None,
|
||||
) -> MagicMock:
|
||||
annotation = MagicMock()
|
||||
annotation.id = id
|
||||
annotation.short_descr = short_descr
|
||||
annotation.long_descr = long_descr
|
||||
annotation.start_dttm = datetime(2024, 1, 15, 8, 0, tzinfo=timezone.utc)
|
||||
annotation.end_dttm = datetime(2024, 1, 15, 9, 0, tzinfo=timezone.utc)
|
||||
return annotation
|
||||
|
||||
|
||||
# --- Schema tests ---
|
||||
|
||||
|
||||
def test_request_valid() -> None:
|
||||
req = _make_request()
|
||||
assert req.layer_id == 1
|
||||
assert req.short_descr == "Deploy v2.0"
|
||||
assert req.long_descr is None
|
||||
assert req.json_metadata is None
|
||||
|
||||
|
||||
def test_request_short_descr_too_long() -> None:
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
_make_request(short_descr="x" * 501)
|
||||
|
||||
|
||||
def test_request_end_before_start_is_allowed_at_schema_level() -> None:
|
||||
# Date ordering is enforced by the command, not the Pydantic schema
|
||||
req = _make_request(
|
||||
start_dttm=datetime(2024, 1, 15, 10, 0),
|
||||
end_dttm=datetime(2024, 1, 15, 8, 0),
|
||||
)
|
||||
assert req.start_dttm > req.end_dttm
|
||||
|
||||
|
||||
def test_request_invalid_json_metadata_fails() -> None:
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="json_metadata must be valid JSON"):
|
||||
_make_request(json_metadata="not-json{")
|
||||
|
||||
|
||||
def test_request_valid_json_metadata() -> None:
|
||||
req = _make_request(json_metadata='{"key": "value"}')
|
||||
assert req.json_metadata == '{"key": "value"}'
|
||||
|
||||
|
||||
# --- Tool logic tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_layer_annotation_success(mcp_server: object) -> None:
|
||||
"""Happy path: annotation created, id and fields returned."""
|
||||
mock_annotation = _make_mock_annotation(id=42, short_descr="Deploy v2.0")
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_annotation
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.create.CreateAnnotationCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_request()
|
||||
result = await client.call_tool(
|
||||
"create_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["id"] == 42
|
||||
assert data["layer_id"] == 1
|
||||
assert data["short_descr"] == "Deploy v2.0"
|
||||
assert data["error"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_layer_annotation_layer_not_found(mcp_server: object) -> None:
|
||||
"""AnnotationLayerNotFoundError returns structured error response."""
|
||||
from superset.commands.annotation_layer.exceptions import (
|
||||
AnnotationLayerNotFoundError,
|
||||
)
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = AnnotationLayerNotFoundError()
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.create.CreateAnnotationCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_request(layer_id=999)
|
||||
result = await client.call_tool(
|
||||
"create_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["id"] is None
|
||||
assert data["layer_id"] == 999
|
||||
assert data["error"] is not None
|
||||
assert "999" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_layer_annotation_invalid_error(mcp_server: object) -> None:
|
||||
"""AnnotationInvalidError (e.g. duplicate short_descr) returns structured error."""
|
||||
from superset.commands.annotation_layer.annotation.exceptions import (
|
||||
AnnotationInvalidError,
|
||||
AnnotationUniquenessValidationError,
|
||||
)
|
||||
|
||||
invalid_exc = AnnotationInvalidError()
|
||||
invalid_exc.append(AnnotationUniquenessValidationError())
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = invalid_exc
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.create.CreateAnnotationCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_request()
|
||||
result = await client.call_tool(
|
||||
"create_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["id"] is None
|
||||
assert data["error"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_layer_annotation_create_failed(mcp_server: object) -> None:
|
||||
"""AnnotationCreateFailedError is caught and returned as error response."""
|
||||
from superset.commands.annotation_layer.annotation.exceptions import (
|
||||
AnnotationCreateFailedError,
|
||||
)
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = AnnotationCreateFailedError()
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.create.CreateAnnotationCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_request()
|
||||
result = await client.call_tool(
|
||||
"create_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["id"] is None
|
||||
assert data["error"] is not None
|
||||
assert "Failed to create annotation" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_layer_annotation_optional_fields_forwarded(
|
||||
mcp_server: object,
|
||||
) -> None:
|
||||
"""long_descr and json_metadata are forwarded to CreateAnnotationCommand."""
|
||||
mock_annotation = _make_mock_annotation(
|
||||
id=5, short_descr="Outage", long_descr="DB connectivity lost"
|
||||
)
|
||||
mock_command_instance = MagicMock()
|
||||
mock_command_instance.run.return_value = mock_annotation
|
||||
mock_command_cls = MagicMock(return_value=mock_command_instance)
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.create.CreateAnnotationCommand",
|
||||
mock_command_cls,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_request(
|
||||
short_descr="Outage",
|
||||
long_descr="DB connectivity lost",
|
||||
json_metadata='{"severity": "high"}',
|
||||
)
|
||||
await client.call_tool(
|
||||
"create_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
props = mock_command_cls.call_args[0][0]
|
||||
assert props["long_descr"] == "DB connectivity lost"
|
||||
assert props["json_metadata"] == '{"severity": "high"}'
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# update_layer_annotation tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
from superset.mcp_service.annotation_layer.schemas import ( # noqa: E402
|
||||
UpdateLayerAnnotationRequest,
|
||||
)
|
||||
|
||||
|
||||
def _make_update_request(**kwargs) -> UpdateLayerAnnotationRequest:
|
||||
defaults: dict[str, object] = {
|
||||
"layer_id": 1,
|
||||
"annotation_id": 42,
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return UpdateLayerAnnotationRequest(**defaults)
|
||||
|
||||
|
||||
# --- Schema tests ---
|
||||
|
||||
|
||||
def test_update_request_valid() -> None:
|
||||
req = _make_update_request()
|
||||
assert req.layer_id == 1
|
||||
assert req.annotation_id == 42
|
||||
assert req.short_descr is None
|
||||
|
||||
|
||||
def test_update_request_short_descr_too_long() -> None:
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
_make_update_request(short_descr="x" * 501)
|
||||
|
||||
|
||||
def test_update_request_invalid_json_metadata_fails() -> None:
|
||||
from pydantic import ValidationError
|
||||
|
||||
with pytest.raises(ValidationError, match="json_metadata must be valid JSON"):
|
||||
_make_update_request(json_metadata="not-json{")
|
||||
|
||||
|
||||
def test_update_request_valid_json_metadata() -> None:
|
||||
req = _make_update_request(json_metadata='{"key": "value"}')
|
||||
assert req.json_metadata == '{"key": "value"}'
|
||||
|
||||
|
||||
# --- Tool logic tests ---
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_layer_annotation_success(mcp_server: object) -> None:
|
||||
"""Happy path: annotation updated, id and fields returned."""
|
||||
mock_annotation = _make_mock_annotation(id=42, short_descr="Fixed title")
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.return_value = mock_annotation
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.update.UpdateAnnotationCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_update_request(short_descr="Fixed title")
|
||||
result = await client.call_tool(
|
||||
"update_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["id"] == 42
|
||||
assert data["layer_id"] == 1
|
||||
assert data["short_descr"] == "Fixed title"
|
||||
assert data["error"] is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_layer_annotation_not_found(mcp_server: object) -> None:
|
||||
"""AnnotationNotFoundError returns structured error response."""
|
||||
from superset.commands.annotation_layer.annotation.exceptions import (
|
||||
AnnotationNotFoundError,
|
||||
)
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = AnnotationNotFoundError()
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.update.UpdateAnnotationCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_update_request(annotation_id=999)
|
||||
result = await client.call_tool(
|
||||
"update_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["id"] is None
|
||||
assert data["error"] is not None
|
||||
assert "999" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_layer_annotation_layer_not_found(mcp_server: object) -> None:
|
||||
"""AnnotationLayerNotFoundError returns structured error response."""
|
||||
from superset.commands.annotation_layer.exceptions import (
|
||||
AnnotationLayerNotFoundError,
|
||||
)
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = AnnotationLayerNotFoundError()
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.update.UpdateAnnotationCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_update_request(layer_id=999)
|
||||
result = await client.call_tool(
|
||||
"update_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["id"] is None
|
||||
assert data["error"] is not None
|
||||
assert "999" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_layer_annotation_invalid_error(mcp_server: object) -> None:
|
||||
"""AnnotationInvalidError (e.g. duplicate short_descr) returns structured error."""
|
||||
from superset.commands.annotation_layer.annotation.exceptions import (
|
||||
AnnotationInvalidError,
|
||||
AnnotationUniquenessValidationError,
|
||||
)
|
||||
|
||||
invalid_exc = AnnotationInvalidError()
|
||||
invalid_exc.append(AnnotationUniquenessValidationError())
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = invalid_exc
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.update.UpdateAnnotationCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_update_request(short_descr="Duplicate")
|
||||
result = await client.call_tool(
|
||||
"update_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["id"] is None
|
||||
assert data["error"] is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_layer_annotation_update_failed(mcp_server: object) -> None:
|
||||
"""AnnotationUpdateFailedError is caught and returned as error response."""
|
||||
from superset.commands.annotation_layer.annotation.exceptions import (
|
||||
AnnotationUpdateFailedError,
|
||||
)
|
||||
|
||||
mock_command = MagicMock()
|
||||
mock_command.run.side_effect = AnnotationUpdateFailedError()
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.update.UpdateAnnotationCommand",
|
||||
return_value=mock_command,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_update_request()
|
||||
result = await client.call_tool(
|
||||
"update_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
|
||||
assert data["id"] is None
|
||||
assert data["error"] is not None
|
||||
assert "Failed to update annotation" in data["error"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_layer_annotation_only_provided_fields_forwarded(
|
||||
mcp_server: object,
|
||||
) -> None:
|
||||
"""Only non-None fields are forwarded to UpdateAnnotationCommand."""
|
||||
mock_annotation = _make_mock_annotation(id=42)
|
||||
mock_command_instance = MagicMock()
|
||||
mock_command_instance.run.return_value = mock_annotation
|
||||
mock_command_cls = MagicMock(return_value=mock_command_instance)
|
||||
|
||||
with patch(
|
||||
"superset.commands.annotation_layer.annotation.update.UpdateAnnotationCommand",
|
||||
mock_command_cls,
|
||||
):
|
||||
async with Client(mcp_server) as client:
|
||||
request = _make_update_request(
|
||||
short_descr="New title",
|
||||
long_descr="Updated description",
|
||||
)
|
||||
await client.call_tool(
|
||||
"update_layer_annotation", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
# annotation_id is the first positional arg, properties is the second
|
||||
call_args = mock_command_cls.call_args
|
||||
annotation_id_arg = call_args[0][0]
|
||||
props = call_args[0][1]
|
||||
|
||||
assert annotation_id_arg == 42
|
||||
assert props["short_descr"] == "New title"
|
||||
assert props["long_descr"] == "Updated description"
|
||||
# Fields not provided should not be in properties
|
||||
assert "start_dttm" not in props
|
||||
assert "end_dttm" not in props
|
||||
assert "json_metadata" not in props
|
||||
Reference in New Issue
Block a user