mirror of
https://github.com/apache/superset.git
synced 2026-06-02 22:29:26 +00:00
feat(mcp): add list and get tools for saved queries and query history (#40346)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -162,6 +162,10 @@ SQL Lab Integration:
|
||||
- execute_sql: Execute SQL queries and get results (requires database_id and SQL access)
|
||||
- save_sql_query: Save a SQL query to Saved Queries list (requires write access)
|
||||
- open_sql_lab_with_context: Generate SQL Lab URL with pre-filled sql
|
||||
- list_saved_queries: List saved SQL queries with filtering and search (1-based pagination)
|
||||
- get_saved_query_info: Get saved query details by ID or UUID
|
||||
- list_queries: List SQL query history with filtering and search (1-based pagination)
|
||||
- get_query_info: Get SQL query history details by ID
|
||||
|
||||
Schema Discovery:
|
||||
- get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters)
|
||||
@@ -668,6 +672,14 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
|
||||
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
|
||||
generate_explore_link,
|
||||
)
|
||||
from superset.mcp_service.query.tool import ( # noqa: F401, E402
|
||||
get_query_info,
|
||||
list_queries,
|
||||
)
|
||||
from superset.mcp_service.saved_query.tool import ( # noqa: F401, E402
|
||||
get_saved_query_info,
|
||||
list_saved_queries,
|
||||
)
|
||||
from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402
|
||||
execute_sql,
|
||||
open_sql_lab_with_context,
|
||||
|
||||
16
superset/mcp_service/query/__init__.py
Normal file
16
superset/mcp_service/query/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
293
superset/mcp_service/query/schemas.py
Normal file
293
superset/mcp_service/query/schemas.py
Normal file
@@ -0,0 +1,293 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for query history-related responses
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.constants import MAX_PAGE_SIZE
|
||||
from superset.mcp_service.privacy import filter_user_directory_fields
|
||||
from superset.mcp_service.system.schemas import PaginationInfo
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_model_list,
|
||||
)
|
||||
|
||||
DEFAULT_QUERY_COLUMNS = ["id", "status", "start_time", "database_id", "schema"]
|
||||
SORTABLE_QUERY_COLUMNS = [
|
||||
"id",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"status",
|
||||
"database_id",
|
||||
"changed_on",
|
||||
]
|
||||
ALL_QUERY_COLUMNS = [
|
||||
"id",
|
||||
"sql",
|
||||
"executed_sql",
|
||||
"status",
|
||||
"start_time",
|
||||
"end_time",
|
||||
"rows",
|
||||
"database_id",
|
||||
"schema",
|
||||
"catalog",
|
||||
"tab_name",
|
||||
"error_message",
|
||||
"client_id",
|
||||
"limit",
|
||||
"progress",
|
||||
"changed_on",
|
||||
"user_id",
|
||||
]
|
||||
|
||||
DEFAULT_QUERY_PAGE_SIZE = 25
|
||||
|
||||
|
||||
class QueryFilter(ColumnOperator):
|
||||
"""
|
||||
Filter object for query history listing.
|
||||
col: The column to filter on. Must be one of the allowed filter fields.
|
||||
opr: The operator to use. Must be one of the supported operators.
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal["status", "database_id", "schema", "user_id", "start_time"] = Field(
|
||||
...,
|
||||
description="Column to filter on.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
description="Operator to use.",
|
||||
)
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by (type depends on col and opr)"
|
||||
)
|
||||
|
||||
|
||||
class QueryInfo(BaseModel):
|
||||
id: int | None = Field(None, description="Query ID")
|
||||
sql: str | None = Field(None, description="SQL query text as submitted")
|
||||
executed_sql: str | None = Field(
|
||||
None, description="Actual SQL executed after templating/CTAS rewriting"
|
||||
)
|
||||
status: str | None = Field(None, description="Query execution status")
|
||||
start_time: float | None = Field(
|
||||
None, description="Query start time (seconds since epoch)"
|
||||
)
|
||||
end_time: float | None = Field(
|
||||
None, description="Query end time (seconds since epoch)"
|
||||
)
|
||||
rows: int | None = Field(None, description="Number of rows returned or affected")
|
||||
database_id: int | None = Field(None, description="Database connection ID")
|
||||
schema: str | None = Field(None, description="Database schema name")
|
||||
catalog: str | None = Field(None, description="Database catalog name")
|
||||
tab_name: str | None = Field(None, description="SQL Lab tab name")
|
||||
error_message: str | None = Field(None, description="Error message if query failed")
|
||||
client_id: str | None = Field(None, description="Client-assigned query identifier")
|
||||
limit: int | None = Field(None, description="Row limit applied to the query")
|
||||
progress: int | None = Field(None, description="Query execution progress (0-100)")
|
||||
changed_on: str | datetime | None = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
user_id: int | None = Field(None, description="ID of the user who ran the query")
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
ser_json_timedelta="iso8601",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
|
||||
data = filter_user_directory_fields(serializer(self))
|
||||
|
||||
if info.context and isinstance(info.context, dict):
|
||||
select_columns = info.context.get("select_columns")
|
||||
if select_columns:
|
||||
requested_fields = set(select_columns)
|
||||
return {k: v for k, v in data.items() if k in requested_fields}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class QueryList(BaseModel):
|
||||
queries: List[QueryInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Requested columns for the response",
|
||||
)
|
||||
columns_loaded: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Columns that were actually loaded for each query",
|
||||
)
|
||||
columns_available: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="All columns available for selection via select_columns parameter",
|
||||
)
|
||||
sortable_columns: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Columns that can be used with order_column parameter",
|
||||
)
|
||||
filters_applied: List[QueryFilter] = Field(
|
||||
default_factory=list,
|
||||
description="List of advanced filter dicts applied to the query.",
|
||||
)
|
||||
pagination: PaginationInfo | None = None
|
||||
timestamp: datetime | None = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListQueriesRequest(BaseModel):
|
||||
"""Request schema for list_queries."""
|
||||
|
||||
filters: Annotated[
|
||||
List[QueryFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (column, operator, value). Each "
|
||||
"filter is an object with 'col', 'opr', and 'value' "
|
||||
"properties. Cannot be used together with 'search'.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of columns to select. Defaults to common columns if not "
|
||||
"specified.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search string to match against query fields. "
|
||||
"Cannot be used together with 'filters'.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
str | None,
|
||||
Field(default=None, description="Column to order results by"),
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(
|
||||
default="desc",
|
||||
description="Direction to order results ('asc' or 'desc')",
|
||||
),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number for pagination (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_QUERY_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Number of items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[QueryFilter]:
|
||||
"""Accept both JSON string and list of objects."""
|
||||
return parse_json_or_model_list(v, QueryFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v: Any) -> List[str]:
|
||||
"""Accept JSON array, list, or comma-separated string."""
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListQueriesRequest":
|
||||
"""Prevent using both search and filters simultaneously."""
|
||||
if self.search and self.filters:
|
||||
raise ValueError(
|
||||
"Cannot use both 'search' and 'filters' parameters simultaneously. "
|
||||
"Use either 'search' for text-based searching across multiple fields, "
|
||||
"or 'filters' for precise column-based filtering, but not both."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class QueryError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: str | datetime | None = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class GetQueryInfoRequest(BaseModel):
|
||||
"""Request schema for get_query_info with support for numeric ID only."""
|
||||
|
||||
identifier: Annotated[
|
||||
int,
|
||||
Field(description="Query ID (numeric)"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_query_object(query: Any) -> QueryInfo | None:
|
||||
if not query:
|
||||
return None
|
||||
|
||||
return QueryInfo(
|
||||
id=getattr(query, "id", None),
|
||||
sql=getattr(query, "sql", None),
|
||||
executed_sql=getattr(query, "executed_sql", None),
|
||||
status=getattr(query, "status", None),
|
||||
start_time=getattr(query, "start_time", None),
|
||||
end_time=getattr(query, "end_time", None),
|
||||
rows=getattr(query, "rows", None),
|
||||
database_id=getattr(query, "database_id", None),
|
||||
schema=getattr(query, "schema", None),
|
||||
catalog=getattr(query, "catalog", None),
|
||||
tab_name=getattr(query, "tab_name", None),
|
||||
error_message=getattr(query, "error_message", None),
|
||||
client_id=getattr(query, "client_id", None),
|
||||
limit=getattr(query, "limit", None),
|
||||
progress=getattr(query, "progress", None),
|
||||
changed_on=getattr(query, "changed_on", None),
|
||||
user_id=getattr(query, "user_id", None),
|
||||
)
|
||||
24
superset/mcp_service/query/tool/__init__.py
Normal file
24
superset/mcp_service/query/tool/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_query_info import get_query_info
|
||||
from .list_queries import list_queries
|
||||
|
||||
__all__ = [
|
||||
"list_queries",
|
||||
"get_query_info",
|
||||
]
|
||||
122
superset/mcp_service/query/tool/get_query_info.py
Normal file
122
superset/mcp_service/query/tool/get_query_info.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get query info FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for getting detailed information
|
||||
about a specific SQL query from the query history.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.query.schemas import (
|
||||
GetQueryInfoRequest,
|
||||
QueryError,
|
||||
QueryInfo,
|
||||
serialize_query_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="Query",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get query info",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_query_info(
|
||||
request: GetQueryInfoRequest, ctx: Context
|
||||
) -> QueryInfo | QueryError:
|
||||
"""Get SQL query history details by ID.
|
||||
|
||||
Returns query details including SQL text, execution status, timing,
|
||||
row count, and any error messages.
|
||||
|
||||
IMPORTANT FOR LLM CLIENTS:
|
||||
- Use numeric ID (e.g., 123)
|
||||
- To find a query ID, use the list_queries tool first
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{
|
||||
"identifier": 123
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Retrieving query information: identifier=%s" % (request.identifier,)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.query import QueryDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_query_info.lookup"):
|
||||
get_tool = ModelGetInfoCore(
|
||||
dao_class=QueryDAO,
|
||||
output_schema=QueryInfo,
|
||||
error_schema=QueryError,
|
||||
serializer=serialize_query_object,
|
||||
supports_slug=False,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
result = get_tool.run_tool(request.identifier)
|
||||
|
||||
if isinstance(result, QueryInfo):
|
||||
await ctx.info(
|
||||
"Query information retrieved successfully: "
|
||||
"query_id=%s, status=%s, database_id=%s"
|
||||
% (
|
||||
result.id,
|
||||
result.status,
|
||||
result.database_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await ctx.warning(
|
||||
"Query retrieval failed: error_type=%s, error=%s"
|
||||
% (result.error_type, result.error)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Query information retrieval failed: identifier=%s, error=%s, "
|
||||
"error_type=%s"
|
||||
% (
|
||||
request.identifier,
|
||||
str(e),
|
||||
type(e).__name__,
|
||||
)
|
||||
)
|
||||
return QueryError(
|
||||
error="Failed to get query info",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
156
superset/mcp_service/query/tool/list_queries.py
Normal file
156
superset/mcp_service/query/tool/list_queries.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
List queries FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for listing SQL query history
|
||||
with filtering, search, and pagination.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.query.schemas import (
|
||||
DEFAULT_QUERY_COLUMNS,
|
||||
ListQueriesRequest,
|
||||
QueryError,
|
||||
QueryFilter,
|
||||
QueryInfo,
|
||||
QueryList,
|
||||
serialize_query_object,
|
||||
SORTABLE_QUERY_COLUMNS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LIST_QUERIES_REQUEST = ListQueriesRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
class_permission_name="Query",
|
||||
annotations=ToolAnnotations(
|
||||
title="List queries",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def list_queries(
|
||||
request: ListQueriesRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> QueryList | QueryError:
|
||||
"""List SQL query history with filtering and search.
|
||||
|
||||
Returns recent queries executed by the current user (or all queries for
|
||||
admins), including SQL text, status, timing, and database information.
|
||||
Results are ordered by changed_on descending by default (start_time is not
|
||||
always populated for all query records).
|
||||
|
||||
Sortable columns for order_column: id, start_time, end_time, status,
|
||||
database_id, changed_on
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_queries")
|
||||
|
||||
request = request or _DEFAULT_LIST_QUERIES_REQUEST.model_copy(deep=True)
|
||||
|
||||
await ctx.info(
|
||||
"Listing queries: page=%s, page_size=%s, search=%s"
|
||||
% (
|
||||
request.page,
|
||||
request.page_size,
|
||||
request.search,
|
||||
)
|
||||
)
|
||||
await ctx.debug(
|
||||
"Query listing parameters: filters=%s, order_column=%s, "
|
||||
"order_direction=%s, select_columns=%s"
|
||||
% (
|
||||
request.filters,
|
||||
request.order_column,
|
||||
request.order_direction,
|
||||
request.select_columns,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.query import QueryDAO
|
||||
|
||||
def _serialize_query(obj: object, cols: list[str] | None) -> QueryInfo | None:
|
||||
return serialize_query_object(obj)
|
||||
|
||||
list_tool = ModelListCore(
|
||||
dao_class=QueryDAO,
|
||||
output_schema=QueryInfo,
|
||||
item_serializer=_serialize_query,
|
||||
filter_type=QueryFilter,
|
||||
default_columns=DEFAULT_QUERY_COLUMNS,
|
||||
search_columns=["tab_name", "sql"],
|
||||
list_field_name="queries",
|
||||
output_list_schema=QueryList,
|
||||
all_columns=list(QueryInfo.model_fields.keys()),
|
||||
sortable_columns=SORTABLE_QUERY_COLUMNS,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
with event_logger.log_context(action="mcp.list_queries.query"):
|
||||
result = list_tool.run_tool(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=request.order_column or "changed_on",
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Queries listed successfully: count=%s, total_count=%s, total_pages=%s"
|
||||
% (
|
||||
len(result.queries) if hasattr(result, "queries") else 0,
|
||||
getattr(result, "total_count", None),
|
||||
getattr(result, "total_pages", None),
|
||||
)
|
||||
)
|
||||
|
||||
columns_to_filter = result.columns_requested
|
||||
await ctx.debug(
|
||||
"Applying field filtering via serialization context: columns=%s"
|
||||
% (columns_to_filter,)
|
||||
)
|
||||
with event_logger.log_context(action="mcp.list_queries.serialization"):
|
||||
return result.model_dump(
|
||||
mode="json",
|
||||
context={"select_columns": columns_to_filter},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Query listing failed: page=%s, page_size=%s, error=%s, error_type=%s"
|
||||
% (
|
||||
request.page,
|
||||
request.page_size,
|
||||
str(e),
|
||||
type(e).__name__,
|
||||
)
|
||||
)
|
||||
raise
|
||||
16
superset/mcp_service/saved_query/__init__.py
Normal file
16
superset/mcp_service/saved_query/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
269
superset/mcp_service/saved_query/schemas.py
Normal file
269
superset/mcp_service/saved_query/schemas.py
Normal file
@@ -0,0 +1,269 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for saved query-related responses
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
field_validator,
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
)
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
|
||||
from superset.mcp_service.privacy import filter_user_directory_fields
|
||||
from superset.mcp_service.system.schemas import PaginationInfo
|
||||
from superset.mcp_service.utils.schema_utils import (
|
||||
parse_json_or_list,
|
||||
parse_json_or_model_list,
|
||||
)
|
||||
|
||||
DEFAULT_SAVED_QUERY_COLUMNS = ["id", "label", "db_id", "schema", "uuid"]
|
||||
SORTABLE_SAVED_QUERY_COLUMNS = [
|
||||
"id",
|
||||
"label",
|
||||
"db_id",
|
||||
"schema",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
ALL_SAVED_QUERY_COLUMNS = [
|
||||
"id",
|
||||
"label",
|
||||
"db_id",
|
||||
"schema",
|
||||
"catalog",
|
||||
"uuid",
|
||||
"sql",
|
||||
"description",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
"last_run",
|
||||
]
|
||||
|
||||
|
||||
class SavedQueryFilter(ColumnOperator):
|
||||
"""
|
||||
Filter object for saved query listing.
|
||||
col: The column to filter on. Must be one of the allowed filter fields.
|
||||
opr: The operator to use. Must be one of the supported operators.
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal["label", "db_id", "schema", "catalog", "created_by_fk"] = Field(
|
||||
...,
|
||||
description="Column to filter on.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
description="Operator to use.",
|
||||
)
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by (type depends on col and opr)"
|
||||
)
|
||||
|
||||
|
||||
class SavedQueryInfo(BaseModel):
|
||||
id: int | None = Field(None, description="Saved query ID")
|
||||
uuid: str | None = Field(None, description="Saved query UUID")
|
||||
label: str | None = Field(None, description="Saved query label/name")
|
||||
sql: str | None = Field(None, description="SQL query text")
|
||||
db_id: int | None = Field(None, description="Database connection ID")
|
||||
schema: str | None = Field(None, description="Database schema name")
|
||||
catalog: str | None = Field(None, description="Database catalog name")
|
||||
description: str | None = Field(None, description="User-provided description")
|
||||
changed_on: str | datetime | None = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
created_on: str | datetime | None = Field(None, description="Creation timestamp")
|
||||
last_run: str | datetime | None = Field(
|
||||
None, description="Timestamp of last execution"
|
||||
)
|
||||
model_config = ConfigDict(
|
||||
from_attributes=True,
|
||||
ser_json_timedelta="iso8601",
|
||||
populate_by_name=True,
|
||||
)
|
||||
|
||||
@model_serializer(mode="wrap")
|
||||
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
|
||||
data = filter_user_directory_fields(serializer(self))
|
||||
|
||||
if info.context and isinstance(info.context, dict):
|
||||
select_columns = info.context.get("select_columns")
|
||||
if select_columns:
|
||||
requested_fields = set(select_columns)
|
||||
return {k: v for k, v in data.items() if k in requested_fields}
|
||||
|
||||
return data
|
||||
|
||||
|
||||
class SavedQueryList(BaseModel):
|
||||
saved_queries: List[SavedQueryInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Requested columns for the response",
|
||||
)
|
||||
columns_loaded: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Columns that were actually loaded for each saved query",
|
||||
)
|
||||
columns_available: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="All columns available for selection via select_columns parameter",
|
||||
)
|
||||
sortable_columns: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="Columns that can be used with order_column parameter",
|
||||
)
|
||||
filters_applied: List[SavedQueryFilter] = Field(
|
||||
default_factory=list,
|
||||
description="List of advanced filter dicts applied to the query.",
|
||||
)
|
||||
pagination: PaginationInfo | None = None
|
||||
timestamp: datetime | None = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListSavedQueriesRequest(BaseModel):
|
||||
"""Request schema for list_saved_queries."""
|
||||
|
||||
filters: Annotated[
|
||||
List[SavedQueryFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (column, operator, value). Each "
|
||||
"filter is an object with 'col', 'opr', and 'value' "
|
||||
"properties. Cannot be used together with 'search'.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of columns to select. Defaults to common columns if not "
|
||||
"specified.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
str | None,
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search string to match against saved query fields. "
|
||||
"Cannot be used together with 'filters'.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
str | None, Field(default=None, description="Column to order results by")
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(
|
||||
default="desc", description="Direction to order results ('asc' or 'desc')"
|
||||
),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number for pagination (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
int,
|
||||
Field(
|
||||
default=DEFAULT_PAGE_SIZE,
|
||||
gt=0,
|
||||
le=MAX_PAGE_SIZE,
|
||||
description=f"Number of items per page (max {MAX_PAGE_SIZE})",
|
||||
),
|
||||
]
|
||||
|
||||
@field_validator("filters", mode="before")
|
||||
@classmethod
|
||||
def parse_filters(cls, v: Any) -> List[SavedQueryFilter]:
|
||||
"""Accept both JSON string and list of objects."""
|
||||
return parse_json_or_model_list(v, SavedQueryFilter, "filters")
|
||||
|
||||
@field_validator("select_columns", mode="before")
|
||||
@classmethod
|
||||
def parse_columns(cls, v: Any) -> List[str]:
|
||||
"""Accept JSON array, list, or comma-separated string."""
|
||||
return parse_json_or_list(v, "select_columns")
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListSavedQueriesRequest":
|
||||
"""Prevent using both search and filters simultaneously."""
|
||||
if self.search and self.filters:
|
||||
raise ValueError(
|
||||
"Cannot use both 'search' and 'filters' parameters simultaneously. "
|
||||
"Use either 'search' for text-based searching across multiple fields, "
|
||||
"or 'filters' for precise column-based filtering, but not both."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class SavedQueryError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: str | datetime | None = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class GetSavedQueryInfoRequest(BaseModel):
|
||||
"""Request schema for get_saved_query_info with support for ID or UUID."""
|
||||
|
||||
identifier: Annotated[
|
||||
int | str,
|
||||
Field(description="Saved query identifier - can be numeric ID or UUID string"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_saved_query_object(saved_query: Any) -> SavedQueryInfo | None:
|
||||
if not saved_query:
|
||||
return None
|
||||
|
||||
return SavedQueryInfo(
|
||||
id=getattr(saved_query, "id", None),
|
||||
uuid=str(getattr(saved_query, "uuid", ""))
|
||||
if getattr(saved_query, "uuid", None)
|
||||
else None,
|
||||
label=getattr(saved_query, "label", None),
|
||||
sql=getattr(saved_query, "sql", None),
|
||||
db_id=getattr(saved_query, "db_id", None),
|
||||
schema=getattr(saved_query, "schema", None),
|
||||
catalog=getattr(saved_query, "catalog", None),
|
||||
description=getattr(saved_query, "description", None),
|
||||
changed_on=getattr(saved_query, "changed_on", None),
|
||||
created_on=getattr(saved_query, "created_on", None),
|
||||
last_run=getattr(saved_query, "last_run", None),
|
||||
)
|
||||
24
superset/mcp_service/saved_query/tool/__init__.py
Normal file
24
superset/mcp_service/saved_query/tool/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .get_saved_query_info import get_saved_query_info
|
||||
from .list_saved_queries import list_saved_queries
|
||||
|
||||
__all__ = [
|
||||
"list_saved_queries",
|
||||
"get_saved_query_info",
|
||||
]
|
||||
129
superset/mcp_service/saved_query/tool/get_saved_query_info.py
Normal file
129
superset/mcp_service/saved_query/tool/get_saved_query_info.py
Normal file
@@ -0,0 +1,129 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get saved query info FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for getting detailed information
|
||||
about a specific saved SQL query.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelGetInfoCore
|
||||
from superset.mcp_service.saved_query.schemas import (
|
||||
GetSavedQueryInfoRequest,
|
||||
SavedQueryError,
|
||||
SavedQueryInfo,
|
||||
serialize_saved_query_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["discovery"],
|
||||
class_permission_name="SavedQuery",
|
||||
annotations=ToolAnnotations(
|
||||
title="Get saved query info",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def get_saved_query_info(
|
||||
request: GetSavedQueryInfoRequest, ctx: Context
|
||||
) -> SavedQueryInfo | SavedQueryError:
|
||||
"""Get saved query details by ID or UUID.
|
||||
|
||||
Returns the full saved query including SQL text, label, database,
|
||||
schema, and timestamps.
|
||||
|
||||
IMPORTANT FOR LLM CLIENTS:
|
||||
- Use numeric ID (e.g., 42) or UUID string (e.g., "a1b2c3d4-...")
|
||||
- To find a saved query ID, use the list_saved_queries tool first
|
||||
|
||||
Example usage:
|
||||
```json
|
||||
{
|
||||
"identifier": 42
|
||||
}
|
||||
```
|
||||
|
||||
Or with UUID:
|
||||
```json
|
||||
{
|
||||
"identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab"
|
||||
}
|
||||
```
|
||||
"""
|
||||
await ctx.info(
|
||||
"Retrieving saved query information: identifier=%s" % (request.identifier,)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.query import SavedQueryDAO
|
||||
|
||||
with event_logger.log_context(action="mcp.get_saved_query_info.lookup"):
|
||||
get_tool = ModelGetInfoCore(
|
||||
dao_class=SavedQueryDAO,
|
||||
output_schema=SavedQueryInfo,
|
||||
error_schema=SavedQueryError,
|
||||
serializer=serialize_saved_query_object,
|
||||
supports_slug=False,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
result = get_tool.run_tool(request.identifier)
|
||||
|
||||
if isinstance(result, SavedQueryInfo):
|
||||
await ctx.info(
|
||||
"Saved query information retrieved successfully: "
|
||||
"saved_query_id=%s, label=%s, db_id=%s"
|
||||
% (
|
||||
result.id,
|
||||
result.label,
|
||||
result.db_id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
await ctx.warning(
|
||||
"Saved query retrieval failed: error_type=%s, error=%s"
|
||||
% (result.error_type, result.error)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Saved query information retrieval failed: identifier=%s, error=%s, "
|
||||
"error_type=%s"
|
||||
% (
|
||||
request.identifier,
|
||||
str(e),
|
||||
type(e).__name__,
|
||||
)
|
||||
)
|
||||
return SavedQueryError(
|
||||
error="Failed to get saved query info",
|
||||
error_type="InternalError",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
158
superset/mcp_service/saved_query/tool/list_saved_queries.py
Normal file
158
superset/mcp_service/saved_query/tool/list_saved_queries.py
Normal file
@@ -0,0 +1,158 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
List saved queries FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for listing saved SQL queries
|
||||
with filtering, search, and pagination.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from superset_core.mcp.decorators import tool, ToolAnnotations
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.mcp_core import ModelListCore
|
||||
from superset.mcp_service.saved_query.schemas import (
|
||||
DEFAULT_SAVED_QUERY_COLUMNS,
|
||||
ListSavedQueriesRequest,
|
||||
SavedQueryError,
|
||||
SavedQueryFilter,
|
||||
SavedQueryInfo,
|
||||
SavedQueryList,
|
||||
serialize_saved_query_object,
|
||||
SORTABLE_SAVED_QUERY_COLUMNS,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_LIST_SAVED_QUERIES_REQUEST = ListSavedQueriesRequest()
|
||||
|
||||
|
||||
@tool(
|
||||
tags=["core"],
|
||||
class_permission_name="SavedQuery",
|
||||
annotations=ToolAnnotations(
|
||||
title="List saved queries",
|
||||
readOnlyHint=True,
|
||||
destructiveHint=False,
|
||||
),
|
||||
)
|
||||
async def list_saved_queries(
|
||||
request: ListSavedQueriesRequest | None = None,
|
||||
ctx: Context | None = None,
|
||||
) -> SavedQueryList | SavedQueryError:
|
||||
"""List saved SQL queries with filtering and search.
|
||||
|
||||
Returns saved queries owned by the current user, including label, SQL,
|
||||
database ID, and schema.
|
||||
|
||||
Sortable columns for order_column: id, label, db_id, schema,
|
||||
changed_on, created_on
|
||||
"""
|
||||
if ctx is None:
|
||||
raise RuntimeError("FastMCP context is required for list_saved_queries")
|
||||
|
||||
request = request or _DEFAULT_LIST_SAVED_QUERIES_REQUEST.model_copy(deep=True)
|
||||
|
||||
await ctx.info(
|
||||
"Listing saved queries: page=%s, page_size=%s, search=%s"
|
||||
% (
|
||||
request.page,
|
||||
request.page_size,
|
||||
request.search,
|
||||
)
|
||||
)
|
||||
await ctx.debug(
|
||||
"Saved query listing parameters: filters=%s, order_column=%s, "
|
||||
"order_direction=%s, select_columns=%s"
|
||||
% (
|
||||
request.filters,
|
||||
request.order_column,
|
||||
request.order_direction,
|
||||
request.select_columns,
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
from superset.daos.query import SavedQueryDAO
|
||||
|
||||
def _serialize_saved_query(
|
||||
obj: object, cols: list[str] | None
|
||||
) -> SavedQueryInfo | None:
|
||||
return serialize_saved_query_object(obj)
|
||||
|
||||
list_tool = ModelListCore(
|
||||
dao_class=SavedQueryDAO,
|
||||
output_schema=SavedQueryInfo,
|
||||
item_serializer=_serialize_saved_query,
|
||||
filter_type=SavedQueryFilter,
|
||||
default_columns=DEFAULT_SAVED_QUERY_COLUMNS,
|
||||
search_columns=["label", "description", "sql"],
|
||||
list_field_name="saved_queries",
|
||||
output_list_schema=SavedQueryList,
|
||||
all_columns=list(SavedQueryInfo.model_fields.keys()),
|
||||
sortable_columns=SORTABLE_SAVED_QUERY_COLUMNS,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
with event_logger.log_context(action="mcp.list_saved_queries.query"):
|
||||
result = list_tool.run_tool(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=request.order_column,
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
|
||||
await ctx.info(
|
||||
"Saved queries listed successfully: count=%s, total_count=%s, "
|
||||
"total_pages=%s"
|
||||
% (
|
||||
len(result.saved_queries) if hasattr(result, "saved_queries") else 0,
|
||||
getattr(result, "total_count", None),
|
||||
getattr(result, "total_pages", None),
|
||||
)
|
||||
)
|
||||
|
||||
columns_to_filter = result.columns_requested
|
||||
await ctx.debug(
|
||||
"Applying field filtering via serialization context: columns=%s"
|
||||
% (columns_to_filter,)
|
||||
)
|
||||
with event_logger.log_context(action="mcp.list_saved_queries.serialization"):
|
||||
return result.model_dump(
|
||||
mode="json",
|
||||
context={"select_columns": columns_to_filter},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Saved query listing failed: page=%s, page_size=%s, error=%s, "
|
||||
"error_type=%s"
|
||||
% (
|
||||
request.page,
|
||||
request.page_size,
|
||||
str(e),
|
||||
type(e).__name__,
|
||||
)
|
||||
)
|
||||
raise
|
||||
16
tests/unit_tests/mcp_service/query/__init__.py
Normal file
16
tests/unit_tests/mcp_service/query/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/query/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/query/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
328
tests/unit_tests/mcp_service/query/tool/test_query_tools.py
Normal file
328
tests/unit_tests/mcp_service/query/tool/test_query_tools.py
Normal file
@@ -0,0 +1,328 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.query.schemas import (
|
||||
ListQueriesRequest,
|
||||
QueryFilter,
|
||||
)
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestQueryFilterSchema:
|
||||
"""Tests for QueryFilter schema — filterable columns."""
|
||||
|
||||
def test_invalid_filter_column_rejected(self):
|
||||
"""Columns not in the Literal set must be rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
QueryFilter(col="not_a_real_column", opr="eq", value="test")
|
||||
|
||||
def test_valid_status_filter_accepted(self):
|
||||
"""status is a valid filter column."""
|
||||
f = QueryFilter(col="status", opr="eq", value="success")
|
||||
assert f.col == "status"
|
||||
|
||||
def test_valid_database_id_filter_accepted(self):
|
||||
"""database_id is a valid filter column."""
|
||||
f = QueryFilter(col="database_id", opr="eq", value=1)
|
||||
assert f.col == "database_id"
|
||||
|
||||
def test_valid_schema_filter_accepted(self):
|
||||
"""schema is a valid filter column."""
|
||||
f = QueryFilter(col="schema", opr="eq", value="public")
|
||||
assert f.col == "schema"
|
||||
|
||||
def test_valid_user_id_filter_accepted(self):
|
||||
"""user_id filter enables admin-level filtering by user."""
|
||||
f = QueryFilter(col="user_id", opr="eq", value=42)
|
||||
assert f.col == "user_id"
|
||||
|
||||
def test_valid_start_time_filter_accepted(self):
|
||||
"""start_time filter enables time-range queries."""
|
||||
f = QueryFilter(col="start_time", opr="gt", value=1700000000.0)
|
||||
assert f.col == "start_time"
|
||||
|
||||
|
||||
def create_mock_query(
|
||||
query_id: int = 1,
|
||||
sql: str = "SELECT * FROM table",
|
||||
executed_sql: str | None = None,
|
||||
status: str = "success",
|
||||
start_time: float = 1700000000.0,
|
||||
end_time: float = 1700000001.0,
|
||||
rows: int = 100,
|
||||
database_id: int = 1,
|
||||
schema: str = "public",
|
||||
catalog: str | None = None,
|
||||
tab_name: str = "SQL Lab 1",
|
||||
error_message: str | None = None,
|
||||
client_id: str = "abc123",
|
||||
user_id: int | None = 1,
|
||||
) -> MagicMock:
|
||||
"""Factory function to create mock query objects with sensible defaults."""
|
||||
query = MagicMock()
|
||||
query.id = query_id
|
||||
query.sql = sql
|
||||
query.executed_sql = executed_sql
|
||||
query.status = status
|
||||
query.start_time = start_time
|
||||
query.end_time = end_time
|
||||
query.rows = rows
|
||||
query.database_id = database_id
|
||||
query.schema = schema
|
||||
query.catalog = catalog
|
||||
query.tab_name = tab_name
|
||||
query.error_message = error_message
|
||||
query.client_id = client_id
|
||||
query.limit = 1000
|
||||
query.progress = 100
|
||||
query.changed_on = None
|
||||
query.user_id = user_id
|
||||
return query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_queries_basic(mock_list, mcp_server):
|
||||
"""Test basic query listing functionality."""
|
||||
query = create_mock_query()
|
||||
query._mapping = {
|
||||
"id": query.id,
|
||||
"sql": query.sql,
|
||||
"status": query.status,
|
||||
"start_time": query.start_time,
|
||||
"database_id": query.database_id,
|
||||
"schema": query.schema,
|
||||
}
|
||||
mock_list.return_value = ([query], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListQueriesRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_queries", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["queries"] is not None
|
||||
assert len(data["queries"]) == 1
|
||||
assert data["queries"][0]["id"] == 1
|
||||
assert data["queries"][0]["status"] == "success"
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_queries_with_status_filter(mock_list, mcp_server):
|
||||
"""Test query listing with status filter."""
|
||||
query = create_mock_query(status="failed", error_message="Syntax error")
|
||||
query._mapping = {
|
||||
"id": query.id,
|
||||
"sql": query.sql,
|
||||
"status": query.status,
|
||||
"error_message": query.error_message,
|
||||
}
|
||||
mock_list.return_value = ([query], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListQueriesRequest(
|
||||
page=1,
|
||||
page_size=10,
|
||||
filters=[
|
||||
{"col": "status", "opr": "eq", "value": "failed"},
|
||||
],
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_queries", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["queries"] is not None
|
||||
assert len(data["queries"]) == 1
|
||||
assert data["queries"][0]["status"] == "failed"
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_queries_default_page_size(mock_list, mcp_server):
|
||||
"""Test that default page size is 25 for query history."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_queries", {})
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["page_size"] == 25
|
||||
|
||||
|
||||
def test_list_queries_request_rejects_both_search_and_filters():
|
||||
"""Cannot use search and filters simultaneously."""
|
||||
with pytest.raises(ValidationError):
|
||||
ListQueriesRequest(
|
||||
search="test",
|
||||
filters=[{"col": "status", "opr": "eq", "value": "success"}],
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_query_info_basic(mock_find, mcp_server):
|
||||
"""Test basic get query info functionality."""
|
||||
query = create_mock_query()
|
||||
mock_find.return_value = query
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_query_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 1
|
||||
assert data["status"] == "success"
|
||||
assert data["database_id"] == 1
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_query_info_not_found(mock_find, mcp_server):
|
||||
"""Test get query info when query does not exist."""
|
||||
mock_find.return_value = None
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_query_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
assert result.data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_queries_empty(mock_list, mcp_server):
|
||||
"""Test query listing returns empty list when no results."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListQueriesRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_queries", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["queries"] == []
|
||||
assert data["count"] == 0
|
||||
assert data["total_count"] == 0
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_queries_pagination_info(mock_list, mcp_server):
|
||||
"""Test that pagination info is correctly returned."""
|
||||
queries = [create_mock_query(query_id=i) for i in range(1, 4)]
|
||||
for q in queries:
|
||||
q._mapping = {"id": q.id, "sql": q.sql, "status": q.status}
|
||||
mock_list.return_value = (queries, 100)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListQueriesRequest(page=1, page_size=3)
|
||||
result = await client.call_tool(
|
||||
"list_queries", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["total_count"] == 100
|
||||
assert data["page_size"] == 3
|
||||
assert data["has_next"] is True
|
||||
assert data["has_previous"] is False
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_queries_default_order_is_changed_on_desc(mock_list, mcp_server):
|
||||
"""Test that default ordering is changed_on descending."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("list_queries", {})
|
||||
assert result.content is not None
|
||||
mock_list.assert_called_once()
|
||||
call_kwargs = mock_list.call_args
|
||||
assert call_kwargs.kwargs.get("order_column") == "changed_on"
|
||||
assert call_kwargs.kwargs.get("order_direction") == "desc"
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_queries_select_columns_projects_fields(mock_list, mcp_server):
|
||||
"""select_columns limits which fields appear in each query result."""
|
||||
query = create_mock_query()
|
||||
query._mapping = {"id": query.id, "status": query.status}
|
||||
mock_list.return_value = ([query], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListQueriesRequest(
|
||||
page=1, page_size=10, select_columns=["id", "status"]
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_queries", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["queries"] is not None
|
||||
q = data["queries"][0]
|
||||
assert set(q.keys()) == {"id", "status"}
|
||||
assert q["id"] == 1
|
||||
assert q["status"] == "success"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_queries_invalid_order_column_raises(mcp_server):
|
||||
"""order_column not in SORTABLE_QUERY_COLUMNS must be rejected."""
|
||||
request = ListQueriesRequest(page=1, page_size=10, order_column="tab_name")
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError, match="Invalid order_column"):
|
||||
await client.call_tool("list_queries", {"request": request.model_dump()})
|
||||
|
||||
|
||||
@patch("superset.daos.query.QueryDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_query_info_internal_error(mock_find, mcp_server):
|
||||
"""When an unexpected exception is raised, get_query_info returns InternalError."""
|
||||
mock_find.side_effect = RuntimeError("unexpected db failure")
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_query_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "InternalError"
|
||||
assert data["error"] == "Failed to get query info"
|
||||
16
tests/unit_tests/mcp_service/saved_query/__init__.py
Normal file
16
tests/unit_tests/mcp_service/saved_query/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/mcp_service/saved_query/tool/__init__.py
Normal file
16
tests/unit_tests/mcp_service/saved_query/tool/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,337 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.app import mcp
|
||||
from superset.mcp_service.saved_query.schemas import (
|
||||
ListSavedQueriesRequest,
|
||||
SavedQueryFilter,
|
||||
)
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TestSavedQueryFilterSchema:
|
||||
"""Tests for SavedQueryFilter schema — filterable columns."""
|
||||
|
||||
def test_invalid_filter_column_rejected(self):
|
||||
"""Columns not in the Literal set must be rejected."""
|
||||
with pytest.raises(ValidationError):
|
||||
SavedQueryFilter(col="not_a_real_column", opr="eq", value="test")
|
||||
|
||||
def test_user_id_is_rejected_as_filter_column(self):
|
||||
"""user_id is not a filter column; use created_by_fk instead."""
|
||||
with pytest.raises(ValidationError):
|
||||
SavedQueryFilter(col="user_id", opr="eq", value=1)
|
||||
|
||||
def test_valid_label_filter_accepted(self):
|
||||
"""label is a valid filter column."""
|
||||
f = SavedQueryFilter(col="label", opr="eq", value="my query")
|
||||
assert f.col == "label"
|
||||
|
||||
def test_valid_db_id_filter_accepted(self):
|
||||
"""db_id is a valid filter column."""
|
||||
f = SavedQueryFilter(col="db_id", opr="eq", value=1)
|
||||
assert f.col == "db_id"
|
||||
|
||||
def test_valid_schema_filter_accepted(self):
|
||||
"""schema is a valid filter column."""
|
||||
f = SavedQueryFilter(col="schema", opr="eq", value="public")
|
||||
assert f.col == "schema"
|
||||
|
||||
def test_valid_catalog_filter_accepted(self):
|
||||
"""catalog is a valid filter column."""
|
||||
f = SavedQueryFilter(col="catalog", opr="eq", value="my_catalog")
|
||||
assert f.col == "catalog"
|
||||
|
||||
def test_valid_created_by_fk_filter_accepted(self):
|
||||
"""created_by_fk enables filtering by the owner user ID."""
|
||||
f = SavedQueryFilter(col="created_by_fk", opr="eq", value=42)
|
||||
assert f.col == "created_by_fk"
|
||||
|
||||
|
||||
def create_mock_saved_query(
|
||||
saved_query_id: int = 1,
|
||||
label: str = "My Query",
|
||||
sql: str = "SELECT 1",
|
||||
db_id: int = 1,
|
||||
schema: str = "public",
|
||||
catalog: str | None = None,
|
||||
description: str = "Test query",
|
||||
uuid: str = "test-uuid-1234",
|
||||
last_run: datetime | None = None,
|
||||
) -> MagicMock:
|
||||
"""Factory function to create mock saved query objects with sensible defaults."""
|
||||
saved_query = MagicMock()
|
||||
saved_query.id = saved_query_id
|
||||
saved_query.label = label
|
||||
saved_query.sql = sql
|
||||
saved_query.db_id = db_id
|
||||
saved_query.schema = schema
|
||||
saved_query.catalog = catalog
|
||||
saved_query.description = description
|
||||
saved_query.uuid = uuid
|
||||
saved_query.changed_on = None
|
||||
saved_query.created_on = None
|
||||
saved_query.last_run = last_run
|
||||
return saved_query
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_auth():
|
||||
"""Mock authentication for all tests."""
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
|
||||
mock_user = Mock()
|
||||
mock_user.id = 1
|
||||
mock_user.username = "admin"
|
||||
mock_get_user.return_value = mock_user
|
||||
yield mock_get_user
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_saved_queries_basic(mock_list, mcp_server):
|
||||
"""Test basic saved query listing functionality."""
|
||||
saved_query = create_mock_saved_query()
|
||||
saved_query._mapping = {
|
||||
"id": saved_query.id,
|
||||
"label": saved_query.label,
|
||||
"db_id": saved_query.db_id,
|
||||
"schema": saved_query.schema,
|
||||
"uuid": saved_query.uuid,
|
||||
}
|
||||
mock_list.return_value = ([saved_query], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListSavedQueriesRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_saved_queries", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["saved_queries"] is not None
|
||||
assert len(data["saved_queries"]) == 1
|
||||
assert data["saved_queries"][0]["id"] == 1
|
||||
assert data["saved_queries"][0]["label"] == "My Query"
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_saved_queries_with_search(mock_list, mcp_server):
|
||||
"""Test saved query listing with search functionality."""
|
||||
saved_query = create_mock_saved_query(label="Production Query")
|
||||
saved_query._mapping = {
|
||||
"id": saved_query.id,
|
||||
"label": saved_query.label,
|
||||
}
|
||||
mock_list.return_value = ([saved_query], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListSavedQueriesRequest(page=1, page_size=10, search="Production")
|
||||
result = await client.call_tool(
|
||||
"list_saved_queries", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["saved_queries"] is not None
|
||||
assert len(data["saved_queries"]) == 1
|
||||
assert data["saved_queries"][0]["label"] == "Production Query"
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_saved_queries_with_filters(mock_list, mcp_server):
|
||||
"""Test saved query listing with filters."""
|
||||
saved_query = create_mock_saved_query(db_id=2)
|
||||
saved_query._mapping = {
|
||||
"id": saved_query.id,
|
||||
"label": saved_query.label,
|
||||
"db_id": saved_query.db_id,
|
||||
}
|
||||
mock_list.return_value = ([saved_query], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListSavedQueriesRequest(
|
||||
page=1,
|
||||
page_size=10,
|
||||
filters=[
|
||||
{"col": "db_id", "opr": "eq", "value": 2},
|
||||
],
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_saved_queries", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["saved_queries"] is not None
|
||||
assert len(data["saved_queries"]) == 1
|
||||
|
||||
|
||||
def test_list_saved_queries_request_rejects_both_search_and_filters():
|
||||
"""Cannot use search and filters simultaneously."""
|
||||
with pytest.raises(ValidationError):
|
||||
ListSavedQueriesRequest(
|
||||
search="test",
|
||||
filters=[{"col": "label", "opr": "eq", "value": "test"}],
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_saved_query_info_basic(mock_find, mcp_server):
|
||||
"""Test basic get saved query info functionality."""
|
||||
saved_query = create_mock_saved_query()
|
||||
mock_find.return_value = saved_query
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_saved_query_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 1
|
||||
assert data["label"] == "My Query"
|
||||
assert data["sql"] == "SELECT 1"
|
||||
assert data["db_id"] == 1
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_saved_query_info_not_found(mock_find, mcp_server):
|
||||
"""Test get saved query info when saved query does not exist."""
|
||||
mock_find.return_value = None
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_saved_query_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
assert result.data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_saved_queries_empty(mock_list, mcp_server):
|
||||
"""Test saved query listing returns empty list when no results."""
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListSavedQueriesRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_saved_queries", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["saved_queries"] == []
|
||||
assert data["count"] == 0
|
||||
assert data["total_count"] == 0
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_saved_query_info_by_uuid(mock_find, mcp_server):
|
||||
"""Test get saved query info by UUID string."""
|
||||
saved_query = create_mock_saved_query(uuid="a1b2c3d4-5678-90ab-cdef-1234567890ab")
|
||||
mock_find.return_value = saved_query
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_saved_query_info",
|
||||
{"request": {"identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab"}},
|
||||
)
|
||||
assert result.content is not None
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["id"] == 1
|
||||
assert data["uuid"] == "a1b2c3d4-5678-90ab-cdef-1234567890ab"
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_saved_queries_pagination_info(mock_list, mcp_server):
|
||||
"""Test that pagination info is correctly returned."""
|
||||
saved_queries = [create_mock_saved_query(saved_query_id=i) for i in range(1, 4)]
|
||||
for sq in saved_queries:
|
||||
sq._mapping = {"id": sq.id, "label": sq.label}
|
||||
mock_list.return_value = (saved_queries, 25)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListSavedQueriesRequest(page=1, page_size=3)
|
||||
result = await client.call_tool(
|
||||
"list_saved_queries", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["total_count"] == 25
|
||||
assert data["count"] == 3
|
||||
assert data["page_size"] == 3
|
||||
assert data["has_next"] is True
|
||||
assert data["has_previous"] is False
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_saved_queries_select_columns_projects_fields(mock_list, mcp_server):
|
||||
"""select_columns limits which fields appear in each saved query result."""
|
||||
saved_query = create_mock_saved_query()
|
||||
saved_query._mapping = {"id": saved_query.id, "label": saved_query.label}
|
||||
mock_list.return_value = ([saved_query], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListSavedQueriesRequest(
|
||||
page=1, page_size=10, select_columns=["id", "label"]
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_saved_queries", {"request": request.model_dump()}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["saved_queries"] is not None
|
||||
sq = data["saved_queries"][0]
|
||||
assert set(sq.keys()) == {"id", "label"}
|
||||
assert sq["id"] == 1
|
||||
assert sq["label"] == "My Query"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_saved_queries_invalid_order_column_raises(mcp_server):
|
||||
"""order_column not in SORTABLE_SAVED_QUERY_COLUMNS must be rejected."""
|
||||
request = ListSavedQueriesRequest(page=1, page_size=10, order_column="sql")
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError, match="Invalid order_column"):
|
||||
await client.call_tool(
|
||||
"list_saved_queries", {"request": request.model_dump()}
|
||||
)
|
||||
|
||||
|
||||
@patch("superset.daos.query.SavedQueryDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_saved_query_info_internal_error(mock_find, mcp_server):
|
||||
"""Unexpected exception in get_saved_query_info returns InternalError."""
|
||||
mock_find.side_effect = RuntimeError("unexpected db failure")
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_saved_query_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
data = json.loads(result.content[0].text)
|
||||
assert data["error_type"] == "InternalError"
|
||||
assert data["error"] == "Failed to get saved query info"
|
||||
Reference in New Issue
Block a user