feat(mcp): MCP service implementation (PRs 3-9 consolidated) (#35877)

This commit is contained in:
Amin Ghadersohi
2025-11-01 02:33:21 +11:00
committed by GitHub
parent 30d584afd1
commit fee4e7d8e2
106 changed files with 21826 additions and 223 deletions

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,221 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Generic SQL execution core for MCP service.
"""
import logging
from typing import Any
from superset.mcp_service.mcp_core import BaseCore
from superset.mcp_service.sql_lab.schemas import (
ExecuteSqlRequest,
ExecuteSqlResponse,
)
class ExecuteSqlCore(BaseCore):
"""
Generic tool for executing SQL queries with security validation.
This tool provides a high-level interface for SQL execution that can be used
by different MCP tools or other components. It handles:
- Database access validation
- SQL query validation (DML permissions, disallowed functions)
- Parameter substitution
- Query execution with timeout
- Result formatting
The tool can work in two modes:
1. Simple mode: Direct SQL execution using sql_lab_utils (default)
2. Command mode: Using ExecuteSqlCommand for full SQL Lab features
"""
def __init__(
self,
use_command_mode: bool = False,
logger: logging.Logger | None = None,
) -> None:
super().__init__(logger)
self.use_command_mode = use_command_mode
def run_tool(self, request: ExecuteSqlRequest) -> ExecuteSqlResponse:
"""
Execute SQL query and return results.
Args:
request: ExecuteSqlRequest with database_id, sql, and optional parameters
Returns:
ExecuteSqlResponse with success status, results, or error information
"""
try:
# Import inside method to avoid initialization issues
from superset.mcp_service.sql_lab.sql_lab_utils import check_database_access
# Check database access
database = check_database_access(request.database_id)
if self.use_command_mode:
# Use full SQL Lab command for complex queries
return self._execute_with_command(request, database)
else:
# Use simplified execution for basic queries
return self._execute_simple(request, database)
except Exception as e:
# Handle errors and return error response with proper error types
self._log_error(e, "executing SQL")
return self._handle_execution_error(e)
def _execute_simple(
self,
request: ExecuteSqlRequest,
database: Any,
) -> ExecuteSqlResponse:
"""Execute SQL using simplified sql_lab_utils."""
# Import inside method to avoid initialization issues
from superset.mcp_service.sql_lab.sql_lab_utils import execute_sql_query
results = execute_sql_query(
database=database,
sql=request.sql,
schema=request.schema_name,
limit=request.limit,
timeout=request.timeout,
parameters=request.parameters,
)
return ExecuteSqlResponse(
success=True,
rows=results.get("rows"),
columns=results.get("columns"),
row_count=results.get("row_count"),
affected_rows=results.get("affected_rows"),
query_id=None, # Not available in simple mode
execution_time=results.get("execution_time"),
error=None,
error_type=None,
)
def _execute_with_command(
self,
request: ExecuteSqlRequest,
database: Any,
) -> ExecuteSqlResponse:
"""Execute SQL using full SQL Lab command (not implemented yet)."""
# This would use ExecuteSqlCommand for full SQL Lab features
# Including query caching, async execution, complex parsing, etc.
# For now, we'll fall back to simple execution
self._log_info("Command mode not fully implemented, using simple mode")
return self._execute_simple(request, database)
# Future implementation would look like:
# context = SqlJsonExecutionContext(
# database_id=request.database_id,
# sql=request.sql,
# schema=request.schema_name,
# limit=request.limit,
# # ... other context fields
# )
#
# command = ExecuteSqlCommand(
# execution_context=context,
# query_dao=QueryDAO(),
# database_dao=DatabaseDAO(),
# # ... other dependencies
# )
#
# result = command.run()
# return self._format_command_result(result)
def _handle_execution_error(self, e: Exception) -> ExecuteSqlResponse:
"""Map exceptions to error responses."""
error_type = self._get_error_type(e)
return ExecuteSqlResponse(
success=False,
error=str(e),
error_type=error_type,
rows=None,
columns=None,
row_count=None,
affected_rows=None,
query_id=None,
execution_time=None,
)
def _get_error_type(self, e: Exception) -> str:
"""Determine error type from exception."""
# Import inside method to avoid initialization issues
from superset.exceptions import (
SupersetDisallowedSQLFunctionException,
SupersetDMLNotAllowedException,
SupersetErrorException,
SupersetSecurityException,
SupersetTimeoutException,
)
if isinstance(e, SupersetSecurityException):
return "SECURITY_ERROR"
elif isinstance(e, SupersetTimeoutException):
return "TIMEOUT"
elif isinstance(e, SupersetDMLNotAllowedException):
return "DML_NOT_ALLOWED"
elif isinstance(e, SupersetDisallowedSQLFunctionException):
return "DISALLOWED_FUNCTION"
elif isinstance(e, SupersetErrorException):
return self._extract_superset_error_type(e)
else:
return "EXECUTION_ERROR"
def _extract_superset_error_type(self, e: Exception) -> str:
"""Extract error type from SupersetErrorException."""
if hasattr(e, "error") and hasattr(e.error, "error_type"):
error_type_name = e.error.error_type.name
# Map common error type patterns
if "INVALID_PAYLOAD" in error_type_name:
return "INVALID_PAYLOAD_FORMAT_ERROR"
elif "DATABASE_NOT_FOUND" in error_type_name:
return "DATABASE_NOT_FOUND_ERROR"
elif "SECURITY" in error_type_name:
return "SECURITY_ERROR"
elif "TIMEOUT" in error_type_name:
return "TIMEOUT"
elif "DML_NOT_ALLOWED" in error_type_name:
return "DML_NOT_ALLOWED"
else:
return error_type_name
return "EXECUTION_ERROR"
def _format_command_result(
self, command_result: dict[str, Any]
) -> ExecuteSqlResponse:
"""Format ExecuteSqlCommand result into ExecuteSqlResponse."""
# This would extract relevant fields from command result
# Placeholder implementation for future use
return ExecuteSqlResponse(
success=command_result.get("success", False),
rows=command_result.get("data"),
columns=command_result.get("columns"),
row_count=command_result.get("row_count"),
affected_rows=command_result.get("affected_rows"),
query_id=command_result.get("query_id"),
execution_time=command_result.get("execution_time"),
error=command_result.get("error"),
error_type=command_result.get("error_type"),
)

View File

@@ -0,0 +1,109 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Schemas for SQL Lab MCP tools."""
from typing import Any
from pydantic import BaseModel, Field, field_validator
class ExecuteSqlRequest(BaseModel):
"""Request schema for executing SQL queries."""
database_id: int = Field(
..., description="Database connection ID to execute query against"
)
sql: str = Field(..., description="SQL query to execute")
schema_name: str | None = Field(
None, description="Schema to use for query execution", alias="schema"
)
limit: int = Field(
default=1000,
description="Maximum number of rows to return",
ge=1,
le=10000,
)
timeout: int = Field(
default=30, description="Query timeout in seconds", ge=1, le=300
)
parameters: dict[str, Any] | None = Field(
None, description="Parameters for query substitution"
)
@field_validator("sql")
@classmethod
def sql_not_empty(cls, v: str) -> str:
if not v or not v.strip():
raise ValueError("SQL query cannot be empty")
return v.strip()
class ColumnInfo(BaseModel):
"""Column metadata information."""
name: str = Field(..., description="Column name")
type: str = Field(..., description="Column data type")
is_nullable: bool | None = Field(None, description="Whether column allows NULL")
class ExecuteSqlResponse(BaseModel):
"""Response schema for SQL execution results."""
success: bool = Field(..., description="Whether query executed successfully")
rows: Any | None = Field(
None, description="Query result rows as list of dictionaries"
)
columns: list[ColumnInfo] | None = Field(
None, description="Column metadata information"
)
row_count: int | None = Field(None, description="Number of rows returned")
affected_rows: int | None = Field(
None, description="Number of rows affected (for DML queries)"
)
query_id: str | None = Field(None, description="Query tracking ID")
execution_time: float | None = Field(
None, description="Query execution time in seconds"
)
error: str | None = Field(None, description="Error message if query failed")
error_type: str | None = Field(None, description="Type of error if failed")
class OpenSqlLabRequest(BaseModel):
"""Request schema for opening SQL Lab with context."""
database_connection_id: int = Field(
..., description="Database connection ID to use in SQL Lab"
)
schema_name: str | None = Field(
None, description="Default schema to select in SQL Lab", alias="schema"
)
dataset_in_context: str | None = Field(
None, description="Dataset name/table to provide as context"
)
sql: str | None = Field(None, description="SQL query to pre-populate in the editor")
title: str | None = Field(None, description="Title for the SQL Lab tab/query")
class SqlLabResponse(BaseModel):
"""Response schema for SQL Lab URL generation."""
url: str = Field(..., description="URL to open SQL Lab with context")
database_id: int = Field(..., description="Database ID used")
schema_name: str | None = Field(None, description="Schema selected", alias="schema")
title: str | None = Field(None, description="Query title")
error: str | None = Field(None, description="Error message if failed")

View File

@@ -0,0 +1,243 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Utility functions for SQL Lab MCP tools.
This module contains helper functions for SQL execution, validation,
and database access that are shared across SQL Lab tools.
"""
import logging
from typing import Any
logger = logging.getLogger(__name__)
def check_database_access(database_id: int) -> Any:
"""Check if user has access to the database."""
# Import inside function to avoid initialization issues
from superset import db, security_manager
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException, SupersetSecurityException
from superset.models.core import Database
# Use session query to ensure relationships are loaded
database = db.session.query(Database).filter_by(id=database_id).first()
if not database:
raise SupersetErrorException(
SupersetError(
message=f"Database with ID {database_id} not found",
error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
level=ErrorLevel.ERROR,
)
)
# Check database access permissions
if not security_manager.can_access_database(database):
raise SupersetSecurityException(
SupersetError(
message=f"Access denied to database {database.database_name}",
error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR,
level=ErrorLevel.ERROR,
)
)
return database
def validate_sql_query(sql: str, database: Any) -> None:
"""Validate SQL query for security and syntax."""
# Import inside function to avoid initialization issues
from flask import current_app as app
from superset.exceptions import (
SupersetDisallowedSQLFunctionException,
SupersetDMLNotAllowedException,
)
# Simplified validation without complex parsing
sql_upper = sql.upper().strip()
# Check for DML operations if not allowed
dml_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "TRUNCATE"]
if any(sql_upper.startswith(keyword) for keyword in dml_keywords):
if not database.allow_dml:
raise SupersetDMLNotAllowedException()
# Check for disallowed functions from config
disallowed_functions = app.config.get("DISALLOWED_SQL_FUNCTIONS", {}).get(
"sqlite",
set(), # Default to sqlite for now
)
if disallowed_functions:
sql_lower = sql.lower()
for func in disallowed_functions:
if f"{func.lower()}(" in sql_lower:
raise SupersetDisallowedSQLFunctionException(disallowed_functions)
def execute_sql_query(
database: Any,
sql: str,
schema: str | None,
limit: int,
timeout: int,
parameters: dict[str, Any] | None,
) -> dict[str, Any]:
"""Execute SQL query and return results."""
# Import inside function to avoid initialization issues
from superset.utils.dates import now_as_float
start_time = now_as_float()
# Apply parameters and validate
sql = _apply_parameters(sql, parameters)
validate_sql_query(sql, database)
# Apply limit for SELECT queries
rendered_sql = _apply_limit(sql, limit)
# Execute and get results
results = _execute_query(database, rendered_sql, schema, limit)
# Calculate execution time
end_time = now_as_float()
results["execution_time"] = end_time - start_time
return results
def _apply_parameters(sql: str, parameters: dict[str, Any] | None) -> str:
"""Apply parameters to SQL query."""
# Import inside function to avoid initialization issues
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import SupersetErrorException
if parameters:
try:
return sql.format(**parameters)
except KeyError as e:
raise SupersetErrorException(
SupersetError(
message=f"Missing parameter: {e}",
error_type=SupersetErrorType.INVALID_PAYLOAD_FORMAT_ERROR,
level=ErrorLevel.ERROR,
)
) from e
else:
# Check if SQL contains placeholders when no parameters provided
import re
placeholders = re.findall(r"{(\w+)}", sql)
if placeholders:
raise SupersetErrorException(
SupersetError(
message=f"Missing parameter: {placeholders[0]}",
error_type=SupersetErrorType.INVALID_PAYLOAD_FORMAT_ERROR,
level=ErrorLevel.ERROR,
)
)
return sql
def _apply_limit(sql: str, limit: int) -> str:
"""Apply limit to SELECT queries if not already present."""
sql_lower = sql.lower().strip()
if sql_lower.startswith("select") and "limit" not in sql_lower:
return f"{sql.rstrip().rstrip(';')} LIMIT {limit}"
return sql
def _execute_query(
database: Any,
sql: str,
schema: str | None,
limit: int,
) -> dict[str, Any]:
"""Execute the query and process results."""
# Import inside function to avoid initialization issues
from superset.utils.core import QuerySource
results = {
"rows": [],
"columns": [],
"row_count": 0,
"affected_rows": None,
"execution_time": 0.0,
}
try:
# Execute query with timeout
with database.get_raw_connection(
catalog=None,
schema=schema,
source=QuerySource.SQL_LAB,
) as conn:
cursor = conn.cursor()
cursor.execute(sql)
# Process results based on query type
if _is_select_query(sql):
_process_select_results(cursor, results, limit)
else:
_process_dml_results(cursor, conn, results)
except Exception as e:
logger.error("Error executing SQL: %s", e)
raise
return results
def _is_select_query(sql: str) -> bool:
"""Check if SQL is a SELECT query."""
return sql.lower().strip().startswith("select")
def _process_select_results(cursor: Any, results: dict[str, Any], limit: int) -> None:
"""Process SELECT query results."""
# Fetch results
data = cursor.fetchmany(limit)
# Get column metadata
column_info = []
if cursor.description:
for col in cursor.description:
column_info.append(
{
"name": col[0],
"type": str(col[1]) if col[1] else "unknown",
"is_nullable": col[6] if len(col) > 6 else None,
}
)
# Set column info regardless of whether there's data
if column_info:
results["columns"] = column_info
# Convert rows to dictionaries
column_names = [col["name"] for col in column_info]
results["rows"] = [dict(zip(column_names, row, strict=False)) for row in data]
results["row_count"] = len(data)
def _process_dml_results(cursor: Any, conn: Any, results: dict[str, Any]) -> None:
"""Process DML query results."""
results["affected_rows"] = cursor.rowcount
conn.commit() # pylint: disable=consider-using-transaction

View File

@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
SQL Lab MCP Tools
"""
from superset.mcp_service.sql_lab.tool.execute_sql import execute_sql
from superset.mcp_service.sql_lab.tool.open_sql_lab_with_context import (
open_sql_lab_with_context,
)
__all__ = [
"execute_sql",
"open_sql_lab_with_context",
]

View File

@@ -0,0 +1,94 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Execute SQL MCP Tool
Tool for executing SQL queries against databases with security validation
and timeout protection.
"""
import logging
from fastmcp import Context
from superset.mcp_service.app import mcp
from superset.mcp_service.auth import mcp_auth_hook
from superset.mcp_service.sql_lab.execute_sql_core import ExecuteSqlCore
from superset.mcp_service.sql_lab.schemas import (
ExecuteSqlRequest,
ExecuteSqlResponse,
)
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse:
"""Execute SQL query against database.
Returns query results with security validation and timeout protection.
"""
await ctx.info(
"Starting SQL execution: database_id=%s, timeout=%s, limit=%s, schema=%s"
% (request.database_id, request.timeout, request.limit, request.schema_name)
)
# Log SQL query details (truncated for security)
sql_preview = request.sql[:100] + "..." if len(request.sql) > 100 else request.sql
await ctx.debug(
"SQL query details: sql_preview=%r, sql_length=%s, has_parameters=%s"
% (
sql_preview,
len(request.sql),
bool(request.parameters),
)
)
logger.info("Executing SQL query on database ID: %s", request.database_id)
try:
# Use the ExecuteSqlCore to handle all the logic
sql_tool = ExecuteSqlCore(use_command_mode=False, logger=logger)
result = sql_tool.run_tool(request)
# Log successful execution
if hasattr(result, "data") and result.data:
row_count = len(result.data) if isinstance(result.data, list) else 1
await ctx.info(
"SQL execution completed successfully: rows_returned=%s, "
"query_duration_ms=%s"
% (
row_count,
getattr(result, "query_duration_ms", None),
)
)
else:
await ctx.info("SQL execution completed: status=no_data_returned")
return result
except Exception as e:
await ctx.error(
"SQL execution failed: error=%s, database_id=%s"
% (
str(e),
request.database_id,
)
)
raise

View File

@@ -0,0 +1,118 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Open SQL Lab with Context MCP Tool
Tool for generating SQL Lab URLs with pre-populated query and context.
"""
import logging
from urllib.parse import urlencode
from fastmcp import Context
from superset.mcp_service.app import mcp
from superset.mcp_service.auth import mcp_auth_hook
from superset.mcp_service.sql_lab.schemas import (
OpenSqlLabRequest,
SqlLabResponse,
)
logger = logging.getLogger(__name__)
@mcp.tool
@mcp_auth_hook
def open_sql_lab_with_context(
request: OpenSqlLabRequest, ctx: Context
) -> SqlLabResponse:
"""Generate SQL Lab URL with pre-populated query and context.
Returns URL for direct navigation.
"""
try:
from superset.daos.database import DatabaseDAO
# Validate database exists and is accessible
database = DatabaseDAO.find_by_id(request.database_connection_id)
if not database:
return SqlLabResponse(
url="",
database_id=request.database_connection_id,
schema_name=request.schema_name,
title=request.title,
error=f"Database with ID {request.database_connection_id} not found",
)
# Build query parameters for SQL Lab URL
params = {
"dbid": str(request.database_connection_id),
}
if request.schema_name:
params["schema"] = request.schema_name
if request.sql:
params["sql"] = request.sql
if request.title:
params["title"] = request.title
if request.dataset_in_context:
# Add dataset context as a comment in the SQL if no SQL provided
if not request.sql:
context_comment = (
f"-- Context: Working with dataset '{request.dataset_in_context}'\n"
f"-- Database: {database.database_name}\n"
)
if request.schema_name:
context_comment += f"-- Schema: {request.schema_name}\n"
table_reference = (
f"{request.schema_name}.{request.dataset_in_context}"
)
else:
table_reference = request.dataset_in_context
context_comment += f"\nSELECT * FROM {table_reference} LIMIT 100;"
params["sql"] = context_comment
# Construct SQL Lab URL
query_string = urlencode(params)
url = f"/sqllab?{query_string}"
logger.info(
"Generated SQL Lab URL for database %s", request.database_connection_id
)
return SqlLabResponse(
url=url,
database_id=request.database_connection_id,
schema_name=request.schema_name,
title=request.title,
error=None,
)
except Exception as e:
logger.error("Error generating SQL Lab URL: %s", e)
return SqlLabResponse(
url="",
database_id=request.database_connection_id,
schema_name=request.schema_name,
title=request.title,
error=f"Failed to generate SQL Lab URL: {str(e)}",
)