Compare commits

...

6 Commits

Author SHA1 Message Date
Amin Ghadersohi
0f9b0eeb92 ci: trigger CI for fix 2026-05-21 19:24:08 +00:00
Amin Ghadersohi
0f926b6757 fix(mcp): use ActionLogFilter for injected default dttm filter
Pydantic v2 rejects a ColumnOperator instance when validating
list[ActionLogFilter] — it requires an exact instance or dict, not
a parent-class instance. The injected 7-day default dttm filter was
created as a plain ColumnOperator, causing every test_list_action_logs_*
call to fail with '1 validation error for ActionLogList'.

Fix: construct the default filter as ActionLogFilter (which is a
subclass of ColumnOperator), so it passes pydantic validation for
ActionLogList.filters_applied: list[ActionLogFilter] and is still
accepted everywhere ColumnOperator is expected.
2026-05-21 19:24:08 +00:00
Amin Ghadersohi
59b0914278 fix(mcp): add task_key/task_name to TaskInfo and strengthen test coverage
- Add task_key and task_name fields to TaskInfo schema and ALL_TASK_COLUMNS;
  these are real Task model columns present in the REST API search_columns
- Expand search_columns in list_tasks to include task_key and task_name
- Strengthen test_list_action_logs_default_7day_filter_applied to also
  assert the injected filter appears in filters_applied with an ISO string value

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 19:24:08 +00:00
Amin Ghadersohi
6a26ba70a9 fix(mcp): field filtering and search for action-log and task list tools
- Add model_serializer to ActionLogInfo and TaskInfo that drops
  non-requested fields from output when select_columns context is set,
  matching the DatabaseInfo pattern
- Switch list_action_logs and list_tasks to return model_dump with
  serialization context so only requested columns appear in responses
- Add search field + search-XOR-filters validator to
  ListActionLogsRequest and ListTasksRequest
