mirror of
https://github.com/apache/superset.git
synced 2026-06-11 18:49:15 +00:00
Compare commits
6 Commits
fix/chart-
...
mcp-action
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f9b0eeb92 | ||
|
|
0f926b6757 | ||
|
|
59b0914278 | ||
|
|
6a26ba70a9 | ||
|
|
ef38144f19 | ||
|
|
9a6c927eba |
16
superset/mcp_service/action_log/__init__.py
Normal file
16
superset/mcp_service/action_log/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
243
superset/mcp_service/action_log/schemas.py
Normal file
243
superset/mcp_service/action_log/schemas.py
Normal file
@@ -0,0 +1,243 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Pydantic schemas for action-log MCP tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.system.schemas import PaginationInfo
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_model_list,
|
||||
)
|
||||
|
||||
DEFAULT_LOG_COLUMNS: list[str] = ["id", "action", "user_id", "dttm"]
|
||||
ALL_LOG_COLUMNS: list[str] = [
|
||||
"id",
|
||||
"action",
|
||||
"user_id",
|
||||
"dttm",
|
||||
"dashboard_id",
|
||||
"slice_id",
|
||||
"json",
|
||||
]
|
||||
LOG_SORTABLE_COLUMNS: list[str] = ["id", "dttm"]
|
||||
|
||||
|
||||
class ActionLogFilter(ColumnOperator):
|
||||
"""Filter object for action-log listing.
|
||||
|
||||
col: Column to filter on.
|
||||
opr: Operator to use.
|
||||
value: Value to filter by.
|
||||
"""
|
||||
|
||||
col: Literal["action", "user_id", "dashboard_id", "slice_id", "dttm"] = Field(
|
||||
...,
|
||||
description="Column to filter on.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
|
||||
value: str | int | float | bool | list[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by"
|
||||
)
|
||||
|
||||
|
||||
class ActionLogInfo(BaseModel):
|
||||
id: int | None = Field(None, description="Log entry ID")
|
||||
action: str | None = Field(None, description="Action name")
|
||||
user_id: int | None = Field(
|
||||
None, description="ID of the user who performed the action"
|
||||
)
|
||||
dttm: str | datetime | None = Field(None, description="Timestamp of the action")
|
||||
dashboard_id: int | None = Field(None, description="Associated dashboard ID")
|
||||
slice_id: int | None = Field(None, description="Associated chart/slice ID")
|
||||
json: str | None = Field(None, description="JSON payload of the action")
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
ser_json_timedelta="iso8601",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
def model_post_init(self, __context: Any) -> None:
|
||||
if isinstance(self.dttm, datetime) and self.dttm.tzinfo is None:
|
||||
from datetime import timezone
|
||||
|
||||
object.__setattr__(self, "dttm", self.dttm.replace(tzinfo=timezone.utc))
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _filter_fields_by_context(self, serializer: Any, info: Any) -> dict[str, Any]:
|
||||
data = serializer(self)
|
||||
if info.context and isinstance(info.context, dict):
|
||||
select_columns = info.context.get("select_columns")
|
||||
if select_columns:
|
||||
requested_fields = set(select_columns)
|
||||
return {k: v for k, v in data.items() if k in requested_fields}
|
||||
return data
|
||||
|
||||
|
||||
class ActionLogList(BaseModel):
|
||||
action_logs: list[ActionLogInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: list[str] = Field(default_factory=list)
|
||||
columns_loaded: list[str] = Field(default_factory=list)
|
||||
columns_available: list[str] = Field(default_factory=list)
|
||||
sortable_columns: list[str] = Field(default_factory=list)
|
||||
filters_applied: list[ActionLogFilter] = Field(default_factory=list)
|
||||
pagination: PaginationInfo | None = None
|
||||
timestamp: datetime | None = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListActionLogsRequest(BaseModel):
|
||||
"""Request schema for list_action_logs."""
|
||||
|
||||
filters: Annotated[
|
||||
list[ActionLogFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"List of filter objects (col, opr, value). "
|
||||
"Filter columns: action, user_id, dashboard_id, slice_id, dttm. "
|
||||
"Cannot be used with 'search'."
|
||||
),
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
list[str],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="Columns to return. Defaults to common columns.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Text search string matched against action. "
|
||||
"Cannot be used together with 'filters'."
|
||||
),
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
str | None,
|
||||
Field(default=None, description="Column to sort by (default: dttm)"),
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(default="desc", description="Sort direction ('asc' or 'desc')"),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> list[ActionLogFilter]:
|
||||
return parse_json_or_model_list(v, ActionLogFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v: Any) -> list[str]:
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListActionLogsRequest":
|
||||
if self.search and self.filters:
|
||||
raise ValueError(
|
||||
"Cannot use both 'search' and 'filters' simultaneously. "
|
||||
"Use 'search' for text matching on action, or 'filters' for "
|
||||
"column-based filtering, but not both."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class ActionLogError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Error type")
|
||||
timestamp: str | datetime | None = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
@classmethod
|
||||
def create(cls, error: str, error_type: str) -> "ActionLogError":
|
||||
from datetime import timezone
|
||||
|
||||
return cls(
|
||||
error=error,
|
||||
error_type=error_type,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
class GetActionLogInfoRequest(BaseModel):
|
||||
"""Request schema for get_action_log_info (ID-only lookup)."""
|
||||
|
||||
identifier: Annotated[
|
||||
int,
|
||||
Field(description="Log entry ID (integer)"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_action_log_object(log: Any) -> ActionLogInfo | None:
|
||||
if not log:
|
||||
return None
|
||||
dttm = getattr(log, "dttm", None)
|
||||
if isinstance(dttm, datetime) and dttm.tzinfo is None:
|
||||
from datetime import timezone
|
||||
|
||||
dttm = dttm.replace(tzinfo=timezone.utc)
|
||||
return ActionLogInfo(
|
||||
id=getattr(log, "id", None),
|
||||
action=getattr(log, "action", None),
|
||||
user_id=getattr(log, "user_id", None),
|
||||
dttm=dttm,
|
||||
dashboard_id=getattr(log, "dashboard_id", None),
|
||||
slice_id=getattr(log, "slice_id", None),
|
||||
json=getattr(log, "json", None),
|
||||
)
|
||||
24
superset/mcp_service/action_log/tool/__init__.py
Normal file
24
superset/mcp_service/action_log/tool/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_action_log_info import get_action_log_info
|
||||
from .list_action_logs import list_action_logs
|
||||
|
||||
__all__ = [
|
||||
"list_action_logs",
|
||||
"get_action_log_info",
|
||||
]
|
||||
97
superset/mcp_service/action_log/tool/get_action_log_info.py
Normal file
97
superset/mcp_service/action_log/tool/get_action_log_info.py
Normal file
@@ -0,0 +1,97 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Get action log info MCP tool."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.action_log.schemas import (
|
||||
ActionLogError,
|
||||
ActionLogInfo,
|
||||
GetActionLogInfoRequest,
|
||||
serialize_action_log_object,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="Log",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get action log info",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_action_log_info(
|
||||
request: GetActionLogInfoRequest,
|
||||
ctx: Context,
|
||||
) -> ActionLogInfo | ActionLogError:
|
||||
"""Get a single action log entry by its integer ID.
|
||||
|
||||
Returns the action, user_id, timestamp (dttm), dashboard_id, slice_id,
|
||||
and JSON payload for the specified log record.
|
||||
|
||||
ADMIN-ONLY: This tool requires admin privileges.
|
||||
|
||||
Use list_action_logs to discover log IDs.
|
||||
"""
|
||||
await ctx.info("Retrieving action log: identifier=%s" % (request.identifier,))
|
||||
|
||||
try:
|
||||
from superset.daos.log import LogDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_action_log_info.lookup"):
|
||||
get_tool = ModelGetInfoCore(
|
||||
dao_class=LogDAO,
|
||||
output_schema=ActionLogInfo,
|
||||
error_schema=ActionLogError,
|
||||
serializer=serialize_action_log_object,
|
||||
supports_slug=False,
|
||||
logger=logger,
|
||||
)
|
||||
result = get_tool.run_tool(request.identifier)
|
||||
|
||||
if isinstance(result, ActionLogInfo):
|
||||
await ctx.info(
|
||||
"Action log retrieved: id=%s, action=%s" % (result.id, result.action)
|
||||
)
|
||||
else:
|
||||
await ctx.warning(
|
||||
"Action log retrieval failed: error_type=%s, error=%s"
|
||||
% (result.error_type, result.error)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Action log retrieval failed: identifier=%s, error=%s, error_type=%s"
|
||||
% (request.identifier, str(e), type(e).__name__)
|
||||
)
|
||||
return ActionLogError(
|
||||
error=f"Failed to get action log info: {str(e)}",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
153
superset/mcp_service/action_log/tool/list_action_logs.py
Normal file
153
superset/mcp_service/action_log/tool/list_action_logs.py
Normal file
@@ -0,0 +1,153 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""List action logs MCP tool."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.action_log.schemas import (
|
||||
ActionLogError,
|
||||
ActionLogFilter,
|
||||
ActionLogInfo,
|
||||
ActionLogList,
|
||||
ALL_LOG_COLUMNS,
|
||||
DEFAULT_LOG_COLUMNS,
|
||||
ListActionLogsRequest,
|
||||
LOG_SORTABLE_COLUMNS,
|
||||
serialize_action_log_object,
|
||||
)
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LIST_ACTION_LOGS_REQUEST = ListActionLogsRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
class_permission_name="Log",
|
||||
annotations=ToolAnnotations(
|
||||
title="List action logs",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def list_action_logs(
|
||||
request: ListActionLogsRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> ActionLogList | ActionLogError:
|
||||
"""List Superset action logs with filtering and pagination.
|
||||
|
||||
Returns audit log entries recording user interactions with dashboards and
|
||||
charts. Defaults to the last 7 days to avoid pulling large result sets.
|
||||
|
||||
ADMIN-ONLY: This tool requires admin privileges. Non-admin users will
|
||||
receive a permission error.
|
||||
|
||||
Sortable columns for order_column: id, dttm
|
||||
Filter columns: action, user_id, dashboard_id, slice_id, dttm
|
||||
|
||||
When no dttm filter is provided the tool automatically applies
|
||||
dttm >= (now - 7 days). Add an explicit dttm filter to override.
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_action_logs")
|
||||
|
||||
request = request or _DEFAULT_LIST_ACTION_LOGS_REQUEST.model_copy(deep=True)
|
||||
|
||||
await ctx.info(
|
||||
"Listing action logs: page=%s, page_size=%s" % (request.page, request.page_size)
|
||||
)
|
||||
await ctx.debug(
|
||||
"Action log parameters: filters=%s, order_column=%s, order_direction=%s"
|
||||
% (request.filters, request.order_column, request.order_direction)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.log import LogDAO
|
||||
|
||||
# Inject default 7-day dttm filter unless caller already provides one
|
||||
filters: list[ColumnOperator] = list(request.filters)
|
||||
has_dttm_filter = any(getattr(f, "col", None) == "dttm" for f in filters)
|
||||
if not has_dttm_filter:
|
||||
cutoff = (datetime.now(timezone.utc) - timedelta(days=7)).isoformat()
|
||||
default_filter = ActionLogFilter(
|
||||
col="dttm",
|
||||
opr=ColumnOperatorEnum.gte,
|
||||
value=cutoff,
|
||||
)
|
||||
filters = [default_filter] + filters
|
||||
await ctx.debug("Applied default 7-day dttm filter: cutoff=%s" % (cutoff,))
|
||||
|
||||
def _serialize(obj: object, cols: list[str] | None) -> ActionLogInfo | None:
|
||||
return serialize_action_log_object(obj)
|
||||
|
||||
list_tool = ModelListCore(
|
||||
dao_class=LogDAO,
|
||||
output_schema=ActionLogInfo,
|
||||
item_serializer=_serialize,
|
||||
filter_type=ActionLogFilter,
|
||||
default_columns=DEFAULT_LOG_COLUMNS,
|
||||
search_columns=["action"],
|
||||
list_field_name="action_logs",
|
||||
output_list_schema=ActionLogList,
|
||||
all_columns=ALL_LOG_COLUMNS,
|
||||
sortable_columns=LOG_SORTABLE_COLUMNS,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
with event_logger.log_context(action="mcp.list_action_logs.query"):
|
||||
result = list_tool.run_tool(
|
||||
filters=filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=request.order_column or "dttm",
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Action logs listed: count=%s, total_count=%s"
|
||||
% (
|
||||
len(result.action_logs) if hasattr(result, "action_logs") else 0,
|
||||
getattr(result, "total_count", None),
|
||||
)
|
||||
)
|
||||
columns_to_filter = result.columns_requested
|
||||
await ctx.debug(
|
||||
"Applying field filtering via serialization context: columns=%s"
|
||||
% (columns_to_filter,)
|
||||
)
|
||||
with event_logger.log_context(action="mcp.list_action_logs.serialization"):
|
||||
return result.model_dump(
|
||||
mode="json",
|
||||
context={"select_columns": columns_to_filter},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Action log listing failed: page=%s, error=%s, error_type=%s"
|
||||
% (request.page, str(e), type(e).__name__)
|
||||
)
|
||||
raise
|
||||
@@ -602,6 +602,10 @@ warnings.filterwarnings(
|
||||
# NOTE: Always add new prompt/resource imports here when creating new prompts/resources.
|
||||
# Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators.
|
||||
# They register automatically on import, similar to tools.
|
||||
from superset.mcp_service.action_log.tool import ( # noqa: F401, E402
|
||||
get_action_log_info,
|
||||
list_action_logs,
|
||||
)
|
||||
from superset.mcp_service.chart import ( # noqa: F401, E402
|
||||
prompts as chart_prompts,
|
||||
resources as chart_resources,
|
||||
@@ -652,6 +656,10 @@ from superset.mcp_service.system.tool import ( # noqa: F401, E402
|
||||
get_schema,
|
||||
health_check,
|
||||
)
|
||||
from superset.mcp_service.task.tool import ( # noqa: F401, E402
|
||||
get_task_info,
|
||||
list_tasks,
|
||||
)
|
||||
|
||||
|
||||
def _remove_disabled_tools(disabled_tools: set[str]) -> None:
|
||||
|
||||
16
superset/mcp_service/task/__init__.py
Normal file
16
superset/mcp_service/task/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
240
superset/mcp_service/task/schemas.py
Normal file
240
superset/mcp_service/task/schemas.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Pydantic schemas for task MCP tools."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.system.schemas import PaginationInfo
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_model_list,
|
||||
)
|
||||
|
||||
DEFAULT_TASK_COLUMNS: list[str] = ["id", "uuid", "task_type", "status", "changed_on"]
|
||||
ALL_TASK_COLUMNS: list[str] = [
|
||||
"id",
|
||||
"uuid",
|
||||
"task_type",
|
||||
"task_key",
|
||||
"task_name",
|
||||
"status",
|
||||
"scope",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
TASK_SORTABLE_COLUMNS: list[str] = ["id", "changed_on", "created_on", "status"]
|
||||
|
||||
|
||||
class TaskColumnFilter(ColumnOperator):
|
||||
"""Filter object for task listing.
|
||||
|
||||
col: Column to filter on.
|
||||
opr: Operator to use.
|
||||
value: Value to filter by.
|
||||
"""
|
||||
|
||||
col: Literal["task_type", "status", "scope"] = Field(
|
||||
...,
|
||||
description="Column to filter on.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
|
||||
value: str | int | float | bool | list[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by"
|
||||
)
|
||||
|
||||
|
||||
class TaskInfo(BaseModel):
|
||||
id: int | None = Field(None, description="Task ID")
|
||||
uuid: str | None = Field(None, description="Task UUID")
|
||||
task_type: str | None = Field(None, description="Task type (e.g., sql_execution)")
|
||||
task_key: str | None = Field(None, description="Task deduplication key")
|
||||
task_name: str | None = Field(None, description="Human-readable task name")
|
||||
status: str | None = Field(None, description="Task status")
|
||||
scope: str | None = Field(None, description="Task scope (private/shared/system)")
|
||||
changed_on: str | datetime | None = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
created_on: str | datetime | None = Field(None, description="Creation timestamp")
|
||||
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
ser_json_timedelta="iso8601",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _filter_fields_by_context(self, serializer: Any, info: Any) -> dict[str, Any]:
|
||||
data = serializer(self)
|
||||
if info.context and isinstance(info.context, dict):
|
||||
select_columns = info.context.get("select_columns")
|
||||
if select_columns:
|
||||
requested_fields = set(select_columns)
|
||||
return {k: v for k, v in data.items() if k in requested_fields}
|
||||
return data
|
||||
|
||||
|
||||
class TaskList(BaseModel):
|
||||
tasks: list[TaskInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: list[str] = Field(default_factory=list)
|
||||
columns_loaded: list[str] = Field(default_factory=list)
|
||||
columns_available: list[str] = Field(default_factory=list)
|
||||
sortable_columns: list[str] = Field(default_factory=list)
|
||||
filters_applied: list[TaskColumnFilter] = Field(default_factory=list)
|
||||
pagination: PaginationInfo | None = None
|
||||
timestamp: datetime | None = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListTasksRequest(BaseModel):
|
||||
"""Request schema for list_tasks."""
|
||||
|
||||
filters: Annotated[
|
||||
list[TaskColumnFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description=(
|
||||
"List of filter objects (col, opr, value). "
|
||||
"Filter columns: task_type, status, scope. "
|
||||
"Cannot be used with 'search'."
|
||||
),
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
list[str],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="Columns to return. Defaults to common columns.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Text search string matched against task_type, task_key, "
|
||||
"task_name, status, and scope. "
|
||||
"Cannot be used together with 'filters'."
|
||||
),
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
str | None,
|
||||
Field(default=None, description="Column to sort by (default: changed_on)"),
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(default="desc", description="Sort direction ('asc' or 'desc')"),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> list[TaskColumnFilter]:
|
||||
return parse_json_or_model_list(v, TaskColumnFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v: Any) -> list[str]:
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListTasksRequest":
|
||||
if self.search and self.filters:
|
||||
raise ValueError(
|
||||
"Cannot use both 'search' and 'filters' simultaneously. "
|
||||
"Use 'search' for text matching on task_type/status/scope, or "
|
||||
"'filters' for column-based filtering, but not both."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class TaskError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Error type")
|
||||
timestamp: str | datetime | None = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
@classmethod
|
||||
def create(cls, error: str, error_type: str) -> "TaskError":
|
||||
from datetime import timezone
|
||||
|
||||
return cls(
|
||||
error=error,
|
||||
error_type=error_type,
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
class GetTaskInfoRequest(BaseModel):
|
||||
"""Request schema for get_task_info (ID or UUID lookup)."""
|
||||
|
||||
identifier: Annotated[
|
||||
int | str,
|
||||
Field(description="Task identifier — numeric ID or UUID string"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_task_object(task: Any) -> TaskInfo | None:
|
||||
if not task:
|
||||
return None
|
||||
uuid_val = getattr(task, "uuid", None)
|
||||
return TaskInfo(
|
||||
id=getattr(task, "id", None),
|
||||
uuid=str(uuid_val) if uuid_val is not None else None,
|
||||
task_type=getattr(task, "task_type", None),
|
||||
task_key=getattr(task, "task_key", None),
|
||||
task_name=getattr(task, "task_name", None),
|
||||
status=getattr(task, "status", None),
|
||||
scope=getattr(task, "scope", None),
|
||||
changed_on=getattr(task, "changed_on", None),
|
||||
created_on=getattr(task, "created_on", None),
|
||||
)
|
||||
24
superset/mcp_service/task/tool/__init__.py
Normal file
24
superset/mcp_service/task/tool/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_task_info import get_task_info
|
||||
from .list_tasks import list_tasks
|
||||
|
||||
__all__ = [
|
||||
"list_tasks",
|
||||
"get_task_info",
|
||||
]
|
||||
108
superset/mcp_service/task/tool/get_task_info.py
Normal file
108
superset/mcp_service/task/tool/get_task_info.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Get task info MCP tool."""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.task.schemas import (
|
||||
GetTaskInfoRequest,
|
||||
serialize_task_object,
|
||||
TaskError,
|
||||
TaskInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="Task",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get task info",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_task_info(
|
||||
request: GetTaskInfoRequest,
|
||||
ctx: Context,
|
||||
) -> TaskInfo | TaskError:
|
||||
"""Get details for a single async task by ID or UUID.
|
||||
|
||||
Returns task_type, status, scope, and timestamps for the specified task.
|
||||
Non-admin users can only retrieve tasks they are subscribed to.
|
||||
|
||||
Use list_tasks to discover task IDs and UUIDs.
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{"identifier": 42}
|
||||
```
|
||||
|
||||
Or with UUID:
|
||||
```json
|
||||
{"identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab"}
|
||||
```
|
||||
"""
|
||||
await ctx.info("Retrieving task: identifier=%s" % (request.identifier,))
|
||||
|
||||
try:
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_task_info.lookup"):
|
||||
# ModelGetInfoCore handles int ID and UUID string automatically.
|
||||
# TaskDAO.base_filter (TaskFilter) enforces subscription-based access.
|
||||
get_tool = ModelGetInfoCore(
|
||||
dao_class=TaskDAO,
|
||||
output_schema=TaskInfo,
|
||||
error_schema=TaskError,
|
||||
serializer=serialize_task_object,
|
||||
supports_slug=False,
|
||||
logger=logger,
|
||||
)
|
||||
result = get_tool.run_tool(request.identifier)
|
||||
|
||||
if isinstance(result, TaskInfo):
|
||||
await ctx.info(
|
||||
"Task retrieved: id=%s, task_type=%s, status=%s"
|
||||
% (result.id, result.task_type, result.status)
|
||||
)
|
||||
else:
|
||||
await ctx.warning(
|
||||
"Task retrieval failed: error_type=%s, error=%s"
|
||||
% (result.error_type, result.error)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Task retrieval failed: identifier=%s, error=%s, error_type=%s"
|
||||
% (request.identifier, str(e), type(e).__name__)
|
||||
)
|
||||
return TaskError(
|
||||
error=f"Failed to get task info: {str(e)}",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
140
superset/mcp_service/task/tool/list_tasks.py
Normal file
140
superset/mcp_service/task/tool/list_tasks.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""List tasks MCP tool."""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.task.schemas import (
|
||||
ALL_TASK_COLUMNS,
|
||||
DEFAULT_TASK_COLUMNS,
|
||||
ListTasksRequest,
|
||||
serialize_task_object,
|
||||
TASK_SORTABLE_COLUMNS,
|
||||
TaskColumnFilter,
|
||||
TaskError,
|
||||
TaskInfo,
|
||||
TaskList,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LIST_TASKS_REQUEST = ListTasksRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
class_permission_name="Task",
|
||||
annotations=ToolAnnotations(
|
||||
title="List tasks",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def list_tasks(
|
||||
request: ListTasksRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> TaskList | TaskError:
|
||||
"""List async tasks with filtering and pagination.
|
||||
|
||||
Returns tasks visible to the current user. Non-admin users only see tasks
|
||||
they are subscribed to (task creators are auto-subscribed). Admins see all
|
||||
tasks.
|
||||
|
||||
Sortable columns for order_column: id, changed_on, created_on, status
|
||||
Filter columns: task_type, status, scope
|
||||
Search columns (via search=): task_type, task_key, task_name, status, scope
|
||||
|
||||
Common task_type values: sql_execution, thumbnail, report
|
||||
Common status values: pending, in_progress, success, failure, aborted
|
||||
Common scope values: private, shared, system
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_tasks")
|
||||
|
||||
request = request or _DEFAULT_LIST_TASKS_REQUEST.model_copy(deep=True)
|
||||
|
||||
await ctx.info(
|
||||
"Listing tasks: page=%s, page_size=%s" % (request.page, request.page_size)
|
||||
)
|
||||
await ctx.debug(
|
||||
"Task parameters: filters=%s, order_column=%s, order_direction=%s"
|
||||
% (request.filters, request.order_column, request.order_direction)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.tasks import TaskDAO
|
||||
|
||||
def _serialize(obj: object, cols: list[str] | None) -> TaskInfo | None:
|
||||
return serialize_task_object(obj)
|
||||
|
||||
# TaskDAO.base_filter = TaskFilter automatically scopes results:
|
||||
# non-admins only see their subscribed tasks; admins see all.
|
||||
list_tool = ModelListCore(
|
||||
dao_class=TaskDAO,
|
||||
output_schema=TaskInfo,
|
||||
item_serializer=_serialize,
|
||||
filter_type=TaskColumnFilter,
|
||||
default_columns=DEFAULT_TASK_COLUMNS,
|
||||
search_columns=["task_type", "task_key", "task_name", "status", "scope"],
|
||||
list_field_name="tasks",
|
||||
output_list_schema=TaskList,
|
||||
all_columns=ALL_TASK_COLUMNS,
|
||||
sortable_columns=TASK_SORTABLE_COLUMNS,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
with event_logger.log_context(action="mcp.list_tasks.query"):
|
||||
result = list_tool.run_tool(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=request.order_column,
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Tasks listed: count=%s, total_count=%s"
|
||||
% (
|
||||
len(result.tasks) if hasattr(result, "tasks") else 0,
|
||||
getattr(result, "total_count", None),
|
||||
)
|
||||
)
|
||||
columns_to_filter = result.columns_requested
|
||||
await ctx.debug(
|
||||
"Applying field filtering via serialization context: columns=%s"
|
||||
% (columns_to_filter,)
|
||||
)
|
||||
with event_logger.log_context(action="mcp.list_tasks.serialization"):
|
||||
return result.model_dump(
|
||||
mode="json",
|
||||
context={"select_columns": columns_to_filter},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Task listing failed: page=%s, error=%s, error_type=%s"
|
||||
% (request.page, str(e), type(e).__name__)
|
||||
)
|
||||
raise
|
||||
16
tests/unit_tests/mcp_service/action_log/__init__.py
Normal file
16
tests/unit_tests/mcp_service/action_log/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/action_log/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/action_log/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,221 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for list_action_logs and get_action_log_info MCP tools."""
|
||||
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.action_log.schemas import (
|
||||
ActionLogFilter,
|
||||
ListActionLogsRequest,
|
||||
)
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
|
||||
def create_mock_log(
|
||||
log_id: int = 1,
|
||||
action: str = "log",
|
||||
user_id: int = 1,
|
||||
dashboard_id: int | None = None,
|
||||
slice_id: int | None = None,
|
||||
json_payload: str | None = None,
|
||||
dttm: datetime | None = None,
|
||||
) -> MagicMock:
|
||||
log = MagicMock()
|
||||
log.id = log_id
|
||||
log.action = action
|
||||
log.user_id = user_id
|
||||
log.dashboard_id = dashboard_id
|
||||
log.slice_id = slice_id
|
||||
log.json = json_payload or '{"event_name": "explore_slice"}'
|
||||
log.dttm = dttm or datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
|
||||
log.duration_ms = None
|
||||
log.referrer = None
|
||||
return log
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
from unittest.mock import Mock
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
class TestActionLogFilterSchema:
|
||||
def test_valid_filter_columns_accepted(self):
|
||||
for col in ("action", "user_id", "dashboard_id", "slice_id", "dttm"):
|
||||
f = ActionLogFilter(col=col, opr="eq", value="test")
|
||||
assert f.col == col
|
||||
|
||||
def test_invalid_filter_column_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ActionLogFilter(col="not_a_column", opr="eq", value="x")
|
||||
|
||||
def test_created_by_fk_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
ActionLogFilter(col="created_by_fk", opr="eq", value=1)
|
||||
|
||||
|
||||
@patch("superset.daos.log.LogDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_action_logs_basic(mock_list, mcp_server):
|
||||
"""Basic listing returns structured response."""
|
||||
log = create_mock_log()
|
||||
mock_list.return_value = ([log], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_action_logs", {})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["action_logs"] is not None
|
||||
assert len(data["action_logs"]) == 1
|
||||
assert data["action_logs"][0]["id"] == 1
|
||||
assert data["action_logs"][0]["action"] == "log"
|
||||
assert data["action_logs"][0]["user_id"] == 1
|
||||
|
||||
|
||||
@patch("superset.daos.log.LogDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_action_logs_default_7day_filter_applied(mock_list, mcp_server):
|
||||
"""When no dttm filter is provided, a 7-day filter is injected automatically."""
|
||||
mock_list.return_value = ([], 0)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_action_logs", {})
|
||||
|
||||
# Verify list() was called with a dttm filter in column_operators
|
||||
call_kwargs = mock_list.call_args.kwargs
|
||||
col_operators = call_kwargs.get("column_operators", [])
|
||||
dttm_filters = [f for f in col_operators if getattr(f, "col", None) == "dttm"]
|
||||
assert len(dttm_filters) == 1
|
||||
assert dttm_filters[0].opr == "gte"
|
||||
|
||||
# Verify the injected filter appears in the serialized filters_applied
|
||||
data = json.loads(result.content[0].text)
|
||||
filters_applied = data.get("filters_applied", [])
|
||||
dttm_applied = [f for f in filters_applied if f.get("col") == "dttm"]
|
||||
assert len(dttm_applied) == 1
|
||||
assert dttm_applied[0]["opr"] == "gte"
|
||||
assert isinstance(dttm_applied[0]["value"], str) # ISO string, not datetime
|
||||
|
||||
|
||||
@patch("superset.daos.log.LogDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_action_logs_explicit_dttm_filter_skips_default(
|
||||
mock_list, mcp_server
|
||||
):
|
||||
"""When a dttm filter is provided, the default 7-day filter is NOT injected."""
|
||||
mock_list.return_value = ([], 0)
|
||||
|
||||
request = ListActionLogsRequest(
|
||||
filters=[{"col": "dttm", "opr": "gte", "value": "2020-01-01T00:00:00"}]
|
||||
)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool("list_action_logs", {"request": request.model_dump()})
|
||||
|
||||
call_kwargs = mock_list.call_args.kwargs
|
||||
col_operators = call_kwargs.get("column_operators", [])
|
||||
dttm_filters = [f for f in col_operators if getattr(f, "col", None) == "dttm"]
|
||||
# Only the user-provided filter, not the injected default
|
||||
assert len(dttm_filters) == 1
|
||||
assert dttm_filters[0].value == "2020-01-01T00:00:00"
|
||||
|
||||
|
||||
@patch("superset.daos.log.LogDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_action_logs_default_sort_is_dttm_desc(mock_list, mcp_server):
|
||||
"""Default sort is dttm descending (most recent first)."""
|
||||
mock_list.return_value = ([], 0)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool("list_action_logs", {})
|
||||
|
||||
call_kwargs = mock_list.call_args.kwargs
|
||||
assert call_kwargs.get("order_column") == "dttm"
|
||||
assert call_kwargs.get("order_direction") == "desc"
|
||||
|
||||
|
||||
@patch("superset.daos.log.LogDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_action_logs_pagination(mock_list, mcp_server):
|
||||
"""Pagination metadata is correct."""
|
||||
logs = [create_mock_log(log_id=i) for i in range(1, 6)]
|
||||
mock_list.return_value = (logs, 20)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListActionLogsRequest(page=1, page_size=5)
|
||||
result = await client.call_tool(
|
||||
"list_action_logs", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 5
|
||||
assert data["total_count"] == 20
|
||||
assert data["page"] == 1
|
||||
assert data["page_size"] == 5
|
||||
assert data["has_next"] is True
|
||||
assert data["has_previous"] is False
|
||||
|
||||
|
||||
@patch("superset.daos.log.LogDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_action_log_info_basic(mock_find, mcp_server):
|
||||
"""get_action_log_info returns log details by integer ID."""
|
||||
log = create_mock_log(log_id=42, action="explore_chart", user_id=7)
|
||||
mock_find.return_value = log
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_action_log_info", {"request": {"identifier": 42}}
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 42
|
||||
assert data["action"] == "explore_chart"
|
||||
assert data["user_id"] == 7
|
||||
|
||||
|
||||
@patch("superset.daos.log.LogDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_action_log_info_not_found(mock_find, mcp_server):
|
||||
"""get_action_log_info returns error when log does not exist."""
|
||||
mock_find.return_value = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_action_log_info", {"request": {"identifier": 9999}}
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "not_found"
|
||||
16
tests/unit_tests/mcp_service/task/__init__.py
Normal file
16
tests/unit_tests/mcp_service/task/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/task/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/task/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
249
tests/unit_tests/mcp_service/task/tool/test_task_tools.py
Normal file
249
tests/unit_tests/mcp_service/task/tool/test_task_tools.py
Normal file
@@ -0,0 +1,249 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for list_tasks and get_task_info MCP tools."""
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.task.schemas import ListTasksRequest, TaskColumnFilter
|
||||
from superset.utils import json
|
||||
|
||||
SAMPLE_UUID = str(uuid.uuid4())
|
||||
|
||||
|
||||
def create_mock_task(
|
||||
task_id: int = 1,
|
||||
task_uuid: str | None = None,
|
||||
task_type: str = "sql_execution",
|
||||
task_key: str = "default-key",
|
||||
task_name: str | None = None,
|
||||
status: str = "success",
|
||||
scope: str = "private",
|
||||
changed_on: datetime | None = None,
|
||||
created_on: datetime | None = None,
|
||||
) -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.id = task_id
|
||||
task.uuid = task_uuid or SAMPLE_UUID
|
||||
task.task_type = task_type
|
||||
task.task_key = task_key
|
||||
task.task_name = task_name
|
||||
task.status = status
|
||||
task.scope = scope
|
||||
task.changed_on = changed_on or datetime(2024, 1, 2, 10, 0, 0, tzinfo=timezone.utc)
|
||||
task.created_on = created_on or datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
|
||||
return task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
from unittest.mock import Mock
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "testuser"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
class TestTaskColumnFilterSchema:
|
||||
def test_valid_filter_columns_accepted(self):
|
||||
for col in ("task_type", "status", "scope"):
|
||||
f = TaskColumnFilter(col=col, opr="eq", value="test")
|
||||
assert f.col == col
|
||||
|
||||
def test_invalid_filter_column_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TaskColumnFilter(col="user_id", opr="eq", value=1)
|
||||
|
||||
def test_uuid_column_rejected(self):
|
||||
with pytest.raises(ValidationError):
|
||||
TaskColumnFilter(col="uuid", opr="eq", value="some-uuid")
|
||||
|
||||
|
||||
@patch("superset.daos.tasks.TaskDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_basic(mock_list, mcp_server):
|
||||
"""Basic task listing returns structured response."""
|
||||
task = create_mock_task()
|
||||
mock_list.return_value = ([task], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_tasks", {})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["tasks"] is not None
|
||||
assert len(data["tasks"]) == 1
|
||||
assert data["tasks"][0]["id"] == 1
|
||||
assert data["tasks"][0]["task_type"] == "sql_execution"
|
||||
assert data["tasks"][0]["status"] == "success"
|
||||
|
||||
|
||||
@patch("superset.daos.tasks.TaskDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_with_status_filter(mock_list, mcp_server):
|
||||
"""Status filter is passed through to the DAO correctly."""
|
||||
task = create_mock_task(status="pending")
|
||||
mock_list.return_value = ([task], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListTasksRequest(
|
||||
filters=[{"col": "status", "opr": "eq", "value": "pending"}]
|
||||
)
|
||||
result = await client.call_tool("list_tasks", {"request": request.model_dump()})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert len(data["tasks"]) == 1
|
||||
assert data["tasks"][0]["status"] == "pending"
|
||||
|
||||
|
||||
@patch("superset.daos.tasks.TaskDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_taskfilter_scoping_is_applied(mock_list, mcp_server):
|
||||
"""TaskDAO.list is called with base_filter (TaskFilter) applied via DAO class."""
|
||||
mock_list.return_value = ([], 0)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
await client.call_tool("list_tasks", {})
|
||||
|
||||
# Verify the DAO's list() is called — the TaskFilter scoping is enforced
|
||||
# by TaskDAO.base_filter = TaskFilter which BaseDAO applies automatically.
|
||||
assert mock_list.called
|
||||
|
||||
|
||||
@patch("superset.daos.tasks.TaskDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_pagination(mock_list, mcp_server):
|
||||
"""Pagination metadata is correct."""
|
||||
tasks = [create_mock_task(task_id=i) for i in range(1, 4)]
|
||||
mock_list.return_value = (tasks, 30)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListTasksRequest(page=2, page_size=3)
|
||||
result = await client.call_tool("list_tasks", {"request": request.model_dump()})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["count"] == 3
|
||||
assert data["total_count"] == 30
|
||||
assert data["page"] == 2
|
||||
assert data["page_size"] == 3
|
||||
assert data["has_previous"] is True
|
||||
assert data["has_next"] is True
|
||||
|
||||
|
||||
@patch("superset.daos.tasks.TaskDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_uuid_in_response(mock_list, mcp_server):
|
||||
"""Task UUID is returned as a string in the response."""
|
||||
task_uuid = str(uuid.uuid4())
|
||||
task = create_mock_task(task_uuid=task_uuid)
|
||||
mock_list.return_value = ([task], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_tasks", {})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["tasks"][0]["uuid"] == task_uuid
|
||||
|
||||
|
||||
@patch("superset.daos.tasks.TaskDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_info_by_integer_id(mock_find, mcp_server):
|
||||
"""get_task_info resolves a task by integer ID."""
|
||||
task = create_mock_task(task_id=5, task_type="thumbnail", status="in_progress")
|
||||
mock_find.return_value = task
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_task_info", {"request": {"identifier": 5}})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 5
|
||||
assert data["task_type"] == "thumbnail"
|
||||
assert data["status"] == "in_progress"
|
||||
|
||||
|
||||
@patch("superset.daos.tasks.TaskDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_info_by_uuid(mock_find, mcp_server):
|
||||
"""get_task_info resolves a task by UUID string."""
|
||||
task_uuid = str(uuid.uuid4())
|
||||
task = create_mock_task(task_id=10, task_uuid=task_uuid, status="success")
|
||||
mock_find.return_value = task
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_task_info", {"request": {"identifier": task_uuid}}
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 10
|
||||
assert data["status"] == "success"
|
||||
|
||||
|
||||
@patch("superset.daos.tasks.TaskDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_task_info_not_found(mock_find, mcp_server):
|
||||
"""get_task_info returns error when task does not exist."""
|
||||
mock_find.return_value = None
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_task_info", {"request": {"identifier": 9999}}
|
||||
)
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@patch("superset.daos.tasks.TaskDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_non_admin_sees_only_subscribed(mock_list, mcp_server):
|
||||
"""Non-admin users only receive tasks their subscriptions allow.
|
||||
|
||||
The subscription scoping is enforced by TaskDAO.base_filter = TaskFilter,
|
||||
which BaseDAO._apply_base_filter injects before each query. The MCP tool
|
||||
itself adds no extra filtering — it just delegates to TaskDAO.list(), which
|
||||
carries the filter class. This test confirms that:
|
||||
|
||||
1. list_tasks calls TaskDAO.list() (so the base_filter runs)
|
||||
2. Only items returned by that call appear in the response
|
||||
"""
|
||||
# Simulate DAO returning only the subscribed task
|
||||
subscribed_task = create_mock_task(task_id=42, status="pending")
|
||||
mock_list.return_value = ([subscribed_task], 1)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_tasks", {})
|
||||
|
||||
data = json.loads(result.content[0].text)
|
||||
assert len(data["tasks"]) == 1
|
||||
assert data["tasks"][0]["id"] == 42
|
||||
# TaskDAO.list was called exactly once — base_filter is applied inside
|
||||
assert mock_list.call_count == 1
|
||||
Reference in New Issue
Block a user