mirror of
https://github.com/apache/superset.git
synced 2026-04-07 18:35:15 +00:00
feat(mcp): add save_sql_query tool for SQL Lab saved queries (#38414)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -70,6 +70,7 @@ Chart Management:
|
||||
|
||||
SQL Lab Integration:
|
||||
- execute_sql: Execute SQL queries and get results (requires database_id)
|
||||
- save_sql_query: Save a SQL query to Saved Queries list
|
||||
- open_sql_lab_with_context: Generate SQL Lab URL with pre-filled sql
|
||||
|
||||
Schema Discovery:
|
||||
@@ -105,7 +106,8 @@ To find your own charts/dashboards:
|
||||
To explore data with SQL:
|
||||
1. list_datasets -> find a dataset and note its database_id
|
||||
2. execute_sql(database_id, sql) -> run query
|
||||
3. open_sql_lab_with_context(database_id) -> open SQL Lab UI
|
||||
3. save_sql_query(database_id, label, sql) -> save query for later reuse
|
||||
4. open_sql_lab_with_context(database_id) -> open SQL Lab UI
|
||||
|
||||
generate_explore_link vs generate_chart:
|
||||
- Use generate_explore_link for exploration (no permanent chart created)
|
||||
@@ -415,6 +417,7 @@ from superset.mcp_service.explore.tool import ( # noqa: F401, E402
|
||||
from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402
|
||||
execute_sql,
|
||||
open_sql_lab_with_context,
|
||||
save_sql_query,
|
||||
)
|
||||
from superset.mcp_service.system import ( # noqa: F401, E402
|
||||
prompts as system_prompts,
|
||||
|
||||
@@ -147,6 +147,65 @@ class ExecuteSqlResponse(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class SaveSqlQueryRequest(BaseModel):
|
||||
"""Request schema for saving a SQL query."""
|
||||
|
||||
database_id: int = Field(
|
||||
..., description="Database connection ID the query runs against"
|
||||
)
|
||||
label: str = Field(
|
||||
...,
|
||||
description="Name for the saved query (shown in Saved Queries list)",
|
||||
min_length=1,
|
||||
max_length=256,
|
||||
)
|
||||
sql: str = Field(
|
||||
...,
|
||||
description="SQL query text to save",
|
||||
)
|
||||
schema_name: str | None = Field(
|
||||
None,
|
||||
description="Schema the query targets",
|
||||
alias="schema",
|
||||
)
|
||||
catalog: str | None = Field(None, description="Catalog name (if applicable)")
|
||||
description: str | None = Field(
|
||||
None, description="Optional description of the query"
|
||||
)
|
||||
|
||||
@field_validator("sql")
|
||||
@classmethod
|
||||
def sql_not_empty(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("SQL query cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
@field_validator("label")
|
||||
@classmethod
|
||||
def label_not_empty(cls, v: str) -> str:
|
||||
if not v or not v.strip():
|
||||
raise ValueError("Label cannot be empty")
|
||||
return v.strip()
|
||||
|
||||
|
||||
class SaveSqlQueryResponse(BaseModel):
|
||||
"""Response schema for a saved SQL query."""
|
||||
|
||||
id: int = Field(..., description="Saved query ID")
|
||||
label: str = Field(..., description="Query name")
|
||||
sql: str = Field(..., description="SQL query text")
|
||||
database_id: int = Field(..., description="Database ID")
|
||||
schema_name: str | None = Field(None, description="Schema name", alias="schema")
|
||||
catalog: str | None = Field(None, description="Catalog name (if applicable)")
|
||||
description: str | None = Field(None, description="Query description")
|
||||
url: str = Field(
|
||||
...,
|
||||
description=(
|
||||
"URL to open this saved query in SQL Lab (e.g., /sqllab?savedQueryId=42)"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class OpenSqlLabRequest(BaseModel):
|
||||
"""Request schema for opening SQL Lab with context."""
|
||||
|
||||
|
||||
@@ -23,8 +23,10 @@ from superset.mcp_service.sql_lab.tool.execute_sql import execute_sql
|
||||
from superset.mcp_service.sql_lab.tool.open_sql_lab_with_context import (
|
||||
open_sql_lab_with_context,
|
||||
)
|
||||
from superset.mcp_service.sql_lab.tool.save_sql_query import save_sql_query
|
||||
|
||||
__all__ = [
|
||||
"execute_sql",
|
||||
"open_sql_lab_with_context",
|
||||
"save_sql_query",
|
||||
]
|
||||
|
||||
137
superset/mcp_service/sql_lab/tool/save_sql_query.py
Normal file
137
superset/mcp_service/sql_lab/tool/save_sql_query.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Save SQL Query MCP Tool
|
||||
|
||||
Tool for saving a SQL query as a named SavedQuery in Superset,
|
||||
so it appears in SQL Lab's "Saved Queries" list and can be
|
||||
reloaded/shared via URL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from fastmcp import Context
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from superset_core.mcp.decorators import tool
|
||||
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.exceptions import SupersetErrorException, SupersetSecurityException
|
||||
from superset.extensions import event_logger
|
||||
from superset.mcp_service.sql_lab.schemas import (
|
||||
SaveSqlQueryRequest,
|
||||
SaveSqlQueryResponse,
|
||||
)
|
||||
from superset.mcp_service.utils.schema_utils import parse_request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@tool(tags=["mutate"])
|
||||
@parse_request(SaveSqlQueryRequest)
|
||||
async def save_sql_query(
|
||||
request: SaveSqlQueryRequest, ctx: Context
|
||||
) -> SaveSqlQueryResponse:
|
||||
"""Save a SQL query so it appears in SQL Lab's Saved Queries list.
|
||||
|
||||
Creates a persistent SavedQuery that the user can reload from
|
||||
SQL Lab, share via URL, and find in the Saved Queries page.
|
||||
Requires a database_id, a label (name), and the SQL text.
|
||||
"""
|
||||
await ctx.info(
|
||||
"Saving SQL query: database_id=%s, label=%r"
|
||||
% (request.database_id, request.label)
|
||||
)
|
||||
|
||||
try:
|
||||
from flask import g
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.daos.query import SavedQueryDAO
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
from superset.models.core import Database
|
||||
|
||||
# 1. Validate database exists and user has access
|
||||
with event_logger.log_context(action="mcp.save_sql_query.db_validation"):
|
||||
database = (
|
||||
db.session.query(Database).filter_by(id=request.database_id).first()
|
||||
)
|
||||
if not database:
|
||||
raise SupersetErrorException(
|
||||
SupersetError(
|
||||
message=(f"Database with ID {request.database_id} not found"),
|
||||
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
if not security_manager.can_access_database(database):
|
||||
raise SupersetSecurityException(
|
||||
SupersetError(
|
||||
message=(f"Access denied to database {database.database_name}"),
|
||||
error_type=(SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR),
|
||||
level=ErrorLevel.ERROR,
|
||||
)
|
||||
)
|
||||
|
||||
# 2. Create the saved query
|
||||
with event_logger.log_context(action="mcp.save_sql_query.create"):
|
||||
saved_query = SavedQueryDAO.create(
|
||||
attributes={
|
||||
"user_id": g.user.id,
|
||||
"db_id": request.database_id,
|
||||
"label": request.label,
|
||||
"sql": request.sql,
|
||||
"schema": request.schema_name or "",
|
||||
"catalog": request.catalog,
|
||||
"description": request.description or "",
|
||||
}
|
||||
)
|
||||
db.session.commit() # pylint: disable=consider-using-transaction
|
||||
|
||||
# 3. Build response
|
||||
base_url = get_superset_base_url()
|
||||
saved_query_url = f"{base_url}/sqllab?savedQueryId={saved_query.id}"
|
||||
|
||||
await ctx.info(
|
||||
"Saved query created: id=%s, url=%s" % (saved_query.id, saved_query_url)
|
||||
)
|
||||
|
||||
return SaveSqlQueryResponse(
|
||||
id=saved_query.id,
|
||||
label=saved_query.label,
|
||||
sql=saved_query.sql,
|
||||
database_id=request.database_id,
|
||||
schema_name=request.schema_name,
|
||||
catalog=getattr(saved_query, "catalog", None),
|
||||
description=request.description,
|
||||
url=saved_query_url,
|
||||
)
|
||||
|
||||
except (SupersetErrorException, SupersetSecurityException):
|
||||
raise
|
||||
except SQLAlchemyError as e:
|
||||
from superset import db
|
||||
|
||||
db.session.rollback()
|
||||
await ctx.error(
|
||||
"Failed to save SQL query: error=%s, database_id=%s"
|
||||
% (str(e), request.database_id)
|
||||
)
|
||||
raise
|
||||
467
tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py
Normal file
467
tests/unit_tests/mcp_service/sql_lab/tool/test_save_sql_query.py
Normal file
@@ -0,0 +1,467 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for save_sql_query MCP tool schemas and logic.
|
||||
"""
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
import types
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.sql_lab.schemas import (
|
||||
SaveSqlQueryRequest,
|
||||
SaveSqlQueryResponse,
|
||||
)
|
||||
|
||||
|
||||
class TestSaveSqlQueryRequest:
|
||||
"""Test SaveSqlQueryRequest schema validation."""
|
||||
|
||||
def test_valid_request(self) -> None:
|
||||
req = SaveSqlQueryRequest(
|
||||
database_id=1,
|
||||
label="Revenue Query",
|
||||
sql="SELECT SUM(revenue) FROM sales",
|
||||
)
|
||||
assert req.database_id == 1
|
||||
assert req.label == "Revenue Query"
|
||||
assert req.sql == "SELECT SUM(revenue) FROM sales"
|
||||
|
||||
def test_with_optional_fields(self) -> None:
|
||||
req = SaveSqlQueryRequest(
|
||||
database_id=1,
|
||||
label="Revenue Query",
|
||||
sql="SELECT 1",
|
||||
schema="public",
|
||||
catalog="main",
|
||||
description="Sums revenue",
|
||||
)
|
||||
assert req.schema_name == "public"
|
||||
assert req.catalog == "main"
|
||||
assert req.description == "Sums revenue"
|
||||
|
||||
def test_empty_sql_fails(self) -> None:
|
||||
with pytest.raises(ValidationError, match="SQL query cannot be empty"):
|
||||
SaveSqlQueryRequest(database_id=1, label="test", sql=" ")
|
||||
|
||||
def test_empty_label_fails(self) -> None:
|
||||
with pytest.raises(ValidationError, match="Label cannot be empty"):
|
||||
SaveSqlQueryRequest(database_id=1, label=" ", sql="SELECT 1")
|
||||
|
||||
def test_sql_is_stripped(self) -> None:
|
||||
req = SaveSqlQueryRequest(database_id=1, label="test", sql=" SELECT 1 ")
|
||||
assert req.sql == "SELECT 1"
|
||||
|
||||
def test_label_is_stripped(self) -> None:
|
||||
req = SaveSqlQueryRequest(database_id=1, label=" My Query ", sql="SELECT 1")
|
||||
assert req.label == "My Query"
|
||||
|
||||
def test_label_max_length(self) -> None:
|
||||
with pytest.raises(ValidationError, match="String should have at most 256"):
|
||||
SaveSqlQueryRequest(database_id=1, label="a" * 257, sql="SELECT 1")
|
||||
|
||||
def test_schema_alias(self) -> None:
|
||||
"""The field accepts 'schema' as alias for 'schema_name'."""
|
||||
req = SaveSqlQueryRequest(
|
||||
database_id=1,
|
||||
label="test",
|
||||
sql="SELECT 1",
|
||||
schema="public",
|
||||
)
|
||||
assert req.schema_name == "public"
|
||||
|
||||
|
||||
class TestSaveSqlQueryResponse:
|
||||
"""Test SaveSqlQueryResponse schema."""
|
||||
|
||||
def test_response_fields(self) -> None:
|
||||
resp = SaveSqlQueryResponse(
|
||||
id=42,
|
||||
label="Revenue",
|
||||
sql="SELECT 1",
|
||||
database_id=1,
|
||||
url="/sqllab?savedQueryId=42",
|
||||
)
|
||||
assert resp.id == 42
|
||||
assert resp.label == "Revenue"
|
||||
assert resp.url == "/sqllab?savedQueryId=42"
|
||||
|
||||
def test_response_with_optional_fields(self) -> None:
|
||||
resp = SaveSqlQueryResponse(
|
||||
id=42,
|
||||
label="Revenue",
|
||||
sql="SELECT 1",
|
||||
database_id=1,
|
||||
schema="public",
|
||||
description="A query",
|
||||
url="/sqllab?savedQueryId=42",
|
||||
)
|
||||
assert resp.schema_name == "public"
|
||||
assert resp.description == "A query"
|
||||
|
||||
|
||||
def _force_passthrough_decorators():
|
||||
"""Force superset_core MCP tool decorator to be a passthrough.
|
||||
|
||||
In CI, superset_core is fully installed and the real @tool decorator
|
||||
includes authentication middleware. For unit tests we want to bypass
|
||||
auth and test the tool logic directly, so we always replace the
|
||||
decorator with a passthrough regardless of installation state.
|
||||
|
||||
Returns a dict of original sys.modules entries so they can be restored.
|
||||
"""
|
||||
|
||||
def _passthrough_tool(func=None, **kwargs):
|
||||
if func is not None:
|
||||
return func
|
||||
return lambda f: f
|
||||
|
||||
mock_mcp = MagicMock()
|
||||
mock_mcp.tool = _passthrough_tool
|
||||
|
||||
mock_decorators = MagicMock()
|
||||
mock_decorators.tool = _passthrough_tool
|
||||
|
||||
mock_api = MagicMock()
|
||||
mock_api.mcp = mock_mcp
|
||||
|
||||
# Save original modules so we can restore them later
|
||||
saved_modules: dict[str, types.ModuleType] = {}
|
||||
|
||||
# Only mock the specific decorator submodules, NOT the top-level
|
||||
# superset_core package. Replacing sys.modules["superset_core"] with
|
||||
# a MagicMock causes 'superset_core' is not a package errors for
|
||||
# other submodules (queries, common) that are imported by sibling
|
||||
# tool files during test collection.
|
||||
mock_keys = [
|
||||
"superset_core.api",
|
||||
"superset_core.api.mcp",
|
||||
"superset_core.api.types",
|
||||
"superset_core.mcp",
|
||||
"superset_core.mcp.decorators",
|
||||
]
|
||||
for key in mock_keys:
|
||||
if key in sys.modules:
|
||||
saved_modules[key] = sys.modules[key]
|
||||
|
||||
sys.modules["superset_core.api"] = mock_api
|
||||
sys.modules["superset_core.api.mcp"] = mock_mcp
|
||||
sys.modules["superset_core.mcp"] = mock_mcp
|
||||
sys.modules["superset_core.mcp.decorators"] = mock_decorators
|
||||
sys.modules.setdefault("superset_core.api.types", MagicMock())
|
||||
|
||||
return saved_modules
|
||||
|
||||
|
||||
def _restore_modules(saved_modules: dict[str, types.ModuleType]) -> None:
|
||||
"""Restore original sys.modules entries after passthrough mocking."""
|
||||
# Remove mock entries for decorator paths and tool modules imported
|
||||
# under patched decorators. Do NOT remove the top-level superset_core
|
||||
# package or unrelated submodules (queries, common, etc.).
|
||||
mock_prefixes = (
|
||||
"superset_core.api",
|
||||
"superset_core.mcp",
|
||||
"superset.mcp_service.sql_lab.tool",
|
||||
)
|
||||
for key in list(sys.modules.keys()):
|
||||
if any(key.startswith(prefix) for prefix in mock_prefixes):
|
||||
del sys.modules[key]
|
||||
# Restore originals (including any previously-imported tool modules)
|
||||
sys.modules.update(saved_modules)
|
||||
|
||||
|
||||
def _get_tool_module():
|
||||
"""Import save_sql_query with passthrough decorators (no auth).
|
||||
|
||||
Returns (module, saved_modules) so callers can restore sys.modules.
|
||||
"""
|
||||
saved_modules = _force_passthrough_decorators()
|
||||
# Clear cached module imports so we get a fresh import with mocked
|
||||
# decorators. This is necessary because in CI the real @tool decorator
|
||||
# may have been applied during a previous import.
|
||||
mod_name = "superset.mcp_service.sql_lab.tool.save_sql_query"
|
||||
saved_tool_modules: dict[str, object] = {}
|
||||
for key in list(sys.modules.keys()):
|
||||
if key.startswith("superset.mcp_service.sql_lab.tool"):
|
||||
saved_tool_modules[key] = sys.modules.pop(key)
|
||||
saved_modules.update(saved_tool_modules)
|
||||
mod = importlib.import_module(mod_name)
|
||||
return mod, saved_modules
|
||||
|
||||
|
||||
def _make_mock_ctx():
|
||||
"""Create a mock FastMCP Context with awaitable methods."""
|
||||
|
||||
async def _noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
ctx = MagicMock()
|
||||
ctx.info = _noop
|
||||
ctx.error = _noop
|
||||
ctx.warning = _noop
|
||||
return ctx
|
||||
|
||||
|
||||
class TestSaveSqlQueryToolLogic:
|
||||
"""Test save_sql_query tool internal logic.
|
||||
|
||||
The tool function uses lazy imports inside its body (from flask import g,
|
||||
from superset import db, etc.). We patch at the import source so that
|
||||
when the function runs, it picks up our mocks.
|
||||
|
||||
The @parse_request decorator injects ctx via get_context() and strips
|
||||
__wrapped__, so we mock get_context and call the decorated function
|
||||
directly (without unwrapping).
|
||||
"""
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_save_query_creates_saved_query(self) -> None:
|
||||
"""Verify the tool calls SavedQueryDAO.create with correct attrs."""
|
||||
mod, saved = _get_tool_module()
|
||||
try:
|
||||
mock_ctx = _make_mock_ctx()
|
||||
|
||||
mock_db_obj = MagicMock()
|
||||
mock_db_obj.id = 1
|
||||
mock_db_obj.database_name = "test_db"
|
||||
|
||||
mock_sq = MagicMock()
|
||||
mock_sq.id = 42
|
||||
mock_sq.label = "Revenue Query"
|
||||
mock_sq.sql = "SELECT SUM(revenue) FROM sales"
|
||||
mock_sq.catalog = None
|
||||
|
||||
request = SaveSqlQueryRequest(
|
||||
database_id=1,
|
||||
label="Revenue Query",
|
||||
sql="SELECT SUM(revenue) FROM sales",
|
||||
)
|
||||
|
||||
mock_db_session = MagicMock()
|
||||
(
|
||||
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
|
||||
) = mock_db_obj
|
||||
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.can_access_database.return_value = True
|
||||
|
||||
mock_dao = MagicMock()
|
||||
mock_dao.create.return_value = mock_sq
|
||||
|
||||
mock_g = MagicMock()
|
||||
mock_g.user = Mock(id=1)
|
||||
|
||||
mock_event_logger = MagicMock()
|
||||
mock_event_logger.log_context.return_value.__enter__ = Mock()
|
||||
mock_event_logger.log_context.return_value.__exit__ = Mock(
|
||||
return_value=False
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastmcp.server.dependencies.get_context",
|
||||
return_value=mock_ctx,
|
||||
),
|
||||
patch("superset.db", mock_db_session),
|
||||
patch("superset.security_manager", mock_sm),
|
||||
patch("superset.daos.query.SavedQueryDAO", mock_dao),
|
||||
patch(
|
||||
"superset.mcp_service.utils.url_utils.get_superset_base_url",
|
||||
return_value="http://localhost:8088",
|
||||
),
|
||||
patch("flask.g", mock_g),
|
||||
patch.object(mod, "event_logger", mock_event_logger),
|
||||
):
|
||||
result = await mod.save_sql_query(request)
|
||||
|
||||
assert result.id == 42
|
||||
assert result.label == "Revenue Query"
|
||||
assert "savedQueryId=42" in result.url
|
||||
mock_dao.create.assert_called_once()
|
||||
call_attrs = mock_dao.create.call_args[1]["attributes"]
|
||||
assert call_attrs["db_id"] == 1
|
||||
assert call_attrs["label"] == "Revenue Query"
|
||||
assert call_attrs["sql"] == "SELECT SUM(revenue) FROM sales"
|
||||
assert call_attrs["user_id"] == 1
|
||||
mock_db_session.session.commit.assert_called_once()
|
||||
finally:
|
||||
_restore_modules(saved)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_save_query_database_not_found(self) -> None:
|
||||
mod, saved = _get_tool_module()
|
||||
try:
|
||||
mock_ctx = _make_mock_ctx()
|
||||
|
||||
request = SaveSqlQueryRequest(
|
||||
database_id=999,
|
||||
label="Test",
|
||||
sql="SELECT 1",
|
||||
)
|
||||
|
||||
mock_db_session = MagicMock()
|
||||
(
|
||||
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
|
||||
) = None
|
||||
|
||||
mock_g = MagicMock()
|
||||
mock_g.user = Mock(id=1)
|
||||
|
||||
mock_event_logger = MagicMock()
|
||||
mock_event_logger.log_context.return_value.__enter__ = Mock()
|
||||
mock_event_logger.log_context.return_value.__exit__ = Mock(
|
||||
return_value=False
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastmcp.server.dependencies.get_context",
|
||||
return_value=mock_ctx,
|
||||
),
|
||||
patch("superset.db", mock_db_session),
|
||||
patch("flask.g", mock_g),
|
||||
patch.object(mod, "event_logger", mock_event_logger),
|
||||
):
|
||||
from superset.exceptions import SupersetErrorException
|
||||
|
||||
with pytest.raises(SupersetErrorException, match="not found"):
|
||||
await mod.save_sql_query(request)
|
||||
finally:
|
||||
_restore_modules(saved)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_save_query_access_denied(self) -> None:
|
||||
mod, saved = _get_tool_module()
|
||||
try:
|
||||
mock_ctx = _make_mock_ctx()
|
||||
|
||||
mock_db_obj = MagicMock()
|
||||
mock_db_obj.id = 1
|
||||
mock_db_obj.database_name = "test_db"
|
||||
|
||||
request = SaveSqlQueryRequest(
|
||||
database_id=1,
|
||||
label="Test",
|
||||
sql="SELECT 1",
|
||||
)
|
||||
|
||||
mock_db_session = MagicMock()
|
||||
(
|
||||
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
|
||||
) = mock_db_obj
|
||||
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.can_access_database.return_value = False
|
||||
|
||||
mock_g = MagicMock()
|
||||
mock_g.user = Mock(id=1)
|
||||
|
||||
mock_event_logger = MagicMock()
|
||||
mock_event_logger.log_context.return_value.__enter__ = Mock()
|
||||
mock_event_logger.log_context.return_value.__exit__ = Mock(
|
||||
return_value=False
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastmcp.server.dependencies.get_context",
|
||||
return_value=mock_ctx,
|
||||
),
|
||||
patch("superset.db", mock_db_session),
|
||||
patch("superset.security_manager", mock_sm),
|
||||
patch("flask.g", mock_g),
|
||||
patch.object(mod, "event_logger", mock_event_logger),
|
||||
):
|
||||
from superset.exceptions import SupersetSecurityException
|
||||
|
||||
with pytest.raises(SupersetSecurityException, match="Access denied"):
|
||||
await mod.save_sql_query(request)
|
||||
finally:
|
||||
_restore_modules(saved)
|
||||
|
||||
@pytest.mark.anyio
|
||||
async def test_save_query_with_schema_and_description(self) -> None:
|
||||
mod, saved = _get_tool_module()
|
||||
try:
|
||||
mock_ctx = _make_mock_ctx()
|
||||
|
||||
mock_db_obj = MagicMock()
|
||||
mock_db_obj.id = 1
|
||||
mock_db_obj.database_name = "test_db"
|
||||
|
||||
mock_sq = MagicMock()
|
||||
mock_sq.id = 10
|
||||
mock_sq.label = "Test"
|
||||
mock_sq.sql = "SELECT 1"
|
||||
mock_sq.catalog = None
|
||||
|
||||
request = SaveSqlQueryRequest(
|
||||
database_id=1,
|
||||
label="Test",
|
||||
sql="SELECT 1",
|
||||
schema="public",
|
||||
description="A test query",
|
||||
)
|
||||
|
||||
mock_db_session = MagicMock()
|
||||
(
|
||||
mock_db_session.session.query.return_value.filter_by.return_value.first.return_value
|
||||
) = mock_db_obj
|
||||
|
||||
mock_sm = MagicMock()
|
||||
mock_sm.can_access_database.return_value = True
|
||||
|
||||
mock_dao = MagicMock()
|
||||
mock_dao.create.return_value = mock_sq
|
||||
|
||||
mock_g = MagicMock()
|
||||
mock_g.user = Mock(id=1)
|
||||
|
||||
mock_event_logger = MagicMock()
|
||||
mock_event_logger.log_context.return_value.__enter__ = Mock()
|
||||
mock_event_logger.log_context.return_value.__exit__ = Mock(
|
||||
return_value=False
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastmcp.server.dependencies.get_context",
|
||||
return_value=mock_ctx,
|
||||
),
|
||||
patch("superset.db", mock_db_session),
|
||||
patch("superset.security_manager", mock_sm),
|
||||
patch("superset.daos.query.SavedQueryDAO", mock_dao),
|
||||
patch(
|
||||
"superset.mcp_service.utils.url_utils.get_superset_base_url",
|
||||
return_value="http://localhost:8088",
|
||||
),
|
||||
patch("flask.g", mock_g),
|
||||
patch.object(mod, "event_logger", mock_event_logger),
|
||||
):
|
||||
result = await mod.save_sql_query(request)
|
||||
|
||||
assert result.id == 10
|
||||
call_attrs = mock_dao.create.call_args[1]["attributes"]
|
||||
assert call_attrs["schema"] == "public"
|
||||
assert call_attrs["description"] == "A test query"
|
||||
finally:
|
||||
_restore_modules(saved)
|
||||
Reference in New Issue
Block a user