diff --git a/superset/mcp_service/annotation_layer/__init__.py b/superset/mcp_service/annotation_layer/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/superset/mcp_service/annotation_layer/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/annotation_layer/schemas.py b/superset/mcp_service/annotation_layer/schemas.py new file mode 100644 index 00000000000..11059ab75b4 --- /dev/null +++ b/superset/mcp_service/annotation_layer/schemas.py @@ -0,0 +1,306 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Pydantic schemas for annotation layer and annotation responses.""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, List, Literal + +from pydantic import ( + BaseModel, + ConfigDict, + Field, + field_validator, + model_validator, + PositiveInt, +) + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE +from superset.mcp_service.system.schemas import PaginationInfo +from superset.mcp_service.utils.schema_utils import ( + parse_json_or_list, + parse_json_or_model_list, +) + +DEFAULT_LAYER_COLUMNS = ["id", "name", "descr"] +DEFAULT_ANNOTATION_COLUMNS = ["id", "short_descr", "start_dttm", "end_dttm"] + + +class AnnotationLayerFilter(ColumnOperator): + """Filter object for annotation layer listing.""" + + col: Literal["name"] = Field( + ..., + description="Column to filter on. Supported: 'name'.", + ) + opr: ColumnOperatorEnum = Field(..., description="Filter operator.") + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by." + ) + + +class AnnotationFilter(ColumnOperator): + """Filter object for annotation listing.""" + + col: Literal["short_descr"] = Field( + ..., + description="Column to filter on. Supported: 'short_descr'.", + ) + opr: ColumnOperatorEnum = Field(..., description="Filter operator.") + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by." + ) + + +class AnnotationLayerInfo(BaseModel): + id: int | None = Field(None, description="Annotation layer ID") + name: str | None = Field(None, description="Annotation layer name") + descr: str | None = Field(None, description="Annotation layer description") + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + created_on: str | datetime | None = Field(None, description="Creation timestamp") + model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") + + +class AnnotationLayerList(BaseModel): + annotation_layers: List[AnnotationLayerInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] = Field(default_factory=list) + columns_loaded: List[str] = Field(default_factory=list) + columns_available: List[str] = Field(default_factory=list) + sortable_columns: List[str] = Field(default_factory=list) + filters_applied: List[AnnotationLayerFilter] = Field(default_factory=list) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListAnnotationLayersRequest(BaseModel): + """Request schema for list_annotation_layers.""" + + filters: Annotated[ + List[AnnotationLayerFilter], + Field( + default_factory=list, + description="List of filter objects. Cannot be combined with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=list, + description="Columns to include in the response.", + ), + ] + search: Annotated[ + str | None, + Field(default=None, description="Text search across name and description."), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by.") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field(default="desc", description="Sort direction."), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number (1-based)."), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Items per page (max {MAX_PAGE_SIZE}).", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[AnnotationLayerFilter]: + return parse_json_or_model_list(v, AnnotationLayerFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListAnnotationLayersRequest": + if self.search and self.filters: + raise ValueError("Cannot use both 'search' and 'filters' simultaneously.") + return self + + +class GetAnnotationLayerInfoRequest(BaseModel): + """Request schema for get_annotation_layer_info.""" + + id: Annotated[int, Field(description="Annotation layer ID.")] + + +class AnnotationInfo(BaseModel): + id: int | None = Field(None, description="Annotation ID") + short_descr: str | None = Field(None, description="Short description") + long_descr: str | None = Field(None, description="Long description") + start_dttm: str | datetime | None = Field(None, description="Start datetime") + end_dttm: str | datetime | None = Field(None, description="End datetime") + json_metadata: str | None = Field(None, description="JSON metadata") + layer_id: int | None = Field(None, description="Parent annotation layer ID") + model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") + + +class AnnotationList(BaseModel): + annotations: List[AnnotationInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + layer_id: int + columns_requested: List[str] = Field(default_factory=list) + columns_loaded: List[str] = Field(default_factory=list) + columns_available: List[str] = Field(default_factory=list) + sortable_columns: List[str] = Field(default_factory=list) + filters_applied: List[AnnotationFilter] = Field(default_factory=list) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListLayerAnnotationsRequest(BaseModel): + """Request schema for list_layer_annotations.""" + + layer_id: Annotated[ + int, Field(description="Annotation layer ID to list annotations for.") + ] + filters: Annotated[ + List[AnnotationFilter], + Field( + default_factory=list, + description="List of filter objects. Cannot be combined with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field(default_factory=list, description="Columns to include in the response."), + ] + search: Annotated[ + str | None, + Field( + default=None, description="Text search across short and long description." + ), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by.") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field(default="desc", description="Sort direction."), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number (1-based)."), + ] + page_size: Annotated[ + int, + Field( + default=DEFAULT_PAGE_SIZE, + gt=0, + le=MAX_PAGE_SIZE, + description=f"Items per page (max {MAX_PAGE_SIZE}).", + ), + ] + + @field_validator("filters", mode="before") + @classmethod + def parse_filters(cls, v: Any) -> List[AnnotationFilter]: + return parse_json_or_model_list(v, AnnotationFilter, "filters") + + @field_validator("select_columns", mode="before") + @classmethod + def parse_columns(cls, v: Any) -> List[str]: + return parse_json_or_list(v, "select_columns") + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListLayerAnnotationsRequest": + if self.search and self.filters: + raise ValueError("Cannot use both 'search' and 'filters' simultaneously.") + return self + + +class GetLayerAnnotationInfoRequest(BaseModel): + """Request schema for get_layer_annotation_info.""" + + layer_id: Annotated[int, Field(description="Annotation layer ID.")] + annotation_id: Annotated[int, Field(description="Annotation ID.")] + + +class AnnotationLayerError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "AnnotationLayerError": + from datetime import timezone + + return cls( + error=error, + error_type=error_type, + timestamp=datetime.now(timezone.utc), + ) + + +def serialize_annotation_layer(obj: Any) -> AnnotationLayerInfo | None: + if not obj: + return None + return AnnotationLayerInfo( + id=getattr(obj, "id", None), + name=getattr(obj, "name", None), + descr=getattr(obj, "descr", None), + changed_on=getattr(obj, "changed_on", None), + created_on=getattr(obj, "created_on", None), + ) + + +def serialize_annotation(obj: Any) -> AnnotationInfo | None: + if not obj: + return None + return AnnotationInfo( + id=getattr(obj, "id", None), + short_descr=getattr(obj, "short_descr", None), + long_descr=getattr(obj, "long_descr", None), + start_dttm=getattr(obj, "start_dttm", None), + end_dttm=getattr(obj, "end_dttm", None), + json_metadata=getattr(obj, "json_metadata", None), + layer_id=getattr(obj, "layer_id", None), + ) diff --git a/superset/mcp_service/annotation_layer/tool/__init__.py b/superset/mcp_service/annotation_layer/tool/__init__.py new file mode 100644 index 00000000000..75bbed4e1f0 --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/__init__.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_annotation_layer_info import get_annotation_layer_info +from .get_layer_annotation_info import get_layer_annotation_info +from .list_annotation_layers import list_annotation_layers +from .list_layer_annotations import list_layer_annotations + +__all__ = [ + "list_annotation_layers", + "get_annotation_layer_info", + "list_layer_annotations", + "get_layer_annotation_info", +] diff --git a/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py b/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py new file mode 100644 index 00000000000..d46c2b10927 --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/get_annotation_layer_info.py @@ -0,0 +1,93 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Get annotation layer info FastMCP tool.""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationLayerError, + AnnotationLayerInfo, + GetAnnotationLayerInfoRequest, + serialize_annotation_layer, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Annotation", + annotations=ToolAnnotations( + title="Get annotation layer info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_annotation_layer_info( + request: GetAnnotationLayerInfoRequest, + ctx: Context, +) -> AnnotationLayerInfo | AnnotationLayerError: + """Get detailed information about an annotation layer by ID. + + Returns the layer's name, description, and timestamps. + + Example: + ```json + {"id": 1} + ``` + """ + await ctx.info("Retrieving annotation layer: id=%s" % (request.id,)) + + try: + from superset.daos.annotation_layer import AnnotationLayerDAO + + with event_logger.log_context(action="mcp.get_annotation_layer_info.lookup"): + layer = AnnotationLayerDAO.find_by_id(request.id) + + if layer is None: + await ctx.warning("Annotation layer not found: id=%s" % (request.id,)) + return AnnotationLayerError.create( + error=f"Annotation layer with id '{request.id}' not found", + error_type="not_found", + ) + + result = serialize_annotation_layer(layer) + await ctx.info( + "Annotation layer retrieved: id=%s, name=%s" + % (result.id if result else None, result.name if result else None) + ) + return result or AnnotationLayerError.create( + error="Failed to serialize annotation layer", + error_type="SerializationError", + ) + + except Exception as e: + await ctx.error( + "Annotation layer lookup failed: id=%s, error=%s, error_type=%s" + % (request.id, str(e), type(e).__name__) + ) + return AnnotationLayerError( + error=f"Failed to get annotation layer info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/annotation_layer/tool/get_layer_annotation_info.py b/superset/mcp_service/annotation_layer/tool/get_layer_annotation_info.py new file mode 100644 index 00000000000..43ba7648e03 --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/get_layer_annotation_info.py @@ -0,0 +1,130 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Get a single annotation within a layer FastMCP tool.""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationInfo, + AnnotationLayerError, + GetLayerAnnotationInfoRequest, + serialize_annotation, +) + +logger = logging.getLogger(__name__) + + +@tool( + tags=["discovery"], + class_permission_name="Annotation", + annotations=ToolAnnotations( + title="Get annotation info", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def get_layer_annotation_info( + request: GetLayerAnnotationInfoRequest, + ctx: Context, +) -> AnnotationInfo | AnnotationLayerError: + """Get detailed information about a specific annotation within a layer. + + Both layer_id and annotation_id are required. Returns an error if the + annotation does not belong to the specified layer. + + Example: + ```json + {"layer_id": 1, "annotation_id": 42} + ``` + """ + await ctx.info( + "Retrieving annotation: layer_id=%s, annotation_id=%s" + % (request.layer_id, request.annotation_id) + ) + + try: + from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO + + # Verify the layer exists + with event_logger.log_context( + action="mcp.get_layer_annotation_info.layer_lookup" + ): + layer = AnnotationLayerDAO.find_by_id(request.layer_id) + + if layer is None: + await ctx.warning("Annotation layer not found: id=%s" % (request.layer_id,)) + return AnnotationLayerError.create( + error=f"Annotation layer with id '{request.layer_id}' not found", + error_type="not_found", + ) + + # Fetch the annotation + with event_logger.log_context( + action="mcp.get_layer_annotation_info.annotation_lookup" + ): + annotation = AnnotationDAO.find_by_id(request.annotation_id) + + if annotation is None: + await ctx.warning( + "Annotation not found: annotation_id=%s" % (request.annotation_id,) + ) + return AnnotationLayerError.create( + error=f"Annotation with id '{request.annotation_id}' not found", + error_type="not_found", + ) + + # Verify the annotation belongs to the requested layer + if getattr(annotation, "layer_id", None) != request.layer_id: + await ctx.warning( + "Annotation %s does not belong to layer %s" + % (request.annotation_id, request.layer_id) + ) + return AnnotationLayerError.create( + error=( + f"Annotation '{request.annotation_id}' does not belong to " + f"layer '{request.layer_id}'" + ), + error_type="not_found", + ) + + result = serialize_annotation(annotation) + await ctx.info( + "Annotation retrieved: id=%s, short_descr=%s" + % (result.id if result else None, result.short_descr if result else None) + ) + return result or AnnotationLayerError.create( + error="Failed to serialize annotation", + error_type="SerializationError", + ) + + except Exception as e: + await ctx.error( + "Annotation lookup failed: layer_id=%s, annotation_id=%s, " + "error=%s, error_type=%s" + % (request.layer_id, request.annotation_id, str(e), type(e).__name__) + ) + return AnnotationLayerError( + error=f"Failed to get annotation info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py b/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py new file mode 100644 index 00000000000..fc924e428f2 --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/list_annotation_layers.py @@ -0,0 +1,123 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""List annotation layers FastMCP tool.""" + +import logging + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.extensions import event_logger +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationLayerError, + AnnotationLayerFilter, + AnnotationLayerInfo, + AnnotationLayerList, + DEFAULT_LAYER_COLUMNS, + ListAnnotationLayersRequest, + serialize_annotation_layer, +) +from superset.mcp_service.mcp_core import ModelListCore + +logger = logging.getLogger(__name__) + +_DEFAULT_REQUEST = ListAnnotationLayersRequest() + +_ALL_LAYER_COLUMNS = ["id", "name", "descr", "changed_on", "created_on"] +_SORTABLE_LAYER_COLUMNS = ["id", "name", "changed_on", "created_on"] + + +@tool( + tags=["core"], + class_permission_name="Annotation", + annotations=ToolAnnotations( + title="List annotation layers", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_annotation_layers( + request: ListAnnotationLayersRequest | None = None, + ctx: Context | None = None, +) -> AnnotationLayerList | AnnotationLayerError: + """List annotation layers with filtering, search, and pagination. + + Returns annotation layer metadata including name and description. + + Sortable columns for order_column: id, name, changed_on, created_on + """ + if ctx is None: + raise RuntimeError("FastMCP context is required for list_annotation_layers") + + request = request or _DEFAULT_REQUEST.model_copy(deep=True) + + await ctx.info( + "Listing annotation layers: page=%s, page_size=%s, search=%s" + % (request.page, request.page_size, request.search) + ) + + try: + from superset.daos.annotation_layer import AnnotationLayerDAO + + def _serialize( + obj: object, cols: list[str] | None + ) -> AnnotationLayerInfo | None: + return serialize_annotation_layer(obj) + + list_tool = ModelListCore( + dao_class=AnnotationLayerDAO, + output_schema=AnnotationLayerInfo, + item_serializer=_serialize, + filter_type=AnnotationLayerFilter, + default_columns=DEFAULT_LAYER_COLUMNS, + search_columns=["name"], + list_field_name="annotation_layers", + output_list_schema=AnnotationLayerList, + all_columns=_ALL_LAYER_COLUMNS, + sortable_columns=_SORTABLE_LAYER_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_annotation_layers.query"): + result = list_tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Annotation layers listed: count=%s, total_count=%s" + % ( + len(result.annotation_layers) + if hasattr(result, "annotation_layers") + else 0, + getattr(result, "total_count", None), + ) + ) + return result + + except Exception as e: + await ctx.error( + "Annotation layer listing failed: error=%s, error_type=%s" + % (str(e), type(e).__name__) + ) + raise diff --git a/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py b/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py new file mode 100644 index 00000000000..279beae094a --- /dev/null +++ b/superset/mcp_service/annotation_layer/tool/list_layer_annotations.py @@ -0,0 +1,153 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""List annotations within a layer FastMCP tool.""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context +from superset_core.mcp.decorators import tool, ToolAnnotations + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.extensions import event_logger +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationFilter, + AnnotationInfo, + AnnotationLayerError, + AnnotationList, + DEFAULT_ANNOTATION_COLUMNS, + ListLayerAnnotationsRequest, + serialize_annotation, +) +from superset.mcp_service.mcp_core import ModelListCore + +logger = logging.getLogger(__name__) + +_ALL_ANNOTATION_COLUMNS = [ + "id", + "short_descr", + "long_descr", + "start_dttm", + "end_dttm", + "json_metadata", + "layer_id", +] +_SORTABLE_ANNOTATION_COLUMNS = ["id", "short_descr", "start_dttm", "end_dttm"] + + +@tool( + tags=["core"], + class_permission_name="Annotation", + annotations=ToolAnnotations( + title="List annotations in a layer", + readOnlyHint=True, + destructiveHint=False, + ), +) +async def list_layer_annotations( + request: ListLayerAnnotationsRequest, + ctx: Context, +) -> AnnotationList | AnnotationLayerError: + """List annotations within a specific annotation layer. + + The layer_id parameter is required and scopes all results to that layer. + + Sortable columns for order_column: id, short_descr, start_dttm, end_dttm + + Example: + ```json + {"layer_id": 1, "page": 1, "page_size": 25} + ``` + """ + await ctx.info( + "Listing annotations: layer_id=%s, page=%s, page_size=%s, search=%s" + % (request.layer_id, request.page, request.page_size, request.search) + ) + + try: + from superset.daos.annotation_layer import AnnotationDAO, AnnotationLayerDAO + + # Verify the layer exists before listing + layer = AnnotationLayerDAO.find_by_id(request.layer_id) + if layer is None: + await ctx.warning("Annotation layer not found: id=%s" % (request.layer_id,)) + return AnnotationLayerError.create( + error=f"Annotation layer with id '{request.layer_id}' not found", + error_type="not_found", + ) + + # Prepend the layer_id filter so results are scoped to this layer + layer_filter = ColumnOperator( + col="layer_id", opr=ColumnOperatorEnum.eq, value=request.layer_id + ) + combined_filters: list[ColumnOperator] = [layer_filter] + list(request.filters) + + def _serialize(obj: object, cols: list[str] | None) -> AnnotationInfo | None: + return serialize_annotation(obj) + + list_tool = ModelListCore( + dao_class=AnnotationDAO, + output_schema=AnnotationInfo, + item_serializer=_serialize, + filter_type=AnnotationFilter, + default_columns=DEFAULT_ANNOTATION_COLUMNS, + search_columns=["short_descr"], + list_field_name="annotations", + output_list_schema=AnnotationList, + all_columns=_ALL_ANNOTATION_COLUMNS, + sortable_columns=_SORTABLE_ANNOTATION_COLUMNS, + logger=logger, + ) + + with event_logger.log_context(action="mcp.list_layer_annotations.query"): + result = list_tool.run_tool( + filters=combined_filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + # Attach the layer_id to the result for caller context + result_dict = result.model_dump() + result_dict["layer_id"] = request.layer_id + # Rebuild with layer_id set + final = AnnotationList(**result_dict) + + await ctx.info( + "Annotations listed: layer_id=%s, count=%s, total_count=%s" + % ( + request.layer_id, + len(final.annotations) if hasattr(final, "annotations") else 0, + getattr(final, "total_count", None), + ) + ) + return final + + except Exception as e: + await ctx.error( + "Annotation listing failed: layer_id=%s, error=%s, error_type=%s" + % (request.layer_id, str(e), type(e).__name__) + ) + return AnnotationLayerError( + error=f"Failed to list annotations: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 81c6bd1f088..5f0711c427e 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -119,6 +119,12 @@ Dashboard Management: - generate_dashboard: Create a dashboard from chart IDs - add_chart_to_existing_dashboard: Add a chart to an existing dashboard +Annotation Layers: +- list_annotation_layers: List annotation layers with advanced filters (1-based pagination) +- get_annotation_layer_info: Get annotation layer details by ID +- list_layer_annotations: List annotations within a layer (requires layer_id, 1-based pagination) +- get_layer_annotation_info: Get annotation details by layer_id and annotation_id + Database Connections: - list_databases: List database connections with advanced filters (1-based pagination) - get_database_info: Get detailed database connection info by ID (backend, capabilities) @@ -602,6 +608,12 @@ warnings.filterwarnings( # NOTE: Always add new prompt/resource imports here when creating new prompts/resources. # Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators. # They register automatically on import, similar to tools. +from superset.mcp_service.annotation_layer.tool import ( # noqa: F401, E402 + get_annotation_layer_info, + get_layer_annotation_info, + list_annotation_layers, + list_layer_annotations, +) from superset.mcp_service.chart import ( # noqa: F401, E402 prompts as chart_prompts, resources as chart_resources, diff --git a/tests/unit_tests/mcp_service/annotation_layer/__init__.py b/tests/unit_tests/mcp_service/annotation_layer/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/annotation_layer/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/annotation_layer/tool/__init__.py b/tests/unit_tests/mcp_service/annotation_layer/tool/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/annotation_layer/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py new file mode 100644 index 00000000000..2963accd349 --- /dev/null +++ b/tests/unit_tests/mcp_service/annotation_layer/tool/test_annotation_layer_tools.py @@ -0,0 +1,434 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from unittest.mock import MagicMock, patch + +import pytest +from fastmcp import Client +from pydantic import ValidationError + +from superset.mcp_service.annotation_layer.schemas import ( + AnnotationFilter, + AnnotationLayerFilter, + ListAnnotationLayersRequest, + ListLayerAnnotationsRequest, +) +from superset.mcp_service.app import mcp +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def make_layer( + layer_id: int = 1, name: str = "My Layer", descr: str = "desc" +) -> MagicMock: + obj = MagicMock() + obj.id = layer_id + obj.name = name + obj.descr = descr + obj.changed_on = None + obj.created_on = None + return obj + + +def make_annotation( + annotation_id: int = 10, + layer_id: int = 1, + short_descr: str = "Deploy", + long_descr: str = "Deployment annotation", +) -> MagicMock: + obj = MagicMock() + obj.id = annotation_id + obj.layer_id = layer_id + obj.short_descr = short_descr + obj.long_descr = long_descr + obj.start_dttm = None + obj.end_dttm = None + obj.json_metadata = None + return obj + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +@pytest.fixture() +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + from unittest.mock import Mock + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +# --------------------------------------------------------------------------- +# Schema validation tests +# --------------------------------------------------------------------------- + + +class TestAnnotationLayerFilterSchema: + def test_valid_name_filter(self): + f = AnnotationLayerFilter(col="name", opr="eq", value="My Layer") + assert f.col == "name" + + def test_invalid_column_rejected(self): + with pytest.raises(ValidationError): + AnnotationLayerFilter(col="descr", opr="eq", value="x") + + def test_search_and_filters_mutual_exclusion(self): + with pytest.raises(ValidationError): + ListAnnotationLayersRequest( + search="foo", + filters=[{"col": "name", "opr": "eq", "value": "bar"}], + ) + + +class TestAnnotationFilterSchema: + def test_valid_short_descr_filter(self): + f = AnnotationFilter(col="short_descr", opr="eq", value="Deploy") + assert f.col == "short_descr" + + def test_invalid_column_rejected(self): + with pytest.raises(ValidationError): + AnnotationFilter(col="layer_id", opr="eq", value=1) + + def test_search_and_filters_mutual_exclusion(self): + with pytest.raises(ValidationError): + ListLayerAnnotationsRequest( + layer_id=1, + search="foo", + filters=[{"col": "short_descr", "opr": "eq", "value": "bar"}], + ) + + +# --------------------------------------------------------------------------- +# list_annotation_layers tests +# --------------------------------------------------------------------------- + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") +@pytest.mark.asyncio() +async def test_list_annotation_layers_basic(mock_list, mcp_server): + """Basic listing returns structured response with annotation layers.""" + layer = make_layer() + mock_list.return_value = ([layer], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_annotation_layers", + {"request": {"page": 1, "page_size": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["annotation_layers"] is not None + assert len(data["annotation_layers"]) == 1 + assert data["annotation_layers"][0]["id"] == 1 + assert data["annotation_layers"][0]["name"] == "My Layer" + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") +@pytest.mark.asyncio() +async def test_list_annotation_layers_empty(mock_list, mcp_server): + """Empty result set returns zero count.""" + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + result = await client.call_tool("list_annotation_layers", {}) + + data = json.loads(result.content[0].text) + assert data["annotation_layers"] == [] + assert data["total_count"] == 0 + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") +@pytest.mark.asyncio() +async def test_list_annotation_layers_search(mock_list, mcp_server): + """Search parameter is passed through to DAO.""" + layer = make_layer(name="Release Events") + mock_list.return_value = ([layer], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_annotation_layers", + {"request": {"search": "release"}}, + ) + + data = json.loads(result.content[0].text) + assert data["annotation_layers"][0]["name"] == "Release Events" + call_kwargs = mock_list.call_args.kwargs + assert call_kwargs["search"] == "release" + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.list") +@pytest.mark.asyncio() +async def test_list_annotation_layers_pagination(mock_list, mcp_server): + """Pagination metadata is correctly computed.""" + mock_list.return_value = ([], 50) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_annotation_layers", + {"request": {"page": 2, "page_size": 25}}, + ) + + data = json.loads(result.content[0].text) + assert data["page"] == 2 + assert data["page_size"] == 25 + assert data["total_count"] == 50 + assert data["total_pages"] == 2 + # Page 2 of 2, so no next page + assert data["has_next"] is False + assert data["has_previous"] is True + + +# --------------------------------------------------------------------------- +# get_annotation_layer_info tests +# --------------------------------------------------------------------------- + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_annotation_layer_info_found(mock_find, mcp_server): + """Returns annotation layer data when found.""" + mock_find.return_value = make_layer(layer_id=5, name="Prod Events") + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_annotation_layer_info", + {"request": {"id": 5}}, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 5 + assert data["name"] == "Prod Events" + mock_find.assert_called_once_with(5) + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_annotation_layer_info_not_found(mock_find, mcp_server): + """Returns error response when layer is not found.""" + mock_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_annotation_layer_info", + {"request": {"id": 999}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "999" in data["error"] + + +# --------------------------------------------------------------------------- +# list_layer_annotations tests +# --------------------------------------------------------------------------- + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.list") +@pytest.mark.asyncio() +async def test_list_layer_annotations_basic(mock_list, mock_layer_find, mcp_server): + """Annotations are listed and scoped to the specified layer.""" + mock_layer_find.return_value = make_layer(layer_id=1) + ann = make_annotation(annotation_id=10, layer_id=1) + mock_list.return_value = ([ann], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_layer_annotations", + {"request": {"layer_id": 1, "page": 1, "page_size": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["layer_id"] == 1 + assert len(data["annotations"]) == 1 + assert data["annotations"][0]["id"] == 10 + assert data["annotations"][0]["layer_id"] == 1 + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.list") +@pytest.mark.asyncio() +async def test_list_layer_annotations_layer_id_filter_prepended( + mock_list, mock_layer_find, mcp_server +): + """The layer_id filter is always prepended to DAO column_operators.""" + mock_layer_find.return_value = make_layer(layer_id=3) + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + await client.call_tool( + "list_layer_annotations", + {"request": {"layer_id": 3}}, + ) + + call_kwargs = mock_list.call_args.kwargs + filters = call_kwargs.get("column_operators", []) + # First filter must be the layer_id eq filter + assert filters, "Expected at least one filter (layer_id)" + first = filters[0] + col = first.get("col") if isinstance(first, dict) else getattr(first, "col", None) + val = ( + first.get("value") if isinstance(first, dict) else getattr(first, "value", None) + ) + assert col == "layer_id" + assert val == 3 + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@pytest.mark.asyncio() +async def test_list_layer_annotations_layer_not_found(mock_layer_find, mcp_server): + """Returns error when the layer does not exist.""" + mock_layer_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_layer_annotations", + {"request": {"layer_id": 42}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "42" in data["error"] + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.list") +@pytest.mark.asyncio() +async def test_list_layer_annotations_only_returns_own_layer( + mock_list, mock_layer_find, mcp_server +): + """Results are filtered to the requested layer only — wrong layer_id is rejected.""" + mock_layer_find.return_value = make_layer(layer_id=1) + # Simulate DAO returning annotations — the layer_id filter is applied at DB level + ann_wrong = make_annotation(annotation_id=99, layer_id=2) + mock_list.return_value = ([ann_wrong], 1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "list_layer_annotations", + {"request": {"layer_id": 1}}, + ) + + data = json.loads(result.content[0].text) + # layer_id in response header must still be 1 (the requested layer) + assert data["layer_id"] == 1 + + +# --------------------------------------------------------------------------- +# get_layer_annotation_info tests +# --------------------------------------------------------------------------- + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_layer_annotation_info_found( + mock_ann_find, mock_layer_find, mcp_server +): + """Returns annotation data when both layer and annotation are found.""" + mock_layer_find.return_value = make_layer(layer_id=1) + mock_ann_find.return_value = make_annotation(annotation_id=10, layer_id=1) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_layer_annotation_info", + {"request": {"layer_id": 1, "annotation_id": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["id"] == 10 + assert data["layer_id"] == 1 + assert data["short_descr"] == "Deploy" + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_layer_annotation_info_layer_not_found(mock_layer_find, mcp_server): + """Returns error when the layer does not exist.""" + mock_layer_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_layer_annotation_info", + {"request": {"layer_id": 99, "annotation_id": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "99" in data["error"] + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_layer_annotation_info_annotation_not_found( + mock_ann_find, mock_layer_find, mcp_server +): + """Returns error when the annotation does not exist.""" + mock_layer_find.return_value = make_layer(layer_id=1) + mock_ann_find.return_value = None + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_layer_annotation_info", + {"request": {"layer_id": 1, "annotation_id": 999}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "999" in data["error"] + + +@patch("superset.daos.annotation_layer.AnnotationLayerDAO.find_by_id") +@patch("superset.daos.annotation_layer.AnnotationDAO.find_by_id") +@pytest.mark.asyncio() +async def test_get_layer_annotation_info_wrong_layer( + mock_ann_find, mock_layer_find, mcp_server +): + """Returns error when annotation exists but belongs to a different layer.""" + mock_layer_find.return_value = make_layer(layer_id=1) + # Annotation belongs to layer 2, not layer 1 + mock_ann_find.return_value = make_annotation(annotation_id=10, layer_id=2) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_layer_annotation_info", + {"request": {"layer_id": 1, "annotation_id": 10}}, + ) + + data = json.loads(result.content[0].text) + assert data["error_type"] == "not_found" + assert "does not belong" in data["error"]