feat(mcp): add save_sql_query tool for SQL Lab saved queries (#38414)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Amin Ghadersohi
2026-03-13 22:02:04 +01:00
committed by GitHub
parent ed622e254a
commit 48220fb33f
5 changed files with 669 additions and 1 deletions

View File

@@ -70,6 +70,7 @@ Chart Management:
SQL Lab Integration:
- execute_sql: Execute SQL queries and get results (requires database_id)
- save_sql_query: Save a SQL query to Saved Queries list
- open_sql_lab_with_context: Generate SQL Lab URL with pre-filled sql
Schema Discovery:
@@ -105,7 +106,8 @@ To find your own charts/dashboards:
To explore data with SQL:
1. list_datasets -> find a dataset and note its database_id
2. execute_sql(database_id, sql) -> run query
3. open_sql_lab_with_context(database_id) -> open SQL Lab UI
3. save_sql_query(database_id, label, sql) -> save query for later reuse
4. open_sql_lab_with_context(database_id) -> open SQL Lab UI
generate_explore_link vs generate_chart:
- Use generate_explore_link for exploration (no permanent chart created)
@@ -415,6 +417,7 @@ from superset.mcp_service.explore.tool import ( # noqa: F401, E402
from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402
execute_sql,
open_sql_lab_with_context,
save_sql_query,
)
from superset.mcp_service.system import ( # noqa: F401, E402
prompts as system_prompts,

View File

@@ -147,6 +147,65 @@ class ExecuteSqlResponse(BaseModel):
)
class SaveSqlQueryRequest(BaseModel):
"""Request schema for saving a SQL query."""
database_id: int = Field(
..., description="Database connection ID the query runs against"
)
label: str = Field(
...,
description="Name for the saved query (shown in Saved Queries list)",
min_length=1,
max_length=256,
)
sql: str = Field(
...,
description="SQL query text to save",
)
schema_name: str | None = Field(
None,
description="Schema the query targets",
alias="schema",
)
catalog: str | None = Field(None, description="Catalog name (if applicable)")
description: str | None = Field(
None, description="Optional description of the query"
)
@field_validator("sql")
@classmethod
def sql_not_empty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("SQL query cannot be empty")
return v.strip()
@field_validator("label")
@classmethod
def label_not_empty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("Label cannot be empty")
return v.strip()
class SaveSqlQueryResponse(BaseModel):
"""Response schema for a saved SQL query."""
id: int = Field(..., description="Saved query ID")
label: str = Field(..., description="Query name")
sql: str = Field(..., description="SQL query text")
database_id: int = Field(..., description="Database ID")
schema_name: str | None = Field(None, description="Schema name", alias="schema")
catalog: str | None = Field(None, description="Catalog name (if applicable)")
description: str | None = Field(None, description="Query description")
url: str = Field(
...,
description=(
"URL to open this saved query in SQL Lab (e.g., /sqllab?savedQueryId=42)"
),
)
class OpenSqlLabRequest(BaseModel):
"""Request schema for opening SQL Lab with context."""

View File

@@ -23,8 +23,10 @@ from superset.mcp_service.sql_lab.tool.execute_sql import execute_sql
from superset.mcp_service.sql_lab.tool.open_sql_lab_with_context import (
open_sql_lab_with_context,
)
from superset.mcp_service.sql_lab.tool.save_sql_query import save_sql_query
__all__ = [
"execute_sql",
"open_sql_lab_with_context",
"save_sql_query",
]

View File

@@ -0,0 +1,137 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Save SQL Query MCP Tool
Tool for saving a SQL query as a named SavedQuery in Superset,
so it appears in SQL Lab's "Saved Queries" list and can be
reloaded/shared via URL.
"""
from __future__ import annotations
import logging
from fastmcp import Context
from sqlalchemy.exc import SQLAlchemyError
from superset_core.mcp.decorators import tool
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException, SupersetSecurityException
from superset.extensions import event_logger
from superset.mcp_service.sql_lab.schemas import (
SaveSqlQueryRequest,
SaveSqlQueryResponse,
)
from superset.mcp_service.utils.schema_utils import parse_request
logger = logging.getLogger(__name__)
@tool(tags=["mutate"])
@parse_request(SaveSqlQueryRequest)
async def save_sql_query(
request: SaveSqlQueryRequest, ctx: Context
) -> SaveSqlQueryResponse:
"""Save a SQL query so it appears in SQL Lab's Saved Queries list.
Creates a persistent SavedQuery that the user can reload from
SQL Lab, share via URL, and find in the Saved Queries page.
Requires a database_id, a label (name), and the SQL text.
"""
await ctx.info(
"Saving SQL query: database_id=%s, label=%r"
% (request.database_id, request.label)
)
try:
from flask import g
from superset import db, security_manager
from superset.daos.query import SavedQueryDAO
from superset.mcp_service.utils.url_utils import get_superset_base_url
from superset.models.core import Database
# 1. Validate database exists and user has access
with event_logger.log_context(action="mcp.save_sql_query.db_validation"):
database = (
db.session.query(Database).filter_by(id=request.database_id).first()
)
if not database:
raise SupersetErrorException(
SupersetError(
message=(f"Database with ID {request.database_id} not found"),
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
level=ErrorLevel.ERROR,
)
)
if not security_manager.can_access_database(database):
raise SupersetSecurityException(
SupersetError(
message=(f"Access denied to database {database.database_name}"),
error_type=(SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR),
level=ErrorLevel.ERROR,
)
)
# 2. Create the saved query
with event_logger.log_context(action="mcp.save_sql_query.create"):
saved_query = SavedQueryDAO.create(
attributes={
"user_id": g.user.id,
"db_id": request.database_id,
"label": request.label,
"sql": request.sql,
"schema": request.schema_name or "",
"catalog": request.catalog,
"description": request.description or "",
}
)
db.session.commit() # pylint: disable=consider-using-transaction
# 3. Build response
base_url = get_superset_base_url()
saved_query_url = f"{base_url}/sqllab?savedQueryId={saved_query.id}"
await ctx.info(
"Saved query created: id=%s, url=%s" % (saved_query.id, saved_query_url)
)
return SaveSqlQueryResponse(
id=saved_query.id,
label=saved_query.label,
sql=saved_query.sql,
database_id=request.database_id,
schema_name=request.schema_name,
catalog=getattr(saved_query, "catalog", None),
description=request.description,
url=saved_query_url,
)
except (SupersetErrorException, SupersetSecurityException):
raise
except SQLAlchemyError as e:
from superset import db
db.session.rollback()
await ctx.error(
"Failed to save SQL query: error=%s, database_id=%s"
% (str(e), request.database_id)
)
raise

View File

@@ -0,0 +1,467 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for save_sql_query MCP tool schemas and logic.
"""
import importlib
import sys
import types
from unittest.mock import MagicMock, Mock, patch
import pytest
from pydantic import ValidationError
from superset.mcp_service.sql_lab.schemas import (
SaveSqlQueryRequest,
SaveSqlQueryResponse,
)
class TestSaveSqlQueryRequest:
"""Test SaveSqlQueryRequest schema validation."""
def test_valid_request(self) -> None:
req = SaveSqlQueryRequest(
database_id=1,
label="Revenue Query",
sql="SELECT SUM(revenue) FROM sales",
)
assert req.database_id == 1
assert req.label == "Revenue Query"
assert req.sql == "SELECT SUM(revenue) FROM sales"
def test_with_optional_fields(self) -> None:
req = SaveSqlQueryRequest(
database_id=1,
label="Revenue Query",
sql="SELECT 1",
schema="public",
catalog="main",
description="Sums revenue",
)
assert req.schema_name == "public"
assert req.catalog == "main"
assert req.description == "Sums revenue"
def test_empty_sql_fails(self) -> None:
with pytest.raises(ValidationError, match="SQL query cannot be empty"):
SaveSqlQueryRequest(database_id=1, label="test", sql=" ")
def test_empty_label_fails(self) -> None:
with pytest.raises(ValidationError, match="Label cannot be empty"):
SaveSqlQueryRequest(database_id=1, label=" ", sql="SELECT 1")
def test_sql_is_stripped(self) -> None:
req = SaveSqlQueryRequest(database_id=1, label="test", sql=" SELECT 1 ")
assert req.sql == "SELECT 1"
def test_label_is_stripped(self) -> None:
req = SaveSqlQueryRequest(database_id=1, label=" My Query ", sql="SELECT 1")
assert req.label == "My Query"
def test_label_max_length(self) -> None:
with pytest.raises(ValidationError, match="String should have at most 256"):
SaveSqlQueryRequest(database_id=1, label="a" * 257, sql="SELECT 1")
def test_schema_alias(self) -> None:
"""The field accepts 'schema' as alias for 'schema_name'."""
req = SaveSqlQueryRequest(
database_id=1,
label="test",
sql="SELECT 1",
schema="public",
)
assert req.schema_name == "public"
class TestSaveSqlQueryResponse:
"""Test SaveSqlQueryResponse schema."""
def test_response_fields(self) -> None:
resp = SaveSqlQueryResponse(
id=42,
label="Revenue",
sql="SELECT 1",
database_id=1,
url="/sqllab?savedQueryId=42",
)
assert resp.id == 42
assert resp.label == "Revenue"
assert resp.url == "/sqllab?savedQueryId=42"
def test_response_with_optional_fields(self) -> None:
resp = SaveSqlQueryResponse(
id=42,
label="Revenue",
sql="SELECT 1",
database_id=1,
schema="public",
description="A query",
url="/sqllab?savedQueryId=42",
)
assert resp.schema_name == "public"
assert resp.description == "A query"
def _force_passthrough_decorators():
"""Force superset_core MCP tool decorator to be a passthrough.
In CI, superset_core is fully installed and the real @tool decorator
includes authentication middleware. For unit tests we want to bypass
auth and test the tool logic directly, so we always replace the
decorator with a passthrough regardless of installation state.
Returns a dict of original sys.modules entries so they can be restored.
"""
def _passthrough_tool(func=None, **kwargs):
if func is not None:
return func
return lambda f: f
mock_mcp = MagicMock()
mock_mcp.tool = _passthrough_tool
mock_decorators = MagicMock()
mock_decorators.tool = _passthrough_tool
mock_api = MagicMock()
mock_api.mcp = mock_mcp
# Save original modules so we can restore them later
saved_modules: dict[str, types.ModuleType] = {}
# Only mock the specific decorator submodules, NOT the top-level
# superset_core package. Replacing sys.modules["superset_core"] with
# a MagicMock causes 'superset_core' is not a package errors for
# other submodules (queries, common) that are imported by sibling
# tool files during test collection.
mock_keys = [
"superset_core.api",
"superset_core.api.mcp",
"superset_core.api.types",
"superset_core.mcp",
"superset_core.mcp.decorators",
]
for key in mock_keys:
if key in sys.modules:
saved_modules[key] = sys.modules[key]
sys.modules["superset_core.api"] = mock_api
sys.modules["superset_core.api.mcp"] = mock_mcp
sys.modules["superset_core.mcp"] = mock_mcp
sys.modules["superset_core.mcp.decorators"] = mock_decorators
sys.modules.setdefault("superset_core.api.types", MagicMock())
return saved_modules
def _restore_modules(saved_modules: dict[str, types.ModuleType]) -> None:
"""Restore original sys.modules entries after passthrough mocking."""
# Remove mock entries for decorator paths and tool modules imported
# under patched decorators. Do NOT remove the top-level superset_core
# package or unrelated submodules (queries, common, etc.).
mock_prefixes = (
"superset_core.api",
"superset_core.mcp",
"superset.mcp_service.sql_lab.tool",
)
for key in list(sys.modules.keys()):
if any(key.startswith(prefix) for prefix in mock_prefixes):
del sys.modules[key]
# Restore originals (including any previously-imported tool modules)
sys.modules.update(saved_modules)
def _get_tool_module():
"""Import save_sql_query with passthrough decorators (no auth).
Returns (module, saved_modules) so callers can restore sys.modules.
"""
saved_modules = _force_passthrough_decorators()
# Clear cached module imports so we get a fresh import with mocked
# decorators. This is necessary because in CI the real @tool decorator
# may have been applied during a previous import.
mod_name = "superset.mcp_service.sql_lab.tool.save_sql_query"
saved_tool_modules: dict[str, object] = {}
for key in list(sys.modules.keys()):
if key.startswith("superset.mcp_service.sql_lab.tool"):
saved_tool_modules[key] = sys.modules.pop(key)
saved_modules.update(saved_tool_modules)
mod = importlib.import_module(mod_name)
return mod, saved_modules
def _make_mock_ctx():
"""Create a mock FastMCP Context with awaitable methods."""
async def _noop(*args, **kwargs):
pass
ctx = MagicMock()
ctx.info = _noop
ctx.error = _noop
ctx.warning = _noop
return ctx
class TestSaveSqlQueryToolLogic:
"""Test save_sql_query tool internal logic.
The tool function uses lazy imports inside its body (from flask import g,
from superset import db, etc.). We patch at the import source so that
when the function runs, it picks up our mocks.
The @parse_request decorator injects ctx via get_context() and strips
__wrapped__, so we mock get_context and call the decorated function
directly (without unwrapping).
"""
@pytest.mark.anyio
async def test_save_query_creates_saved_query(self) -> None:
"""Verify the tool calls SavedQueryDAO.create with correct attrs."""
mod, saved = _get_tool_module()
try:
mock_ctx = _make_mock_ctx()
mock_db_obj = MagicMock()
mock_db_obj.id = 1
mock_db_obj.database_name = "test_db"
mock_sq = MagicMock()
mock_sq.id = 42
mock_sq.label = "Revenue Query"
mock_sq.sql = "SELECT SUM(revenue) FROM sales"
mock_sq.catalog = None
request = SaveSqlQueryRequest(
database_id=1,
label="Revenue Query",
sql="SELECT SUM(revenue) FROM sales",
)
mock_db_session = MagicMock()
(
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
) = mock_db_obj
mock_sm = MagicMock()
mock_sm.can_access_database.return_value = True
mock_dao = MagicMock()
mock_dao.create.return_value = mock_sq
mock_g = MagicMock()
mock_g.user = Mock(id=1)
mock_event_logger = MagicMock()
mock_event_logger.log_context.return_value.__enter__ = Mock()
mock_event_logger.log_context.return_value.__exit__ = Mock(
return_value=False
)
with (
patch(
"fastmcp.server.dependencies.get_context",
return_value=mock_ctx,
),
patch("superset.db", mock_db_session),
patch("superset.security_manager", mock_sm),
patch("superset.daos.query.SavedQueryDAO", mock_dao),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost:8088",
),
patch("flask.g", mock_g),
patch.object(mod, "event_logger", mock_event_logger),
):
result = await mod.save_sql_query(request)
assert result.id == 42
assert result.label == "Revenue Query"
assert "savedQueryId=42" in result.url
mock_dao.create.assert_called_once()
call_attrs = mock_dao.create.call_args[1]["attributes"]
assert call_attrs["db_id"] == 1
assert call_attrs["label"] == "Revenue Query"
assert call_attrs["sql"] == "SELECT SUM(revenue) FROM sales"
assert call_attrs["user_id"] == 1
mock_db_session.session.commit.assert_called_once()
finally:
_restore_modules(saved)
@pytest.mark.anyio
async def test_save_query_database_not_found(self) -> None:
mod, saved = _get_tool_module()
try:
mock_ctx = _make_mock_ctx()
request = SaveSqlQueryRequest(
database_id=999,
label="Test",
sql="SELECT 1",
)
mock_db_session = MagicMock()
(
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
) = None
mock_g = MagicMock()
mock_g.user = Mock(id=1)
mock_event_logger = MagicMock()
mock_event_logger.log_context.return_value.__enter__ = Mock()
mock_event_logger.log_context.return_value.__exit__ = Mock(
return_value=False
)
with (
patch(
"fastmcp.server.dependencies.get_context",
return_value=mock_ctx,
),
patch("superset.db", mock_db_session),
patch("flask.g", mock_g),
patch.object(mod, "event_logger", mock_event_logger),
):
from superset.exceptions import SupersetErrorException
with pytest.raises(SupersetErrorException, match="not found"):
await mod.save_sql_query(request)
finally:
_restore_modules(saved)
@pytest.mark.anyio
async def test_save_query_access_denied(self) -> None:
mod, saved = _get_tool_module()
try:
mock_ctx = _make_mock_ctx()
mock_db_obj = MagicMock()
mock_db_obj.id = 1
mock_db_obj.database_name = "test_db"
request = SaveSqlQueryRequest(
database_id=1,
label="Test",
sql="SELECT 1",
)
mock_db_session = MagicMock()
(
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
) = mock_db_obj
mock_sm = MagicMock()
mock_sm.can_access_database.return_value = False
mock_g = MagicMock()
mock_g.user = Mock(id=1)
mock_event_logger = MagicMock()
mock_event_logger.log_context.return_value.__enter__ = Mock()
mock_event_logger.log_context.return_value.__exit__ = Mock(
return_value=False
)
with (
patch(
"fastmcp.server.dependencies.get_context",
return_value=mock_ctx,
),
patch("superset.db", mock_db_session),
patch("superset.security_manager", mock_sm),
patch("flask.g", mock_g),
patch.object(mod, "event_logger", mock_event_logger),
):
from superset.exceptions import SupersetSecurityException
with pytest.raises(SupersetSecurityException, match="Access denied"):
await mod.save_sql_query(request)
finally:
_restore_modules(saved)
@pytest.mark.anyio
async def test_save_query_with_schema_and_description(self) -> None:
mod, saved = _get_tool_module()
try:
mock_ctx = _make_mock_ctx()
mock_db_obj = MagicMock()
mock_db_obj.id = 1
mock_db_obj.database_name = "test_db"
mock_sq = MagicMock()
mock_sq.id = 10
mock_sq.label = "Test"
mock_sq.sql = "SELECT 1"
mock_sq.catalog = None
request = SaveSqlQueryRequest(
database_id=1,
label="Test",
sql="SELECT 1",
schema="public",
description="A test query",
)
mock_db_session = MagicMock()
(
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
) = mock_db_obj
mock_sm = MagicMock()
mock_sm.can_access_database.return_value = True
mock_dao = MagicMock()
mock_dao.create.return_value = mock_sq
mock_g = MagicMock()
mock_g.user = Mock(id=1)
mock_event_logger = MagicMock()
mock_event_logger.log_context.return_value.__enter__ = Mock()
mock_event_logger.log_context.return_value.__exit__ = Mock(
return_value=False
)
with (
patch(
"fastmcp.server.dependencies.get_context",
return_value=mock_ctx,
),
patch("superset.db", mock_db_session),
patch("superset.security_manager", mock_sm),
patch("superset.daos.query.SavedQueryDAO", mock_dao),
patch(
"superset.mcp_service.utils.url_utils.get_superset_base_url",
return_value="http://localhost:8088",
),
patch("flask.g", mock_g),
patch.object(mod, "event_logger", mock_event_logger),
):
result = await mod.save_sql_query(request)
assert result.id == 10
call_attrs = mock_dao.create.call_args[1]["attributes"]
assert call_attrs["schema"] == "public"
assert call_attrs["description"] == "A test query"
finally:
_restore_modules(saved)