- Pass search=request.search through to ModelListCore.run_tool()

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 19:24:08 +00:00
Amin Ghadersohi
ef38144f19 fix(mcp): convert dttm cutoff to ISO string so filters_applied validates
The injected 7-day default filter used a datetime object as the value,
but ActionLogFilter.value only allows str|int|float|bool|list. Pydantic
rejects the datetime when building the filters_applied list in
ActionLogList, causing a ValidationError on every call that triggered
the default filter.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 19:24:08 +00:00
Amin Ghadersohi
9a6c927eba feat(mcp): add list and get tools for action log and tasks
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-21 19:24:08 +00:00
17 changed files with 1603 additions and 0 deletions

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,243 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pydantic schemas for action-log MCP tools."""
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_serializer,
model_validator,
PositiveInt,
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.system.schemas import PaginationInfo
from superset.mcp_service.utils.schema_utils import (
parse_json_or_list,
parse_json_or_model_list,
)
DEFAULT_LOG_COLUMNS: list[str] = ["id", "action", "user_id", "dttm"]
ALL_LOG_COLUMNS: list[str] = [
"id",
"action",
"user_id",
"dttm",
"dashboard_id",
"slice_id",
"json",
]
LOG_SORTABLE_COLUMNS: list[str] = ["id", "dttm"]
class ActionLogFilter(ColumnOperator):
"""Filter object for action-log listing.
col: Column to filter on.
opr: Operator to use.
value: Value to filter by.
"""
col: Literal["action", "user_id", "dashboard_id", "slice_id", "dttm"] = Field(
...,
description="Column to filter on.",
)
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
value: str | int | float | bool | list[str | int | float | bool] = Field(
..., description="Value to filter by"
)
class ActionLogInfo(BaseModel):
id: int | None = Field(None, description="Log entry ID")
action: str | None = Field(None, description="Action name")
user_id: int | None = Field(
None, description="ID of the user who performed the action"
)
dttm: str | datetime | None = Field(None, description="Timestamp of the action")
dashboard_id: int | None = Field(None, description="Associated dashboard ID")
slice_id: int | None = Field(None, description="Associated chart/slice ID")
json: str | None = Field(None, description="JSON payload of the action")
model_config = ConfigDict(
from_attributes=True,
ser_json_timedelta="iso8601",
populate_by_name=True,
)
def model_post_init(self, __context: Any) -> None:
if isinstance(self.dttm, datetime) and self.dttm.tzinfo is None:
from datetime import timezone
object.__setattr__(self, "dttm", self.dttm.replace(tzinfo=timezone.utc))
@model_serializer(mode="wrap")
def _filter_fields_by_context(self, serializer: Any, info: Any) -> dict[str, Any]:
data = serializer(self)
if info.context and isinstance(info.context, dict):
select_columns = info.context.get("select_columns")
if select_columns:
requested_fields = set(select_columns)
return {k: v for k, v in data.items() if k in requested_fields}
return data
class ActionLogList(BaseModel):
action_logs: list[ActionLogInfo]
count: int
total_count: int
page: int
page_size: int
total_pages: int
has_previous: bool
has_next: bool
columns_requested: list[str] = Field(default_factory=list)
columns_loaded: list[str] = Field(default_factory=list)
columns_available: list[str] = Field(default_factory=list)
sortable_columns: list[str] = Field(default_factory=list)
filters_applied: list[ActionLogFilter] = Field(default_factory=list)
pagination: PaginationInfo | None = None
timestamp: datetime | None = None
model_config = ConfigDict(ser_json_timedelta="iso8601")
class ListActionLogsRequest(BaseModel):
"""Request schema for list_action_logs."""
filters: Annotated[
list[ActionLogFilter],
Field(
default_factory=list,
description=(
"List of filter objects (col, opr, value). "
"Filter columns: action, user_id, dashboard_id, slice_id, dttm. "
"Cannot be used with 'search'."
),
),
]
select_columns: Annotated[
list[str],
Field(
default_factory=list,
description="Columns to return. Defaults to common columns.",
),
]
search: Annotated[
str | None,
Field(
default=None,
description=(
"Text search string matched against action. "
"Cannot be used together with 'filters'."
),
),
]
order_column: Annotated[
str | None,
Field(default=None, description="Column to sort by (default: dttm)"),
]
order_direction: Annotated[
Literal["asc", "desc"],
Field(default="desc", description="Sort direction ('asc' or 'desc')"),
]
page: Annotated[
PositiveInt,
Field(default=1, description="Page number (1-based)"),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> list[ActionLogFilter]:
return parse_json_or_model_list(v, ActionLogFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_columns(cls, v: Any) -> list[str]:
return parse_json_or_list(v, "select_columns")
@model_validator(mode="after")
def validate_search_and_filters(self) -> "ListActionLogsRequest":
if self.search and self.filters:
raise ValueError(
"Cannot use both 'search' and 'filters' simultaneously. "
"Use 'search' for text matching on action, or 'filters' for "
"column-based filtering, but not both."
)
return self
class ActionLogError(BaseModel):
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Error type")
timestamp: str | datetime | None = Field(None, description="Error timestamp")
model_config = ConfigDict(ser_json_timedelta="iso8601")
@classmethod
def create(cls, error: str, error_type: str) -> "ActionLogError":
from datetime import timezone
return cls(
error=error,
error_type=error_type,
timestamp=datetime.now(timezone.utc),
)
class GetActionLogInfoRequest(BaseModel):
"""Request schema for get_action_log_info (ID-only lookup)."""
identifier: Annotated[
int,
Field(description="Log entry ID (integer)"),
]
def serialize_action_log_object(log: Any) -> ActionLogInfo | None:
if not log:
return None
dttm = getattr(log, "dttm", None)
if isinstance(dttm, datetime) and dttm.tzinfo is None:
from datetime import timezone
dttm = dttm.replace(tzinfo=timezone.utc)
return ActionLogInfo(
id=getattr(log, "id", None),
action=getattr(log, "action", None),
user_id=getattr(log, "user_id", None),
dttm=dttm,
dashboard_id=getattr(log, "dashboard_id", None),
slice_id=getattr(log, "slice_id", None),
json=getattr(log, "json", None),
)

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .get_action_log_info import get_action_log_info
from .list_action_logs import list_action_logs
__all__ = [
"list_action_logs",
"get_action_log_info",
]

View File

@@ -0,0 +1,97 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Get action log info MCP tool."""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.action_log.schemas import (
ActionLogError,
ActionLogInfo,
GetActionLogInfoRequest,
serialize_action_log_object,
)
from superset.mcp_service.mcp_core import ModelGetInfoCore
logger = logging.getLogger(__name__)
@tool(
tags=["discovery"],
class_permission_name="Log",
annotations=ToolAnnotations(
title="Get action log info",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_action_log_info(
request: GetActionLogInfoRequest,
ctx: Context,
) -> ActionLogInfo | ActionLogError:
"""Get a single action log entry by its integer ID.
Returns the action, user_id, timestamp (dttm), dashboard_id, slice_id,
and JSON payload for the specified log record.
ADMIN-ONLY: This tool requires admin privileges.
Use list_action_logs to discover log IDs.
"""
await ctx.info("Retrieving action log: identifier=%s" % (request.identifier,))
try:
from superset.daos.log import LogDAO
with event_logger.log_context(action="mcp.get_action_log_info.lookup"):
get_tool = ModelGetInfoCore(
dao_class=LogDAO,
output_schema=ActionLogInfo,
error_schema=ActionLogError,
serializer=serialize_action_log_object,
supports_slug=False,
logger=logger,
)
result = get_tool.run_tool(request.identifier)
if isinstance(result, ActionLogInfo):
await ctx.info(
"Action log retrieved: id=%s, action=%s" % (result.id, result.action)
)
else:
await ctx.warning(
"Action log retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result
except Exception as e:
await ctx.error(
"Action log retrieval failed: identifier=%s, error=%s, error_type=%s"
% (request.identifier, str(e), type(e).__name__)
)
return ActionLogError(
error=f"Failed to get action log info: {str(e)}",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,153 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""List action logs MCP tool."""
import logging
from datetime import datetime, timedelta, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.extensions import event_logger
from superset.mcp_service.action_log.schemas import (
ActionLogError,
ActionLogFilter,
ActionLogInfo,
ActionLogList,
ALL_LOG_COLUMNS,
DEFAULT_LOG_COLUMNS,
ListActionLogsRequest,
LOG_SORTABLE_COLUMNS,
serialize_action_log_object,
)
from superset.mcp_service.mcp_core import ModelListCore
logger = logging.getLogger(__name__)
_DEFAULT_LIST_ACTION_LOGS_REQUEST = ListActionLogsRequest()
@tool(
tags=["core"],
class_permission_name="Log",
annotations=ToolAnnotations(
title="List action logs",
readOnlyHint=True,
destructiveHint=False,
),
)
async def list_action_logs(
request: ListActionLogsRequest | None = None,
ctx: Context | None = None,
) -> ActionLogList | ActionLogError:
"""List Superset action logs with filtering and pagination.
Returns audit log entries recording user interactions with dashboards and
charts. Defaults to the last 7 days to avoid pulling large result sets.
ADMIN-ONLY: This tool requires admin privileges. Non-admin users will
receive a permission error.
Sortable columns for order_column: id, dttm
Filter columns: action, user_id, dashboard_id, slice_id, dttm
When no dttm filter is provided the tool automatically applies
dttm >= (now - 7 days). Add an explicit dttm filter to override.
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_action_logs")
request = request or _DEFAULT_LIST_ACTION_LOGS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing action logs: page=%s, page_size=%s" % (request.page, request.page_size)
)
await ctx.debug(
"Action log parameters: filters=%s, order_column=%s, order_direction=%s"
% (request.filters, request.order_column, request.order_direction)
)
try:
from superset.daos.log import LogDAO
# Inject default 7-day dttm filter unless caller already provides one
filters: list[ColumnOperator] = list(request.filters)
has_dttm_filter = any(getattr(f, "col", None) == "dttm" for f in filters)
if not has_dttm_filter:
cutoff = (datetime.now(timezone.utc) - timedelta(days=7)).isoformat()
default_filter = ActionLogFilter(
col="dttm",
opr=ColumnOperatorEnum.gte,
value=cutoff,
)
filters = [default_filter] + filters
await ctx.debug("Applied default 7-day dttm filter: cutoff=%s" % (cutoff,))
def _serialize(obj: object, cols: list[str] | None) -> ActionLogInfo | None:
return serialize_action_log_object(obj)
list_tool = ModelListCore(
dao_class=LogDAO,
output_schema=ActionLogInfo,
item_serializer=_serialize,
filter_type=ActionLogFilter,
default_columns=DEFAULT_LOG_COLUMNS,
search_columns=["action"],
list_field_name="action_logs",
output_list_schema=ActionLogList,
all_columns=ALL_LOG_COLUMNS,
sortable_columns=LOG_SORTABLE_COLUMNS,
logger=logger,
)
with event_logger.log_context(action="mcp.list_action_logs.query"):
result = list_tool.run_tool(
filters=filters,
search=request.search,
select_columns=request.select_columns,
order_column=request.order_column or "dttm",
order_direction=request.order_direction,
page=max(request.page - 1, 0),
page_size=request.page_size,
)
await ctx.info(
"Action logs listed: count=%s, total_count=%s"
% (
len(result.action_logs) if hasattr(result, "action_logs") else 0,
getattr(result, "total_count", None),
)
)
columns_to_filter = result.columns_requested
await ctx.debug(
"Applying field filtering via serialization context: columns=%s"
% (columns_to_filter,)
)
with event_logger.log_context(action="mcp.list_action_logs.serialization"):
return result.model_dump(
mode="json",
context={"select_columns": columns_to_filter},
)
except Exception as e:
await ctx.error(
"Action log listing failed: page=%s, error=%s, error_type=%s"
% (request.page, str(e), type(e).__name__)
)
raise

View File

@@ -602,6 +602,10 @@ warnings.filterwarnings(
# NOTE: Always add new prompt/resource imports here when creating new prompts/resources.
# Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators.
# They register automatically on import, similar to tools.
from superset.mcp_service.action_log.tool import ( # noqa: F401, E402
get_action_log_info,
list_action_logs,
)
from superset.mcp_service.chart import ( # noqa: F401, E402
prompts as chart_prompts,
resources as chart_resources,
@@ -652,6 +656,10 @@ from superset.mcp_service.system.tool import ( # noqa: F401, E402
get_schema,
health_check,
)
from superset.mcp_service.task.tool import ( # noqa: F401, E402
get_task_info,
list_tasks,
)
def _remove_disabled_tools(disabled_tools: set[str]) -> None:

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,240 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Pydantic schemas for task MCP tools."""
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_serializer,
model_validator,
PositiveInt,
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.system.schemas import PaginationInfo
from superset.mcp_service.utils.schema_utils import (
parse_json_or_list,
parse_json_or_model_list,
)
DEFAULT_TASK_COLUMNS: list[str] = ["id", "uuid", "task_type", "status", "changed_on"]
ALL_TASK_COLUMNS: list[str] = [
"id",
"uuid",
"task_type",
"task_key",
"task_name",
"status",
"scope",
"changed_on",
"created_on",
]
TASK_SORTABLE_COLUMNS: list[str] = ["id", "changed_on", "created_on", "status"]
class TaskColumnFilter(ColumnOperator):
"""Filter object for task listing.
col: Column to filter on.
opr: Operator to use.
value: Value to filter by.
"""
col: Literal["task_type", "status", "scope"] = Field(
...,
description="Column to filter on.",
)
opr: ColumnOperatorEnum = Field(..., description="Operator to use.")
value: str | int | float | bool | list[str | int | float | bool] = Field(
..., description="Value to filter by"
)
class TaskInfo(BaseModel):
id: int | None = Field(None, description="Task ID")
uuid: str | None = Field(None, description="Task UUID")
task_type: str | None = Field(None, description="Task type (e.g., sql_execution)")
task_key: str | None = Field(None, description="Task deduplication key")
task_name: str | None = Field(None, description="Human-readable task name")
status: str | None = Field(None, description="Task status")
scope: str | None = Field(None, description="Task scope (private/shared/system)")
changed_on: str | datetime | None = Field(
None, description="Last modification timestamp"
)
created_on: str | datetime | None = Field(None, description="Creation timestamp")
model_config = ConfigDict(
from_attributes=True,
ser_json_timedelta="iso8601",
populate_by_name=True,
)
@model_serializer(mode="wrap")
def _filter_fields_by_context(self, serializer: Any, info: Any) -> dict[str, Any]:
data = serializer(self)
if info.context and isinstance(info.context, dict):
select_columns = info.context.get("select_columns")
if select_columns:
requested_fields = set(select_columns)
return {k: v for k, v in data.items() if k in requested_fields}
return data
class TaskList(BaseModel):
tasks: list[TaskInfo]
count: int
total_count: int
page: int
page_size: int
total_pages: int
has_previous: bool
has_next: bool
columns_requested: list[str] = Field(default_factory=list)
columns_loaded: list[str] = Field(default_factory=list)
columns_available: list[str] = Field(default_factory=list)
sortable_columns: list[str] = Field(default_factory=list)
filters_applied: list[TaskColumnFilter] = Field(default_factory=list)
pagination: PaginationInfo | None = None
timestamp: datetime | None = None
model_config = ConfigDict(ser_json_timedelta="iso8601")
class ListTasksRequest(BaseModel):
"""Request schema for list_tasks."""
filters: Annotated[
list[TaskColumnFilter],
Field(
default_factory=list,
description=(
"List of filter objects (col, opr, value). "
"Filter columns: task_type, status, scope. "
"Cannot be used with 'search'."
),
),
]
select_columns: Annotated[
list[str],
Field(
default_factory=list,
description="Columns to return. Defaults to common columns.",
),
]
search: Annotated[
str | None,
Field(
default=None,
description=(
"Text search string matched against task_type, task_key, "
"task_name, status, and scope. "
"Cannot be used together with 'filters'."
),
),
]
order_column: Annotated[
str | None,
Field(default=None, description="Column to sort by (default: changed_on)"),
]
order_direction: Annotated[
Literal["asc", "desc"],
Field(default="desc", description="Sort direction ('asc' or 'desc')"),
]
page: Annotated[
PositiveInt,
Field(default=1, description="Page number (1-based)"),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Items per page (max {MAX_PAGE_SIZE})",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> list[TaskColumnFilter]:
return parse_json_or_model_list(v, TaskColumnFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_columns(cls, v: Any) -> list[str]:
return parse_json_or_list(v, "select_columns")
@model_validator(mode="after")
def validate_search_and_filters(self) -> "ListTasksRequest":
if self.search and self.filters:
raise ValueError(
"Cannot use both 'search' and 'filters' simultaneously. "
"Use 'search' for text matching on task_type/status/scope, or "
"'filters' for column-based filtering, but not both."
)
return self
class TaskError(BaseModel):
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Error type")
timestamp: str | datetime | None = Field(None, description="Error timestamp")
model_config = ConfigDict(ser_json_timedelta="iso8601")
@classmethod
def create(cls, error: str, error_type: str) -> "TaskError":
from datetime import timezone
return cls(
error=error,
error_type=error_type,
timestamp=datetime.now(timezone.utc),
)
class GetTaskInfoRequest(BaseModel):
"""Request schema for get_task_info (ID or UUID lookup)."""
identifier: Annotated[
int | str,
Field(description="Task identifier — numeric ID or UUID string"),
]
def serialize_task_object(task: Any) -> TaskInfo | None:
if not task:
return None
uuid_val = getattr(task, "uuid", None)
return TaskInfo(
id=getattr(task, "id", None),
uuid=str(uuid_val) if uuid_val is not None else None,
task_type=getattr(task, "task_type", None),
task_key=getattr(task, "task_key", None),
task_name=getattr(task, "task_name", None),
status=getattr(task, "status", None),
scope=getattr(task, "scope", None),
changed_on=getattr(task, "changed_on", None),
created_on=getattr(task, "created_on", None),
)

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .get_task_info import get_task_info
from .list_tasks import list_tasks
__all__ = [
"list_tasks",
"get_task_info",
]

View File

@@ -0,0 +1,108 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Get task info MCP tool."""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.task.schemas import (
GetTaskInfoRequest,
serialize_task_object,
TaskError,
TaskInfo,
)
logger = logging.getLogger(__name__)
@tool(
tags=["discovery"],
class_permission_name="Task",
annotations=ToolAnnotations(
title="Get task info",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_task_info(
request: GetTaskInfoRequest,
ctx: Context,
) -> TaskInfo | TaskError:
"""Get details for a single async task by ID or UUID.
Returns task_type, status, scope, and timestamps for the specified task.
Non-admin users can only retrieve tasks they are subscribed to.
Use list_tasks to discover task IDs and UUIDs.
Example usage:
```json
{"identifier": 42}
```
Or with UUID:
```json
{"identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab"}
```
"""
await ctx.info("Retrieving task: identifier=%s" % (request.identifier,))
try:
from superset.daos.tasks import TaskDAO
with event_logger.log_context(action="mcp.get_task_info.lookup"):
# ModelGetInfoCore handles int ID and UUID string automatically.
# TaskDAO.base_filter (TaskFilter) enforces subscription-based access.
get_tool = ModelGetInfoCore(
dao_class=TaskDAO,
output_schema=TaskInfo,
error_schema=TaskError,
serializer=serialize_task_object,
supports_slug=False,
logger=logger,
)
result = get_tool.run_tool(request.identifier)
if isinstance(result, TaskInfo):
await ctx.info(
"Task retrieved: id=%s, task_type=%s, status=%s"
% (result.id, result.task_type, result.status)
)
else:
await ctx.warning(
"Task retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result
except Exception as e:
await ctx.error(
"Task retrieval failed: identifier=%s, error=%s, error_type=%s"
% (request.identifier, str(e), type(e).__name__)
)
return TaskError(
error=f"Failed to get task info: {str(e)}",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,140 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""List tasks MCP tool."""
import logging
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.task.schemas import (
ALL_TASK_COLUMNS,
DEFAULT_TASK_COLUMNS,
ListTasksRequest,
serialize_task_object,
TASK_SORTABLE_COLUMNS,
TaskColumnFilter,
TaskError,
TaskInfo,
TaskList,
)
logger = logging.getLogger(__name__)
_DEFAULT_LIST_TASKS_REQUEST = ListTasksRequest()
@tool(
tags=["core"],
class_permission_name="Task",
annotations=ToolAnnotations(
title="List tasks",
readOnlyHint=True,
destructiveHint=False,
),
)
async def list_tasks(
request: ListTasksRequest | None = None,
ctx: Context | None = None,
) -> TaskList | TaskError:
"""List async tasks with filtering and pagination.
Returns tasks visible to the current user. Non-admin users only see tasks
they are subscribed to (task creators are auto-subscribed). Admins see all
tasks.
Sortable columns for order_column: id, changed_on, created_on, status
Filter columns: task_type, status, scope
Search columns (via search=): task_type, task_key, task_name, status, scope
Common task_type values: sql_execution, thumbnail, report
Common status values: pending, in_progress, success, failure, aborted
Common scope values: private, shared, system
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_tasks")
request = request or _DEFAULT_LIST_TASKS_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing tasks: page=%s, page_size=%s" % (request.page, request.page_size)
)
await ctx.debug(
"Task parameters: filters=%s, order_column=%s, order_direction=%s"
% (request.filters, request.order_column, request.order_direction)
)
try:
from superset.daos.tasks import TaskDAO
def _serialize(obj: object, cols: list[str] | None) -> TaskInfo | None:
return serialize_task_object(obj)
# TaskDAO.base_filter = TaskFilter automatically scopes results:
# non-admins only see their subscribed tasks; admins see all.
list_tool = ModelListCore(
dao_class=TaskDAO,
output_schema=TaskInfo,
item_serializer=_serialize,
filter_type=TaskColumnFilter,
default_columns=DEFAULT_TASK_COLUMNS,
search_columns=["task_type", "task_key", "task_name", "status", "scope"],
list_field_name="tasks",
output_list_schema=TaskList,
all_columns=ALL_TASK_COLUMNS,
sortable_columns=TASK_SORTABLE_COLUMNS,
logger=logger,
)
with event_logger.log_context(action="mcp.list_tasks.query"):
result = list_tool.run_tool(
filters=request.filters,
search=request.search,
select_columns=request.select_columns,
order_column=request.order_column,
order_direction=request.order_direction,
page=max(request.page - 1, 0),
page_size=request.page_size,
)
await ctx.info(
"Tasks listed: count=%s, total_count=%s"
% (
len(result.tasks) if hasattr(result, "tasks") else 0,
getattr(result, "total_count", None),
)
)
columns_to_filter = result.columns_requested
await ctx.debug(
"Applying field filtering via serialization context: columns=%s"
% (columns_to_filter,)
)
with event_logger.log_context(action="mcp.list_tasks.serialization"):
return result.model_dump(
mode="json",
context={"select_columns": columns_to_filter},
)
except Exception as e:
await ctx.error(
"Task listing failed: page=%s, error=%s, error_type=%s"
% (request.page, str(e), type(e).__name__)
)
raise

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,221 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for list_action_logs and get_action_log_info MCP tools."""
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastmcp import Client
from pydantic import ValidationError
from superset.mcp_service.action_log.schemas import (
ActionLogFilter,
ListActionLogsRequest,
)
from superset.mcp_service.app import mcp
from superset.utils import json
def create_mock_log(
log_id: int = 1,
action: str = "log",
user_id: int = 1,
dashboard_id: int | None = None,
slice_id: int | None = None,
json_payload: str | None = None,
dttm: datetime | None = None,
) -> MagicMock:
log = MagicMock()
log.id = log_id
log.action = action
log.user_id = user_id
log.dashboard_id = dashboard_id
log.slice_id = slice_id
log.json = json_payload or '{"event_name": "explore_slice"}'
log.dttm = dttm or datetime(2024, 1, 1, 12, 0, 0, tzinfo=timezone.utc)
log.duration_ms = None
log.referrer = None
return log
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
from unittest.mock import Mock
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
class TestActionLogFilterSchema:
def test_valid_filter_columns_accepted(self):
for col in ("action", "user_id", "dashboard_id", "slice_id", "dttm"):
f = ActionLogFilter(col=col, opr="eq", value="test")
assert f.col == col
def test_invalid_filter_column_rejected(self):
with pytest.raises(ValidationError):
ActionLogFilter(col="not_a_column", opr="eq", value="x")
def test_created_by_fk_rejected(self):
with pytest.raises(ValidationError):
ActionLogFilter(col="created_by_fk", opr="eq", value=1)
@patch("superset.daos.log.LogDAO.list")
@pytest.mark.asyncio
async def test_list_action_logs_basic(mock_list, mcp_server):
"""Basic listing returns structured response."""
log = create_mock_log()
mock_list.return_value = ([log], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_action_logs", {})
data = json.loads(result.content[0].text)
assert data["action_logs"] is not None
assert len(data["action_logs"]) == 1
assert data["action_logs"][0]["id"] == 1
assert data["action_logs"][0]["action"] == "log"
assert data["action_logs"][0]["user_id"] == 1
@patch("superset.daos.log.LogDAO.list")
@pytest.mark.asyncio
async def test_list_action_logs_default_7day_filter_applied(mock_list, mcp_server):
"""When no dttm filter is provided, a 7-day filter is injected automatically."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_action_logs", {})
# Verify list() was called with a dttm filter in column_operators
call_kwargs = mock_list.call_args.kwargs
col_operators = call_kwargs.get("column_operators", [])
dttm_filters = [f for f in col_operators if getattr(f, "col", None) == "dttm"]
assert len(dttm_filters) == 1
assert dttm_filters[0].opr == "gte"
# Verify the injected filter appears in the serialized filters_applied
data = json.loads(result.content[0].text)
filters_applied = data.get("filters_applied", [])
dttm_applied = [f for f in filters_applied if f.get("col") == "dttm"]
assert len(dttm_applied) == 1
assert dttm_applied[0]["opr"] == "gte"
assert isinstance(dttm_applied[0]["value"], str) # ISO string, not datetime
@patch("superset.daos.log.LogDAO.list")
@pytest.mark.asyncio
async def test_list_action_logs_explicit_dttm_filter_skips_default(
mock_list, mcp_server
):
"""When a dttm filter is provided, the default 7-day filter is NOT injected."""
mock_list.return_value = ([], 0)
request = ListActionLogsRequest(
filters=[{"col": "dttm", "opr": "gte", "value": "2020-01-01T00:00:00"}]
)
async with Client(mcp_server) as client:
await client.call_tool("list_action_logs", {"request": request.model_dump()})
call_kwargs = mock_list.call_args.kwargs
col_operators = call_kwargs.get("column_operators", [])
dttm_filters = [f for f in col_operators if getattr(f, "col", None) == "dttm"]
# Only the user-provided filter, not the injected default
assert len(dttm_filters) == 1
assert dttm_filters[0].value == "2020-01-01T00:00:00"
@patch("superset.daos.log.LogDAO.list")
@pytest.mark.asyncio
async def test_list_action_logs_default_sort_is_dttm_desc(mock_list, mcp_server):
"""Default sort is dttm descending (most recent first)."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
await client.call_tool("list_action_logs", {})
call_kwargs = mock_list.call_args.kwargs
assert call_kwargs.get("order_column") == "dttm"
assert call_kwargs.get("order_direction") == "desc"
@patch("superset.daos.log.LogDAO.list")
@pytest.mark.asyncio
async def test_list_action_logs_pagination(mock_list, mcp_server):
"""Pagination metadata is correct."""
logs = [create_mock_log(log_id=i) for i in range(1, 6)]
mock_list.return_value = (logs, 20)
async with Client(mcp_server) as client:
request = ListActionLogsRequest(page=1, page_size=5)
result = await client.call_tool(
"list_action_logs", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["count"] == 5
assert data["total_count"] == 20
assert data["page"] == 1
assert data["page_size"] == 5
assert data["has_next"] is True
assert data["has_previous"] is False
@patch("superset.daos.log.LogDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_action_log_info_basic(mock_find, mcp_server):
"""get_action_log_info returns log details by integer ID."""
log = create_mock_log(log_id=42, action="explore_chart", user_id=7)
mock_find.return_value = log
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_action_log_info", {"request": {"identifier": 42}}
)
data = json.loads(result.content[0].text)
assert data["id"] == 42
assert data["action"] == "explore_chart"
assert data["user_id"] == 7
@patch("superset.daos.log.LogDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_action_log_info_not_found(mock_find, mcp_server):
"""get_action_log_info returns error when log does not exist."""
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_action_log_info", {"request": {"identifier": 9999}}
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "not_found"

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,249 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Unit tests for list_tasks and get_task_info MCP tools."""
import uuid
from datetime import datetime, timezone
from unittest.mock import MagicMock, patch
import pytest
from fastmcp import Client
from pydantic import ValidationError
from superset.mcp_service.app import mcp
from superset.mcp_service.task.schemas import ListTasksRequest, TaskColumnFilter
from superset.utils import json
SAMPLE_UUID = str(uuid.uuid4())
def create_mock_task(
task_id: int = 1,
task_uuid: str | None = None,
task_type: str = "sql_execution",
task_key: str = "default-key",
task_name: str | None = None,
status: str = "success",
scope: str = "private",
changed_on: datetime | None = None,
created_on: datetime | None = None,
) -> MagicMock:
task = MagicMock()
task.id = task_id
task.uuid = task_uuid or SAMPLE_UUID
task.task_type = task_type
task.task_key = task_key
task.task_name = task_name
task.status = status
task.scope = scope
task.changed_on = changed_on or datetime(2024, 1, 2, 10, 0, 0, tzinfo=timezone.utc)
task.created_on = created_on or datetime(2024, 1, 1, 9, 0, 0, tzinfo=timezone.utc)
return task
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
from unittest.mock import Mock
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "testuser"
mock_get_user.return_value = mock_user
yield mock_get_user
class TestTaskColumnFilterSchema:
def test_valid_filter_columns_accepted(self):
for col in ("task_type", "status", "scope"):
f = TaskColumnFilter(col=col, opr="eq", value="test")
assert f.col == col
def test_invalid_filter_column_rejected(self):
with pytest.raises(ValidationError):
TaskColumnFilter(col="user_id", opr="eq", value=1)
def test_uuid_column_rejected(self):
with pytest.raises(ValidationError):
TaskColumnFilter(col="uuid", opr="eq", value="some-uuid")
@patch("superset.daos.tasks.TaskDAO.list")
@pytest.mark.asyncio
async def test_list_tasks_basic(mock_list, mcp_server):
"""Basic task listing returns structured response."""
task = create_mock_task()
mock_list.return_value = ([task], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_tasks", {})
data = json.loads(result.content[0].text)
assert data["tasks"] is not None
assert len(data["tasks"]) == 1
assert data["tasks"][0]["id"] == 1
assert data["tasks"][0]["task_type"] == "sql_execution"
assert data["tasks"][0]["status"] == "success"
@patch("superset.daos.tasks.TaskDAO.list")
@pytest.mark.asyncio
async def test_list_tasks_with_status_filter(mock_list, mcp_server):
"""Status filter is passed through to the DAO correctly."""
task = create_mock_task(status="pending")
mock_list.return_value = ([task], 1)
async with Client(mcp_server) as client:
request = ListTasksRequest(
filters=[{"col": "status", "opr": "eq", "value": "pending"}]
)
result = await client.call_tool("list_tasks", {"request": request.model_dump()})
data = json.loads(result.content[0].text)
assert len(data["tasks"]) == 1
assert data["tasks"][0]["status"] == "pending"
@patch("superset.daos.tasks.TaskDAO.list")
@pytest.mark.asyncio
async def test_list_tasks_taskfilter_scoping_is_applied(mock_list, mcp_server):
"""TaskDAO.list is called with base_filter (TaskFilter) applied via DAO class."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
await client.call_tool("list_tasks", {})
# Verify the DAO's list() is called — the TaskFilter scoping is enforced
# by TaskDAO.base_filter = TaskFilter which BaseDAO applies automatically.
assert mock_list.called
@patch("superset.daos.tasks.TaskDAO.list")
@pytest.mark.asyncio
async def test_list_tasks_pagination(mock_list, mcp_server):
"""Pagination metadata is correct."""
tasks = [create_mock_task(task_id=i) for i in range(1, 4)]
mock_list.return_value = (tasks, 30)
async with Client(mcp_server) as client:
request = ListTasksRequest(page=2, page_size=3)
result = await client.call_tool("list_tasks", {"request": request.model_dump()})
data = json.loads(result.content[0].text)
assert data["count"] == 3
assert data["total_count"] == 30
assert data["page"] == 2
assert data["page_size"] == 3
assert data["has_previous"] is True
assert data["has_next"] is True
@patch("superset.daos.tasks.TaskDAO.list")
@pytest.mark.asyncio
async def test_list_tasks_uuid_in_response(mock_list, mcp_server):
"""Task UUID is returned as a string in the response."""
task_uuid = str(uuid.uuid4())
task = create_mock_task(task_uuid=task_uuid)
mock_list.return_value = ([task], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_tasks", {})
data = json.loads(result.content[0].text)
assert data["tasks"][0]["uuid"] == task_uuid
@patch("superset.daos.tasks.TaskDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_task_info_by_integer_id(mock_find, mcp_server):
"""get_task_info resolves a task by integer ID."""
task = create_mock_task(task_id=5, task_type="thumbnail", status="in_progress")
mock_find.return_value = task
async with Client(mcp_server) as client:
result = await client.call_tool("get_task_info", {"request": {"identifier": 5}})
data = json.loads(result.content[0].text)
assert data["id"] == 5
assert data["task_type"] == "thumbnail"
assert data["status"] == "in_progress"
@patch("superset.daos.tasks.TaskDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_task_info_by_uuid(mock_find, mcp_server):
"""get_task_info resolves a task by UUID string."""
task_uuid = str(uuid.uuid4())
task = create_mock_task(task_id=10, task_uuid=task_uuid, status="success")
mock_find.return_value = task
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_task_info", {"request": {"identifier": task_uuid}}
)
data = json.loads(result.content[0].text)
assert data["id"] == 10
assert data["status"] == "success"
@patch("superset.daos.tasks.TaskDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_task_info_not_found(mock_find, mcp_server):
"""get_task_info returns error when task does not exist."""
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_task_info", {"request": {"identifier": 9999}}
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "not_found"
@patch("superset.daos.tasks.TaskDAO.list")
@pytest.mark.asyncio
async def test_list_tasks_non_admin_sees_only_subscribed(mock_list, mcp_server):
"""Non-admin users only receive tasks their subscriptions allow.
The subscription scoping is enforced by TaskDAO.base_filter = TaskFilter,
which BaseDAO._apply_base_filter injects before each query. The MCP tool
itself adds no extra filtering — it just delegates to TaskDAO.list(), which
carries the filter class. This test confirms that:
1. list_tasks calls TaskDAO.list() (so the base_filter runs)
2. Only items returned by that call appear in the response
"""
# Simulate DAO returning only the subscribed task
subscribed_task = create_mock_task(task_id=42, status="pending")
mock_list.return_value = ([subscribed_task], 1)
async with Client(mcp_server) as client:
result = await client.call_tool("list_tasks", {})
data = json.loads(result.content[0].text)
assert len(data["tasks"]) == 1
assert data["tasks"][0]["id"] == 42
# TaskDAO.list was called exactly once — base_filter is applied inside
assert mock_list.call_count == 1