feat(mcp): add list and get tools for saved queries and query history (#40346)

Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-05-29 19:16:11 -07:00
committed by GitHub
parent f614863ed7
commit a69bbcb044
17 changed files with 1948 additions and 0 deletions

View File

@@ -162,6 +162,10 @@ SQL Lab Integration:
- execute_sql: Execute SQL queries and get results (requires database_id and SQL access)
- save_sql_query: Save a SQL query to Saved Queries list (requires write access)
- open_sql_lab_with_context: Generate SQL Lab URL with pre-filled sql
- list_saved_queries: List saved SQL queries with filtering and search (1-based pagination)
- get_saved_query_info: Get saved query details by ID or UUID
- list_queries: List SQL query history with filtering and search (1-based pagination)
- get_query_info: Get SQL query history details by ID
Schema Discovery:
- get_schema: Get schema metadata for chart/dataset/dashboard (columns, filters)
@@ -668,6 +672,14 @@ from superset.mcp_service.dataset.tool import ( # noqa: F401, E402
from superset.mcp_service.explore.tool import ( # noqa: F401, E402
generate_explore_link,
)
from superset.mcp_service.query.tool import ( # noqa: F401, E402
get_query_info,
list_queries,
)
from superset.mcp_service.saved_query.tool import ( # noqa: F401, E402
get_saved_query_info,
list_saved_queries,
)
from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402
execute_sql,
open_sql_lab_with_context,

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,293 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Pydantic schemas for query history-related responses
"""
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_serializer,
model_validator,
PositiveInt,
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.constants import MAX_PAGE_SIZE
from superset.mcp_service.privacy import filter_user_directory_fields
from superset.mcp_service.system.schemas import PaginationInfo
from superset.mcp_service.utils.schema_utils import (
parse_json_or_list,
parse_json_or_model_list,
)
DEFAULT_QUERY_COLUMNS = ["id", "status", "start_time", "database_id", "schema"]
SORTABLE_QUERY_COLUMNS = [
"id",
"start_time",
"end_time",
"status",
"database_id",
"changed_on",
]
ALL_QUERY_COLUMNS = [
"id",
"sql",
"executed_sql",
"status",
"start_time",
"end_time",
"rows",
"database_id",
"schema",
"catalog",
"tab_name",
"error_message",
"client_id",
"limit",
"progress",
"changed_on",
"user_id",
]
DEFAULT_QUERY_PAGE_SIZE = 25
class QueryFilter(ColumnOperator):
"""
Filter object for query history listing.
col: The column to filter on. Must be one of the allowed filter fields.
opr: The operator to use. Must be one of the supported operators.
value: The value to filter by (type depends on col and opr).
"""
col: Literal["status", "database_id", "schema", "user_id", "start_time"] = Field(
...,
description="Column to filter on.",
)
opr: ColumnOperatorEnum = Field(
...,
description="Operator to use.",
)
value: str | int | float | bool | List[str | int | float | bool] = Field(
..., description="Value to filter by (type depends on col and opr)"
)
class QueryInfo(BaseModel):
id: int | None = Field(None, description="Query ID")
sql: str | None = Field(None, description="SQL query text as submitted")
executed_sql: str | None = Field(
None, description="Actual SQL executed after templating/CTAS rewriting"
)
status: str | None = Field(None, description="Query execution status")
start_time: float | None = Field(
None, description="Query start time (seconds since epoch)"
)
end_time: float | None = Field(
None, description="Query end time (seconds since epoch)"
)
rows: int | None = Field(None, description="Number of rows returned or affected")
database_id: int | None = Field(None, description="Database connection ID")
schema: str | None = Field(None, description="Database schema name")
catalog: str | None = Field(None, description="Database catalog name")
tab_name: str | None = Field(None, description="SQL Lab tab name")
error_message: str | None = Field(None, description="Error message if query failed")
client_id: str | None = Field(None, description="Client-assigned query identifier")
limit: int | None = Field(None, description="Row limit applied to the query")
progress: int | None = Field(None, description="Query execution progress (0-100)")
changed_on: str | datetime | None = Field(
None, description="Last modification timestamp"
)
user_id: int | None = Field(None, description="ID of the user who ran the query")
model_config = ConfigDict(
from_attributes=True,
ser_json_timedelta="iso8601",
populate_by_name=True,
)
@model_serializer(mode="wrap")
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
data = filter_user_directory_fields(serializer(self))
if info.context and isinstance(info.context, dict):
select_columns = info.context.get("select_columns")
if select_columns:
requested_fields = set(select_columns)
return {k: v for k, v in data.items() if k in requested_fields}
return data
class QueryList(BaseModel):
queries: List[QueryInfo]
count: int
total_count: int
page: int
page_size: int
total_pages: int
has_previous: bool
has_next: bool
columns_requested: List[str] = Field(
default_factory=list,
description="Requested columns for the response",
)
columns_loaded: List[str] = Field(
default_factory=list,
description="Columns that were actually loaded for each query",
)
columns_available: List[str] = Field(
default_factory=list,
description="All columns available for selection via select_columns parameter",
)
sortable_columns: List[str] = Field(
default_factory=list,
description="Columns that can be used with order_column parameter",
)
filters_applied: List[QueryFilter] = Field(
default_factory=list,
description="List of advanced filter dicts applied to the query.",
)
pagination: PaginationInfo | None = None
timestamp: datetime | None = None
model_config = ConfigDict(ser_json_timedelta="iso8601")
class ListQueriesRequest(BaseModel):
"""Request schema for list_queries."""
filters: Annotated[
List[QueryFilter],
Field(
default_factory=list,
description="List of filter objects (column, operator, value). Each "
"filter is an object with 'col', 'opr', and 'value' "
"properties. Cannot be used together with 'search'.",
),
]
select_columns: Annotated[
List[str],
Field(
default_factory=list,
description="List of columns to select. Defaults to common columns if not "
"specified.",
),
]
search: Annotated[
str | None,
Field(
default=None,
description="Text search string to match against query fields. "
"Cannot be used together with 'filters'.",
),
]
order_column: Annotated[
str | None,
Field(default=None, description="Column to order results by"),
]
order_direction: Annotated[
Literal["asc", "desc"],
Field(
default="desc",
description="Direction to order results ('asc' or 'desc')",
),
]
page: Annotated[
PositiveInt,
Field(default=1, description="Page number for pagination (1-based)"),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_QUERY_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Number of items per page (max {MAX_PAGE_SIZE})",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> List[QueryFilter]:
"""Accept both JSON string and list of objects."""
return parse_json_or_model_list(v, QueryFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_columns(cls, v: Any) -> List[str]:
"""Accept JSON array, list, or comma-separated string."""
return parse_json_or_list(v, "select_columns")
@model_validator(mode="after")
def validate_search_and_filters(self) -> "ListQueriesRequest":
"""Prevent using both search and filters simultaneously."""
if self.search and self.filters:
raise ValueError(
"Cannot use both 'search' and 'filters' parameters simultaneously. "
"Use either 'search' for text-based searching across multiple fields, "
"or 'filters' for precise column-based filtering, but not both."
)
return self
class QueryError(BaseModel):
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Type of error")
timestamp: str | datetime | None = Field(None, description="Error timestamp")
model_config = ConfigDict(ser_json_timedelta="iso8601")
class GetQueryInfoRequest(BaseModel):
"""Request schema for get_query_info with support for numeric ID only."""
identifier: Annotated[
int,
Field(description="Query ID (numeric)"),
]
def serialize_query_object(query: Any) -> QueryInfo | None:
if not query:
return None
return QueryInfo(
id=getattr(query, "id", None),
sql=getattr(query, "sql", None),
executed_sql=getattr(query, "executed_sql", None),
status=getattr(query, "status", None),
start_time=getattr(query, "start_time", None),
end_time=getattr(query, "end_time", None),
rows=getattr(query, "rows", None),
database_id=getattr(query, "database_id", None),
schema=getattr(query, "schema", None),
catalog=getattr(query, "catalog", None),
tab_name=getattr(query, "tab_name", None),
error_message=getattr(query, "error_message", None),
client_id=getattr(query, "client_id", None),
limit=getattr(query, "limit", None),
progress=getattr(query, "progress", None),
changed_on=getattr(query, "changed_on", None),
user_id=getattr(query, "user_id", None),
)

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .get_query_info import get_query_info
from .list_queries import list_queries
__all__ = [
"list_queries",
"get_query_info",
]

View File

@@ -0,0 +1,122 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Get query info FastMCP tool
This module contains the FastMCP tool for getting detailed information
about a specific SQL query from the query history.
"""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.query.schemas import (
GetQueryInfoRequest,
QueryError,
QueryInfo,
serialize_query_object,
)
logger = logging.getLogger(__name__)
@tool(
tags=["discovery"],
class_permission_name="Query",
annotations=ToolAnnotations(
title="Get query info",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_query_info(
request: GetQueryInfoRequest, ctx: Context
) -> QueryInfo | QueryError:
"""Get SQL query history details by ID.
Returns query details including SQL text, execution status, timing,
row count, and any error messages.
IMPORTANT FOR LLM CLIENTS:
- Use numeric ID (e.g., 123)
- To find a query ID, use the list_queries tool first
Example usage:
```json
{
"identifier": 123
}
```
"""
await ctx.info(
"Retrieving query information: identifier=%s" % (request.identifier,)
)
try:
from superset.daos.query import QueryDAO
with event_logger.log_context(action="mcp.get_query_info.lookup"):
get_tool = ModelGetInfoCore(
dao_class=QueryDAO,
output_schema=QueryInfo,
error_schema=QueryError,
serializer=serialize_query_object,
supports_slug=False,
logger=logger,
)
result = get_tool.run_tool(request.identifier)
if isinstance(result, QueryInfo):
await ctx.info(
"Query information retrieved successfully: "
"query_id=%s, status=%s, database_id=%s"
% (
result.id,
result.status,
result.database_id,
)
)
else:
await ctx.warning(
"Query retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result
except Exception as e:
await ctx.error(
"Query information retrieval failed: identifier=%s, error=%s, "
"error_type=%s"
% (
request.identifier,
str(e),
type(e).__name__,
)
)
return QueryError(
error="Failed to get query info",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,156 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
List queries FastMCP tool
This module contains the FastMCP tool for listing SQL query history
with filtering, search, and pagination.
"""
import logging
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.query.schemas import (
DEFAULT_QUERY_COLUMNS,
ListQueriesRequest,
QueryError,
QueryFilter,
QueryInfo,
QueryList,
serialize_query_object,
SORTABLE_QUERY_COLUMNS,
)
logger = logging.getLogger(__name__)
_DEFAULT_LIST_QUERIES_REQUEST = ListQueriesRequest()
@tool(
tags=["core"],
class_permission_name="Query",
annotations=ToolAnnotations(
title="List queries",
readOnlyHint=True,
destructiveHint=False,
),
)
async def list_queries(
request: ListQueriesRequest | None = None,
ctx: Context | None = None,
) -> QueryList | QueryError:
"""List SQL query history with filtering and search.
Returns recent queries executed by the current user (or all queries for
admins), including SQL text, status, timing, and database information.
Results are ordered by changed_on descending by default (start_time is not
always populated for all query records).
Sortable columns for order_column: id, start_time, end_time, status,
database_id, changed_on
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_queries")
request = request or _DEFAULT_LIST_QUERIES_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing queries: page=%s, page_size=%s, search=%s"
% (
request.page,
request.page_size,
request.search,
)
)
await ctx.debug(
"Query listing parameters: filters=%s, order_column=%s, "
"order_direction=%s, select_columns=%s"
% (
request.filters,
request.order_column,
request.order_direction,
request.select_columns,
)
)
try:
from superset.daos.query import QueryDAO
def _serialize_query(obj: object, cols: list[str] | None) -> QueryInfo | None:
return serialize_query_object(obj)
list_tool = ModelListCore(
dao_class=QueryDAO,
output_schema=QueryInfo,
item_serializer=_serialize_query,
filter_type=QueryFilter,
default_columns=DEFAULT_QUERY_COLUMNS,
search_columns=["tab_name", "sql"],
list_field_name="queries",
output_list_schema=QueryList,
all_columns=list(QueryInfo.model_fields.keys()),
sortable_columns=SORTABLE_QUERY_COLUMNS,
logger=logger,
)
with event_logger.log_context(action="mcp.list_queries.query"):
result = list_tool.run_tool(
filters=request.filters,
search=request.search,
select_columns=request.select_columns,
order_column=request.order_column or "changed_on",
order_direction=request.order_direction,
page=max(request.page - 1, 0),
page_size=request.page_size,
)
await ctx.info(
"Queries listed successfully: count=%s, total_count=%s, total_pages=%s"
% (
len(result.queries) if hasattr(result, "queries") else 0,
getattr(result, "total_count", None),
getattr(result, "total_pages", None),
)
)
columns_to_filter = result.columns_requested
await ctx.debug(
"Applying field filtering via serialization context: columns=%s"
% (columns_to_filter,)
)
with event_logger.log_context(action="mcp.list_queries.serialization"):
return result.model_dump(
mode="json",
context={"select_columns": columns_to_filter},
)
except Exception as e:
await ctx.error(
"Query listing failed: page=%s, page_size=%s, error=%s, error_type=%s"
% (
request.page,
request.page_size,
str(e),
type(e).__name__,
)
)
raise

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,269 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Pydantic schemas for saved query-related responses
"""
from __future__ import annotations
from datetime import datetime
from typing import Annotated, Any, Dict, List, Literal
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_serializer,
model_validator,
PositiveInt,
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.constants import DEFAULT_PAGE_SIZE, MAX_PAGE_SIZE
from superset.mcp_service.privacy import filter_user_directory_fields
from superset.mcp_service.system.schemas import PaginationInfo
from superset.mcp_service.utils.schema_utils import (
parse_json_or_list,
parse_json_or_model_list,
)
DEFAULT_SAVED_QUERY_COLUMNS = ["id", "label", "db_id", "schema", "uuid"]
SORTABLE_SAVED_QUERY_COLUMNS = [
"id",
"label",
"db_id",
"schema",
"changed_on",
"created_on",
]
ALL_SAVED_QUERY_COLUMNS = [
"id",
"label",
"db_id",
"schema",
"catalog",
"uuid",
"sql",
"description",
"changed_on",
"created_on",
"last_run",
]
class SavedQueryFilter(ColumnOperator):
"""
Filter object for saved query listing.
col: The column to filter on. Must be one of the allowed filter fields.
opr: The operator to use. Must be one of the supported operators.
value: The value to filter by (type depends on col and opr).
"""
col: Literal["label", "db_id", "schema", "catalog", "created_by_fk"] = Field(
...,
description="Column to filter on.",
)
opr: ColumnOperatorEnum = Field(
...,
description="Operator to use.",
)
value: str | int | float | bool | List[str | int | float | bool] = Field(
..., description="Value to filter by (type depends on col and opr)"
)
class SavedQueryInfo(BaseModel):
id: int | None = Field(None, description="Saved query ID")
uuid: str | None = Field(None, description="Saved query UUID")
label: str | None = Field(None, description="Saved query label/name")
sql: str | None = Field(None, description="SQL query text")
db_id: int | None = Field(None, description="Database connection ID")
schema: str | None = Field(None, description="Database schema name")
catalog: str | None = Field(None, description="Database catalog name")
description: str | None = Field(None, description="User-provided description")
changed_on: str | datetime | None = Field(
None, description="Last modification timestamp"
)
created_on: str | datetime | None = Field(None, description="Creation timestamp")
last_run: str | datetime | None = Field(
None, description="Timestamp of last execution"
)
model_config = ConfigDict(
from_attributes=True,
ser_json_timedelta="iso8601",
populate_by_name=True,
)
@model_serializer(mode="wrap")
def _filter_fields_by_context(self, serializer: Any, info: Any) -> Dict[str, Any]:
data = filter_user_directory_fields(serializer(self))
if info.context and isinstance(info.context, dict):
select_columns = info.context.get("select_columns")
if select_columns:
requested_fields = set(select_columns)
return {k: v for k, v in data.items() if k in requested_fields}
return data
class SavedQueryList(BaseModel):
saved_queries: List[SavedQueryInfo]
count: int
total_count: int
page: int
page_size: int
total_pages: int
has_previous: bool
has_next: bool
columns_requested: List[str] = Field(
default_factory=list,
description="Requested columns for the response",
)
columns_loaded: List[str] = Field(
default_factory=list,
description="Columns that were actually loaded for each saved query",
)
columns_available: List[str] = Field(
default_factory=list,
description="All columns available for selection via select_columns parameter",
)
sortable_columns: List[str] = Field(
default_factory=list,
description="Columns that can be used with order_column parameter",
)
filters_applied: List[SavedQueryFilter] = Field(
default_factory=list,
description="List of advanced filter dicts applied to the query.",
)
pagination: PaginationInfo | None = None
timestamp: datetime | None = None
model_config = ConfigDict(ser_json_timedelta="iso8601")
class ListSavedQueriesRequest(BaseModel):
"""Request schema for list_saved_queries."""
filters: Annotated[
List[SavedQueryFilter],
Field(
default_factory=list,
description="List of filter objects (column, operator, value). Each "
"filter is an object with 'col', 'opr', and 'value' "
"properties. Cannot be used together with 'search'.",
),
]
select_columns: Annotated[
List[str],
Field(
default_factory=list,
description="List of columns to select. Defaults to common columns if not "
"specified.",
),
]
search: Annotated[
str | None,
Field(
default=None,
description="Text search string to match against saved query fields. "
"Cannot be used together with 'filters'.",
),
]
order_column: Annotated[
str | None, Field(default=None, description="Column to order results by")
]
order_direction: Annotated[
Literal["asc", "desc"],
Field(
default="desc", description="Direction to order results ('asc' or 'desc')"
),
]
page: Annotated[
PositiveInt,
Field(default=1, description="Page number for pagination (1-based)"),
]
page_size: Annotated[
int,
Field(
default=DEFAULT_PAGE_SIZE,
gt=0,
le=MAX_PAGE_SIZE,
description=f"Number of items per page (max {MAX_PAGE_SIZE})",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> List[SavedQueryFilter]:
"""Accept both JSON string and list of objects."""
return parse_json_or_model_list(v, SavedQueryFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_columns(cls, v: Any) -> List[str]:
"""Accept JSON array, list, or comma-separated string."""
return parse_json_or_list(v, "select_columns")
@model_validator(mode="after")
def validate_search_and_filters(self) -> "ListSavedQueriesRequest":
"""Prevent using both search and filters simultaneously."""
if self.search and self.filters:
raise ValueError(
"Cannot use both 'search' and 'filters' parameters simultaneously. "
"Use either 'search' for text-based searching across multiple fields, "
"or 'filters' for precise column-based filtering, but not both."
)
return self
class SavedQueryError(BaseModel):
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Type of error")
timestamp: str | datetime | None = Field(None, description="Error timestamp")
model_config = ConfigDict(ser_json_timedelta="iso8601")
class GetSavedQueryInfoRequest(BaseModel):
"""Request schema for get_saved_query_info with support for ID or UUID."""
identifier: Annotated[
int | str,
Field(description="Saved query identifier - can be numeric ID or UUID string"),
]
def serialize_saved_query_object(saved_query: Any) -> SavedQueryInfo | None:
if not saved_query:
return None
return SavedQueryInfo(
id=getattr(saved_query, "id", None),
uuid=str(getattr(saved_query, "uuid", ""))
if getattr(saved_query, "uuid", None)
else None,
label=getattr(saved_query, "label", None),
sql=getattr(saved_query, "sql", None),
db_id=getattr(saved_query, "db_id", None),
schema=getattr(saved_query, "schema", None),
catalog=getattr(saved_query, "catalog", None),
description=getattr(saved_query, "description", None),
changed_on=getattr(saved_query, "changed_on", None),
created_on=getattr(saved_query, "created_on", None),
last_run=getattr(saved_query, "last_run", None),
)

View File

@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from .get_saved_query_info import get_saved_query_info
from .list_saved_queries import list_saved_queries
__all__ = [
"list_saved_queries",
"get_saved_query_info",
]

View File

@@ -0,0 +1,129 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Get saved query info FastMCP tool
This module contains the FastMCP tool for getting detailed information
about a specific saved SQL query.
"""
import logging
from datetime import datetime, timezone
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelGetInfoCore
from superset.mcp_service.saved_query.schemas import (
GetSavedQueryInfoRequest,
SavedQueryError,
SavedQueryInfo,
serialize_saved_query_object,
)
logger = logging.getLogger(__name__)
@tool(
tags=["discovery"],
class_permission_name="SavedQuery",
annotations=ToolAnnotations(
title="Get saved query info",
readOnlyHint=True,
destructiveHint=False,
),
)
async def get_saved_query_info(
request: GetSavedQueryInfoRequest, ctx: Context
) -> SavedQueryInfo | SavedQueryError:
"""Get saved query details by ID or UUID.
Returns the full saved query including SQL text, label, database,
schema, and timestamps.
IMPORTANT FOR LLM CLIENTS:
- Use numeric ID (e.g., 42) or UUID string (e.g., "a1b2c3d4-...")
- To find a saved query ID, use the list_saved_queries tool first
Example usage:
```json
{
"identifier": 42
}
```
Or with UUID:
```json
{
"identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab"
}
```
"""
await ctx.info(
"Retrieving saved query information: identifier=%s" % (request.identifier,)
)
try:
from superset.daos.query import SavedQueryDAO
with event_logger.log_context(action="mcp.get_saved_query_info.lookup"):
get_tool = ModelGetInfoCore(
dao_class=SavedQueryDAO,
output_schema=SavedQueryInfo,
error_schema=SavedQueryError,
serializer=serialize_saved_query_object,
supports_slug=False,
logger=logger,
)
result = get_tool.run_tool(request.identifier)
if isinstance(result, SavedQueryInfo):
await ctx.info(
"Saved query information retrieved successfully: "
"saved_query_id=%s, label=%s, db_id=%s"
% (
result.id,
result.label,
result.db_id,
)
)
else:
await ctx.warning(
"Saved query retrieval failed: error_type=%s, error=%s"
% (result.error_type, result.error)
)
return result
except Exception as e:
await ctx.error(
"Saved query information retrieval failed: identifier=%s, error=%s, "
"error_type=%s"
% (
request.identifier,
str(e),
type(e).__name__,
)
)
return SavedQueryError(
error="Failed to get saved query info",
error_type="InternalError",
timestamp=datetime.now(timezone.utc),
)

View File

@@ -0,0 +1,158 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
List saved queries FastMCP tool
This module contains the FastMCP tool for listing saved SQL queries
with filtering, search, and pagination.
"""
import logging
from fastmcp import Context
from superset_core.mcp.decorators import tool, ToolAnnotations
from superset.extensions import event_logger
from superset.mcp_service.mcp_core import ModelListCore
from superset.mcp_service.saved_query.schemas import (
DEFAULT_SAVED_QUERY_COLUMNS,
ListSavedQueriesRequest,
SavedQueryError,
SavedQueryFilter,
SavedQueryInfo,
SavedQueryList,
serialize_saved_query_object,
SORTABLE_SAVED_QUERY_COLUMNS,
)
logger = logging.getLogger(__name__)
_DEFAULT_LIST_SAVED_QUERIES_REQUEST = ListSavedQueriesRequest()
@tool(
tags=["core"],
class_permission_name="SavedQuery",
annotations=ToolAnnotations(
title="List saved queries",
readOnlyHint=True,
destructiveHint=False,
),
)
async def list_saved_queries(
request: ListSavedQueriesRequest | None = None,
ctx: Context | None = None,
) -> SavedQueryList | SavedQueryError:
"""List saved SQL queries with filtering and search.
Returns saved queries owned by the current user, including label, SQL,
database ID, and schema.
Sortable columns for order_column: id, label, db_id, schema,
changed_on, created_on
"""
if ctx is None:
raise RuntimeError("FastMCP context is required for list_saved_queries")
request = request or _DEFAULT_LIST_SAVED_QUERIES_REQUEST.model_copy(deep=True)
await ctx.info(
"Listing saved queries: page=%s, page_size=%s, search=%s"
% (
request.page,
request.page_size,
request.search,
)
)
await ctx.debug(
"Saved query listing parameters: filters=%s, order_column=%s, "
"order_direction=%s, select_columns=%s"
% (
request.filters,
request.order_column,
request.order_direction,
request.select_columns,
)
)
try:
from superset.daos.query import SavedQueryDAO
def _serialize_saved_query(
obj: object, cols: list[str] | None
) -> SavedQueryInfo | None:
return serialize_saved_query_object(obj)
list_tool = ModelListCore(
dao_class=SavedQueryDAO,
output_schema=SavedQueryInfo,
item_serializer=_serialize_saved_query,
filter_type=SavedQueryFilter,
default_columns=DEFAULT_SAVED_QUERY_COLUMNS,
search_columns=["label", "description", "sql"],
list_field_name="saved_queries",
output_list_schema=SavedQueryList,
all_columns=list(SavedQueryInfo.model_fields.keys()),
sortable_columns=SORTABLE_SAVED_QUERY_COLUMNS,
logger=logger,
)
with event_logger.log_context(action="mcp.list_saved_queries.query"):
result = list_tool.run_tool(
filters=request.filters,
search=request.search,
select_columns=request.select_columns,
order_column=request.order_column,
order_direction=request.order_direction,
page=max(request.page - 1, 0),
page_size=request.page_size,
)
await ctx.info(
"Saved queries listed successfully: count=%s, total_count=%s, "
"total_pages=%s"
% (
len(result.saved_queries) if hasattr(result, "saved_queries") else 0,
getattr(result, "total_count", None),
getattr(result, "total_pages", None),
)
)
columns_to_filter = result.columns_requested
await ctx.debug(
"Applying field filtering via serialization context: columns=%s"
% (columns_to_filter,)
)
with event_logger.log_context(action="mcp.list_saved_queries.serialization"):
return result.model_dump(
mode="json",
context={"select_columns": columns_to_filter},
)
except Exception as e:
await ctx.error(
"Saved query listing failed: page=%s, page_size=%s, error=%s, "
"error_type=%s"
% (
request.page,
request.page_size,
str(e),
type(e).__name__,
)
)
raise

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,328 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from unittest.mock import MagicMock, patch
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from pydantic import ValidationError
from superset.mcp_service.app import mcp
from superset.mcp_service.query.schemas import (
ListQueriesRequest,
QueryFilter,
)
from superset.utils import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class TestQueryFilterSchema:
"""Tests for QueryFilter schema — filterable columns."""
def test_invalid_filter_column_rejected(self):
"""Columns not in the Literal set must be rejected."""
with pytest.raises(ValidationError):
QueryFilter(col="not_a_real_column", opr="eq", value="test")
def test_valid_status_filter_accepted(self):
"""status is a valid filter column."""
f = QueryFilter(col="status", opr="eq", value="success")
assert f.col == "status"
def test_valid_database_id_filter_accepted(self):
"""database_id is a valid filter column."""
f = QueryFilter(col="database_id", opr="eq", value=1)
assert f.col == "database_id"
def test_valid_schema_filter_accepted(self):
"""schema is a valid filter column."""
f = QueryFilter(col="schema", opr="eq", value="public")
assert f.col == "schema"
def test_valid_user_id_filter_accepted(self):
"""user_id filter enables admin-level filtering by user."""
f = QueryFilter(col="user_id", opr="eq", value=42)
assert f.col == "user_id"
def test_valid_start_time_filter_accepted(self):
"""start_time filter enables time-range queries."""
f = QueryFilter(col="start_time", opr="gt", value=1700000000.0)
assert f.col == "start_time"
def create_mock_query(
query_id: int = 1,
sql: str = "SELECT * FROM table",
executed_sql: str | None = None,
status: str = "success",
start_time: float = 1700000000.0,
end_time: float = 1700000001.0,
rows: int = 100,
database_id: int = 1,
schema: str = "public",
catalog: str | None = None,
tab_name: str = "SQL Lab 1",
error_message: str | None = None,
client_id: str = "abc123",
user_id: int | None = 1,
) -> MagicMock:
"""Factory function to create mock query objects with sensible defaults."""
query = MagicMock()
query.id = query_id
query.sql = sql
query.executed_sql = executed_sql
query.status = status
query.start_time = start_time
query.end_time = end_time
query.rows = rows
query.database_id = database_id
query.schema = schema
query.catalog = catalog
query.tab_name = tab_name
query.error_message = error_message
query.client_id = client_id
query.limit = 1000
query.progress = 100
query.changed_on = None
query.user_id = user_id
return query
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
from unittest.mock import Mock, patch
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
@patch("superset.daos.query.QueryDAO.list")
@pytest.mark.asyncio
async def test_list_queries_basic(mock_list, mcp_server):
"""Test basic query listing functionality."""
query = create_mock_query()
query._mapping = {
"id": query.id,
"sql": query.sql,
"status": query.status,
"start_time": query.start_time,
"database_id": query.database_id,
"schema": query.schema,
}
mock_list.return_value = ([query], 1)
async with Client(mcp_server) as client:
request = ListQueriesRequest(page=1, page_size=10)
result = await client.call_tool(
"list_queries", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["queries"] is not None
assert len(data["queries"]) == 1
assert data["queries"][0]["id"] == 1
assert data["queries"][0]["status"] == "success"
@patch("superset.daos.query.QueryDAO.list")
@pytest.mark.asyncio
async def test_list_queries_with_status_filter(mock_list, mcp_server):
"""Test query listing with status filter."""
query = create_mock_query(status="failed", error_message="Syntax error")
query._mapping = {
"id": query.id,
"sql": query.sql,
"status": query.status,
"error_message": query.error_message,
}
mock_list.return_value = ([query], 1)
async with Client(mcp_server) as client:
request = ListQueriesRequest(
page=1,
page_size=10,
filters=[
{"col": "status", "opr": "eq", "value": "failed"},
],
)
result = await client.call_tool(
"list_queries", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["queries"] is not None
assert len(data["queries"]) == 1
assert data["queries"][0]["status"] == "failed"
@patch("superset.daos.query.QueryDAO.list")
@pytest.mark.asyncio
async def test_list_queries_default_page_size(mock_list, mcp_server):
"""Test that default page size is 25 for query history."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_queries", {})
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["page_size"] == 25
def test_list_queries_request_rejects_both_search_and_filters():
"""Cannot use search and filters simultaneously."""
with pytest.raises(ValidationError):
ListQueriesRequest(
search="test",
filters=[{"col": "status", "opr": "eq", "value": "success"}],
)
@patch("superset.daos.query.QueryDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_query_info_basic(mock_find, mcp_server):
"""Test basic get query info functionality."""
query = create_mock_query()
mock_find.return_value = query
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_query_info", {"request": {"identifier": 1}}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["status"] == "success"
assert data["database_id"] == 1
@patch("superset.daos.query.QueryDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_query_info_not_found(mock_find, mcp_server):
"""Test get query info when query does not exist."""
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_query_info", {"request": {"identifier": 999}}
)
assert result.data["error_type"] == "not_found"
@patch("superset.daos.query.QueryDAO.list")
@pytest.mark.asyncio
async def test_list_queries_empty(mock_list, mcp_server):
"""Test query listing returns empty list when no results."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
request = ListQueriesRequest(page=1, page_size=10)
result = await client.call_tool(
"list_queries", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["queries"] == []
assert data["count"] == 0
assert data["total_count"] == 0
@patch("superset.daos.query.QueryDAO.list")
@pytest.mark.asyncio
async def test_list_queries_pagination_info(mock_list, mcp_server):
"""Test that pagination info is correctly returned."""
queries = [create_mock_query(query_id=i) for i in range(1, 4)]
for q in queries:
q._mapping = {"id": q.id, "sql": q.sql, "status": q.status}
mock_list.return_value = (queries, 100)
async with Client(mcp_server) as client:
request = ListQueriesRequest(page=1, page_size=3)
result = await client.call_tool(
"list_queries", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["total_count"] == 100
assert data["page_size"] == 3
assert data["has_next"] is True
assert data["has_previous"] is False
@patch("superset.daos.query.QueryDAO.list")
@pytest.mark.asyncio
async def test_list_queries_default_order_is_changed_on_desc(mock_list, mcp_server):
"""Test that default ordering is changed_on descending."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
result = await client.call_tool("list_queries", {})
assert result.content is not None
mock_list.assert_called_once()
call_kwargs = mock_list.call_args
assert call_kwargs.kwargs.get("order_column") == "changed_on"
assert call_kwargs.kwargs.get("order_direction") == "desc"
@patch("superset.daos.query.QueryDAO.list")
@pytest.mark.asyncio
async def test_list_queries_select_columns_projects_fields(mock_list, mcp_server):
"""select_columns limits which fields appear in each query result."""
query = create_mock_query()
query._mapping = {"id": query.id, "status": query.status}
mock_list.return_value = ([query], 1)
async with Client(mcp_server) as client:
request = ListQueriesRequest(
page=1, page_size=10, select_columns=["id", "status"]
)
result = await client.call_tool(
"list_queries", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["queries"] is not None
q = data["queries"][0]
assert set(q.keys()) == {"id", "status"}
assert q["id"] == 1
assert q["status"] == "success"
@pytest.mark.asyncio
async def test_list_queries_invalid_order_column_raises(mcp_server):
"""order_column not in SORTABLE_QUERY_COLUMNS must be rejected."""
request = ListQueriesRequest(page=1, page_size=10, order_column="tab_name")
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="Invalid order_column"):
await client.call_tool("list_queries", {"request": request.model_dump()})
@patch("superset.daos.query.QueryDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_query_info_internal_error(mock_find, mcp_server):
"""When an unexpected exception is raised, get_query_info returns InternalError."""
mock_find.side_effect = RuntimeError("unexpected db failure")
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_query_info", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "InternalError"
assert data["error"] == "Failed to get query info"

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,337 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from unittest.mock import MagicMock, patch
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from pydantic import ValidationError
from superset.mcp_service.app import mcp
from superset.mcp_service.saved_query.schemas import (
ListSavedQueriesRequest,
SavedQueryFilter,
)
from superset.utils import json
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
class TestSavedQueryFilterSchema:
"""Tests for SavedQueryFilter schema — filterable columns."""
def test_invalid_filter_column_rejected(self):
"""Columns not in the Literal set must be rejected."""
with pytest.raises(ValidationError):
SavedQueryFilter(col="not_a_real_column", opr="eq", value="test")
def test_user_id_is_rejected_as_filter_column(self):
"""user_id is not a filter column; use created_by_fk instead."""
with pytest.raises(ValidationError):
SavedQueryFilter(col="user_id", opr="eq", value=1)
def test_valid_label_filter_accepted(self):
"""label is a valid filter column."""
f = SavedQueryFilter(col="label", opr="eq", value="my query")
assert f.col == "label"
def test_valid_db_id_filter_accepted(self):
"""db_id is a valid filter column."""
f = SavedQueryFilter(col="db_id", opr="eq", value=1)
assert f.col == "db_id"
def test_valid_schema_filter_accepted(self):
"""schema is a valid filter column."""
f = SavedQueryFilter(col="schema", opr="eq", value="public")
assert f.col == "schema"
def test_valid_catalog_filter_accepted(self):
"""catalog is a valid filter column."""
f = SavedQueryFilter(col="catalog", opr="eq", value="my_catalog")
assert f.col == "catalog"
def test_valid_created_by_fk_filter_accepted(self):
"""created_by_fk enables filtering by the owner user ID."""
f = SavedQueryFilter(col="created_by_fk", opr="eq", value=42)
assert f.col == "created_by_fk"
def create_mock_saved_query(
saved_query_id: int = 1,
label: str = "My Query",
sql: str = "SELECT 1",
db_id: int = 1,
schema: str = "public",
catalog: str | None = None,
description: str = "Test query",
uuid: str = "test-uuid-1234",
last_run: datetime | None = None,
) -> MagicMock:
"""Factory function to create mock saved query objects with sensible defaults."""
saved_query = MagicMock()
saved_query.id = saved_query_id
saved_query.label = label
saved_query.sql = sql
saved_query.db_id = db_id
saved_query.schema = schema
saved_query.catalog = catalog
saved_query.description = description
saved_query.uuid = uuid
saved_query.changed_on = None
saved_query.created_on = None
saved_query.last_run = last_run
return saved_query
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
from unittest.mock import Mock, patch
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
@patch("superset.daos.query.SavedQueryDAO.list")
@pytest.mark.asyncio
async def test_list_saved_queries_basic(mock_list, mcp_server):
"""Test basic saved query listing functionality."""
saved_query = create_mock_saved_query()
saved_query._mapping = {
"id": saved_query.id,
"label": saved_query.label,
"db_id": saved_query.db_id,
"schema": saved_query.schema,
"uuid": saved_query.uuid,
}
mock_list.return_value = ([saved_query], 1)
async with Client(mcp_server) as client:
request = ListSavedQueriesRequest(page=1, page_size=10)
result = await client.call_tool(
"list_saved_queries", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["saved_queries"] is not None
assert len(data["saved_queries"]) == 1
assert data["saved_queries"][0]["id"] == 1
assert data["saved_queries"][0]["label"] == "My Query"
@patch("superset.daos.query.SavedQueryDAO.list")
@pytest.mark.asyncio
async def test_list_saved_queries_with_search(mock_list, mcp_server):
"""Test saved query listing with search functionality."""
saved_query = create_mock_saved_query(label="Production Query")
saved_query._mapping = {
"id": saved_query.id,
"label": saved_query.label,
}
mock_list.return_value = ([saved_query], 1)
async with Client(mcp_server) as client:
request = ListSavedQueriesRequest(page=1, page_size=10, search="Production")
result = await client.call_tool(
"list_saved_queries", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["saved_queries"] is not None
assert len(data["saved_queries"]) == 1
assert data["saved_queries"][0]["label"] == "Production Query"
@patch("superset.daos.query.SavedQueryDAO.list")
@pytest.mark.asyncio
async def test_list_saved_queries_with_filters(mock_list, mcp_server):
"""Test saved query listing with filters."""
saved_query = create_mock_saved_query(db_id=2)
saved_query._mapping = {
"id": saved_query.id,
"label": saved_query.label,
"db_id": saved_query.db_id,
}
mock_list.return_value = ([saved_query], 1)
async with Client(mcp_server) as client:
request = ListSavedQueriesRequest(
page=1,
page_size=10,
filters=[
{"col": "db_id", "opr": "eq", "value": 2},
],
)
result = await client.call_tool(
"list_saved_queries", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["saved_queries"] is not None
assert len(data["saved_queries"]) == 1
def test_list_saved_queries_request_rejects_both_search_and_filters():
"""Cannot use search and filters simultaneously."""
with pytest.raises(ValidationError):
ListSavedQueriesRequest(
search="test",
filters=[{"col": "label", "opr": "eq", "value": "test"}],
)
@patch("superset.daos.query.SavedQueryDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_saved_query_info_basic(mock_find, mcp_server):
"""Test basic get saved query info functionality."""
saved_query = create_mock_saved_query()
mock_find.return_value = saved_query
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_saved_query_info", {"request": {"identifier": 1}}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["label"] == "My Query"
assert data["sql"] == "SELECT 1"
assert data["db_id"] == 1
@patch("superset.daos.query.SavedQueryDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_saved_query_info_not_found(mock_find, mcp_server):
"""Test get saved query info when saved query does not exist."""
mock_find.return_value = None
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_saved_query_info", {"request": {"identifier": 999}}
)
assert result.data["error_type"] == "not_found"
@patch("superset.daos.query.SavedQueryDAO.list")
@pytest.mark.asyncio
async def test_list_saved_queries_empty(mock_list, mcp_server):
"""Test saved query listing returns empty list when no results."""
mock_list.return_value = ([], 0)
async with Client(mcp_server) as client:
request = ListSavedQueriesRequest(page=1, page_size=10)
result = await client.call_tool(
"list_saved_queries", {"request": request.model_dump()}
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["saved_queries"] == []
assert data["count"] == 0
assert data["total_count"] == 0
@patch("superset.daos.query.SavedQueryDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_saved_query_info_by_uuid(mock_find, mcp_server):
"""Test get saved query info by UUID string."""
saved_query = create_mock_saved_query(uuid="a1b2c3d4-5678-90ab-cdef-1234567890ab")
mock_find.return_value = saved_query
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_saved_query_info",
{"request": {"identifier": "a1b2c3d4-5678-90ab-cdef-1234567890ab"}},
)
assert result.content is not None
data = json.loads(result.content[0].text)
assert data["id"] == 1
assert data["uuid"] == "a1b2c3d4-5678-90ab-cdef-1234567890ab"
@patch("superset.daos.query.SavedQueryDAO.list")
@pytest.mark.asyncio
async def test_list_saved_queries_pagination_info(mock_list, mcp_server):
"""Test that pagination info is correctly returned."""
saved_queries = [create_mock_saved_query(saved_query_id=i) for i in range(1, 4)]
for sq in saved_queries:
sq._mapping = {"id": sq.id, "label": sq.label}
mock_list.return_value = (saved_queries, 25)
async with Client(mcp_server) as client:
request = ListSavedQueriesRequest(page=1, page_size=3)
result = await client.call_tool(
"list_saved_queries", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["total_count"] == 25
assert data["count"] == 3
assert data["page_size"] == 3
assert data["has_next"] is True
assert data["has_previous"] is False
@patch("superset.daos.query.SavedQueryDAO.list")
@pytest.mark.asyncio
async def test_list_saved_queries_select_columns_projects_fields(mock_list, mcp_server):
"""select_columns limits which fields appear in each saved query result."""
saved_query = create_mock_saved_query()
saved_query._mapping = {"id": saved_query.id, "label": saved_query.label}
mock_list.return_value = ([saved_query], 1)
async with Client(mcp_server) as client:
request = ListSavedQueriesRequest(
page=1, page_size=10, select_columns=["id", "label"]
)
result = await client.call_tool(
"list_saved_queries", {"request": request.model_dump()}
)
data = json.loads(result.content[0].text)
assert data["saved_queries"] is not None
sq = data["saved_queries"][0]
assert set(sq.keys()) == {"id", "label"}
assert sq["id"] == 1
assert sq["label"] == "My Query"
@pytest.mark.asyncio
async def test_list_saved_queries_invalid_order_column_raises(mcp_server):
"""order_column not in SORTABLE_SAVED_QUERY_COLUMNS must be rejected."""
request = ListSavedQueriesRequest(page=1, page_size=10, order_column="sql")
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="Invalid order_column"):
await client.call_tool(
"list_saved_queries", {"request": request.model_dump()}
)
@patch("superset.daos.query.SavedQueryDAO.find_by_id")
@pytest.mark.asyncio
async def test_get_saved_query_info_internal_error(mock_find, mcp_server):
"""Unexpected exception in get_saved_query_info returns InternalError."""
mock_find.side_effect = RuntimeError("unexpected db failure")
async with Client(mcp_server) as client:
result = await client.call_tool(
"get_saved_query_info", {"request": {"identifier": 1}}
)
data = json.loads(result.content[0].text)
assert data["error_type"] == "InternalError"
assert data["error"] == "Failed to get saved query info"