feat(mcp): MCP service implementation (PRs 3-9 consolidated) (#35877)

This commit is contained in:
Amin Ghadersohi
2025-11-01 02:33:21 +11:00
committed by GitHub
parent 30d584afd1
commit fee4e7d8e2
106 changed files with 21826 additions and 223 deletions

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,64 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Helper function to extract row data from MCP responses.
The MCP client seems to wrap dict rows in Root objects.
This helper handles the extraction properly.
"""
def extract_row_data(row):
"""Extract dictionary data from a row object."""
# Handle different possible formats
if isinstance(row, dict):
return row
# Check for Pydantic Root object
if hasattr(row, "__root__"):
return row.__root__
# Check if it's a Pydantic model with model_dump
if hasattr(row, "model_dump"):
return row.model_dump()
# Try to access __dict__ directly
if hasattr(row, "__dict__"):
# Filter out private attributes
return {k: v for k, v in row.__dict__.items() if not k.startswith("_")}
# Last resort - convert to string and parse
# This is for the Root object issue
row_str = str(row)
if row_str == "Root()":
# Empty Root object - the actual data might be elsewhere
# Let's check all attributes
attrs = dir(row)
for attr in attrs:
if not attr.startswith("_") and attr not in [
"model_dump",
"model_validate",
]:
try:
val = getattr(row, attr)
if isinstance(val, dict):
return val
except AttributeError:
pass
raise ValueError(f"Cannot extract data from row of type {type(row)}: {row}")

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,497 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Unit tests for execute_sql MCP tool
"""
import logging
from unittest.mock import MagicMock, Mock, patch
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
from superset.mcp_service.app import mcp
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest.fixture
def mcp_server():
return mcp
@pytest.fixture(autouse=True)
def mock_auth():
"""Mock authentication for all tests."""
with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user:
mock_user = Mock()
mock_user.id = 1
mock_user.username = "admin"
mock_get_user.return_value = mock_user
yield mock_get_user
def _mock_database(
id: int = 1,
database_name: str = "test_db",
allow_dml: bool = False,
) -> Mock:
"""Create a mock database object."""
database = Mock()
database.id = id
database.database_name = database_name
database.allow_dml = allow_dml
# Mock raw connection context manager
mock_cursor = Mock()
mock_cursor.description = [
("id", "INTEGER", None, None, None, None, False),
("name", "VARCHAR", None, None, None, None, True),
]
mock_cursor.fetchmany.return_value = [(1, "test_name")]
mock_cursor.rowcount = 1
mock_conn = Mock()
mock_conn.cursor.return_value = mock_cursor
mock_conn.commit = Mock()
mock_context = MagicMock()
mock_context.__enter__.return_value = mock_conn
mock_context.__exit__.return_value = None
database.get_raw_connection.return_value = mock_context
return database
class TestExecuteSql:
"""Tests for execute_sql MCP tool."""
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_basic_select(
self, mock_db, mock_security_manager, mcp_server
):
"""Test basic SELECT query execution."""
# Setup mocks
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT id, name FROM users LIMIT 10",
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.row_count == 1
assert len(result.data.rows) == 1
assert result.data.rows[0]["id"] == 1
assert result.data.rows[0]["name"] == "test_name"
assert len(result.data.columns) == 2
assert result.data.columns[0].name == "id"
assert result.data.columns[0].type == "INTEGER"
assert result.data.execution_time > 0
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_with_parameters(
self, mock_db, mock_security_manager, mcp_server
):
"""Test SQL execution with parameter substitution."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table} WHERE status = '{status}' LIMIT {limit}",
"parameters": {"table": "orders", "status": "active", "limit": "5"},
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
# Verify parameter substitution happened
mock_database.get_raw_connection.assert_called_once()
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
# Check that the SQL was formatted with parameters
executed_sql = cursor.execute.call_args[0][0]
assert "orders" in executed_sql
assert "active" in executed_sql
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_database_not_found(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when database is not found."""
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
None
)
request = {
"database_id": 999,
"sql": "SELECT 1",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Database with ID 999 not found" in result.data.error
assert result.data.error_type == "DATABASE_NOT_FOUND_ERROR"
assert result.data.rows is None
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_access_denied(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when user lacks database access."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
# Use Mock instead of AsyncMock for synchronous call
from unittest.mock import Mock
mock_security_manager.can_access_database = Mock(return_value=False)
request = {
"database_id": 1,
"sql": "SELECT 1",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Access denied to database" in result.data.error
assert result.data.error_type == "SECURITY_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_dml_not_allowed(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when DML operations are not allowed."""
mock_database = _mock_database(allow_dml=False)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "UPDATE users SET name = 'test' WHERE id = 1",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert result.data.error_type == "DML_NOT_ALLOWED"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_dml_allowed(
self, mock_db, mock_security_manager, mcp_server
):
"""Test successful DML execution when allowed."""
mock_database = _mock_database(allow_dml=True)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock cursor for DML operation
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.rowcount = 3 # 3 rows affected
request = {
"database_id": 1,
"sql": "UPDATE users SET active = true WHERE last_login > '2024-01-01'",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.affected_rows == 3
assert result.data.rows == [] # Empty rows for DML
assert result.data.row_count == 0
# Verify commit was called
(
mock_database.get_raw_connection.return_value.__enter__.return_value.commit.assert_called_once()
)
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_empty_results(
self, mock_db, mock_security_manager, mcp_server
):
"""Test query that returns no results."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock empty results
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.fetchmany.return_value = []
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 999999",
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
assert result.data.row_count == 0
assert len(result.data.rows) == 0
assert len(result.data.columns) == 2 # Column metadata still returned
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_missing_parameter(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when required parameter is missing."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table_name} WHERE id = {user_id}",
"parameters": {"table_name": "users"}, # Missing user_id
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "user_id" in result.data.error # Error contains parameter name
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_empty_parameters_with_placeholders(
self, mock_db, mock_security_manager, mcp_server
):
"""Test error when empty parameters dict is provided but SQL has
placeholders."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM {table_name} LIMIT 5",
"parameters": {}, # Empty dict but SQL has {table_name}
"limit": 5,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Missing parameter: table_name" in result.data.error
assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_with_schema(
self, mock_db, mock_security_manager, mcp_server
):
"""Test SQL execution with schema specification."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT COUNT(*) as total FROM orders",
"schema": "sales",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
assert result.data.error is None
# Verify schema was passed to get_raw_connection
# Verify schema was passed
call_args = mock_database.get_raw_connection.call_args
assert call_args[1]["schema"] == "sales"
assert call_args[1]["catalog"] is None
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_limit_enforcement(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that LIMIT is added to SELECT queries without one."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
request = {
"database_id": 1,
"sql": "SELECT * FROM users", # No LIMIT
"limit": 50,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is True
# Verify LIMIT was added
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
executed_sql = cursor.execute.call_args[0][0]
assert "LIMIT 50" in executed_sql
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_sql_injection_prevention(
self, mock_db, mock_security_manager, mcp_server
):
"""Test that SQL injection attempts are handled safely."""
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
# Mock execute to raise an exception
cursor = ( # fmt: skip
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
)
cursor.execute.side_effect = Exception("Syntax error")
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request": request})
assert result.data.success is False
assert result.data.error is not None
assert "Syntax error" in result.data.error # Contains actual error
assert result.data.error_type == "EXECUTION_ERROR"
@pytest.mark.asyncio
async def test_execute_sql_empty_query_validation(self, mcp_server):
"""Test validation of empty SQL query."""
request = {
"database_id": 1,
"sql": " ", # Empty/whitespace only
"limit": 10,
}
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="SQL query cannot be empty"):
await client.call_tool("execute_sql", {"request": request})
@pytest.mark.asyncio
async def test_execute_sql_invalid_limit(self, mcp_server):
"""Test validation of invalid limit values."""
# Test limit too low
request = {
"database_id": 1,
"sql": "SELECT 1",
"limit": 0,
}
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="minimum of 1"):
await client.call_tool("execute_sql", {"request": request})
# Test limit too high
request = {
"database_id": 1,
"sql": "SELECT 1",
"limit": 20000,
}
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="maximum of 10000"):
await client.call_tool("execute_sql", {"request": request})