From 2bbd529ab7cb79c4ae7f9be12e0afc2b9531bcc8 Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Wed, 20 May 2026 23:30:59 +0000 Subject: [PATCH] fix(mcp): add task_key/task_name to TaskInfo and strengthen test coverage - Add task_key and task_name fields to TaskInfo schema and ALL_TASK_COLUMNS; these are real Task model columns present in the REST API search_columns - Expand search_columns in list_tasks to include task_key and task_name - Strengthen test_list_action_logs_default_7day_filter_applied to also assert the injected filter appears in filters_applied with an ISO string value Co-Authored-By: Claude Sonnet 4.6 --- superset/mcp_service/task/schemas.py | 9 ++++++++- superset/mcp_service/task/tool/list_tasks.py | 3 ++- .../action_log/tool/test_action_log_tools.py | 10 +++++++++- .../mcp_service/task/tool/test_task_tools.py | 4 ++++ 4 files changed, 23 insertions(+), 3 deletions(-) diff --git a/superset/mcp_service/task/schemas.py b/superset/mcp_service/task/schemas.py index af5e7f6662d..c908abf19e5 100644 --- a/superset/mcp_service/task/schemas.py +++ b/superset/mcp_service/task/schemas.py @@ -45,6 +45,8 @@ ALL_TASK_COLUMNS: list[str] = [ "id", "uuid", "task_type", + "task_key", + "task_name", "status", "scope", "changed_on", @@ -75,6 +77,8 @@ class TaskInfo(BaseModel): id: int | None = Field(None, description="Task ID") uuid: str | None = Field(None, description="Task UUID") task_type: str | None = Field(None, description="Task type (e.g., sql_execution)") + task_key: str | None = Field(None, description="Task deduplication key") + task_name: str | None = Field(None, description="Human-readable task name") status: str | None = Field(None, description="Task status") scope: str | None = Field(None, description="Task scope (private/shared/system)") changed_on: str | datetime | None = Field( @@ -144,7 +148,8 @@ class ListTasksRequest(BaseModel): Field( default=None, description=( - "Text search string matched against task_type, status, and scope. " + "Text search string matched against task_type, task_key, " + "task_name, status, and scope. " "Cannot be used together with 'filters'." ), ), @@ -226,6 +231,8 @@ def serialize_task_object(task: Any) -> TaskInfo | None: id=getattr(task, "id", None), uuid=str(uuid_val) if uuid_val is not None else None, task_type=getattr(task, "task_type", None), + task_key=getattr(task, "task_key", None), + task_name=getattr(task, "task_name", None), status=getattr(task, "status", None), scope=getattr(task, "scope", None), changed_on=getattr(task, "changed_on", None), diff --git a/superset/mcp_service/task/tool/list_tasks.py b/superset/mcp_service/task/tool/list_tasks.py index a9ca6329145..1aa701dd4aa 100644 --- a/superset/mcp_service/task/tool/list_tasks.py +++ b/superset/mcp_service/task/tool/list_tasks.py @@ -62,6 +62,7 @@ async def list_tasks( Sortable columns for order_column: id, changed_on, created_on, status Filter columns: task_type, status, scope + Search columns (via search=): task_type, task_key, task_name, status, scope Common task_type values: sql_execution, thumbnail, report Common status values: pending, in_progress, success, failure, aborted @@ -94,7 +95,7 @@ async def list_tasks( item_serializer=_serialize, filter_type=TaskColumnFilter, default_columns=DEFAULT_TASK_COLUMNS, - search_columns=["task_type", "status", "scope"], + search_columns=["task_type", "task_key", "task_name", "status", "scope"], list_field_name="tasks", output_list_schema=TaskList, all_columns=ALL_TASK_COLUMNS, diff --git a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py index 0f57cdf56ff..89496ace78c 100644 --- a/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py +++ b/tests/unit_tests/mcp_service/action_log/tool/test_action_log_tools.py @@ -111,7 +111,7 @@ async def test_list_action_logs_default_7day_filter_applied(mock_list, mcp_serve mock_list.return_value = ([], 0) async with Client(mcp_server) as client: - await client.call_tool("list_action_logs", {}) + result = await client.call_tool("list_action_logs", {}) # Verify list() was called with a dttm filter in column_operators call_kwargs = mock_list.call_args.kwargs @@ -120,6 +120,14 @@ async def test_list_action_logs_default_7day_filter_applied(mock_list, mcp_serve assert len(dttm_filters) == 1 assert dttm_filters[0].opr == "gte" + # Verify the injected filter appears in the serialized filters_applied + data = json.loads(result.content[0].text) + filters_applied = data.get("filters_applied", []) + dttm_applied = [f for f in filters_applied if f.get("col") == "dttm"] + assert len(dttm_applied) == 1 + assert dttm_applied[0]["opr"] == "gte" + assert isinstance(dttm_applied[0]["value"], str) # ISO string, not datetime + @patch("superset.daos.log.LogDAO.list") @pytest.mark.asyncio diff --git a/tests/unit_tests/mcp_service/task/tool/test_task_tools.py b/tests/unit_tests/mcp_service/task/tool/test_task_tools.py index 8cb616ff435..9257b2c002d 100644 --- a/tests/unit_tests/mcp_service/task/tool/test_task_tools.py +++ b/tests/unit_tests/mcp_service/task/tool/test_task_tools.py @@ -36,6 +36,8 @@ def create_mock_task( task_id: int = 1, task_uuid: str | None = None, task_type: str = "sql_execution", + task_key: str = "default-key", + task_name: str | None = None, status: str = "success", scope: str = "private", changed_on: datetime | None = None, @@ -45,6 +47,8 @@ def create_mock_task( task.id = task_id task.uuid = task_uuid or SAMPLE_UUID task.task_type = task_type + task.task_key = task_key + task.task_name = task_name task.status = status task.scope = scope task.changed_on = changed_on or datetime(2024, 1, 2, 10, 0, 0, tzinfo=timezone.utc)