mirror of
https://github.com/apache/superset.git
synced 2026-04-30 13:34:20 +00:00
Compare commits
58 Commits
extensions
...
mcp_servic
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d7467bafd | ||
|
|
71f294b7d3 | ||
|
|
e839d0989a | ||
|
|
714e21b3ec | ||
|
|
bf78eb69ed | ||
|
|
a13c1ba8c2 | ||
|
|
422c34a6ee | ||
|
|
cd213fc57d | ||
|
|
0f20a88598 | ||
|
|
33d16eaca1 | ||
|
|
5b56bc622b | ||
|
|
8cbfe027b5 | ||
|
|
c3b3edc6ba | ||
|
|
aa06bb9fda | ||
|
|
0f222b9034 | ||
|
|
7044153ca4 | ||
|
|
ce82c35bb6 | ||
|
|
9b25bd973f | ||
|
|
5f7502b85c | ||
|
|
1d5372210f | ||
|
|
64af53f6f6 | ||
|
|
56b308340e | ||
|
|
3b2b7609a8 | ||
|
|
61a91f80fe | ||
|
|
ed3c5ecbc2 | ||
|
|
c8fbd4233c | ||
|
|
fc85f68585 | ||
|
|
e825bbe1f4 | ||
|
|
a80b637a2a | ||
|
|
541b3bd727 | ||
|
|
a909799e5c | ||
|
|
01329f1c62 | ||
|
|
67f621d360 | ||
|
|
b552dbf4a1 | ||
|
|
364af98c04 | ||
|
|
afdb8b38a6 | ||
|
|
fc7ea804bc | ||
|
|
7c256ae9aa | ||
|
|
1b190abc3b | ||
|
|
c6c71bf835 | ||
|
|
cd5ead7f11 | ||
|
|
e5eebe28f9 | ||
|
|
9eac6ef433 | ||
|
|
d523d523e5 | ||
|
|
91a3214ed4 | ||
|
|
95b787f024 | ||
|
|
39121791e8 | ||
|
|
9d40fe913f | ||
|
|
748ae49c8c | ||
|
|
a9d543b6f4 | ||
|
|
55d6130fc4 | ||
|
|
b98e3eb309 | ||
|
|
b469077e0e | ||
|
|
397b4e450b | ||
|
|
0f97002520 | ||
|
|
2312250127 | ||
|
|
cd52193869 | ||
|
|
9ffe680aaa |
19
.devcontainer/default/devcontainer.json
Normal file
19
.devcontainer/default/devcontainer.json
Normal file
@@ -0,0 +1,19 @@
|
||||
{
|
||||
// Extend the base configuration
|
||||
"extends": "../devcontainer-base.json",
|
||||
|
||||
"name": "Apache Superset Development (Default)",
|
||||
|
||||
// Forward ports for development
|
||||
"forwardPorts": [9001],
|
||||
"portsAttributes": {
|
||||
"9001": {
|
||||
"label": "Superset (via Webpack Dev Server)",
|
||||
"onAutoForward": "notify",
|
||||
"visibility": "public"
|
||||
}
|
||||
},
|
||||
|
||||
// Auto-start Superset on Codespace resume
|
||||
"postStartCommand": ".devcontainer/start-superset.sh"
|
||||
}
|
||||
39
.devcontainer/devcontainer-base.json
Normal file
39
.devcontainer/devcontainer-base.json
Normal file
@@ -0,0 +1,39 @@
|
||||
{
|
||||
"name": "Apache Superset Development",
|
||||
// Keep this in sync with the base image in Dockerfile (ARG PY_VER)
|
||||
// Using the same base as Dockerfile, but non-slim for dev tools
|
||||
"image": "python:3.11.13-bookworm",
|
||||
|
||||
"features": {
|
||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {
|
||||
"moby": true,
|
||||
"dockerDashComposeVersion": "v2"
|
||||
},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"version": "20"
|
||||
},
|
||||
"ghcr.io/devcontainers/features/git:1": {},
|
||||
"ghcr.io/devcontainers/features/common-utils:2": {
|
||||
"configureZshAsDefaultShell": true
|
||||
},
|
||||
"ghcr.io/devcontainers/features/sshd:1": {
|
||||
"version": "latest"
|
||||
}
|
||||
},
|
||||
|
||||
// Run commands after container is created
|
||||
"postCreateCommand": "chmod +x .devcontainer/setup-dev.sh && .devcontainer/setup-dev.sh",
|
||||
|
||||
// VS Code customizations
|
||||
"customizations": {
|
||||
"vscode": {
|
||||
"extensions": [
|
||||
"ms-python.python",
|
||||
"ms-python.vscode-pylance",
|
||||
"charliermarsh.ruff",
|
||||
"dbaeumer.vscode-eslint",
|
||||
"esbenp.prettier-vscode"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,11 @@
|
||||
echo "🚀 Starting Superset in Codespaces..."
|
||||
echo "🌐 Frontend will be available at port 9001"
|
||||
|
||||
# Check if MCP is enabled
|
||||
if [ "$ENABLE_MCP" = "true" ]; then
|
||||
echo "🤖 MCP Service will be available at port 5008"
|
||||
fi
|
||||
|
||||
# Find the workspace directory (Codespaces clones as 'superset', not 'superset-2')
|
||||
WORKSPACE_DIR=$(find /workspaces -maxdepth 1 -name "superset*" -type d | head -1)
|
||||
if [ -n "$WORKSPACE_DIR" ]; then
|
||||
@@ -21,7 +26,7 @@ fi
|
||||
|
||||
# Clean up any existing containers
|
||||
echo "🧹 Cleaning up existing containers..."
|
||||
docker-compose -f docker-compose-light.yml down
|
||||
docker-compose -f docker-compose-light.yml --profile mcp down
|
||||
|
||||
# Start services
|
||||
echo "🏗️ Building and starting services..."
|
||||
@@ -33,7 +38,12 @@ echo ""
|
||||
echo "📋 Running in foreground with live logs (Ctrl+C to stop)..."
|
||||
|
||||
# Run docker-compose and capture exit code
|
||||
docker-compose -f docker-compose-light.yml up
|
||||
if [ "$ENABLE_MCP" = "true" ]; then
|
||||
echo "🤖 Starting with MCP Service enabled..."
|
||||
docker-compose -f docker-compose-light.yml --profile mcp up
|
||||
else
|
||||
docker-compose -f docker-compose-light.yml up
|
||||
fi
|
||||
EXIT_CODE=$?
|
||||
|
||||
# If it failed, provide helpful instructions
|
||||
|
||||
29
.devcontainer/with-mcp/devcontainer.json
Normal file
29
.devcontainer/with-mcp/devcontainer.json
Normal file
@@ -0,0 +1,29 @@
|
||||
{
|
||||
// Extend the base configuration
|
||||
"extends": "../devcontainer-base.json",
|
||||
|
||||
"name": "Apache Superset Development with MCP",
|
||||
|
||||
// Forward ports for development
|
||||
"forwardPorts": [9001, 5008],
|
||||
"portsAttributes": {
|
||||
"9001": {
|
||||
"label": "Superset (via Webpack Dev Server)",
|
||||
"onAutoForward": "notify",
|
||||
"visibility": "public"
|
||||
},
|
||||
"5008": {
|
||||
"label": "MCP Service (Model Context Protocol)",
|
||||
"onAutoForward": "notify",
|
||||
"visibility": "private"
|
||||
}
|
||||
},
|
||||
|
||||
// Auto-start Superset with MCP on Codespace resume
|
||||
"postStartCommand": "ENABLE_MCP=true .devcontainer/start-superset.sh",
|
||||
|
||||
// Environment variables
|
||||
"containerEnv": {
|
||||
"ENABLE_MCP": "true"
|
||||
}
|
||||
}
|
||||
215
CHART_METADATA_API.md
Normal file
215
CHART_METADATA_API.md
Normal file
@@ -0,0 +1,215 @@
|
||||
# Chart Metadata API Reference
|
||||
|
||||
The Superset MCP service provides rich metadata alongside chart generation to enable better UI integration and user experiences.
|
||||
|
||||
## Background & Design Philosophy
|
||||
|
||||
Modern chart systems need to provide more than just visual output. Inspired by contemporary web standards and LLM integration patterns, this metadata system addresses several key needs:
|
||||
|
||||
**Accessibility-First Design**: Following WCAG guidelines and `aria-*` attribute patterns, charts include semantic descriptions and accessibility metadata to ensure inclusive experiences.
|
||||
|
||||
**Rich Context for AI Systems**: Similar to how platforms like social media generate rich previews (OpenGraph, Twitter Cards), charts provide semantic understanding beyond just visual representation - enabling AI agents to reason about and describe visualizations meaningfully.
|
||||
|
||||
**Performance-Aware Integration**: Modern web APIs emphasize performance transparency (Core Web Vitals, etc.). Charts include execution metrics and optimization suggestions to help UIs make informed decisions about rendering and user feedback.
|
||||
|
||||
**Capability-Driven UX**: Rather than requiring UIs to hardcode chart type behaviors, the system exposes what each chart can actually do - enabling dynamic, contextual interfaces that adapt to chart capabilities.
|
||||
|
||||
## Overview
|
||||
|
||||
When generating charts via `generate_chart`, the response includes structured metadata that helps UIs:
|
||||
- Present appropriate controls and interactions
|
||||
- Generate accessible descriptions
|
||||
- Optimize rendering performance
|
||||
- Guide user workflows
|
||||
|
||||
## Metadata Types
|
||||
|
||||
### ChartCapabilities
|
||||
|
||||
Describes what interactions and features the chart supports.
|
||||
|
||||
```python
|
||||
{
|
||||
"supports_interaction": bool, # User can interact (zoom, pan, hover)
|
||||
"supports_real_time": bool, # Chart can update with live data
|
||||
"supports_drill_down": bool, # Can navigate to more detailed views
|
||||
"supports_export": bool, # Can be exported to other formats
|
||||
"optimal_formats": [ # Recommended preview formats
|
||||
"url", # Static image URL
|
||||
"interactive", # HTML with JavaScript controls
|
||||
"ascii", # Text-based representation
|
||||
"vega_lite" # Vega-Lite specification
|
||||
],
|
||||
"data_types": [ # Types of data visualized
|
||||
"time_series", # Time-based data
|
||||
"categorical", # Discrete categories
|
||||
"metric" # Numeric measurements
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**UI Integration:**
|
||||
- Show/hide interaction controls based on `supports_interaction`
|
||||
- Enable real-time updates if `supports_real_time`
|
||||
- Display drill-down options for `supports_drill_down`
|
||||
- Choose optimal preview format from `optimal_formats`
|
||||
|
||||
### ChartSemantics
|
||||
|
||||
Provides semantic understanding of what the chart represents and reveals.
|
||||
|
||||
```python
|
||||
{
|
||||
"primary_insight": "Shows trends and changes over time",
|
||||
"data_story": "This line chart analyzes sales, revenue over Q1-Q4",
|
||||
"recommended_actions": [
|
||||
"Review data patterns and trends",
|
||||
"Consider filtering for more detail",
|
||||
"Export chart for reporting"
|
||||
],
|
||||
"anomalies": [], # Notable outliers (future enhancement)
|
||||
"statistical_summary": {} # Key statistics (future enhancement)
|
||||
}
|
||||
```
|
||||
|
||||
**UI Integration:**
|
||||
- Display `primary_insight` as chart description
|
||||
- Use `data_story` for accessibility and tooltips
|
||||
- Show `recommended_actions` as suggested next steps
|
||||
- Highlight `anomalies` in the visualization
|
||||
|
||||
### AccessibilityMetadata
|
||||
|
||||
Information for creating inclusive, accessible chart experiences.
|
||||
|
||||
```python
|
||||
{
|
||||
"color_blind_safe": bool, # Uses colorblind-friendly palette
|
||||
"alt_text": "Chart showing Sales Data over time",
|
||||
"high_contrast_available": bool # High contrast version available
|
||||
}
|
||||
```
|
||||
|
||||
**UI Integration:**
|
||||
- Use `alt_text` for screen readers
|
||||
- Show accessibility indicators if `color_blind_safe`
|
||||
- Offer high contrast mode if available
|
||||
|
||||
### PerformanceMetadata
|
||||
|
||||
Performance information for optimization and user feedback.
|
||||
|
||||
```python
|
||||
{
|
||||
"query_duration_ms": 1250, # Time to generate chart data
|
||||
"cache_status": "hit|miss|error", # Whether data came from cache
|
||||
"optimization_suggestions": [ # Performance improvement tips
|
||||
"Consider adding date filters to reduce data volume",
|
||||
"Chart complexity may impact load time"
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
**UI Integration:**
|
||||
- Show loading indicators based on `query_duration_ms`
|
||||
- Display cache status for debugging
|
||||
- Present `optimization_suggestions` to users
|
||||
- Warn about slow queries
|
||||
|
||||
## Example Response
|
||||
|
||||
```json
|
||||
{
|
||||
"chart": {
|
||||
"id": 123,
|
||||
"slice_name": "Sales Trends Q1-Q4",
|
||||
"viz_type": "echarts_timeseries_line",
|
||||
"url": "/explore/?slice_id=123"
|
||||
},
|
||||
"capabilities": {
|
||||
"supports_interaction": true,
|
||||
"supports_real_time": false,
|
||||
"supports_drill_down": false,
|
||||
"supports_export": true,
|
||||
"optimal_formats": ["url", "interactive", "ascii"],
|
||||
"data_types": ["time_series", "metric"]
|
||||
},
|
||||
"semantics": {
|
||||
"primary_insight": "Shows trends and changes over time",
|
||||
"data_story": "This line chart analyzes sales over Q1-Q4",
|
||||
"recommended_actions": [
|
||||
"Review seasonal patterns",
|
||||
"Export for quarterly report"
|
||||
]
|
||||
},
|
||||
"accessibility": {
|
||||
"color_blind_safe": true,
|
||||
"alt_text": "Line chart showing sales trends from Q1 to Q4",
|
||||
"high_contrast_available": false
|
||||
},
|
||||
"performance": {
|
||||
"query_duration_ms": 450,
|
||||
"cache_status": "miss",
|
||||
"optimization_suggestions": []
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### React Component Integration
|
||||
|
||||
```jsx
|
||||
function ChartComponent({ chartData }) {
|
||||
const { capabilities, semantics, accessibility, performance } = chartData;
|
||||
|
||||
return (
|
||||
<div>
|
||||
{/* Accessibility */}
|
||||
<img
|
||||
src={chartData.chart.url}
|
||||
alt={accessibility.alt_text}
|
||||
aria-describedby="chart-description"
|
||||
/>
|
||||
|
||||
{/* Semantic description */}
|
||||
<p id="chart-description">{semantics.primary_insight}</p>
|
||||
|
||||
{/* Conditional controls based on capabilities */}
|
||||
{capabilities.supports_interaction && (
|
||||
<InteractiveControls />
|
||||
)}
|
||||
|
||||
{capabilities.supports_export && (
|
||||
<ExportButton />
|
||||
)}
|
||||
|
||||
{/* Performance feedback */}
|
||||
{performance.query_duration_ms > 2000 && (
|
||||
<SlowQueryWarning suggestions={performance.optimization_suggestions} />
|
||||
)}
|
||||
|
||||
{/* Recommended actions */}
|
||||
<ActionSuggestions actions={semantics.recommended_actions} />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
## Chart Type Mapping
|
||||
|
||||
Different chart types provide different capabilities:
|
||||
|
||||
| Chart Type | Interaction | Real-time | Drill-down | Optimal Formats |
|
||||
|------------|------------|-----------|------------|-----------------|
|
||||
| `echarts_timeseries_line` | ✅ | ✅ | ❌ | url, interactive, ascii |
|
||||
| `echarts_timeseries_bar` | ✅ | ✅ | ❌ | url, interactive, ascii |
|
||||
| `table` | ❌ | ❌ | ✅ | url, table, ascii |
|
||||
| `pie` | ✅ | ❌ | ❌ | url, interactive |
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
- **Statistical Summary**: Automatic calculation of mean, median, trends
|
||||
- **Anomaly Detection**: Identification of outliers and unusual patterns
|
||||
- **Smart Recommendations**: ML-powered suggestions for chart improvements
|
||||
- **Accessibility Scoring**: Automated accessibility compliance checking
|
||||
1
LLMS.md
1
LLMS.md
@@ -180,6 +180,7 @@ pre-commit run eslint # Frontend linting
|
||||
|
||||
## Platform-Specific Instructions
|
||||
|
||||
- **[LLMS.md](LLMS.md)** - General LLM development guide (READ THIS FIRST)
|
||||
- **[CLAUDE.md](CLAUDE.md)** - For Claude/Anthropic tools
|
||||
- **[.github/copilot-instructions.md](.github/copilot-instructions.md)** - For GitHub Copilot
|
||||
- **[GEMINI.md](GEMINI.md)** - For Google Gemini tools
|
||||
|
||||
@@ -25,6 +25,12 @@
|
||||
# - Volumes are isolated by project name (e.g., project1_db_home_light, project2_db_home_light)
|
||||
# - Database name is intentionally different (superset_light) to prevent accidental cross-connections
|
||||
#
|
||||
# MCP Service (Model Context Protocol):
|
||||
# - Optional service for LLM agent integration, available under 'mcp' profile
|
||||
# - To include MCP: docker-compose -f docker-compose-light.yml --profile mcp up
|
||||
# - MCP runs on port 5008 by default (customize with MCP_PORT=5009)
|
||||
# - Enable SQL debugging with MCP_SQL_DEBUG=true
|
||||
#
|
||||
# For verbose logging during development:
|
||||
# - Set SUPERSET_LOG_LEVEL=debug in docker/.env-local for detailed Superset logs
|
||||
# -----------------------------------------------------------------------
|
||||
@@ -150,6 +156,37 @@ services:
|
||||
required: false
|
||||
volumes: *superset-volumes
|
||||
|
||||
superset-mcp-light:
|
||||
profiles:
|
||||
- mcp
|
||||
build:
|
||||
<<: *common-build
|
||||
command: ["/app/docker/docker-bootstrap.sh", "mcp"]
|
||||
restart: unless-stopped
|
||||
ports:
|
||||
- "127.0.0.1:${MCP_PORT:-5008}:5008" # Parameterized port
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
user: *superset-user
|
||||
depends_on:
|
||||
superset-init-light:
|
||||
condition: service_completed_successfully
|
||||
volumes: *superset-volumes
|
||||
env_file:
|
||||
- path: docker/.env # default
|
||||
required: true
|
||||
- path: docker/.env-local # optional override
|
||||
required: false
|
||||
environment:
|
||||
# Override DB connection for light service
|
||||
DATABASE_HOST: db-light
|
||||
DATABASE_DB: superset_light
|
||||
POSTGRES_DB: superset_light
|
||||
# Use light-specific config that disables Redis
|
||||
SUPERSET_CONFIG_PATH: /app/docker/pythonpath_dev/superset_config_docker_light.py
|
||||
# Enable SQL debugging for MCP if needed
|
||||
SQLALCHEMY_DEBUG: ${MCP_SQL_DEBUG:-false}
|
||||
|
||||
volumes:
|
||||
superset_home_light:
|
||||
external: false
|
||||
|
||||
@@ -78,6 +78,10 @@ case "${1}" in
|
||||
echo "Starting web app..."
|
||||
/usr/bin/run-server.sh
|
||||
;;
|
||||
mcp)
|
||||
echo "Starting MCP service..."
|
||||
superset mcp run --host 0.0.0.0 --port ${MCP_PORT:-5008} --debug
|
||||
;;
|
||||
*)
|
||||
echo "Unknown Operation!!!"
|
||||
;;
|
||||
|
||||
741
docs/docs/mcp-service/api-reference.mdx
Normal file
741
docs/docs/mcp-service/api-reference.mdx
Normal file
@@ -0,0 +1,741 @@
|
||||
---
|
||||
title: API Reference
|
||||
sidebar_position: 3
|
||||
version: 1
|
||||
---
|
||||
|
||||
# MCP Tools API Reference
|
||||
|
||||
Complete reference for all 16 MCP tools with request/response examples.
|
||||
|
||||
> 🚀 **First time here?** Start with [Dashboard Tools](#dashboard-tools) or [Chart Tools](#chart-tools) to see the most commonly used features.
|
||||
>
|
||||
> 🔐 **Need authentication?** See the [Authentication Guide](./authentication) for JWT setup.
|
||||
>
|
||||
> 🔧 **Want to add tools?** Check the [Development Guide](./development#adding-new-tools) for step-by-step instructions.
|
||||
|
||||
## Dashboard Tools
|
||||
|
||||
### list_dashboards
|
||||
|
||||
List dashboards with search, filtering, and pagination support.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"search": "sales", // Optional: Search term
|
||||
"filters": [ // Optional: Advanced filters
|
||||
{
|
||||
"col": "published",
|
||||
"opr": "eq",
|
||||
"value": true
|
||||
}
|
||||
],
|
||||
"page": 1, // Optional: Page number (default: 1)
|
||||
"page_size": 20, // Optional: Items per page (default: 20)
|
||||
"select_columns": [ // Optional: Specific columns
|
||||
"id", "dashboard_title", "uuid"
|
||||
],
|
||||
"use_cache": true // Optional: Use cached data (default: true)
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"dashboards": [
|
||||
{
|
||||
"id": 1,
|
||||
"uuid": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||
"dashboard_title": "Sales Performance",
|
||||
"url": "/superset/dashboard/1/",
|
||||
"published": true,
|
||||
"owners": ["admin"],
|
||||
"created_on": "2024-01-15T10:30:00Z",
|
||||
"changed_on": "2024-01-20T14:15:00Z"
|
||||
}
|
||||
],
|
||||
"total_count": 45,
|
||||
"page": 1,
|
||||
"page_size": 20,
|
||||
"cache_status": {
|
||||
"cache_hit": true,
|
||||
"cache_age_seconds": 300
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### get_dashboard_info
|
||||
|
||||
Get detailed information about a specific dashboard.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"identifier": "a1b2c3d4-e5f6-7890-abcd-ef1234567890", // ID, UUID, or slug
|
||||
"use_cache": true
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"dashboard_id": 1,
|
||||
"uuid": "a1b2c3d4-e5f6-7890-abcd-ef1234567890",
|
||||
"dashboard_title": "Sales Performance Dashboard",
|
||||
"slug": "sales-performance",
|
||||
"url": "/superset/dashboard/1/",
|
||||
"published": true,
|
||||
"owners": ["admin", "analyst"],
|
||||
"roles": ["Sales Team"],
|
||||
"charts": [
|
||||
{
|
||||
"id": 10,
|
||||
"slice_name": "Monthly Revenue",
|
||||
"viz_type": "line"
|
||||
},
|
||||
{
|
||||
"id": 11,
|
||||
"slice_name": "Regional Sales",
|
||||
"viz_type": "bar"
|
||||
}
|
||||
],
|
||||
"filters": [
|
||||
{
|
||||
"column": "region",
|
||||
"type": "select"
|
||||
}
|
||||
],
|
||||
"created_on": "2024-01-15T10:30:00Z",
|
||||
"changed_on": "2024-01-20T14:15:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
### generate_dashboard
|
||||
|
||||
Create a new dashboard with multiple charts.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"chart_ids": [10, 11, 12, 13],
|
||||
"dashboard_title": "Q4 Performance Dashboard",
|
||||
"description": "Quarterly performance metrics and KPIs",
|
||||
"published": true,
|
||||
"layout_type": "grid" // Optional: "grid" or "tabs"
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"dashboard_id": 25,
|
||||
"uuid": "new-dash-uuid-here",
|
||||
"dashboard_title": "Q4 Performance Dashboard",
|
||||
"url": "/superset/dashboard/25/",
|
||||
"charts_added": 4,
|
||||
"layout": {
|
||||
"type": "grid",
|
||||
"columns": 2,
|
||||
"rows": 2
|
||||
},
|
||||
"created_on": "2024-01-25T16:45:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
## Chart Tools
|
||||
|
||||
### list_charts
|
||||
|
||||
List charts with advanced filtering and search capabilities.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"search": "revenue",
|
||||
"filters": [
|
||||
{
|
||||
"col": "viz_type",
|
||||
"opr": "in",
|
||||
"value": ["line", "bar", "area"]
|
||||
}
|
||||
],
|
||||
"page": 1,
|
||||
"page_size": 25,
|
||||
"select_columns": ["id", "slice_name", "viz_type", "uuid"],
|
||||
"use_cache": true
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"charts": [
|
||||
{
|
||||
"id": 10,
|
||||
"uuid": "chart-uuid-1",
|
||||
"slice_name": "Monthly Revenue Trend",
|
||||
"viz_type": "line",
|
||||
"datasource_name": "sales_data",
|
||||
"owners": ["admin"],
|
||||
"created_on": "2024-01-10T09:15:00Z"
|
||||
}
|
||||
],
|
||||
"total_count": 125,
|
||||
"page": 1,
|
||||
"page_size": 25
|
||||
}
|
||||
```
|
||||
|
||||
### get_chart_info
|
||||
|
||||
Get comprehensive chart information including configuration.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"identifier": 10, // ID or UUID
|
||||
"include_form_data": true, // Include chart configuration
|
||||
"use_cache": true
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"chart_id": 10,
|
||||
"uuid": "chart-uuid-1",
|
||||
"slice_name": "Monthly Revenue Trend",
|
||||
"viz_type": "line",
|
||||
"datasource_id": 5,
|
||||
"datasource_name": "sales_data",
|
||||
"datasource_type": "table",
|
||||
"form_data": {
|
||||
"viz_type": "line",
|
||||
"x_axis": "month",
|
||||
"metrics": ["sum__revenue"],
|
||||
"time_range": "Last 12 months"
|
||||
},
|
||||
"query_context": {
|
||||
"datasource": {"id": 5, "type": "table"},
|
||||
"queries": [{"columns": [], "metrics": ["sum__revenue"]}]
|
||||
},
|
||||
"explore_url": "/superset/explore/?form_data=%7B%22slice_id%22%3A10%7D",
|
||||
"owners": ["admin"],
|
||||
"created_on": "2024-01-10T09:15:00Z",
|
||||
"changed_on": "2024-01-15T11:30:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
### generate_chart
|
||||
|
||||
Create a new chart with specified configuration.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"dataset_id": "5",
|
||||
"config": {
|
||||
"chart_type": "xy",
|
||||
"x": {"name": "month", "label": "Month"},
|
||||
"y": [
|
||||
{
|
||||
"name": "revenue",
|
||||
"aggregate": "SUM",
|
||||
"label": "Total Revenue"
|
||||
},
|
||||
{
|
||||
"name": "orders",
|
||||
"aggregate": "COUNT",
|
||||
"label": "Order Count"
|
||||
}
|
||||
],
|
||||
"kind": "line",
|
||||
"x_axis": {
|
||||
"title": "Month",
|
||||
"format": "smart_date"
|
||||
},
|
||||
"y_axis": {
|
||||
"title": "Revenue ($)",
|
||||
"format": "$,.0f"
|
||||
},
|
||||
"legend": {
|
||||
"show": true,
|
||||
"position": "top"
|
||||
}
|
||||
},
|
||||
"slice_name": "Revenue and Orders Trend",
|
||||
"description": "Monthly revenue and order count comparison",
|
||||
"save_chart": true,
|
||||
"generate_preview": true,
|
||||
"preview_formats": ["url", "ascii"]
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"chart_id": 45,
|
||||
"uuid": "new-chart-uuid",
|
||||
"slice_name": "Revenue and Orders Trend",
|
||||
"viz_type": "echarts_timeseries_line",
|
||||
"datasource_id": 5,
|
||||
"explore_url": "/superset/explore/?form_data=%7B%22slice_id%22%3A45%7D",
|
||||
"query_executed": true,
|
||||
"query_result": {
|
||||
"status": "success",
|
||||
"row_count": 12,
|
||||
"execution_time": 0.145
|
||||
},
|
||||
"preview": {
|
||||
"url": {
|
||||
"preview_url": "http://localhost:5008/screenshot/chart/45.png",
|
||||
"width": 800,
|
||||
"height": 600
|
||||
},
|
||||
"ascii": {
|
||||
"ascii_content": "Revenue Trend\n==============\nJan |████████████████ $125K\nFeb |██████████████████ $140K\n...",
|
||||
"width": 80,
|
||||
"height": 20
|
||||
}
|
||||
},
|
||||
"created_on": "2024-01-25T14:20:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
### get_chart_data
|
||||
|
||||
Export chart data in multiple formats.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"identifier": 10,
|
||||
"format": "json", // "json", "csv", "excel"
|
||||
"limit": 1000, // Optional: Row limit
|
||||
"offset": 0, // Optional: Row offset
|
||||
"filters": [ // Optional: Additional filters
|
||||
{
|
||||
"column": "region",
|
||||
"op": "=",
|
||||
"value": "US"
|
||||
}
|
||||
],
|
||||
"use_cache": true,
|
||||
"force_refresh": false
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{
|
||||
"month": "2024-01",
|
||||
"revenue": 125000,
|
||||
"orders": 450
|
||||
},
|
||||
{
|
||||
"month": "2024-02",
|
||||
"revenue": 140000,
|
||||
"orders": 520
|
||||
}
|
||||
],
|
||||
"total_rows": 12,
|
||||
"columns": [
|
||||
{"name": "month", "type": "DATE"},
|
||||
{"name": "revenue", "type": "BIGINT"},
|
||||
{"name": "orders", "type": "BIGINT"}
|
||||
],
|
||||
"query": {
|
||||
"sql": "SELECT month, SUM(revenue) as revenue, COUNT(*) as orders FROM sales_data GROUP BY month ORDER BY month",
|
||||
"execution_time": 0.089
|
||||
},
|
||||
"cache_status": {
|
||||
"cache_hit": false,
|
||||
"cache_type": "query",
|
||||
"refreshed": true
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### get_chart_preview
|
||||
|
||||
Generate chart previews in multiple formats.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"identifier": 10,
|
||||
"format": "url", // "url", "base64", "ascii", "table"
|
||||
"width": 800, // For image formats
|
||||
"height": 600, // For image formats
|
||||
"ascii_width": 80, // For ASCII format
|
||||
"ascii_height": 20, // For ASCII format
|
||||
"use_cache": true
|
||||
}
|
||||
```
|
||||
|
||||
**Response Examples:**
|
||||
|
||||
**URL Format:**
|
||||
```json
|
||||
{
|
||||
"format": "url",
|
||||
"preview_url": "http://localhost:5008/screenshot/chart/10.png",
|
||||
"width": 800,
|
||||
"height": 600,
|
||||
"supports_interaction": false,
|
||||
"expires_at": "2024-01-26T14:20:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
**ASCII Format:**
|
||||
```json
|
||||
{
|
||||
"format": "ascii",
|
||||
"ascii_content": "Monthly Revenue Trend\n=====================\n\nJan |████████████████████ $125K\nFeb |██████████████████████ $140K\nMar |███████████████████ $135K\nApr |█████████████████████████ $155K\n\nRange: $125K to $155K\n▁▃▂▅▇▆▄▃▂▄▅▆▇▅▃▂",
|
||||
"width": 80,
|
||||
"height": 20,
|
||||
"supports_color": false
|
||||
}
|
||||
```
|
||||
|
||||
**Table Format:**
|
||||
```json
|
||||
{
|
||||
"format": "table",
|
||||
"table_data": "Monthly Revenue Data\n====================\n\nMonth | Revenue | Orders\n---------|----------|--------\nJan 2024 | $125,000 | 450\nFeb 2024 | $140,000 | 520\nMar 2024 | $135,000 | 495\n\nTotal: 12 rows × 3 columns",
|
||||
"row_count": 12,
|
||||
"supports_sorting": true
|
||||
}
|
||||
```
|
||||
|
||||
## Dataset Tools
|
||||
|
||||
### list_datasets
|
||||
|
||||
List available datasets with columns and metrics.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"search": "sales",
|
||||
"filters": [
|
||||
{
|
||||
"col": "is_active",
|
||||
"opr": "eq",
|
||||
"value": true
|
||||
}
|
||||
],
|
||||
"include_columns": true, // Include column metadata
|
||||
"include_metrics": true, // Include metric metadata
|
||||
"page": 1,
|
||||
"page_size": 15
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"datasets": [
|
||||
{
|
||||
"id": 1,
|
||||
"uuid": "dataset-uuid-1",
|
||||
"table_name": "sales_data",
|
||||
"database_name": "main_warehouse",
|
||||
"schema": "public",
|
||||
"owners": ["admin"],
|
||||
"columns": [
|
||||
{
|
||||
"column_name": "region",
|
||||
"type": "VARCHAR",
|
||||
"is_active": true,
|
||||
"is_dttm": false
|
||||
},
|
||||
{
|
||||
"column_name": "revenue",
|
||||
"type": "DECIMAL",
|
||||
"is_active": true,
|
||||
"is_dttm": false
|
||||
}
|
||||
],
|
||||
"metrics": [
|
||||
{
|
||||
"metric_name": "sum__revenue",
|
||||
"expression": "SUM(revenue)",
|
||||
"metric_type": "sum"
|
||||
}
|
||||
],
|
||||
"created_on": "2024-01-05T08:00:00Z"
|
||||
}
|
||||
],
|
||||
"total_count": 23,
|
||||
"page": 1,
|
||||
"page_size": 15
|
||||
}
|
||||
```
|
||||
|
||||
### get_dataset_info
|
||||
|
||||
Get detailed dataset information with full column/metric metadata.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"identifier": "dataset-uuid-1", // ID or UUID
|
||||
"include_columns": true,
|
||||
"include_metrics": true,
|
||||
"use_cache": true
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"uuid": "dataset-uuid-1",
|
||||
"table_name": "sales_data",
|
||||
"database_name": "main_warehouse",
|
||||
"database_id": 1,
|
||||
"schema": "public",
|
||||
"sql": null,
|
||||
"is_active": true,
|
||||
"owners": ["admin", "data_team"],
|
||||
"columns": [
|
||||
{
|
||||
"id": 101,
|
||||
"column_name": "region",
|
||||
"type": "VARCHAR",
|
||||
"is_active": true,
|
||||
"is_dttm": false,
|
||||
"groupby": true,
|
||||
"filterable": true,
|
||||
"description": "Geographic region"
|
||||
},
|
||||
{
|
||||
"id": 102,
|
||||
"column_name": "order_date",
|
||||
"type": "DATE",
|
||||
"is_active": true,
|
||||
"is_dttm": true,
|
||||
"groupby": true,
|
||||
"filterable": true
|
||||
}
|
||||
],
|
||||
"metrics": [
|
||||
{
|
||||
"id": 201,
|
||||
"metric_name": "sum__revenue",
|
||||
"expression": "SUM(revenue)",
|
||||
"metric_type": "sum",
|
||||
"is_active": true,
|
||||
"description": "Total revenue"
|
||||
}
|
||||
],
|
||||
"created_on": "2024-01-05T08:00:00Z",
|
||||
"changed_on": "2024-01-18T12:30:00Z"
|
||||
}
|
||||
```
|
||||
|
||||
## System Tools
|
||||
|
||||
### get_superset_instance_info
|
||||
|
||||
Get Superset instance information and statistics.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"include_statistics": true, // Include usage statistics
|
||||
"include_tools": true, // Include available MCP tools
|
||||
"use_cache": true
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"version": "4.1.0",
|
||||
"build": "apache-superset-4.1.0",
|
||||
"mcp_service_version": "1.0.0",
|
||||
"authentication": {
|
||||
"enabled": true,
|
||||
"type": "jwt_bearer",
|
||||
"required_scopes": ["dashboard:read", "chart:read"]
|
||||
},
|
||||
"statistics": {
|
||||
"dashboards": {
|
||||
"total": 45,
|
||||
"published": 32
|
||||
},
|
||||
"charts": {
|
||||
"total": 125,
|
||||
"by_viz_type": {
|
||||
"line": 35,
|
||||
"bar": 28,
|
||||
"table": 42,
|
||||
"pie": 20
|
||||
}
|
||||
},
|
||||
"datasets": {
|
||||
"total": 23,
|
||||
"active": 18
|
||||
},
|
||||
"users": {
|
||||
"total": 15,
|
||||
"active": 12
|
||||
}
|
||||
},
|
||||
"mcp_tools": [
|
||||
{
|
||||
"name": "list_dashboards",
|
||||
"description": "List dashboards with search and filtering",
|
||||
"category": "dashboard"
|
||||
},
|
||||
{
|
||||
"name": "generate_chart",
|
||||
"description": "Create new charts programmatically",
|
||||
"category": "chart"
|
||||
}
|
||||
],
|
||||
"database_connections": [
|
||||
{
|
||||
"id": 1,
|
||||
"database_name": "main_warehouse",
|
||||
"backend": "postgresql",
|
||||
"status": "healthy"
|
||||
}
|
||||
],
|
||||
"cache_status": {
|
||||
"enabled": true,
|
||||
"backend": "redis",
|
||||
"hit_rate": 0.85
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### generate_explore_link
|
||||
|
||||
Generate Superset explore URLs with pre-configured chart settings.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"dataset_id": "1",
|
||||
"chart_config": {
|
||||
"viz_type": "line",
|
||||
"x_axis": "month",
|
||||
"metrics": ["sum__revenue"],
|
||||
"time_range": "Last 6 months"
|
||||
},
|
||||
"title": "Revenue Analysis",
|
||||
"cache_form_data": true
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"explore_url": "/superset/explore/?form_data_key=abc123def456",
|
||||
"full_url": "http://localhost:8088/superset/explore/?form_data_key=abc123def456",
|
||||
"form_data_key": "abc123def456",
|
||||
"expires_at": "2024-01-26T16:45:00Z",
|
||||
"chart_config": {
|
||||
"viz_type": "line",
|
||||
"datasource": "1__table",
|
||||
"x_axis": "month",
|
||||
"metrics": ["sum__revenue"],
|
||||
"time_range": "Last 6 months"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## SQL Lab Tools
|
||||
|
||||
### open_sql_lab_with_context
|
||||
|
||||
Open SQL Lab with pre-configured database, schema, and SQL.
|
||||
|
||||
**Request Schema:**
|
||||
```json
|
||||
{
|
||||
"database_connection_id": 1,
|
||||
"schema": "public",
|
||||
"dataset_in_context": "sales_data",
|
||||
"sql": "SELECT region, SUM(revenue) as total_revenue\nFROM sales_data \nWHERE order_date >= '2024-01-01'\nGROUP BY region\nORDER BY total_revenue DESC",
|
||||
"title": "Regional Sales Analysis"
|
||||
}
|
||||
```
|
||||
|
||||
**Response Example:**
|
||||
```json
|
||||
{
|
||||
"sql_lab_url": "/superset/sqllab/?dbid=1&schema=public&sql_template=encoded_sql_here",
|
||||
"full_url": "http://localhost:8088/superset/sqllab/?dbid=1&schema=public&sql_template=encoded_sql_here",
|
||||
"database_connection": {
|
||||
"id": 1,
|
||||
"database_name": "main_warehouse",
|
||||
"backend": "postgresql"
|
||||
},
|
||||
"schema": "public",
|
||||
"sql_template": "SELECT region, SUM(revenue) as total_revenue...",
|
||||
"context": {
|
||||
"dataset": "sales_data",
|
||||
"title": "Regional Sales Analysis"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Error Responses
|
||||
|
||||
All tools can return error responses with this structure:
|
||||
|
||||
```json
|
||||
{
|
||||
"error": "Chart not found with identifier: 999",
|
||||
"error_type": "NotFound",
|
||||
"suggestions": [
|
||||
"Verify the chart ID exists",
|
||||
"Check if you have permission to access this chart",
|
||||
"Try using the chart UUID instead of ID"
|
||||
],
|
||||
"details": {
|
||||
"identifier": 999,
|
||||
"identifier_type": "id"
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Cache Status
|
||||
|
||||
Many responses include cache status information:
|
||||
|
||||
```json
|
||||
{
|
||||
"cache_status": {
|
||||
"cache_hit": true, // Data served from cache
|
||||
"cache_type": "query", // Type: query, metadata, form_data
|
||||
"cache_age_seconds": 300, // Age of cached data
|
||||
"refreshed": false // Whether cache was refreshed
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
This API reference provides complete documentation for integrating with the Superset MCP service, including all request schemas, response formats, and error handling patterns.
|
||||
|
||||
## What's Next?
|
||||
|
||||
### 🔐 **Ready for Production?**
|
||||
Set up authentication and security with the [Authentication Guide](./authentication).
|
||||
|
||||
### 🔧 **Want to Add More Tools?**
|
||||
Learn how to extend the MCP service in the [Development Guide](./development).
|
||||
|
||||
### 🏗️ **Need Architecture Details?**
|
||||
Understand the system design in the [Architecture Overview](./architecture).
|
||||
|
||||
### 🏢 **Enterprise Features?**
|
||||
Explore advanced capabilities in the [Preset Integration Guide](./preset-integration).
|
||||
|
||||
> 📖 **Back to Documentation Index**: [MCP Service](./intro)
|
||||
191
docs/docs/mcp-service/architecture.mdx
Normal file
191
docs/docs/mcp-service/architecture.mdx
Normal file
@@ -0,0 +1,191 @@
|
||||
---
|
||||
title: Architecture Overview
|
||||
sidebar_position: 5
|
||||
version: 1
|
||||
---
|
||||
|
||||
# Architecture Overview
|
||||
|
||||
The Superset Model Context Protocol (MCP) service provides a modular, schema-driven interface for programmatic access to Superset dashboards, charts, datasets, and instance metadata. Built on FastMCP for LLM agents and automation tools.
|
||||
|
||||
**Status:** Phase 1 Complete. Core functionality stable, authentication production-ready. See [SIP-171](https://github.com/apache/superset/issues/33870) for roadmap.
|
||||
|
||||
## Core Architecture
|
||||
|
||||
### Tool Structure
|
||||
- **16 MCP tools** organized by domain: `dashboard/`, `chart/`, `dataset/`, `system/`
|
||||
- All tools decorated with `@mcp.tool` and `@mcp_auth_hook`
|
||||
- **Import inside functions**: All Superset DAOs/commands imported in function body to ensure proper app context
|
||||
- Pydantic v2 schemas with LLM/OpenAPI-compatible field descriptions
|
||||
|
||||
### Request Schema Pattern
|
||||
Eliminates LLM parameter validation issues using structured request objects:
|
||||
```python
|
||||
# New approach - single request object
|
||||
get_dataset_info(request={"identifier": 123}) # ID
|
||||
get_dataset_info(request={"identifier": "uuid-string"}) # UUID
|
||||
|
||||
# Old approach - replaced
|
||||
get_dataset_info(dataset_id=123)
|
||||
```
|
||||
|
||||
### Multi-Identifier Support
|
||||
- **Charts/Datasets**: ID (numeric) or UUID (string)
|
||||
- **Dashboards**: ID (numeric), UUID (string), or slug (string)
|
||||
- Validation prevents conflicting parameters (search + filters)
|
||||
|
||||
## Available Tools
|
||||
|
||||
### Dashboard Tools (5)
|
||||
- `list_dashboards` - List with search/filters/pagination
|
||||
- `get_dashboard_info` - Get by ID/UUID/slug
|
||||
- `get_dashboard_available_filters` - Discover filterable columns
|
||||
- `generate_dashboard` - Create dashboards with multiple charts
|
||||
- `add_chart_to_existing_dashboard` - Add charts to existing dashboards
|
||||
|
||||
### Chart Tools (8)
|
||||
- `list_charts` - List with search/filters/pagination
|
||||
- `get_chart_info` - Get by ID/UUID
|
||||
- `get_chart_available_filters` - Discover filterable columns
|
||||
- `generate_chart` - Create charts (table, line, bar, area, scatter)
|
||||
- `update_chart` - Update saved charts
|
||||
- `update_chart_preview` - Update cached previews
|
||||
- `get_chart_data` - Export data (JSON/CSV/Excel)
|
||||
- `get_chart_preview` - Screenshots, ASCII art, table previews
|
||||
|
||||
### Dataset Tools (3)
|
||||
- `list_datasets` - List with columns/metrics
|
||||
- `get_dataset_info` - Get by ID/UUID with metadata
|
||||
- `get_dataset_available_filters` - Discover filterable columns
|
||||
|
||||
### System Tools (2)
|
||||
- `get_superset_instance_info` - Instance statistics and version
|
||||
- `generate_explore_link` - Generate chart exploration URLs
|
||||
|
||||
### SQL Lab Tools (1)
|
||||
- `open_sql_lab_with_context` - Pre-configured SQL Lab sessions
|
||||
|
||||
## Authentication & Security
|
||||
|
||||
### JWT Bearer Authentication
|
||||
Production-ready authentication with configurable factory pattern:
|
||||
```python
|
||||
# In superset_config.py
|
||||
MCP_AUTH_ENABLED = True
|
||||
MCP_JWKS_URI = "https://auth.company.com/.well-known/jwks.json"
|
||||
MCP_JWT_ISSUER = "https://auth.company.com/"
|
||||
MCP_JWT_AUDIENCE = "superset-mcp-api"
|
||||
```
|
||||
|
||||
### Scope-Based Authorization
|
||||
| Tool Category | Required Scope |
|
||||
|---------------|----------------|
|
||||
| Dashboard ops | `dashboard:read` |
|
||||
| Chart ops | `chart:read` / `chart:write` |
|
||||
| Dataset ops | `dataset:read` |
|
||||
| System ops | `instance:read` |
|
||||
|
||||
### Audit Logging
|
||||
All operations logged with MCP context:
|
||||
- User impersonation tracking
|
||||
- Tool execution details
|
||||
- Sanitized payloads (sensitive data redacted)
|
||||
|
||||
## Cache Control
|
||||
|
||||
Leverages Superset's existing cache layers with comprehensive control:
|
||||
|
||||
### Cache Types
|
||||
1. **Query Result Cache** - Database query results
|
||||
2. **Metadata Cache** - Table schemas, columns, metrics
|
||||
3. **Form Data Cache** - Chart configurations
|
||||
4. **Dashboard Cache** - Rendered components
|
||||
|
||||
### Cache Parameters
|
||||
Tools support cache control through request schemas:
|
||||
- `use_cache`: Enable/disable caching (default: true)
|
||||
- `force_refresh`: Force cache refresh (default: false)
|
||||
- `cache_timeout`: Override timeout in seconds
|
||||
- `refresh_metadata`: Force metadata refresh
|
||||
|
||||
### Cache Status Reporting
|
||||
```json
|
||||
{
|
||||
"cache_status": {
|
||||
"cache_hit": true,
|
||||
"cache_type": "query",
|
||||
"cache_age_seconds": 300,
|
||||
"refreshed": false
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Tool Abstractions
|
||||
|
||||
### Generic Base Classes
|
||||
- **ModelListTool**: Handles list/search/filter operations with pagination
|
||||
- **ModelGetInfoTool**: Single object retrieval by multiple identifier types
|
||||
- **ModelGetAvailableFiltersTool**: Returns filterable columns/operators
|
||||
|
||||
### Implementation Pattern
|
||||
```python
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def my_tool(request: MyRequest) -> MyResponse:
|
||||
# Import Superset modules inside function
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
from superset.commands.chart.create import CreateChartCommand
|
||||
|
||||
# Tool implementation
|
||||
return response
|
||||
```
|
||||
|
||||
## Configuration
|
||||
|
||||
### URL Configuration
|
||||
Centralized URL management for consistent link generation:
|
||||
```python
|
||||
# In superset_config.py
|
||||
SUPERSET_WEBSERVER_ADDRESS = "http://localhost:8088" # Development
|
||||
SUPERSET_WEBSERVER_ADDRESS = "https://superset.company.com" # Production
|
||||
```
|
||||
|
||||
### Schema Design Principles
|
||||
- **Minimal columns** in list responses
|
||||
- **Optional fields** in info schemas for missing data handling
|
||||
- **Null exclusion** for cleaner JSON responses
|
||||
- **Type safety** with clear Pydantic validation
|
||||
|
||||
## Adding New Tools
|
||||
|
||||
1. **Choose domain folder**: `dashboard/`, `chart/`, `dataset/`, or `system/`
|
||||
2. **Define schemas**: Use Pydantic with field descriptions
|
||||
3. **Implement tool**:
|
||||
- Decorate with `@mcp.tool` and `@mcp_auth_hook`
|
||||
- Import Superset modules inside function body
|
||||
- Use generic abstractions where applicable
|
||||
4. **Register**: Add to appropriate `__init__.py`
|
||||
5. **Test**: Add unit tests in `tests/unit_tests/mcp_service/`
|
||||
|
||||
## Current Status
|
||||
|
||||
### ✅ Phase 1 Complete
|
||||
- FastMCP server with CLI
|
||||
- JWT authentication with RBAC
|
||||
- All 16 core tools implemented
|
||||
- Request schema pattern
|
||||
- Cache control system
|
||||
- Audit logging
|
||||
- 194+ unit tests
|
||||
|
||||
### 🎯 Future Enhancements
|
||||
- Demo notebooks and video examples
|
||||
- OAuth integration for user impersonation
|
||||
- Enhanced chart rendering formats
|
||||
- Advanced security features
|
||||
|
||||
**Production Ready**: Core functionality stable with comprehensive testing and authentication.
|
||||
|
||||
---
|
||||
|
||||
For setup and usage, see the [MCP Service overview](./intro).
|
||||
434
docs/docs/mcp-service/authentication.mdx
Normal file
434
docs/docs/mcp-service/authentication.mdx
Normal file
@@ -0,0 +1,434 @@
|
||||
---
|
||||
title: Authentication & Security
|
||||
sidebar_position: 4
|
||||
version: 1
|
||||
---
|
||||
|
||||
# Authentication & Security
|
||||
|
||||
The MCP service provides enterprise-grade JWT Bearer authentication with flexible configuration options and comprehensive security controls.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Development Mode (Default)
|
||||
|
||||
:::tip
|
||||
Authentication is **disabled by default** for local development - no configuration needed.
|
||||
:::
|
||||
|
||||
```bash
|
||||
# No configuration needed - service runs without authentication
|
||||
superset mcp run --port 5008 --debug
|
||||
```
|
||||
|
||||
### Production Mode
|
||||
|
||||
:::warning
|
||||
Always enable authentication for production deployments to secure your Superset instance.
|
||||
:::
|
||||
|
||||
Enable JWT authentication in your Superset configuration:
|
||||
|
||||
```python
|
||||
# In superset_config.py
|
||||
MCP_AUTH_ENABLED = True
|
||||
MCP_JWKS_URI = "https://auth.company.com/.well-known/jwks.json"
|
||||
MCP_JWT_ISSUER = "https://auth.company.com/"
|
||||
MCP_JWT_AUDIENCE = "superset-mcp-api"
|
||||
MCP_REQUIRED_SCOPES = ["dashboard:read", "chart:read", "dataset:read"]
|
||||
```
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Option 1: Simple Configuration
|
||||
|
||||
Add to your `superset_config.py`:
|
||||
|
||||
```python
|
||||
# Enable authentication
|
||||
MCP_AUTH_ENABLED = True
|
||||
|
||||
# JWT settings
|
||||
MCP_JWKS_URI = "https://auth.company.com/.well-known/jwks.json"
|
||||
MCP_JWT_ISSUER = "https://auth.company.com/"
|
||||
MCP_JWT_AUDIENCE = "superset-mcp-api"
|
||||
MCP_REQUIRED_SCOPES = ["dashboard:read", "chart:read"]
|
||||
|
||||
# Optional: User resolution
|
||||
MCP_JWT_USER_CLAIM = "sub" # JWT claim for username (default: "sub")
|
||||
MCP_JWT_EMAIL_CLAIM = "email" # JWT claim for email (default: "email")
|
||||
MCP_FALLBACK_USER = "admin" # Fallback user if JWT user not found
|
||||
```
|
||||
|
||||
### Option 2: Custom Factory
|
||||
|
||||
For advanced authentication requirements:
|
||||
|
||||
```python
|
||||
def create_custom_mcp_auth(app):
|
||||
"""Custom auth factory for enterprise environments."""
|
||||
from fastmcp.server.auth.providers.bearer import BearerAuthProvider
|
||||
|
||||
return BearerAuthProvider(
|
||||
jwks_uri=app.config["MCP_JWKS_URI"],
|
||||
issuer=app.config["MCP_JWT_ISSUER"],
|
||||
audience=app.config["MCP_JWT_AUDIENCE"],
|
||||
required_scopes=app.config.get("MCP_REQUIRED_SCOPES", []),
|
||||
user_resolver=custom_user_resolver,
|
||||
cache_ttl=300 # Cache JWKS for 5 minutes
|
||||
)
|
||||
|
||||
MCP_AUTH_FACTORY = create_custom_mcp_auth
|
||||
```
|
||||
|
||||
### Option 3: Environment Variables
|
||||
|
||||
For containerized deployments:
|
||||
|
||||
```bash
|
||||
# Environment variables
|
||||
export MCP_AUTH_ENABLED=true
|
||||
export MCP_JWKS_URI=https://auth.company.com/.well-known/jwks.json
|
||||
export MCP_JWT_ISSUER=https://auth.company.com/
|
||||
export MCP_JWT_AUDIENCE=superset-mcp-api
|
||||
export MCP_REQUIRED_SCOPES=dashboard:read,chart:read,dataset:read
|
||||
```
|
||||
|
||||
## Identity Provider Integration
|
||||
|
||||
### Auth0
|
||||
|
||||
```python
|
||||
# Auth0 configuration
|
||||
MCP_JWKS_URI = "https://your-tenant.auth0.com/.well-known/jwks.json"
|
||||
MCP_JWT_ISSUER = "https://your-tenant.auth0.com/"
|
||||
MCP_JWT_AUDIENCE = "superset-mcp-api"
|
||||
```
|
||||
|
||||
### Okta
|
||||
|
||||
```python
|
||||
# Okta configuration
|
||||
MCP_JWKS_URI = "https://your-org.okta.com/oauth2/default/v1/keys"
|
||||
MCP_JWT_ISSUER = "https://your-org.okta.com/oauth2/default"
|
||||
MCP_JWT_AUDIENCE = "api://superset-mcp"
|
||||
```
|
||||
|
||||
### AWS Cognito
|
||||
|
||||
```python
|
||||
# Cognito configuration
|
||||
MCP_JWKS_URI = "https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json"
|
||||
MCP_JWT_ISSUER = "https://cognito-idp.{region}.amazonaws.com/{userPoolId}"
|
||||
MCP_JWT_AUDIENCE = "your-app-client-id"
|
||||
```
|
||||
|
||||
### Azure AD
|
||||
|
||||
```python
|
||||
# Azure AD configuration
|
||||
MCP_JWKS_URI = "https://login.microsoftonline.com/{tenant}/discovery/v2.0/keys"
|
||||
MCP_JWT_ISSUER = "https://login.microsoftonline.com/{tenant}/v2.0"
|
||||
MCP_JWT_AUDIENCE = "api://superset-mcp"
|
||||
```
|
||||
|
||||
## Scope-Based Authorization
|
||||
|
||||
### Standard Scopes
|
||||
|
||||
The MCP service defines these standard scopes:
|
||||
|
||||
| Scope | Description | Required For |
|
||||
|-------|-------------|--------------|
|
||||
| `dashboard:read` | Read dashboard information | `list_dashboards`, `get_dashboard_info` |
|
||||
| `dashboard:write` | Create/modify dashboards | `generate_dashboard`, `add_chart_to_existing_dashboard` |
|
||||
| `chart:read` | Read chart information | `list_charts`, `get_chart_info`, `get_chart_data` |
|
||||
| `chart:write` | Create/modify charts | `generate_chart`, `update_chart` |
|
||||
| `dataset:read` | Read dataset information | `list_datasets`, `get_dataset_info` |
|
||||
| `instance:read` | Read instance information | `get_superset_instance_info` |
|
||||
|
||||
### Custom Scopes
|
||||
|
||||
Define custom scopes for specific use cases:
|
||||
|
||||
```python
|
||||
# Custom scope definitions
|
||||
CUSTOM_MCP_SCOPES = {
|
||||
"analytics:export": "Export analytical data",
|
||||
"reports:generate": "Generate automated reports",
|
||||
"admin:config": "Access administrative configuration"
|
||||
}
|
||||
|
||||
# Map tools to custom scopes
|
||||
def get_custom_required_scopes(tool_name: str) -> List[str]:
|
||||
scope_map = {
|
||||
"get_chart_data": ["chart:read", "analytics:export"],
|
||||
"generate_dashboard": ["dashboard:write", "reports:generate"],
|
||||
"get_superset_instance_info": ["instance:read", "admin:config"]
|
||||
}
|
||||
return scope_map.get(tool_name, [])
|
||||
|
||||
MCP_SCOPE_RESOLVER = get_custom_required_scopes
|
||||
```
|
||||
|
||||
## JWT Token Format
|
||||
|
||||
### Required Claims
|
||||
|
||||
Your JWT tokens must include these standard claims:
|
||||
|
||||
```json
|
||||
{
|
||||
"iss": "https://auth.company.com/", // Issuer
|
||||
"aud": "superset-mcp-api", // Audience
|
||||
"sub": "user@company.com", // Subject (username)
|
||||
"exp": 1704118800, // Expiration timestamp
|
||||
"iat": 1704115200, // Issued at timestamp
|
||||
"scope": "dashboard:read chart:read" // Space-separated scopes
|
||||
}
|
||||
```
|
||||
|
||||
### Optional Claims
|
||||
|
||||
Additional claims for enhanced functionality:
|
||||
|
||||
```json
|
||||
{
|
||||
"email": "user@company.com", // User email
|
||||
"name": "John Doe", // Full name
|
||||
"groups": ["analysts", "sales_team"], // User groups
|
||||
"tenant_id": "company_123", // Multi-tenant ID
|
||||
"role": "analyst" // User role
|
||||
}
|
||||
```
|
||||
|
||||
## Client Integration
|
||||
|
||||
### API Client Usage
|
||||
|
||||
```python
|
||||
import requests
|
||||
|
||||
# Get JWT token from your identity provider
|
||||
token = get_jwt_token()
|
||||
|
||||
# Call MCP service with Bearer authentication
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
"http://localhost:5008/call_tool",
|
||||
headers=headers,
|
||||
json={
|
||||
"tool": "list_dashboards",
|
||||
"arguments": {"search": "sales"}
|
||||
}
|
||||
)
|
||||
|
||||
data = response.json()
|
||||
```
|
||||
|
||||
### Claude Desktop with Authentication
|
||||
|
||||
For Claude Desktop, the proxy script handles authentication:
|
||||
|
||||
```bash
|
||||
#!/bin/bash
|
||||
# run_proxy_with_auth.sh
|
||||
|
||||
# Get token from environment or file
|
||||
if [ -f ~/.superset_mcp_token ]; then
|
||||
TOKEN=$(cat ~/.superset_mcp_token)
|
||||
else
|
||||
TOKEN=${SUPERSET_MCP_TOKEN}
|
||||
fi
|
||||
|
||||
# Export token for proxy
|
||||
export MCP_AUTH_TOKEN="$TOKEN"
|
||||
|
||||
cd /path/to/superset
|
||||
source venv/bin/activate
|
||||
exec fastmcp proxy http://localhost:5008 --auth-header "Authorization: Bearer $TOKEN"
|
||||
```
|
||||
|
||||
## User Resolution
|
||||
|
||||
### Default User Resolution
|
||||
|
||||
The service maps JWT claims to Superset users:
|
||||
|
||||
```python
|
||||
def default_user_resolver(claims: Dict[str, Any]) -> User:
|
||||
"""Default user resolution from JWT claims."""
|
||||
|
||||
# Extract username from configurable claim
|
||||
username = claims.get(app.config.get("MCP_JWT_USER_CLAIM", "sub"))
|
||||
|
||||
# Find Superset user
|
||||
user = security_manager.find_user(username=username)
|
||||
|
||||
if not user:
|
||||
# Try email lookup
|
||||
email = claims.get(app.config.get("MCP_JWT_EMAIL_CLAIM", "email"))
|
||||
if email:
|
||||
user = security_manager.find_user(email=email)
|
||||
|
||||
if not user and app.config.get("MCP_FALLBACK_USER"):
|
||||
# Use fallback user for development
|
||||
user = security_manager.find_user(username=app.config["MCP_FALLBACK_USER"])
|
||||
|
||||
return user
|
||||
```
|
||||
|
||||
### Custom User Resolution
|
||||
|
||||
Implement custom user resolution logic:
|
||||
|
||||
```python
|
||||
def custom_user_resolver(claims: Dict[str, Any]) -> User:
|
||||
"""Custom user resolution for enterprise environments."""
|
||||
|
||||
# Extract custom claims
|
||||
employee_id = claims.get("employee_id")
|
||||
tenant_id = claims.get("tenant_id")
|
||||
|
||||
# Multi-tenant user lookup
|
||||
user = find_user_by_employee_id(employee_id, tenant_id)
|
||||
|
||||
if user:
|
||||
# Set additional context
|
||||
user.mcp_tenant_id = tenant_id
|
||||
user.mcp_groups = claims.get("groups", [])
|
||||
|
||||
return user
|
||||
|
||||
# Use custom resolver
|
||||
MCP_USER_RESOLVER = custom_user_resolver
|
||||
```
|
||||
|
||||
## Security Features
|
||||
|
||||
### Token Validation
|
||||
|
||||
Comprehensive JWT validation:
|
||||
|
||||
- **Signature verification**: RS256 with JWKS key rotation support
|
||||
- **Expiration checking**: Automatic token expiry validation
|
||||
- **Audience validation**: Prevents token reuse across services
|
||||
- **Issuer validation**: Ensures tokens from trusted sources only
|
||||
- **Scope validation**: Enforces tool-level permissions
|
||||
|
||||
### Request Security
|
||||
|
||||
- **HTTPS enforcement**: Production deployments should use HTTPS
|
||||
- **Rate limiting**: Configurable per-user rate limits
|
||||
- **Request logging**: All authenticated requests logged with user context
|
||||
- **Input validation**: Comprehensive request schema validation
|
||||
|
||||
### Audit Logging
|
||||
|
||||
Every tool call is logged with security context:
|
||||
|
||||
```json
|
||||
{
|
||||
"timestamp": "2024-01-25T14:30:00Z",
|
||||
"user_id": "user@company.com",
|
||||
"tool_name": "get_chart_data",
|
||||
"source": "mcp",
|
||||
"jwt_subject": "user@company.com",
|
||||
"jwt_scopes": ["chart:read", "analytics:export"],
|
||||
"tenant_id": "company_123",
|
||||
"request_id": "req_12345",
|
||||
"execution_time": 0.145,
|
||||
"status": "success"
|
||||
}
|
||||
```
|
||||
|
||||
## Testing Authentication
|
||||
|
||||
### Generate Test Tokens
|
||||
|
||||
For development and testing:
|
||||
|
||||
```python
|
||||
from fastmcp.server.auth.providers.bearer import RSAKeyPair
|
||||
|
||||
# Generate test keypair
|
||||
keypair = RSAKeyPair.generate()
|
||||
print("Public key:", keypair.public_key)
|
||||
|
||||
# Create test token
|
||||
token = keypair.create_token(
|
||||
subject="test@example.com",
|
||||
issuer="https://test.example.com",
|
||||
audience="superset-mcp-api",
|
||||
scopes=["dashboard:read", "chart:read", "dataset:read"],
|
||||
expires_in=3600 # 1 hour
|
||||
)
|
||||
print("Test token:", token)
|
||||
```
|
||||
|
||||
### Test Configuration
|
||||
|
||||
```python
|
||||
# Test configuration with generated keypair
|
||||
MCP_AUTH_ENABLED = True
|
||||
MCP_JWT_PUBLIC_KEY = """-----BEGIN PUBLIC KEY-----
|
||||
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA...
|
||||
-----END PUBLIC KEY-----"""
|
||||
MCP_JWT_ISSUER = "https://test.example.com"
|
||||
MCP_JWT_AUDIENCE = "superset-mcp-api"
|
||||
MCP_FALLBACK_USER = "admin"
|
||||
```
|
||||
|
||||
### Manual Testing
|
||||
|
||||
```bash
|
||||
# Test with curl
|
||||
curl -X POST http://localhost:5008/call_tool \
|
||||
-H "Authorization: Bearer $TEST_TOKEN" \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"tool": "get_superset_instance_info",
|
||||
"arguments": {"include_statistics": true}
|
||||
}'
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Token Validation Errors:**
|
||||
```
|
||||
Error: Invalid JWT signature
|
||||
Solution: Verify JWKS_URI is accessible and contains correct keys
|
||||
```
|
||||
|
||||
**User Not Found:**
|
||||
```
|
||||
Error: User not found for JWT subject
|
||||
Solution: Check MCP_JWT_USER_CLAIM configuration and user exists in Superset
|
||||
```
|
||||
|
||||
**Insufficient Scopes:**
|
||||
```
|
||||
Error: Missing required scope 'chart:read'
|
||||
Solution: Update JWT token to include required scopes
|
||||
```
|
||||
|
||||
### Debug Configuration
|
||||
|
||||
Enable debug logging for authentication issues:
|
||||
|
||||
```python
|
||||
# Enhanced logging for auth debugging
|
||||
import logging
|
||||
logging.getLogger('superset.mcp_service.auth').setLevel(logging.DEBUG)
|
||||
|
||||
# Log all JWT validation steps
|
||||
MCP_AUTH_DEBUG = True
|
||||
```
|
||||
|
||||
This authentication guide provides comprehensive coverage for securing the MCP service in production environments while maintaining development flexibility.
|
||||
705
docs/docs/mcp-service/development.mdx
Normal file
705
docs/docs/mcp-service/development.mdx
Normal file
@@ -0,0 +1,705 @@
|
||||
---
|
||||
title: Development Guide
|
||||
sidebar_position: 2
|
||||
version: 1
|
||||
---
|
||||
|
||||
# MCP Service Development Guide
|
||||
|
||||
This guide covers the internal architecture, development workflows, and patterns for extending the Superset MCP service.
|
||||
|
||||
> 🚀 **New to MCP?** Start with the [Overview](./overview) to understand what the service does before diving into development.
|
||||
>
|
||||
> 📚 **Need API examples?** Check the [API Reference](./api-reference) to see how existing tools work.
|
||||
>
|
||||
> 🔐 **Planning production use?** Review [Authentication](./authentication) for security considerations.
|
||||
|
||||
## Internal Architecture
|
||||
|
||||
### Component Overview
|
||||
|
||||
The MCP service follows a layered architecture with clear separation of concerns:
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Transport Layer"
|
||||
HTTP[HTTP Server :5008]
|
||||
FastMCP[FastMCP Protocol Handler]
|
||||
end
|
||||
|
||||
subgraph "Auth & Middleware Layer"
|
||||
AuthHook[Auth Hook Decorator]
|
||||
JWT[JWT Validator]
|
||||
RBAC[RBAC Engine]
|
||||
Audit[Audit Logger]
|
||||
end
|
||||
|
||||
subgraph "Tool Layer"
|
||||
Tools[16 MCP Tools<br/>Tool Decorated]
|
||||
Schemas[Pydantic Schemas]
|
||||
Validation[Request Validation]
|
||||
end
|
||||
|
||||
subgraph "Business Logic Layer"
|
||||
Generic[Generic Tool Abstractions]
|
||||
ModelList[ModelListTool]
|
||||
ModelGet[ModelGetInfoTool]
|
||||
ModelFilter[ModelGetAvailableFiltersTool]
|
||||
end
|
||||
|
||||
subgraph "Data Access Layer"
|
||||
DAOs[Superset DAOs]
|
||||
Commands[Superset Commands]
|
||||
Cache[Cache Manager]
|
||||
end
|
||||
|
||||
subgraph "Storage Layer"
|
||||
MetaDB[(Metadata DB)]
|
||||
DataWH[(Data Warehouse)]
|
||||
Redis[(Redis Cache)]
|
||||
end
|
||||
|
||||
HTTP --> FastMCP
|
||||
FastMCP --> AuthHook
|
||||
AuthHook --> JWT
|
||||
JWT --> RBAC
|
||||
RBAC --> Audit
|
||||
Audit --> Tools
|
||||
|
||||
Tools --> Schemas
|
||||
Schemas --> Validation
|
||||
Validation --> Generic
|
||||
|
||||
Generic --> ModelList
|
||||
Generic --> ModelGet
|
||||
Generic --> ModelFilter
|
||||
|
||||
ModelList --> DAOs
|
||||
ModelGet --> DAOs
|
||||
ModelFilter --> DAOs
|
||||
|
||||
Tools --> Commands
|
||||
Commands --> Cache
|
||||
|
||||
DAOs --> MetaDB
|
||||
Commands --> MetaDB
|
||||
Commands --> DataWH
|
||||
Cache --> Redis
|
||||
```
|
||||
|
||||
### Request Flow
|
||||
|
||||
Every MCP tool call follows this execution pattern:
|
||||
|
||||
```mermaid
|
||||
sequenceDiagram
|
||||
participant Client as LLM Client
|
||||
participant MCP as FastMCP Server
|
||||
participant Auth as Auth Hook
|
||||
participant Tool as MCP Tool
|
||||
participant Generic as Generic Abstraction
|
||||
participant DAO as Superset DAO
|
||||
participant DB as Database
|
||||
|
||||
Client->>+MCP: tool_call(request)
|
||||
MCP->>+Auth: validate_and_authorize()
|
||||
Auth->>Auth: Validate JWT token
|
||||
Auth->>Auth: Check required scopes
|
||||
Auth->>Auth: Set Flask g.user context
|
||||
Auth->>Auth: Log audit event
|
||||
Auth->>+Tool: execute_tool(validated_request)
|
||||
|
||||
Tool->>Tool: Parse Pydantic request schema
|
||||
Tool->>+Generic: Use generic abstraction
|
||||
Generic->>+DAO: Query Superset data
|
||||
DAO->>+DB: Execute SQL
|
||||
DB-->>-DAO: Return results
|
||||
DAO-->>-Generic: Return objects
|
||||
Generic->>Generic: Apply pagination/filtering
|
||||
Generic-->>-Tool: Return formatted data
|
||||
|
||||
Tool->>Tool: Build Pydantic response schema
|
||||
Tool-->>-Auth: Return response
|
||||
Auth->>Auth: Log success audit event
|
||||
Auth-->>-MCP: Return validated response
|
||||
MCP-->>-Client: JSON response
|
||||
```
|
||||
|
||||
### Tool Registration System
|
||||
|
||||
Tools are automatically discovered and registered through the decorator pattern:
|
||||
|
||||
```python
|
||||
# superset/mcp_service/mcp_app.py
|
||||
from fastmcp import FastMCP
|
||||
|
||||
# Global MCP instance
|
||||
mcp = FastMCP("Superset MCP Service")
|
||||
|
||||
# Tools register themselves via decorators
|
||||
@mcp.tool
|
||||
@mcp_auth_hook(['chart:read'])
|
||||
def get_chart_info(request: GetChartInfoRequest) -> GetChartInfoResponse:
|
||||
# Tool implementation
|
||||
pass
|
||||
|
||||
# All tool modules imported to trigger registration
|
||||
from superset.mcp_service.chart.tool import *
|
||||
from superset.mcp_service.dashboard.tool import *
|
||||
from superset.mcp_service.dataset.tool import *
|
||||
from superset.mcp_service.system.tool import *
|
||||
```
|
||||
|
||||
## Development Patterns
|
||||
|
||||
### Tool Implementation Pattern
|
||||
|
||||
All tools follow this standardized pattern:
|
||||
|
||||
```python
|
||||
# Example: superset/mcp_service/chart/tool/get_chart_info.py
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
GetChartInfoRequest,
|
||||
GetChartInfoResponse,
|
||||
ChartError
|
||||
)
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook(['chart:read'])
|
||||
def get_chart_info(request: GetChartInfoRequest) -> GetChartInfoResponse:
|
||||
"""
|
||||
Get detailed information about a specific chart.
|
||||
|
||||
Supports lookup by ID or UUID with comprehensive metadata.
|
||||
"""
|
||||
try:
|
||||
# CRITICAL: Import Superset modules inside function
|
||||
from superset.daos.chart import ChartDAO
|
||||
from superset.models.slice import Slice
|
||||
|
||||
# Use generic abstraction for common operations
|
||||
from superset.mcp_service.generic_tools import ModelGetInfoTool
|
||||
|
||||
tool = ModelGetInfoTool(
|
||||
dao=ChartDAO,
|
||||
model=Slice,
|
||||
response_schema=GetChartInfoResponse,
|
||||
identifier_field_map={
|
||||
'id': 'id',
|
||||
'uuid': 'uuid'
|
||||
}
|
||||
)
|
||||
|
||||
return tool.execute(request)
|
||||
|
||||
except Exception as e:
|
||||
return ChartError(
|
||||
error=f"Failed to get chart info: {str(e)}",
|
||||
error_type="ChartInfoError"
|
||||
)
|
||||
```
|
||||
|
||||
### Schema Design Patterns
|
||||
|
||||
Pydantic schemas follow these conventions:
|
||||
|
||||
```python
|
||||
# Request Schema Pattern
|
||||
class GetChartInfoRequest(BaseModel):
|
||||
"""Request to get detailed chart information."""
|
||||
|
||||
identifier: Union[int, str] = Field(
|
||||
...,
|
||||
description="Chart ID (numeric) or UUID (string)"
|
||||
)
|
||||
|
||||
include_form_data: bool = Field(
|
||||
default=True,
|
||||
description="Whether to include chart configuration"
|
||||
)
|
||||
|
||||
use_cache: bool = Field(
|
||||
default=True,
|
||||
description="Whether to use cached data"
|
||||
)
|
||||
|
||||
# Response Schema Pattern
|
||||
class GetChartInfoResponse(BaseModel):
|
||||
"""Detailed chart information response."""
|
||||
|
||||
chart_id: int = Field(description="Chart numeric ID")
|
||||
uuid: Optional[str] = Field(description="Chart UUID")
|
||||
slice_name: str = Field(description="Chart display name")
|
||||
viz_type: str = Field(description="Visualization type")
|
||||
datasource_id: Optional[int] = Field(description="Dataset ID")
|
||||
form_data: Optional[Dict[str, Any]] = Field(description="Chart configuration")
|
||||
explore_url: Optional[str] = Field(description="Explore URL for editing")
|
||||
|
||||
# Cache status for transparency
|
||||
cache_status: Optional[CacheStatus] = Field(description="Cache hit information")
|
||||
|
||||
# Error Schema Pattern
|
||||
class ChartError(BaseModel):
|
||||
"""Chart operation error response."""
|
||||
|
||||
error: str = Field(description="Error message")
|
||||
error_type: str = Field(description="Error type identifier")
|
||||
suggestions: Optional[List[str]] = Field(description="Suggested fixes")
|
||||
```
|
||||
|
||||
### Generic Tool Abstractions
|
||||
|
||||
Common operations are abstracted into reusable classes:
|
||||
|
||||
```python
|
||||
# superset/mcp_service/generic_tools.py
|
||||
from typing import Type, Dict, Any, List, Optional
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ModelListTool:
|
||||
"""Generic tool for list operations with pagination and filtering."""
|
||||
|
||||
def __init__(self,
|
||||
dao: Type,
|
||||
model: Type,
|
||||
response_schema: Type[BaseModel],
|
||||
default_columns: List[str] = None,
|
||||
searchable_columns: List[str] = None):
|
||||
self.dao = dao
|
||||
self.model = model
|
||||
self.response_schema = response_schema
|
||||
self.default_columns = default_columns or []
|
||||
self.searchable_columns = searchable_columns or []
|
||||
|
||||
def execute(self, request: BaseModel) -> BaseModel:
|
||||
"""Execute list operation with pagination and filtering."""
|
||||
|
||||
# Build query with filters
|
||||
query = self.dao.find_all()
|
||||
|
||||
# Apply search if provided
|
||||
if hasattr(request, 'search') and request.search:
|
||||
query = self._apply_search(query, request.search)
|
||||
|
||||
# Apply filters if provided
|
||||
if hasattr(request, 'filters') and request.filters:
|
||||
query = self._apply_filters(query, request.filters)
|
||||
|
||||
# Apply pagination
|
||||
total = query.count()
|
||||
|
||||
if hasattr(request, 'page') and hasattr(request, 'page_size'):
|
||||
offset = (request.page - 1) * request.page_size
|
||||
query = query.offset(offset).limit(request.page_size)
|
||||
|
||||
# Execute query and serialize
|
||||
results = query.all()
|
||||
serialized = [self._serialize_model(obj) for obj in results]
|
||||
|
||||
return self.response_schema(
|
||||
results=serialized,
|
||||
total_count=total,
|
||||
page=getattr(request, 'page', 1),
|
||||
page_size=getattr(request, 'page_size', len(serialized))
|
||||
)
|
||||
|
||||
class ModelGetInfoTool:
|
||||
"""Generic tool for getting single object by multiple identifier types."""
|
||||
|
||||
def __init__(self,
|
||||
dao: Type,
|
||||
model: Type,
|
||||
response_schema: Type[BaseModel],
|
||||
identifier_field_map: Dict[str, str]):
|
||||
self.dao = dao
|
||||
self.model = model
|
||||
self.response_schema = response_schema
|
||||
self.identifier_field_map = identifier_field_map
|
||||
|
||||
def execute(self, request: BaseModel) -> BaseModel:
|
||||
"""Execute get operation with multi-identifier support."""
|
||||
|
||||
identifier = request.identifier
|
||||
|
||||
# Determine identifier type and field
|
||||
if isinstance(identifier, int):
|
||||
field = self.identifier_field_map.get('id', 'id')
|
||||
obj = self.dao.find_by_id(identifier)
|
||||
elif isinstance(identifier, str):
|
||||
if len(identifier) == 36 and '-' in identifier: # UUID format
|
||||
field = self.identifier_field_map.get('uuid', 'uuid')
|
||||
obj = self.dao.find_by_uuid(identifier)
|
||||
else: # Assume slug
|
||||
field = self.identifier_field_map.get('slug', 'slug')
|
||||
obj = getattr(self.dao, 'find_by_slug', lambda x: None)(identifier)
|
||||
|
||||
if not obj:
|
||||
raise ValueError(f"Object not found with identifier: {identifier}")
|
||||
|
||||
# Serialize and return
|
||||
serialized = self._serialize_model(obj)
|
||||
return self.response_schema(**serialized)
|
||||
```
|
||||
|
||||
## Adding New Tools
|
||||
|
||||
### Step-by-Step Process
|
||||
|
||||
1. **Define the Domain**
|
||||
|
||||
Choose the appropriate domain folder:
|
||||
- `dashboard/` - Dashboard operations
|
||||
- `chart/` - Chart operations
|
||||
- `dataset/` - Dataset operations
|
||||
- `system/` - System-level operations
|
||||
|
||||
2. **Create Schemas**
|
||||
|
||||
```bash
|
||||
# Create schema file
|
||||
touch superset/mcp_service/schemas/my_domain_schemas.py
|
||||
```
|
||||
|
||||
```python
|
||||
# Define request/response schemas
|
||||
class MyToolRequest(BaseModel):
|
||||
param1: str = Field(description="Parameter description")
|
||||
param2: Optional[int] = Field(default=None, description="Optional parameter")
|
||||
|
||||
class MyToolResponse(BaseModel):
|
||||
result: str = Field(description="Result description")
|
||||
metadata: Dict[str, Any] = Field(description="Additional metadata")
|
||||
```
|
||||
|
||||
3. **Implement the Tool**
|
||||
|
||||
```bash
|
||||
# Create tool file
|
||||
touch superset/mcp_service/my_domain/tool/my_tool.py
|
||||
```
|
||||
|
||||
```python
|
||||
@mcp.tool
|
||||
@mcp_auth_hook(['required:scope'])
|
||||
def my_tool(request: MyToolRequest) -> MyToolResponse:
|
||||
"""Tool description for LLM."""
|
||||
|
||||
# Import Superset modules inside function
|
||||
from superset.daos.my_dao import MyDAO
|
||||
|
||||
# Implement business logic
|
||||
result = MyDAO.do_something(request.param1)
|
||||
|
||||
return MyToolResponse(
|
||||
result=result,
|
||||
metadata={"processed_at": datetime.utcnow()}
|
||||
)
|
||||
```
|
||||
|
||||
4. **Register the Tool**
|
||||
|
||||
```python
|
||||
# Add to superset/mcp_service/my_domain/tool/__init__.py
|
||||
from .my_tool import my_tool
|
||||
|
||||
__all__ = ['my_tool']
|
||||
```
|
||||
|
||||
```python
|
||||
# Import in superset/mcp_service/mcp_app.py
|
||||
from superset.mcp_service.my_domain.tool import *
|
||||
```
|
||||
|
||||
5. **Add Tests**
|
||||
|
||||
```bash
|
||||
# Create test file
|
||||
touch tests/unit_tests/mcp_service/test_my_tool.py
|
||||
```
|
||||
|
||||
```python
|
||||
import pytest
|
||||
from superset.mcp_service.my_domain.tool.my_tool import my_tool
|
||||
from superset.mcp_service.schemas.my_domain_schemas import MyToolRequest
|
||||
|
||||
class TestMyTool:
|
||||
def test_my_tool_success(self):
|
||||
request = MyToolRequest(param1="test")
|
||||
response = my_tool(request)
|
||||
assert response.result == "expected_result"
|
||||
```
|
||||
|
||||
### Tool Best Practices
|
||||
|
||||
1. **Import Inside Functions**
|
||||
```python
|
||||
# ❌ DON'T: Import at module level
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
@mcp.tool
|
||||
def my_tool():
|
||||
# Tool implementation
|
||||
pass
|
||||
|
||||
# ✅ DO: Import inside function
|
||||
@mcp.tool
|
||||
def my_tool():
|
||||
from superset.daos.chart import ChartDAO
|
||||
# Tool implementation
|
||||
pass
|
||||
```
|
||||
|
||||
2. **Use Generic Abstractions**
|
||||
```python
|
||||
# ✅ Leverage existing patterns
|
||||
@mcp.tool
|
||||
def list_my_objects(request):
|
||||
from superset.mcp_service.generic_tools import ModelListTool
|
||||
|
||||
tool = ModelListTool(
|
||||
dao=MyDAO,
|
||||
model=MyModel,
|
||||
response_schema=ListMyObjectsResponse
|
||||
)
|
||||
return tool.execute(request)
|
||||
```
|
||||
|
||||
3. **Comprehensive Error Handling**
|
||||
```python
|
||||
@mcp.tool
|
||||
def my_tool(request):
|
||||
try:
|
||||
# Tool implementation
|
||||
return success_response
|
||||
except PermissionError as e:
|
||||
return MyToolError(
|
||||
error="Permission denied",
|
||||
error_type="PermissionError",
|
||||
suggestions=["Check user permissions"]
|
||||
)
|
||||
except Exception as e:
|
||||
return MyToolError(
|
||||
error=f"Unexpected error: {str(e)}",
|
||||
error_type="InternalError"
|
||||
)
|
||||
```
|
||||
|
||||
## Testing Patterns
|
||||
|
||||
### Unit Test Structure
|
||||
|
||||
```python
|
||||
# tests/unit_tests/mcp_service/test_chart_tools.py
|
||||
import pytest
|
||||
from unittest.mock import Mock, patch
|
||||
from superset.mcp_service.chart.tool.get_chart_info import get_chart_info
|
||||
from superset.mcp_service.schemas.chart_schemas import GetChartInfoRequest
|
||||
|
||||
class TestGetChartInfo:
|
||||
"""Test suite for get_chart_info tool."""
|
||||
|
||||
@patch('superset.mcp_service.chart.tool.get_chart_info.ChartDAO')
|
||||
def test_get_chart_info_by_id_success(self, mock_dao):
|
||||
"""Test successful chart lookup by ID."""
|
||||
|
||||
# Setup mock
|
||||
mock_chart = Mock()
|
||||
mock_chart.id = 1
|
||||
mock_chart.slice_name = "Test Chart"
|
||||
mock_chart.viz_type = "line"
|
||||
mock_dao.find_by_id.return_value = mock_chart
|
||||
|
||||
# Execute
|
||||
request = GetChartInfoRequest(identifier=1)
|
||||
response = get_chart_info(request)
|
||||
|
||||
# Verify
|
||||
assert response.chart_id == 1
|
||||
assert response.slice_name == "Test Chart"
|
||||
mock_dao.find_by_id.assert_called_once_with(1)
|
||||
|
||||
@patch('superset.mcp_service.chart.tool.get_chart_info.ChartDAO')
|
||||
def test_get_chart_info_not_found(self, mock_dao):
|
||||
"""Test chart not found scenario."""
|
||||
|
||||
# Setup mock
|
||||
mock_dao.find_by_id.return_value = None
|
||||
|
||||
# Execute
|
||||
request = GetChartInfoRequest(identifier=999)
|
||||
response = get_chart_info(request)
|
||||
|
||||
# Verify error response
|
||||
assert hasattr(response, 'error')
|
||||
assert "not found" in response.error.lower()
|
||||
```
|
||||
|
||||
### Integration Test Patterns
|
||||
|
||||
```python
|
||||
# tests/integration_tests/mcp_service/test_chart_integration.py
|
||||
import pytest
|
||||
from superset.app import create_app
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
|
||||
class TestChartIntegration(SupersetTestCase):
|
||||
"""Integration tests for chart tools."""
|
||||
|
||||
def setUp(self):
|
||||
super().setUp()
|
||||
self.app = create_app()
|
||||
self.app_context = self.app.app_context()
|
||||
self.app_context.push()
|
||||
|
||||
def tearDown(self):
|
||||
self.app_context.pop()
|
||||
super().tearDown()
|
||||
|
||||
def test_chart_workflow_integration(self):
|
||||
"""Test complete chart workflow."""
|
||||
|
||||
# Create chart
|
||||
create_request = {
|
||||
"dataset_id": "1",
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "region"}]
|
||||
}
|
||||
}
|
||||
|
||||
create_response = mcp.call_tool("generate_chart", create_request)
|
||||
chart_id = create_response["chart_id"]
|
||||
|
||||
# Get chart info
|
||||
info_request = {"identifier": chart_id}
|
||||
info_response = mcp.call_tool("get_chart_info", info_request)
|
||||
|
||||
assert info_response["chart_id"] == chart_id
|
||||
assert info_response["viz_type"] == "table"
|
||||
|
||||
# Get chart data
|
||||
data_request = {"identifier": chart_id, "limit": 10}
|
||||
data_response = mcp.call_tool("get_chart_data", data_request)
|
||||
|
||||
assert "data" in data_response
|
||||
assert len(data_response["data"]) <= 10
|
||||
```
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
### Caching Strategy
|
||||
|
||||
The MCP service leverages Superset's existing cache layers:
|
||||
|
||||
```python
|
||||
# Cache control in tools
|
||||
@mcp.tool
|
||||
def get_chart_data(request: GetChartDataRequest):
|
||||
"""Tool with cache control."""
|
||||
|
||||
cache_config = {
|
||||
'use_cache': request.use_cache,
|
||||
'force_refresh': request.force_refresh,
|
||||
'cache_timeout': request.cache_timeout
|
||||
}
|
||||
|
||||
# Use Superset's cache infrastructure
|
||||
result = execute_with_cache(query, cache_config)
|
||||
|
||||
return ChartDataResponse(
|
||||
data=result.data,
|
||||
cache_status=result.cache_status
|
||||
)
|
||||
```
|
||||
|
||||
### Query Optimization
|
||||
|
||||
```python
|
||||
# Efficient pagination
|
||||
def list_objects(query, page, page_size):
|
||||
"""Optimized pagination pattern."""
|
||||
|
||||
# Count query optimization
|
||||
total = query.count()
|
||||
|
||||
# Limit columns for list operations
|
||||
query = query.options(load_only('id', 'name', 'created_on'))
|
||||
|
||||
# Apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
results = query.offset(offset).limit(page_size).all()
|
||||
|
||||
return results, total
|
||||
```
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### Authentication Flow
|
||||
|
||||
```python
|
||||
# JWT validation and user context
|
||||
@mcp_auth_hook(['chart:read'])
|
||||
def secure_tool(request):
|
||||
"""Tool with proper security context."""
|
||||
|
||||
# g.user is set by auth hook
|
||||
user_id = g.user.id
|
||||
|
||||
# Apply user-specific filtering
|
||||
query = ChartDAO.find_all().filter(
|
||||
Chart.owners.contains(g.user)
|
||||
)
|
||||
|
||||
return execute_query(query)
|
||||
```
|
||||
|
||||
### Input Validation
|
||||
|
||||
```python
|
||||
# Comprehensive request validation
|
||||
class CreateChartRequest(BaseModel):
|
||||
"""Validated chart creation request."""
|
||||
|
||||
dataset_id: Union[int, str] = Field(
|
||||
...,
|
||||
description="Dataset ID or UUID"
|
||||
)
|
||||
|
||||
config: ChartConfig = Field(
|
||||
...,
|
||||
description="Chart configuration"
|
||||
)
|
||||
|
||||
@validator('dataset_id')
|
||||
def validate_dataset_id(cls, v):
|
||||
"""Validate dataset exists and user has access."""
|
||||
# Validation logic
|
||||
return v
|
||||
|
||||
@validator('config')
|
||||
def validate_chart_config(cls, v):
|
||||
"""Validate chart configuration."""
|
||||
# Configuration validation
|
||||
return v
|
||||
```
|
||||
|
||||
This development guide provides comprehensive coverage of the MCP service's internal architecture and development patterns, enabling team members to effectively extend and maintain the system.
|
||||
|
||||
## Related Documentation
|
||||
|
||||
### 📚 **Ready to Use Your New Tools?**
|
||||
Test your implementations with examples from the [API Reference](./api-reference).
|
||||
|
||||
### 🔐 **Securing Your Extensions?**
|
||||
Add authentication to your tools using the [Authentication Guide](./authentication).
|
||||
|
||||
### 🏗️ **Understanding the Big Picture?**
|
||||
See the complete system design in the [Architecture Overview](./architecture).
|
||||
|
||||
### 🏢 **Building Enterprise Features?**
|
||||
Explore advanced patterns in the [Preset Integration Guide](./preset-integration).
|
||||
|
||||
> 📖 **Back to Documentation Index**: [MCP Service](./intro)
|
||||
124
docs/docs/mcp-service/intro.mdx
Normal file
124
docs/docs/mcp-service/intro.mdx
Normal file
@@ -0,0 +1,124 @@
|
||||
---
|
||||
title: MCP Service
|
||||
sidebar_position: 1
|
||||
version: 1
|
||||
---
|
||||
|
||||
# Superset MCP Service
|
||||
|
||||
The Superset Model Context Protocol (MCP) service provides programmatic access to Superset dashboards, charts, datasets, and instance metadata. Built for LLM agents and automation tools.
|
||||
|
||||
## What is MCP?
|
||||
|
||||
The Model Context Protocol (MCP) is an open standard that allows AI assistants to securely connect to data sources and tools. Superset's MCP service exposes **16 production-ready tools** that enable:
|
||||
|
||||
- 📊 **Data Exploration**: List and query dashboards, charts, and datasets
|
||||
- 🔧 **Chart Creation**: Generate visualizations programmatically
|
||||
- 📈 **Data Export**: Extract data in multiple formats (JSON, CSV, Excel)
|
||||
- 🔗 **Navigation**: Generate explore links and SQL Lab sessions
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Installation
|
||||
|
||||
:::note
|
||||
The MCP service is included with Superset development setup. FastMCP dependencies are installed automatically with `make install`.
|
||||
:::
|
||||
|
||||
```bash
|
||||
# MCP service is included with Superset development setup
|
||||
git clone https://github.com/apache/superset.git
|
||||
cd superset
|
||||
make venv && source venv/bin/activate
|
||||
make install
|
||||
|
||||
# Start Superset
|
||||
superset run -p 8088 --with-threads --reload --debugger
|
||||
|
||||
# Start MCP service (separate terminal)
|
||||
source venv/bin/activate
|
||||
superset mcp run --port 5008 --debug
|
||||
```
|
||||
|
||||
### Claude Desktop Integration
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"Superset MCP": {
|
||||
"command": "/path/to/superset/superset/mcp_service/run_proxy.sh",
|
||||
"args": [],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Key Features
|
||||
|
||||
### 🔧 **16 Production Tools**
|
||||
| Category | Tools | Purpose |
|
||||
|----------|-------|---------|
|
||||
| **Dashboard** (5) | List, get info, create, add charts | Dashboard management |
|
||||
| **Chart** (8) | Full CRUD, data export, previews | Chart operations |
|
||||
| **Dataset** (3) | List, get info, discover filters | Dataset exploration |
|
||||
| **System** (2) | Instance info, explore links | System integration |
|
||||
| **SQL Lab** (1) | Pre-configured sessions | SQL development |
|
||||
|
||||
### 🔐 **Enterprise Security**
|
||||
- **JWT Bearer Authentication**: Production-ready with configurable factory pattern
|
||||
- **RBAC Integration**: Scope-based permissions with Superset's security model
|
||||
- **Audit Logging**: Comprehensive MCP context tracking
|
||||
|
||||
### 📊 **Advanced Capabilities**
|
||||
- **Multi-format Export**: JSON, CSV, Excel data export
|
||||
- **Chart Previews**: Screenshots, ASCII art, table representations
|
||||
- **Cache Control**: Leverage Superset's existing cache infrastructure
|
||||
- **Request Schemas**: Eliminates LLM parameter validation issues
|
||||
|
||||
## Example Usage
|
||||
|
||||
```python
|
||||
# List dashboards
|
||||
dashboards = client.call_tool("list_dashboards", {
|
||||
"search": "sales",
|
||||
"page_size": 10
|
||||
})
|
||||
|
||||
# Create a chart
|
||||
chart = client.call_tool("generate_chart", {
|
||||
"dataset_id": "1",
|
||||
"config": {
|
||||
"chart_type": "line",
|
||||
"x": {"name": "date"},
|
||||
"y": [{"name": "revenue", "aggregate": "SUM"}]
|
||||
}
|
||||
})
|
||||
|
||||
# Export chart data
|
||||
data = client.call_tool("get_chart_data", {
|
||||
"identifier": chart["chart_id"],
|
||||
"format": "json",
|
||||
"limit": 1000
|
||||
})
|
||||
```
|
||||
|
||||
## Status
|
||||
|
||||
✅ **Phase 1 Complete** - Core functionality stable, authentication production-ready, comprehensive testing coverage.
|
||||
|
||||
## Documentation Structure
|
||||
|
||||
### Getting Started
|
||||
- **[Overview](./overview)** - Features, use cases, and examples
|
||||
- **[API Reference](./api-reference)** - Complete tool documentation
|
||||
|
||||
### Development
|
||||
- **[Development Guide](./development)** - Internal architecture and adding tools
|
||||
- **[Architecture](./architecture)** - System design and patterns
|
||||
|
||||
### Production
|
||||
- **[Authentication](./authentication)** - JWT setup and security
|
||||
- **[Preset Integration](./preset-integration)** - Enterprise features
|
||||
|
||||
> 🚀 **Ready to start?** Continue with the [Overview](./overview) for detailed examples and use cases.
|
||||
196
docs/docs/mcp-service/overview.mdx
Normal file
196
docs/docs/mcp-service/overview.mdx
Normal file
@@ -0,0 +1,196 @@
|
||||
---
|
||||
title: MCP Service Overview
|
||||
sidebar_position: 1
|
||||
version: 1
|
||||
---
|
||||
|
||||
# Superset MCP Service
|
||||
|
||||
The Superset Model Context Protocol (MCP) service provides a modular, schema-driven interface for programmatic access to Superset dashboards, charts, datasets, and instance metadata. Built on FastMCP, it's designed for LLM agents and automation tools.
|
||||
|
||||
**Status:** ✅ Phase 1 Complete. Core functionality stable, authentication production-ready, comprehensive testing coverage.
|
||||
|
||||
## What is MCP?
|
||||
|
||||
The Model Context Protocol (MCP) is an open standard for connecting AI assistants to data sources and tools. Superset's MCP service exposes 16 tools that allow LLM agents to:
|
||||
|
||||
- **Explore data**: List and query dashboards, charts, and datasets
|
||||
- **Create visualizations**: Generate charts and dashboards programmatically
|
||||
- **Export data**: Extract chart data in multiple formats
|
||||
- **Navigate interfaces**: Generate explore links and SQL Lab sessions
|
||||
|
||||
## Key Features
|
||||
|
||||
### 🔧 **16 Production-Ready Tools**
|
||||
- **Dashboard Tools (5)**: List, get info, create dashboards, add charts
|
||||
- **Chart Tools (8)**: Full CRUD operations, data export, screenshot previews
|
||||
- **Dataset Tools (3)**: List, get info, discover filterable columns
|
||||
- **System Tools (2)**: Instance info, explore link generation
|
||||
- **SQL Lab Tools (1)**: Pre-configured SQL sessions
|
||||
|
||||
### 🔐 **Enterprise Authentication**
|
||||
- **JWT Bearer Authentication**: Production-ready with configurable factory pattern
|
||||
- **RBAC Integration**: Scope-based permissions with Superset's security model
|
||||
- **Audit Logging**: Comprehensive MCP context tracking with impersonation support
|
||||
|
||||
### 📊 **Advanced Capabilities**
|
||||
- **Multi-format Export**: JSON, CSV, Excel data export
|
||||
- **Chart Previews**: Screenshots, ASCII art, and table representations
|
||||
- **Cache Control**: Comprehensive control over Superset's cache layers
|
||||
- **Request Schema Pattern**: Eliminates LLM parameter validation issues
|
||||
|
||||
## Architecture Overview
|
||||
|
||||
```mermaid
|
||||
graph TB
|
||||
subgraph "Client Layer"
|
||||
LLM[LLM/Agent Client]
|
||||
Claude[Claude Desktop]
|
||||
SDK[Custom SDK]
|
||||
end
|
||||
|
||||
subgraph "MCP Service Layer"
|
||||
FastMCP[FastMCP Server<br/>Port 5008]
|
||||
Auth[JWT Auth Hook]
|
||||
Tools[16 MCP Tools]
|
||||
end
|
||||
|
||||
subgraph "Superset Integration"
|
||||
DAOs[Superset DAOs]
|
||||
Commands[Superset Commands]
|
||||
Cache[Cache Layer]
|
||||
end
|
||||
|
||||
subgraph "Data Layer"
|
||||
DB[(Superset Database)]
|
||||
DataWarehouse[(Data Warehouse)]
|
||||
end
|
||||
|
||||
LLM --> FastMCP
|
||||
Claude --> FastMCP
|
||||
SDK --> FastMCP
|
||||
|
||||
FastMCP --> Auth
|
||||
Auth --> Tools
|
||||
|
||||
Tools --> DAOs
|
||||
Tools --> Commands
|
||||
Tools --> Cache
|
||||
|
||||
DAOs --> DB
|
||||
Commands --> DB
|
||||
Commands --> DataWarehouse
|
||||
```
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Quick Setup
|
||||
|
||||
```bash
|
||||
# Clone and install Superset
|
||||
git clone https://github.com/apache/superset.git
|
||||
cd superset
|
||||
make venv && source venv/bin/activate
|
||||
make install
|
||||
|
||||
# Start Superset
|
||||
superset run -p 8088 --with-threads --reload --debugger
|
||||
|
||||
# Start MCP service (in separate terminal)
|
||||
source venv/bin/activate
|
||||
superset mcp run --port 5008 --debug
|
||||
```
|
||||
|
||||
### Connect to Claude Desktop
|
||||
|
||||
:::note
|
||||
The MCP service runs on HTTP and requires a proxy for Claude Desktop integration.
|
||||
:::
|
||||
|
||||
```bash
|
||||
# Install FastMCP proxy
|
||||
pip install fastmcp
|
||||
```
|
||||
|
||||
Configure Claude Desktop (`~/.config/Claude/claude_desktop_config.json`):
|
||||
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"Superset MCP": {
|
||||
"command": "/path/to/superset/superset/mcp_service/run_proxy.sh",
|
||||
"args": [],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Use Cases
|
||||
|
||||
### Data Exploration
|
||||
- "List all dashboards related to sales"
|
||||
- "Show me the charts in the Q4 Performance dashboard"
|
||||
- "What datasets are available for customer analysis?"
|
||||
|
||||
### Chart Creation
|
||||
- "Create a line chart showing revenue trends by month"
|
||||
- "Generate a table showing top 10 products by sales"
|
||||
- "Build a bar chart comparing regional performance"
|
||||
|
||||
### Data Export
|
||||
- "Export the sales data from this chart as CSV"
|
||||
- "Get the underlying data for this dashboard as JSON"
|
||||
- "Show me a preview of this chart as ASCII art"
|
||||
|
||||
### Dashboard Management
|
||||
- "Create a new dashboard with these 4 charts"
|
||||
- "Add this revenue chart to the executive dashboard"
|
||||
- "Generate an explore link for this chart configuration"
|
||||
|
||||
## Example Workflow
|
||||
|
||||
```python
|
||||
# List available dashboards
|
||||
dashboards = client.call_tool("list_dashboards", {
|
||||
"search": "sales",
|
||||
"page_size": 10
|
||||
})
|
||||
|
||||
# Get detailed dashboard info
|
||||
dashboard = client.call_tool("get_dashboard_info", {
|
||||
"identifier": dashboards["dashboards"][0]["id"]
|
||||
})
|
||||
|
||||
# Create a new chart
|
||||
chart = client.call_tool("generate_chart", {
|
||||
"dataset_id": "1",
|
||||
"config": {
|
||||
"chart_type": "line",
|
||||
"x": {"name": "date"},
|
||||
"y": [{"name": "revenue", "aggregate": "SUM"}]
|
||||
}
|
||||
})
|
||||
|
||||
# Export chart data
|
||||
data = client.call_tool("get_chart_data", {
|
||||
"identifier": chart["chart_id"],
|
||||
"format": "json",
|
||||
"limit": 1000
|
||||
})
|
||||
```
|
||||
|
||||
## Next Steps
|
||||
|
||||
### Ready to Use MCP?
|
||||
- **[📚 API Reference](./api-reference)** - Try all 16 tools with request/response examples
|
||||
- **[🔐 Authentication](./authentication)** - Set up JWT security for production use
|
||||
|
||||
### Want to Extend MCP?
|
||||
- **[🔧 Development Guide](./development)** - Learn internal architecture and add new tools
|
||||
- **[🏗️ Architecture](./architecture)** - Understand system design and deployment patterns
|
||||
|
||||
### Enterprise Deployment?
|
||||
- **[🏢 Preset Integration](./preset-integration)** - RBAC extensions and OIDC integration for enterprise
|
||||
|
||||
> 💡 **Getting started?** Return to the [MCP Service intro](./intro) for a complete overview.
|
||||
483
docs/docs/mcp-service/preset-integration.mdx
Normal file
483
docs/docs/mcp-service/preset-integration.mdx
Normal file
@@ -0,0 +1,483 @@
|
||||
---
|
||||
title: Preset.io Integration
|
||||
sidebar_position: 6
|
||||
version: 1
|
||||
---
|
||||
|
||||
# Preset.io Integration Guide
|
||||
|
||||
This document outlines integration points for the Preset.io team to extend the Superset MCP service with enterprise features, RBAC customizations, and OIDC integration.
|
||||
|
||||
## RBAC Extension Points
|
||||
|
||||
### Custom Authorization Factory
|
||||
|
||||
The MCP service supports custom authorization logic through the factory pattern:
|
||||
|
||||
```python
|
||||
# In preset_config.py or superset_config.py
|
||||
def create_preset_mcp_auth(app):
|
||||
"""Custom auth factory for Preset.io environments."""
|
||||
from superset.mcp_service.auth import create_auth_provider
|
||||
from preset.auth.mcp import PresetMCPAuthProvider
|
||||
|
||||
return PresetMCPAuthProvider(
|
||||
jwks_uri=app.config["PRESET_JWKS_URI"],
|
||||
issuer=app.config["PRESET_JWT_ISSUER"],
|
||||
audience=app.config["PRESET_JWT_AUDIENCE"],
|
||||
tenant_resolver=preset_tenant_resolver,
|
||||
rbac_manager=app.security_manager,
|
||||
)
|
||||
|
||||
MCP_AUTH_FACTORY = create_preset_mcp_auth
|
||||
```
|
||||
|
||||
### Multi-Tenant RBAC
|
||||
|
||||
Extend the base auth hook for tenant-aware permissions:
|
||||
|
||||
```python
|
||||
# preset/mcp/auth.py
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from functools import wraps
|
||||
|
||||
def preset_tenant_auth_hook(required_permissions=None):
|
||||
"""Preset-specific auth hook with tenant isolation."""
|
||||
def decorator(func):
|
||||
@wraps(func)
|
||||
@mcp_auth_hook(required_permissions)
|
||||
def wrapper(*args, **kwargs):
|
||||
# Extract tenant from JWT claims
|
||||
tenant_id = g.user.tenant_id if hasattr(g.user, 'tenant_id') else None
|
||||
|
||||
# Inject tenant context
|
||||
g.mcp_tenant_id = tenant_id
|
||||
g.mcp_tenant_context = get_tenant_context(tenant_id)
|
||||
|
||||
return func(*args, **kwargs)
|
||||
return wrapper
|
||||
return decorator
|
||||
```
|
||||
|
||||
### Custom Permission Scopes
|
||||
|
||||
Define Preset-specific permission scopes:
|
||||
|
||||
```python
|
||||
# preset/mcp/permissions.py
|
||||
PRESET_MCP_SCOPES = {
|
||||
# Tenant-level permissions
|
||||
"tenant:admin": "Full tenant administration",
|
||||
"tenant:read": "Read tenant resources",
|
||||
|
||||
# Workspace-level permissions
|
||||
"workspace:admin": "Full workspace administration",
|
||||
"workspace:read": "Read workspace resources",
|
||||
|
||||
# Enhanced dashboard permissions
|
||||
"dashboard:publish": "Publish dashboards to marketplace",
|
||||
"dashboard:embed": "Generate embed tokens",
|
||||
|
||||
# Enhanced chart permissions
|
||||
"chart:export": "Export chart data and configs",
|
||||
"chart:alerts": "Manage chart alerts and notifications",
|
||||
|
||||
# Dataset permissions with row-level security
|
||||
"dataset:rls": "Apply row-level security filters",
|
||||
"dataset:pii": "Access PII-flagged columns",
|
||||
}
|
||||
|
||||
def get_preset_required_scopes(tool_name: str, context: dict = None) -> List[str]:
|
||||
"""Map tool calls to Preset-specific permission requirements."""
|
||||
base_scopes = get_base_required_scopes(tool_name)
|
||||
|
||||
# Add tenant-aware scopes
|
||||
if context and context.get('tenant_id'):
|
||||
base_scopes.append(f"tenant:{context['tenant_id']}")
|
||||
|
||||
# Add workspace-aware scopes
|
||||
if context and context.get('workspace_id'):
|
||||
base_scopes.append(f"workspace:{context['workspace_id']}")
|
||||
|
||||
return base_scopes
|
||||
```
|
||||
|
||||
### Row-Level Security Integration
|
||||
|
||||
Extend data access tools with RLS:
|
||||
|
||||
```python
|
||||
# preset/mcp/rls.py
|
||||
def apply_preset_rls_filters(query_context: dict, user_context: dict) -> dict:
|
||||
"""Apply Preset row-level security filters to query context."""
|
||||
|
||||
# Get user's RLS rules from Preset metadata
|
||||
rls_rules = get_user_rls_rules(
|
||||
user_id=user_context['user_id'],
|
||||
tenant_id=user_context['tenant_id'],
|
||||
workspace_id=user_context.get('workspace_id')
|
||||
)
|
||||
|
||||
# Apply RLS filters to query
|
||||
for rule in rls_rules:
|
||||
if rule.applies_to_dataset(query_context['datasource']['id']):
|
||||
query_context = rule.apply_filter(query_context)
|
||||
|
||||
return query_context
|
||||
|
||||
# Usage in custom tools
|
||||
@mcp.tool
|
||||
@preset_tenant_auth_hook(['dataset:read', 'dataset:rls'])
|
||||
def preset_get_chart_data(request: GetChartDataRequest) -> ChartDataResponse:
|
||||
"""Get chart data with Preset RLS applied."""
|
||||
|
||||
# Apply RLS before executing query
|
||||
query_context = build_query_context(request)
|
||||
query_context = apply_preset_rls_filters(
|
||||
query_context,
|
||||
{'user_id': g.user.id, 'tenant_id': g.mcp_tenant_id}
|
||||
)
|
||||
|
||||
return execute_chart_data_query(query_context)
|
||||
```
|
||||
|
||||
## OIDC Integration Points
|
||||
|
||||
### Preset OIDC Provider
|
||||
|
||||
Custom OIDC integration for Preset environments:
|
||||
|
||||
```python
|
||||
# preset/mcp/oidc.py
|
||||
from superset.mcp_service.auth.providers.bearer import BearerAuthProvider
|
||||
import requests
|
||||
from typing import Dict, Any
|
||||
|
||||
class PresetOIDCAuthProvider(BearerAuthProvider):
|
||||
"""OIDC-specific auth provider for Preset.io."""
|
||||
|
||||
def __init__(self,
|
||||
oidc_discovery_url: str,
|
||||
client_id: str,
|
||||
client_secret: str = None,
|
||||
**kwargs):
|
||||
|
||||
# Discover OIDC endpoints
|
||||
self.discovery_doc = self._fetch_discovery_document(oidc_discovery_url)
|
||||
|
||||
super().__init__(
|
||||
jwks_uri=self.discovery_doc['jwks_uri'],
|
||||
issuer=self.discovery_doc['issuer'],
|
||||
**kwargs
|
||||
)
|
||||
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
|
||||
def _fetch_discovery_document(self, discovery_url: str) -> Dict[str, Any]:
|
||||
"""Fetch OIDC discovery document."""
|
||||
response = requests.get(discovery_url)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def validate_token(self, token: str) -> Dict[str, Any]:
|
||||
"""Validate JWT token with OIDC-specific claims."""
|
||||
claims = super().validate_token(token)
|
||||
|
||||
# Validate OIDC-specific claims
|
||||
if claims.get('aud') != self.client_id:
|
||||
raise ValueError("Invalid audience claim")
|
||||
|
||||
# Extract Preset-specific claims
|
||||
claims['preset_tenant_id'] = claims.get('tenant_id')
|
||||
claims['preset_workspace_id'] = claims.get('workspace_id')
|
||||
claims['preset_roles'] = claims.get('roles', [])
|
||||
|
||||
return claims
|
||||
|
||||
def resolve_user(self, claims: Dict[str, Any]) -> Any:
|
||||
"""Resolve Superset user from OIDC claims."""
|
||||
from preset.auth.user_resolver import resolve_preset_user
|
||||
|
||||
return resolve_preset_user(
|
||||
subject=claims['sub'],
|
||||
email=claims.get('email'),
|
||||
tenant_id=claims.get('preset_tenant_id'),
|
||||
roles=claims.get('preset_roles', [])
|
||||
)
|
||||
```
|
||||
|
||||
### Configuration for OIDC
|
||||
|
||||
```python
|
||||
# In preset_config.py
|
||||
def create_preset_oidc_auth(app):
|
||||
"""Factory for Preset OIDC authentication."""
|
||||
from preset.mcp.oidc import PresetOIDCAuthProvider
|
||||
|
||||
return PresetOIDCAuthProvider(
|
||||
oidc_discovery_url=app.config["PRESET_OIDC_DISCOVERY_URL"],
|
||||
client_id=app.config["PRESET_OIDC_CLIENT_ID"],
|
||||
client_secret=app.config["PRESET_OIDC_CLIENT_SECRET"],
|
||||
audience=app.config["PRESET_MCP_AUDIENCE"],
|
||||
required_scopes=app.config.get("PRESET_MCP_REQUIRED_SCOPES", [])
|
||||
)
|
||||
|
||||
# MCP Configuration
|
||||
MCP_AUTH_ENABLED = True
|
||||
MCP_AUTH_FACTORY = create_preset_oidc_auth
|
||||
|
||||
# OIDC Configuration
|
||||
PRESET_OIDC_DISCOVERY_URL = "https://auth.preset.io/.well-known/openid_configuration"
|
||||
PRESET_OIDC_CLIENT_ID = "preset-mcp-service"
|
||||
PRESET_OIDC_CLIENT_SECRET = os.environ.get("PRESET_OIDC_CLIENT_SECRET")
|
||||
PRESET_MCP_AUDIENCE = "preset-superset-mcp"
|
||||
PRESET_MCP_REQUIRED_SCOPES = [
|
||||
"openid", "profile", "email",
|
||||
"superset:read", "superset:write"
|
||||
]
|
||||
```
|
||||
|
||||
## Preset-Specific Tools
|
||||
|
||||
### Tenant Management Tools
|
||||
|
||||
```python
|
||||
# preset/mcp/tools/tenant.py
|
||||
@mcp.tool
|
||||
@preset_tenant_auth_hook(['tenant:read'])
|
||||
def get_tenant_info(request: GetTenantInfoRequest) -> TenantInfoResponse:
|
||||
"""Get Preset tenant information and quotas."""
|
||||
|
||||
tenant_id = g.mcp_tenant_id
|
||||
tenant = get_tenant_by_id(tenant_id)
|
||||
|
||||
return TenantInfoResponse(
|
||||
tenant_id=tenant.id,
|
||||
name=tenant.name,
|
||||
plan=tenant.plan,
|
||||
quotas=tenant.quotas,
|
||||
usage=get_tenant_usage(tenant_id),
|
||||
workspaces=list_tenant_workspaces(tenant_id)
|
||||
)
|
||||
|
||||
@mcp.tool
|
||||
@preset_tenant_auth_hook(['workspace:read'])
|
||||
def list_workspace_assets(request: ListWorkspaceAssetsRequest) -> ListWorkspaceAssetsResponse:
|
||||
"""List all assets in a Preset workspace."""
|
||||
|
||||
workspace_id = request.workspace_id
|
||||
tenant_id = g.mcp_tenant_id
|
||||
|
||||
# Validate workspace belongs to tenant
|
||||
validate_workspace_access(workspace_id, tenant_id)
|
||||
|
||||
assets = {
|
||||
'dashboards': list_workspace_dashboards(workspace_id),
|
||||
'charts': list_workspace_charts(workspace_id),
|
||||
'datasets': list_workspace_datasets(workspace_id)
|
||||
}
|
||||
|
||||
return ListWorkspaceAssetsResponse(
|
||||
workspace_id=workspace_id,
|
||||
assets=assets,
|
||||
total_count=sum(len(v) for v in assets.values())
|
||||
)
|
||||
```
|
||||
|
||||
### Embed Token Generation
|
||||
|
||||
```python
|
||||
# preset/mcp/tools/embed.py
|
||||
@mcp.tool
|
||||
@preset_tenant_auth_hook(['dashboard:embed'])
|
||||
def generate_embed_token(request: GenerateEmbedTokenRequest) -> EmbedTokenResponse:
|
||||
"""Generate secure embed token for dashboard/chart."""
|
||||
|
||||
# Validate resource access
|
||||
resource = validate_embed_resource_access(
|
||||
resource_type=request.resource_type,
|
||||
resource_id=request.resource_id,
|
||||
tenant_id=g.mcp_tenant_id
|
||||
)
|
||||
|
||||
# Generate signed embed token
|
||||
embed_token = create_embed_token(
|
||||
resource=resource,
|
||||
user_id=g.user.id,
|
||||
tenant_id=g.mcp_tenant_id,
|
||||
permissions=request.permissions,
|
||||
expiry=request.expiry_hours
|
||||
)
|
||||
|
||||
return EmbedTokenResponse(
|
||||
embed_token=embed_token,
|
||||
embed_url=f"{get_preset_base_url()}/embed/{embed_token}",
|
||||
expires_at=embed_token.expires_at
|
||||
)
|
||||
```
|
||||
|
||||
## Audit and Compliance Extensions
|
||||
|
||||
### Enhanced Audit Logging
|
||||
|
||||
```python
|
||||
# preset/mcp/audit.py
|
||||
from superset.mcp_service.auth import get_audit_context
|
||||
|
||||
def create_preset_audit_context(user_context: dict, tool_name: str,
|
||||
request_data: dict) -> dict:
|
||||
"""Create Preset-specific audit context."""
|
||||
|
||||
base_context = get_audit_context(user_context, tool_name, request_data)
|
||||
|
||||
# Add Preset-specific fields
|
||||
preset_context = {
|
||||
**base_context,
|
||||
'tenant_id': user_context.get('tenant_id'),
|
||||
'workspace_id': user_context.get('workspace_id'),
|
||||
'preset_user_role': user_context.get('preset_role'),
|
||||
'data_classification': classify_request_data(request_data),
|
||||
'compliance_flags': get_compliance_flags(tool_name, request_data)
|
||||
}
|
||||
|
||||
return preset_context
|
||||
|
||||
def log_preset_mcp_access(audit_context: dict):
|
||||
"""Log MCP access to Preset audit systems."""
|
||||
|
||||
# Log to Superset's audit system
|
||||
log_superset_audit_event(audit_context)
|
||||
|
||||
# Log to Preset's compliance system
|
||||
log_preset_compliance_event(audit_context)
|
||||
|
||||
# Log to external SIEM if configured
|
||||
if app.config.get('PRESET_SIEM_ENABLED'):
|
||||
log_to_siem(audit_context)
|
||||
```
|
||||
|
||||
### Data Classification
|
||||
|
||||
```python
|
||||
# preset/mcp/classification.py
|
||||
def classify_request_data(request_data: dict) -> dict:
|
||||
"""Classify data sensitivity in MCP requests."""
|
||||
|
||||
classification = {
|
||||
'contains_pii': False,
|
||||
'data_level': 'public',
|
||||
'retention_policy': 'standard'
|
||||
}
|
||||
|
||||
# Check for PII in request
|
||||
if contains_pii_fields(request_data):
|
||||
classification['contains_pii'] = True
|
||||
classification['data_level'] = 'restricted'
|
||||
classification['retention_policy'] = 'pii_compliant'
|
||||
|
||||
# Check for sensitive datasets
|
||||
if references_sensitive_datasets(request_data):
|
||||
classification['data_level'] = 'confidential'
|
||||
|
||||
return classification
|
||||
```
|
||||
|
||||
## Deployment Considerations
|
||||
|
||||
### Multi-Region Deployment
|
||||
|
||||
```python
|
||||
# preset/mcp/deployment.py
|
||||
def get_region_specific_config():
|
||||
"""Get region-specific MCP configuration."""
|
||||
|
||||
region = os.environ.get('PRESET_REGION', 'us-east-1')
|
||||
|
||||
config_map = {
|
||||
'us-east-1': {
|
||||
'jwks_uri': 'https://auth-us.preset.io/.well-known/jwks.json',
|
||||
'base_url': 'https://app.preset.io',
|
||||
'data_residency': 'US'
|
||||
},
|
||||
'eu-west-1': {
|
||||
'jwks_uri': 'https://auth-eu.preset.io/.well-known/jwks.json',
|
||||
'base_url': 'https://eu.preset.io',
|
||||
'data_residency': 'EU'
|
||||
}
|
||||
}
|
||||
|
||||
return config_map.get(region, config_map['us-east-1'])
|
||||
|
||||
# Usage in config
|
||||
region_config = get_region_specific_config()
|
||||
PRESET_JWKS_URI = region_config['jwks_uri']
|
||||
SUPERSET_WEBSERVER_ADDRESS = region_config['base_url']
|
||||
```
|
||||
|
||||
### Health Check Extensions
|
||||
|
||||
```python
|
||||
# preset/mcp/health.py
|
||||
@mcp.tool
|
||||
def preset_health_check() -> HealthCheckResponse:
|
||||
"""Preset-specific health check for MCP service."""
|
||||
|
||||
checks = {
|
||||
'mcp_service': check_mcp_service_health(),
|
||||
'database': check_database_health(),
|
||||
'auth_provider': check_auth_provider_health(),
|
||||
'tenant_isolation': check_tenant_isolation(),
|
||||
'rls_engine': check_rls_engine_health()
|
||||
}
|
||||
|
||||
overall_status = 'healthy' if all(
|
||||
check['status'] == 'healthy' for check in checks.values()
|
||||
) else 'degraded'
|
||||
|
||||
return HealthCheckResponse(
|
||||
status=overall_status,
|
||||
checks=checks,
|
||||
region=os.environ.get('PRESET_REGION'),
|
||||
version=get_preset_mcp_version()
|
||||
)
|
||||
```
|
||||
|
||||
## Configuration Templates
|
||||
|
||||
### Production Configuration
|
||||
|
||||
```python
|
||||
# preset_production_config.py
|
||||
from preset.mcp.auth import create_preset_oidc_auth
|
||||
from preset.mcp.audit import create_preset_audit_context
|
||||
|
||||
# MCP Service Configuration
|
||||
MCP_AUTH_ENABLED = True
|
||||
MCP_AUTH_FACTORY = create_preset_oidc_auth
|
||||
MCP_AUDIT_CONTEXT_FACTORY = create_preset_audit_context
|
||||
|
||||
# Preset OIDC Configuration
|
||||
PRESET_OIDC_DISCOVERY_URL = "https://auth.preset.io/.well-known/openid_configuration"
|
||||
PRESET_OIDC_CLIENT_ID = "preset-mcp-production"
|
||||
PRESET_MCP_AUDIENCE = "preset-superset-mcp"
|
||||
|
||||
# Security Configuration
|
||||
PRESET_MCP_REQUIRED_SCOPES = [
|
||||
"openid", "profile", "email",
|
||||
"tenant:read", "workspace:read",
|
||||
"dashboard:read", "chart:read", "dataset:read"
|
||||
]
|
||||
|
||||
# Audit Configuration
|
||||
PRESET_AUDIT_ENABLED = True
|
||||
PRESET_SIEM_ENABLED = True
|
||||
PRESET_COMPLIANCE_MODE = "SOC2"
|
||||
|
||||
# Performance Configuration
|
||||
PRESET_MCP_CACHE_ENABLED = True
|
||||
PRESET_MCP_RATE_LIMIT = "1000/hour"
|
||||
PRESET_MCP_TIMEOUT = 30
|
||||
```
|
||||
|
||||
This integration guide provides the Preset.io team with concrete extension points for implementing enterprise features while maintaining compatibility with the base MCP service architecture.
|
||||
@@ -87,6 +87,16 @@ const sidebars = {
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
type: 'category',
|
||||
label: 'MCP Service',
|
||||
items: [
|
||||
{
|
||||
type: 'autogenerated',
|
||||
dirName: 'mcp-service',
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
type: 'doc',
|
||||
label: 'FAQ',
|
||||
|
||||
@@ -133,6 +133,7 @@ solr = ["sqlalchemy-solr >= 0.2.0"]
|
||||
elasticsearch = ["elasticsearch-dbapi>=0.2.9, <0.3.0"]
|
||||
exasol = ["sqlalchemy-exasol >= 2.4.0, <3.0"]
|
||||
excel = ["xlrd>=1.2.0, <1.3"]
|
||||
fastmcp = ["fastmcp>=2.8.1"]
|
||||
firebird = ["sqlalchemy-firebird>=0.7.0, <0.8"]
|
||||
firebolt = ["firebolt-sqlalchemy>=1.0.0, <2"]
|
||||
gevent = ["gevent>=23.9.1"]
|
||||
@@ -202,6 +203,7 @@ development = [
|
||||
"pyinstrument>=4.0.2,<5",
|
||||
"pylint",
|
||||
"pytest<8.0.0", # hairy issue with pytest >=8 where current_app proxies are not set in time
|
||||
"pytest-asyncio", # need this due to not using latest pytest
|
||||
"pytest-cov",
|
||||
"pytest-mock",
|
||||
"python-ldap>=3.4.4",
|
||||
|
||||
@@ -16,4 +16,4 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
#
|
||||
-e .[development,bigquery,druid,gevent,gsheets,mysql,postgres,presto,prophet,trino,thumbnails]
|
||||
-e .[development,bigquery,druid,fastmcp,gevent,gsheets,mysql,postgres,presto,prophet,trino,thumbnails]
|
||||
|
||||
@@ -10,6 +10,14 @@ amqp==5.3.1
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# kombu
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anyio==4.9.0
|
||||
# via
|
||||
# httpx
|
||||
# mcp
|
||||
# sse-starlette
|
||||
# starlette
|
||||
apispec==6.6.1
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -24,11 +32,14 @@ attrs==25.3.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# cattrs
|
||||
# cyclopts
|
||||
# jsonschema
|
||||
# outcome
|
||||
# referencing
|
||||
# requests-cache
|
||||
# trio
|
||||
authlib==1.6.1
|
||||
# via fastmcp
|
||||
babel==2.17.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -77,6 +88,8 @@ celery==5.5.2
|
||||
certifi==2025.6.15
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
# selenium
|
||||
cffi==1.17.1
|
||||
@@ -101,6 +114,7 @@ click==8.2.1
|
||||
# click-repl
|
||||
# flask
|
||||
# flask-appbuilder
|
||||
# uvicorn
|
||||
click-didyoumean==0.3.1
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -140,10 +154,13 @@ cryptography==44.0.3
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# apache-superset
|
||||
# authlib
|
||||
# paramiko
|
||||
# pyopenssl
|
||||
cycler==0.12.1
|
||||
# via matplotlib
|
||||
cyclopts==3.22.2
|
||||
# via fastmcp
|
||||
db-dtypes==1.3.1
|
||||
# via pandas-gbq
|
||||
defusedxml==0.7.1
|
||||
@@ -168,14 +185,23 @@ dnspython==2.7.0
|
||||
# email-validator
|
||||
docker==7.0.0
|
||||
# via apache-superset
|
||||
docstring-parser==0.17.0
|
||||
# via cyclopts
|
||||
docutils==0.21.2
|
||||
# via rich-rst
|
||||
email-validator==2.2.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# flask-appbuilder
|
||||
# pydantic
|
||||
et-xmlfile==2.0.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# openpyxl
|
||||
exceptiongroup==1.3.0
|
||||
# via fastmcp
|
||||
fastmcp==2.10.6
|
||||
# via apache-superset
|
||||
filelock==3.12.2
|
||||
# via virtualenv
|
||||
flask==2.3.3
|
||||
@@ -327,6 +353,8 @@ gunicorn==23.0.0
|
||||
h11==0.16.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# httpcore
|
||||
# uvicorn
|
||||
# wsproto
|
||||
hashids==1.3.1
|
||||
# via
|
||||
@@ -337,6 +365,14 @@ holidays==0.25
|
||||
# -c requirements/base.txt
|
||||
# apache-superset
|
||||
# prophet
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# fastmcp
|
||||
# mcp
|
||||
httpx-sse==0.4.1
|
||||
# via mcp
|
||||
humanize==4.12.3
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -346,7 +382,9 @@ identify==2.5.36
|
||||
idna==3.10
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# anyio
|
||||
# email-validator
|
||||
# httpx
|
||||
# requests
|
||||
# trio
|
||||
# url-normalize
|
||||
@@ -378,6 +416,7 @@ jsonschema==4.23.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# flask-appbuilder
|
||||
# mcp
|
||||
# openapi-schema-validator
|
||||
# openapi-spec-validator
|
||||
jsonschema-path==0.3.4
|
||||
@@ -437,6 +476,8 @@ matplotlib==3.9.0
|
||||
# via prophet
|
||||
mccabe==0.7.0
|
||||
# via pylint
|
||||
mcp==1.12.0
|
||||
# via fastmcp
|
||||
mdurl==0.1.2
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -475,6 +516,8 @@ odfpy==1.4.1
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# pandas
|
||||
openapi-pydantic==0.5.1
|
||||
# via fastmcp
|
||||
openapi-schema-validator==0.6.3
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -607,6 +650,16 @@ pycparser==2.22
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# cffi
|
||||
pydantic==2.11.7
|
||||
# via
|
||||
# fastmcp
|
||||
# mcp
|
||||
# openapi-pydantic
|
||||
# pydantic-settings
|
||||
pydantic-core==2.33.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.10.1
|
||||
# via mcp
|
||||
pydata-google-auth==1.9.0
|
||||
# via pandas-gbq
|
||||
pydruid==0.6.9
|
||||
@@ -642,6 +695,8 @@ pyparsing==3.2.3
|
||||
# -c requirements/base.txt
|
||||
# apache-superset
|
||||
# matplotlib
|
||||
pyperclip==1.9.0
|
||||
# via fastmcp
|
||||
pysocks==1.7.1
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -649,8 +704,11 @@ pysocks==1.7.1
|
||||
pytest==7.4.4
|
||||
# via
|
||||
# apache-superset
|
||||
# pytest-asyncio
|
||||
# pytest-cov
|
||||
# pytest-mock
|
||||
pytest-asyncio==0.23.8
|
||||
# via apache-superset
|
||||
pytest-cov==6.0.0
|
||||
# via apache-superset
|
||||
pytest-mock==3.10.0
|
||||
@@ -674,12 +732,16 @@ python-dotenv==1.1.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# apache-superset
|
||||
# fastmcp
|
||||
# pydantic-settings
|
||||
python-geohash==0.8.5
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# apache-superset
|
||||
python-ldap==3.4.4
|
||||
# via apache-superset
|
||||
python-multipart==0.0.20
|
||||
# via mcp
|
||||
pytz==2025.2
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -734,7 +796,12 @@ rfc3339-validator==0.1.4
|
||||
rich==13.9.4
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# cyclopts
|
||||
# fastmcp
|
||||
# flask-limiter
|
||||
# rich-rst
|
||||
rich-rst==1.3.1
|
||||
# via cyclopts
|
||||
rpds-py==0.25.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -779,6 +846,7 @@ slack-sdk==3.35.0
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# anyio
|
||||
# trio
|
||||
sortedcontainers==2.4.0
|
||||
# via
|
||||
@@ -808,10 +876,14 @@ sqlglot==27.3.0
|
||||
# apache-superset
|
||||
sqloxide==0.1.51
|
||||
# via apache-superset
|
||||
sse-starlette==2.4.1
|
||||
# via mcp
|
||||
sshtunnel==0.4.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# apache-superset
|
||||
starlette==0.47.2
|
||||
# via mcp
|
||||
statsd==4.0.1
|
||||
# via apache-superset
|
||||
tabulate==0.9.0
|
||||
@@ -839,13 +911,23 @@ typing-extensions==4.14.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
# alembic
|
||||
# anyio
|
||||
# apache-superset
|
||||
# cattrs
|
||||
# exceptiongroup
|
||||
# limits
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pyopenssl
|
||||
# referencing
|
||||
# selenium
|
||||
# shillelagh
|
||||
# starlette
|
||||
# typing-inspection
|
||||
typing-inspection==0.4.1
|
||||
# via
|
||||
# pydantic
|
||||
# pydantic-settings
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
@@ -864,6 +946,8 @@ urllib3==2.5.0
|
||||
# requests
|
||||
# requests-cache
|
||||
# selenium
|
||||
uvicorn==0.35.0
|
||||
# via mcp
|
||||
vine==5.1.0
|
||||
# via
|
||||
# -c requirements/base.txt
|
||||
|
||||
43
superset/cli/mcp.py
Normal file
43
superset/cli/mcp.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""CLI module for MCP service"""
|
||||
|
||||
import os
|
||||
|
||||
import click
|
||||
|
||||
from superset.mcp_service.server import run_server
|
||||
|
||||
|
||||
@click.group()
|
||||
def mcp() -> None:
|
||||
"""Model Context Protocol service commands"""
|
||||
pass
|
||||
|
||||
|
||||
@mcp.command()
|
||||
@click.option("--host", default="127.0.0.1", help="Host to bind to")
|
||||
@click.option("--port", default=5008, help="Port to bind to")
|
||||
@click.option("--debug", is_flag=True, help="Enable debug mode")
|
||||
@click.option("--sql-debug", is_flag=True, help="Enable SQL query logging")
|
||||
def run(host: str, port: int, debug: bool, sql_debug: bool) -> None:
|
||||
"""Run the MCP service"""
|
||||
if sql_debug:
|
||||
os.environ["SQLALCHEMY_DEBUG"] = "1"
|
||||
click.echo("🔍 SQL Debug mode enabled")
|
||||
|
||||
run_server(host=host, port=port, debug=debug)
|
||||
@@ -37,6 +37,10 @@ class CreateFormDataCommand(BaseCommand):
|
||||
def __init__(self, cmd_params: CommandParameters):
|
||||
self._cmd_params = cmd_params
|
||||
|
||||
def _get_session_id(self) -> str:
|
||||
"""Get session ID. Can be overridden in subclasses."""
|
||||
return session.get("_id")
|
||||
|
||||
def run(self) -> str:
|
||||
self.validate()
|
||||
try:
|
||||
@@ -47,7 +51,7 @@ class CreateFormDataCommand(BaseCommand):
|
||||
form_data = self._cmd_params.form_data
|
||||
check_access(datasource_id, chart_id, datasource_type)
|
||||
contextual_key = cache_key(
|
||||
session.get("_id"), tab_id, datasource_id, chart_id, datasource_type
|
||||
self._get_session_id(), tab_id, datasource_id, chart_id, datasource_type
|
||||
)
|
||||
key = cache_manager.explore_form_data_cache.get(contextual_key)
|
||||
if not key or not tab_id:
|
||||
|
||||
@@ -16,18 +16,87 @@
|
||||
# under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Generic, get_args, TypeVar
|
||||
import logging
|
||||
import uuid as uuid_lib
|
||||
from enum import Enum
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
Generic,
|
||||
get_args,
|
||||
List,
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
)
|
||||
|
||||
from flask_appbuilder.models.filters import BaseFilter
|
||||
from flask_appbuilder.models.sqla import Model
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import asc, cast, desc, or_, Text
|
||||
from sqlalchemy.exc import StatementError
|
||||
from sqlalchemy.inspection import inspect
|
||||
from sqlalchemy.orm import ColumnProperty, joinedload, RelationshipProperty
|
||||
|
||||
from superset.extensions import db
|
||||
|
||||
T = TypeVar("T", bound=Model)
|
||||
|
||||
|
||||
class ColumnOperatorEnum(str, Enum):
|
||||
eq = "eq"
|
||||
ne = "ne"
|
||||
sw = "sw"
|
||||
ew = "ew"
|
||||
in_ = "in"
|
||||
nin = "nin"
|
||||
gt = "gt"
|
||||
gte = "gte"
|
||||
lt = "lt"
|
||||
lte = "lte"
|
||||
like = "like"
|
||||
ilike = "ilike"
|
||||
is_null = "is_null"
|
||||
is_not_null = "is_not_null"
|
||||
|
||||
@classmethod
|
||||
def operator_map(cls) -> Dict[ColumnOperatorEnum, Any]:
|
||||
return {
|
||||
cls.eq: lambda col, val: col == val,
|
||||
cls.ne: lambda col, val: col != val,
|
||||
cls.sw: lambda col, val: col.like(f"{val}%"),
|
||||
cls.ew: lambda col, val: col.like(f"%{val}"),
|
||||
cls.in_: lambda col, val: col.in_(
|
||||
val if isinstance(val, (list, tuple)) else [val]
|
||||
),
|
||||
cls.nin: lambda col, val: ~col.in_(
|
||||
val if isinstance(val, (list, tuple)) else [val]
|
||||
),
|
||||
cls.gt: lambda col, val: col > val,
|
||||
cls.gte: lambda col, val: col >= val,
|
||||
cls.lt: lambda col, val: col < val,
|
||||
cls.lte: lambda col, val: col <= val,
|
||||
cls.like: lambda col, val: col.like(f"%{val}%"),
|
||||
cls.ilike: lambda col, val: col.ilike(f"%{val}%"),
|
||||
cls.is_null: lambda col, _: col.is_(None),
|
||||
cls.is_not_null: lambda col, _: col.isnot(None),
|
||||
}
|
||||
|
||||
def apply(self, column: Any, value: Any) -> Any:
|
||||
op_func = self.operator_map().get(self)
|
||||
if not op_func:
|
||||
raise ValueError(f"Unsupported operator: {self}")
|
||||
return op_func(column, value)
|
||||
|
||||
|
||||
class ColumnOperator(BaseModel):
|
||||
col: str = Field(..., description="Column name to filter on")
|
||||
opr: ColumnOperatorEnum = Field(..., description="Operator")
|
||||
value: Any = Field(None, description="Value for the filter")
|
||||
|
||||
|
||||
class BaseDAO(Generic[T]):
|
||||
"""
|
||||
Base DAO, implement base CRUD sqlalchemy operations
|
||||
@@ -50,45 +119,128 @@ class BaseDAO(Generic[T]):
|
||||
)[0]
|
||||
|
||||
@classmethod
|
||||
def find_by_id(
|
||||
cls,
|
||||
model_id: str | int,
|
||||
skip_base_filter: bool = False,
|
||||
) -> T | None:
|
||||
def _apply_base_filter(
|
||||
cls, query: Any, skip_base_filter: bool = False, data_model: Any = None
|
||||
) -> Any:
|
||||
"""
|
||||
Find a model by id, if defined applies `base_filter`
|
||||
Apply the base_filter to the query if it exists and skip_base_filter is False.
|
||||
"""
|
||||
query = db.session.query(cls.model_cls)
|
||||
if cls.base_filter and not skip_base_filter:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
if data_model is None:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
id_column = getattr(cls.model_cls, cls.id_column_name)
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def _convert_value_for_column(cls, column: Any, value: Any) -> Any:
|
||||
"""
|
||||
Convert a value to the appropriate type for a given SQLAlchemy column.
|
||||
|
||||
Args:
|
||||
column: SQLAlchemy column object
|
||||
value: Value to convert
|
||||
|
||||
Returns:
|
||||
Converted value or None if conversion fails
|
||||
"""
|
||||
if (
|
||||
hasattr(column.type, "python_type")
|
||||
and column.type.python_type == uuid_lib.UUID
|
||||
):
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
return uuid_lib.UUID(value)
|
||||
except (ValueError, AttributeError):
|
||||
return None
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
def _find_by_column(
|
||||
cls,
|
||||
column_name: str,
|
||||
value: str | int,
|
||||
skip_base_filter: bool = False,
|
||||
) -> T | None:
|
||||
"""
|
||||
Private method to find a model by any column value.
|
||||
|
||||
Args:
|
||||
column_name: Name of the column to search by
|
||||
value: Value to search for
|
||||
skip_base_filter: Whether to skip base filtering
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
"""
|
||||
query = db.session.query(cls.model_cls)
|
||||
query = cls._apply_base_filter(query, skip_base_filter)
|
||||
|
||||
if not hasattr(cls.model_cls, column_name):
|
||||
return None
|
||||
|
||||
column = getattr(cls.model_cls, column_name)
|
||||
converted_value = cls._convert_value_for_column(column, value)
|
||||
if converted_value is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
return query.filter(id_column == model_id).one_or_none()
|
||||
return query.filter(column == converted_value).one_or_none()
|
||||
except StatementError:
|
||||
# can happen if int is passed instead of a string or similar
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def find_by_id(
|
||||
cls,
|
||||
model_id: str | int,
|
||||
skip_base_filter: bool = False,
|
||||
id_column: str | None = None,
|
||||
) -> T | None:
|
||||
"""
|
||||
Find a model by ID using specified or default ID column.
|
||||
|
||||
Args:
|
||||
model_id: ID value to search for
|
||||
skip_base_filter: Whether to skip base filtering
|
||||
id_column: Column name to use (defaults to cls.id_column_name)
|
||||
|
||||
Returns:
|
||||
Model instance or None if not found
|
||||
"""
|
||||
column = id_column or cls.id_column_name
|
||||
return cls._find_by_column(column, model_id, skip_base_filter)
|
||||
|
||||
@classmethod
|
||||
def find_by_ids(
|
||||
cls,
|
||||
model_ids: list[str] | list[int],
|
||||
model_ids: Sequence[str | int],
|
||||
skip_base_filter: bool = False,
|
||||
id_column: str | None = None,
|
||||
) -> list[T]:
|
||||
"""
|
||||
Find a List of models by a list of ids, if defined applies `base_filter`
|
||||
|
||||
:param model_ids: List of IDs to find
|
||||
:param skip_base_filter: If true, skip applying the base filter
|
||||
:param id_column: Optional column name to use for ID lookup
|
||||
(defaults to id_column_name)
|
||||
"""
|
||||
id_col = getattr(cls.model_cls, cls.id_column_name, None)
|
||||
column = id_column or cls.id_column_name
|
||||
id_col = getattr(cls.model_cls, column, None)
|
||||
if id_col is None:
|
||||
return []
|
||||
query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids))
|
||||
if cls.base_filter and not skip_base_filter:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
|
||||
# Convert IDs to appropriate types based on column type
|
||||
converted_ids: list[str | int | uuid_lib.UUID] = []
|
||||
for id_val in model_ids:
|
||||
converted_value = cls._convert_value_for_column(id_col, id_val)
|
||||
if converted_value is not None:
|
||||
converted_ids.append(converted_value)
|
||||
|
||||
query = db.session.query(cls.model_cls).filter(id_col.in_(converted_ids))
|
||||
query = cls._apply_base_filter(query, skip_base_filter)
|
||||
return query.all()
|
||||
|
||||
@classmethod
|
||||
@@ -97,11 +249,7 @@ class BaseDAO(Generic[T]):
|
||||
Get all that fit the `base_filter`
|
||||
"""
|
||||
query = db.session.query(cls.model_cls)
|
||||
if cls.base_filter:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
query = cls._apply_base_filter(query)
|
||||
return query.all()
|
||||
|
||||
@classmethod
|
||||
@@ -110,11 +258,7 @@ class BaseDAO(Generic[T]):
|
||||
Get the first that fit the `base_filter`
|
||||
"""
|
||||
query = db.session.query(cls.model_cls)
|
||||
if cls.base_filter:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = cls.base_filter( # pylint: disable=not-callable
|
||||
cls.id_column_name, data_model
|
||||
).apply(query, None)
|
||||
query = cls._apply_base_filter(query)
|
||||
return query.filter_by(**filter_by).one_or_none()
|
||||
|
||||
@classmethod
|
||||
@@ -184,3 +328,247 @@ class BaseDAO(Generic[T]):
|
||||
|
||||
for item in items:
|
||||
db.session.delete(item)
|
||||
|
||||
@classmethod
|
||||
def apply_column_operators(
|
||||
cls, query: Any, column_operators: Optional[List[ColumnOperator]] = None
|
||||
) -> Any:
|
||||
"""
|
||||
Apply column operators (list of ColumnOperator) to the query using
|
||||
ColumnOperatorEnum logic. Raises ValueError if a filter references a
|
||||
non-existent column.
|
||||
"""
|
||||
if not column_operators:
|
||||
return query
|
||||
for c in column_operators:
|
||||
if not isinstance(c, ColumnOperator):
|
||||
continue
|
||||
col = c.col
|
||||
opr = c.opr
|
||||
value = c.value
|
||||
if not col or not hasattr(cls.model_cls, col):
|
||||
model_name = cls.model_cls.__name__ if cls.model_cls else "Unknown"
|
||||
logging.error(
|
||||
f"Invalid filter: column '{col}' does not exist on {model_name}"
|
||||
)
|
||||
raise ValueError(
|
||||
f"Invalid filter: column '{col}' does not exist on {model_name}"
|
||||
)
|
||||
column = getattr(cls.model_cls, col)
|
||||
try:
|
||||
# Always use ColumnOperatorEnum's apply method
|
||||
operator_enum = ColumnOperatorEnum(opr)
|
||||
query = query.filter(operator_enum.apply(column, value))
|
||||
except Exception as e:
|
||||
logging.error(f"Error applying filter on column '{col}': {e}")
|
||||
raise
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
|
||||
"""
|
||||
Returns a dict mapping filterable columns (including hybrid/computed fields if
|
||||
present) to their supported operators. Used by MCP tools to dynamically expose
|
||||
filter options. Custom fields supported by the DAO but not present on the model
|
||||
should be documented here.
|
||||
"""
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
mapper = inspect(cls.model_cls)
|
||||
columns = {c.key: c for c in mapper.columns}
|
||||
# Add hybrid properties
|
||||
hybrids = {
|
||||
name: attr
|
||||
for name, attr in vars(cls.model_cls).items()
|
||||
if isinstance(attr, hybrid_property)
|
||||
}
|
||||
# You may add custom fields here, e.g.:
|
||||
# custom_fields = {"tags": ["eq", "in_", "like"], ...}
|
||||
custom_fields: Dict[str, List[str]] = {}
|
||||
# Map SQLAlchemy types to supported operators
|
||||
type_operator_map = {
|
||||
"string": [
|
||||
"eq",
|
||||
"ne",
|
||||
"sw",
|
||||
"ew",
|
||||
"in_",
|
||||
"nin",
|
||||
"like",
|
||||
"ilike",
|
||||
"is_null",
|
||||
"is_not_null",
|
||||
],
|
||||
"boolean": ["eq", "ne", "is_null", "is_not_null"],
|
||||
"number": [
|
||||
"eq",
|
||||
"ne",
|
||||
"gt",
|
||||
"gte",
|
||||
"lt",
|
||||
"lte",
|
||||
"in_",
|
||||
"nin",
|
||||
"is_null",
|
||||
"is_not_null",
|
||||
],
|
||||
"datetime": [
|
||||
"eq",
|
||||
"ne",
|
||||
"gt",
|
||||
"gte",
|
||||
"lt",
|
||||
"lte",
|
||||
"in_",
|
||||
"nin",
|
||||
"is_null",
|
||||
"is_not_null",
|
||||
],
|
||||
}
|
||||
import sqlalchemy as sa
|
||||
|
||||
filterable = {}
|
||||
for name, col in columns.items():
|
||||
if isinstance(col.type, (sa.String, sa.Text)):
|
||||
filterable[name] = type_operator_map["string"]
|
||||
elif isinstance(col.type, (sa.Boolean,)):
|
||||
filterable[name] = type_operator_map["boolean"]
|
||||
elif isinstance(col.type, (sa.Integer, sa.Float, sa.Numeric)):
|
||||
filterable[name] = type_operator_map["number"]
|
||||
elif isinstance(col.type, (sa.DateTime, sa.Date, sa.Time)):
|
||||
filterable[name] = type_operator_map["datetime"]
|
||||
else:
|
||||
# Fallback to eq/ne/null
|
||||
filterable[name] = ["eq", "ne", "is_null", "is_not_null"]
|
||||
# Add hybrid properties as string fields by default
|
||||
for name in hybrids:
|
||||
filterable[name] = type_operator_map["string"]
|
||||
# Add custom fields
|
||||
filterable.update(custom_fields)
|
||||
return filterable
|
||||
|
||||
@classmethod
|
||||
def _build_query(
|
||||
cls,
|
||||
column_operators: Optional[List[ColumnOperator]] = None,
|
||||
search: Optional[str] = None,
|
||||
search_columns: Optional[List[str]] = None,
|
||||
custom_filters: Optional[Dict[str, BaseFilter]] = None,
|
||||
skip_base_filter: bool = False,
|
||||
data_model: Optional[SQLAInterface] = None,
|
||||
) -> Any:
|
||||
"""
|
||||
Build a SQLAlchemy query with base filter, column operators, search, and
|
||||
custom filters.
|
||||
"""
|
||||
if data_model is None:
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
query = data_model.session.query(cls.model_cls)
|
||||
query = cls._apply_base_filter(
|
||||
query, skip_base_filter=skip_base_filter, data_model=data_model
|
||||
)
|
||||
if search and search_columns:
|
||||
search_filters = []
|
||||
for column_name in search_columns:
|
||||
if hasattr(cls.model_cls, column_name):
|
||||
column = getattr(cls.model_cls, column_name)
|
||||
search_filters.append(cast(column, Text).ilike(f"%{search}%"))
|
||||
if search_filters:
|
||||
query = query.filter(or_(*search_filters))
|
||||
if custom_filters:
|
||||
for filter_class in custom_filters.values():
|
||||
query = filter_class.apply(query, None)
|
||||
if column_operators:
|
||||
query = cls.apply_column_operators(query, column_operators)
|
||||
return query
|
||||
|
||||
@classmethod
|
||||
def list( # noqa: C901
|
||||
cls,
|
||||
column_operators: Optional[List[ColumnOperator]] = None,
|
||||
order_column: str = "changed_on",
|
||||
order_direction: str = "desc",
|
||||
page: int = 0,
|
||||
page_size: int = 100,
|
||||
search: Optional[str] = None,
|
||||
search_columns: Optional[List[str]] = None,
|
||||
custom_filters: Optional[Dict[str, BaseFilter]] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
) -> Tuple[List[Any], int]:
|
||||
"""
|
||||
Generic list method for filtered, sorted, and paginated results.
|
||||
If columns is specified, returns a list of tuples (one per row),
|
||||
otherwise returns model instances.
|
||||
"""
|
||||
data_model = SQLAInterface(cls.model_cls, db.session)
|
||||
|
||||
column_attrs = []
|
||||
relationship_loads = []
|
||||
if columns is None:
|
||||
columns = []
|
||||
for name in columns:
|
||||
attr = getattr(cls.model_cls, name, None)
|
||||
if attr is None:
|
||||
continue
|
||||
prop = getattr(attr, "property", None)
|
||||
if isinstance(prop, ColumnProperty):
|
||||
column_attrs.append(attr)
|
||||
elif isinstance(prop, RelationshipProperty):
|
||||
relationship_loads.append(joinedload(attr))
|
||||
# Ignore properties and other non-queryable attributes
|
||||
|
||||
if relationship_loads:
|
||||
# If any relationships are requested, query the full model and joinedload
|
||||
# relationships
|
||||
query = data_model.session.query(cls.model_cls)
|
||||
for loader in relationship_loads:
|
||||
query = query.options(loader)
|
||||
elif column_attrs:
|
||||
# Only columns requested
|
||||
query = data_model.session.query(*column_attrs)
|
||||
else:
|
||||
# Fallback: query the full model
|
||||
query = data_model.session.query(cls.model_cls)
|
||||
query = cls._apply_base_filter(query, data_model=data_model)
|
||||
if search and search_columns:
|
||||
search_filters = []
|
||||
for column_name in search_columns:
|
||||
if hasattr(cls.model_cls, column_name):
|
||||
column = getattr(cls.model_cls, column_name)
|
||||
search_filters.append(cast(column, Text).ilike(f"%{search}%"))
|
||||
if search_filters:
|
||||
query = query.filter(or_(*search_filters))
|
||||
if custom_filters:
|
||||
for filter_class in custom_filters.values():
|
||||
query = filter_class.apply(query, None)
|
||||
if column_operators:
|
||||
query = cls.apply_column_operators(query, column_operators)
|
||||
total_count = query.count()
|
||||
if hasattr(cls.model_cls, order_column):
|
||||
column = getattr(cls.model_cls, order_column)
|
||||
if order_direction.lower() == "desc":
|
||||
query = query.order_by(desc(column))
|
||||
else:
|
||||
query = query.order_by(asc(column))
|
||||
page = page
|
||||
page_size = max(page_size, 1)
|
||||
query = query.offset(page * page_size).limit(page_size)
|
||||
items = query.all()
|
||||
# If columns are specified, SQLAlchemy returns Row objects (not tuples or
|
||||
# model instances)
|
||||
return items, total_count
|
||||
|
||||
@classmethod
|
||||
def count(
|
||||
cls,
|
||||
column_operators: Optional[List[ColumnOperator]] = None,
|
||||
skip_base_filter: bool = False,
|
||||
) -> int:
|
||||
"""
|
||||
Count the number of records for the model, optionally filtered by column
|
||||
operators.
|
||||
"""
|
||||
query = cls._build_query(
|
||||
column_operators=column_operators, skip_base_filter=skip_base_filter
|
||||
)
|
||||
return query.count()
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Dict, List, TYPE_CHECKING
|
||||
|
||||
from superset.charts.filters import ChartFilter
|
||||
from superset.daos.base import BaseDAO
|
||||
@@ -36,6 +36,20 @@ logger = logging.getLogger(__name__)
|
||||
class ChartDAO(BaseDAO[Slice]):
|
||||
base_filter = ChartFilter
|
||||
|
||||
@classmethod
|
||||
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
|
||||
filterable = super().get_filterable_columns_and_operators()
|
||||
# Add custom fields for charts
|
||||
filterable.update(
|
||||
{
|
||||
"tags": ["eq", "in_", "like"],
|
||||
"owner": ["eq", "in_"],
|
||||
"viz_type": ["eq", "in_", "like"],
|
||||
"datasource_name": ["eq", "in_", "like"],
|
||||
}
|
||||
)
|
||||
return filterable
|
||||
|
||||
@staticmethod
|
||||
def favorited_ids(charts: list[Slice]) -> list[FavStar]:
|
||||
ids = [chart.id for chart in charts]
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from flask import g
|
||||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||||
@@ -48,6 +48,20 @@ logger = logging.getLogger(__name__)
|
||||
class DashboardDAO(BaseDAO[Dashboard]):
|
||||
base_filter = DashboardAccessFilter
|
||||
|
||||
@classmethod
|
||||
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
|
||||
filterable = super().get_filterable_columns_and_operators()
|
||||
# Add custom fields for dashboards
|
||||
filterable.update(
|
||||
{
|
||||
"tags": ["eq", "in_", "like"],
|
||||
"owner": ["eq", "in_"],
|
||||
"published": ["eq"],
|
||||
"favorite": ["eq"],
|
||||
}
|
||||
)
|
||||
return filterable
|
||||
|
||||
@classmethod
|
||||
def get_by_id_or_slug(cls, id_or_slug: int | str) -> Dashboard:
|
||||
if is_uuid(id_or_slug):
|
||||
|
||||
@@ -18,7 +18,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import dateutil.parser
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
@@ -37,6 +37,13 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetDAO(BaseDAO[SqlaTable]):
|
||||
"""
|
||||
DAO for datasets. Supports filtering on model fields, hybrid properties, and custom
|
||||
fields:
|
||||
- tags: list of tags (eq, in_, like)
|
||||
- owner: user id (eq, in_)
|
||||
"""
|
||||
|
||||
base_filter = DatasourceFilter
|
||||
|
||||
@staticmethod
|
||||
@@ -351,6 +358,18 @@ class DatasetDAO(BaseDAO[SqlaTable]):
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
|
||||
filterable = super().get_filterable_columns_and_operators()
|
||||
# Add custom fields
|
||||
filterable.update(
|
||||
{
|
||||
"tags": ["eq", "in_", "like"],
|
||||
"owner": ["eq", "in_"],
|
||||
}
|
||||
)
|
||||
return filterable
|
||||
|
||||
|
||||
class DatasetColumnDAO(BaseDAO[TableColumn]):
|
||||
pass
|
||||
|
||||
126
superset/mcp_service/CHART_GENERATION_TODO.md
Normal file
126
superset/mcp_service/CHART_GENERATION_TODO.md
Normal file
@@ -0,0 +1,126 @@
|
||||
# Chart Generation Improvement Plan
|
||||
|
||||
Based on user feedback from testing the chart generation API, this document tracks improvements needed for the MCP service chart generation functionality.
|
||||
|
||||
## Status: Active Development
|
||||
|
||||
### 1. **Fix ASCII Preview Rendering** 🟡 In Progress
|
||||
**Status**: Partially Complete
|
||||
**Issue**: ASCII previews show "Range: nan to nan" for time series data
|
||||
**Tasks:**
|
||||
- [x] Fix ASCII chart rendering for time series/datetime data
|
||||
- [x] Add proper NaN/null value handling in ASCII generation
|
||||
- [x] Implement fallback messages when data can't be visualized
|
||||
- [ ] Add unit tests for edge cases (empty data, NaN values, date formats)
|
||||
|
||||
### 2. **Enhance Error Messages with Context** 🟢 Complete
|
||||
**Status**: Complete
|
||||
**Issue**: Generic error messages without helpful context
|
||||
**Tasks:**
|
||||
- [x] Create detailed error response schema with:
|
||||
- [x] Invalid field name
|
||||
- [x] Available columns list
|
||||
- [x] Fuzzy matching suggestions for typos
|
||||
- [x] Data type mismatches
|
||||
- [x] Implement column validation with helpful error messages
|
||||
- [x] Add dataset schema introspection for better error context
|
||||
|
||||
### 3. **Fix Table Chart Aggregation** 🟢 Complete
|
||||
**Status**: Complete
|
||||
**Issue**: Table previews truncate headers, unexpected aggregation behavior
|
||||
**Tasks:**
|
||||
- [x] Fix column header truncation in table previews
|
||||
- [x] Clarify GROUP BY behavior for non-aggregated columns
|
||||
- [x] Improve table formatting with proper column width calculation
|
||||
- [x] Add option to control grouping behavior explicitly
|
||||
- [x] Document expected table aggregation behavior
|
||||
|
||||
### 4. **Fix Preview Generation Consistency** 🟢 Complete
|
||||
**Status**: Complete
|
||||
**Issue**: Previews not generated when `save_chart=false`
|
||||
**Tasks:**
|
||||
- [x] Ensure preview generation works regardless of save_chart flag
|
||||
- [x] Fix the logic flow to generate previews before/after save
|
||||
- [x] Add preview generation from form data for unsaved charts
|
||||
- [x] Remove base64 preview support (never return base64)
|
||||
- [ ] Add integration tests for all preview generation scenarios
|
||||
- [ ] Validate preview_formats parameter is respected
|
||||
|
||||
### 5. **Implement Rich Performance Analytics** 🟢 Medium Priority
|
||||
**Status**: Not Started
|
||||
**Issue**: Generic performance feedback
|
||||
**Tasks:**
|
||||
- [ ] Add query analysis with specific optimization suggestions:
|
||||
- [ ] Index recommendations based on filter columns
|
||||
- [ ] Partitioning suggestions for large datasets
|
||||
- [ ] Caching recommendations with specific TTL values
|
||||
- [ ] Include metrics:
|
||||
- [ ] Rows processed
|
||||
- [ ] Bytes scanned
|
||||
- [ ] Execution plan hints
|
||||
- [ ] Implement cost estimation when available
|
||||
|
||||
### 6. **Enhance Semantic Analysis** 🟢 Medium Priority
|
||||
**Status**: Not Started
|
||||
**Issue**: Basic semantic responses without actual insights
|
||||
**Tasks:**
|
||||
- [ ] Implement statistical analysis:
|
||||
- [ ] Trend detection (increasing/decreasing/stable)
|
||||
- [ ] Seasonality detection
|
||||
- [ ] Outlier detection with specific values
|
||||
- [ ] Growth rate calculations
|
||||
- [ ] Add data storytelling:
|
||||
- [ ] Key insights based on actual data
|
||||
- [ ] Anomaly descriptions with context
|
||||
- [ ] Comparative analysis (YoY, MoM)
|
||||
- [ ] Include summary statistics in response
|
||||
|
||||
### 7. **Additional Improvements** 🔵 Low Priority
|
||||
**Status**: Not Started
|
||||
- [ ] Add preview format validation
|
||||
- [ ] Implement preview size constraints
|
||||
- [ ] Add chart type validation against dataset characteristics
|
||||
- [ ] Improve caching for preview generation
|
||||
- [ ] Add preview quality options (low/medium/high)
|
||||
|
||||
## Implementation Timeline
|
||||
|
||||
### Phase 1 (Critical Fixes - Current Sprint)
|
||||
1. Fix preview generation when `save_chart=false` (#4)
|
||||
2. Fix ASCII preview NaN handling (#1)
|
||||
3. Fix table header truncation (#3)
|
||||
|
||||
### Phase 2 (Error Handling - Next Sprint)
|
||||
1. Implement enhanced error response schema (#2)
|
||||
2. Add column validation with suggestions (#2)
|
||||
3. Add comprehensive error tests (#2)
|
||||
|
||||
### Phase 3 (Data Quality - Sprint 3)
|
||||
1. Fix table aggregation behavior (#3)
|
||||
2. Implement semantic analysis engine (#6)
|
||||
3. Add statistical calculations (#6)
|
||||
|
||||
### Phase 4 (Performance - Sprint 4)
|
||||
1. Add performance analytics (#5)
|
||||
2. Implement optimization suggestions (#5)
|
||||
3. Add cost estimation (#5)
|
||||
|
||||
## Testing Requirements
|
||||
- Unit tests for each component
|
||||
- Integration tests for full chart generation flow
|
||||
- Edge case testing (empty data, large datasets, special characters)
|
||||
- Performance benchmarking
|
||||
|
||||
## Documentation Requirements
|
||||
- API documentation with examples
|
||||
- Error response catalog
|
||||
- Best practices guide
|
||||
- Migration guide for breaking changes
|
||||
|
||||
## Progress Tracking
|
||||
- 🔴 Critical - Must fix immediately
|
||||
- 🟡 High Priority - Fix in current release
|
||||
- 🟢 Medium Priority - Plan for next release
|
||||
- 🔵 Low Priority - Nice to have
|
||||
|
||||
Last Updated: 2025-07-30
|
||||
150
superset/mcp_service/DEMO_SCRIPT.md
Normal file
150
superset/mcp_service/DEMO_SCRIPT.md
Normal file
@@ -0,0 +1,150 @@
|
||||
# MCP Service Demo Script for Claude Desktop
|
||||
|
||||
This is a safe, read-only demo script to showcase the MCP service capabilities. Run these commands in sequence in Claude Desktop.
|
||||
|
||||
## Prerequisites
|
||||
- Ensure Superset is running locally on port 8088
|
||||
- MCP service should be running on port 5008
|
||||
- You should have some sample data loaded
|
||||
|
||||
## Demo Script
|
||||
|
||||
### 1. Check Instance Health
|
||||
```
|
||||
First, let's verify the Superset instance is running and get some basic stats:
|
||||
|
||||
Use the get_superset_instance_info tool
|
||||
```
|
||||
|
||||
### 2. List Available Datasets
|
||||
```
|
||||
Now let's see what datasets are available:
|
||||
|
||||
Use the list_datasets tool with these parameters:
|
||||
- page: 1
|
||||
- page_size: 5
|
||||
```
|
||||
|
||||
### 3. Get Dataset Details
|
||||
```
|
||||
Pick a dataset ID from the list above and get detailed information:
|
||||
|
||||
Use the get_dataset_info tool with:
|
||||
- dataset_id: [ID from previous list]
|
||||
```
|
||||
|
||||
### 4. List Dashboards
|
||||
```
|
||||
Let's explore existing dashboards:
|
||||
|
||||
Use the list_dashboards tool with:
|
||||
- page: 1
|
||||
- page_size: 5
|
||||
```
|
||||
|
||||
### 5. Get Dashboard Details
|
||||
```
|
||||
Get details about a specific dashboard:
|
||||
|
||||
Use the get_dashboard_info tool with:
|
||||
- dashboard_id: [ID from dashboard list]
|
||||
```
|
||||
|
||||
### 6. List Charts with Filters
|
||||
```
|
||||
Let's see charts, filtered by a specific dataset:
|
||||
|
||||
Use the list_charts tool with:
|
||||
- page: 1
|
||||
- page_size: 5
|
||||
- filters: {"datasource_id": [dataset_id from step 3]}
|
||||
```
|
||||
|
||||
### 7. Get Chart Preview
|
||||
```
|
||||
Get a visual preview of a chart:
|
||||
|
||||
Use the get_chart_preview tool with:
|
||||
- chart_id: [ID from chart list]
|
||||
- format: "url"
|
||||
```
|
||||
|
||||
### 8. Generate Explore Link
|
||||
```
|
||||
Create a custom explore link for data analysis:
|
||||
|
||||
Use the generate_explore_link tool with:
|
||||
- dataset_id: [ID from step 2]
|
||||
- metrics: ["COUNT(*)"]
|
||||
- dimensions: ["[column_name from dataset info]"]
|
||||
- time_range: "Last week"
|
||||
```
|
||||
|
||||
### 9. Check Available Filters
|
||||
```
|
||||
See what filtering options are available for datasets:
|
||||
|
||||
Use the get_dataset_available_filters tool
|
||||
```
|
||||
|
||||
### 10. Advanced Dataset Search
|
||||
```
|
||||
Search for datasets with specific criteria:
|
||||
|
||||
Use the list_datasets tool with:
|
||||
- page: 1
|
||||
- page_size: 10
|
||||
- filters: {
|
||||
"database_name": {"operator": "contains", "value": "examples"},
|
||||
"table_name": {"operator": "contains", "value": "sales"}
|
||||
}
|
||||
- sort_by: "changed_on_delta_humanized"
|
||||
- sort_desc: true
|
||||
```
|
||||
|
||||
## Expected Results
|
||||
|
||||
Each command should return:
|
||||
- ✅ Structured JSON responses with detailed information
|
||||
- ✅ Preview URLs for charts (viewable in browser)
|
||||
- ✅ Metadata about relationships between entities
|
||||
- ✅ Human-readable timestamps and descriptions
|
||||
|
||||
## Safety Notes
|
||||
|
||||
- All operations in this demo are **read-only**
|
||||
- No data is modified or created
|
||||
- Preview URLs expire after cache timeout
|
||||
- Filters validate column names to prevent errors
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If you get errors:
|
||||
1. Verify Superset is running: `curl http://localhost:8088/health`
|
||||
2. Check MCP service is running on port 5008
|
||||
3. Ensure you have datasets loaded in Superset
|
||||
4. Use valid IDs from the list responses
|
||||
|
||||
## Advanced Demo (Optional)
|
||||
|
||||
For a more advanced demo showing chart generation capabilities:
|
||||
|
||||
```
|
||||
Create a simple table chart:
|
||||
|
||||
Use the generate_chart tool with:
|
||||
- dataset_id: [valid dataset ID]
|
||||
- config: {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "[column1_name]"},
|
||||
{"name": "[column2_name]", "aggregate": "COUNT"}
|
||||
]
|
||||
}
|
||||
- chart_name: "Demo Table Chart"
|
||||
- save_chart: false
|
||||
- generate_preview: true
|
||||
- preview_formats: ["table", "url"]
|
||||
```
|
||||
|
||||
This will generate a preview without saving the chart to the database.
|
||||
262
superset/mcp_service/ENTITY_TESTING_PLAN.md
Normal file
262
superset/mcp_service/ENTITY_TESTING_PLAN.md
Normal file
@@ -0,0 +1,262 @@
|
||||
# Superset Entity Testing Plan
|
||||
|
||||
## Overview
|
||||
This plan provides a systematic approach to test all parameter combinations for Superset list endpoints (`list_datasets`, `list_charts`, `list_dashboards`, etc.). Each test validates different aspects of the API functionality.
|
||||
|
||||
## Prerequisites
|
||||
1. Access to Superset MCP Proxy tools
|
||||
2. At least 10+ entities in the target category for meaningful testing
|
||||
3. Knowledge of available filter fields (use `get_[entity]_available_filters` first)
|
||||
|
||||
## Test Execution Steps
|
||||
|
||||
### Step 0: Preparation
|
||||
**Get Available Filters**
|
||||
```
|
||||
Tool: get_[entity]_available_filters
|
||||
Purpose: Understand filterable fields and supported operators
|
||||
```
|
||||
|
||||
### Step 1: Basic Default Parameters
|
||||
**Objective:** Validate basic functionality with minimal parameters
|
||||
```json
|
||||
{}
|
||||
```
|
||||
**Validates:**
|
||||
- Default pagination (page 1, page_size 100)
|
||||
- Default ordering (usually by changed_on desc)
|
||||
- Total count and basic entity structure
|
||||
|
||||
### Step 2: Pagination Parameters
|
||||
**Objective:** Test pagination controls
|
||||
```json
|
||||
{"page": 2, "page_size": 5}
|
||||
```
|
||||
**Validates:**
|
||||
- Custom page size working
|
||||
- Page navigation
|
||||
- Pagination metadata (total_pages, has_next, has_previous)
|
||||
|
||||
### Step 3: Ordering Parameters
|
||||
**Objective:** Test sorting functionality
|
||||
```json
|
||||
{"page_size": 10, "order_column": "[sortable_field]", "order_direction": "asc"}
|
||||
```
|
||||
**Common sortable fields:**
|
||||
- Datasets: `table_name`, `id`, `changed_on`, `created_on`
|
||||
- Charts: `slice_name`, `id`, `changed_on`, `created_on`
|
||||
- Dashboards: `dashboard_title`, `id`, `changed_on`, `created_on`
|
||||
|
||||
**Validates:**
|
||||
- Custom sorting working
|
||||
- Result order matches requested direction
|
||||
|
||||
### Step 4: Text Search
|
||||
**Objective:** Test search functionality
|
||||
```json
|
||||
{"search": "[common_term]", "page_size": 5}
|
||||
```
|
||||
**Common search terms:**
|
||||
- Datasets: "birth", "sales", "user"
|
||||
- Charts: "revenue", "sales", "time"
|
||||
- Dashboards: "dashboard", "overview"
|
||||
|
||||
**Validates:**
|
||||
- Search filtering working
|
||||
- Reduced total_count from filtering
|
||||
|
||||
### Step 5: Basic Filters
|
||||
**Objective:** Test single filter functionality
|
||||
```json
|
||||
{"filters": [{"col": "[filter_field]", "opr": "eq", "value": "[filter_value]"}], "page_size": 5}
|
||||
```
|
||||
**Common filters:**
|
||||
- Datasets: `{"col": "schema", "opr": "eq", "value": "main"}`
|
||||
- Charts: `{"col": "viz_type", "opr": "eq", "value": "line"}`
|
||||
- Dashboards: `{"col": "published", "opr": "eq", "value": true}`
|
||||
|
||||
**Validates:**
|
||||
- Single filter application
|
||||
- `filters_applied` metadata
|
||||
- Filtered result count
|
||||
|
||||
### Step 6: Multiple Filters with Different Operators
|
||||
**Objective:** Test multiple filters and different operators
|
||||
```json
|
||||
{
|
||||
"filters": [
|
||||
{"col": "[field1]", "opr": "sw", "value": "[prefix]"},
|
||||
{"col": "[field2]", "opr": "eq", "value": "[exact_value]"}
|
||||
],
|
||||
"page_size": 5
|
||||
}
|
||||
```
|
||||
**Example combinations:**
|
||||
- Datasets: `table_name` starts with + `schema` equals
|
||||
- Charts: `slice_name` starts with + `viz_type` equals
|
||||
- Dashboards: `dashboard_title` contains + `published` equals
|
||||
|
||||
**Validates:**
|
||||
- Multiple filter combination (AND logic)
|
||||
- Different operator types working
|
||||
- Complex filtering accuracy
|
||||
|
||||
### Step 7: Custom Column Selection
|
||||
**Objective:** Test selective field retrieval
|
||||
```json
|
||||
{"page_size": 8, "select_columns": ["id", "[name_field]", "[key_fields]"]}
|
||||
```
|
||||
**Common column selections:**
|
||||
- Datasets: `["id", "table_name", "database_name", "is_virtual"]`
|
||||
- Charts: `["id", "slice_name", "viz_type", "datasource_name"]`
|
||||
- Dashboards: `["id", "dashboard_title", "published", "slug"]`
|
||||
|
||||
**Validates:**
|
||||
- Column selection working
|
||||
- `columns_requested` vs `columns_loaded` metadata
|
||||
- Response structure with limited fields
|
||||
|
||||
### Step 8: Cache Control Parameters
|
||||
**Objective:** Test caching behavior
|
||||
```json
|
||||
{"page_size": 3, "use_cache": false, "force_refresh": true}
|
||||
```
|
||||
**Validates:**
|
||||
- Cache bypass functionality
|
||||
- Fresh data retrieval
|
||||
- Performance impact of cache settings
|
||||
|
||||
### Step 9: Metadata Refresh Parameters
|
||||
**Objective:** Test metadata refresh functionality
|
||||
```json
|
||||
{
|
||||
"page_size": 4,
|
||||
"order_column": "id",
|
||||
"order_direction": "asc",
|
||||
"refresh_metadata": true
|
||||
}
|
||||
```
|
||||
**Validates:**
|
||||
- Metadata refresh working
|
||||
- Fresh schema/column information
|
||||
- Impact on response completeness
|
||||
|
||||
### Step 10: Complex Parameter Combination
|
||||
**Objective:** Test all parameter types working together
|
||||
```json
|
||||
{
|
||||
"page": 2,
|
||||
"filters": [{"col": "[field]", "opr": "like", "value": "%[pattern]%"}],
|
||||
"page_size": 3,
|
||||
"use_cache": false,
|
||||
"order_column": "changed_on",
|
||||
"force_refresh": true,
|
||||
"order_direction": "desc",
|
||||
"refresh_metadata": true
|
||||
}
|
||||
```
|
||||
**Validates:**
|
||||
- Complex parameter interaction
|
||||
- No conflicts between parameter types
|
||||
- All functionality working simultaneously
|
||||
|
||||
## Entity-Specific Adaptations
|
||||
|
||||
### For Datasets (`list_datasets`)
|
||||
- **Tool:** `Superset MCP Proxy:list_datasets`
|
||||
- **Filter prep:** `get_dataset_available_filters`
|
||||
- **Key fields:** `table_name`, `schema`, `database_name`
|
||||
- **Search terms:** Table/dataset names
|
||||
- **Common filters:** `schema`, `table_name`, `owner`
|
||||
|
||||
### For Charts (`list_charts`)
|
||||
- **Tool:** `Superset MCP Proxy:list_charts`
|
||||
- **Filter prep:** `get_chart_available_filters`
|
||||
- **Key fields:** `slice_name`, `viz_type`, `datasource_name`
|
||||
- **Search terms:** Chart names, visualization types
|
||||
- **Common filters:** `viz_type`, `slice_name`, `datasource_name`
|
||||
|
||||
### For Dashboards (`list_dashboards`)
|
||||
- **Tool:** `Superset MCP Proxy:list_dashboards`
|
||||
- **Filter prep:** `get_dashboard_available_filters`
|
||||
- **Key fields:** `dashboard_title`, `published`, `slug`
|
||||
- **Search terms:** Dashboard names
|
||||
- **Common filters:** `published`, `dashboard_title`, `favorite`
|
||||
|
||||
### For New Entities
|
||||
1. Identify the `list_[entity]` tool
|
||||
2. Check if `get_[entity]_available_filters` exists
|
||||
3. Examine initial response to understand:
|
||||
- Key identifying fields
|
||||
- Available sortable columns
|
||||
- Common filterable fields
|
||||
- Typical data patterns
|
||||
4. Adapt test values accordingly
|
||||
|
||||
## Validation Checklist
|
||||
|
||||
For each test, verify:
|
||||
- ✅ **Response Structure:** Proper JSON with expected fields
|
||||
- ✅ **Status:** No errors returned
|
||||
- ✅ **Data Integrity:** Results match expected parameters
|
||||
- ✅ **Metadata:** Pagination, filtering, and sorting metadata accurate
|
||||
- ✅ **Count Consistency:** `count` matches actual results returned
|
||||
- ✅ **Pagination Logic:** Page boundaries and navigation work correctly
|
||||
|
||||
## Common Issues to Watch For
|
||||
|
||||
1. **Empty Results:** Page 2+ with small result sets
|
||||
2. **Filter Mismatches:** Case sensitivity in string filters
|
||||
3. **Column Selection:** Some fields may not populate as expected
|
||||
4. **Cache Behavior:** Performance differences with cache settings
|
||||
5. **Operator Support:** Not all operators work with all field types
|
||||
|
||||
## Automation Considerations
|
||||
|
||||
This plan can be automated by:
|
||||
1. Creating parameterized test functions
|
||||
2. Building entity-specific configuration objects
|
||||
3. Implementing validation assertion helpers
|
||||
4. Adding performance timing measurements
|
||||
5. Generating test reports with pass/fail status
|
||||
|
||||
## Example Execution Flow
|
||||
|
||||
```
|
||||
1. Run get_[entity]_available_filters
|
||||
2. Execute Steps 1-10 sequentially
|
||||
3. Wait for "next" confirmation between steps
|
||||
4. Document any unexpected behaviors
|
||||
5. Verify all parameter combinations work
|
||||
6. Generate summary report
|
||||
```
|
||||
|
||||
This plan ensures comprehensive testing of all Superset list endpoint functionality while being adaptable to any current or future entity type.
|
||||
|
||||
## Improvements and Enhancements
|
||||
|
||||
### Suggested Improvements
|
||||
|
||||
1. **Performance Testing:** Add response time measurements for cache vs non-cache scenarios
|
||||
2. **Edge Case Testing:** Test with extreme values (very large page_size, invalid dates, etc.)
|
||||
3. **Error Handling Testing:** Test invalid parameters to verify proper error responses
|
||||
4. **Data Quality Testing:** Verify data consistency across different parameter combinations
|
||||
5. **Concurrent Testing:** Test multiple simultaneous requests to check for race conditions
|
||||
6. **Memory Usage Testing:** Monitor memory consumption with large result sets
|
||||
7. **Backward Compatibility:** Test with legacy parameter formats if applicable
|
||||
|
||||
### Implementation Suggestions
|
||||
|
||||
1. **Test Configuration Files:** Create JSON configs for each entity type with common test values
|
||||
2. **Result Comparison:** Add utilities to compare results across different parameter combinations
|
||||
3. **Regression Testing:** Save baseline results to detect unexpected changes
|
||||
4. **Visual Reports:** Generate HTML reports with pass/fail status and performance metrics
|
||||
5. **CI Integration:** Automate this test plan as part of continuous integration
|
||||
|
||||
### Advanced Testing Scenarios
|
||||
|
||||
1. **Load Testing:** Test with hundreds of concurrent requests
|
||||
2. **Data Volume Testing:** Test with databases containing millions of records
|
||||
3. **Network Failure Testing:** Test behavior with intermittent network issues
|
||||
4. **Permission Testing:** Test with different user roles and permissions
|
||||
5. **Multi-tenant Testing:** Test across different organization contexts
|
||||
941
superset/mcp_service/MCP_CHART_TEST_PLAN.md
Normal file
941
superset/mcp_service/MCP_CHART_TEST_PLAN.md
Normal file
@@ -0,0 +1,941 @@
|
||||
# MCP Chart Tools Test Plan
|
||||
|
||||
This document provides a comprehensive test plan for testing the MCP chart tools with Claude Desktop.
|
||||
|
||||
## Important Test Instructions
|
||||
|
||||
### 🔗 **ALWAYS SHOW URLs**
|
||||
When any tool returns a URL (e.g., `url`, `preview_url`, `explore_url`), **always display the complete URL** in your response. For example:
|
||||
- "Chart created successfully! View it at: http://localhost:8088/explore/?slice_id=123"
|
||||
- "Preview URL: http://localhost:8088/superset/slice/123/"
|
||||
|
||||
### 🖼️ **EMBED IMAGES WHEN POSSIBLE**
|
||||
When testing preview tools:
|
||||
1. For `format: "url"` - Display the preview URL and attempt to embed the image
|
||||
2. For `format: "base64"` - Decode and display the image inline
|
||||
3. For `format: "ascii"` or `format: "table"` - Display the text representation in a code block
|
||||
|
||||
## Prerequisites
|
||||
|
||||
1. Ensure Superset is running locally on http://localhost:8088
|
||||
2. Have at least one dataset available (e.g., "examples.births_2008" or any dataset ID)
|
||||
3. Have some existing charts in your Superset instance
|
||||
4. Verify MCP service is running on port 5008
|
||||
|
||||
## Important Schema Notes
|
||||
|
||||
- **Filter operator field**: Use `op` not `operator` in filter objects
|
||||
- **Data format**: Use `excel` not `xlsx` for Excel export
|
||||
- **Preview formats**: Only `url`, `ascii`, and `table` are supported (NOT `base64`, `interactive`, or `vega_lite`)
|
||||
- **Column selection**: The `url` field is not in default columns - must be explicitly requested
|
||||
- **Sort parameters**: Use `order_column` and `order_direction`, not `sort_columns`
|
||||
|
||||
## Test Coverage Overview
|
||||
|
||||
| Tool | Basic | Advanced | Error Cases | Performance |
|
||||
|------|-------|----------|-------------|-------------|
|
||||
| list_charts | ✓ | ✓ | ✓ | ✓ |
|
||||
| get_chart_info | ✓ | ✓ | ✓ | ✓ |
|
||||
| get_chart_available_filters | ✓ | ✓ | ✓ | - |
|
||||
| generate_chart | ✓ | ✓ | ✓ | ✓ |
|
||||
| update_chart | ✓ | ✓ | ✓ | - |
|
||||
| update_chart_preview | ✓ | ✓ | ✓ | - |
|
||||
| get_chart_data | ✓ | ✓ | ✓ | ✓ |
|
||||
| get_chart_preview | ✓ | ✓ | ✓ | ✓ |
|
||||
|
||||
## 1. Test list_charts
|
||||
|
||||
### Basic Listing
|
||||
```
|
||||
Test: List all charts with default pagination
|
||||
Expected: Returns first 20 charts with metadata including URLs
|
||||
Action: Display the URL for at least one chart
|
||||
```
|
||||
|
||||
### Pagination Tests
|
||||
```
|
||||
Test: List charts with page=2, page_size=5
|
||||
Expected: Returns charts 6-10
|
||||
|
||||
Test: List charts with page_size=50
|
||||
Expected: Returns up to 50 charts on first page
|
||||
|
||||
Test: List with cache control use_cache=false
|
||||
Expected: Fresh data with cache_status showing cache_hit=false
|
||||
```
|
||||
|
||||
### Search Tests
|
||||
```
|
||||
Test: Search for charts with search="sales"
|
||||
Expected: Returns charts with "sales" in name or description
|
||||
|
||||
Test: Search with UUID/slug search="abc-123-def"
|
||||
Expected: Searches across UUID and slug fields
|
||||
|
||||
Test: Search with no results search="xyz123nonexistent"
|
||||
Expected: Returns empty list with count=0
|
||||
```
|
||||
|
||||
### Filter Tests
|
||||
```
|
||||
Test: Filter by viz_type with filters=[{"col": "viz_type", "opr": "eq", "value": "table"}]
|
||||
Expected: Returns only table charts
|
||||
|
||||
Test: Filter by multiple conditions
|
||||
filters=[
|
||||
{"col": "viz_type", "opr": "eq", "value": "line"},
|
||||
{"col": "datasource_name", "opr": "sw", "value": "births"}
|
||||
]
|
||||
Expected: Returns line charts from births dataset
|
||||
|
||||
Test: Filter with IN operator
|
||||
filters=[{"col": "viz_type", "opr": "in", "value": ["line", "bar", "area"]}]
|
||||
Expected: Returns charts matching any of the specified types
|
||||
```
|
||||
|
||||
### Column Selection
|
||||
```
|
||||
Test: Select specific columns with select_columns=["id", "slice_name", "viz_type", "url"]
|
||||
Expected: Returns only requested fields - DISPLAY THE URL
|
||||
|
||||
Test: Include UUID with select_columns=["id", "slice_name", "uuid", "url"]
|
||||
Expected: Returns charts with UUID field populated - DISPLAY THE URL
|
||||
```
|
||||
|
||||
### Sort Options
|
||||
```
|
||||
Test: Sort by name ascending sort_columns=[{"col": "slice_name", "order": "asc"}]
|
||||
Expected: Charts ordered alphabetically
|
||||
|
||||
Test: Sort by updated date sort_columns=[{"col": "changed_on", "order": "desc"}]
|
||||
Expected: Most recently updated charts first
|
||||
```
|
||||
|
||||
## 2. Test get_chart_info
|
||||
|
||||
### Valid Chart Lookup
|
||||
```
|
||||
Test: Get info for existing chart by numeric ID (e.g., 1)
|
||||
Expected: Returns full chart details including form_data and URLs
|
||||
Action: DISPLAY the chart URL
|
||||
|
||||
Test: Get info for chart by UUID (if you have one)
|
||||
Expected: Returns same chart info using UUID identifier
|
||||
Action: DISPLAY the chart URL
|
||||
```
|
||||
|
||||
### Error Cases
|
||||
```
|
||||
Test: Get info for non-existent chart ID 99999
|
||||
Expected: Returns error with type "NotFound"
|
||||
|
||||
Test: Get info with invalid identifier "not-a-valid-id"
|
||||
Expected: Returns appropriate validation error
|
||||
```
|
||||
|
||||
## 3. Test get_chart_available_filters
|
||||
|
||||
### Basic Filter Discovery
|
||||
```
|
||||
Test: Get available filters for a chart
|
||||
Request: {"identifier": 1}
|
||||
Expected: Returns filterable columns with operators and current values
|
||||
```
|
||||
|
||||
### With Current Filters
|
||||
```
|
||||
Test: See interaction with existing filters
|
||||
Request: {"identifier": 1, "include_filter_values": true}
|
||||
Expected: Shows columns, operators, and any applied filter values
|
||||
```
|
||||
|
||||
## 4. Test generate_chart
|
||||
|
||||
### Table Chart Generation
|
||||
|
||||
#### Basic Table
|
||||
```
|
||||
Test: Generate simple table chart
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1, // Use your dataset ID
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "region", "label": "Region"},
|
||||
{"name": "sales", "label": "Sales"}
|
||||
]
|
||||
}
|
||||
}
|
||||
Expected: Creates table chart with selected columns
|
||||
Action: DISPLAY THE CHART URL from response
|
||||
```
|
||||
|
||||
#### Table with Aggregation
|
||||
```
|
||||
Test: Generate table with aggregated metrics
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "region", "label": "Region"},
|
||||
{"name": "sales", "label": "Total Sales", "aggregate": "SUM"},
|
||||
{"name": "quantity", "label": "Avg Quantity", "aggregate": "AVG"}
|
||||
],
|
||||
"filters": [{"column": "year", "operator": "==", "value": 2024}],
|
||||
"time_range": "Last quarter"
|
||||
}
|
||||
}
|
||||
Expected: Creates table with aggregated data, filtered and time-scoped
|
||||
Action: DISPLAY THE CHART URL
|
||||
```
|
||||
|
||||
#### Table with All Options
|
||||
```
|
||||
Test: Comprehensive table configuration
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "date", "label": "Date"},
|
||||
{"name": "category", "label": "Category"},
|
||||
{"name": "sales", "label": "Sales", "aggregate": "SUM"},
|
||||
{"name": "profit", "label": "Profit %", "aggregate": "AVG"}
|
||||
],
|
||||
"filters": [
|
||||
{"column": "region", "operator": "IN", "value": ["East", "West"]},
|
||||
{"column": "sales", "operator": ">", "value": 1000}
|
||||
],
|
||||
"order_by": [
|
||||
{"column": "sales", "desc": true}
|
||||
],
|
||||
"row_limit": 100,
|
||||
"show_totals": true,
|
||||
"conditional_formatting": [
|
||||
{
|
||||
"column": "profit",
|
||||
"operator": "<",
|
||||
"value": 0,
|
||||
"color": "#FF0000"
|
||||
}
|
||||
]
|
||||
},
|
||||
"save_chart": true,
|
||||
"slice_name": "Regional Sales Analysis"
|
||||
}
|
||||
Expected: Creates fully configured table
|
||||
Action: DISPLAY THE CHART URL
|
||||
```
|
||||
|
||||
### Line Chart Generation
|
||||
|
||||
#### Time Series Line Chart
|
||||
```
|
||||
Test: Generate time series line chart
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "line",
|
||||
"x": {"name": "date"},
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}],
|
||||
"time_grain": "P1D",
|
||||
"time_range": "Last 30 days"
|
||||
}
|
||||
}
|
||||
Expected: Creates line chart with daily granularity
|
||||
Action: DISPLAY THE CHART URL
|
||||
```
|
||||
|
||||
#### Multi-Metric Line Chart
|
||||
```
|
||||
Test: Generate chart with multiple metrics
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "line",
|
||||
"x": {"name": "date"},
|
||||
"y": [
|
||||
{"name": "sales", "aggregate": "SUM", "label": "Total Sales"},
|
||||
{"name": "profit", "aggregate": "SUM", "label": "Total Profit"},
|
||||
{"name": "orders", "aggregate": "COUNT", "label": "Order Count"}
|
||||
],
|
||||
"group_by": ["region"],
|
||||
"show_legend": true,
|
||||
"y_axis_format": ",.0f"
|
||||
}
|
||||
}
|
||||
Expected: Creates multi-line chart with grouping
|
||||
Action: DISPLAY THE CHART URL
|
||||
```
|
||||
|
||||
### Bar Chart Generation
|
||||
|
||||
#### Simple Bar Chart
|
||||
```
|
||||
Test: Generate bar chart
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "bar",
|
||||
"x": {"name": "category"},
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}]
|
||||
}
|
||||
}
|
||||
Expected: Creates vertical bar chart
|
||||
Action: DISPLAY THE CHART URL
|
||||
```
|
||||
|
||||
#### Stacked Bar Chart
|
||||
```
|
||||
Test: Generate stacked bar chart
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "bar",
|
||||
"x": {"name": "month"},
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}],
|
||||
"group_by": ["product_line"],
|
||||
"stack": true,
|
||||
"show_values": true
|
||||
}
|
||||
}
|
||||
Expected: Creates stacked bar chart with values
|
||||
Action: DISPLAY THE CHART URL
|
||||
```
|
||||
|
||||
### Area Chart Generation
|
||||
```
|
||||
Test: Generate area chart
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "area",
|
||||
"x": {"name": "date"},
|
||||
"y": [{"name": "revenue", "aggregate": "SUM"}],
|
||||
"group_by": ["segment"],
|
||||
"opacity": 0.7,
|
||||
"show_brush": true
|
||||
}
|
||||
}
|
||||
Expected: Creates area chart with brush selection
|
||||
Action: DISPLAY THE CHART URL
|
||||
```
|
||||
|
||||
### Scatter Plot Generation
|
||||
```
|
||||
Test: Generate scatter plot
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "scatter",
|
||||
"x": {"name": "price", "label": "Price"},
|
||||
"y": [{"name": "quantity", "label": "Quantity Sold"}],
|
||||
"size": {"name": "profit", "aggregate": "SUM"},
|
||||
"color": {"name": "category"},
|
||||
"max_bubble_size": 50
|
||||
}
|
||||
}
|
||||
Expected: Creates scatter plot (limited to 50 data points in ASCII preview)
|
||||
Action: DISPLAY THE CHART URL
|
||||
```
|
||||
|
||||
### Preview Without Saving
|
||||
```
|
||||
Test: Generate chart without saving (preview only)
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "region"}, {"name": "sales", "aggregate": "SUM"}]
|
||||
},
|
||||
"save_chart": false
|
||||
}
|
||||
Expected: Returns preview data without saving
|
||||
Action: Note that no permanent URL is created
|
||||
```
|
||||
|
||||
### Error Cases
|
||||
|
||||
#### Invalid Dataset
|
||||
```
|
||||
Test: Generate chart with non-existent dataset
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 99999,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "col1"}]
|
||||
}
|
||||
}
|
||||
Expected: Returns error with type "DatasetNotFound"
|
||||
```
|
||||
|
||||
#### Invalid Column
|
||||
```
|
||||
Test: Generate chart with non-existent column
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "nonexistent_column"}]
|
||||
}
|
||||
}
|
||||
Expected: Returns validation error with column suggestions
|
||||
Action: Note the suggested column names for next test
|
||||
```
|
||||
|
||||
#### Invalid Aggregation
|
||||
```
|
||||
Test: Use SUM on text column
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "region", "aggregate": "SUM"} // Text column with numeric aggregate
|
||||
]
|
||||
}
|
||||
}
|
||||
Expected: Returns validation error about aggregate type mismatch
|
||||
```
|
||||
|
||||
#### Missing Required Fields
|
||||
```
|
||||
Test: Generate chart without required x axis
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "line",
|
||||
"y": [{"name": "sales", "aggregate": "SUM"}]
|
||||
// Missing x field
|
||||
}
|
||||
}
|
||||
Expected: Returns validation error about missing x axis
|
||||
```
|
||||
|
||||
## 5. Test update_chart
|
||||
|
||||
### Basic Update
|
||||
```
|
||||
Test: Update chart name and description
|
||||
Request:
|
||||
{
|
||||
"identifier": 1, // Use existing chart ID
|
||||
"updates": {
|
||||
"slice_name": "Updated Chart Name",
|
||||
"description": "This chart has been updated via MCP"
|
||||
}
|
||||
}
|
||||
Expected: Updates chart metadata
|
||||
Action: DISPLAY THE UPDATED CHART URL
|
||||
```
|
||||
|
||||
### Update Visualization
|
||||
```
|
||||
Test: Change chart type from bar to line
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"updates": {
|
||||
"viz_type": "line",
|
||||
"params": {
|
||||
"viz_type": "line",
|
||||
"line_interpolation": "smooth"
|
||||
}
|
||||
}
|
||||
}
|
||||
Expected: Changes chart visualization type
|
||||
Action: DISPLAY THE URL to see the change
|
||||
```
|
||||
|
||||
### Update with Cache Refresh
|
||||
```
|
||||
Test: Update and force cache refresh
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"updates": {
|
||||
"slice_name": "Fresh Data Chart"
|
||||
},
|
||||
"force_refresh": true
|
||||
}
|
||||
Expected: Updates chart and refreshes cache
|
||||
```
|
||||
|
||||
## 6. Test update_chart_preview
|
||||
|
||||
### Update Existing Preview
|
||||
```
|
||||
Test: Refresh chart preview
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"force_refresh": true
|
||||
}
|
||||
Expected: Regenerates preview with fresh data
|
||||
Action: Note cache_status in response
|
||||
```
|
||||
|
||||
### Update Preview Format
|
||||
```
|
||||
Test: Change preview settings
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"width": 1200,
|
||||
"height": 800,
|
||||
"force_refresh": true
|
||||
}
|
||||
Expected: Updates preview with new dimensions
|
||||
```
|
||||
|
||||
## 7. Test get_chart_preview
|
||||
|
||||
### URL Preview (Screenshot)
|
||||
```
|
||||
Test: Get chart preview as URL
|
||||
Request:
|
||||
{
|
||||
"identifier": 1, // Use existing chart ID
|
||||
"format": "url",
|
||||
"width": 800,
|
||||
"height": 600
|
||||
}
|
||||
Expected: Returns preview_url
|
||||
Action: DISPLAY THE PREVIEW URL and attempt to embed the image:
|
||||

|
||||
```
|
||||
|
||||
### Base64 Preview
|
||||
```
|
||||
Test: Get chart as base64 image
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "base64"
|
||||
}
|
||||
Expected: Returns base64 encoded image
|
||||
Action: Display decoded image inline if possible
|
||||
```
|
||||
|
||||
### ASCII Preview
|
||||
```
|
||||
Test: Get chart as ASCII art
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "ascii",
|
||||
"ascii_width": 80,
|
||||
"ascii_height": 20
|
||||
}
|
||||
Expected: Returns ASCII representation (limited to 50 rows)
|
||||
Action: Display in a code block:
|
||||
```
|
||||
[ASCII art will appear here]
|
||||
```
|
||||
|
||||
### Table Preview
|
||||
```
|
||||
Test: Get chart data as formatted table
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "table",
|
||||
"max_rows": 10
|
||||
}
|
||||
Expected: Returns tabular data (limited to 20 rows)
|
||||
Action: Display the table in a formatted code block
|
||||
```
|
||||
|
||||
### Cache Control in Preview
|
||||
```
|
||||
Test: Force fresh preview
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "url",
|
||||
"force_refresh": true,
|
||||
"cache_timeout": 0
|
||||
}
|
||||
Expected: Returns fresh preview with cache_hit=false
|
||||
Action: DISPLAY THE PREVIEW URL
|
||||
```
|
||||
|
||||
### Error Cases
|
||||
```
|
||||
Test: Get preview for non-existent chart
|
||||
Request:
|
||||
{
|
||||
"identifier": 99999,
|
||||
"format": "url"
|
||||
}
|
||||
Expected: Returns error with type "NotFound"
|
||||
|
||||
Test: Invalid format
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "invalid_format"
|
||||
}
|
||||
Expected: Returns error "Unsupported preview format: invalid_format"
|
||||
|
||||
Test: Unsupported format (base64)
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "base64"
|
||||
}
|
||||
Expected: Returns error "Unsupported preview format: base64"
|
||||
```
|
||||
|
||||
## 8. Test get_chart_data
|
||||
|
||||
### Basic Data Retrieval
|
||||
```
|
||||
Test: Get data for existing chart
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "json",
|
||||
"limit": 100
|
||||
}
|
||||
Expected: Returns chart data with metadata
|
||||
Action: Display sample of data and note total_rows
|
||||
```
|
||||
|
||||
### CSV Export
|
||||
```
|
||||
Test: Get data as CSV
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "csv"
|
||||
}
|
||||
Expected: Returns CSV formatted data
|
||||
Action: Display first few lines of CSV
|
||||
```
|
||||
|
||||
### Excel Export
|
||||
```
|
||||
Test: Get data as Excel
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "excel" // Note: use "excel" not "xlsx"
|
||||
}
|
||||
Expected: Returns base64 encoded Excel file
|
||||
Action: Note that Excel file was generated
|
||||
```
|
||||
|
||||
### With Additional Processing
|
||||
```
|
||||
Test: Get data with insights
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "json",
|
||||
"include_column_metadata": true,
|
||||
"generate_insights": true,
|
||||
"limit": 50
|
||||
}
|
||||
Expected: Returns data with column analysis and insights
|
||||
Action: Display the insights and column metadata
|
||||
```
|
||||
|
||||
### Cache Control
|
||||
```
|
||||
Test: Force fresh data
|
||||
Request:
|
||||
{
|
||||
"identifier": 1,
|
||||
"format": "json",
|
||||
"force_refresh": true,
|
||||
"use_cache": false
|
||||
}
|
||||
Expected: Returns fresh data with cache_hit=false
|
||||
Action: Note the cache_status details
|
||||
```
|
||||
|
||||
### Big Number Chart Handling
|
||||
```
|
||||
Test: Get data for big_number chart type
|
||||
Request:
|
||||
{
|
||||
"identifier": [ID of a big_number chart],
|
||||
"format": "json"
|
||||
}
|
||||
Expected: Should handle appropriately or return specific error
|
||||
```
|
||||
|
||||
## 9. Integration Test Scenarios
|
||||
|
||||
### Complete Chart Lifecycle
|
||||
```
|
||||
1. Generate a new chart with save_chart=true
|
||||
- DISPLAY THE CHART URL
|
||||
2. Use returned chart_id to get_chart_info
|
||||
- Verify all details match
|
||||
3. Update the chart with update_chart
|
||||
- DISPLAY THE UPDATED URL
|
||||
4. Get preview in multiple formats
|
||||
- DISPLAY URL preview and embed image
|
||||
- Show ASCII preview in code block
|
||||
5. Get chart data in JSON and CSV formats
|
||||
- Display sample data
|
||||
6. Update chart preview with new dimensions
|
||||
- DISPLAY new preview URL
|
||||
```
|
||||
|
||||
### Error Recovery Flow
|
||||
```
|
||||
1. Try to generate chart with invalid column
|
||||
- Note the error and suggestions
|
||||
2. Use list_datasets to find correct dataset
|
||||
3. Use get_dataset_info to see columns
|
||||
4. Generate chart with correct column names
|
||||
- DISPLAY THE SUCCESSFUL CHART URL
|
||||
```
|
||||
|
||||
### Cache Testing Flow
|
||||
```
|
||||
1. Get chart data with use_cache=true
|
||||
- Note cache_hit status
|
||||
2. Get same data again
|
||||
- Verify cache_hit=true
|
||||
3. Get data with force_refresh=true
|
||||
- Verify cache_hit=false
|
||||
4. Check cache_age_seconds values
|
||||
```
|
||||
|
||||
### Multi-Format Export
|
||||
```
|
||||
1. Create a complex chart with multiple metrics
|
||||
- DISPLAY THE CHART URL
|
||||
2. Export as:
|
||||
- JSON (display sample)
|
||||
- CSV (display headers)
|
||||
- Excel (note generation)
|
||||
3. Get preview as:
|
||||
- URL (embed image)
|
||||
- ASCII (show in code block)
|
||||
- Table (display formatted)
|
||||
```
|
||||
|
||||
## 10. Performance and Load Tests
|
||||
|
||||
### Large Dataset Handling
|
||||
```
|
||||
Test: Generate chart with row_limit=10000
|
||||
Expected: Handles gracefully, returns data or appropriate limit
|
||||
|
||||
Test: Get data with limit=50000
|
||||
Expected: Returns data or indicates maximum allowed
|
||||
```
|
||||
|
||||
### Concurrent Operations
|
||||
```
|
||||
Test: Generate 5 charts rapidly in sequence
|
||||
Expected: All succeed without conflicts
|
||||
Action: DISPLAY ALL CHART URLs
|
||||
```
|
||||
|
||||
### Complex Aggregations
|
||||
```
|
||||
Test: Chart with multiple groupings and aggregations
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "region"},
|
||||
{"name": "category"},
|
||||
{"name": "sales", "aggregate": "SUM"},
|
||||
{"name": "sales", "aggregate": "AVG", "label": "Avg Sale"},
|
||||
{"name": "sales", "aggregate": "MAX", "label": "Max Sale"},
|
||||
{"name": "sales", "aggregate": "MIN", "label": "Min Sale"},
|
||||
{"name": "sales", "aggregate": "COUNT", "label": "Sale Count"}
|
||||
],
|
||||
"order_by": [{"column": "sales", "desc": true}],
|
||||
"row_limit": 500
|
||||
}
|
||||
}
|
||||
Expected: Handles complex aggregations efficiently
|
||||
Action: DISPLAY THE CHART URL
|
||||
```
|
||||
|
||||
## 11. Special Cases and Edge Cases
|
||||
|
||||
### Unicode and Special Characters
|
||||
```
|
||||
Test: Chart with unicode in name
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "region"}]
|
||||
},
|
||||
"slice_name": "Sales 销售 🌏 Report"
|
||||
}
|
||||
Expected: Handles unicode correctly
|
||||
Action: DISPLAY THE CHART URL with unicode name
|
||||
```
|
||||
|
||||
### Very Long Names
|
||||
```
|
||||
Test: Chart with very long name
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "region"}]
|
||||
},
|
||||
"slice_name": "This is a very long chart name that exceeds typical length limits and should be handled gracefully by the system without causing any errors or truncation issues"
|
||||
}
|
||||
Expected: Handles or truncates appropriately
|
||||
```
|
||||
|
||||
### SQL Injection Prevention
|
||||
```
|
||||
Test: Attempt SQL injection in filter
|
||||
Request:
|
||||
{
|
||||
"dataset_id": 1,
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [{"name": "region"}],
|
||||
"filters": [{"column": "region", "operator": "==", "value": "'; DROP TABLE users; --"}]
|
||||
}
|
||||
}
|
||||
Expected: Safely handles without executing SQL
|
||||
```
|
||||
|
||||
## Expected Response Patterns
|
||||
|
||||
### Successful Chart Creation
|
||||
```json
|
||||
{
|
||||
"chart": {
|
||||
"id": 123,
|
||||
"slice_name": "My Chart",
|
||||
"viz_type": "table",
|
||||
"url": "http://localhost:8088/explore/?slice_id=123",
|
||||
"uuid": "abc-123-def",
|
||||
"saved": true
|
||||
},
|
||||
"success": true
|
||||
}
|
||||
```
|
||||
**Action: ALWAYS DISPLAY THE URL**
|
||||
|
||||
### Successful Preview
|
||||
```json
|
||||
{
|
||||
"chart_id": 123,
|
||||
"format": "url",
|
||||
"content": {
|
||||
"type": "url",
|
||||
"preview_url": "http://localhost:8088/api/v1/chart/123/screenshot/...",
|
||||
"expires_at": "2024-01-01T12:00:00Z"
|
||||
}
|
||||
}
|
||||
```
|
||||
**Action: DISPLAY AND EMBED THE preview_url**
|
||||
|
||||
### Validation Error
|
||||
```json
|
||||
{
|
||||
"error": "validation_error",
|
||||
"message": "Chart configuration validation failed",
|
||||
"validation_errors": [
|
||||
{
|
||||
"field": "columns[0]",
|
||||
"error_type": "column_not_found",
|
||||
"message": "Column 'nonexistent' not found",
|
||||
"suggestions": ["region", "sales", "profit"]
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Data Response with Cache Info
|
||||
```json
|
||||
{
|
||||
"chart_id": 123,
|
||||
"data": [...],
|
||||
"row_count": 100,
|
||||
"total_rows": 5000,
|
||||
"cache_status": {
|
||||
"cache_hit": true,
|
||||
"cache_type": "query",
|
||||
"cache_age_seconds": 300
|
||||
},
|
||||
"insights": ["Data served from cache", "Large dataset - consider filtering"]
|
||||
}
|
||||
```
|
||||
|
||||
## Test Execution Checklist
|
||||
|
||||
- [ ] Environment setup verified
|
||||
- [ ] Basic CRUD operations tested
|
||||
- [ ] All chart types tested
|
||||
- [ ] Error handling verified
|
||||
- [ ] Cache behavior confirmed
|
||||
- [ ] Preview formats working
|
||||
- [ ] URLs displayed for all operations
|
||||
- [ ] Images embedded where possible
|
||||
- [ ] Performance acceptable
|
||||
- [ ] Edge cases handled
|
||||
|
||||
## Debugging Tips
|
||||
|
||||
1. **Always display returned URLs** - They're crucial for verification
|
||||
2. **For image previews** - Try to embed using markdown: ``
|
||||
3. **For errors** - Show the complete error response
|
||||
4. **For data** - Show a representative sample, not everything
|
||||
5. **Check cache_status** - Helps understand performance
|
||||
6. **Save successful IDs** - Reuse for subsequent tests
|
||||
7. **Note patterns** - Errors often reveal API patterns
|
||||
|
||||
## Summary Report Template
|
||||
|
||||
After running tests, summarize:
|
||||
|
||||
```
|
||||
Test Summary for MCP Chart Tools
|
||||
================================
|
||||
Total Tests Run: X
|
||||
Passed: X
|
||||
Failed: X
|
||||
|
||||
Working Features:
|
||||
- ✅ Feature 1 (with URL: ...)
|
||||
- ✅ Feature 2 (with preview: ...)
|
||||
|
||||
Issues Found:
|
||||
- ❌ Issue 1: Description
|
||||
- ❌ Issue 2: Description
|
||||
|
||||
Performance Notes:
|
||||
- Average response time: Xs
|
||||
- Cache hit rate: X%
|
||||
|
||||
Recommendations:
|
||||
- ...
|
||||
```
|
||||
688
superset/mcp_service/README.md
Normal file
688
superset/mcp_service/README.md
Normal file
@@ -0,0 +1,688 @@
|
||||
# Superset MCP Service
|
||||
|
||||
The Superset Model Context Protocol (MCP) service provides a modular, schema-driven interface for programmatic access to Superset dashboards, charts, datasets, and instance metadata. It is designed for LLM agents and automation tools, and is built on the FastMCP protocol.
|
||||
|
||||
**✅ Phase 1 Complete. Core functionality stable, authentication production-ready, comprehensive testing coverage.**
|
||||
|
||||
## 🚀 Quickstart
|
||||
|
||||
### 1. Install Superset Locally
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone https://github.com/apache/superset.git
|
||||
cd superset
|
||||
|
||||
# Create virtual environment and install (Python 3.10 or 3.11 required)
|
||||
make venv
|
||||
source venv/bin/activate
|
||||
make install
|
||||
|
||||
# Start Superset (in a separate terminal)
|
||||
source venv/bin/activate
|
||||
superset run -p 8088 --with-threads --reload --debugger
|
||||
```
|
||||
|
||||
For alternative installation methods, see the [official Superset development guide](https://superset.apache.org/docs/contributing/development).
|
||||
|
||||
### 2. Run the MCP Service
|
||||
|
||||
The MCP service runs as an HTTP server (not stdout) and requires a proxy for Claude Desktop:
|
||||
|
||||
```bash
|
||||
# In a new terminal, with your virtual environment activated
|
||||
source venv/bin/activate # if using make venv
|
||||
# OR
|
||||
# pyenv activate superset-mcp # if using pyenv
|
||||
|
||||
# Run the MCP service
|
||||
superset mcp run --port 5008 --debug
|
||||
```
|
||||
|
||||
The service will start on http://localhost:5008
|
||||
|
||||
### 3. Connect to Claude Desktop
|
||||
|
||||
Since the MCP service runs on HTTP (not stdout), you need to use the FastMCP proxy:
|
||||
|
||||
**Step 1: Configure the existing proxy script**
|
||||
The proxy script `superset/mcp_service/run_proxy.sh` is already provided. Update the paths in it if needed for your environment.
|
||||
|
||||
**Step 2: Configure Claude Desktop**
|
||||
Add to your Claude Desktop config (~/Library/Application Support/Claude/claude_desktop_config.json):
|
||||
```json
|
||||
{
|
||||
"mcpServers": {
|
||||
"Superset MCP Proxy": {
|
||||
"command": "/path/to/your/superset/superset/mcp_service/run_proxy.sh",
|
||||
"args": [],
|
||||
"env": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**Step 3: Restart Claude Desktop**
|
||||
- Quit Claude Desktop completely
|
||||
- Start it again
|
||||
- The Superset MCP tools should now be available
|
||||
|
||||
### 4. Install Browser Dependencies (Optional - for chart screenshots)
|
||||
|
||||
The chart preview functionality requires Firefox and geckodriver. See the [Superset documentation](https://superset.apache.org/docs/contributing/development) for installation instructions.
|
||||
|
||||
**Quick install on macOS:**
|
||||
```bash
|
||||
brew install --cask firefox
|
||||
brew install geckodriver
|
||||
```
|
||||
|
||||
### 5. Verify Your Setup
|
||||
|
||||
**Check that Superset is running:**
|
||||
```bash
|
||||
curl http://localhost:8088/health
|
||||
# Should return {"status": "OK"}
|
||||
```
|
||||
|
||||
**Check that MCP service is running:**
|
||||
```bash
|
||||
# Check if the MCP service port is listening
|
||||
lsof -i :5008
|
||||
# Should show the superset mcp process listening on port 5008
|
||||
|
||||
# Or check the process directly
|
||||
ps aux | grep "superset mcp"
|
||||
```
|
||||
|
||||
**Test in Claude Desktop:**
|
||||
- Ask Claude to "list dashboards" or "get superset instance info"
|
||||
- Claude should be able to use the MCP tools to query your Superset instance
|
||||
|
||||
### 6. Run Tests (Optional)
|
||||
|
||||
Run the unit tests to verify your environment:
|
||||
|
||||
```bash
|
||||
# Unit tests
|
||||
pytest tests/unit_tests/mcp_service/ --maxfail=1 -v
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
**If Claude Desktop can't connect:**
|
||||
1. Ensure both Superset (port 8088) and MCP service (port 5008) are running
|
||||
2. Check the proxy script has the correct path to your virtual environment
|
||||
3. Look at Claude Desktop logs: `tail -f ~/Library/Logs/Claude/mcp-server-Superset MCP Proxy.log`
|
||||
4. Verify the proxy works manually: `./run_proxy.sh` (should show MCP protocol messages)
|
||||
|
||||
**If screenshots don't work:**
|
||||
1. Ensure Firefox and geckodriver are installed and in PATH
|
||||
2. Check `which geckodriver` returns a valid path
|
||||
3. Try running Firefox manually to ensure it works
|
||||
|
||||
## Available Tools
|
||||
|
||||
**16 MCP tools** with Pydantic v2 schemas and comprehensive field documentation for LLM compatibility.
|
||||
|
||||
### 📊 Dashboard Tools (5)
|
||||
- **`list_dashboards`** - List with search/filters/pagination, UUID/slug support
|
||||
- **`get_dashboard_info`** - Get by ID/UUID/slug with metadata
|
||||
- **`get_dashboard_available_filters`** - Discover filterable columns
|
||||
- **`generate_dashboard`** - Create dashboards with multiple charts
|
||||
- **`add_chart_to_existing_dashboard`** - Add charts to existing dashboards
|
||||
|
||||
### 📈 Chart Tools (8)
|
||||
- **`list_charts`** - List with search/filters/pagination, UUID support
|
||||
- **`get_chart_info`** - Get by ID/UUID with full metadata
|
||||
- **`get_chart_available_filters`** - Discover filterable columns
|
||||
- **`generate_chart`** - Create charts (table, line, bar, area, scatter)
|
||||
- **`update_chart`** - Update existing saved charts
|
||||
- **`update_chart_preview`** - Update cached chart previews
|
||||
- **`get_chart_data`** - Export data (JSON/CSV/Excel)
|
||||
- **`get_chart_preview`** - Screenshots, ASCII art, table previews
|
||||
|
||||
### 🗂️ Dataset Tools (3)
|
||||
- **`list_datasets`** - List with columns/metrics, UUID support
|
||||
- **`get_dataset_info`** - Get by ID/UUID with columns/metrics metadata
|
||||
- **`get_dataset_available_filters`** - Discover filterable columns
|
||||
|
||||
### 🖥️ System Tools (2)
|
||||
- **`get_superset_instance_info`** - Instance statistics and version info
|
||||
- **`generate_explore_link`** - Generate chart exploration URLs
|
||||
|
||||
### 🧪 SQL Lab Tools (1)
|
||||
- **`open_sql_lab_with_context`** - Pre-configured SQL Lab sessions
|
||||
|
||||
## Available Operations
|
||||
|
||||
### ✅ Read Operations (All entities)
|
||||
- **List**: Paginated lists with filtering, search, and UUID/slug support
|
||||
- **Get Info**: Detailed information by ID, UUID, or slug
|
||||
- **Get Filters**: Discover available filter columns and operators
|
||||
- **Get Data**: Export chart data in multiple formats
|
||||
- **Get Previews**: Chart screenshots, ASCII art, and table representations
|
||||
|
||||
### ✅ Create Operations
|
||||
- **Charts**: Create charts with 5 visualization types (table, line, bar, area, scatter)
|
||||
- **Dashboards**: Generate dashboards with multiple charts and automatic layout
|
||||
- **Add to Dashboard**: Add existing charts to dashboards with smart positioning
|
||||
|
||||
### ✅ Update Operations
|
||||
- **Charts**: Update saved charts and cached chart previews
|
||||
- **Navigation**: Generate explore links and SQL Lab sessions
|
||||
|
||||
### ❌ Not Available (Future phases)
|
||||
- **Update/Delete**: Dashboard and dataset modifications
|
||||
- **SQL Execution**: Query execution in SQL Lab (opens sessions only)
|
||||
|
||||
## 📖 Complete Documentation
|
||||
|
||||
The MCP service is fully documented on the **[official Superset documentation site](https://superset.apache.org/docs/mcp-service/intro)**:
|
||||
|
||||
### Quick Access
|
||||
- **[🚀 MCP Service Overview](https://superset.apache.org/docs/mcp-service/intro)** - Complete introduction and features
|
||||
- **[📚 API Reference](https://superset.apache.org/docs/mcp-service/api-reference)** - All 16 tools with examples
|
||||
- **[🔧 Development Guide](https://superset.apache.org/docs/mcp-service/development)** - Adding new tools and architecture
|
||||
- **[🔐 Authentication](https://superset.apache.org/docs/mcp-service/authentication)** - Production security setup
|
||||
|
||||
### By Role
|
||||
**👩💻 Developers & Integrators:** [Overview](https://superset.apache.org/docs/mcp-service/overview) → [API Reference](https://superset.apache.org/docs/mcp-service/api-reference) → [Development Guide](https://superset.apache.org/docs/mcp-service/development)
|
||||
|
||||
**🔒 DevOps & Production:** [Authentication](https://superset.apache.org/docs/mcp-service/authentication) → [Architecture](https://superset.apache.org/docs/mcp-service/architecture)
|
||||
|
||||
**🏢 Enterprise Teams:** [Preset Integration](https://superset.apache.org/docs/mcp-service/preset-integration)
|
||||
|
||||
> 💡 **Local Development?** See the [local docs folder](./docs/) for markdown versions during development.
|
||||
|
||||
## Enhanced Parameter Handling
|
||||
|
||||
All MCP tools now use the **FastMCP Complex Inputs Pattern** to eliminate LLM parameter validation issues:
|
||||
|
||||
### Request Schema Pattern
|
||||
Instead of individual parameters, tools use structured request objects:
|
||||
```python
|
||||
# New approach (current)
|
||||
get_dataset_info(request={"identifier": 123}) # ID
|
||||
get_dataset_info(request={"identifier": "a1b2c3d4-e5f6-7890-abcd-ef1234567890"}) # UUID
|
||||
|
||||
# Old approach (replaced)
|
||||
get_dataset_info(dataset_id=123)
|
||||
```
|
||||
|
||||
### Multi-Identifier Support
|
||||
All `get_*_info` tools now support multiple identifier types:
|
||||
- **Datasets/Charts**: ID (numeric) or UUID (string)
|
||||
- **Dashboards**: ID (numeric), UUID (string), or slug (string)
|
||||
|
||||
### Filtering & Search
|
||||
All `list_*` tools support:
|
||||
- **Filters**: Structured filter objects with validation to prevent conflicts
|
||||
- **Search**: Free-text search across key fields (including UUID and slug)
|
||||
- **Validation**: Cannot use both `search` and `filters` simultaneously
|
||||
|
||||
Example:
|
||||
```python
|
||||
# Using request schema with filters
|
||||
list_dashboards(request={
|
||||
"search": "sales",
|
||||
"page": 1,
|
||||
"page_size": 20
|
||||
})
|
||||
|
||||
# Or with filters (but not both)
|
||||
list_dashboards(request={
|
||||
"filters": [{"col": "published", "opr": "eq", "value": True}],
|
||||
"page": 1,
|
||||
"page_size": 20
|
||||
})
|
||||
```
|
||||
|
||||
## Chart Creation
|
||||
|
||||
The `generate_chart` tool supports chart creation with:
|
||||
|
||||
### Supported Chart Types
|
||||
- **Table charts** — Simple column display with filters and sorting
|
||||
- **Line charts** — Time series line charts
|
||||
- **Bar charts** — Time series bar charts
|
||||
- **Area charts** — Time series area charts
|
||||
- **Scatter charts** — Time series scatter charts
|
||||
|
||||
### Chart Creation
|
||||
The tool creates and saves permanent charts in Superset with automatically generated explore URLs.
|
||||
|
||||
### Intelligent Metric Handling
|
||||
The tool automatically handles two metric formats:
|
||||
1. **Simple metrics** (like `["count"]`) — Passed as simple strings
|
||||
2. **Complex metrics** (like column names) — Converted to full Superset metric objects with SQL aggregators (SUM, COUNT, AVG, MIN, MAX)
|
||||
|
||||
### Example Usage
|
||||
```python
|
||||
# Create a line chart with SQL aggregators
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
|
||||
ColumnRef(name="orders", aggregate="COUNT", label="Order Count")
|
||||
],
|
||||
kind="line"
|
||||
)
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
# Create a table chart
|
||||
table_config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region", label="Region"),
|
||||
ColumnRef(name="sales", label="Sales")
|
||||
]
|
||||
)
|
||||
table_request = GenerateChartRequest(dataset_id="1", config=table_config)
|
||||
```
|
||||
|
||||
## Dashboard Generation & Management
|
||||
|
||||
New dashboard management tools provide dashboard creation and chart addition capabilities:
|
||||
|
||||
### Dashboard Creation
|
||||
```python
|
||||
# Generate a dashboard with multiple charts
|
||||
generate_dashboard(request={
|
||||
"chart_ids": [1, 2, 3, 4],
|
||||
"dashboard_title": "Sales Analytics Dashboard",
|
||||
"description": "Sales performance metrics dashboard",
|
||||
"published": True
|
||||
})
|
||||
```
|
||||
|
||||
### Chart Addition to Existing Dashboards
|
||||
```python
|
||||
# Add a chart to an existing dashboard
|
||||
add_chart_to_existing_dashboard(request={
|
||||
"dashboard_id": 123,
|
||||
"chart_id": 456,
|
||||
"target_tab": "Overview" # Optional
|
||||
})
|
||||
```
|
||||
|
||||
## SQL Lab Integration
|
||||
|
||||
Direct integration with Superset's SQL Lab for seamless development workflows:
|
||||
|
||||
```python
|
||||
# Open SQL Lab with context
|
||||
open_sql_lab_with_context(request={
|
||||
"database_connection_id": 1,
|
||||
"schema": "public",
|
||||
"dataset_in_context": "sales_data",
|
||||
"sql": "SELECT * FROM sales_data WHERE region = 'US'",
|
||||
"title": "US Sales Analysis"
|
||||
})
|
||||
```
|
||||
|
||||
**Features:**
|
||||
- Pre-selected database and schema
|
||||
- Contextual SQL templates
|
||||
- Dataset-aware query generation
|
||||
- Proper URL parameter handling (`dbid` for compatibility)
|
||||
|
||||
## Chart Data & Preview System
|
||||
|
||||
Advanced chart preview and data extraction capabilities:
|
||||
|
||||
### Chart Data Retrieval
|
||||
```python
|
||||
# Get chart data in multiple formats
|
||||
get_chart_data(request={
|
||||
"identifier": "chart-uuid-or-id",
|
||||
"format": "json", # json, csv, excel
|
||||
"row_count": 1000,
|
||||
"row_offset": 0
|
||||
})
|
||||
```
|
||||
|
||||
### Chart Preview Generation
|
||||
```python
|
||||
# Generate chart previews
|
||||
get_chart_preview(request={
|
||||
"identifier": "chart-uuid-or-id",
|
||||
"format": "url", # url, base64, ascii, table
|
||||
"width": 800,
|
||||
"height": 600
|
||||
})
|
||||
```
|
||||
|
||||
**Preview Formats:**
|
||||
- **URL**: Screenshot URLs served by MCP service
|
||||
- **Base64**: Embedded image data for direct display
|
||||
- **ASCII**: Text-based charts for terminal/chat display
|
||||
- **Table**: Structured data representation
|
||||
|
||||
## Modular Structure & Best Practices
|
||||
|
||||
- Tools are organized by domain: `dashboard/`, `dataset/`, `chart/`, `system/`.
|
||||
- All input/output is validated with Pydantic v2.
|
||||
- Shared schemas live in `schemas/`.
|
||||
- All tool calls are logged and RBAC/auth hooks are pluggable.
|
||||
- **All tool functions must be decorated with `@mcp.tool` and `@mcp_auth_hook`.**
|
||||
- **All Superset DAOs, command classes, and most Superset modules must be imported inside the function body, not at the top of the file.** This ensures proper app context and avoids initialization errors.
|
||||
|
||||
## Current Status
|
||||
|
||||
### ✅ Phase 1 Complete
|
||||
- **FastMCP Server**: CLI with `superset mcp run`, HTTP service on port 5008
|
||||
- **Authentication**: Production-ready JWT Bearer with configurable factory pattern
|
||||
- **16 Core Tools**: All list/info/filter tools, chart creation, dashboard generation
|
||||
- **Request Schema Pattern**: Eliminates LLM parameter validation issues
|
||||
- **Cache Control**: Comprehensive control over Superset's existing cache layers
|
||||
- **Audit Logging**: MCP context tracking with impersonation and payload sanitization
|
||||
- **Testing**: 194+ unit tests with full pre-commit compliance
|
||||
|
||||
### 🎯 Future Enhancements
|
||||
- Demo notebooks and interactive examples
|
||||
- OAuth integration for user impersonation
|
||||
- Enhanced chart rendering formats (Vega-Lite, Plotly JSON)
|
||||
- Advanced security features and tool poisoning prevention
|
||||
|
||||
## Security & Authentication
|
||||
|
||||
The MCP service supports **configurable JWT Bearer authentication** following Superset's factory pattern. Authentication is **disabled by default** for development convenience.
|
||||
|
||||
### Configuration Options
|
||||
|
||||
**Option 1: Simple Configuration** (Add to `superset_config.py`):
|
||||
```python
|
||||
# Enable authentication
|
||||
MCP_AUTH_ENABLED = True
|
||||
|
||||
# JWT settings
|
||||
MCP_JWKS_URI = "https://auth.company.com/.well-known/jwks.json"
|
||||
MCP_JWT_ISSUER = "https://auth.company.com/"
|
||||
MCP_JWT_AUDIENCE = "superset-mcp-api"
|
||||
MCP_REQUIRED_SCOPES = ["dashboard:read", "chart:read"]
|
||||
```
|
||||
|
||||
**Option 2: Custom Factory** (Advanced):
|
||||
```python
|
||||
def create_custom_mcp_auth(app):
|
||||
"""Custom auth logic for your environment."""
|
||||
from fastmcp.server.auth.providers.bearer import BearerAuthProvider
|
||||
|
||||
return BearerAuthProvider(
|
||||
jwks_uri=app.config["MCP_JWKS_URI"],
|
||||
issuer=app.config["MCP_JWT_ISSUER"],
|
||||
audience=app.config["MCP_JWT_AUDIENCE"],
|
||||
)
|
||||
|
||||
MCP_AUTH_FACTORY = create_custom_mcp_auth
|
||||
```
|
||||
|
||||
**Option 3: Environment Variables** (Legacy):
|
||||
```bash
|
||||
MCP_AUTH_ENABLED=true
|
||||
MCP_JWKS_URI=https://auth.company.com/.well-known/jwks.json
|
||||
MCP_JWT_ISSUER=https://auth.company.com/
|
||||
MCP_JWT_AUDIENCE=superset-mcp-api
|
||||
MCP_REQUIRED_SCOPES=dashboard:read,chart:read
|
||||
```
|
||||
|
||||
### Security Features
|
||||
|
||||
**JWT Authentication**: RS256 tokens validated against JWKS or public key
|
||||
|
||||
**User Context**: JWT claims mapped to Superset users for proper permissions
|
||||
|
||||
**Scope-Based Authorization**:
|
||||
| Tool | Required Scope |
|
||||
|------|----------------|
|
||||
| `list_dashboards`, `get_dashboard_info` | `dashboard:read` |
|
||||
| `list_charts`, `get_chart_info` | `chart:read` |
|
||||
| `generate_chart` | `chart:write` |
|
||||
| `list_datasets`, `get_dataset_info` | `dataset:read` |
|
||||
| `get_superset_instance_info` | `instance:read` |
|
||||
|
||||
**MCP Audit Logging**: All operations logged with MCP-specific context including impersonation tracking, source identification, and sanitized payloads
|
||||
|
||||
**Flexible User Resolution**: Configurable JWT claim extraction
|
||||
|
||||
### For Testing & Development
|
||||
|
||||
Generate test credentials using FastMCP's built-in utilities:
|
||||
|
||||
```python
|
||||
from fastmcp.server.auth.providers.bearer import RSAKeyPair
|
||||
|
||||
# Generate test keypair
|
||||
keypair = RSAKeyPair.generate()
|
||||
print("Public key:", keypair.public_key)
|
||||
|
||||
# Create test token
|
||||
token = keypair.create_token(
|
||||
subject="john.doe",
|
||||
issuer="https://test.example.com",
|
||||
audience="superset-mcp-api",
|
||||
scopes=["dashboard:read", "chart:read", "dataset:read"]
|
||||
)
|
||||
print("Test token:", token)
|
||||
```
|
||||
|
||||
### Integration with Identity Providers
|
||||
|
||||
This authentication works with any JWT-compatible identity provider:
|
||||
- **Auth0**: Use your tenant's JWKS URL
|
||||
- **Okta**: Configure with your Okta domain JWKS endpoint
|
||||
- **AWS Cognito**: Use your user pool's JWKS URL
|
||||
- **Azure AD**: Configure with Microsoft identity platform
|
||||
- **Custom JWT**: Use your own public key for validation
|
||||
|
||||
The MCP service extracts user identity from standard JWT claims and doesn't require complex integration - just valid JWT tokens with appropriate scopes.
|
||||
|
||||
## MCP Audit Logging
|
||||
|
||||
The MCP service implements comprehensive audit logging to distinguish MCP requests from regular user requests in audit trails:
|
||||
|
||||
### Required Context Fields
|
||||
- **`log_source`**: Always set to "mcp" to identify MCP requests
|
||||
- **`impersonation`**: Username of the authenticated user making the MCP request
|
||||
- **`mcp_tool`**: Name of the specific MCP tool being executed
|
||||
|
||||
### Optional Enhanced Fields
|
||||
- **`model_info`**: LLM model information from User-Agent header
|
||||
- **`session_info`**: Session tracking from X-Session-ID header
|
||||
- **`whitelisted_payload`**: Sanitized tool parameters (sensitive data redacted)
|
||||
|
||||
### Payload Sanitization
|
||||
- **Sensitive keys redacted**: password, token, secret, key, auth fields
|
||||
- **Large content truncated**: Strings over 1000 characters truncated
|
||||
- **Security-first approach**: Better to over-redact than expose sensitive data
|
||||
|
||||
### Usage
|
||||
All MCP tools automatically include audit context via the `@mcp_auth_hook` decorator:
|
||||
|
||||
```python
|
||||
@mcp.tool
|
||||
@mcp_auth_hook # Automatically adds MCP audit context
|
||||
def my_tool(request: MyRequest) -> Dict[str, Any]:
|
||||
# Tool implementation
|
||||
pass
|
||||
```
|
||||
|
||||
This enables enterprise audit compliance and helps distinguish automated MCP requests from interactive user sessions.
|
||||
|
||||
## Cache Control & Performance
|
||||
|
||||
The MCP service provides comprehensive cache control that leverages Superset's existing cache infrastructure for optimal performance:
|
||||
|
||||
### Superset Cache Layers
|
||||
|
||||
Superset has multiple cache layers that the MCP service leverages:
|
||||
1. **Query Result Cache** - Caches actual query results from customer databases
|
||||
2. **Metadata Cache** - Caches table schemas, column info, etc.
|
||||
3. **Form Data Cache** - Caches chart configurations for explore URLs
|
||||
4. **Dashboard Cache** - Caches rendered dashboard components
|
||||
|
||||
### Cache Control Parameters
|
||||
|
||||
All MCP tools support cache control through request parameters:
|
||||
|
||||
#### Query Cache Control
|
||||
For tools that execute SQL queries (`get_chart_data`, `get_chart_data_cached`, `generate_chart`, `update_chart`):
|
||||
|
||||
```python
|
||||
{
|
||||
"use_cache": true, # Whether to use Superset's cache layers
|
||||
"force_refresh": false, # Force refresh cached data
|
||||
"cache_timeout": 3600 # Override cache timeout for this query (seconds)
|
||||
}
|
||||
```
|
||||
|
||||
#### Metadata Cache Control
|
||||
For tools that fetch metadata (`list_dashboards`, `list_charts`, `list_datasets`, `get_*_info`):
|
||||
|
||||
```python
|
||||
{
|
||||
"use_cache": true, # Whether to use metadata cache
|
||||
"refresh_metadata": false # Force refresh metadata for datasets/tables
|
||||
}
|
||||
```
|
||||
|
||||
#### Form Data Cache Control
|
||||
For tools that work with chart configurations (`generate_explore_link`, `update_chart_preview`):
|
||||
|
||||
```python
|
||||
{
|
||||
"cache_form_data": true # Whether to cache form data configurations
|
||||
}
|
||||
```
|
||||
|
||||
### Cache Status Information
|
||||
|
||||
Tools return detailed cache status to help understand data freshness:
|
||||
|
||||
```python
|
||||
{
|
||||
"cache_status": {
|
||||
"cache_hit": true, # Whether data was served from cache
|
||||
"cache_type": "query", # Type of cache used (query, metadata, form_data)
|
||||
"cache_age_seconds": 300, # Age of cached data in seconds
|
||||
"refreshed": false # Whether cache was refreshed in this request
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Usage Examples
|
||||
|
||||
```python
|
||||
# Get fresh data, bypassing cache
|
||||
get_chart_data({
|
||||
"identifier": 123,
|
||||
"use_cache": false,
|
||||
"force_refresh": true
|
||||
})
|
||||
|
||||
# Use cache but with custom timeout
|
||||
get_chart_data({
|
||||
"identifier": 123,
|
||||
"cache_timeout": 1800, # 30 minutes
|
||||
"use_cache": true
|
||||
})
|
||||
|
||||
# Refresh metadata for datasets
|
||||
list_datasets({
|
||||
"refresh_metadata": true,
|
||||
"use_cache": false
|
||||
})
|
||||
|
||||
# Fast metadata queries from cache
|
||||
list_charts({
|
||||
"use_cache": true,
|
||||
"refresh_metadata": false
|
||||
})
|
||||
```
|
||||
|
||||
### Performance Benefits
|
||||
|
||||
- **Faster Response Times**: Cached queries return instantly without database execution
|
||||
- **Reduced Database Load**: Identical queries hit cache regardless of how they were created (UI vs MCP)
|
||||
- **Smart Cache Keys**: Cache based on query hash, so identical SQL queries share cache entries
|
||||
- **Configurable TTL**: Per-dataset and global cache timeout configuration
|
||||
- **Cache Transparency**: Clear cache status reporting helps users understand data freshness
|
||||
|
||||
### Cache Iteration Support
|
||||
|
||||
Chart iterations can effectively utilize the cache layer:
|
||||
- When you modify a chart through MCP tools, if the underlying SQL query hasn't changed (same metrics, filters, time range), Superset serves from its query result cache
|
||||
- The cache key is based on query hash, so identical queries hit the cache regardless of how they were created (UI vs MCP)
|
||||
- This enables rapid chart iteration and preview generation
|
||||
|
||||
## Configuration & Deployment
|
||||
|
||||
### URL Configuration
|
||||
The MCP service now uses centralized URL configuration for consistency across all tools:
|
||||
|
||||
```python
|
||||
# In superset_config.py
|
||||
SUPERSET_WEBSERVER_ADDRESS = "http://localhost:8088" # Development
|
||||
SUPERSET_WEBSERVER_ADDRESS = "https://superset.company.com" # Production
|
||||
```
|
||||
|
||||
**Key Features:**
|
||||
- **Centralized URL management**: All tools use `get_superset_base_url()` utility
|
||||
- **Environment flexibility**: Fallback to `SUPERSET_WEBSERVER_SCHEME`/`SUPERSET_WEBSERVER_HOST`/`SUPERSET_WEBSERVER_PORT`
|
||||
- **Screenshot service integration**: MCP service serves screenshots on same port as WSGI endpoint
|
||||
- **Configuration hierarchy**: `SUPERSET_WEBSERVER_ADDRESS` → component URLs → localhost:8088 fallback
|
||||
|
||||
### Agent Integration Options
|
||||
1. **Claude Agent SDK**: Create cloud agent connecting to local/deployed MCP service
|
||||
2. **LangChain Toolkit**: Use `langchain-mcp` for chatbot integration
|
||||
3. **Direct MCP Connection**: Connect any MCP-compatible client to service
|
||||
|
||||
## Future Milestones
|
||||
|
||||
### 🔒 Enterprise Security (Future Phase)
|
||||
- **Advanced Security Hooks**: Tool poisoning attack prevention and rate limiting
|
||||
- **Comprehensive Audit Logging**: Enhanced logging and monitoring for enterprise environments
|
||||
- **RBAC Extensions**: Advanced permission models and user role management
|
||||
- **Multi-tenant Support**: Isolated environments for enterprise deployments
|
||||
|
||||
### 📊 Advanced Analytics (Future Phase)
|
||||
- **Smart Cache Management**: Intelligent caching strategies with force refresh capabilities
|
||||
- **Dashboard Creation**: Automated dashboard generation with multiple related charts
|
||||
- **Advanced Chart Types**: Support for complex visualizations (maps, 3D, custom viz)
|
||||
- **Business Intelligence**: Natural language to SQL query generation
|
||||
- **End-to-End Testing**: Synthetic environments with example database integration
|
||||
|
||||
## Recent Major Improvements
|
||||
|
||||
### 🔧 **BaseDAO Type Safety & Performance**
|
||||
- **UUID Type Conversion**: Centralized `_convert_value_for_column()` method for type-safe UUID handling
|
||||
- **Flexible Column Support**: Enhanced `find_by_id()` and `find_by_ids()` with customizable column lookups
|
||||
- **Test Coverage**: 185+ passing unit tests including edge cases and error scenarios
|
||||
- **Code Quality**: Eliminated code duplication and hardcoded string checks
|
||||
|
||||
### 🚀 **MCP Service Consistency & Reliability**
|
||||
- **Async Pattern Cleanup**: Removed unnecessary async declarations for better performance
|
||||
- **SQL Lab Integration**: Fixed parameter naming (`dbid`) for proper frontend compatibility
|
||||
- **Error Handling**: Robust UUID conversion with graceful fallbacks for malformed data
|
||||
- **Type Validation**: Enhanced SQLAlchemy column type inspection for safer operations
|
||||
|
||||
### 🆕 **New Dashboard & SQL Lab Tools**
|
||||
- **Dashboard Generation**: Dashboard creation with automatic chart layout
|
||||
- **Chart Management**: Add charts to existing dashboards with intelligent positioning
|
||||
- **SQL Lab Context**: Pre-configured SQL Lab sessions with database/schema selection
|
||||
- **Preview System**: Chart screenshots, ASCII art, and data extraction capabilities
|
||||
|
||||
### 📊 **Enhanced Chart & Data Capabilities**
|
||||
- **Multi-format Data Export**: JSON, CSV, Excel export with pagination support
|
||||
- **Preview Generation**: URL screenshots, base64 images, ASCII charts, and table data
|
||||
- **Smart Layout**: Automatic 2-column dashboard layouts with optimized positioning
|
||||
- **Context Preservation**: Seamless navigation between Superset interfaces
|
||||
|
||||
### 🔒 **Production-Ready Architecture**
|
||||
- **Configurable Auth Factory**: Enterprise JWT authentication following Superset patterns
|
||||
- **Request Schema Pattern**: Structured inputs eliminating LLM parameter validation issues
|
||||
- **Multi-identifier Support**: ID, UUID, and slug lookups across all tools with type safety
|
||||
- **Professional Testing**: Integration tests, mocking patterns, and edge case coverage
|
||||
- **MCP Audit Logging**: Comprehensive audit trails with MCP context, payload sanitization, and impersonation tracking
|
||||
175
superset/mcp_service/README_PHASE1_STATUS.md
Normal file
175
superset/mcp_service/README_PHASE1_STATUS.md
Normal file
@@ -0,0 +1,175 @@
|
||||
# Superset MCP Service – Phase 1 Status Update
|
||||
|
||||
## Background
|
||||
The Model Context Protocol (MCP) is a new protocol for exposing high-level, structured actions in Superset, designed for LLM agents and automation. Phase 1 delivers a foundational, extensible MCP service in Superset, leveraging internal APIs (DAOs/commands) and providing a versioned, developer-friendly interface for both Apache and Preset use cases. ([SIP-171](https://github.com/apache/superset/issues/33870))
|
||||
|
||||
## Phase 1 Objectives (from SoW/SIP-171)
|
||||
- Standalone MCP service, config flag, CLI, modular, stateless
|
||||
- Strong typing: all actions use DAOs/commands and Pydantic schemas
|
||||
- Clear extension points for Preset-specific auth, RBAC, and logging
|
||||
- 3+ high-value MCP actions (list, info, mutation)
|
||||
- Developer experience: easy to run, clear docs, tests
|
||||
- Auth/RBAC/logging hooks stubbed, ready for enterprise
|
||||
- Out of scope: full RBAC, impersonation, logging, external identity provider integration
|
||||
|
||||
## What's Delivered (Phase 1)
|
||||
|
||||
### ✅ Completed Epics
|
||||
| Epic ID | Name | Status | Key Deliverables |
|
||||
|---------|------|--------|-----------------|
|
||||
| 90298 | **Implement Standalone MCP Service CLI** | ✅ Complete | ASGI-based FastMCP server, config flag, CLI (`superset mcp run`) |
|
||||
| 90301 | **Add Auth/RBAC Hooks** | ✅ Complete | JWT Bearer authentication, configurable factory pattern, scope-based authorization |
|
||||
|
||||
### ✅ Recently Completed
|
||||
| Epic ID | Name | Status | Progress |
|
||||
|---------|------|--------|---------|
|
||||
| 90300 | **Implement list/info tools for dataset, dashboard, chart** | ✅ Completed | All tools with multi-identifier support, enhanced search/filtering |
|
||||
| 90299 | **Define Modular, Typed Schemas** | ✅ Completed | Pydantic v2 schemas, FastMCP Complex Inputs Pattern |
|
||||
| 90302 | **Write Dev Guide and Docs** | 🔧 QA | Comprehensive documentation integrated into Superset Docusaurus |
|
||||
| 90304 | **Implement Chart Creation Mutation** | 🔧 In Review | Chart creation, dashboard generation, update operations |
|
||||
| 90305 | **Implement Navigation Actions** | 🔧 In Review | `generate_explore_link` and `open_sql_lab_with_context` |
|
||||
| 90303 | **Document Preset Extension Points** | 🔧 In Review | RBAC, OIDC integration design for enterprise |
|
||||
| 90511 | **Backend Chart Rendering** | 🔧 QA | Chart data/preview with screenshots, ASCII, table formats |
|
||||
| 90509 | **Support for Bearer Authentication** | 🔧 QA | JWT Bearer authentication with configurable factory |
|
||||
| 90510 | **Caching and Refresh** | 🔧 QA | Cache control parameters leveraging Superset infrastructure |
|
||||
| 90548 | **Audit Logging** | 🔧 In Review | MCP context tracking with impersonation support |
|
||||
|
||||
### 🔧 Technical Achievements
|
||||
- **Service Infrastructure**: ASGI-based FastMCP server, stateless design, professional CLI
|
||||
- **Production Auth**: JWT Bearer authentication with configurable factory pattern (per @dpgaspar's design)
|
||||
- **Code Quality**: 149 passing unit tests, full pre-commit compliance, professional error handling
|
||||
- **Strong Typing**: All input/output uses Pydantic v2 with detailed field descriptions
|
||||
- **Modular Architecture**: Domain-grouped tools (`dashboard/`, `dataset/`, `chart/`, `system/`)
|
||||
- **Request Schema Pattern**: Eliminates LLM parameter validation issues with structured requests
|
||||
- **Multi-Identifier Support**: ID/UUID/slug lookups across all get_*_info tools
|
||||
- **Enhanced Search**: UUID/slug fields included in search and default response columns
|
||||
- **Cache Control**: Comprehensive cache control parameters across all tools leveraging Superset's existing cache layers
|
||||
|
||||
### 🛠️ Core Tools Implemented (18 Total)
|
||||
- **Dashboard Tools**: `list_dashboards`, `get_dashboard_info`, `get_dashboard_available_filters`, `generate_dashboard`, `add_chart_to_existing_dashboard`
|
||||
- **Chart Tools**: `list_charts`, `get_chart_info`, `get_chart_available_filters`, `generate_chart`, `update_chart`, `update_chart_preview`, `get_chart_data`, `get_chart_preview`
|
||||
- **Dataset Tools**: `list_datasets`, `get_dataset_info`, `get_dataset_available_filters`
|
||||
- **System Tools**: `get_superset_instance_info`, `generate_explore_link`
|
||||
- **SQL Lab Tools**: `open_sql_lab_with_context`
|
||||
|
||||
## Phase 1 Completion Status
|
||||
|
||||
**Overall Progress: 95% Complete** (All core epics complete, finalization tasks remaining)
|
||||
|
||||
**Phase 1 Status**: Core features complete, demo and testing needed for finalization
|
||||
|
||||
### ✅ Recent Technical Completions
|
||||
- **BaseDAO Type Safety**: Enhanced UUID handling with extensive test coverage ✅
|
||||
- **URL Configuration**: `SUPERSET_WEBSERVER_ADDRESS` support with centralized URL management ✅
|
||||
- **MCP Audit Logging**: Comprehensive audit trails with impersonation tracking and payload sanitization ✅
|
||||
- **Chart Update Operations**: `update_chart` and `update_chart_preview` for modifying saved and cached charts ✅
|
||||
- **Schema Optimization**: Optional fields, minimal columns, null value handling ✅
|
||||
- **Chart Embedding**: Screenshot URLs and backend rendering with Firefox WebDriver for LLM chat integration ✅
|
||||
- **SQL Lab Integration**: Pre-configured SQL Lab sessions with database/schema selection ✅
|
||||
- **Dashboard Management**: Dashboard creation and chart addition capabilities ✅
|
||||
- **Cache Control Implementation**: Integrated cache control parameters across all tools with schema inheritance pattern ✅
|
||||
- Query cache control for chart data and generation tools
|
||||
- Metadata cache control for list and get_info tools
|
||||
- Form data cache control for explore link and preview tools
|
||||
- Cache status reporting in tool responses
|
||||
|
||||
### 🎯 Phase 1 Finalization Remaining
|
||||
| Epic ID | Task | Status | Description |
|
||||
|---------|------|--------|-------------|
|
||||
| 90306 | **Create Demo Script/Notebook** | 📋 Procurement | Interactive demo showing bot capabilities |
|
||||
| 90527 | **End-to-End Prompt Testing** | 🔧 In Development | At least one complete LLM agent workflow test |
|
||||
|
||||
### 🚫 Out of Scope Items
|
||||
| Epic ID | Name | Status | Reason |
|
||||
|---------|------|--------|--------|
|
||||
| 90508 | **LLM/Chat Friendly Backend Rendered Charts** | 🔧 QA | Vega-Lite/Plotly JSON for enhanced LLM integration |
|
||||
| 90398 | **Security Hooks for Tool Poisoning Attacks** | 📋 Procurement | Advanced security feature for future phase |
|
||||
| 90397 | **In-Preset Hosted Demo (OAuth, impersonation)** | 📋 Procurement | Cloud deployment with proper authentication |
|
||||
|
||||
|
||||
## Phase 1 Finalization Tasks
|
||||
|
||||
**Remaining work to complete Phase 1:**
|
||||
|
||||
1. **Demo Video/Script** - Create comprehensive demonstration
|
||||
- Video walkthrough of all 16 MCP tools working end-to-end
|
||||
- Claude Desktop integration examples
|
||||
- Complete workflow from data exploration to chart creation
|
||||
|
||||
2. **End-to-End Prompt Test** - Validate complete LLM workflow
|
||||
- At least one complete multi-step agent interaction
|
||||
- Test real-world use case: "Create a sales dashboard with 3 charts"
|
||||
- Verify all tools work together seamlessly
|
||||
|
||||
## Team Meeting Notes - Future Considerations
|
||||
|
||||
### LangChain Integration Ideas
|
||||
1. **Create a chat bot with LangChain**
|
||||
2. **Tool Discovery**: When user chats, append/prepend message saying "hey you have these tools you can use"
|
||||
3. **Tool Mapping**: Map tools to what people want to do
|
||||
4. **Diego's Note**: This works well for lots of tools when we don't know which one to use, but for our cases we might get away without the mapping. Later we can use custom prompts to figure out exact tools
|
||||
|
||||
### Max's Priority Areas
|
||||
1. **Reference**: https://context7.com/
|
||||
2. **Playwright MCP** - Could it be leveraged for chart generation?
|
||||
3. **Easy Setup** - Making it easy for anyone to pull branch and get going
|
||||
4. **Focus on Quality over Coverage** - Instead of coverage, focus on getting the tools we have already right
|
||||
- **Communication Layer**:
|
||||
- Error handling improvements
|
||||
- Ensure LLM gives proper JSON/object format
|
||||
- Return clear/direct messages: "hey you can't pass it with quotes you need to pass it this way"
|
||||
5. **Next 20 Tools** - Define the semantics and schemas for these tools
|
||||
6. **GitHub Codespaces** - Uses docker compose lite.yaml
|
||||
7. **Next Step**: Build UI chat in Superset
|
||||
|
||||
### Diego's Agentic System Questions
|
||||
1. Right now we just get Claude to basically figure out what to call for us
|
||||
2. Do you think having a proper agentic system that does multiple passes on the user input and is specialized would help?
|
||||
3. Do you think it's an MCP service level thing?
|
||||
4. Would this be a middle layer? (let's work together on how this would work out)
|
||||
5. Make sure README is updated
|
||||
|
||||
## Future Development (Post-Phase 1)
|
||||
|
||||
### Enterprise & Security
|
||||
- **Advanced Security Hooks**: Tool poisoning prevention, rate limiting
|
||||
- **Enhanced RBAC**: Advanced permission models, multi-tenant support
|
||||
- **Audit**: Enterprise logging and monitoring
|
||||
|
||||
### Advanced Features
|
||||
- **Dashboard Creation**: Multi-chart dashboard generation
|
||||
- **Advanced Chart Types**: Maps, 3D visualizations, custom components
|
||||
- **Business Intelligence**: Natural language to SQL query generation
|
||||
|
||||
## Summary Table
|
||||
| Epic/Deliverable | Epic ID | Status | Completion |
|
||||
|------------------|---------|--------|-----------|
|
||||
| **Standalone MCP Service CLI** | 90298 | ✅ Complete | 100% |
|
||||
| **Add Auth/RBAC Hooks** | 90301 | ✅ Complete | 100% |
|
||||
| **List/Info Tools** | 90300 | 🟡 In QA | 95% |
|
||||
| **Define Modular Schemas** | 90299 | 🟡 In Review | 90% |
|
||||
| **Write Dev Guide and Docs** | 90302 | 🟡 In Review | 90% |
|
||||
| **Chart Creation Mutation** | 90304 | 🟡 In Review | 85% |
|
||||
| **Navigation Actions** | 90305 | 🟡 In Review | 75% |
|
||||
| **Document Preset Extensions** | 90303 | 🟡 In Review | 80% |
|
||||
| **Backend Chart Rendering** | 90511 | 🔧 In Development | 20% |
|
||||
| **Bearer Authentication** | 90509 | 🔧 In Development | 60% |
|
||||
| **Demo Script/Notebook** | 90306 | 📋 Stretch Goal | 0% |
|
||||
| **In-Preset OAuth Demo** | 90397 | 📋 Stretch Goal | 0% |
|
||||
| **LLM-Friendly Rendering** | 90508 | 📋 Stretch Goal | 0% |
|
||||
| **Security Hooks** | 90398 | 🚫 Out of Scope | 0% |
|
||||
|
||||
**Phase 1 Core: 95% Complete** | **Stretch Goals: Available for additional polish**
|
||||
|
||||
## Key Metrics
|
||||
- **194+ Unit Tests**: All passing with extensive coverage including URL utils and audit logging
|
||||
- **18 Core Tools**: List/read/update operations for all entities, chart creation/updates, dashboard generation, SQL Lab integration
|
||||
- **Production Auth**: JWT Bearer with configurable factory pattern and MCP audit logging
|
||||
- **Zero Breaking Changes**: Stable API ready for Phase 2 enhancements
|
||||
- **Developer Experience**: Single command setup, detailed docs, clear extension points
|
||||
- **Type Safety**: Enhanced BaseDAO with UUID handling and robust error handling
|
||||
- **Enterprise Audit**: MCP-specific audit logging with impersonation tracking and payload sanitization
|
||||
|
||||
## Reference
|
||||
- [SIP-171: MCP Service Proposal](https://github.com/apache/superset/issues/33870)
|
||||
- [Epic Tracking CSV](project-epic-status.csv) - Updated July 28, 2025
|
||||
444
superset/mcp_service/README_SCHEMAS.md
Normal file
444
superset/mcp_service/README_SCHEMAS.md
Normal file
@@ -0,0 +1,444 @@
|
||||
# Superset MCP Service: Tool Schemas Reference
|
||||
|
||||
This document provides a reference for the input and output schemas of all MCP tools in the Superset MCP service. All schemas are Pydantic v2 models with field descriptions for LLM/OpenAPI compatibility.
|
||||
|
||||
**Status**: Phase 1 Complete (95% done). All core schemas stable and production-ready with extensive testing coverage.
|
||||
|
||||
## FastMCP Complex Inputs Pattern
|
||||
|
||||
All MCP tools use **structured request objects** instead of individual parameters to eliminate LLM validation issues:
|
||||
|
||||
```python
|
||||
# All list tools use request objects
|
||||
list_dashboards(request=ListDashboardsRequest(...))
|
||||
list_datasets(request=ListDatasetsRequest(...))
|
||||
list_charts(request=ListChartsRequest(...))
|
||||
|
||||
# All get_info tools use request objects with multi-identifier support
|
||||
get_dashboard_info(request=GetDashboardInfoRequest(identifier="123")) # ID
|
||||
get_dashboard_info(request=GetDashboardInfoRequest(identifier="uuid-string")) # UUID
|
||||
get_dashboard_info(request=GetDashboardInfoRequest(identifier="slug-string")) # Slug
|
||||
|
||||
# Chart creation with detailed config
|
||||
generate_chart(request=GenerateChartRequest(
|
||||
dataset_id="1",
|
||||
config=XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="sales", aggregate="SUM")],
|
||||
kind="line"
|
||||
)
|
||||
))
|
||||
```
|
||||
|
||||
### Key Benefits
|
||||
- **No parameter ambiguity**: Filters are always arrays, never strings
|
||||
- **Clear validation**: Cannot use both search and filters simultaneously
|
||||
- **Multi-identifier support**: ID, UUID, and slug (where applicable) in single interface
|
||||
- **LLM-friendly**: Unambiguous types prevent common LLM validation errors
|
||||
- **Production-ready**: 185+ unit tests ensure schema reliability
|
||||
|
||||
## Cache Control Schemas
|
||||
|
||||
All MCP tools support cache control through schema inheritance:
|
||||
|
||||
### CacheControlMixin
|
||||
Base mixin for all cache control:
|
||||
- `use_cache`: `bool = True` — Whether to use Superset's cache layers
|
||||
- `force_refresh`: `bool = False` — Whether to force refresh cached data
|
||||
|
||||
### QueryCacheControl
|
||||
For tools that execute SQL queries (`get_chart_data`, `generate_chart`, `update_chart`):
|
||||
- Inherits: `CacheControlMixin`
|
||||
- `cache_timeout`: `Optional[int]` — Override cache timeout for this query (seconds)
|
||||
|
||||
### MetadataCacheControl
|
||||
For tools that fetch metadata (`list_*`, `get_*_info` tools):
|
||||
- Inherits: `CacheControlMixin`
|
||||
- `refresh_metadata`: `bool = False` — Force refresh metadata from database
|
||||
|
||||
### FormDataCacheControl
|
||||
For tools working with chart configurations (`generate_explore_link`, `update_chart_preview`):
|
||||
- Inherits: `CacheControlMixin`
|
||||
- `cache_form_data`: `bool = True` — Whether to cache form data configurations
|
||||
|
||||
### CacheStatus
|
||||
Returned by tools to indicate cache usage:
|
||||
- `cache_hit`: `bool` — Whether data was served from cache
|
||||
- `cache_type`: `Literal["query", "metadata", "form_data", "none"]` — Type of cache used
|
||||
- `cache_age_seconds`: `Optional[int]` — Age of cached data in seconds
|
||||
- `refreshed`: `bool` — Whether cache was refreshed in this request
|
||||
|
||||
## Dashboards
|
||||
|
||||
### list_dashboards
|
||||
|
||||
**Input:** `ListDashboardsRequest` (inherits `MetadataCacheControl`)
|
||||
- `filters`: `List[DashboardFilter]` — List of filter objects (cannot be used with search)
|
||||
- `search`: `Optional[str]` — Free-text search string (cannot be used with filters)
|
||||
- `select_columns`: `List[str]` — Columns to select (defaults include id, dashboard_title, slug, uuid)
|
||||
- `order_column`: `Optional[str]` — Column to order results by (valid: id, dashboard_title, slug, published, changed_on, created_on)
|
||||
- `order_direction`: `Optional[Literal['asc', 'desc']]` — Order direction
|
||||
- `page`: `int` — Page number (0-based)
|
||||
- `page_size`: `int` — Number of items per page (default 100)
|
||||
- `use_cache`: `bool = True` — Whether to use metadata cache
|
||||
- `refresh_metadata`: `bool = False` — Force refresh metadata from database
|
||||
|
||||
**Returns:** `DashboardList`
|
||||
- `dashboards`: `List[DashboardListItem]`
|
||||
- `count`: `int`
|
||||
- `total_count`: `int`
|
||||
- `page`: `int`
|
||||
- `page_size`: `int`
|
||||
- `total_pages`: `int`
|
||||
- `has_previous`: `bool`
|
||||
- `has_next`: `bool`
|
||||
- `columns_requested`: `List[str]`
|
||||
- `columns_loaded`: `List[str]`
|
||||
- `filters_applied`: `List[Any]`
|
||||
- `pagination`: `PaginationInfo`
|
||||
- `timestamp`: `datetime`
|
||||
|
||||
### get_dashboard_info
|
||||
|
||||
**Input:** `GetDashboardInfoRequest`
|
||||
- `identifier`: `Union[int, str]` — Dashboard identifier (supports ID, UUID, or slug)
|
||||
|
||||
**Returns:** `DashboardInfo` or `DashboardError`
|
||||
|
||||
**Multi-Identifier Support:**
|
||||
- **ID**: Numeric dashboard ID (e.g., `123`)
|
||||
- **UUID**: Dashboard UUID string (e.g., `"a1b2c3d4-e5f6-7890-abcd-ef1234567890"`)
|
||||
- **Slug**: Dashboard slug string (e.g., `"sales-dashboard"`)
|
||||
|
||||
### get_dashboard_available_filters
|
||||
|
||||
**Input:** `GetDashboardAvailableFiltersRequest` (API consistency)
|
||||
- No parameters required (empty request object for consistent API design)
|
||||
|
||||
**Returns:** `DashboardAvailableFilters`
|
||||
- `column_operators`: `Dict[str, Any]` — Available filter operators and metadata for each column
|
||||
|
||||
## Datasets
|
||||
|
||||
### list_datasets
|
||||
|
||||
**Input:** `ListDatasetsRequest` (inherits `MetadataCacheControl`)
|
||||
- `filters`: `List[DatasetFilter]` — List of filter objects (cannot be used with search)
|
||||
- `search`: `Optional[str]` — Free-text search string (cannot be used with filters)
|
||||
- `select_columns`: `List[str]` — Columns to select (defaults include id, table_name, uuid)
|
||||
- `order_column`: `Optional[str]` — Column to order results by (valid: id, table_name, schema, changed_on, created_on)
|
||||
- `order_direction`: `Optional[Literal['asc', 'desc']]` — Order direction
|
||||
- `page`: `int` — Page number (0-based)
|
||||
- `page_size`: `int` — Number of items per page (default 100)
|
||||
- `use_cache`: `bool = True` — Whether to use metadata cache
|
||||
- `refresh_metadata`: `bool = False` — Force refresh metadata from database
|
||||
|
||||
**Returns:** `DatasetList`
|
||||
- `datasets`: `List[DatasetListItem]` (each includes columns and metrics)
|
||||
- `count`: `int`
|
||||
- `total_count`: `int`
|
||||
- `page`: `int`
|
||||
- `page_size`: `int`
|
||||
- `total_pages`: `int`
|
||||
- `has_previous`: `bool`
|
||||
- `has_next`: `bool`
|
||||
- `columns_requested`: `List[str]`
|
||||
- `columns_loaded`: `List[str]`
|
||||
- `filters_applied`: `List[Any]`
|
||||
- `pagination`: `PaginationInfo`
|
||||
- `timestamp`: `datetime`
|
||||
|
||||
### get_dataset_info
|
||||
|
||||
**Input:** `GetDatasetInfoRequest`
|
||||
- `identifier`: `Union[int, str]` — Dataset identifier (supports ID or UUID)
|
||||
|
||||
**Returns:** `DatasetInfo` or `DatasetError` (now includes columns and metrics)
|
||||
|
||||
**Multi-Identifier Support:**
|
||||
- **ID**: Numeric dataset ID (e.g., `123`)
|
||||
- **UUID**: Dataset UUID string (e.g., `"a1b2c3d4-e5f6-7890-abcd-ef1234567890"`)
|
||||
|
||||
#### DatasetInfo fields (new):
|
||||
- `columns`: `List[TableColumnInfo]` — List of columns with name, type, verbose name, etc.
|
||||
- `metrics`: `List[SqlMetricInfo]` — List of metrics with name, expression, verbose name, etc.
|
||||
|
||||
#### TableColumnInfo
|
||||
- `column_name`: `str` — Column name
|
||||
- `verbose_name`: `Optional[str]` — Verbose name
|
||||
- `type`: `Optional[str]` — Column type
|
||||
- `is_dttm`: `Optional[bool]` — Is datetime column
|
||||
- `groupby`: `Optional[bool]` — Is groupable
|
||||
- `filterable`: `Optional[bool]` — Is filterable
|
||||
- `description`: `Optional[str]` — Column description
|
||||
|
||||
#### SqlMetricInfo
|
||||
- `metric_name`: `str` — Metric name
|
||||
- `verbose_name`: `Optional[str]` — Verbose name
|
||||
- `expression`: `Optional[str]` — SQL expression
|
||||
- `description`: `Optional[str]` — Metric description
|
||||
|
||||
> **Note:** All dataset list/info responses now include full column and metric metadata for each dataset.
|
||||
|
||||
### get_dataset_available_filters
|
||||
|
||||
**Input:** `GetDatasetAvailableFiltersRequest` (API consistency)
|
||||
- No parameters required (empty request object for consistent API design)
|
||||
|
||||
**Returns:** `DatasetAvailableFilters`
|
||||
- `column_operators`: `Dict[str, Any]` — Available filter operators and metadata for each column
|
||||
|
||||
## Charts
|
||||
|
||||
### list_charts
|
||||
|
||||
**Input:** `ListChartsRequest` (inherits `MetadataCacheControl`)
|
||||
- `filters`: `List[ChartFilter]` — List of filter objects (cannot be used with search)
|
||||
- `search`: `Optional[str]` — Free-text search string (cannot be used with filters)
|
||||
- `select_columns`: `List[str]` — Columns to select (defaults include id, slice_name, uuid)
|
||||
- `order_column`: `Optional[str]` — Column to order results by (valid: id, slice_name, viz_type, datasource_name, description, changed_on, created_on)
|
||||
- `order_direction`: `Optional[Literal['asc', 'desc']]` — Order direction
|
||||
- `page`: `int` — Page number (0-based)
|
||||
- `page_size`: `int` — Number of items per page (default 100)
|
||||
- `use_cache`: `bool = True` — Whether to use metadata cache
|
||||
- `refresh_metadata`: `bool = False` — Force refresh metadata from database
|
||||
|
||||
**Returns:** `ChartList`
|
||||
- `charts`: `List[ChartListItem]`
|
||||
- `count`: `int`
|
||||
- `total_count`: `int`
|
||||
- `page`: `int`
|
||||
- `page_size`: `int`
|
||||
- `total_pages`: `int`
|
||||
- `has_previous`: `bool`
|
||||
- `has_next`: `bool`
|
||||
- `columns_requested`: `List[str]`
|
||||
- `columns_loaded`: `List[str]`
|
||||
- `filters_applied`: `List[Any]`
|
||||
- `pagination`: `PaginationInfo`
|
||||
- `timestamp`: `datetime`
|
||||
|
||||
### get_chart_info
|
||||
|
||||
**Input:** `GetChartInfoRequest`
|
||||
- `identifier`: `Union[int, str]` — Chart identifier (supports ID or UUID)
|
||||
|
||||
**Returns:** `ChartInfo` or `ChartError`
|
||||
|
||||
**Multi-Identifier Support:**
|
||||
- **ID**: Numeric chart ID (e.g., `123`)
|
||||
- **UUID**: Chart UUID string (e.g., `"a1b2c3d4-e5f6-7890-abcd-ef1234567890"`)
|
||||
|
||||
### get_chart_available_filters
|
||||
|
||||
**Input:** `GetChartAvailableFiltersRequest` (API consistency)
|
||||
- No parameters required (empty request object for consistent API design)
|
||||
|
||||
**Returns:** `ChartAvailableFiltersResponse`
|
||||
- `column_operators`: `Dict[str, Any]` — Available filter operators and metadata for each column
|
||||
|
||||
### generate_chart
|
||||
|
||||
**Input:** `GenerateChartRequest`
|
||||
- `dataset_id`: `str` — ID of the dataset to use
|
||||
- `config`: `ChartConfig` — Chart configuration (supports table and XY charts)
|
||||
|
||||
**Returns:** `Dict[str, Any]`
|
||||
- `chart`: `Optional[Dict]` — The created chart info with id, slice_name, viz_type, and url
|
||||
- `error`: `Optional[str]` — Error message, if creation failed
|
||||
|
||||
#### ChartConfig (Union of TableChartConfig and XYChartConfig)
|
||||
|
||||
#### TableChartConfig
|
||||
- `chart_type`: `Literal["table"]` — Chart type
|
||||
- `columns`: `List[ColumnRef]` — Columns to display
|
||||
- `filters`: `Optional[List[FilterConfig]]` — Filters to apply
|
||||
- `sort_by`: `Optional[List[str]]` — Columns to sort by
|
||||
|
||||
#### XYChartConfig
|
||||
- `chart_type`: `Literal["xy"]` — Chart type
|
||||
- `x`: `ColumnRef` — X-axis column
|
||||
- `y`: `List[ColumnRef]` — Y-axis columns
|
||||
- `kind`: `Literal["line", "bar", "area", "scatter"]` — Chart visualization type
|
||||
- `group_by`: `Optional[ColumnRef]` — Column to group by
|
||||
- `x_axis`: `Optional[AxisConfig]` — X-axis configuration
|
||||
- `y_axis`: `Optional[AxisConfig]` — Y-axis configuration
|
||||
- `legend`: `Optional[LegendConfig]` — Legend configuration
|
||||
- `filters`: `Optional[List[FilterConfig]]` — Filters to apply
|
||||
|
||||
#### ColumnRef
|
||||
- `name`: `str` — Column name
|
||||
- `label`: `Optional[str]` — Display label for the column
|
||||
- `dtype`: `Optional[str]` — Data type hint
|
||||
- `aggregate`: `Optional[str]` — SQL aggregation function (SUM, COUNT, AVG, MIN, MAX, etc.)
|
||||
|
||||
#### AxisConfig
|
||||
- `title`: `Optional[str]` — Axis title
|
||||
- `scale`: `Optional[Literal["linear", "log"]]` — Axis scale type
|
||||
- `format`: `Optional[str]` — Format string (e.g. '$,.2f')
|
||||
|
||||
#### LegendConfig
|
||||
- `show`: `bool` — Whether to show legend
|
||||
- `position`: `Optional[Literal["top", "bottom", "left", "right"]]` — Legend position
|
||||
|
||||
#### FilterConfig
|
||||
- `column`: `str` — Column to filter on
|
||||
- `op`: `Literal["=", ">", "<", ">=", "<=", "!="]` — Filter operator
|
||||
- `value`: `Union[str, int, float, bool]` — Filter value
|
||||
|
||||
#### Supported Chart Types
|
||||
- **Table charts** (`table`) — Simple column display with filters and sorting
|
||||
- **Line charts** (`echarts_timeseries_line`) — Time series line charts
|
||||
- **Bar charts** (`echarts_timeseries_bar`) — Time series bar charts
|
||||
- **Area charts** (`echarts_area`) — Time series area charts
|
||||
- **Scatter charts** (`echarts_timeseries_scatter`) — Time series scatter charts
|
||||
|
||||
#### Metric Handling
|
||||
The tool intelligently handles two metric formats:
|
||||
1. **Simple metrics** (like `["count"]`) — Passed as simple strings
|
||||
2. **Complex metrics** (like column names) — Converted to full Superset metric objects with SQL aggregators (SUM, COUNT, AVG, MIN, MAX)
|
||||
|
||||
#### Chart Creation Output
|
||||
```python
|
||||
{
|
||||
"chart": {
|
||||
"id": 123,
|
||||
"slice_name": "Sales Over Time",
|
||||
"viz_type": "echarts_timeseries_line",
|
||||
"url": "/explore/?form_data=...",
|
||||
"explore_url": "http://localhost:8088/explore/?form_data=..."
|
||||
},
|
||||
"error": None
|
||||
}
|
||||
```
|
||||
|
||||
## System Tools
|
||||
|
||||
### get_superset_instance_info
|
||||
|
||||
**Input:** `GetSupersetInstanceInfoRequest` (API consistency)
|
||||
- No parameters required (empty request object for consistent API design)
|
||||
|
||||
**Returns:** `SupersetInstanceInfo`
|
||||
- `version`: `str` — Superset version
|
||||
- `build_number`: `Optional[str]` — Build identifier
|
||||
- `instance_id`: `str` — Unique instance identifier
|
||||
- `mcp_service_version`: `str` — MCP service version
|
||||
- `authentication_enabled`: `bool` — Whether JWT authentication is enabled
|
||||
- `available_tools`: `List[str]` — List of available MCP tools
|
||||
- `supported_chart_types`: `List[str]` — Supported chart types for creation
|
||||
|
||||
### generate_explore_link
|
||||
|
||||
**Input:** `GenerateExploreLinkRequest`
|
||||
- `dataset_id`: `str` — Dataset ID to explore
|
||||
- `config`: `ChartConfig` — Chart configuration (same as generate_chart)
|
||||
|
||||
**Returns:** `ExploreLinkResponse`
|
||||
- `explore_url`: `str` — Full URL to Superset explore interface with chart configuration
|
||||
- `form_data`: `Dict[str, Any]` — Serialized form data for the chart
|
||||
|
||||
## Authentication Context
|
||||
|
||||
When authentication is enabled, all tools receive additional context:
|
||||
|
||||
### JWT Authentication
|
||||
- **User Extraction**: JWT claims (subject, client_id, email, username) mapped to Superset users
|
||||
- **Scope Validation**: Each tool validates required scopes before execution
|
||||
- **Audit Logging**: All operations logged with user context and JWT metadata
|
||||
- **Impersonation**: Optional `run_as` parameter for user impersonation (where permitted)
|
||||
|
||||
### Error Responses
|
||||
When authentication fails or permissions are insufficient:
|
||||
```python
|
||||
{
|
||||
"error": "Access denied: user lacks permission for tool_name",
|
||||
"error_type": "PermissionError",
|
||||
"required_scopes": ["chart:read"],
|
||||
"user_scopes": ["dashboard:read"]
|
||||
}
|
||||
```
|
||||
|
||||
## Model Relationships
|
||||
|
||||
```mermaid
|
||||
flowchart TD
|
||||
subgraph Schema Types
|
||||
A["DashboardListItem"]
|
||||
B["DatasetListItem"]
|
||||
C["ChartListItem"]
|
||||
D["UserInfo"]
|
||||
E["TagInfo"]
|
||||
F["RoleInfo"]
|
||||
G["TableColumnInfo"]
|
||||
H["SqlMetricInfo"]
|
||||
I["ChartConfig"]
|
||||
J["GenerateChartRequest"]
|
||||
end
|
||||
A -- owners --> D
|
||||
A -- tags --> E
|
||||
A -- roles --> F
|
||||
B -- owners --> D
|
||||
B -- tags --> E
|
||||
B -- columns --> G
|
||||
B -- metrics --> H
|
||||
C -- owners --> D
|
||||
C -- tags --> E
|
||||
J -- config --> I
|
||||
I -- columns --> G
|
||||
```
|
||||
|
||||
## Request Schema Pattern Benefits
|
||||
|
||||
All tools using the FastMCP Complex Inputs Pattern provide:
|
||||
|
||||
### For List Tools (`list_*`)
|
||||
- **Clear array types**: `filters` is always `List[Filter]`, never a string
|
||||
- **Mutual exclusion**: Cannot use both `search` and `filters` simultaneously
|
||||
- **Default columns**: Include UUID/slug in default responses for better searchability
|
||||
- **Validation messages**: Clear error messages guide LLM usage
|
||||
|
||||
### For Get Info Tools (`get_*_info`)
|
||||
- **Multi-identifier support**: Single interface for ID, UUID, and slug lookup
|
||||
- **Intelligent detection**: Automatically determines identifier type based on format
|
||||
- **Enhanced flexibility**: Works with LLM-generated identifiers of any supported type
|
||||
- **Rich metadata**: Full object details including relationships (columns, metrics, owners)
|
||||
- **Error handling**: Clear error responses when objects not found or access denied
|
||||
|
||||
### ModelListTool and Schema Consistency
|
||||
|
||||
All list tools use the `ModelListTool` abstraction, which enforces:
|
||||
- Consistent parameter order and types via request schemas
|
||||
- Strongly-typed Pydantic input/output models
|
||||
- LLM/OpenAPI-friendly field names
|
||||
- Validation logic preventing parameter conflicts
|
||||
- Enhanced search including UUID/slug fields
|
||||
- Detailed metadata in responses (columns_requested, columns_loaded, etc.)
|
||||
|
||||
## Schema Validation & Testing
|
||||
|
||||
All schemas are thoroughly tested with:
|
||||
- **194+ unit tests** covering all input/output combinations including URL utils and audit logging
|
||||
- **Multi-identifier testing** for all get_*_info tools (ID, UUID, slug)
|
||||
- **Request schema validation** preventing parameter conflicts
|
||||
- **Authentication integration** testing with JWT contexts
|
||||
- **Error response validation** for permission and authentication failures
|
||||
- **Chart creation and update testing** covering all supported chart types and aggregators
|
||||
- **Dashboard generation testing** for workflow validation
|
||||
- **SQL Lab integration testing** with proper parameter handling
|
||||
|
||||
## Future Schema Enhancements
|
||||
|
||||
### Phase 1 Recently Completed ✅
|
||||
- **Backend rendering schemas**: Chart screenshot and image response formats ✅
|
||||
- **SQL Lab schemas**: Context-aware query session parameters ✅
|
||||
- **Dashboard generation schemas**: Complete dashboard creation and chart addition ✅
|
||||
- **Chart data/preview schemas**: Multi-format data export and preview generation ✅
|
||||
- **Enhanced error responses**: More detailed validation and permission error details ✅
|
||||
|
||||
### Future Phases
|
||||
- **Advanced chart types**: Maps, 3D visualizations, custom components
|
||||
- **Vega-Lite/Plotly output**: LLM-friendly chart rendering formats
|
||||
- **Advanced dashboard layouts**: Custom positioning and grid configurations
|
||||
- **Business intelligence schemas**: Natural language to SQL query generation
|
||||
261
superset/mcp_service/TABLE_CHART_GUIDE.md
Normal file
261
superset/mcp_service/TABLE_CHART_GUIDE.md
Normal file
@@ -0,0 +1,261 @@
|
||||
# Table Chart Configuration Guide
|
||||
|
||||
This guide explains how table charts work in the Superset MCP service, including the improved aggregation behavior and formatting.
|
||||
|
||||
## Table Chart Behavior
|
||||
|
||||
### Column Types
|
||||
|
||||
Table charts support two types of columns:
|
||||
|
||||
#### 1. Raw Columns (No Aggregation)
|
||||
```python
|
||||
ColumnRef(name="customer_name") # No aggregate specified
|
||||
ColumnRef(name="order_date") # Raw date values
|
||||
```
|
||||
- Shows individual row values
|
||||
- No grouping applied
|
||||
- Displays data as-is from the dataset
|
||||
|
||||
#### 2. Aggregated Columns (With Aggregation)
|
||||
```python
|
||||
ColumnRef(name="revenue", aggregate="SUM") # Sum of revenue
|
||||
ColumnRef(name="orders", aggregate="COUNT") # Count of orders
|
||||
ColumnRef(name="price", aggregate="AVG") # Average price
|
||||
```
|
||||
- Applies specified aggregation function
|
||||
- Groups data when mixed with raw columns
|
||||
- Supported aggregates: SUM, COUNT, AVG, MIN, MAX, COUNT_DISTINCT
|
||||
|
||||
### Mixed Column Behavior
|
||||
|
||||
When you mix raw and aggregated columns, the table automatically groups by the raw columns:
|
||||
|
||||
#### Example 1: Pure Raw Columns
|
||||
```python
|
||||
TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="customer_name"),
|
||||
ColumnRef(name="order_date"),
|
||||
ColumnRef(name="product_name")
|
||||
]
|
||||
)
|
||||
```
|
||||
**Result**: Shows individual rows, no grouping
|
||||
```
|
||||
customer_name | order_date | product_name
|
||||
John Smith | 2024-01-15 | Widget A
|
||||
Jane Doe | 2024-01-16 | Widget B
|
||||
John Smith | 2024-01-17 | Widget C
|
||||
```
|
||||
|
||||
#### Example 2: Pure Aggregated Columns
|
||||
```python
|
||||
TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="revenue", aggregate="SUM"),
|
||||
ColumnRef(name="orders", aggregate="COUNT")
|
||||
]
|
||||
)
|
||||
```
|
||||
**Result**: Single row with aggregated totals
|
||||
```
|
||||
SUM(revenue) | COUNT(orders)
|
||||
45,250.00 | 1,247
|
||||
```
|
||||
|
||||
#### Example 3: Mixed Raw + Aggregated (Recommended)
|
||||
```python
|
||||
TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="customer_name"), # Raw (becomes GROUP BY)
|
||||
ColumnRef(name="revenue", aggregate="SUM"), # Aggregated
|
||||
ColumnRef(name="orders", aggregate="COUNT") # Aggregated
|
||||
]
|
||||
)
|
||||
```
|
||||
**Result**: Groups by customer_name, aggregates metrics
|
||||
```
|
||||
customer_name | SUM(revenue) | COUNT(orders)
|
||||
John Smith | 15,750.00 | 8
|
||||
Jane Doe | 29,500.00 | 12
|
||||
```
|
||||
|
||||
## Aggregation Functions
|
||||
|
||||
### Supported Aggregates
|
||||
|
||||
| Function | Description | Works With |
|
||||
|----------|-------------|------------|
|
||||
| `SUM` | Sum of values | Numeric columns |
|
||||
| `COUNT` | Count of rows | All column types |
|
||||
| `COUNT_DISTINCT` | Count unique values | All column types |
|
||||
| `AVG` | Average value | Numeric columns |
|
||||
| `MIN` | Minimum value | Numeric, date columns |
|
||||
| `MAX` | Maximum value | Numeric, date columns |
|
||||
|
||||
### Type Compatibility
|
||||
|
||||
The validation system prevents incompatible aggregations:
|
||||
- ✅ `SUM(revenue)` - numeric column
|
||||
- ✅ `COUNT(customer_name)` - text column
|
||||
- ❌ `SUM(customer_name)` - invalid (text column)
|
||||
- ✅ `MIN(order_date)` - date column
|
||||
- ❌ `AVG(customer_name)` - invalid (text column)
|
||||
|
||||
## Improved Table Preview
|
||||
|
||||
### Enhanced Formatting Features
|
||||
|
||||
1. **Dynamic Column Widths**: Columns adjust width based on content
|
||||
2. **Better Number Formatting**:
|
||||
- Thousands separators: `1,234.56`
|
||||
- Scientific notation for large numbers: `1.23e+06`
|
||||
- Proper decimal places for floats
|
||||
3. **More Columns Shown**: Up to 8 columns (was 5)
|
||||
4. **More Rows Shown**: Up to 15 rows (was 10)
|
||||
5. **Smart Truncation**: Uses `..` to indicate truncated content
|
||||
6. **NULL Handling**: Shows `NULL` for null values
|
||||
|
||||
### Example Enhanced Preview
|
||||
```
|
||||
Table Preview
|
||||
================================================================================
|
||||
customer_name | region | SUM(revenue) | COUNT(orders) | AVG(rating)
|
||||
------------------+-----------+--------------+---------------+-------------
|
||||
John Smith | North | 15,750.00 | 8 | 4.25
|
||||
Jane Doe | South | 29,500.00 | 12 | 4.67
|
||||
Mike Johnson | West | 8,900.00 | 5 | 3.80
|
||||
Sarah Wilson | East | 22,100.00 | 9 | 4.44
|
||||
... and 146 more rows
|
||||
... and 3 more columns
|
||||
|
||||
Total: 150 rows × 11 columns
|
||||
```
|
||||
|
||||
## Configuration Examples
|
||||
|
||||
### Basic Customer Report
|
||||
```python
|
||||
TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="customer_name"),
|
||||
ColumnRef(name="total_orders", aggregate="COUNT"),
|
||||
ColumnRef(name="total_revenue", aggregate="SUM"),
|
||||
ColumnRef(name="avg_order_value", aggregate="AVG")
|
||||
],
|
||||
sort_by=["total_revenue"] # Sort by revenue descending
|
||||
)
|
||||
```
|
||||
|
||||
### Regional Sales Summary
|
||||
```python
|
||||
TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region"),
|
||||
ColumnRef(name="sales_rep"),
|
||||
ColumnRef(name="revenue", aggregate="SUM"),
|
||||
ColumnRef(name="deals_closed", aggregate="COUNT")
|
||||
],
|
||||
filters=[
|
||||
FilterConfig(column="region", op="!=", value="Unknown")
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Product Performance Analysis
|
||||
```python
|
||||
TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="product_category"),
|
||||
ColumnRef(name="product_name"),
|
||||
ColumnRef(name="units_sold", aggregate="SUM"),
|
||||
ColumnRef(name="revenue", aggregate="SUM"),
|
||||
ColumnRef(name="profit_margin", aggregate="AVG")
|
||||
],
|
||||
sort_by=["revenue", "units_sold"]
|
||||
)
|
||||
```
|
||||
|
||||
## Migration from Previous Behavior
|
||||
|
||||
### Before (Problematic)
|
||||
- All columns were forced to have aggregation (defaulted to SUM)
|
||||
- Mixed raw and aggregated behavior was unclear
|
||||
- Headers were truncated to 15 characters
|
||||
- Only 5 columns and 10 rows shown
|
||||
|
||||
### After (Fixed)
|
||||
- Raw columns stay raw, aggregated columns stay aggregated
|
||||
- Clear grouping behavior when mixing column types
|
||||
- Dynamic column widths with smart truncation
|
||||
- Better number formatting and more data shown
|
||||
- Detailed preview with summary statistics
|
||||
|
||||
## Best Practices
|
||||
|
||||
### 1. Choose Column Types Intentionally
|
||||
- Use raw columns for dimensional data (names, categories, dates)
|
||||
- Use aggregated columns for metrics (revenue, counts, averages)
|
||||
|
||||
### 2. Meaningful Grouping
|
||||
```python
|
||||
# Good: Groups customers by region, shows metrics
|
||||
columns=[
|
||||
ColumnRef(name="region"), # GROUP BY
|
||||
ColumnRef(name="revenue", aggregate="SUM")
|
||||
]
|
||||
|
||||
# Bad: Mixing unrelated raw columns
|
||||
columns=[
|
||||
ColumnRef(name="customer_name"), # Will group by this
|
||||
ColumnRef(name="product_name"), # And this (Cartesian product!)
|
||||
ColumnRef(name="revenue", aggregate="SUM")
|
||||
]
|
||||
```
|
||||
|
||||
### 3. Use Appropriate Aggregates
|
||||
```python
|
||||
# Good
|
||||
ColumnRef(name="price", aggregate="AVG") # Average price
|
||||
ColumnRef(name="order_count", aggregate="COUNT") # Count orders
|
||||
|
||||
# Bad
|
||||
ColumnRef(name="customer_name", aggregate="SUM") # Invalid!
|
||||
```
|
||||
|
||||
### 4. Sort by Important Metrics
|
||||
```python
|
||||
TableChartConfig(
|
||||
# ... columns ...
|
||||
sort_by=["revenue", "customer_name"] # Sort by revenue desc, then name asc
|
||||
)
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Issue: Too Many Grouped Rows
|
||||
**Problem**: When mixing many raw columns, you get too many groups
|
||||
**Solution**: Reduce raw columns or use filters to limit data
|
||||
|
||||
### Issue: Unexpected Aggregation
|
||||
**Problem**: Getting aggregated data when you want raw rows
|
||||
**Solution**: Remove `aggregate` parameter from column definitions
|
||||
|
||||
### Issue: Missing Data in Groups
|
||||
**Problem**: Some combinations showing NULL
|
||||
**Solution**: This is normal - not all combinations exist in your data
|
||||
|
||||
### Issue: Performance Problems
|
||||
**Problem**: Table takes too long to load
|
||||
**Solution**: Add filters to reduce data volume, or use fewer grouping columns
|
||||
|
||||
---
|
||||
|
||||
This improved table chart implementation provides clearer behavior, better formatting, and more predictable results for users.
|
||||
247
superset/mcp_service/TABLE_CHART_TEST_PLAN.md
Normal file
247
superset/mcp_service/TABLE_CHART_TEST_PLAN.md
Normal file
@@ -0,0 +1,247 @@
|
||||
# Table Chart Aggregation Fixes - Test Plan (UPDATED)
|
||||
|
||||
This test plan validates the table chart aggregation improvements including **CRITICAL FIXES** for raw columns and numeric type validation. Run these tests to verify the fixes work correctly.
|
||||
|
||||
## Test Overview
|
||||
|
||||
**Purpose**: Verify table chart aggregation behavior and preview formatting improvements
|
||||
**Time**: ~10-15 minutes
|
||||
**Prerequisites**: Superset MCP service running with sample data
|
||||
**Updated**: Includes fixes for raw column empty queries and DOUBLE PRECISION aggregation support
|
||||
|
||||
## Test Cases
|
||||
|
||||
### Test 1: Raw Columns Only (No Aggregation)
|
||||
**Expected**: Show individual rows without grouping
|
||||
|
||||
```python
|
||||
# Generate a table with only raw columns
|
||||
{
|
||||
"dataset_id": [your_dataset_id],
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "customer_name"},
|
||||
{"name": "order_date"},
|
||||
{"name": "product_name"}
|
||||
]
|
||||
},
|
||||
"generate_preview": true,
|
||||
"preview_formats": ["table"]
|
||||
}
|
||||
```
|
||||
|
||||
**✅ Expected Results** (CRITICAL FIX):
|
||||
- Shows individual rows (not aggregated)
|
||||
- All customer names appear (no grouping)
|
||||
- **Table preview contains actual data** (not empty)
|
||||
- **No "Empty query?" errors**
|
||||
- Preview shows multiple rows with different customers
|
||||
- Column headers are not truncated
|
||||
- **Form data includes query_mode="raw" and row_limit**
|
||||
|
||||
### Test 2: Numeric Aggregates (CRITICAL FIX)
|
||||
**Expected**: SUM/AVG work on DOUBLE PRECISION, BIGINT, etc.
|
||||
|
||||
```python
|
||||
# Test SUM on DOUBLE PRECISION (previously failing)
|
||||
{
|
||||
"dataset_id": [your_dataset_id],
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "sales", "aggregate": "SUM"}, # DOUBLE PRECISION type
|
||||
{"name": "price_each", "aggregate": "AVG"} # DOUBLE PRECISION type
|
||||
]
|
||||
},
|
||||
"generate_preview": true,
|
||||
"preview_formats": ["table"]
|
||||
}
|
||||
```
|
||||
|
||||
**✅ Expected Results** (CRITICAL FIX):
|
||||
- **No validation errors** (previously rejected DOUBLE PRECISION)
|
||||
- Shows single row with aggregated totals
|
||||
- Column headers show "SUM(sales)", "AVG(price_each)"
|
||||
- **Numeric types properly recognized**: DOUBLE PRECISION, BIGINT, INTEGER, FLOAT all work
|
||||
- Values are properly aggregated across all data
|
||||
|
||||
### Test 3: Aggregated Columns Only
|
||||
**Expected**: Show single summary row with totals
|
||||
|
||||
```python
|
||||
# Generate a table with only aggregated columns
|
||||
{
|
||||
"dataset_id": [your_dataset_id],
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "revenue", "aggregate": "SUM"},
|
||||
{"name": "order_id", "aggregate": "COUNT"}
|
||||
]
|
||||
},
|
||||
"generate_preview": true,
|
||||
"preview_formats": ["table"]
|
||||
}
|
||||
```
|
||||
|
||||
**✅ Expected Results**:
|
||||
- Shows single row with totals
|
||||
- Column headers show "SUM(revenue)", "COUNT(order_id)"
|
||||
- Values are properly aggregated across all data
|
||||
|
||||
### Test 3: Mixed Columns (Raw + Aggregated)
|
||||
**Expected**: Group by raw columns, aggregate metrics
|
||||
|
||||
```python
|
||||
# Generate a table mixing raw and aggregated columns
|
||||
{
|
||||
"dataset_id": [your_dataset_id],
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "customer_name"},
|
||||
{"name": "revenue", "aggregate": "SUM"},
|
||||
{"name": "order_id", "aggregate": "COUNT"}
|
||||
]
|
||||
},
|
||||
"generate_preview": true,
|
||||
"preview_formats": ["table"]
|
||||
}
|
||||
```
|
||||
|
||||
**✅ Expected Results**:
|
||||
- One row per customer (grouped by customer_name)
|
||||
- Revenue and order counts aggregated per customer
|
||||
- Multiple customers visible in preview
|
||||
- Clear grouping behavior
|
||||
|
||||
### Test 4: Enhanced Table Preview Formatting
|
||||
**Expected**: Better formatting and more information
|
||||
|
||||
```python
|
||||
# Generate table with various data types to test formatting
|
||||
{
|
||||
"dataset_id": [your_dataset_id],
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "customer_name"},
|
||||
{"name": "revenue", "aggregate": "SUM"},
|
||||
{"name": "avg_rating", "aggregate": "AVG"}
|
||||
]
|
||||
},
|
||||
"generate_preview": true,
|
||||
"preview_formats": ["table"]
|
||||
}
|
||||
```
|
||||
|
||||
**✅ Expected Results**:
|
||||
- Column widths adjust to content (not fixed 15 chars)
|
||||
- Numbers show thousands separators (e.g., "1,234.56")
|
||||
- Large numbers use scientific notation if needed
|
||||
- Table shows "Total: X rows × Y columns" at bottom
|
||||
- No harsh truncation of column names
|
||||
|
||||
### Test 5: Error Validation (Enhanced Error Messages)
|
||||
**Expected**: Helpful error messages with suggestions
|
||||
|
||||
```python
|
||||
# Try invalid column name to test error handling
|
||||
{
|
||||
"dataset_id": [your_dataset_id],
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "invalid_column_name"}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**✅ Expected Results**:
|
||||
- Clear error message about invalid column
|
||||
- Suggestions for similar column names (fuzzy matching)
|
||||
- List of available columns
|
||||
- Helpful context about the dataset
|
||||
|
||||
### Test 6: Invalid Aggregation (Type Checking)
|
||||
**Expected**: Prevent incompatible aggregations
|
||||
|
||||
```python
|
||||
# Try invalid aggregation (SUM on text column)
|
||||
{
|
||||
"dataset_id": [your_dataset_id],
|
||||
"config": {
|
||||
"chart_type": "table",
|
||||
"columns": [
|
||||
{"name": "customer_name", "aggregate": "SUM"}
|
||||
]
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**✅ Expected Results**:
|
||||
- Error about incompatible aggregation
|
||||
- Suggestion to use COUNT/COUNT_DISTINCT for text columns
|
||||
- Clear explanation of what went wrong
|
||||
|
||||
## Success Criteria
|
||||
|
||||
### ✅ Core Functionality
|
||||
- [ ] Raw columns show individual rows (no forced aggregation)
|
||||
- [ ] Aggregated columns show proper totals
|
||||
- [ ] Mixed columns group correctly
|
||||
- [ ] No unexpected SUM() wrapping of text columns
|
||||
|
||||
### ✅ Preview Quality
|
||||
- [ ] Table previews show more than 5 columns
|
||||
- [ ] Column headers not truncated at 15 characters
|
||||
- [ ] Numbers formatted with thousands separators
|
||||
- [ ] Dynamic column widths based on content
|
||||
- [ ] Summary information at bottom
|
||||
|
||||
### ✅ Error Handling
|
||||
- [ ] Invalid columns give helpful suggestions
|
||||
- [ ] Invalid aggregations provide clear guidance
|
||||
- [ ] Error messages include available options
|
||||
- [ ] Fuzzy matching suggests corrections
|
||||
|
||||
## Quick Validation Commands
|
||||
|
||||
Use these MCP tool calls to quickly test the fixes:
|
||||
|
||||
```bash
|
||||
# Test 1: Raw columns
|
||||
generate_chart with dataset_id and raw columns only
|
||||
|
||||
# Test 2: Aggregated columns
|
||||
generate_chart with dataset_id and aggregated columns only
|
||||
|
||||
# Test 3: Mixed columns
|
||||
generate_chart with dataset_id mixing raw and aggregated
|
||||
|
||||
# Test 4: Invalid column
|
||||
generate_chart with invalid column name (expect helpful error)
|
||||
```
|
||||
|
||||
## Regression Tests
|
||||
|
||||
Ensure these still work:
|
||||
- [ ] Chart screenshots still generate
|
||||
- [ ] Explore page screenshots still work
|
||||
- [ ] Other chart types (XY charts) unaffected
|
||||
- [ ] Filters still work with table charts
|
||||
|
||||
## Notes for Testing
|
||||
|
||||
1. **Use a dataset with multiple rows and columns** for best results
|
||||
2. **Check both preview and actual chart generation**
|
||||
3. **Test with different data types** (text, numbers, dates)
|
||||
4. **Verify error messages are user-friendly**
|
||||
5. **Confirm no performance regression**
|
||||
|
||||
---
|
||||
|
||||
**Expected Test Duration**: 10-15 minutes
|
||||
**Pass Criteria**: All ✅ checkboxes completed successfully
|
||||
262
superset/mcp_service/WEBDRIVER_POOLING.md
Normal file
262
superset/mcp_service/WEBDRIVER_POOLING.md
Normal file
@@ -0,0 +1,262 @@
|
||||
# WebDriver Connection Pooling for Screenshot Performance
|
||||
|
||||
This document describes the WebDriver connection pooling implementation that significantly improves screenshot generation performance in the Superset MCP service.
|
||||
|
||||
## Problem Statement
|
||||
|
||||
Previously, each screenshot request would:
|
||||
1. Create a new WebDriver instance (browser startup: 2-5 seconds)
|
||||
2. Navigate to the URL and take screenshot
|
||||
3. Destroy the WebDriver instance (browser shutdown: 1-2 seconds)
|
||||
|
||||
This resulted in **3-7 seconds overhead per screenshot**, making the service slow and resource-intensive.
|
||||
|
||||
## Solution: WebDriver Connection Pooling
|
||||
|
||||
The pooling solution reuses WebDriver instances across requests, reducing screenshot generation time by **80-90%**.
|
||||
|
||||
### Key Components
|
||||
|
||||
#### 1. WebDriverPool (`webdriver_pool.py`)
|
||||
- **Thread-safe** connection pool for WebDriver instances
|
||||
- **Automatic health checking** and recovery of browser instances
|
||||
- **TTL-based expiration** to prevent memory leaks
|
||||
- **Usage-based rotation** to prevent browser degradation
|
||||
- **Configurable pool size** and behavior
|
||||
|
||||
#### 2. PooledScreenshot Classes (`pooled_screenshot.py`)
|
||||
- `PooledBaseScreenshot` - Base class with pooling logic
|
||||
- `PooledChartScreenshot` - Drop-in replacement for `ChartScreenshot`
|
||||
- `PooledExploreScreenshot` - Enhanced explore page screenshots with UI hiding
|
||||
- `PooledDashboardScreenshot` - Dashboard screenshot support
|
||||
|
||||
#### 3. Configuration (`webdriver_config.py`)
|
||||
- Pre-configured settings for different traffic levels
|
||||
- Environment-specific optimizations
|
||||
- Monitoring and debugging utilities
|
||||
|
||||
## Performance Improvements
|
||||
|
||||
| Metric | Before (No Pool) | After (With Pool) | Improvement |
|
||||
|--------|------------------|-------------------|-------------|
|
||||
| First screenshot | 5-7 seconds | 5-7 seconds | Same (cold start) |
|
||||
| Subsequent screenshots | 5-7 seconds | 0.5-1 second | **85-90% faster** |
|
||||
| Resource usage | High (constant browser startup/shutdown) | Low (reused browsers) | **70-80% reduction** |
|
||||
| Concurrent requests | Limited by startup time | Higher throughput | **3-5x improvement** |
|
||||
|
||||
## Configuration Options
|
||||
|
||||
### Basic Configuration
|
||||
```python
|
||||
# In superset_config.py
|
||||
WEBDRIVER_POOL = {
|
||||
"MAX_POOL_SIZE": 5, # Maximum browsers in pool
|
||||
"MAX_AGE_SECONDS": 3600, # Browser lifetime (1 hour)
|
||||
"MAX_USAGE_COUNT": 50, # Max reuses before recreation
|
||||
"IDLE_TIMEOUT_SECONDS": 300, # Idle timeout (5 minutes)
|
||||
"HEALTH_CHECK_INTERVAL": 60, # Health check frequency
|
||||
}
|
||||
```
|
||||
|
||||
### Environment-Specific Configurations
|
||||
|
||||
#### Development
|
||||
```python
|
||||
from superset.mcp_service.webdriver_config import configure_for_environment
|
||||
configure_for_environment(config, "development")
|
||||
```
|
||||
- Small pool size (2 browsers)
|
||||
- Short lifetimes for faster iteration
|
||||
- Frequent health checks
|
||||
|
||||
#### Production - Low Traffic
|
||||
```python
|
||||
configure_for_environment(config, "low_traffic")
|
||||
```
|
||||
- Conservative resource usage
|
||||
- Longer idle timeouts
|
||||
- 2-3 browsers maximum
|
||||
|
||||
#### Production - High Traffic
|
||||
```python
|
||||
configure_for_environment(config, "high_traffic")
|
||||
```
|
||||
- Larger pool (10 browsers)
|
||||
- Extended lifetimes
|
||||
- Optimized for throughput
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### Before (Original Implementation)
|
||||
```python
|
||||
from superset.utils.screenshots import ChartScreenshot
|
||||
|
||||
# Creates new browser, takes screenshot, destroys browser (slow)
|
||||
screenshot = ChartScreenshot(chart_url, chart.digest)
|
||||
image_data = screenshot.get_screenshot(user=g.user)
|
||||
```
|
||||
|
||||
### After (Pooled Implementation)
|
||||
```python
|
||||
from superset.mcp_service.pooled_screenshot import PooledChartScreenshot
|
||||
|
||||
# Reuses browser from pool (fast)
|
||||
screenshot = PooledChartScreenshot(chart_url, chart.digest)
|
||||
image_data = screenshot.get_screenshot(user=g.user)
|
||||
```
|
||||
|
||||
### Context Manager Usage (Advanced)
|
||||
```python
|
||||
from superset.mcp_service.webdriver_pool import get_webdriver_pool
|
||||
|
||||
pool = get_webdriver_pool()
|
||||
with pool.get_driver((800, 600), user_id=user.id) as driver:
|
||||
# Driver is authenticated and ready to use
|
||||
driver.get(url)
|
||||
screenshot = driver.get_screenshot_as_png()
|
||||
# Driver automatically returned to pool
|
||||
```
|
||||
|
||||
## Monitoring and Debugging
|
||||
|
||||
### Pool Statistics
|
||||
```python
|
||||
from superset.mcp_service.webdriver_pool import get_webdriver_pool
|
||||
|
||||
pool = get_webdriver_pool()
|
||||
stats = pool.get_stats()
|
||||
print(stats)
|
||||
# Output:
|
||||
# {
|
||||
# "pool_size": 3,
|
||||
# "active_count": 1,
|
||||
# "created": 15,
|
||||
# "destroyed": 12,
|
||||
# "borrowed": 150,
|
||||
# "returned": 149,
|
||||
# "health_check_failures": 2,
|
||||
# "evictions": 5
|
||||
# }
|
||||
```
|
||||
|
||||
### Health Monitoring
|
||||
The pool automatically:
|
||||
- **Health checks** browsers every minute
|
||||
- **Evicts** unhealthy or expired browsers
|
||||
- **Recreates** browsers as needed
|
||||
- **Logs** all pool operations for debugging
|
||||
|
||||
### Debug Endpoint (Optional)
|
||||
```python
|
||||
from superset.mcp_service.webdriver_config import get_pool_stats_endpoint
|
||||
|
||||
# Register debug endpoint
|
||||
app.route('/debug/webdriver-pool')(get_pool_stats_endpoint())
|
||||
```
|
||||
|
||||
## Architecture Integration
|
||||
|
||||
### MCP Service Integration
|
||||
The pooled screenshots are integrated into:
|
||||
- `serve_chart_screenshot()` - Chart screenshot endpoint
|
||||
- `serve_explore_screenshot()` - Explore screenshot endpoint
|
||||
- `get_chart_preview` tool - Chart preview generation
|
||||
- `generate_chart` tool - Chart creation with previews
|
||||
|
||||
### Backward Compatibility
|
||||
- **Drop-in replacement** for existing screenshot classes
|
||||
- **Same API** as original implementations
|
||||
- **No breaking changes** to existing code
|
||||
|
||||
## Resource Management
|
||||
|
||||
### Memory Management
|
||||
- **Automatic cleanup** of expired browsers
|
||||
- **Configurable limits** on pool size
|
||||
- **Usage tracking** to prevent memory leaks
|
||||
|
||||
### Error Handling
|
||||
- **Graceful degradation** if pool is unavailable
|
||||
- **Automatic recovery** from browser crashes
|
||||
- **Fallback** to single-use browsers if needed
|
||||
|
||||
### Shutdown Handling
|
||||
```python
|
||||
from superset.mcp_service.webdriver_pool import shutdown_webdriver_pool
|
||||
|
||||
# Clean shutdown (call during app teardown)
|
||||
shutdown_webdriver_pool()
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
### Pool Sizing
|
||||
- **Start small** (2-3 browsers) and monitor
|
||||
- **Scale up** based on concurrent screenshot requests
|
||||
- **Consider memory** (each browser uses ~100-200MB)
|
||||
|
||||
### Health Monitoring
|
||||
- **Monitor pool statistics** regularly
|
||||
- **Watch for** high eviction rates (indicates configuration issues)
|
||||
- **Alert on** health check failures
|
||||
|
||||
### Configuration Tuning
|
||||
- **Development**: Use short lifetimes for faster iteration
|
||||
- **Low traffic**: Conservative settings to save resources
|
||||
- **High traffic**: Larger pools and longer lifetimes
|
||||
- **Debugging**: Enable more frequent health checks
|
||||
|
||||
## Security Considerations
|
||||
|
||||
### User Isolation
|
||||
- WebDriver instances are **not shared between users**
|
||||
- Each request gets a **fresh authentication**
|
||||
- **No cross-user data leakage** through browser state
|
||||
|
||||
### Resource Limits
|
||||
- **Pool size limits** prevent resource exhaustion
|
||||
- **TTL limits** prevent indefinite resource holding
|
||||
- **Health checks** detect and remove compromised browsers
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
#### Pool Not Creating Browsers
|
||||
- Check WebDriver configuration (`WEBDRIVER_TYPE`, etc.)
|
||||
- Verify browser binaries are installed
|
||||
- Check system resources (memory, CPU)
|
||||
|
||||
#### High Eviction Rates
|
||||
- Increase `MAX_AGE_SECONDS` or `MAX_USAGE_COUNT`
|
||||
- Check for memory pressure
|
||||
- Monitor browser health
|
||||
|
||||
#### Performance Not Improving
|
||||
- Verify pooled classes are being used
|
||||
- Check pool statistics for reuse rates
|
||||
- Ensure adequate pool size for load
|
||||
|
||||
### Debug Steps
|
||||
1. **Check pool stats** to see activity
|
||||
2. **Enable debug logging** for WebDriver operations
|
||||
3. **Monitor system resources** during operation
|
||||
4. **Test with single browser** to isolate issues
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
### Potential Improvements
|
||||
- **Multi-window support** for parallel screenshots
|
||||
- **Browser-specific pools** (Chrome vs Firefox)
|
||||
- **Dynamic scaling** based on load
|
||||
- **Persistent pools** across service restarts
|
||||
- **Integration with container orchestration**
|
||||
|
||||
### Metrics Integration
|
||||
- **Prometheus metrics** for pool statistics
|
||||
- **Performance tracking** for screenshot timing
|
||||
- **Alert integration** for pool health
|
||||
|
||||
---
|
||||
|
||||
This WebDriver pooling implementation provides significant performance improvements while maintaining reliability and security. The modular design allows for easy configuration and monitoring in production environments.
|
||||
32
superset/mcp_service/__init__.py
Normal file
32
superset/mcp_service/__init__.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
# superset/mcp_service/__init__.py
|
||||
|
||||
"""
|
||||
Model Context Protocol (MCP) service for Apache Superset.
|
||||
|
||||
This service provides a structured interface for AI agents to interact with Superset's
|
||||
core functionality through well-defined, high-level actions.
|
||||
|
||||
The service runs as a standalone server.
|
||||
|
||||
To start the service, run:
|
||||
superset mcp run
|
||||
"""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
283
superset/mcp_service/auth.py
Normal file
283
superset/mcp_service/auth.py
Normal file
@@ -0,0 +1,283 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
from flask import Flask
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from mcp.server.auth.provider import AccessToken
|
||||
|
||||
from superset.extensions import event_logger
|
||||
|
||||
# Type variable for decorated functions
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_user_from_request() -> User:
|
||||
"""Extract user from JWT token with robust fallback to admin users."""
|
||||
from flask import current_app
|
||||
|
||||
from superset import security_manager
|
||||
|
||||
username = _extract_username_from_jwt(current_app)
|
||||
user = security_manager.find_user(username)
|
||||
|
||||
if not user:
|
||||
user = _find_fallback_admin_user(username)
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def _extract_username_from_jwt(app: Flask) -> str:
|
||||
"""Extract username from JWT token or return configured fallback."""
|
||||
try:
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
|
||||
access_token = get_access_token()
|
||||
user_resolver = app.config.get("MCP_USER_RESOLVER")
|
||||
|
||||
username = (
|
||||
user_resolver(access_token)
|
||||
if user_resolver and callable(user_resolver)
|
||||
else getattr(access_token, "client_id", None)
|
||||
)
|
||||
|
||||
if username:
|
||||
logger.info(f"MCP auth: JWT user '{username}'")
|
||||
return username
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"JWT extraction failed: {e}")
|
||||
|
||||
fallback = app.config.get("MCP_ADMIN_USERNAME", "admin")
|
||||
logger.debug(f"MCP auth: Using fallback user '{fallback}'")
|
||||
return fallback
|
||||
|
||||
|
||||
def _find_fallback_admin_user(username: str) -> User:
|
||||
"""Find any admin user as fallback when configured user doesn't exist."""
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset import db
|
||||
|
||||
try:
|
||||
admin_user = db.session.query(User).filter(User.roles.any(name="Admin")).first()
|
||||
|
||||
if admin_user:
|
||||
logger.warning(
|
||||
f"User '{username}' not found, using '{admin_user.username}'"
|
||||
)
|
||||
return admin_user
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Failed to find admin user: {e}")
|
||||
|
||||
raise ValueError(
|
||||
f"User '{username}' not found. Create user or update MCP_ADMIN_USERNAME. "
|
||||
f"Use: superset fab create-admin"
|
||||
)
|
||||
|
||||
|
||||
def impersonate_user(user: User, run_as: Optional[str] = None) -> User:
|
||||
"""Return impersonated user or original if no impersonation requested."""
|
||||
if not run_as:
|
||||
return user
|
||||
|
||||
from superset import security_manager
|
||||
|
||||
impersonated = security_manager.find_user(run_as)
|
||||
if not impersonated:
|
||||
raise ValueError(f"Impersonation target '{run_as}' not found")
|
||||
|
||||
logger.info(f"Impersonating {run_as} as {user.username}")
|
||||
return impersonated
|
||||
|
||||
|
||||
def has_permission(user: User, tool_func: Callable[..., Any]) -> bool:
|
||||
"""Validate user permissions using JWT scopes and user status."""
|
||||
if not user or not user.is_active:
|
||||
return False
|
||||
|
||||
try:
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
|
||||
access_token = get_access_token()
|
||||
if not access_token:
|
||||
return True # No JWT means no scope restrictions
|
||||
|
||||
return _check_jwt_scopes(user, tool_func, access_token)
|
||||
|
||||
except Exception:
|
||||
return True # Allow access when JWT unavailable
|
||||
|
||||
|
||||
def _check_jwt_scopes(
|
||||
user: User, tool_func: Callable[..., Any], access_token: AccessToken
|
||||
) -> bool:
|
||||
"""Check if user has required JWT scopes for the tool."""
|
||||
user_scopes = access_token.scopes or []
|
||||
|
||||
scope_requirements = {
|
||||
"list_dashboards": ["dashboard:read"],
|
||||
"get_dashboard_info": ["dashboard:read"],
|
||||
"list_charts": ["chart:read"],
|
||||
"get_chart_info": ["chart:read"],
|
||||
"generate_chart": ["chart:write"],
|
||||
"list_datasets": ["dataset:read"],
|
||||
"get_dataset_info": ["dataset:read"],
|
||||
"get_superset_instance_info": ["instance:read"],
|
||||
}
|
||||
|
||||
required_scopes = scope_requirements.get(tool_func.__name__)
|
||||
if not required_scopes:
|
||||
return True
|
||||
|
||||
has_access = any(scope in user_scopes for scope in required_scopes)
|
||||
if not has_access:
|
||||
logger.warning(
|
||||
f"User {user.username} lacks scopes {required_scopes} "
|
||||
f"for {tool_func.__name__}"
|
||||
)
|
||||
|
||||
return has_access
|
||||
|
||||
|
||||
def sanitize_mcp_payload(kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Sanitize MCP tool payload for audit logging."""
|
||||
# Remove sensitive fields and limit payload size
|
||||
sensitive_keys = {"password", "token", "secret", "key", "auth"}
|
||||
|
||||
sanitized = {}
|
||||
for key, value in kwargs.items():
|
||||
if any(sensitive in key.lower() for sensitive in sensitive_keys):
|
||||
sanitized[key] = "[REDACTED]"
|
||||
elif isinstance(value, str) and len(value) > 1000:
|
||||
sanitized[key] = value[:1000] + "...[TRUNCATED]"
|
||||
else:
|
||||
sanitized[key] = value
|
||||
|
||||
return sanitized
|
||||
|
||||
|
||||
def get_mcp_audit_context(
|
||||
tool_func: Callable[..., Any], kwargs: dict[str, Any]
|
||||
) -> dict[str, Any]:
|
||||
"""Get MCP-specific audit context for logging."""
|
||||
from flask import g, request
|
||||
|
||||
# Get JWT context if available
|
||||
jwt_context = _get_jwt_context()
|
||||
|
||||
context = {
|
||||
"log_source": "mcp",
|
||||
"impersonation": getattr(g.user, "username", "unknown")
|
||||
if hasattr(g, "user") and g.user
|
||||
else "unknown",
|
||||
"mcp_tool": tool_func.__name__,
|
||||
}
|
||||
|
||||
# Add ideally available fields
|
||||
try:
|
||||
if hasattr(request, "headers"):
|
||||
context["model_info"] = request.headers.get("User-Agent", "unknown")
|
||||
context["session_info"] = request.headers.get("X-Session-ID")
|
||||
|
||||
context["whitelisted_payload"] = sanitize_mcp_payload(kwargs)
|
||||
|
||||
# Add JWT context if available
|
||||
if jwt_context:
|
||||
context["jwt_user"] = jwt_context.get("user")
|
||||
context["jwt_scopes"] = jwt_context.get("scopes", [])
|
||||
|
||||
except Exception as e:
|
||||
logger.debug(f"Error getting MCP audit context: {e}")
|
||||
|
||||
return context
|
||||
|
||||
|
||||
def log_access(user: User, tool_name: str, args: Any, kwargs: Any) -> None:
|
||||
"""Log tool access with optional JWT context."""
|
||||
|
||||
if jwt_context := _get_jwt_context():
|
||||
logger.info(
|
||||
f"MCP access: user={user.username}, jwt={jwt_context['user']}, "
|
||||
f"tool={tool_name}, scopes={jwt_context['scopes']}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"MCP access: user={user.username}, tool={tool_name}")
|
||||
|
||||
|
||||
def _get_jwt_context() -> Optional[dict[str, Any]]:
|
||||
"""Extract JWT context for logging purposes."""
|
||||
try:
|
||||
from fastmcp.server.dependencies import get_access_token
|
||||
|
||||
token = get_access_token()
|
||||
if token:
|
||||
return {
|
||||
"user": getattr(token, "client_id", None),
|
||||
"scopes": getattr(token, "scopes", []),
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug(f"JWT context extraction failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def mcp_auth_hook(tool_func: F) -> F:
|
||||
"""Authentication and authorization decorator for MCP tools with audit logging."""
|
||||
import functools
|
||||
|
||||
from flask import g
|
||||
|
||||
# Apply event logger decorator if available, otherwise proceed without it
|
||||
def apply_audit_logging(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
try:
|
||||
if event_logger and hasattr(event_logger, "log_this_with_context"):
|
||||
return event_logger.log_this_with_context(
|
||||
action=lambda *args, **kwargs: f"mcp.{tool_func.__name__}",
|
||||
log_to_statsd=False,
|
||||
)(func)
|
||||
except Exception as e:
|
||||
logger.debug(f"Event logger not available: {e}")
|
||||
return func
|
||||
|
||||
@apply_audit_logging
|
||||
@functools.wraps(tool_func)
|
||||
def wrapper(*args: Any, **kwargs: Any) -> Any:
|
||||
user = get_user_from_request()
|
||||
g.user = user
|
||||
|
||||
# Add MCP audit context to Flask g for event logger
|
||||
g.mcp_audit_context = get_mcp_audit_context(tool_func, kwargs)
|
||||
|
||||
if run_as := kwargs.get("run_as"):
|
||||
user = impersonate_user(user, run_as)
|
||||
|
||||
if not has_permission(user, tool_func):
|
||||
raise PermissionError(
|
||||
f"Access denied: {user.username} lacks permission "
|
||||
f"for {tool_func.__name__}"
|
||||
)
|
||||
|
||||
log_access(user, tool_func.__name__, args, kwargs)
|
||||
return tool_func(*args, **kwargs)
|
||||
|
||||
return wrapper # type: ignore[return-value]
|
||||
125
superset/mcp_service/auth_config_example.py
Normal file
125
superset/mcp_service/auth_config_example.py
Normal file
@@ -0,0 +1,125 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Example MCP Authentication Configuration for superset_config.py
|
||||
|
||||
Add these settings to your superset_config.py to enable MCP authentication.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask
|
||||
|
||||
# ---------------------------------------------------
|
||||
# Example 1: Simple Configuration (Recommended)
|
||||
# ---------------------------------------------------
|
||||
|
||||
# Enable MCP authentication
|
||||
MCP_AUTH_ENABLED = True
|
||||
|
||||
# JWT configuration
|
||||
MCP_JWKS_URI = "https://your-auth-provider.com/.well-known/jwks.json"
|
||||
MCP_JWT_ISSUER = "https://your-auth-provider.com/"
|
||||
MCP_JWT_AUDIENCE = "superset-mcp-server"
|
||||
MCP_JWT_ALGORITHM = "RS256"
|
||||
MCP_REQUIRED_SCOPES = ["superset:read", "superset:query"]
|
||||
|
||||
# The default factory will use these values automatically
|
||||
|
||||
|
||||
# ---------------------------------------------------
|
||||
# Example 2: Custom Factory (Following @dpgaspar's pattern)
|
||||
# ---------------------------------------------------
|
||||
|
||||
|
||||
def create_custom_mcp_auth(app: Flask) -> Optional[Any]:
|
||||
"""Custom MCP auth factory following Superset patterns."""
|
||||
# Access config values from app.config
|
||||
jwks_uri = app.config["MCP_JWKS_URI"]
|
||||
issuer = app.config["MCP_JWT_ISSUER"]
|
||||
audience = app.config["MCP_JWT_AUDIENCE"]
|
||||
|
||||
# Add custom logic here
|
||||
if app.debug:
|
||||
app.logger.info("MCP auth in debug mode")
|
||||
|
||||
from fastmcp.server.auth.providers.bearer import BearerAuthProvider
|
||||
|
||||
return BearerAuthProvider(
|
||||
jwks_uri=jwks_uri,
|
||||
issuer=issuer,
|
||||
audience=audience,
|
||||
algorithm=app.config.get("MCP_JWT_ALGORITHM", "RS256"),
|
||||
required_scopes=app.config.get("MCP_REQUIRED_SCOPES", []),
|
||||
)
|
||||
|
||||
|
||||
# Override the default factory
|
||||
# MCP_AUTH_FACTORY = create_custom_mcp_auth
|
||||
|
||||
|
||||
# ---------------------------------------------------
|
||||
# Example 3: Environment-based Factory
|
||||
# ---------------------------------------------------
|
||||
|
||||
|
||||
def create_env_based_mcp_auth(app: Flask) -> Optional[Any]:
|
||||
"""Factory that uses environment variables with app.config fallback."""
|
||||
import os
|
||||
|
||||
# Check environment first, then app.config
|
||||
if os.getenv("DISABLE_MCP_AUTH", "").lower() == "true":
|
||||
return None
|
||||
|
||||
jwks_uri = os.getenv("MCP_JWKS_URI") or app.config.get("MCP_JWKS_URI")
|
||||
if not jwks_uri:
|
||||
app.logger.warning("No JWKS URI configured for MCP auth")
|
||||
return None
|
||||
|
||||
from fastmcp.server.auth.providers.bearer import BearerAuthProvider
|
||||
|
||||
return BearerAuthProvider(
|
||||
jwks_uri=jwks_uri,
|
||||
issuer=os.getenv("MCP_JWT_ISSUER") or app.config.get("MCP_JWT_ISSUER"),
|
||||
audience=os.getenv("MCP_JWT_AUDIENCE") or app.config.get("MCP_JWT_AUDIENCE"),
|
||||
)
|
||||
|
||||
|
||||
# MCP_AUTH_FACTORY = create_env_based_mcp_auth
|
||||
|
||||
|
||||
# ---------------------------------------------------
|
||||
# Example 4: Custom JWT User Resolution
|
||||
# ---------------------------------------------------
|
||||
|
||||
|
||||
def custom_user_resolver(access_token: Any) -> Optional[str]:
|
||||
"""
|
||||
Custom resolver to extract username from JWT token.
|
||||
Useful when your JWT has non-standard claim names.
|
||||
"""
|
||||
# Example: Extract from nested claims
|
||||
if hasattr(access_token, "payload"):
|
||||
user_info = access_token.payload.get("user_info", {})
|
||||
return user_info.get("preferred_username") or user_info.get("email")
|
||||
|
||||
# Fallback to standard claims
|
||||
return access_token.subject or access_token.client_id
|
||||
|
||||
|
||||
# MCP_USER_RESOLVER = custom_user_resolver
|
||||
143
superset/mcp_service/cache_utils.py
Normal file
143
superset/mcp_service/cache_utils.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Cache utilities for MCP tools.
|
||||
|
||||
This module provides utilities for working with Superset's cache layers
|
||||
and implementing cache control in MCP tools.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from superset.mcp_service.schemas.cache_schemas import CacheStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_cache_status_from_result(
|
||||
result: Dict[str, Any], force_refresh: bool = False
|
||||
) -> CacheStatus:
|
||||
"""
|
||||
Extract cache status information from a Superset query result.
|
||||
|
||||
Args:
|
||||
result: Query result dictionary from Superset
|
||||
force_refresh: Whether cache was force refreshed
|
||||
|
||||
Returns:
|
||||
CacheStatus object with cache usage information
|
||||
"""
|
||||
# Handle different result structures
|
||||
if "queries" in result and len(result["queries"]) > 0:
|
||||
query_result = result["queries"][0]
|
||||
else:
|
||||
query_result = result
|
||||
|
||||
cache_hit = bool(query_result.get("is_cached", False))
|
||||
|
||||
# Convert cache age to seconds if available
|
||||
cache_age_seconds = None
|
||||
if cache_age := query_result.get("cache_dttm"):
|
||||
try:
|
||||
from datetime import datetime
|
||||
|
||||
if isinstance(cache_age, str):
|
||||
cache_dt = datetime.fromisoformat(cache_age.replace("Z", "+00:00"))
|
||||
cache_age_seconds = int(
|
||||
(datetime.now(cache_dt.tzinfo) - cache_dt).total_seconds()
|
||||
)
|
||||
elif isinstance(cache_age, datetime):
|
||||
cache_age_seconds = int(
|
||||
(datetime.now(cache_age.tzinfo) - cache_age).total_seconds()
|
||||
)
|
||||
except Exception as e:
|
||||
logger.debug(f"Could not parse cache age: {e}")
|
||||
|
||||
return CacheStatus(
|
||||
cache_hit=cache_hit,
|
||||
cache_type="query" if cache_hit else "none",
|
||||
cache_age_seconds=cache_age_seconds,
|
||||
refreshed=force_refresh,
|
||||
)
|
||||
|
||||
|
||||
def apply_cache_control_to_query_context(
|
||||
query_context: Dict[str, Any],
|
||||
use_cache: bool = True,
|
||||
force_refresh: bool = False,
|
||||
cache_timeout: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Apply cache control parameters to a query context.
|
||||
|
||||
Args:
|
||||
query_context: Query context dictionary
|
||||
use_cache: Whether to use cache
|
||||
force_refresh: Whether to force refresh
|
||||
cache_timeout: Cache timeout override
|
||||
|
||||
Returns:
|
||||
Modified query context with cache control applied
|
||||
"""
|
||||
if not use_cache or force_refresh:
|
||||
query_context["force"] = True
|
||||
|
||||
if cache_timeout is not None:
|
||||
# Apply to all queries in the context
|
||||
for query in query_context.get("queries", []):
|
||||
query["cache_timeout"] = cache_timeout
|
||||
|
||||
return query_context
|
||||
|
||||
|
||||
def should_use_metadata_cache(
|
||||
use_cache: bool = True,
|
||||
refresh_metadata: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Determine whether to use metadata cache based on cache control parameters.
|
||||
|
||||
Args:
|
||||
use_cache: Whether to use cache
|
||||
refresh_metadata: Whether to refresh metadata
|
||||
|
||||
Returns:
|
||||
True if metadata cache should be used
|
||||
"""
|
||||
return use_cache and not refresh_metadata
|
||||
|
||||
|
||||
def get_cache_key_info(cache_key: Optional[str]) -> Optional[str]:
|
||||
"""
|
||||
Get truncated cache key for debugging purposes.
|
||||
|
||||
Args:
|
||||
cache_key: Full cache key
|
||||
|
||||
Returns:
|
||||
Truncated cache key or None
|
||||
"""
|
||||
if not cache_key:
|
||||
return None
|
||||
|
||||
# Truncate long cache keys for readability
|
||||
if len(cache_key) > 50:
|
||||
return cache_key[:47] + "..."
|
||||
|
||||
return cache_key
|
||||
429
superset/mcp_service/chart/chart_utils.py
Normal file
429
superset/mcp_service/chart/chart_utils.py
Normal file
@@ -0,0 +1,429 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Shared chart utilities for MCP tools
|
||||
|
||||
This module contains shared logic for chart configuration mapping and explore link
|
||||
generation that can be used by both generate_chart and generate_explore_link tools.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
ChartCapabilities,
|
||||
ChartSemantics,
|
||||
ColumnRef,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
from superset.mcp_service.url_utils import get_superset_base_url
|
||||
from superset.utils import json
|
||||
|
||||
|
||||
def generate_explore_link(dataset_id: int | str, form_data: Dict[str, Any]) -> str:
|
||||
"""Generate an explore link for the given dataset and form data."""
|
||||
base_url = get_superset_base_url()
|
||||
numeric_dataset_id = None
|
||||
|
||||
try:
|
||||
from superset.commands.explore.form_data.parameters import CommandParameters
|
||||
|
||||
# Find the dataset to get its numeric ID
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.mcp_service.commands.create_form_data import (
|
||||
MCPCreateFormDataCommand,
|
||||
)
|
||||
from superset.utils.core import DatasourceType
|
||||
|
||||
dataset = None
|
||||
|
||||
if isinstance(dataset_id, int) or (
|
||||
isinstance(dataset_id, str) and dataset_id.isdigit()
|
||||
):
|
||||
numeric_dataset_id = (
|
||||
int(dataset_id) if isinstance(dataset_id, str) else dataset_id
|
||||
)
|
||||
dataset = DatasetDAO.find_by_id(numeric_dataset_id)
|
||||
else:
|
||||
# Try UUID lookup using DAO flexible method
|
||||
dataset = DatasetDAO.find_by_id(dataset_id, id_column="uuid")
|
||||
if dataset:
|
||||
numeric_dataset_id = dataset.id
|
||||
|
||||
if not dataset or numeric_dataset_id is None:
|
||||
# Fallback to basic explore URL
|
||||
return (
|
||||
f"{base_url}/explore/?datasource_type=table&datasource_id={dataset_id}"
|
||||
)
|
||||
|
||||
# Add datasource to form_data
|
||||
form_data_with_datasource = {
|
||||
**form_data,
|
||||
"datasource": f"{numeric_dataset_id}__table",
|
||||
}
|
||||
|
||||
# Try to create form_data in cache using MCP-specific CreateFormDataCommand
|
||||
cmd_params = CommandParameters(
|
||||
datasource_type=DatasourceType.TABLE,
|
||||
datasource_id=numeric_dataset_id,
|
||||
chart_id=0, # 0 for new charts
|
||||
tab_id=None,
|
||||
form_data=json.dumps(form_data_with_datasource),
|
||||
)
|
||||
|
||||
# Create the form_data cache entry and get the key
|
||||
form_data_key = MCPCreateFormDataCommand(cmd_params).run()
|
||||
|
||||
# Return URL with just the form_data_key
|
||||
return f"{base_url}/explore/?form_data_key={form_data_key}"
|
||||
|
||||
except Exception:
|
||||
# Fallback to basic explore URL with numeric ID if available
|
||||
if numeric_dataset_id is not None:
|
||||
return (
|
||||
f"{base_url}/explore/?datasource_type=table"
|
||||
f"&datasource_id={numeric_dataset_id}"
|
||||
)
|
||||
else:
|
||||
return (
|
||||
f"{base_url}/explore/?datasource_type=table&datasource_id={dataset_id}"
|
||||
)
|
||||
|
||||
|
||||
def map_config_to_form_data(
|
||||
config: TableChartConfig | XYChartConfig,
|
||||
) -> Dict[str, Any]:
|
||||
"""Map chart config to Superset form_data."""
|
||||
if isinstance(config, TableChartConfig):
|
||||
return map_table_config(config)
|
||||
elif isinstance(config, XYChartConfig):
|
||||
return map_xy_config(config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported config type: {type(config)}")
|
||||
|
||||
|
||||
def map_table_config(config: TableChartConfig) -> Dict[str, Any]:
|
||||
"""Map table chart config to form_data."""
|
||||
# Separate columns with aggregates from raw columns
|
||||
raw_columns = []
|
||||
aggregated_metrics = []
|
||||
|
||||
for col in config.columns:
|
||||
if col.aggregate:
|
||||
# Column has aggregation - treat as metric
|
||||
aggregated_metrics.append(create_metric_object(col))
|
||||
else:
|
||||
# No aggregation - treat as raw column
|
||||
raw_columns.append(col.name)
|
||||
|
||||
form_data: Dict[str, Any] = {
|
||||
"viz_type": "table",
|
||||
}
|
||||
|
||||
# Handle raw columns (no aggregation)
|
||||
if raw_columns and not aggregated_metrics:
|
||||
# Pure raw columns - show individual rows
|
||||
form_data.update(
|
||||
{
|
||||
"all_columns": raw_columns,
|
||||
"query_mode": "raw",
|
||||
"include_time": False,
|
||||
"order_desc": True,
|
||||
"row_limit": 1000, # Reasonable limit for raw data
|
||||
}
|
||||
)
|
||||
|
||||
# Handle aggregated columns only
|
||||
elif aggregated_metrics and not raw_columns:
|
||||
# Pure aggregation - show totals
|
||||
form_data.update(
|
||||
{
|
||||
"metrics": aggregated_metrics,
|
||||
"query_mode": "aggregate",
|
||||
}
|
||||
)
|
||||
|
||||
# Handle mixed columns (raw + aggregated)
|
||||
elif raw_columns and aggregated_metrics:
|
||||
# Mixed mode - group by raw columns, aggregate metrics
|
||||
form_data.update(
|
||||
{
|
||||
"all_columns": raw_columns,
|
||||
"metrics": aggregated_metrics,
|
||||
"groupby": raw_columns,
|
||||
"query_mode": "aggregate",
|
||||
}
|
||||
)
|
||||
|
||||
if config.filters:
|
||||
form_data["adhoc_filters"] = [
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"subject": filter_config.column,
|
||||
"operator": map_filter_operator(filter_config.op),
|
||||
"comparator": filter_config.value,
|
||||
}
|
||||
for filter_config in config.filters
|
||||
]
|
||||
|
||||
if config.sort_by:
|
||||
form_data["order_by_cols"] = config.sort_by
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
def create_metric_object(col: ColumnRef) -> Dict[str, Any]:
|
||||
"""Create a metric object for a column."""
|
||||
aggregate = col.aggregate or "SUM"
|
||||
return {
|
||||
"aggregate": aggregate,
|
||||
"column": {
|
||||
"column_name": col.name,
|
||||
},
|
||||
"expressionType": "SIMPLE",
|
||||
"label": col.label or f"{aggregate}({col.name})",
|
||||
"optionName": f"metric_{col.name}",
|
||||
"sqlExpression": None,
|
||||
"hasCustomLabel": False,
|
||||
"datasourceWarning": False,
|
||||
}
|
||||
|
||||
|
||||
def add_axis_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
|
||||
"""Add axis configurations to form_data."""
|
||||
if config.x_axis:
|
||||
if config.x_axis.title:
|
||||
form_data["x_axis_title"] = config.x_axis.title
|
||||
if config.x_axis.format:
|
||||
form_data["x_axis_format"] = config.x_axis.format
|
||||
|
||||
if config.y_axis:
|
||||
if config.y_axis.title:
|
||||
form_data["y_axis_title"] = config.y_axis.title
|
||||
if config.y_axis.format:
|
||||
form_data["y_axis_format"] = config.y_axis.format
|
||||
if config.y_axis.scale == "log":
|
||||
form_data["y_axis_scale"] = "log"
|
||||
|
||||
|
||||
def add_legend_config(form_data: Dict[str, Any], config: XYChartConfig) -> None:
|
||||
"""Add legend configuration to form_data."""
|
||||
if config.legend:
|
||||
if not config.legend.show:
|
||||
form_data["show_legend"] = False
|
||||
if config.legend.position:
|
||||
form_data["legend_orientation"] = config.legend.position
|
||||
|
||||
|
||||
def map_xy_config(config: XYChartConfig) -> Dict[str, Any]:
|
||||
"""Map XY chart config to form_data."""
|
||||
# Map chart kind to viz_type
|
||||
viz_type_map = {
|
||||
"line": "echarts_timeseries_line",
|
||||
"bar": "echarts_timeseries_bar",
|
||||
"area": "echarts_area",
|
||||
"scatter": "echarts_timeseries_scatter",
|
||||
}
|
||||
|
||||
# Convert Y columns to metrics
|
||||
metrics = [create_metric_object(col) for col in config.y]
|
||||
|
||||
form_data: Dict[str, Any] = {
|
||||
"viz_type": viz_type_map.get(config.kind, "echarts_timeseries_line"),
|
||||
"x_axis": config.x.name,
|
||||
"metrics": metrics,
|
||||
}
|
||||
|
||||
# Add groupby if specified
|
||||
if config.group_by:
|
||||
form_data["groupby"] = [config.group_by.name]
|
||||
|
||||
# Add filters if specified
|
||||
if config.filters:
|
||||
form_data["adhoc_filters"] = [
|
||||
{
|
||||
"expressionType": "SIMPLE",
|
||||
"subject": filter_config.column,
|
||||
"operator": map_filter_operator(filter_config.op),
|
||||
"comparator": filter_config.value,
|
||||
}
|
||||
for filter_config in config.filters
|
||||
]
|
||||
|
||||
# Add configurations
|
||||
add_axis_config(form_data, config)
|
||||
add_legend_config(form_data, config)
|
||||
|
||||
return form_data
|
||||
|
||||
|
||||
def map_filter_operator(op: str) -> str:
|
||||
"""Map filter operator to Superset format."""
|
||||
operator_map = {
|
||||
"=": "==",
|
||||
">": ">",
|
||||
"<": "<",
|
||||
">=": ">=",
|
||||
"<=": "<=",
|
||||
"!=": "!=",
|
||||
}
|
||||
return operator_map.get(op, op)
|
||||
|
||||
|
||||
def generate_chart_name(config: TableChartConfig | XYChartConfig) -> str:
|
||||
"""Generate a chart name based on the configuration."""
|
||||
if isinstance(config, TableChartConfig):
|
||||
return f"Table Chart - {', '.join(col.name for col in config.columns)}"
|
||||
elif isinstance(config, XYChartConfig):
|
||||
chart_type = config.kind.capitalize()
|
||||
x_col = config.x.name
|
||||
y_cols = ", ".join(col.name for col in config.y)
|
||||
return f"{chart_type} Chart - {x_col} vs {y_cols}"
|
||||
else:
|
||||
return "Chart"
|
||||
|
||||
|
||||
def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities:
|
||||
"""Analyze chart capabilities based on type and configuration."""
|
||||
if chart:
|
||||
viz_type = getattr(chart, "viz_type", "unknown")
|
||||
else:
|
||||
# Map config chart_type to viz_type
|
||||
chart_type = getattr(config, "chart_type", "unknown")
|
||||
if chart_type == "xy":
|
||||
kind = getattr(config, "kind", "line")
|
||||
viz_type_map = {
|
||||
"line": "echarts_timeseries_line",
|
||||
"bar": "echarts_timeseries_bar",
|
||||
"area": "echarts_area",
|
||||
"scatter": "echarts_timeseries_scatter",
|
||||
}
|
||||
viz_type = viz_type_map.get(kind, "echarts_timeseries_line")
|
||||
elif chart_type == "table":
|
||||
viz_type = "table"
|
||||
else:
|
||||
viz_type = "unknown"
|
||||
|
||||
# Determine interaction capabilities based on chart type
|
||||
interactive_types = [
|
||||
"echarts_timeseries_line",
|
||||
"echarts_timeseries_bar",
|
||||
"echarts_area",
|
||||
"echarts_timeseries_scatter",
|
||||
"deck_scatter",
|
||||
"deck_hex",
|
||||
]
|
||||
|
||||
supports_interaction = viz_type in interactive_types
|
||||
supports_drill_down = viz_type in ["table", "pivot_table_v2"]
|
||||
supports_real_time = viz_type in [
|
||||
"echarts_timeseries_line",
|
||||
"echarts_timeseries_bar",
|
||||
]
|
||||
|
||||
# Determine optimal formats
|
||||
optimal_formats = ["url"] # Always include static image
|
||||
if supports_interaction:
|
||||
optimal_formats.extend(["interactive", "vega_lite"])
|
||||
optimal_formats.extend(["ascii", "table"])
|
||||
|
||||
# Classify data types
|
||||
data_types = []
|
||||
if hasattr(config, "x") and config.x:
|
||||
data_types.append("categorical" if not config.x.aggregate else "metric")
|
||||
if hasattr(config, "y") and config.y:
|
||||
data_types.extend(["metric"] * len(config.y))
|
||||
if "time" in viz_type or "timeseries" in viz_type:
|
||||
data_types.append("time_series")
|
||||
|
||||
return ChartCapabilities(
|
||||
supports_interaction=supports_interaction,
|
||||
supports_real_time=supports_real_time,
|
||||
supports_drill_down=supports_drill_down,
|
||||
supports_export=True, # All charts can be exported
|
||||
optimal_formats=optimal_formats,
|
||||
data_types=list(set(data_types)),
|
||||
)
|
||||
|
||||
|
||||
def analyze_chart_semantics(chart: Any | None, config: Any) -> ChartSemantics:
|
||||
"""Generate semantic understanding of the chart."""
|
||||
if chart:
|
||||
viz_type = getattr(chart, "viz_type", "unknown")
|
||||
else:
|
||||
# Map config chart_type to viz_type
|
||||
chart_type = getattr(config, "chart_type", "unknown")
|
||||
if chart_type == "xy":
|
||||
kind = getattr(config, "kind", "line")
|
||||
viz_type_map = {
|
||||
"line": "echarts_timeseries_line",
|
||||
"bar": "echarts_timeseries_bar",
|
||||
"area": "echarts_area",
|
||||
"scatter": "echarts_timeseries_scatter",
|
||||
}
|
||||
viz_type = viz_type_map.get(kind, "echarts_timeseries_line")
|
||||
elif chart_type == "table":
|
||||
viz_type = "table"
|
||||
else:
|
||||
viz_type = "unknown"
|
||||
|
||||
# Generate primary insight based on chart type
|
||||
insights_map = {
|
||||
"echarts_timeseries_line": "Shows trends and changes over time",
|
||||
"echarts_timeseries_bar": "Compares values across categories or time periods",
|
||||
"table": "Displays detailed data in tabular format",
|
||||
"pie": "Shows proportional relationships within a dataset",
|
||||
"echarts_area": "Emphasizes cumulative totals and part-to-whole relationships",
|
||||
}
|
||||
|
||||
primary_insight = insights_map.get(
|
||||
viz_type, f"Visualizes data using {viz_type} format"
|
||||
)
|
||||
|
||||
# Generate data story
|
||||
columns = []
|
||||
if hasattr(config, "x") and config.x:
|
||||
columns.append(config.x.name)
|
||||
if hasattr(config, "y") and config.y:
|
||||
columns.extend([col.name for col in config.y])
|
||||
|
||||
if columns:
|
||||
ellipsis = "..." if len(columns) > 3 else ""
|
||||
data_story = (
|
||||
f"This {viz_type} chart analyzes {', '.join(columns[:3])}{ellipsis}"
|
||||
)
|
||||
else:
|
||||
data_story = "This chart provides insights into the selected dataset"
|
||||
|
||||
# Generate recommended actions
|
||||
recommended_actions = [
|
||||
"Review data patterns and trends",
|
||||
"Consider filtering or drilling down for more detail",
|
||||
"Export chart for reporting or sharing",
|
||||
]
|
||||
|
||||
if viz_type in ["echarts_timeseries_line", "echarts_timeseries_bar"]:
|
||||
recommended_actions.append("Analyze seasonal patterns or cyclical trends")
|
||||
|
||||
return ChartSemantics(
|
||||
primary_insight=primary_insight,
|
||||
data_story=data_story,
|
||||
recommended_actions=recommended_actions,
|
||||
anomalies=[], # Would need actual data analysis to populate
|
||||
statistical_summary={}, # Would need actual data analysis to populate
|
||||
)
|
||||
427
superset/mcp_service/chart/preview_utils.py
Normal file
427
superset/mcp_service/chart/preview_utils.py
Normal file
@@ -0,0 +1,427 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Preview utilities for chart generation without saving.
|
||||
|
||||
This module provides utilities for generating chart previews
|
||||
from form data without requiring a saved chart object.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from superset.charts.data.api import ChartDataCommand
|
||||
from superset.charts.data.query_context import QueryContext
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
ASCIIPreview,
|
||||
ChartError,
|
||||
TablePreview,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def generate_preview_from_form_data(
|
||||
form_data: Dict[str, Any], dataset_id: int, preview_format: str
|
||||
) -> Any:
|
||||
"""
|
||||
Generate preview from form data without a saved chart.
|
||||
|
||||
Args:
|
||||
form_data: Chart configuration form data
|
||||
dataset_id: Dataset ID
|
||||
preview_format: Preview format (ascii, table, etc.)
|
||||
|
||||
Returns:
|
||||
Preview object or ChartError
|
||||
"""
|
||||
try:
|
||||
# Execute query to get data
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.extensions import db
|
||||
|
||||
dataset = db.session.query(SqlaTable).get(dataset_id)
|
||||
if not dataset:
|
||||
return ChartError(
|
||||
error=f"Dataset {dataset_id} not found", error_type="DatasetNotFound"
|
||||
)
|
||||
|
||||
# Create query context from form data
|
||||
query_context = {
|
||||
"datasource": {"id": dataset_id, "type": "table"},
|
||||
"queries": [
|
||||
{
|
||||
"columns": form_data.get("columns", []),
|
||||
"metrics": form_data.get("metrics", []),
|
||||
"orderby": form_data.get("orderby", []),
|
||||
"row_limit": form_data.get("row_limit", 100),
|
||||
"filters": form_data.get("adhoc_filters", []),
|
||||
"time_range": form_data.get("time_range", "No filter"),
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
# Execute query
|
||||
query_context_obj = QueryContext(**query_context)
|
||||
command = ChartDataCommand(query_context_obj)
|
||||
result = command.run()
|
||||
|
||||
if not result or not result.get("queries"):
|
||||
return ChartError(
|
||||
error="No data returned from query", error_type="EmptyResult"
|
||||
)
|
||||
|
||||
query_result = result["queries"][0]
|
||||
data = query_result.get("data", [])
|
||||
|
||||
# Generate preview based on format
|
||||
if preview_format == "ascii":
|
||||
return _generate_ascii_preview_from_data(data, form_data)
|
||||
elif preview_format == "table":
|
||||
return _generate_table_preview_from_data(data, form_data)
|
||||
else:
|
||||
return ChartError(
|
||||
error=f"Unsupported preview format: {preview_format}",
|
||||
error_type="UnsupportedFormat",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Preview generation from form data failed: {e}")
|
||||
return ChartError(
|
||||
error=f"Failed to generate preview: {str(e)}", error_type="PreviewError"
|
||||
)
|
||||
|
||||
|
||||
def _generate_ascii_preview_from_data(
|
||||
data: List[Dict[str, Any]], form_data: Dict[str, Any]
|
||||
) -> ASCIIPreview:
|
||||
"""Generate ASCII preview from raw data."""
|
||||
viz_type = form_data.get("viz_type", "table")
|
||||
|
||||
# Handle different chart types
|
||||
if viz_type in ["bar", "dist_bar", "column"]:
|
||||
content = _generate_safe_ascii_bar_chart(data)
|
||||
elif viz_type in ["line", "area"]:
|
||||
content = _generate_safe_ascii_line_chart(data)
|
||||
elif viz_type == "pie":
|
||||
content = _generate_safe_ascii_pie_chart(data)
|
||||
else:
|
||||
content = _generate_safe_ascii_table(data)
|
||||
|
||||
return ASCIIPreview(
|
||||
ascii_content=content, width=80, height=20, supports_color=False
|
||||
)
|
||||
|
||||
|
||||
def _calculate_column_widths(
|
||||
display_columns: List[str], data: List[Dict[str, Any]]
|
||||
) -> Dict[str, int]:
|
||||
"""Calculate optimal width for each column."""
|
||||
column_widths = {}
|
||||
for col in display_columns:
|
||||
# Start with column name length
|
||||
max_width = len(str(col))
|
||||
|
||||
# Check data values to determine width
|
||||
for row in data[:20]: # Sample first 20 rows
|
||||
val = row.get(col, "")
|
||||
if isinstance(val, float):
|
||||
val_str = f"{val:.2f}"
|
||||
elif isinstance(val, int):
|
||||
val_str = str(val)
|
||||
else:
|
||||
val_str = str(val)
|
||||
max_width = max(max_width, len(val_str))
|
||||
|
||||
# Set reasonable bounds
|
||||
column_widths[col] = min(max(max_width, 8), 25)
|
||||
return column_widths
|
||||
|
||||
|
||||
def _format_value(val: Any, width: int) -> str:
|
||||
"""Format a value based on its type."""
|
||||
if isinstance(val, float):
|
||||
if abs(val) >= 1000000:
|
||||
val_str = f"{val:.2e}" # Scientific notation for large numbers
|
||||
elif abs(val) >= 1000:
|
||||
val_str = f"{val:,.2f}" # Thousands separator
|
||||
else:
|
||||
val_str = f"{val:.2f}"
|
||||
elif isinstance(val, int):
|
||||
if abs(val) >= 1000:
|
||||
val_str = f"{val:,}" # Thousands separator
|
||||
else:
|
||||
val_str = str(val)
|
||||
elif val is None:
|
||||
val_str = "NULL"
|
||||
else:
|
||||
val_str = str(val)
|
||||
|
||||
# Truncate if too long
|
||||
if len(val_str) > width:
|
||||
val_str = val_str[: width - 2] + ".."
|
||||
return val_str
|
||||
|
||||
|
||||
def _generate_table_preview_from_data(
|
||||
data: List[Dict[str, Any]], form_data: Dict[str, Any]
|
||||
) -> TablePreview:
|
||||
"""Generate table preview from raw data with improved formatting."""
|
||||
if not data:
|
||||
return TablePreview(
|
||||
table_data="No data available", row_count=0, supports_sorting=False
|
||||
)
|
||||
|
||||
# Get columns
|
||||
columns = list(data[0].keys()) if data else []
|
||||
|
||||
# Determine optimal column widths and how many columns to show
|
||||
max_columns = 8 # Show more columns than before
|
||||
display_columns = columns[:max_columns]
|
||||
|
||||
# Calculate optimal width for each column
|
||||
column_widths = _calculate_column_widths(display_columns, data)
|
||||
|
||||
# Format table with proper alignment
|
||||
lines = ["Table Preview", "=" * 80]
|
||||
|
||||
# Header with dynamic width
|
||||
header_parts = []
|
||||
separator_parts = []
|
||||
for col in display_columns:
|
||||
width = column_widths[col]
|
||||
col_name = str(col)
|
||||
if len(col_name) > width:
|
||||
col_name = col_name[: width - 2] + ".."
|
||||
header_parts.append(f"{col_name:<{width}}")
|
||||
separator_parts.append("-" * width)
|
||||
|
||||
lines.append(" | ".join(header_parts))
|
||||
lines.append("-+-".join(separator_parts))
|
||||
|
||||
# Data rows with proper formatting
|
||||
rows_shown = min(len(data), 15) # Show more rows
|
||||
for row in data[:rows_shown]:
|
||||
row_parts = []
|
||||
for col in display_columns:
|
||||
width = column_widths[col]
|
||||
val = row.get(col, "")
|
||||
val_str = _format_value(val, width)
|
||||
row_parts.append(f"{val_str:<{width}}")
|
||||
lines.append(" | ".join(row_parts))
|
||||
|
||||
# Summary information
|
||||
if len(data) > rows_shown:
|
||||
lines.append(f"... and {len(data) - rows_shown} more rows")
|
||||
|
||||
if len(columns) > max_columns:
|
||||
lines.append(f"... and {len(columns) - max_columns} more columns")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"Total: {len(data)} rows × {len(columns)} columns")
|
||||
|
||||
return TablePreview(
|
||||
table_data="\n".join(lines), row_count=len(data), supports_sorting=True
|
||||
)
|
||||
|
||||
|
||||
def _generate_safe_ascii_bar_chart(data: List[Dict[str, Any]]) -> str:
|
||||
"""Generate ASCII bar chart with proper error handling."""
|
||||
if not data:
|
||||
return "No data available for bar chart"
|
||||
|
||||
lines = ["ASCII Bar Chart", "=" * 50]
|
||||
|
||||
# Extract values safely
|
||||
values = []
|
||||
labels = []
|
||||
|
||||
for row in data[:10]:
|
||||
label = None
|
||||
value = None
|
||||
|
||||
for _, val in row.items():
|
||||
if isinstance(val, (int, float)) and not _is_nan(val) and value is None:
|
||||
value = val
|
||||
elif isinstance(val, str) and label is None:
|
||||
label = val
|
||||
|
||||
if value is not None:
|
||||
values.append(value)
|
||||
labels.append(label or f"Item {len(values)}")
|
||||
|
||||
if not values:
|
||||
return "No numeric data found for bar chart"
|
||||
|
||||
# Generate bars
|
||||
max_val = max(values)
|
||||
if max_val == 0:
|
||||
return "All values are zero"
|
||||
|
||||
for label, value in zip(labels, values, strict=False):
|
||||
bar_length = int((value / max_val) * 30)
|
||||
bar = "█" * bar_length
|
||||
lines.append(f"{label[:10]:>10} |{bar:<30} {value:.2f}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _generate_safe_ascii_line_chart(data: List[Dict[str, Any]]) -> str:
|
||||
"""Generate ASCII line chart with proper NaN handling."""
|
||||
if not data:
|
||||
return "No data available for line chart"
|
||||
|
||||
lines = ["ASCII Line Chart", "=" * 50]
|
||||
values = _extract_numeric_values_safe(data)
|
||||
|
||||
if not values:
|
||||
return "No valid numeric data found for line chart"
|
||||
|
||||
range_str = _format_range_display(values)
|
||||
lines.append(range_str)
|
||||
|
||||
sparkline = _generate_sparkline_safe(values)
|
||||
lines.append(sparkline)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_numeric_values_safe(data: List[Dict[str, Any]]) -> List[float]:
|
||||
"""Extract numeric values safely from data."""
|
||||
values = []
|
||||
for row in data[:20]:
|
||||
for _, val in row.items():
|
||||
if isinstance(val, (int, float)) and not _is_nan(val):
|
||||
values.append(val)
|
||||
break
|
||||
return values
|
||||
|
||||
|
||||
def _format_range_display(values: List[float]) -> str:
|
||||
"""Format range display safely."""
|
||||
min_val = min(values)
|
||||
max_val = max(values)
|
||||
|
||||
if _is_nan(min_val) or _is_nan(max_val):
|
||||
return "Range: Unable to calculate"
|
||||
else:
|
||||
return f"Range: {min_val:.2f} to {max_val:.2f}"
|
||||
|
||||
|
||||
def _generate_sparkline_safe(values: List[float]) -> str:
|
||||
"""Generate sparkline from values."""
|
||||
if not values:
|
||||
return ""
|
||||
|
||||
min_val = min(values)
|
||||
|
||||
if (max_val := max(values)) != min_val:
|
||||
sparkline = ""
|
||||
for val in values:
|
||||
normalized = (val - min_val) / (max_val - min_val)
|
||||
if normalized < 0.2:
|
||||
sparkline += "▁"
|
||||
elif normalized < 0.4:
|
||||
sparkline += "▂"
|
||||
elif normalized < 0.6:
|
||||
sparkline += "▄"
|
||||
elif normalized < 0.8:
|
||||
sparkline += "▆"
|
||||
else:
|
||||
sparkline += "█"
|
||||
return sparkline
|
||||
else:
|
||||
return "─" * len(values) # Flat line if all values are same
|
||||
|
||||
|
||||
def _generate_safe_ascii_pie_chart(data: List[Dict[str, Any]]) -> str:
|
||||
"""Generate ASCII pie chart representation."""
|
||||
if not data:
|
||||
return "No data available for pie chart"
|
||||
|
||||
lines = ["ASCII Pie Chart", "=" * 50]
|
||||
|
||||
# Extract values and labels
|
||||
values = []
|
||||
labels = []
|
||||
|
||||
for row in data[:8]: # Limit slices
|
||||
label = None
|
||||
value = None
|
||||
|
||||
for _, val in row.items():
|
||||
if isinstance(val, (int, float)) and not _is_nan(val) and value is None:
|
||||
value = val
|
||||
elif isinstance(val, str) and label is None:
|
||||
label = val
|
||||
|
||||
if value is not None and value > 0:
|
||||
values.append(value)
|
||||
labels.append(label or f"Slice {len(values)}")
|
||||
|
||||
if not values:
|
||||
return "No valid data for pie chart"
|
||||
|
||||
# Calculate percentages
|
||||
total = sum(values)
|
||||
if total == 0:
|
||||
return "Total is zero"
|
||||
|
||||
for label, value in zip(labels, values, strict=False):
|
||||
percentage = (value / total) * 100
|
||||
bar_length = int(percentage / 3) # Scale to fit
|
||||
bar = "●" * bar_length
|
||||
lines.append(f"{label[:15]:>15}: {bar} {percentage:.1f}%")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _generate_safe_ascii_table(data: List[Dict[str, Any]]) -> str:
|
||||
"""Generate ASCII table with safe formatting."""
|
||||
if not data:
|
||||
return "No data available"
|
||||
|
||||
lines = ["Data Table", "=" * 50]
|
||||
|
||||
# Get columns
|
||||
columns = list(data[0].keys()) if data else []
|
||||
|
||||
# Format header
|
||||
header = " | ".join(str(col)[:10] for col in columns[:5])
|
||||
lines.append(header)
|
||||
lines.append("-" * len(header))
|
||||
|
||||
# Format rows
|
||||
for row in data[:10]:
|
||||
row_str = " | ".join(str(row.get(col, ""))[:10] for col in columns[:5])
|
||||
lines.append(row_str)
|
||||
|
||||
if len(data) > 10:
|
||||
lines.append(f"... {len(data) - 10} more rows")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _is_nan(value: Any) -> bool:
|
||||
"""Check if a value is NaN."""
|
||||
try:
|
||||
import math
|
||||
|
||||
return math.isnan(float(value))
|
||||
except (ValueError, TypeError):
|
||||
return False
|
||||
19
superset/mcp_service/chart/tool/__init__.py
Normal file
19
superset/mcp_service/chart/tool/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from .generate_chart import generate_chart
|
||||
from .get_chart_available_filters import get_chart_available_filters
|
||||
from .get_chart_data import get_chart_data
|
||||
from .get_chart_info import get_chart_info
|
||||
from .get_chart_preview import get_chart_preview
|
||||
from .list_charts import list_charts
|
||||
from .update_chart import update_chart
|
||||
from .update_chart_preview import update_chart_preview
|
||||
|
||||
__all__ = [
|
||||
"list_charts",
|
||||
"get_chart_info",
|
||||
"get_chart_available_filters",
|
||||
"generate_chart",
|
||||
"update_chart",
|
||||
"update_chart_preview",
|
||||
"get_chart_preview",
|
||||
"get_chart_data",
|
||||
]
|
||||
387
superset/mcp_service/chart/tool/generate_chart.py
Normal file
387
superset/mcp_service/chart/tool/generate_chart.py
Normal file
@@ -0,0 +1,387 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
MCP tool: generate_chart (simplified schema)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
analyze_chart_capabilities,
|
||||
analyze_chart_semantics,
|
||||
generate_chart_name,
|
||||
map_config_to_form_data,
|
||||
)
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
AccessibilityMetadata,
|
||||
GenerateChartRequest,
|
||||
PerformanceMetadata,
|
||||
URLPreview,
|
||||
)
|
||||
from superset.mcp_service.url_utils import (
|
||||
get_chart_screenshot_url,
|
||||
get_superset_base_url,
|
||||
)
|
||||
from superset.utils import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def generate_chart(request: GenerateChartRequest) -> Dict[str, Any]: # noqa: C901
|
||||
"""
|
||||
Create and SAVE a new chart in Superset with enhanced preview capabilities.
|
||||
|
||||
This tool creates a permanent chart in Superset and generates preview images
|
||||
that can be embedded in LLM conversations. The chart appears in saved charts
|
||||
and can be added to dashboards.
|
||||
|
||||
Enhanced features:
|
||||
- Generates chart preview images for embedding in LLM conversations
|
||||
- Returns explore URL for interactive editing after creation
|
||||
- Supports multiple preview formats (url, ascii, table, base64)
|
||||
- Rich semantic analysis and capabilities metadata
|
||||
- Enhanced form_data caching for better explore experience
|
||||
|
||||
Default behavior (save_chart=True):
|
||||
- Permanently saves chart in Superset
|
||||
- Generates preview images
|
||||
- Returns chart ID and metadata
|
||||
- Provides explore URL for further editing
|
||||
|
||||
Optional behavior (save_chart=False):
|
||||
- Creates temporary visualization only
|
||||
- Caches configuration server-side
|
||||
- Returns preview + explore link
|
||||
- No permanent chart saved
|
||||
|
||||
Args:
|
||||
request: Chart creation request with dataset_id, config, save_chart flag,
|
||||
and preview options
|
||||
|
||||
Returns:
|
||||
Response with saved chart info, preview images, and explore URL
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Validate chart configuration first
|
||||
from superset.mcp_service.chart.validation_utils import validate_chart_config
|
||||
|
||||
is_valid, validation_error = validate_chart_config(
|
||||
request.config, request.dataset_id
|
||||
)
|
||||
if not is_valid:
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
assert validation_error is not None # Type narrowing for mypy
|
||||
return {
|
||||
"chart": None,
|
||||
"error": validation_error.model_dump(),
|
||||
"performance": {
|
||||
"query_duration_ms": execution_time,
|
||||
"cache_status": "error",
|
||||
"optimization_suggestions": [],
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
|
||||
# Map the simplified config to Superset's form_data format
|
||||
form_data = map_config_to_form_data(request.config)
|
||||
|
||||
chart = None
|
||||
chart_id = None
|
||||
explore_url = None
|
||||
form_data_key = None
|
||||
|
||||
# Save chart by default (unless save_chart=False)
|
||||
if request.save_chart:
|
||||
from superset.commands.chart.create import CreateChartCommand
|
||||
|
||||
# Generate a chart name
|
||||
chart_name = generate_chart_name(request.config)
|
||||
|
||||
# Find the dataset to get its numeric ID
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
dataset = None
|
||||
if isinstance(request.dataset_id, int) or (
|
||||
isinstance(request.dataset_id, str) and request.dataset_id.isdigit()
|
||||
):
|
||||
dataset_id = (
|
||||
int(request.dataset_id)
|
||||
if isinstance(request.dataset_id, str)
|
||||
else request.dataset_id
|
||||
)
|
||||
dataset = DatasetDAO.find_by_id(dataset_id)
|
||||
else:
|
||||
# Try UUID lookup using DAO flexible method
|
||||
dataset = DatasetDAO.find_by_id(request.dataset_id, id_column="uuid")
|
||||
|
||||
if not dataset:
|
||||
from superset.mcp_service.schemas.error_schemas import (
|
||||
ChartGenerationError,
|
||||
)
|
||||
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
error = ChartGenerationError(
|
||||
error_type="dataset_not_found",
|
||||
message=f"Dataset not found: {request.dataset_id}",
|
||||
details=(
|
||||
f"No dataset found with identifier '{request.dataset_id}'. "
|
||||
f"This could be an invalid ID/UUID or a permissions issue."
|
||||
),
|
||||
suggestions=[
|
||||
"Verify the dataset ID or UUID is correct",
|
||||
"Check that you have access to this dataset",
|
||||
"Use the list_datasets tool to find available datasets",
|
||||
"If using UUID, ensure it's the correct format",
|
||||
],
|
||||
error_code="DATASET_NOT_FOUND",
|
||||
)
|
||||
return {
|
||||
"chart": None,
|
||||
"error": error.model_dump(),
|
||||
"performance": {
|
||||
"query_duration_ms": execution_time,
|
||||
"cache_status": "error",
|
||||
"optimization_suggestions": [],
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
|
||||
# Create the chart using Superset's command
|
||||
command = CreateChartCommand(
|
||||
{
|
||||
"slice_name": chart_name,
|
||||
"viz_type": form_data["viz_type"],
|
||||
"datasource_id": dataset.id,
|
||||
"datasource_type": "table",
|
||||
"params": json.dumps(form_data),
|
||||
}
|
||||
)
|
||||
|
||||
chart = command.run()
|
||||
chart_id = chart.id
|
||||
# Update explore URL to use saved chart
|
||||
explore_url = f"{get_superset_base_url()}/explore/?slice_id={chart.id}"
|
||||
else:
|
||||
# Generate explore link with cached form_data for preview-only mode
|
||||
from superset.mcp_service.chart.chart_utils import generate_explore_link
|
||||
|
||||
explore_url = generate_explore_link(request.dataset_id, form_data)
|
||||
|
||||
# Extract form_data_key from the explore URL
|
||||
if "form_data_key=" in explore_url:
|
||||
form_data_key = explore_url.split("form_data_key=")[1].split("&")[0]
|
||||
|
||||
# Generate semantic analysis
|
||||
capabilities = analyze_chart_capabilities(chart, request.config)
|
||||
semantics = analyze_chart_semantics(chart, request.config)
|
||||
|
||||
# Create performance metadata
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
performance = PerformanceMetadata(
|
||||
query_duration_ms=execution_time,
|
||||
cache_status="miss",
|
||||
optimization_suggestions=[],
|
||||
)
|
||||
|
||||
# Create accessibility metadata
|
||||
chart_name = (
|
||||
chart.slice_name
|
||||
if chart and hasattr(chart, "slice_name")
|
||||
else generate_chart_name(request.config)
|
||||
)
|
||||
accessibility = AccessibilityMetadata(
|
||||
color_blind_safe=True, # Would need actual analysis
|
||||
alt_text=f"Chart showing {chart_name}",
|
||||
high_contrast_available=False,
|
||||
)
|
||||
|
||||
# Generate previews if requested
|
||||
previews = {}
|
||||
if request.generate_preview:
|
||||
try:
|
||||
for format_type in request.preview_formats:
|
||||
# Skip base64 format - we never return base64
|
||||
if format_type == "base64":
|
||||
logger.info("Skipping base64 format - not supported")
|
||||
continue
|
||||
|
||||
if chart_id:
|
||||
# For saved charts, use the existing preview generation
|
||||
from superset.mcp_service.chart.tool.get_chart_preview import (
|
||||
_get_chart_preview_internal,
|
||||
GetChartPreviewRequest,
|
||||
)
|
||||
|
||||
preview_request = GetChartPreviewRequest(
|
||||
identifier=str(chart_id), format=format_type
|
||||
)
|
||||
preview_result = _get_chart_preview_internal(preview_request)
|
||||
|
||||
if hasattr(preview_result, "content"):
|
||||
previews[format_type] = preview_result.content
|
||||
else:
|
||||
# For preview-only mode (save_chart=false)
|
||||
if format_type == "url" and form_data_key:
|
||||
# Generate screenshot URL using centralized helper
|
||||
from superset.mcp_service.url_utils import (
|
||||
get_explore_screenshot_url,
|
||||
)
|
||||
|
||||
preview_url = get_explore_screenshot_url(form_data_key)
|
||||
previews[format_type] = URLPreview(
|
||||
preview_url=preview_url,
|
||||
width=800,
|
||||
height=600,
|
||||
supports_interaction=False,
|
||||
)
|
||||
elif format_type in ["ascii", "table"]:
|
||||
# Generate preview from form data without saved chart
|
||||
from superset.mcp_service.chart.preview_utils import (
|
||||
generate_preview_from_form_data,
|
||||
)
|
||||
|
||||
preview_result = generate_preview_from_form_data(
|
||||
form_data=form_data,
|
||||
dataset_id=int(request.dataset_id),
|
||||
preview_format=format_type,
|
||||
)
|
||||
|
||||
if not hasattr(preview_result, "error"):
|
||||
previews[format_type] = preview_result
|
||||
|
||||
except Exception as e:
|
||||
# Log warning but don't fail the entire request
|
||||
logger.warning(f"Preview generation failed: {e}")
|
||||
|
||||
# Return enhanced data while maintaining backward compatibility
|
||||
result = {
|
||||
"chart": {
|
||||
"id": chart.id if chart else None,
|
||||
"slice_name": chart.slice_name
|
||||
if chart
|
||||
else generate_chart_name(request.config),
|
||||
"viz_type": chart.viz_type if chart else form_data.get("viz_type"),
|
||||
"url": explore_url,
|
||||
"uuid": str(chart.uuid) if chart and chart.uuid else None,
|
||||
"saved": request.save_chart,
|
||||
}
|
||||
if request.save_chart
|
||||
else {
|
||||
"id": None,
|
||||
"slice_name": generate_chart_name(request.config),
|
||||
"viz_type": form_data.get("viz_type"),
|
||||
"url": explore_url,
|
||||
"uuid": None,
|
||||
"saved": False,
|
||||
},
|
||||
"error": None,
|
||||
# Enhanced fields for better LLM integration
|
||||
"previews": previews,
|
||||
"capabilities": capabilities.model_dump() if capabilities else None,
|
||||
"semantics": semantics.model_dump() if semantics else None,
|
||||
"explore_url": explore_url,
|
||||
"form_data_key": form_data_key,
|
||||
"api_endpoints": {
|
||||
"data": f"{get_superset_base_url()}/api/v1/chart/{chart.id}/data/"
|
||||
if chart
|
||||
else None,
|
||||
"preview": get_chart_screenshot_url(chart.id) if chart else None,
|
||||
"export": f"{get_superset_base_url()}/api/v1/chart/{chart.id}/export/"
|
||||
if chart
|
||||
else None,
|
||||
}
|
||||
if chart
|
||||
else {},
|
||||
"performance": performance.model_dump() if performance else None,
|
||||
"accessibility": accessibility.model_dump() if accessibility else None,
|
||||
"success": True,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
from superset.mcp_service.schemas.error_schemas import ChartGenerationError
|
||||
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
|
||||
# Analyze exception to provide better error context
|
||||
error_type = "chart_creation_error"
|
||||
suggestions = [
|
||||
"Check that all column names are spelled correctly",
|
||||
"Verify the dataset has data",
|
||||
"Try with a simpler chart configuration",
|
||||
"Check server logs for detailed error information",
|
||||
]
|
||||
|
||||
# Enhance error message based on exception type
|
||||
if "permission" in str(e).lower() or "access" in str(e).lower():
|
||||
error_type = "permission_error"
|
||||
suggestions = [
|
||||
"Check that you have access to the dataset",
|
||||
"Verify your user permissions in Superset",
|
||||
"Contact your administrator for dataset access",
|
||||
]
|
||||
elif "sql" in str(e).lower() or "query" in str(e).lower():
|
||||
error_type = "query_execution_error"
|
||||
suggestions = [
|
||||
"Check that column names exist in the dataset",
|
||||
"Verify filter values are valid for their column types",
|
||||
"Try a simpler query first",
|
||||
"Check the dataset's underlying data source",
|
||||
]
|
||||
elif "timeout" in str(e).lower():
|
||||
error_type = "query_timeout_error"
|
||||
suggestions = [
|
||||
"Try reducing the data range or adding filters",
|
||||
"Consider using a smaller sample of data",
|
||||
"Check if the database is responding slowly",
|
||||
"Contact your administrator about query performance",
|
||||
]
|
||||
|
||||
error = ChartGenerationError(
|
||||
error_type=error_type,
|
||||
message="Chart creation failed",
|
||||
details=f"An error occurred while creating the chart: {str(e)}",
|
||||
suggestions=suggestions,
|
||||
error_code="CHART_CREATION_FAILED",
|
||||
)
|
||||
|
||||
return {
|
||||
"chart": None,
|
||||
"error": error.model_dump(),
|
||||
"performance": {
|
||||
"query_duration_ms": execution_time,
|
||||
"cache_status": "error",
|
||||
"optimization_suggestions": [],
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
@@ -0,0 +1,50 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: get_chart_available_filters
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import ModelGetAvailableFiltersTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
ChartAvailableFiltersResponse,
|
||||
GetChartAvailableFiltersRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def get_chart_available_filters(
|
||||
request: GetChartAvailableFiltersRequest,
|
||||
) -> ChartAvailableFiltersResponse:
|
||||
"""
|
||||
Return available chart filter fields, types, and supported operators (MCP tool).
|
||||
"""
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
tool = ModelGetAvailableFiltersTool(
|
||||
dao_class=ChartDAO,
|
||||
output_schema=ChartAvailableFiltersResponse,
|
||||
logger=logger,
|
||||
)
|
||||
return tool.run()
|
||||
288
superset/mcp_service/chart/tool/get_chart_data.py
Normal file
288
superset/mcp_service/chart/tool/get_chart_data.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: get_chart_data
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.cache_utils import get_cache_status_from_result
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
ChartData,
|
||||
ChartError,
|
||||
DataColumn,
|
||||
GetChartDataRequest,
|
||||
PerformanceMetadata,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def get_chart_data(request: GetChartDataRequest) -> ChartData | ChartError: # noqa: C901
|
||||
"""
|
||||
Get the underlying data for a chart with advanced cache control.
|
||||
|
||||
This tool returns the actual data behind a chart, making it easy for LLM clients
|
||||
to understand and describe the chart contents without needing image rendering.
|
||||
|
||||
Supports lookup by:
|
||||
- Numeric ID (e.g., 123)
|
||||
- UUID string (e.g., "a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||
|
||||
Cache Control Features:
|
||||
- use_cache: Whether to use Superset's query result cache
|
||||
- force_refresh: Force refresh cached data
|
||||
- cache_timeout: Override default cache timeout for this query
|
||||
|
||||
Returns chart data in a structured format with summary and detailed cache status.
|
||||
"""
|
||||
try:
|
||||
from superset.daos.chart import ChartDAO
|
||||
from superset.utils import json as utils_json
|
||||
|
||||
# Find the chart
|
||||
chart = None
|
||||
if isinstance(request.identifier, int) or (
|
||||
isinstance(request.identifier, str) and request.identifier.isdigit()
|
||||
):
|
||||
chart_id = (
|
||||
int(request.identifier)
|
||||
if isinstance(request.identifier, str)
|
||||
else request.identifier
|
||||
)
|
||||
chart = ChartDAO.find_by_id(chart_id)
|
||||
else:
|
||||
# Try UUID lookup using DAO flexible method
|
||||
chart = ChartDAO.find_by_id(request.identifier, id_column="uuid")
|
||||
|
||||
if not chart:
|
||||
return ChartError(
|
||||
error=f"No chart found with identifier: {request.identifier}",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
logger.info(f"Getting data for chart {chart.id}: {chart.slice_name}")
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Get chart data using the existing API
|
||||
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
|
||||
# Parse the form_data to get query context
|
||||
form_data = utils_json.loads(chart.params) if chart.params else {}
|
||||
|
||||
# Create a proper QueryContext using the factory with cache control
|
||||
factory = QueryContextFactory()
|
||||
query_context = factory.create(
|
||||
datasource={"id": chart.datasource_id, "type": chart.datasource_type},
|
||||
queries=[
|
||||
{
|
||||
"filters": form_data.get("filters", []),
|
||||
"columns": form_data.get("groupby", []),
|
||||
"metrics": form_data.get("metrics", []),
|
||||
"row_limit": request.limit or 100,
|
||||
"order_desc": True,
|
||||
# Apply cache control from request
|
||||
"cache_timeout": request.cache_timeout,
|
||||
}
|
||||
],
|
||||
form_data=form_data,
|
||||
# Use cache unless force_refresh is True
|
||||
force=request.force_refresh,
|
||||
)
|
||||
|
||||
# Execute the query
|
||||
command = ChartDataCommand(query_context)
|
||||
result = command.run()
|
||||
|
||||
# Handle empty query results for certain chart types
|
||||
if not result or ("queries" not in result) or len(result["queries"]) == 0:
|
||||
return ChartError(
|
||||
error=f"No query results returned for chart {chart.id}. "
|
||||
f"This may occur with chart types like big_number.",
|
||||
error_type="EmptyQuery",
|
||||
)
|
||||
|
||||
# Extract data from result (we've already validated it exists above)
|
||||
query_result = result["queries"][0]
|
||||
data = query_result.get("data", [])
|
||||
raw_columns = query_result.get("colnames", [])
|
||||
|
||||
# Check if we have data to work with
|
||||
if not data:
|
||||
return ChartError(
|
||||
error=f"No data available for chart {chart.id}", error_type="NoData"
|
||||
)
|
||||
|
||||
# Create rich column metadata
|
||||
columns = []
|
||||
for col_name in raw_columns:
|
||||
# Sample some values for metadata
|
||||
sample_values = [
|
||||
row.get(col_name)
|
||||
for row in data[:3]
|
||||
if row.get(col_name) is not None
|
||||
]
|
||||
|
||||
# Infer data type
|
||||
data_type = "string"
|
||||
if sample_values:
|
||||
if all(isinstance(v, (int, float)) for v in sample_values):
|
||||
data_type = "numeric"
|
||||
elif all(isinstance(v, bool) for v in sample_values):
|
||||
data_type = "boolean"
|
||||
|
||||
columns.append(
|
||||
DataColumn(
|
||||
name=col_name,
|
||||
display_name=col_name.replace("_", " ").title(),
|
||||
data_type=data_type,
|
||||
sample_values=sample_values[:3],
|
||||
null_count=sum(1 for row in data if row.get(col_name) is None),
|
||||
unique_count=len({str(row.get(col_name)) for row in data}),
|
||||
)
|
||||
)
|
||||
|
||||
# Cache status information using utility function
|
||||
cache_status = get_cache_status_from_result(
|
||||
query_result, force_refresh=request.force_refresh
|
||||
)
|
||||
|
||||
# Generate insights and recommendations
|
||||
insights = []
|
||||
if len(data) > 100:
|
||||
insights.append(
|
||||
"Large dataset - consider filtering for better performance"
|
||||
)
|
||||
if len(raw_columns) > 10:
|
||||
insights.append("Many columns available - focus on key metrics")
|
||||
|
||||
# Add cache-specific insights
|
||||
if cache_status.cache_hit:
|
||||
if (
|
||||
cache_status.cache_age_seconds
|
||||
and cache_status.cache_age_seconds > 3600
|
||||
):
|
||||
hours_old = cache_status.cache_age_seconds // 3600
|
||||
insights.append(
|
||||
f"Data is from cache ({hours_old}h old) - "
|
||||
"consider refreshing for latest data"
|
||||
)
|
||||
else:
|
||||
insights.append("Data served from cache for fast response")
|
||||
else:
|
||||
insights.append("Fresh data retrieved from database")
|
||||
|
||||
recommended_visualizations = []
|
||||
if any(
|
||||
"time" in col.lower() or "date" in col.lower() for col in raw_columns
|
||||
):
|
||||
recommended_visualizations.extend(["line chart", "time series"])
|
||||
if len(raw_columns) <= 3:
|
||||
recommended_visualizations.extend(["bar chart", "scatter plot"])
|
||||
|
||||
# Performance metadata with cache awareness
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
performance_status = (
|
||||
"cache_hit" if cache_status.cache_hit else "fresh_query"
|
||||
)
|
||||
optimization_suggestions = []
|
||||
|
||||
if not cache_status.cache_hit and execution_time > 5000:
|
||||
optimization_suggestions.append(
|
||||
"Consider using cache for this slow query"
|
||||
)
|
||||
elif (
|
||||
cache_status.cache_hit
|
||||
and cache_status.cache_age_seconds
|
||||
and cache_status.cache_age_seconds > 86400
|
||||
):
|
||||
optimization_suggestions.append("Cache is old - consider refreshing")
|
||||
|
||||
performance = PerformanceMetadata(
|
||||
query_duration_ms=execution_time,
|
||||
cache_status=performance_status,
|
||||
optimization_suggestions=optimization_suggestions,
|
||||
)
|
||||
|
||||
# Generate comprehensive summary with cache info
|
||||
cache_info = ""
|
||||
if cache_status.cache_hit:
|
||||
age_info = (
|
||||
f" (cached {cache_status.cache_age_seconds // 60}m ago)"
|
||||
if cache_status.cache_age_seconds
|
||||
else " (cached)"
|
||||
)
|
||||
cache_info = age_info
|
||||
|
||||
summary_parts = [
|
||||
f"Chart '{chart.slice_name}' ({chart.viz_type})",
|
||||
f"Contains {len(data)} rows across {len(raw_columns)} columns"
|
||||
f"{cache_info}",
|
||||
]
|
||||
|
||||
if data and len(data) > 0:
|
||||
summary_parts.append(
|
||||
f"Sample data includes: {', '.join(raw_columns[:3])}"
|
||||
)
|
||||
|
||||
summary = ". ".join(summary_parts)
|
||||
|
||||
return ChartData(
|
||||
chart_id=chart.id,
|
||||
chart_name=chart.slice_name or f"Chart {chart.id}",
|
||||
chart_type=chart.viz_type or "unknown",
|
||||
columns=columns,
|
||||
data=data[: request.limit] if request.limit else data,
|
||||
row_count=len(data),
|
||||
total_rows=query_result.get("rowcount"),
|
||||
summary=summary,
|
||||
insights=insights,
|
||||
data_quality={
|
||||
"completeness": 1.0
|
||||
- (
|
||||
sum(col.null_count for col in columns)
|
||||
/ max(len(data) * len(columns), 1)
|
||||
)
|
||||
},
|
||||
recommended_visualizations=recommended_visualizations,
|
||||
data_freshness=None, # Add missing field
|
||||
performance=performance,
|
||||
cache_status=cache_status,
|
||||
)
|
||||
|
||||
except Exception as data_error:
|
||||
logger.error(f"Data retrieval error for chart {chart.id}: {data_error}")
|
||||
return ChartError(
|
||||
error=f"Error retrieving chart data: {str(data_error)}",
|
||||
error_type="DataError",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_chart_data: {e}")
|
||||
return ChartError(
|
||||
error=f"Failed to get chart data: {str(e)}", error_type="InternalError"
|
||||
)
|
||||
62
superset/mcp_service/chart/tool/get_chart_info.py
Normal file
62
superset/mcp_service/chart/tool/get_chart_info.py
Normal file
@@ -0,0 +1,62 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: get_chart_info
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import ModelGetInfoTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas import ChartError, ChartInfo
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
GetChartInfoRequest,
|
||||
serialize_chart_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def get_chart_info(request: GetChartInfoRequest) -> ChartInfo | ChartError:
|
||||
"""
|
||||
Get detailed information about a specific chart with metadata cache control.
|
||||
|
||||
Supports lookup by:
|
||||
- Numeric ID (e.g., 123)
|
||||
- UUID string (e.g., "a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||
|
||||
Metadata Cache Control:
|
||||
- use_cache: Whether to use metadata cache for faster responses
|
||||
- refresh_metadata: Force refresh of metadata cache for fresh data
|
||||
|
||||
Returns a ChartInfo model or ChartError on error.
|
||||
"""
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
tool = ModelGetInfoTool(
|
||||
dao_class=ChartDAO,
|
||||
output_schema=ChartInfo,
|
||||
error_schema=ChartError,
|
||||
serializer=serialize_chart_object,
|
||||
supports_slug=False, # Charts don't have slugs
|
||||
logger=logger,
|
||||
)
|
||||
return tool.run(request.identifier)
|
||||
772
superset/mcp_service/chart/tool/get_chart_preview.py
Normal file
772
superset/mcp_service/chart/tool/get_chart_preview.py
Normal file
@@ -0,0 +1,772 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: get_chart_preview
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, List, Protocol
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
AccessibilityMetadata,
|
||||
ASCIIPreview,
|
||||
Base64Preview,
|
||||
ChartError,
|
||||
ChartPreview,
|
||||
GetChartPreviewRequest,
|
||||
InteractivePreview,
|
||||
PerformanceMetadata,
|
||||
TablePreview,
|
||||
URLPreview,
|
||||
VegaLitePreview,
|
||||
)
|
||||
from superset.mcp_service.url_utils import get_superset_base_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChartLike(Protocol):
|
||||
"""Protocol for chart-like objects with required attributes for preview."""
|
||||
|
||||
id: int
|
||||
slice_name: str | None
|
||||
viz_type: str | None
|
||||
datasource_id: int
|
||||
datasource_type: str
|
||||
params: str | None
|
||||
digest: str
|
||||
uuid: Any
|
||||
|
||||
|
||||
class PreviewFormatStrategy:
|
||||
"""Base class for preview format strategies."""
|
||||
|
||||
def __init__(self, chart: ChartLike, request: GetChartPreviewRequest) -> None:
|
||||
self.chart = chart
|
||||
self.request = request
|
||||
|
||||
def generate(self) -> ChartPreview | ChartError:
|
||||
"""Generate preview in the specific format."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class URLPreviewStrategy(PreviewFormatStrategy):
|
||||
"""Generate URL-based image preview."""
|
||||
|
||||
def generate(self) -> URLPreview | ChartError:
|
||||
try:
|
||||
from flask import g
|
||||
|
||||
from superset.mcp_service.pooled_screenshot import PooledChartScreenshot
|
||||
from superset.utils.urls import get_url_path
|
||||
|
||||
chart_url = get_url_path("Superset.slice", slice_id=self.chart.id)
|
||||
screenshot = PooledChartScreenshot(chart_url, self.chart.digest)
|
||||
|
||||
window_size = (self.request.width or 800, self.request.height or 600)
|
||||
image_data = screenshot.get_screenshot(user=g.user, window_size=window_size)
|
||||
|
||||
if image_data:
|
||||
# Use the MCP service screenshot URL via centralized helper
|
||||
from superset.mcp_service.url_utils import get_chart_screenshot_url
|
||||
|
||||
preview_url = get_chart_screenshot_url(self.chart.id)
|
||||
|
||||
return URLPreview(
|
||||
preview_url=preview_url,
|
||||
width=self.request.width or 800,
|
||||
height=self.request.height or 600,
|
||||
)
|
||||
else:
|
||||
return ChartError(
|
||||
error=f"Could not generate screenshot for chart {self.chart.id}",
|
||||
error_type="ScreenshotError",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"URL preview generation failed: {e}")
|
||||
return ChartError(
|
||||
error=f"Failed to generate URL preview: {str(e)}", error_type="URLError"
|
||||
)
|
||||
|
||||
|
||||
# Base64 preview support removed - we never return base64 data
|
||||
|
||||
|
||||
class ASCIIPreviewStrategy(PreviewFormatStrategy):
|
||||
"""Generate ASCII art preview."""
|
||||
|
||||
def generate(self) -> ASCIIPreview | ChartError:
|
||||
try:
|
||||
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.utils import json as utils_json
|
||||
|
||||
form_data = utils_json.loads(self.chart.params) if self.chart.params else {}
|
||||
|
||||
logger.info(f"Chart form_data keys: {list(form_data.keys())}")
|
||||
logger.info(f"Chart viz_type: {self.chart.viz_type}")
|
||||
|
||||
# Build query for chart data
|
||||
x_axis_config = form_data.get("x_axis")
|
||||
groupby_columns = form_data.get("groupby", [])
|
||||
metrics = form_data.get("metrics", [])
|
||||
|
||||
columns = groupby_columns.copy()
|
||||
if x_axis_config and isinstance(x_axis_config, str):
|
||||
columns.append(x_axis_config)
|
||||
elif x_axis_config and isinstance(x_axis_config, dict):
|
||||
if "column_name" in x_axis_config:
|
||||
columns.append(x_axis_config["column_name"])
|
||||
|
||||
factory = QueryContextFactory()
|
||||
query_context = factory.create(
|
||||
datasource={
|
||||
"id": self.chart.datasource_id,
|
||||
"type": self.chart.datasource_type,
|
||||
},
|
||||
queries=[
|
||||
{
|
||||
"filters": form_data.get("filters", []),
|
||||
"columns": columns,
|
||||
"metrics": metrics,
|
||||
"row_limit": 50,
|
||||
"order_desc": True,
|
||||
}
|
||||
],
|
||||
form_data=form_data,
|
||||
force=False,
|
||||
)
|
||||
|
||||
command = ChartDataCommand(query_context)
|
||||
result = command.run()
|
||||
|
||||
data = []
|
||||
if result and "queries" in result and len(result["queries"]) > 0:
|
||||
data = result["queries"][0].get("data", [])
|
||||
|
||||
ascii_chart = generate_ascii_chart(
|
||||
data,
|
||||
self.chart.viz_type or "table",
|
||||
self.request.ascii_width or 80,
|
||||
self.request.ascii_height or 20,
|
||||
)
|
||||
|
||||
return ASCIIPreview(
|
||||
ascii_content=ascii_chart,
|
||||
width=self.request.ascii_width or 80,
|
||||
height=self.request.ascii_height or 20,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"ASCII preview generation failed: {e}")
|
||||
return ChartError(
|
||||
error=f"Failed to generate ASCII preview: {str(e)}",
|
||||
error_type="ASCIIError",
|
||||
)
|
||||
|
||||
|
||||
class TablePreviewStrategy(PreviewFormatStrategy):
|
||||
"""Generate table preview of chart data."""
|
||||
|
||||
def generate(self) -> TablePreview | ChartError:
|
||||
try:
|
||||
from superset.commands.chart.data.get_data_command import ChartDataCommand
|
||||
from superset.common.query_context_factory import QueryContextFactory
|
||||
from superset.utils import json as utils_json
|
||||
|
||||
form_data = utils_json.loads(self.chart.params) if self.chart.params else {}
|
||||
|
||||
factory = QueryContextFactory()
|
||||
query_context = factory.create(
|
||||
datasource={
|
||||
"id": self.chart.datasource_id,
|
||||
"type": self.chart.datasource_type,
|
||||
},
|
||||
queries=[
|
||||
{
|
||||
"filters": form_data.get("filters", []),
|
||||
"columns": form_data.get("groupby", []),
|
||||
"metrics": form_data.get("metrics", []),
|
||||
"row_limit": 20,
|
||||
"order_desc": True,
|
||||
}
|
||||
],
|
||||
form_data=form_data,
|
||||
force=False,
|
||||
)
|
||||
|
||||
command = ChartDataCommand(query_context)
|
||||
result = command.run()
|
||||
|
||||
data = []
|
||||
if result and "queries" in result and len(result["queries"]) > 0:
|
||||
data = result["queries"][0].get("data", [])
|
||||
|
||||
table_data = _generate_ascii_table(data, 120)
|
||||
|
||||
return TablePreview(
|
||||
table_data=table_data,
|
||||
row_count=len(data),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Table preview generation failed: {e}")
|
||||
return ChartError(
|
||||
error=f"Failed to generate table preview: {str(e)}",
|
||||
error_type="TableError",
|
||||
)
|
||||
|
||||
|
||||
class PreviewFormatGenerator:
|
||||
"""Factory for generating different preview formats."""
|
||||
|
||||
STRATEGIES = {
|
||||
"url": URLPreviewStrategy,
|
||||
"ascii": ASCIIPreviewStrategy,
|
||||
"table": TablePreviewStrategy,
|
||||
}
|
||||
|
||||
def __init__(self, chart: ChartLike, request: GetChartPreviewRequest) -> None:
|
||||
self.chart = chart
|
||||
self.request = request
|
||||
|
||||
def generate(
|
||||
self,
|
||||
) -> (
|
||||
URLPreview
|
||||
| InteractivePreview
|
||||
| ASCIIPreview
|
||||
| VegaLitePreview
|
||||
| TablePreview
|
||||
| Base64Preview
|
||||
| ChartError
|
||||
):
|
||||
"""Generate preview using the appropriate strategy."""
|
||||
strategy_class = self.STRATEGIES.get(self.request.format)
|
||||
|
||||
if not strategy_class:
|
||||
return ChartError(
|
||||
error=f"Unsupported preview format: {self.request.format}",
|
||||
error_type="UnsupportedFormat",
|
||||
)
|
||||
|
||||
strategy = strategy_class(self.chart, self.request)
|
||||
return strategy.generate()
|
||||
|
||||
|
||||
def generate_ascii_chart(
|
||||
data: List[Any], chart_type: str, width: int = 80, height: int = 20
|
||||
) -> str:
|
||||
"""Generate ASCII art chart from data."""
|
||||
if not data or len(data) == 0:
|
||||
return "No data available for ASCII chart"
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
f"generate_ascii_chart: chart_type={chart_type}, data_rows={len(data)}"
|
||||
)
|
||||
|
||||
# Generate appropriate ASCII chart based on type
|
||||
if chart_type in ["bar", "column", "echarts_timeseries_bar"]:
|
||||
logger.info("Generating bar chart")
|
||||
return _generate_ascii_bar_chart(data, width, height)
|
||||
elif chart_type in ["line", "echarts_timeseries_line"]:
|
||||
logger.info("Generating line chart")
|
||||
return _generate_ascii_line_chart(data, width, height)
|
||||
elif chart_type in ["scatter", "echarts_timeseries_scatter"]:
|
||||
logger.info("Generating scatter chart")
|
||||
return _generate_ascii_scatter_chart(data, width, height)
|
||||
else:
|
||||
# Default to table format for unsupported chart types
|
||||
logger.info(f"Unsupported chart type '{chart_type}', falling back to table")
|
||||
return _generate_ascii_table(data, width)
|
||||
except Exception as e:
|
||||
logger.error(f"ASCII chart generation failed: {e}")
|
||||
import traceback
|
||||
|
||||
logger.error(f"Traceback: {traceback.format_exc()}")
|
||||
return f"ASCII chart generation failed: {str(e)}"
|
||||
|
||||
|
||||
def _generate_ascii_bar_chart(data: List[Any], width: int, height: int) -> str:
|
||||
"""Generate ASCII bar chart."""
|
||||
if not data:
|
||||
return "No data for bar chart"
|
||||
|
||||
lines = []
|
||||
lines.append("ASCII Bar Chart")
|
||||
lines.append("=" * min(width, 50))
|
||||
|
||||
# Extract numeric values for bars
|
||||
values = []
|
||||
labels = []
|
||||
|
||||
for row in data[:10]: # Limit to 10 bars for readability
|
||||
if isinstance(row, dict):
|
||||
# Find numeric and string values
|
||||
numeric_val = None
|
||||
label_val = None
|
||||
|
||||
for _key, val in row.items():
|
||||
if isinstance(val, (int, float)) and numeric_val is None:
|
||||
numeric_val = val
|
||||
elif isinstance(val, str) and label_val is None:
|
||||
label_val = val
|
||||
|
||||
if numeric_val is not None:
|
||||
values.append(numeric_val)
|
||||
labels.append(label_val or f"Item {len(values)}")
|
||||
|
||||
if not values:
|
||||
return "No numeric data found for bar chart"
|
||||
|
||||
# Normalize values to fit in chart
|
||||
max_val = max(values) if values else 1
|
||||
|
||||
for label, value in zip(labels, values, strict=False):
|
||||
# Calculate bar length (max 30 chars)
|
||||
bar_length = int((value / max_val) * 30) if max_val > 0 else 0
|
||||
bar = "█" * bar_length
|
||||
lines.append(f"{label[:10]:>10} |{bar:<30} {value}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _generate_ascii_line_chart(data: List[Any], width: int, height: int) -> str:
|
||||
"""Generate ASCII line chart."""
|
||||
if not data:
|
||||
return "No data for line chart"
|
||||
|
||||
lines = []
|
||||
lines.append("ASCII Line Chart")
|
||||
lines.append("=" * min(width, 50))
|
||||
|
||||
# Extract values for plotting
|
||||
values = _extract_numeric_values(data)
|
||||
|
||||
if not values:
|
||||
return "No numeric data found for line chart"
|
||||
|
||||
# Generate sparkline
|
||||
sparkline_data = _create_sparkline(values)
|
||||
lines.extend(sparkline_data)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_numeric_values(data: List[Any]) -> List[float]:
|
||||
"""Extract numeric values from data for line chart."""
|
||||
values = []
|
||||
for row in data[:20]: # Limit points
|
||||
if isinstance(row, dict):
|
||||
for _key, val in row.items():
|
||||
if isinstance(val, (int, float)):
|
||||
values.append(val)
|
||||
break
|
||||
return values
|
||||
|
||||
|
||||
def _create_sparkline(values: List[float]) -> List[str]:
|
||||
"""Create sparkline visualization from values."""
|
||||
if len(values) <= 1:
|
||||
return []
|
||||
|
||||
max_val = max(values)
|
||||
min_val = min(values)
|
||||
range_val = max_val - min_val if max_val != min_val else 1
|
||||
|
||||
sparkline = ""
|
||||
for val in values:
|
||||
normalized = (val - min_val) / range_val
|
||||
if normalized < 0.2:
|
||||
sparkline += "▁"
|
||||
elif normalized < 0.4:
|
||||
sparkline += "▂"
|
||||
elif normalized < 0.6:
|
||||
sparkline += "▄"
|
||||
elif normalized < 0.8:
|
||||
sparkline += "▆"
|
||||
else:
|
||||
sparkline += "█"
|
||||
|
||||
# Safe formatting to avoid NaN display
|
||||
if _is_nan_value(min_val) or _is_nan_value(max_val):
|
||||
return ["Range: Unable to calculate from data", sparkline]
|
||||
else:
|
||||
return [f"Range: {min_val:.2f} to {max_val:.2f}", sparkline]
|
||||
|
||||
|
||||
def _is_nan_value(value: Any) -> bool:
|
||||
"""Check if a value is NaN or invalid."""
|
||||
try:
|
||||
import math
|
||||
|
||||
return math.isnan(float(value))
|
||||
except (ValueError, TypeError):
|
||||
return True
|
||||
|
||||
|
||||
def _generate_ascii_scatter_chart(data: List[Any], width: int, height: int) -> str:
|
||||
"""Generate ASCII scatter plot."""
|
||||
if not data:
|
||||
return "No data for scatter chart"
|
||||
|
||||
lines = []
|
||||
lines.append("ASCII Scatter Plot")
|
||||
lines.append("=" * min(width, 50))
|
||||
|
||||
# Extract data points
|
||||
x_values, y_values, x_column, y_column = _extract_scatter_data(data)
|
||||
|
||||
# Debug info
|
||||
lines.extend(_create_debug_info(x_values, y_values, x_column, y_column, data))
|
||||
|
||||
# Check if we have enough data
|
||||
if len(x_values) < 2:
|
||||
return _generate_ascii_table(data, width)
|
||||
|
||||
# Add axis info
|
||||
lines.extend(_create_axis_info(x_values, y_values, x_column, y_column))
|
||||
|
||||
# Create and render grid
|
||||
grid = _create_scatter_grid(x_values, y_values, width, height)
|
||||
lines.extend(_render_scatter_grid(grid, x_values, y_values, width, height))
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _extract_scatter_data(
|
||||
data: List[Any],
|
||||
) -> tuple[List[float], List[float], str | None, str | None]:
|
||||
"""Extract X,Y data from scatter chart data."""
|
||||
x_values = []
|
||||
y_values = []
|
||||
x_column = None
|
||||
y_column = None
|
||||
numeric_columns = []
|
||||
|
||||
if data and isinstance(data[0], dict):
|
||||
# Find the first two numeric columns
|
||||
for key, val in data[0].items():
|
||||
if isinstance(val, (int, float)) and not (
|
||||
isinstance(val, float) and (val != val)
|
||||
): # Exclude NaN
|
||||
numeric_columns.append(key)
|
||||
|
||||
if len(numeric_columns) >= 2:
|
||||
x_column = numeric_columns[0]
|
||||
y_column = numeric_columns[1]
|
||||
|
||||
# Extract X,Y pairs
|
||||
for row in data[:50]: # Limit for ASCII display
|
||||
if isinstance(row, dict):
|
||||
x_val = row.get(x_column)
|
||||
y_val = row.get(y_column)
|
||||
# Check for valid numbers (not NaN)
|
||||
if (
|
||||
isinstance(x_val, (int, float))
|
||||
and isinstance(y_val, (int, float))
|
||||
and not (
|
||||
isinstance(x_val, float) and (x_val != x_val)
|
||||
) # Not NaN
|
||||
and not (isinstance(y_val, float) and (y_val != y_val))
|
||||
): # Not NaN
|
||||
x_values.append(x_val)
|
||||
y_values.append(y_val)
|
||||
|
||||
return x_values, y_values, x_column, y_column
|
||||
|
||||
|
||||
def _create_debug_info(
|
||||
x_values: List[float],
|
||||
y_values: List[float],
|
||||
x_column: str | None,
|
||||
y_column: str | None,
|
||||
data: List[Any],
|
||||
) -> List[str]:
|
||||
"""Create debug information lines for scatter chart."""
|
||||
numeric_columns = []
|
||||
if data and isinstance(data[0], dict):
|
||||
for key, val in data[0].items():
|
||||
if isinstance(val, (int, float)) and not (
|
||||
isinstance(val, float) and (val != val)
|
||||
):
|
||||
numeric_columns.append(key)
|
||||
|
||||
return [
|
||||
f"DEBUG: Found {len(numeric_columns)} numeric columns: {numeric_columns}",
|
||||
f"DEBUG: X column: {x_column}, Y column: {y_column}",
|
||||
f"DEBUG: Valid X,Y pairs: {len(x_values)}",
|
||||
]
|
||||
|
||||
|
||||
def _create_axis_info(
|
||||
x_values: List[float],
|
||||
y_values: List[float],
|
||||
x_column: str | None,
|
||||
y_column: str | None,
|
||||
) -> List[str]:
|
||||
"""Create axis information lines."""
|
||||
return [
|
||||
f"X-axis: {x_column} (range: {min(x_values):.2f} to {max(x_values):.2f})",
|
||||
f"Y-axis: {y_column} (range: {min(y_values):.2f} to {max(y_values):.2f})",
|
||||
f"Showing {len(x_values)} data points",
|
||||
"",
|
||||
]
|
||||
|
||||
|
||||
def _create_scatter_grid(
|
||||
x_values: List[float], y_values: List[float], width: int, height: int
|
||||
) -> List[List[str]]:
|
||||
"""Create and populate the scatter plot grid."""
|
||||
plot_width = min(40, width - 10)
|
||||
plot_height = min(15, height - 8)
|
||||
|
||||
# Normalize values to fit in grid
|
||||
x_min, x_max = min(x_values), max(x_values)
|
||||
y_min, y_max = min(y_values), max(y_values)
|
||||
x_range = x_max - x_min if x_max != x_min else 1
|
||||
y_range = y_max - y_min if y_max != y_min else 1
|
||||
|
||||
# Create grid
|
||||
grid = [[" " for _ in range(plot_width)] for _ in range(plot_height)]
|
||||
|
||||
# Plot points
|
||||
for x, y in zip(x_values, y_values, strict=False):
|
||||
try:
|
||||
grid_x = int(((x - x_min) / x_range) * (plot_width - 1))
|
||||
grid_y = int(((y - y_min) / y_range) * (plot_height - 1))
|
||||
grid_y = plot_height - 1 - grid_y # Flip Y axis for display
|
||||
|
||||
if 0 <= grid_x < plot_width and 0 <= grid_y < plot_height:
|
||||
if grid[grid_y][grid_x] == " ":
|
||||
grid[grid_y][grid_x] = "•"
|
||||
else:
|
||||
grid[grid_y][grid_x] = "█" # Multiple points
|
||||
except (ValueError, OverflowError):
|
||||
# Skip points that can't be converted to integers (NaN, inf, etc.)
|
||||
continue
|
||||
|
||||
return grid
|
||||
|
||||
|
||||
def _render_scatter_grid(
|
||||
grid: List[List[str]],
|
||||
x_values: List[float],
|
||||
y_values: List[float],
|
||||
width: int,
|
||||
height: int,
|
||||
) -> List[str]:
|
||||
"""Render the scatter plot grid with axes and labels."""
|
||||
lines = []
|
||||
plot_width = min(40, width - 10)
|
||||
plot_height = min(15, height - 8)
|
||||
|
||||
x_min, x_max = min(x_values), max(x_values)
|
||||
y_min, y_max = min(y_values), max(y_values)
|
||||
y_range = y_max - y_min if y_max != y_min else 1
|
||||
|
||||
# Add Y-axis labels and plot
|
||||
for i, row in enumerate(grid):
|
||||
y_val = y_max - (i / (plot_height - 1)) * y_range if plot_height > 1 else y_max
|
||||
y_label = f"{y_val:.1f}" if abs(y_val) < 1000 else f"{y_val:.0f}"
|
||||
lines.append(f"{y_label:>6} |{''.join(row)}")
|
||||
|
||||
# Add X-axis
|
||||
x_axis_line = " " * 7 + "+" + "-" * plot_width
|
||||
lines.append(x_axis_line)
|
||||
|
||||
# Add X-axis labels
|
||||
x_left_label = f"{x_min:.1f}" if abs(x_min) < 1000 else f"{x_min:.0f}"
|
||||
x_right_label = f"{x_max:.1f}" if abs(x_max) < 1000 else f"{x_max:.0f}"
|
||||
x_labels = (
|
||||
" " * 8
|
||||
+ x_left_label
|
||||
+ " " * (plot_width - len(x_left_label) - len(x_right_label))
|
||||
+ x_right_label
|
||||
)
|
||||
lines.append(x_labels)
|
||||
|
||||
return lines
|
||||
|
||||
|
||||
def _generate_ascii_table(data: List[Any], width: int) -> str:
|
||||
"""Generate ASCII table from data."""
|
||||
if not data:
|
||||
return "No data for table"
|
||||
|
||||
lines = []
|
||||
lines.append("Data Table")
|
||||
lines.append("=" * min(width, 50))
|
||||
|
||||
# Get column headers from first row
|
||||
if isinstance(data[0], dict):
|
||||
headers = list(data[0].keys())[:5] # Limit columns
|
||||
|
||||
# Header row
|
||||
header_row = " | ".join(f"{h[:10]:>10}" for h in headers)
|
||||
lines.append(header_row)
|
||||
lines.append("-" * len(header_row))
|
||||
|
||||
# Data rows (limit to first 10)
|
||||
for row in data[:10]:
|
||||
values = []
|
||||
for header in headers:
|
||||
val = row.get(header, "")
|
||||
if isinstance(val, float):
|
||||
values.append(f"{val:.2f}")
|
||||
else:
|
||||
values.append(str(val)[:10])
|
||||
|
||||
data_row = " | ".join(f"{v:>10}" for v in values)
|
||||
lines.append(data_row)
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
def _get_chart_preview_internal(
|
||||
request: GetChartPreviewRequest,
|
||||
) -> ChartPreview | ChartError:
|
||||
"""
|
||||
Get a visual preview of a chart with URLs for LLM embedding.
|
||||
|
||||
This tool generates or retrieves URLs for chart images that can be
|
||||
displayed directly in LLM clients. The URLs point to Superset's
|
||||
screenshot endpoints for proper image serving.
|
||||
|
||||
Supports lookup by:
|
||||
- Numeric ID (e.g., 123)
|
||||
- UUID string (e.g., "a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||
|
||||
Returns a ChartPreview with Superset URLs for the chart image or
|
||||
ChartError on error.
|
||||
"""
|
||||
try:
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
# Find the chart
|
||||
chart = None
|
||||
if isinstance(request.identifier, int) or (
|
||||
isinstance(request.identifier, str) and request.identifier.isdigit()
|
||||
):
|
||||
chart_id = (
|
||||
int(request.identifier)
|
||||
if isinstance(request.identifier, str)
|
||||
else request.identifier
|
||||
)
|
||||
chart = ChartDAO.find_by_id(chart_id)
|
||||
else:
|
||||
# Try UUID lookup using DAO flexible method
|
||||
chart = ChartDAO.find_by_id(request.identifier, id_column="uuid")
|
||||
|
||||
if not chart:
|
||||
return ChartError(
|
||||
error=f"No chart found with identifier: {request.identifier}",
|
||||
error_type="NotFound",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Generating preview for chart {chart.id} in {request.format} format: "
|
||||
f"{chart.slice_name}"
|
||||
)
|
||||
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# Handle different preview formats using strategy pattern
|
||||
preview_generator = PreviewFormatGenerator(chart, request)
|
||||
content = preview_generator.generate()
|
||||
|
||||
if isinstance(content, ChartError):
|
||||
return content
|
||||
|
||||
# Create performance and accessibility metadata
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
performance = PerformanceMetadata(
|
||||
query_duration_ms=execution_time,
|
||||
cache_status="miss",
|
||||
optimization_suggestions=[],
|
||||
)
|
||||
|
||||
accessibility = AccessibilityMetadata(
|
||||
color_blind_safe=True,
|
||||
alt_text=f"Preview of {chart.slice_name or f'Chart {chart.id}'}",
|
||||
high_contrast_available=False,
|
||||
)
|
||||
|
||||
# Create backward-compatible response with enhanced metadata
|
||||
result = ChartPreview(
|
||||
chart_id=chart.id,
|
||||
chart_name=chart.slice_name or f"Chart {chart.id}",
|
||||
chart_type=chart.viz_type or "unknown",
|
||||
explore_url=f"{get_superset_base_url()}/explore/?slice_id={chart.id}",
|
||||
content=content,
|
||||
chart_description=(
|
||||
f"Preview of {chart.viz_type or 'chart'}: "
|
||||
f"{chart.slice_name or f'Chart {chart.id}'}"
|
||||
),
|
||||
accessibility=accessibility,
|
||||
performance=performance,
|
||||
)
|
||||
|
||||
# Add format-specific fields for backward compatibility
|
||||
if isinstance(content, URLPreview):
|
||||
result.format = "url"
|
||||
result.preview_url = content.preview_url
|
||||
result.width = content.width
|
||||
result.height = content.height
|
||||
elif isinstance(content, ASCIIPreview):
|
||||
result.format = "ascii"
|
||||
result.ascii_chart = content.ascii_content
|
||||
result.width = content.width
|
||||
result.height = content.height
|
||||
elif isinstance(content, TablePreview):
|
||||
result.format = "table"
|
||||
result.table_data = content.table_data
|
||||
# Base64 preview support removed
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in get_chart_preview: {e}")
|
||||
return ChartError(
|
||||
error=f"Failed to get chart preview: {str(e)}", error_type="InternalError"
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def get_chart_preview(request: GetChartPreviewRequest) -> ChartPreview | ChartError:
|
||||
"""
|
||||
Get a visual preview of a chart with URLs for LLM embedding.
|
||||
|
||||
This tool generates or retrieves URLs for chart images that can be
|
||||
displayed directly in LLM clients. The URLs point to Superset's
|
||||
screenshot endpoints for proper image serving.
|
||||
|
||||
Supports lookup by:
|
||||
- Numeric ID (e.g., 123)
|
||||
- UUID string (e.g., "a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||
|
||||
Returns a ChartPreview with Superset URLs for the chart image or
|
||||
ChartError on error.
|
||||
"""
|
||||
return _get_chart_preview_internal(request)
|
||||
105
superset/mcp_service/chart/tool/list_charts.py
Normal file
105
superset/mcp_service/chart/tool/list_charts.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: list_charts (advanced filtering with metadata cache control)
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import ModelListTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas import ChartInfo, ChartList
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
ChartFilter,
|
||||
ListChartsRequest,
|
||||
serialize_chart_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_CHART_COLUMNS = [
|
||||
"id",
|
||||
"slice_name",
|
||||
"viz_type",
|
||||
"uuid",
|
||||
"datasource_name",
|
||||
"description",
|
||||
"changed_by_name",
|
||||
"created_by_name",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
|
||||
SORTABLE_CHART_COLUMNS = [
|
||||
"id",
|
||||
"slice_name",
|
||||
"viz_type",
|
||||
"datasource_name",
|
||||
"description",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def list_charts(request: ListChartsRequest) -> ChartList:
|
||||
"""
|
||||
List charts with advanced filtering, search, and metadata cache control.
|
||||
|
||||
Uses a clear request object schema to avoid validation ambiguity with
|
||||
arrays/strings. All parameters are properly typed and have sensible defaults.
|
||||
|
||||
Search columns: slice_name, description
|
||||
Sortable columns for order_column: id, slice_name, viz_type, datasource_name,
|
||||
description, changed_on, created_on
|
||||
|
||||
Metadata Cache Control:
|
||||
- use_cache: Whether to use metadata cache for faster responses
|
||||
- refresh_metadata: Force refresh of metadata cache for fresh data
|
||||
|
||||
When refresh_metadata=True, the tool will fetch fresh metadata from the database
|
||||
which is useful when database schema has changed.
|
||||
"""
|
||||
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
tool = ModelListTool(
|
||||
dao_class=ChartDAO,
|
||||
output_schema=ChartInfo,
|
||||
item_serializer=lambda obj, cols: serialize_chart_object(obj) if obj else None, # type: ignore[arg-type]
|
||||
filter_type=ChartFilter,
|
||||
default_columns=DEFAULT_CHART_COLUMNS,
|
||||
search_columns=[
|
||||
"slice_name",
|
||||
"description",
|
||||
],
|
||||
list_field_name="charts",
|
||||
output_list_schema=ChartList,
|
||||
logger=logger,
|
||||
)
|
||||
return tool.run(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=request.order_column,
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
217
superset/mcp_service/chart/tool/update_chart.py
Normal file
217
superset/mcp_service/chart/tool/update_chart.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: update_chart
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
analyze_chart_capabilities,
|
||||
analyze_chart_semantics,
|
||||
generate_chart_name,
|
||||
map_config_to_form_data,
|
||||
)
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
AccessibilityMetadata,
|
||||
PerformanceMetadata,
|
||||
UpdateChartRequest,
|
||||
)
|
||||
from superset.mcp_service.url_utils import (
|
||||
get_chart_screenshot_url,
|
||||
get_superset_base_url,
|
||||
)
|
||||
from superset.utils import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def update_chart(request: UpdateChartRequest) -> Dict[str, Any]:
|
||||
"""
|
||||
Update an existing saved chart with new configuration.
|
||||
|
||||
This tool modifies a permanently saved chart in Superset with new settings,
|
||||
filters, or visualization parameters. The chart must already exist (created
|
||||
via generate_chart with save_chart=True).
|
||||
|
||||
Use this tool when:
|
||||
- User wants to modify an existing saved chart
|
||||
- Updating chart title, filters, or visualization settings
|
||||
- Changing chart type or data columns for a saved chart
|
||||
|
||||
Args:
|
||||
request: Chart update request with chart identifier, new config,
|
||||
and optional preview generation
|
||||
|
||||
Returns:
|
||||
Response with updated chart info, preview images, and explore URL
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Find the existing chart
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
chart = None
|
||||
if isinstance(request.identifier, int) or (
|
||||
isinstance(request.identifier, str) and request.identifier.isdigit()
|
||||
):
|
||||
chart_id = (
|
||||
int(request.identifier)
|
||||
if isinstance(request.identifier, str)
|
||||
else request.identifier
|
||||
)
|
||||
chart = ChartDAO.find_by_id(chart_id)
|
||||
else:
|
||||
# Try UUID lookup using DAO flexible method
|
||||
chart = ChartDAO.find_by_id(request.identifier, id_column="uuid")
|
||||
|
||||
if not chart:
|
||||
return {
|
||||
"chart": None,
|
||||
"error": f"No chart found with identifier: {request.identifier}",
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
|
||||
# Map the new config to form_data format
|
||||
new_form_data = map_config_to_form_data(request.config)
|
||||
|
||||
# Update chart using Superset's command
|
||||
from superset.commands.chart.update import UpdateChartCommand
|
||||
|
||||
# Generate new chart name if provided, otherwise keep existing
|
||||
chart_name = (
|
||||
request.chart_name
|
||||
if request.chart_name
|
||||
else chart.slice_name or generate_chart_name(request.config)
|
||||
)
|
||||
|
||||
update_payload = {
|
||||
"slice_name": chart_name,
|
||||
"viz_type": new_form_data["viz_type"],
|
||||
"params": json.dumps(new_form_data),
|
||||
}
|
||||
|
||||
command = UpdateChartCommand(chart.id, update_payload)
|
||||
updated_chart = command.run()
|
||||
|
||||
# Generate semantic analysis
|
||||
capabilities = analyze_chart_capabilities(updated_chart, request.config)
|
||||
semantics = analyze_chart_semantics(updated_chart, request.config)
|
||||
|
||||
# Create performance metadata
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
performance = PerformanceMetadata(
|
||||
query_duration_ms=execution_time,
|
||||
cache_status="miss",
|
||||
optimization_suggestions=[],
|
||||
)
|
||||
|
||||
# Create accessibility metadata
|
||||
chart_name = (
|
||||
updated_chart.slice_name
|
||||
if updated_chart and hasattr(updated_chart, "slice_name")
|
||||
else generate_chart_name(request.config)
|
||||
)
|
||||
accessibility = AccessibilityMetadata(
|
||||
color_blind_safe=True, # Would need actual analysis
|
||||
alt_text=f"Updated chart showing {chart_name}",
|
||||
high_contrast_available=False,
|
||||
)
|
||||
|
||||
# Generate previews if requested
|
||||
previews = {}
|
||||
if request.generate_preview:
|
||||
try:
|
||||
from superset.mcp_service.chart.tool.get_chart_preview import (
|
||||
_get_chart_preview_internal,
|
||||
GetChartPreviewRequest,
|
||||
)
|
||||
|
||||
for format_type in request.preview_formats:
|
||||
preview_request = GetChartPreviewRequest(
|
||||
identifier=str(updated_chart.id), format=format_type
|
||||
)
|
||||
preview_result = _get_chart_preview_internal(preview_request)
|
||||
|
||||
if hasattr(preview_result, "content"):
|
||||
previews[format_type] = preview_result.content
|
||||
|
||||
except Exception as e:
|
||||
# Log warning but don't fail the entire request
|
||||
logger.warning(f"Preview generation failed: {e}")
|
||||
|
||||
# Return enhanced data
|
||||
result = {
|
||||
"chart": {
|
||||
"id": updated_chart.id,
|
||||
"slice_name": updated_chart.slice_name,
|
||||
"viz_type": updated_chart.viz_type,
|
||||
"url": (
|
||||
f"{get_superset_base_url()}/explore/?slice_id={updated_chart.id}"
|
||||
),
|
||||
"uuid": str(updated_chart.uuid) if updated_chart.uuid else None,
|
||||
"updated": True,
|
||||
},
|
||||
"error": None,
|
||||
# Enhanced fields for better LLM integration
|
||||
"previews": previews,
|
||||
"capabilities": capabilities.model_dump() if capabilities else None,
|
||||
"semantics": semantics.model_dump() if semantics else None,
|
||||
"explore_url": (
|
||||
f"{get_superset_base_url()}/explore/?slice_id={updated_chart.id}"
|
||||
),
|
||||
"api_endpoints": {
|
||||
"data": (
|
||||
f"{get_superset_base_url()}/api/v1/chart/{updated_chart.id}/data/"
|
||||
),
|
||||
"preview": get_chart_screenshot_url(updated_chart.id),
|
||||
"export": (
|
||||
f"{get_superset_base_url()}/api/v1/chart/{updated_chart.id}/export/"
|
||||
),
|
||||
},
|
||||
"performance": performance.model_dump() if performance else None,
|
||||
"accessibility": accessibility.model_dump() if accessibility else None,
|
||||
"success": True,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
return {
|
||||
"chart": None,
|
||||
"error": f"Chart update failed: {str(e)}",
|
||||
"performance": {
|
||||
"query_duration_ms": execution_time,
|
||||
"cache_status": "error",
|
||||
"optimization_suggestions": [],
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
170
superset/mcp_service/chart/tool/update_chart_preview.py
Normal file
170
superset/mcp_service/chart/tool/update_chart_preview.py
Normal file
@@ -0,0 +1,170 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: update_chart_preview
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
analyze_chart_capabilities,
|
||||
analyze_chart_semantics,
|
||||
generate_chart_name,
|
||||
generate_explore_link,
|
||||
map_config_to_form_data,
|
||||
)
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
AccessibilityMetadata,
|
||||
PerformanceMetadata,
|
||||
UpdateChartPreviewRequest,
|
||||
URLPreview,
|
||||
)
|
||||
from superset.mcp_service.url_utils import get_mcp_service_url
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def update_chart_preview(request: UpdateChartPreviewRequest) -> Dict[str, Any]:
|
||||
"""
|
||||
Update a cached chart preview with new configuration without saving.
|
||||
|
||||
This tool modifies the cached form_data for a chart preview (created via
|
||||
generate_chart with save_chart=False) and returns a new preview with updated
|
||||
configuration. The original form_data_key is invalidated and a new one is
|
||||
created.
|
||||
|
||||
Use this tool when:
|
||||
- User wants to modify a chart preview before deciding to save
|
||||
- Iterating on chart design without creating permanent charts
|
||||
- Testing different visualization configurations
|
||||
- Updating filters, chart type, or data columns for previews
|
||||
|
||||
Args:
|
||||
request: Chart preview update request with form_data_key, dataset_id,
|
||||
new config, and optional preview generation
|
||||
|
||||
Returns:
|
||||
Response with new form_data_key, preview images, and explore URL
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Map the new config to form_data format
|
||||
new_form_data = map_config_to_form_data(request.config)
|
||||
|
||||
# Generate new explore link with updated form_data
|
||||
explore_url = generate_explore_link(request.dataset_id, new_form_data)
|
||||
|
||||
# Extract new form_data_key from the explore URL
|
||||
new_form_data_key = None
|
||||
if "form_data_key=" in explore_url:
|
||||
new_form_data_key = explore_url.split("form_data_key=")[1].split("&")[0]
|
||||
|
||||
# Generate semantic analysis
|
||||
capabilities = analyze_chart_capabilities(None, request.config)
|
||||
semantics = analyze_chart_semantics(None, request.config)
|
||||
|
||||
# Create performance metadata
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
performance = PerformanceMetadata(
|
||||
query_duration_ms=execution_time,
|
||||
cache_status="miss",
|
||||
optimization_suggestions=[],
|
||||
)
|
||||
|
||||
# Create accessibility metadata
|
||||
chart_name = generate_chart_name(request.config)
|
||||
accessibility = AccessibilityMetadata(
|
||||
color_blind_safe=True, # Would need actual analysis
|
||||
alt_text=f"Updated chart preview showing {chart_name}",
|
||||
high_contrast_available=False,
|
||||
)
|
||||
|
||||
# Generate previews if requested
|
||||
previews = {}
|
||||
if request.generate_preview and new_form_data_key:
|
||||
try:
|
||||
for format_type in request.preview_formats:
|
||||
if format_type == "url":
|
||||
# Generate screenshot URL using new form_data key
|
||||
mcp_base = get_mcp_service_url()
|
||||
preview_url = (
|
||||
f"{mcp_base}/screenshot/explore/{new_form_data_key}.png"
|
||||
)
|
||||
|
||||
previews[format_type] = URLPreview(
|
||||
preview_url=preview_url,
|
||||
width=800,
|
||||
height=600,
|
||||
supports_interaction=False,
|
||||
)
|
||||
# Other formats would need form_data execution
|
||||
# which is more complex for preview-only mode
|
||||
|
||||
except Exception as e:
|
||||
# Log warning but don't fail the entire request
|
||||
logger.warning(f"Preview generation failed: {e}")
|
||||
|
||||
# Return enhanced data
|
||||
result = {
|
||||
"chart": {
|
||||
"id": None,
|
||||
"slice_name": chart_name,
|
||||
"viz_type": new_form_data.get("viz_type"),
|
||||
"url": explore_url,
|
||||
"uuid": None,
|
||||
"saved": False,
|
||||
"updated": True,
|
||||
},
|
||||
"error": None,
|
||||
# Enhanced fields for better LLM integration
|
||||
"previews": previews,
|
||||
"capabilities": capabilities.model_dump() if capabilities else None,
|
||||
"semantics": semantics.model_dump() if semantics else None,
|
||||
"explore_url": explore_url,
|
||||
"form_data_key": new_form_data_key,
|
||||
"previous_form_data_key": request.form_data_key, # For reference
|
||||
"api_endpoints": {}, # No API endpoints for unsaved charts
|
||||
"performance": performance.model_dump() if performance else None,
|
||||
"accessibility": accessibility.model_dump() if accessibility else None,
|
||||
"success": True,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
return {
|
||||
"chart": None,
|
||||
"error": f"Chart preview update failed: {str(e)}",
|
||||
"performance": {
|
||||
"query_duration_ms": execution_time,
|
||||
"cache_status": "error",
|
||||
"optimization_suggestions": [],
|
||||
},
|
||||
"success": False,
|
||||
"schema_version": "2.0",
|
||||
"api_version": "v1",
|
||||
}
|
||||
463
superset/mcp_service/chart/validation_utils.py
Normal file
463
superset/mcp_service/chart/validation_utils.py
Normal file
@@ -0,0 +1,463 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Chart generation validation utilities with enhanced error context
|
||||
"""
|
||||
|
||||
import difflib
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
from superset.mcp_service.schemas.error_schemas import (
|
||||
ChartGenerationError,
|
||||
ColumnSuggestion,
|
||||
DatasetContext,
|
||||
ValidationError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def validate_chart_config(
|
||||
config: Union[TableChartConfig, XYChartConfig], dataset_id: Union[int, str]
|
||||
) -> Tuple[bool, Optional[ChartGenerationError]]:
|
||||
"""
|
||||
Validate chart configuration against dataset schema with enhanced error reporting.
|
||||
|
||||
Args:
|
||||
config: Chart configuration to validate
|
||||
dataset_id: Dataset ID to validate against
|
||||
|
||||
Returns:
|
||||
Tuple of (is_valid, error_details)
|
||||
"""
|
||||
try:
|
||||
# Get dataset context
|
||||
dataset_context = get_dataset_context(dataset_id)
|
||||
if not dataset_context:
|
||||
return False, ChartGenerationError(
|
||||
error_type="dataset_not_found",
|
||||
message=f"Dataset not found: {dataset_id}",
|
||||
details=(
|
||||
f"No dataset found with identifier '{dataset_id}'. "
|
||||
f"Please verify the dataset ID or UUID is correct."
|
||||
),
|
||||
suggestions=[
|
||||
"Check that the dataset ID is correct",
|
||||
"Verify you have access to this dataset",
|
||||
"Use the list_datasets tool to find available datasets",
|
||||
],
|
||||
error_code="DATASET_NOT_FOUND",
|
||||
)
|
||||
|
||||
# Validate columns
|
||||
validation_errors = []
|
||||
|
||||
if isinstance(config, TableChartConfig):
|
||||
validation_errors.extend(
|
||||
validate_table_columns(config.columns, dataset_context)
|
||||
)
|
||||
if config.filters:
|
||||
validation_errors.extend(
|
||||
validate_filter_columns(config.filters, dataset_context)
|
||||
)
|
||||
elif isinstance(config, XYChartConfig):
|
||||
# Validate X-axis column
|
||||
validation_errors.extend(
|
||||
validate_single_column(config.x, dataset_context, "x_axis")
|
||||
)
|
||||
# Validate Y-axis columns
|
||||
for i, y_col in enumerate(config.y):
|
||||
validation_errors.extend(
|
||||
validate_single_column(y_col, dataset_context, f"y_axis[{i}]")
|
||||
)
|
||||
# Validate group by column if specified
|
||||
if config.group_by:
|
||||
validation_errors.extend(
|
||||
validate_single_column(config.group_by, dataset_context, "group_by")
|
||||
)
|
||||
# Validate filter columns if specified
|
||||
if config.filters:
|
||||
validation_errors.extend(
|
||||
validate_filter_columns(config.filters, dataset_context)
|
||||
)
|
||||
|
||||
# If there are validation errors, return detailed error
|
||||
if validation_errors:
|
||||
return False, ChartGenerationError(
|
||||
error_type="validation_error",
|
||||
message="Chart configuration validation failed",
|
||||
details=(
|
||||
"One or more fields in the chart configuration are invalid. "
|
||||
"See validation_errors for specific issues."
|
||||
),
|
||||
validation_errors=validation_errors,
|
||||
dataset_context=dataset_context,
|
||||
suggestions=generate_overall_suggestions(
|
||||
validation_errors, dataset_context
|
||||
),
|
||||
error_code="VALIDATION_FAILED",
|
||||
)
|
||||
|
||||
return True, None
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Validation error for dataset {dataset_id}")
|
||||
return False, ChartGenerationError(
|
||||
error_type="validation_system_error",
|
||||
message="Chart validation system error",
|
||||
details=f"An internal error occurred during validation: {str(e)}",
|
||||
suggestions=[
|
||||
"Try again with a simpler configuration",
|
||||
"Check that all required fields are provided",
|
||||
"Contact support if the issue persists",
|
||||
],
|
||||
error_code="VALIDATION_SYSTEM_ERROR",
|
||||
)
|
||||
|
||||
|
||||
def get_dataset_context(dataset_id: Union[int, str]) -> Optional[DatasetContext]:
|
||||
"""Get dataset context information for validation."""
|
||||
try:
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
# Find dataset
|
||||
dataset = None
|
||||
if isinstance(dataset_id, int) or (
|
||||
isinstance(dataset_id, str) and dataset_id.isdigit()
|
||||
):
|
||||
numeric_id = int(dataset_id) if isinstance(dataset_id, str) else dataset_id
|
||||
dataset = DatasetDAO.find_by_id(numeric_id)
|
||||
else:
|
||||
# Try UUID lookup
|
||||
dataset = DatasetDAO.find_by_id(dataset_id, id_column="uuid")
|
||||
|
||||
if not dataset:
|
||||
return None
|
||||
|
||||
# Get columns with metadata
|
||||
available_columns = []
|
||||
if hasattr(dataset, "columns") and dataset.columns:
|
||||
for col in dataset.columns:
|
||||
available_columns.append(
|
||||
{
|
||||
"name": col.column_name,
|
||||
"type": col.type or "unknown",
|
||||
"description": col.description,
|
||||
"is_dttm": getattr(col, "is_dttm", False),
|
||||
"python_date_format": getattr(col, "python_date_format", None),
|
||||
"verbose_name": getattr(col, "verbose_name", None),
|
||||
}
|
||||
)
|
||||
|
||||
# Get metrics
|
||||
available_metrics = []
|
||||
if hasattr(dataset, "metrics") and dataset.metrics:
|
||||
for metric in dataset.metrics:
|
||||
available_metrics.append(
|
||||
{
|
||||
"name": metric.metric_name,
|
||||
"expression": metric.expression,
|
||||
"description": metric.description,
|
||||
"verbose_name": getattr(metric, "verbose_name", None),
|
||||
"metric_type": getattr(metric, "metric_type", None),
|
||||
}
|
||||
)
|
||||
|
||||
return DatasetContext(
|
||||
id=dataset.id,
|
||||
table_name=dataset.table_name,
|
||||
schema=getattr(dataset, "schema", None),
|
||||
database_name=dataset.database.database_name
|
||||
if dataset.database
|
||||
else "unknown",
|
||||
available_columns=available_columns,
|
||||
available_metrics=available_metrics,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Error getting dataset context for {dataset_id}")
|
||||
return None
|
||||
|
||||
|
||||
def validate_table_columns(
|
||||
columns: List[ColumnRef], dataset_context: DatasetContext
|
||||
) -> List[ValidationError]:
|
||||
"""Validate table chart columns."""
|
||||
errors = []
|
||||
for i, col in enumerate(columns):
|
||||
errors.extend(validate_single_column(col, dataset_context, f"columns[{i}]"))
|
||||
return errors
|
||||
|
||||
|
||||
def validate_single_column(
|
||||
column: ColumnRef, dataset_context: DatasetContext, field_name: str
|
||||
) -> List[ValidationError]:
|
||||
"""Validate a single column reference."""
|
||||
errors = []
|
||||
|
||||
# Check if column exists
|
||||
available_column_names = [col["name"] for col in dataset_context.available_columns]
|
||||
available_metric_names = [
|
||||
metric["name"] for metric in dataset_context.available_metrics
|
||||
]
|
||||
all_available = available_column_names + available_metric_names
|
||||
|
||||
if column.name not in all_available:
|
||||
# Generate suggestions using fuzzy matching
|
||||
suggestions = get_column_suggestions(column.name, dataset_context)
|
||||
|
||||
errors.append(
|
||||
ValidationError(
|
||||
field=field_name,
|
||||
provided_value=column.name,
|
||||
error_type="column_not_found",
|
||||
message=(
|
||||
f"Column '{column.name}' not found in dataset "
|
||||
f"'{dataset_context.table_name}'"
|
||||
),
|
||||
suggestions=suggestions,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Validate column type compatibility with aggregate function
|
||||
if column.aggregate:
|
||||
column_info = next(
|
||||
(
|
||||
col
|
||||
for col in dataset_context.available_columns
|
||||
if col["name"] == column.name
|
||||
),
|
||||
None,
|
||||
)
|
||||
if column_info:
|
||||
errors.extend(
|
||||
validate_aggregate_compatibility(column, column_info, field_name)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_filter_columns(
|
||||
filters: List[FilterConfig], dataset_context: DatasetContext
|
||||
) -> List[ValidationError]:
|
||||
"""Validate filter column references."""
|
||||
errors = []
|
||||
available_column_names = [col["name"] for col in dataset_context.available_columns]
|
||||
|
||||
for i, filter_config in enumerate(filters):
|
||||
if filter_config.column not in available_column_names:
|
||||
suggestions = get_column_suggestions(filter_config.column, dataset_context)
|
||||
errors.append(
|
||||
ValidationError(
|
||||
field=f"filters[{i}].column",
|
||||
provided_value=filter_config.column,
|
||||
error_type="filter_column_not_found",
|
||||
message=(
|
||||
f"Filter column '{filter_config.column}' not found in dataset"
|
||||
),
|
||||
suggestions=suggestions,
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def validate_aggregate_compatibility(
|
||||
column: ColumnRef, column_info: Dict[str, Any], field_name: str
|
||||
) -> List[ValidationError]:
|
||||
"""Validate that aggregate function is compatible with column type."""
|
||||
errors = []
|
||||
|
||||
# Define aggregates that work with different types
|
||||
numeric_aggregates = {"SUM", "AVG", "MIN", "MAX", "STDDEV", "VAR"}
|
||||
# text_aggregates = {"COUNT", "COUNT_DISTINCT"} # Not used
|
||||
# datetime_aggregates = {"MIN", "MAX", "COUNT", "COUNT_DISTINCT"} # Not used
|
||||
|
||||
column_type = column_info.get("type", "").lower()
|
||||
aggregate = column.aggregate.upper() if column.aggregate else "COUNT"
|
||||
|
||||
# Check compatibility - expanded numeric type detection
|
||||
numeric_indicators = [
|
||||
"int",
|
||||
"float",
|
||||
"decimal",
|
||||
"number",
|
||||
"numeric",
|
||||
"double",
|
||||
"precision",
|
||||
"bigint",
|
||||
"smallint",
|
||||
"real",
|
||||
"money",
|
||||
"serial",
|
||||
"auto_increment",
|
||||
]
|
||||
datetime_indicators = ["date", "time", "timestamp", "datetime", "interval"]
|
||||
|
||||
is_numeric = any(indicator in column_type for indicator in numeric_indicators)
|
||||
is_datetime = any(indicator in column_type for indicator in datetime_indicators)
|
||||
is_text = not is_numeric and not is_datetime
|
||||
|
||||
incompatible = False
|
||||
suggestions = []
|
||||
|
||||
if is_text and aggregate in numeric_aggregates:
|
||||
incompatible = True
|
||||
suggestions = [
|
||||
"Use COUNT or COUNT_DISTINCT for text columns",
|
||||
"SUM/AVG only work with numeric data types",
|
||||
]
|
||||
elif is_datetime and aggregate in numeric_aggregates - {"MIN", "MAX"}:
|
||||
incompatible = True
|
||||
suggestions = [
|
||||
"Use MIN, MAX, COUNT, or COUNT_DISTINCT for datetime columns",
|
||||
"SUM/AVG are not meaningful for dates and times",
|
||||
]
|
||||
elif is_numeric and aggregate in numeric_aggregates:
|
||||
# This is valid - numeric column with numeric aggregate
|
||||
incompatible = False
|
||||
|
||||
if incompatible:
|
||||
errors.append(
|
||||
ValidationError(
|
||||
field=field_name,
|
||||
provided_value=aggregate,
|
||||
error_type="aggregate_type_mismatch",
|
||||
message=(
|
||||
f"Aggregate '{aggregate}' is not compatible with column type "
|
||||
f"'{column_type}'"
|
||||
),
|
||||
suggestions=[
|
||||
ColumnSuggestion(
|
||||
name=suggestion,
|
||||
type="aggregate_function",
|
||||
similarity_score=1.0,
|
||||
description=f"Compatible aggregate for {column_type} columns",
|
||||
)
|
||||
for suggestion in suggestions
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
return errors
|
||||
|
||||
|
||||
def get_column_suggestions(
|
||||
invalid_column: str, dataset_context: DatasetContext, max_suggestions: int = 5
|
||||
) -> List[ColumnSuggestion]:
|
||||
"""Generate column suggestions using fuzzy matching."""
|
||||
suggestions = []
|
||||
|
||||
# Combine columns and metrics for suggestions
|
||||
all_items = []
|
||||
|
||||
# Add columns
|
||||
for col in dataset_context.available_columns:
|
||||
all_items.append(
|
||||
{
|
||||
"name": col["name"],
|
||||
"type": col["type"],
|
||||
"description": col.get("description")
|
||||
or f"Column of type {col['type']}",
|
||||
"category": "column",
|
||||
}
|
||||
)
|
||||
|
||||
# Add metrics
|
||||
for metric in dataset_context.available_metrics:
|
||||
all_items.append(
|
||||
{
|
||||
"name": metric["name"],
|
||||
"type": "metric",
|
||||
"description": metric.get("description")
|
||||
or f"Calculated metric: {metric.get('expression', 'N/A')}",
|
||||
"category": "metric",
|
||||
}
|
||||
)
|
||||
|
||||
# Use difflib to find close matches
|
||||
names = [item["name"] for item in all_items]
|
||||
close_matches = difflib.get_close_matches(
|
||||
invalid_column, names, n=max_suggestions, cutoff=0.3
|
||||
)
|
||||
|
||||
# Create suggestions with metadata
|
||||
for match in close_matches:
|
||||
item = next(item for item in all_items if item["name"] == match)
|
||||
similarity = difflib.SequenceMatcher(
|
||||
None, invalid_column.lower(), match.lower()
|
||||
).ratio()
|
||||
|
||||
suggestions.append(
|
||||
ColumnSuggestion(
|
||||
name=match,
|
||||
type=item["type"],
|
||||
similarity_score=similarity,
|
||||
description=item["description"],
|
||||
)
|
||||
)
|
||||
|
||||
# Sort by similarity score
|
||||
suggestions.sort(key=lambda x: x.similarity_score, reverse=True)
|
||||
|
||||
return suggestions
|
||||
|
||||
|
||||
def generate_overall_suggestions(
|
||||
validation_errors: List[ValidationError], dataset_context: DatasetContext
|
||||
) -> List[str]:
|
||||
"""Generate overall suggestions based on validation errors."""
|
||||
suggestions = []
|
||||
|
||||
# Count error types
|
||||
error_types = [error.error_type for error in validation_errors]
|
||||
|
||||
if "column_not_found" in error_types or "filter_column_not_found" in error_types:
|
||||
columns_list = [col["name"] for col in dataset_context.available_columns[:10]]
|
||||
suggestions.append(
|
||||
f"Available columns in '{dataset_context.table_name}': "
|
||||
f"{', '.join(columns_list)}"
|
||||
)
|
||||
if len(dataset_context.available_columns) > 10:
|
||||
suggestions.append("Use get_dataset_info tool to see all available columns")
|
||||
|
||||
if "aggregate_type_mismatch" in error_types:
|
||||
suggestions.append(
|
||||
"Check that aggregate functions match column data types "
|
||||
"(SUM/AVG for numbers, COUNT for text)"
|
||||
)
|
||||
|
||||
if dataset_context.available_metrics:
|
||||
metrics_list = [
|
||||
metric["name"] for metric in dataset_context.available_metrics[:5]
|
||||
]
|
||||
suggestions.append(f"Available metrics: {', '.join(metrics_list)}")
|
||||
|
||||
suggestions.append(
|
||||
"Use get_dataset_available_filters tool to see valid filter operators "
|
||||
"for each column"
|
||||
)
|
||||
|
||||
return suggestions
|
||||
33
superset/mcp_service/commands/create_form_data.py
Normal file
33
superset/mcp_service/commands/create_form_data.py
Normal file
@@ -0,0 +1,33 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP-specific form data command that extends the base CreateFormDataCommand
|
||||
"""
|
||||
|
||||
from superset.commands.explore.form_data.create import CreateFormDataCommand
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
|
||||
class MCPCreateFormDataCommand(CreateFormDataCommand):
|
||||
"""
|
||||
MCP-specific CreateFormDataCommand that uses user_id instead of session._id
|
||||
"""
|
||||
|
||||
def _get_session_id(self) -> str:
|
||||
"""Override to use user_id instead of Flask session for MCP context."""
|
||||
return str(get_user_id())
|
||||
82
superset/mcp_service/config.py
Normal file
82
superset/mcp_service/config.py
Normal file
@@ -0,0 +1,82 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""MCP Service Configuration."""
|
||||
|
||||
import logging
|
||||
from typing import Any, Optional
|
||||
|
||||
from flask import Flask
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]:
|
||||
"""Default MCP auth factory using app.config values."""
|
||||
if not app.config.get("MCP_AUTH_ENABLED", False):
|
||||
return None
|
||||
|
||||
jwks_uri = app.config.get("MCP_JWKS_URI")
|
||||
public_key = app.config.get("MCP_JWT_PUBLIC_KEY")
|
||||
|
||||
if not (jwks_uri or public_key):
|
||||
logger.warning("MCP_AUTH_ENABLED is True but no JWT keys configured")
|
||||
return None
|
||||
|
||||
try:
|
||||
from fastmcp.server.auth.providers.bearer import BearerAuthProvider
|
||||
|
||||
return BearerAuthProvider(
|
||||
jwks_uri=jwks_uri,
|
||||
public_key=public_key,
|
||||
issuer=app.config.get("MCP_JWT_ISSUER"),
|
||||
audience=app.config.get("MCP_JWT_AUDIENCE"),
|
||||
algorithm=app.config.get("MCP_JWT_ALGORITHM", "RS256"),
|
||||
required_scopes=app.config.get("MCP_REQUIRED_SCOPES", []),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create MCP auth provider: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def default_user_resolver(access_token: Any) -> Optional[str]:
|
||||
"""Extract username from JWT token claims."""
|
||||
if hasattr(access_token, "subject"):
|
||||
return access_token.subject
|
||||
if hasattr(access_token, "client_id"):
|
||||
return access_token.client_id
|
||||
if hasattr(access_token, "payload") and isinstance(access_token.payload, dict):
|
||||
return (
|
||||
access_token.payload.get("sub")
|
||||
or access_token.payload.get("email")
|
||||
or access_token.payload.get("username")
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
DEFAULT_CONFIG = {
|
||||
"MCP_AUTH_ENABLED": False,
|
||||
"MCP_AUTH_FACTORY": create_default_mcp_auth_factory,
|
||||
"MCP_USER_RESOLVER": default_user_resolver,
|
||||
"MCP_JWKS_URI": None,
|
||||
"MCP_JWT_PUBLIC_KEY": None,
|
||||
"MCP_JWT_ISSUER": None,
|
||||
"MCP_JWT_AUDIENCE": None,
|
||||
"MCP_JWT_ALGORITHM": "RS256",
|
||||
"MCP_REQUIRED_SCOPES": [],
|
||||
"MCP_ADMIN_USERNAME": "admin",
|
||||
}
|
||||
13
superset/mcp_service/dashboard/tool/__init__.py
Normal file
13
superset/mcp_service/dashboard/tool/__init__.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from .add_chart_to_existing_dashboard import add_chart_to_existing_dashboard
|
||||
from .generate_dashboard import generate_dashboard
|
||||
from .get_dashboard_available_filters import get_dashboard_available_filters
|
||||
from .get_dashboard_info import get_dashboard_info
|
||||
from .list_dashboards import list_dashboards
|
||||
|
||||
__all__ = [
|
||||
"list_dashboards",
|
||||
"get_dashboard_info",
|
||||
"get_dashboard_available_filters",
|
||||
"generate_dashboard",
|
||||
"add_chart_to_existing_dashboard",
|
||||
]
|
||||
@@ -0,0 +1,281 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: add_chart_to_existing_dashboard
|
||||
|
||||
This tool adds a chart to an existing dashboard with automatic layout positioning.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.dashboard_schemas import DashboardInfo
|
||||
from superset.utils import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AddChartToDashboardRequest(BaseModel):
|
||||
"""Request schema for adding a chart to an existing dashboard."""
|
||||
|
||||
dashboard_id: int = Field(
|
||||
..., description="ID of the dashboard to add the chart to"
|
||||
)
|
||||
chart_id: int = Field(..., description="ID of the chart to add to the dashboard")
|
||||
target_tab: Optional[str] = Field(
|
||||
None, description="Target tab name (if dashboard has tabs)"
|
||||
)
|
||||
|
||||
|
||||
class AddChartToDashboardResponse(BaseModel):
|
||||
"""Response schema for adding chart to dashboard."""
|
||||
|
||||
dashboard: Optional[DashboardInfo] = Field(
|
||||
None, description="The updated dashboard info, if successful"
|
||||
)
|
||||
dashboard_url: Optional[str] = Field(
|
||||
None, description="URL to view the updated dashboard"
|
||||
)
|
||||
position: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Position information for the added chart"
|
||||
)
|
||||
error: Optional[str] = Field(None, description="Error message, if operation failed")
|
||||
|
||||
|
||||
def _find_next_position(layout: Dict[str, Any]) -> tuple[int, int]:
|
||||
"""
|
||||
Find the next available position in the dashboard layout.
|
||||
|
||||
Returns:
|
||||
Tuple of (x, y) coordinates for the new chart
|
||||
"""
|
||||
# Standard chart dimensions
|
||||
chart_height = 16
|
||||
grid_columns = 48
|
||||
|
||||
# Find the maximum Y position of existing charts
|
||||
max_y = 0
|
||||
for key, item in layout.items():
|
||||
if key.endswith("_POSITION") and isinstance(item, dict):
|
||||
y_pos = item.get("y", 0)
|
||||
height = item.get("h", chart_height)
|
||||
max_y = max(max_y, y_pos + height)
|
||||
|
||||
# Try to place in a 2-column layout
|
||||
# Check if there's space in the left column at max_y
|
||||
left_occupied = False
|
||||
right_occupied = False
|
||||
|
||||
for key, item in layout.items():
|
||||
if key.endswith("_POSITION") and isinstance(item, dict):
|
||||
y_pos = item.get("y", 0)
|
||||
x_pos = item.get("x", 0)
|
||||
height = item.get("h", chart_height)
|
||||
|
||||
# Check if this chart occupies the position we're considering
|
||||
if y_pos <= max_y < y_pos + height:
|
||||
if x_pos < grid_columns // 2: # Left column
|
||||
left_occupied = True
|
||||
else: # Right column
|
||||
right_occupied = True
|
||||
|
||||
# Choose position based on availability
|
||||
if not left_occupied:
|
||||
return (0, max_y) # Left column
|
||||
elif not right_occupied:
|
||||
return (24, max_y) # Right column
|
||||
else:
|
||||
return (0, max_y + chart_height) # New row, left column
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def add_chart_to_existing_dashboard(
|
||||
request: AddChartToDashboardRequest,
|
||||
) -> AddChartToDashboardResponse:
|
||||
"""
|
||||
Add a chart to an existing dashboard.
|
||||
|
||||
This tool adds the specified chart to an existing dashboard, automatically
|
||||
positioning it in the layout. The chart will be placed in the next
|
||||
available position using a 2-column grid layout.
|
||||
|
||||
Args:
|
||||
request: AddChartToDashboardRequest with dashboard_id, chart_id, and
|
||||
optional target_tab
|
||||
|
||||
Returns:
|
||||
AddChartToDashboardResponse with updated dashboard info and position
|
||||
"""
|
||||
try:
|
||||
from superset.commands.dashboard.update import UpdateDashboardCommand
|
||||
from superset.daos.chart import ChartDAO
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
# Validate dashboard exists
|
||||
dashboard = DashboardDAO.find_by_id(request.dashboard_id)
|
||||
if not dashboard:
|
||||
return AddChartToDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
position=None,
|
||||
error=f"Dashboard with ID {request.dashboard_id} not found",
|
||||
)
|
||||
|
||||
# Validate chart exists
|
||||
chart = ChartDAO.find_by_id(request.chart_id)
|
||||
if not chart:
|
||||
return AddChartToDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
position=None,
|
||||
error=f"Chart with ID {request.chart_id} not found",
|
||||
)
|
||||
|
||||
# Check if chart is already in dashboard
|
||||
current_chart_ids = [slice.id for slice in dashboard.slices]
|
||||
if request.chart_id in current_chart_ids:
|
||||
return AddChartToDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
position=None,
|
||||
error=(
|
||||
f"Chart {request.chart_id} is already in dashboard "
|
||||
f"{request.dashboard_id}"
|
||||
),
|
||||
)
|
||||
|
||||
# Parse current layout
|
||||
try:
|
||||
current_layout = json.loads(dashboard.position_json or "{}")
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
current_layout = {}
|
||||
|
||||
# Find position for new chart
|
||||
x_pos, y_pos = _find_next_position(current_layout)
|
||||
|
||||
# Create chart component key
|
||||
chart_key = f"CHART-{request.chart_id}"
|
||||
chart_width = 24
|
||||
chart_height = 16
|
||||
|
||||
# Add chart to layout
|
||||
current_layout[chart_key] = {
|
||||
"children": [],
|
||||
"id": chart_key,
|
||||
"meta": {
|
||||
"chartId": request.chart_id,
|
||||
"height": chart_height,
|
||||
"sliceName": chart.slice_name or f"Chart {request.chart_id}",
|
||||
"uuid": str(chart.uuid) if chart.uuid else f"chart-{request.chart_id}",
|
||||
"width": chart_width,
|
||||
},
|
||||
"parents": ["ROOT_ID"],
|
||||
"type": "CHART",
|
||||
}
|
||||
|
||||
# Add position information
|
||||
position_info = {"h": chart_height, "w": chart_width, "x": x_pos, "y": y_pos}
|
||||
current_layout[f"{chart_key}_POSITION"] = position_info
|
||||
|
||||
# Update ROOT_ID children if it exists
|
||||
if "ROOT_ID" in current_layout:
|
||||
if "children" not in current_layout["ROOT_ID"]:
|
||||
current_layout["ROOT_ID"]["children"] = []
|
||||
current_layout["ROOT_ID"]["children"].append(chart_key)
|
||||
else:
|
||||
# Create ROOT_ID if it doesn't exist
|
||||
current_layout["ROOT_ID"] = {
|
||||
"children": [chart_key],
|
||||
"id": "ROOT_ID",
|
||||
"type": "ROOT",
|
||||
}
|
||||
|
||||
# Prepare update data
|
||||
update_data = {
|
||||
"position_json": json.dumps(current_layout),
|
||||
"slices": current_chart_ids + [request.chart_id], # Add new chart ID
|
||||
}
|
||||
|
||||
# Update the dashboard
|
||||
command = UpdateDashboardCommand(request.dashboard_id, update_data)
|
||||
updated_dashboard = command.run()
|
||||
|
||||
# Convert to response format
|
||||
from superset.mcp_service.schemas.dashboard_schemas import (
|
||||
serialize_tag_object,
|
||||
serialize_user_object,
|
||||
)
|
||||
|
||||
dashboard_info = DashboardInfo(
|
||||
id=updated_dashboard.id,
|
||||
dashboard_title=updated_dashboard.dashboard_title,
|
||||
slug=updated_dashboard.slug,
|
||||
description=updated_dashboard.description,
|
||||
published=updated_dashboard.published,
|
||||
created_on=updated_dashboard.created_on,
|
||||
changed_on=updated_dashboard.changed_on,
|
||||
created_by=updated_dashboard.created_by.username
|
||||
if updated_dashboard.created_by
|
||||
else None,
|
||||
changed_by=updated_dashboard.changed_by.username
|
||||
if updated_dashboard.changed_by
|
||||
else None,
|
||||
uuid=str(updated_dashboard.uuid) if updated_dashboard.uuid else None,
|
||||
url=f"/superset/dashboard/{updated_dashboard.id}/",
|
||||
chart_count=len(updated_dashboard.slices),
|
||||
owners=[
|
||||
serialize_user_object(owner)
|
||||
for owner in getattr(updated_dashboard, "owners", [])
|
||||
if serialize_user_object(owner) is not None
|
||||
],
|
||||
tags=[
|
||||
serialize_tag_object(tag)
|
||||
for tag in getattr(updated_dashboard, "tags", [])
|
||||
if serialize_tag_object(tag) is not None
|
||||
],
|
||||
roles=[],
|
||||
charts=[],
|
||||
)
|
||||
|
||||
dashboard_url = f"/superset/dashboard/{updated_dashboard.id}/"
|
||||
|
||||
logger.info(
|
||||
f"Added chart {request.chart_id} to dashboard {request.dashboard_id} "
|
||||
f"at position ({x_pos}, {y_pos})"
|
||||
)
|
||||
|
||||
return AddChartToDashboardResponse(
|
||||
dashboard=dashboard_info,
|
||||
dashboard_url=dashboard_url,
|
||||
position=position_info,
|
||||
error=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error adding chart to dashboard: {e}", exc_info=True)
|
||||
return AddChartToDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
position=None,
|
||||
error=f"Failed to add chart to dashboard: {str(e)}",
|
||||
)
|
||||
235
superset/mcp_service/dashboard/tool/generate_dashboard.py
Normal file
235
superset/mcp_service/dashboard/tool/generate_dashboard.py
Normal file
@@ -0,0 +1,235 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: generate_dashboard
|
||||
|
||||
This tool creates a new dashboard with specified charts and layout configuration.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.dashboard_schemas import DashboardInfo
|
||||
from superset.utils import json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class GenerateDashboardRequest(BaseModel):
|
||||
"""Request schema for generating a dashboard."""
|
||||
|
||||
chart_ids: List[int] = Field(
|
||||
..., description="List of chart IDs to include in the dashboard", min_length=1
|
||||
)
|
||||
dashboard_title: str = Field(..., description="Title for the new dashboard")
|
||||
description: Optional[str] = Field(
|
||||
None, description="Description for the dashboard"
|
||||
)
|
||||
published: bool = Field(
|
||||
default=True, description="Whether to publish the dashboard"
|
||||
)
|
||||
|
||||
|
||||
class GenerateDashboardResponse(BaseModel):
|
||||
"""Response schema for dashboard generation."""
|
||||
|
||||
dashboard: Optional[DashboardInfo] = Field(
|
||||
None, description="The created dashboard info, if successful"
|
||||
)
|
||||
dashboard_url: Optional[str] = Field(None, description="URL to view the dashboard")
|
||||
error: Optional[str] = Field(None, description="Error message, if creation failed")
|
||||
|
||||
|
||||
def _create_dashboard_layout(chart_ids: List[int]) -> Dict[str, Any]:
|
||||
"""
|
||||
Create a simple dashboard layout with charts arranged in a grid.
|
||||
|
||||
This creates a basic 2-column layout where charts are arranged
|
||||
vertically in alternating columns.
|
||||
"""
|
||||
layout = {}
|
||||
|
||||
# Grid configuration
|
||||
chart_width = 24 # Half width for 2-column layout
|
||||
chart_height = 16 # Standard chart height
|
||||
|
||||
for i, chart_id in enumerate(chart_ids):
|
||||
# Alternate between left (x=0) and right (x=24) columns
|
||||
x_position = 0 if i % 2 == 0 else 24
|
||||
# Stack charts vertically in each column
|
||||
y_position = (i // 2) * chart_height
|
||||
|
||||
# Create chart component in layout
|
||||
chart_key = f"CHART-{chart_id}"
|
||||
layout[chart_key] = {
|
||||
"children": [],
|
||||
"id": chart_key,
|
||||
"meta": {
|
||||
"chartId": chart_id,
|
||||
"height": chart_height,
|
||||
"sliceName": f"Chart {chart_id}",
|
||||
"uuid": f"chart-{chart_id}",
|
||||
"width": chart_width,
|
||||
},
|
||||
"parents": ["ROOT_ID"],
|
||||
"type": "CHART",
|
||||
}
|
||||
|
||||
# Add position information to the chart meta
|
||||
chart_meta = layout[chart_key]["meta"]
|
||||
if isinstance(chart_meta, dict):
|
||||
chart_meta.update(
|
||||
{
|
||||
"h": chart_height,
|
||||
"w": chart_width,
|
||||
"x": x_position,
|
||||
"y": y_position,
|
||||
}
|
||||
)
|
||||
|
||||
# Add root layout container
|
||||
layout["ROOT_ID"] = {
|
||||
"children": [f"CHART-{chart_id}" for chart_id in chart_ids],
|
||||
"id": "ROOT_ID",
|
||||
"type": "ROOT",
|
||||
}
|
||||
|
||||
return layout
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def generate_dashboard(request: GenerateDashboardRequest) -> GenerateDashboardResponse:
|
||||
"""
|
||||
Generate a new dashboard with the specified charts.
|
||||
|
||||
This tool creates a dashboard with the provided charts arranged in a
|
||||
simple 2-column grid layout. All charts must exist and be accessible
|
||||
to the current user.
|
||||
|
||||
Args:
|
||||
request: GenerateDashboardRequest with chart_ids, title, and options
|
||||
|
||||
Returns:
|
||||
GenerateDashboardResponse with the created dashboard info and URL
|
||||
"""
|
||||
try:
|
||||
from superset.commands.dashboard.create import CreateDashboardCommand
|
||||
from superset.daos.chart import ChartDAO
|
||||
|
||||
# Validate all chart IDs exist
|
||||
missing_charts = []
|
||||
for chart_id in request.chart_ids:
|
||||
chart = ChartDAO.find_by_id(chart_id)
|
||||
if not chart:
|
||||
missing_charts.append(chart_id)
|
||||
|
||||
if missing_charts:
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
error=f"Charts not found: {missing_charts}",
|
||||
)
|
||||
|
||||
# Create dashboard layout
|
||||
layout = _create_dashboard_layout(request.chart_ids)
|
||||
|
||||
# Prepare dashboard data
|
||||
dashboard_data = {
|
||||
"dashboard_title": request.dashboard_title,
|
||||
"slug": None, # Let Superset auto-generate slug
|
||||
"css": "",
|
||||
"json_metadata": json.dumps(
|
||||
{
|
||||
"filter_scopes": {},
|
||||
"expanded_slices": {},
|
||||
"refresh_frequency": 0,
|
||||
"timed_refresh_immune_slices": [],
|
||||
"color_scheme": None,
|
||||
"label_colors": {},
|
||||
"shared_label_colors": {},
|
||||
"color_scheme_domain": [],
|
||||
"cross_filters_enabled": False,
|
||||
}
|
||||
),
|
||||
"position_json": json.dumps(layout),
|
||||
"published": request.published,
|
||||
"slices": request.chart_ids,
|
||||
}
|
||||
|
||||
if request.description:
|
||||
dashboard_data["description"] = request.description
|
||||
|
||||
# Create the dashboard using Superset's command pattern
|
||||
command = CreateDashboardCommand(dashboard_data)
|
||||
dashboard = command.run()
|
||||
|
||||
# Convert to our response format
|
||||
from superset.mcp_service.schemas.dashboard_schemas import (
|
||||
serialize_tag_object,
|
||||
serialize_user_object,
|
||||
)
|
||||
|
||||
dashboard_info = DashboardInfo(
|
||||
id=dashboard.id,
|
||||
dashboard_title=dashboard.dashboard_title,
|
||||
slug=dashboard.slug,
|
||||
description=dashboard.description,
|
||||
published=dashboard.published,
|
||||
created_on=dashboard.created_on,
|
||||
changed_on=dashboard.changed_on,
|
||||
created_by=dashboard.created_by.username if dashboard.created_by else None,
|
||||
changed_by=dashboard.changed_by.username if dashboard.changed_by else None,
|
||||
uuid=str(dashboard.uuid) if dashboard.uuid else None,
|
||||
url=f"/superset/dashboard/{dashboard.id}/",
|
||||
chart_count=len(request.chart_ids),
|
||||
owners=[
|
||||
serialize_user_object(owner)
|
||||
for owner in getattr(dashboard, "owners", [])
|
||||
if serialize_user_object(owner) is not None
|
||||
],
|
||||
tags=[
|
||||
serialize_tag_object(tag)
|
||||
for tag in getattr(dashboard, "tags", [])
|
||||
if serialize_tag_object(tag) is not None
|
||||
],
|
||||
roles=[], # Dashboard roles not typically set at creation
|
||||
charts=[], # Chart details not needed in response
|
||||
)
|
||||
|
||||
dashboard_url = f"/superset/dashboard/{dashboard.id}/"
|
||||
|
||||
logger.info(
|
||||
f"Created dashboard {dashboard.id} with {len(request.chart_ids)} charts"
|
||||
)
|
||||
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=dashboard_info, dashboard_url=dashboard_url, error=None
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error creating dashboard: {e}", exc_info=True)
|
||||
return GenerateDashboardResponse(
|
||||
dashboard=None,
|
||||
dashboard_url=None,
|
||||
error=f"Failed to create dashboard: {str(e)}",
|
||||
)
|
||||
@@ -0,0 +1,37 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# ... existing license ...
|
||||
"""
|
||||
Get available filters FastMCP tool
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import ModelGetAvailableFiltersTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.dashboard_schemas import (
|
||||
DashboardAvailableFilters,
|
||||
GetDashboardAvailableFiltersRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def get_dashboard_available_filters(
|
||||
request: GetDashboardAvailableFiltersRequest,
|
||||
) -> DashboardAvailableFilters:
|
||||
"""
|
||||
Get information about available dashboard filters and their operators
|
||||
Returns:
|
||||
DashboardAvailableFilters
|
||||
"""
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
tool = ModelGetAvailableFiltersTool(
|
||||
dao_class=DashboardDAO,
|
||||
output_schema=DashboardAvailableFilters,
|
||||
logger=logger,
|
||||
)
|
||||
return tool.run()
|
||||
126
superset/mcp_service/dashboard/tool/get_dashboard_info.py
Normal file
126
superset/mcp_service/dashboard/tool/get_dashboard_info.py
Normal file
@@ -0,0 +1,126 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get dashboard info FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for getting detailed information
|
||||
about a specific dashboard.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import ModelGetInfoTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas import DashboardError, DashboardInfo
|
||||
from superset.mcp_service.schemas.chart_schemas import serialize_chart_object
|
||||
from superset.mcp_service.schemas.dashboard_schemas import (
|
||||
GetDashboardInfoRequest,
|
||||
)
|
||||
from superset.mcp_service.schemas.system_schemas import (
|
||||
RoleInfo,
|
||||
TagInfo,
|
||||
UserInfo,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def dashboard_serializer(dashboard: Any) -> DashboardInfo:
|
||||
return DashboardInfo(
|
||||
id=dashboard.id,
|
||||
dashboard_title=dashboard.dashboard_title or "Untitled",
|
||||
slug=dashboard.slug or "",
|
||||
description=dashboard.description,
|
||||
css=dashboard.css,
|
||||
certified_by=dashboard.certified_by,
|
||||
certification_details=dashboard.certification_details,
|
||||
json_metadata=dashboard.json_metadata,
|
||||
position_json=dashboard.position_json,
|
||||
published=dashboard.published,
|
||||
is_managed_externally=dashboard.is_managed_externally,
|
||||
external_url=dashboard.external_url,
|
||||
created_on=dashboard.created_on,
|
||||
changed_on=dashboard.changed_on,
|
||||
created_by=getattr(dashboard.created_by, "username", None)
|
||||
if dashboard.created_by
|
||||
else None,
|
||||
changed_by=getattr(dashboard.changed_by, "username", None)
|
||||
if dashboard.changed_by
|
||||
else None,
|
||||
uuid=str(dashboard.uuid) if dashboard.uuid else None,
|
||||
url=dashboard.url,
|
||||
thumbnail_url=dashboard.thumbnail_url,
|
||||
created_on_humanized=dashboard.created_on_humanized,
|
||||
changed_on_humanized=dashboard.changed_on_humanized,
|
||||
chart_count=len(dashboard.slices) if dashboard.slices else 0,
|
||||
owners=[
|
||||
UserInfo.model_validate(owner, from_attributes=True)
|
||||
for owner in dashboard.owners
|
||||
]
|
||||
if dashboard.owners
|
||||
else [],
|
||||
tags=[
|
||||
TagInfo.model_validate(tag, from_attributes=True) for tag in dashboard.tags
|
||||
]
|
||||
if dashboard.tags
|
||||
else [],
|
||||
roles=[
|
||||
RoleInfo.model_validate(role, from_attributes=True)
|
||||
for role in dashboard.roles
|
||||
]
|
||||
if dashboard.roles
|
||||
else [],
|
||||
charts=[serialize_chart_object(chart) for chart in dashboard.slices]
|
||||
if dashboard.slices
|
||||
else [],
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def get_dashboard_info(
|
||||
request: GetDashboardInfoRequest,
|
||||
) -> DashboardInfo | DashboardError:
|
||||
"""
|
||||
Get detailed information about a specific dashboard with metadata cache control.
|
||||
|
||||
Supports lookup by:
|
||||
- Numeric ID (e.g., 123)
|
||||
- UUID string (e.g., "a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||
- Slug string (e.g., "my-dashboard")
|
||||
|
||||
Metadata Cache Control:
|
||||
- use_cache: Whether to use metadata cache for faster responses
|
||||
- refresh_metadata: Force refresh of metadata cache for fresh data
|
||||
|
||||
Returns a DashboardInfo model or DashboardError on error.
|
||||
"""
|
||||
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
tool = ModelGetInfoTool(
|
||||
dao_class=DashboardDAO,
|
||||
output_schema=DashboardInfo,
|
||||
error_schema=DashboardError,
|
||||
serializer=dashboard_serializer,
|
||||
supports_slug=True, # Dashboards support slugs
|
||||
logger=logger,
|
||||
)
|
||||
return tool.run(request.identifier)
|
||||
147
superset/mcp_service/dashboard/tool/list_dashboards.py
Normal file
147
superset/mcp_service/dashboard/tool/list_dashboards.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
List dashboards FastMCP tool (Advanced with metadata cache control)
|
||||
|
||||
This module contains the FastMCP tool for listing dashboards using
|
||||
advanced filtering with clear, unambiguous request schema and metadata cache control.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import ModelListTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas import (
|
||||
DashboardFilter,
|
||||
DashboardInfo,
|
||||
DashboardList,
|
||||
)
|
||||
from superset.mcp_service.schemas.chart_schemas import serialize_chart_object
|
||||
from superset.mcp_service.schemas.dashboard_schemas import (
|
||||
ListDashboardsRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def serialize_dashboard_object(dashboard: Any) -> DashboardInfo:
|
||||
"""Simple dashboard serializer that safely handles object attributes."""
|
||||
return DashboardInfo(
|
||||
id=getattr(dashboard, "id", None),
|
||||
dashboard_title=getattr(dashboard, "dashboard_title", None),
|
||||
slug=getattr(dashboard, "slug", None),
|
||||
url=getattr(dashboard, "url", None),
|
||||
published=getattr(dashboard, "published", None),
|
||||
changed_by_name=getattr(dashboard, "changed_by_name", None),
|
||||
changed_on=getattr(dashboard, "changed_on", None),
|
||||
changed_on_humanized=getattr(dashboard, "changed_on_humanized", None),
|
||||
created_by_name=getattr(dashboard, "created_by_name", None),
|
||||
created_on=getattr(dashboard, "created_on", None),
|
||||
created_on_humanized=getattr(dashboard, "created_on_humanized", None),
|
||||
description=getattr(dashboard, "description", None),
|
||||
css=getattr(dashboard, "css", None),
|
||||
certified_by=getattr(dashboard, "certified_by", None),
|
||||
certification_details=getattr(dashboard, "certification_details", None),
|
||||
json_metadata=getattr(dashboard, "json_metadata", None),
|
||||
position_json=getattr(dashboard, "position_json", None),
|
||||
is_managed_externally=getattr(dashboard, "is_managed_externally", None),
|
||||
external_url=getattr(dashboard, "external_url", None),
|
||||
uuid=str(getattr(dashboard, "uuid", ""))
|
||||
if getattr(dashboard, "uuid", None)
|
||||
else None,
|
||||
thumbnail_url=getattr(dashboard, "thumbnail_url", None),
|
||||
chart_count=len(getattr(dashboard, "slices", [])),
|
||||
owners=getattr(dashboard, "owners", []),
|
||||
tags=getattr(dashboard, "tags", []),
|
||||
roles=getattr(dashboard, "roles", []),
|
||||
charts=[
|
||||
serialize_chart_object(chart) for chart in getattr(dashboard, "slices", [])
|
||||
]
|
||||
if getattr(dashboard, "slices", None)
|
||||
else [],
|
||||
)
|
||||
|
||||
|
||||
DEFAULT_DASHBOARD_COLUMNS = [
|
||||
"id",
|
||||
"dashboard_title",
|
||||
"slug",
|
||||
"uuid",
|
||||
"published",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
|
||||
SORTABLE_DASHBOARD_COLUMNS = [
|
||||
"id",
|
||||
"dashboard_title",
|
||||
"slug",
|
||||
"published",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def list_dashboards(request: ListDashboardsRequest) -> DashboardList:
|
||||
"""
|
||||
List dashboards with advanced filtering, search, and metadata cache control.
|
||||
|
||||
Uses a clear request object schema to avoid validation ambiguity with
|
||||
arrays/strings. All parameters are properly typed and have sensible defaults.
|
||||
|
||||
Sortable columns for order_column: id, dashboard_title, slug, published,
|
||||
changed_on, created_on
|
||||
|
||||
Metadata Cache Control:
|
||||
- use_cache: Whether to use metadata cache for faster responses
|
||||
- refresh_metadata: Force refresh of metadata cache for fresh data
|
||||
|
||||
When refresh_metadata=True, the tool will fetch fresh metadata from the database
|
||||
which is useful when database schema has changed.
|
||||
"""
|
||||
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
|
||||
tool = ModelListTool(
|
||||
dao_class=DashboardDAO,
|
||||
output_schema=DashboardInfo,
|
||||
item_serializer=lambda obj, cols: serialize_dashboard_object(obj),
|
||||
filter_type=DashboardFilter,
|
||||
default_columns=DEFAULT_DASHBOARD_COLUMNS,
|
||||
search_columns=[
|
||||
"dashboard_title",
|
||||
"slug",
|
||||
"uuid",
|
||||
],
|
||||
list_field_name="dashboards",
|
||||
output_list_schema=DashboardList,
|
||||
logger=logger,
|
||||
)
|
||||
return tool.run(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=request.order_column,
|
||||
order_direction=request.order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
9
superset/mcp_service/dataset/tool/__init__.py
Normal file
9
superset/mcp_service/dataset/tool/__init__.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from .get_dataset_available_filters import get_dataset_available_filters
|
||||
from .get_dataset_info import get_dataset_info
|
||||
from .list_datasets import list_datasets
|
||||
|
||||
__all__ = [
|
||||
"list_datasets",
|
||||
"get_dataset_info",
|
||||
"get_dataset_available_filters",
|
||||
]
|
||||
@@ -0,0 +1,51 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""
|
||||
Get available dataset filters FastMCP tool
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import ModelGetAvailableFiltersTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.dataset_schemas import (
|
||||
DatasetAvailableFilters,
|
||||
GetDatasetAvailableFiltersRequest,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def get_dataset_available_filters(
|
||||
request: GetDatasetAvailableFiltersRequest,
|
||||
) -> DatasetAvailableFilters:
|
||||
"""
|
||||
Dynamically get information about available dataset filters and their operators.
|
||||
Custom fields supported: tags, owner.
|
||||
Returns a DatasetAvailableFilters object with column_operators.
|
||||
"""
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
tool = ModelGetAvailableFiltersTool(
|
||||
dao_class=DatasetDAO,
|
||||
output_schema=DatasetAvailableFilters,
|
||||
logger=logger,
|
||||
)
|
||||
return tool.run()
|
||||
69
superset/mcp_service/dataset/tool/get_dataset_info.py
Normal file
69
superset/mcp_service/dataset/tool/get_dataset_info.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get dataset info FastMCP tool
|
||||
|
||||
This module contains the FastMCP tool for getting detailed information
|
||||
about a specific dataset.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import ModelGetInfoTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas import DatasetError, DatasetInfo
|
||||
from superset.mcp_service.schemas.dataset_schemas import (
|
||||
GetDatasetInfoRequest,
|
||||
serialize_dataset_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def get_dataset_info(request: GetDatasetInfoRequest) -> DatasetInfo | DatasetError:
|
||||
"""
|
||||
Get detailed information about a specific dataset with metadata cache control.
|
||||
|
||||
Supports lookup by:
|
||||
- Numeric ID (e.g., 123)
|
||||
- UUID string (e.g., "a1b2c3d4-e5f6-7890-abcd-ef1234567890")
|
||||
|
||||
Metadata Cache Control:
|
||||
- use_cache: Whether to use metadata cache for faster responses
|
||||
- refresh_metadata: Force refresh of metadata cache for fresh data
|
||||
|
||||
When refresh_metadata=True, the tool will fetch fresh column and metric
|
||||
metadata from the database, which is useful when table schema has changed.
|
||||
|
||||
Returns a DatasetInfo model or DatasetError on error.
|
||||
"""
|
||||
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
tool = ModelGetInfoTool(
|
||||
dao_class=DatasetDAO,
|
||||
output_schema=DatasetInfo,
|
||||
error_schema=DatasetError,
|
||||
serializer=serialize_dataset_object,
|
||||
supports_slug=False, # Datasets don't have slugs
|
||||
logger=logger,
|
||||
)
|
||||
return tool.run(request.identifier)
|
||||
111
superset/mcp_service/dataset/tool/list_datasets.py
Normal file
111
superset/mcp_service/dataset/tool/list_datasets.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
List datasets FastMCP tool (Advanced with metadata cache control)
|
||||
|
||||
This module contains the FastMCP tool for listing datasets using
|
||||
advanced filtering with clear, unambiguous request schema and metadata cache control.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import ModelListTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas import DatasetInfo, DatasetList
|
||||
from superset.mcp_service.schemas.dataset_schemas import (
|
||||
DatasetFilter,
|
||||
ListDatasetsRequest,
|
||||
serialize_dataset_object,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_DATASET_COLUMNS = [
|
||||
"id",
|
||||
"table_name",
|
||||
"schema",
|
||||
"uuid",
|
||||
"database_name",
|
||||
"changed_by_name",
|
||||
"changed_on",
|
||||
"created_by_name",
|
||||
"created_on",
|
||||
"metrics",
|
||||
"columns",
|
||||
]
|
||||
|
||||
SORTABLE_DATASET_COLUMNS = [
|
||||
"id",
|
||||
"table_name",
|
||||
"schema",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
]
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def list_datasets(request: ListDatasetsRequest) -> DatasetList:
|
||||
"""
|
||||
List datasets with advanced filtering, search, and metadata cache control.
|
||||
|
||||
Datasets are sorted by favorites first (if user has favorites), then by
|
||||
most recently updated.
|
||||
|
||||
Uses a clear request object schema to avoid validation ambiguity with
|
||||
arrays/strings. All parameters are properly typed and have sensible defaults.
|
||||
|
||||
Sortable columns for order_column: id, table_name, schema, changed_on, created_on
|
||||
|
||||
Metadata Cache Control:
|
||||
- use_cache: Whether to use metadata cache for faster responses
|
||||
- refresh_metadata: Force refresh of metadata cache for fresh data
|
||||
|
||||
When refresh_metadata=True, the tool will fetch fresh column and metric
|
||||
metadata from the database, which is useful when table schema has changed.
|
||||
"""
|
||||
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
|
||||
# Create tool with standard serialization
|
||||
tool = ModelListTool(
|
||||
dao_class=DatasetDAO,
|
||||
output_schema=DatasetInfo,
|
||||
item_serializer=lambda obj, cols: serialize_dataset_object(obj),
|
||||
filter_type=DatasetFilter,
|
||||
default_columns=DEFAULT_DATASET_COLUMNS,
|
||||
search_columns=["schema", "sql", "table_name", "uuid"],
|
||||
list_field_name="datasets",
|
||||
output_list_schema=DatasetList,
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
# Default ordering: by most recently updated
|
||||
order_column = request.order_column or "changed_on"
|
||||
order_direction = request.order_direction or "desc"
|
||||
|
||||
return tool.run(
|
||||
filters=request.filters,
|
||||
search=request.search,
|
||||
select_columns=request.select_columns,
|
||||
order_column=order_column,
|
||||
order_direction=order_direction,
|
||||
page=max(request.page - 1, 0),
|
||||
page_size=request.page_size,
|
||||
)
|
||||
490
superset/mcp_service/generic_tools.py
Normal file
490
superset/mcp_service/generic_tools.py
Normal file
@@ -0,0 +1,490 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Protocol, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Type variables for generic model tools
|
||||
T = TypeVar("T") # For model objects
|
||||
S = TypeVar("S", bound=BaseModel) # For Pydantic schemas
|
||||
F = TypeVar("F", bound=BaseModel) # For filter types
|
||||
|
||||
|
||||
class DAO(Protocol):
|
||||
"""Protocol for Data Access Objects used in model tools."""
|
||||
|
||||
model_cls: Type[Any]
|
||||
|
||||
@classmethod
|
||||
def list(
|
||||
cls,
|
||||
column_operators: Optional[List[Any]] = None,
|
||||
order_column: str = "changed_on",
|
||||
order_direction: str = "desc",
|
||||
page: int = 0,
|
||||
page_size: int = 100,
|
||||
search: Optional[str] = None,
|
||||
search_columns: Optional[List[str]] = None,
|
||||
custom_filters: Optional[Dict[str, Any]] = None,
|
||||
columns: Optional[List[str]] = None,
|
||||
) -> tuple[List[Any], int]:
|
||||
"""List method that all DAOs should implement."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def find_by_id(cls, id: int) -> T | None:
|
||||
"""Find by ID method that all DAOs should implement."""
|
||||
...
|
||||
|
||||
@classmethod
|
||||
def get_filterable_columns_and_operators(cls) -> Dict[str, Any]:
|
||||
"""Get filterable columns and operators."""
|
||||
...
|
||||
|
||||
|
||||
class ModelListTool:
|
||||
"""
|
||||
Generic tool for listing model objects with filtering, search, pagination, and
|
||||
column selection.
|
||||
|
||||
- Paging is 0-based: page=0 is the first page (to match backend and API
|
||||
conventions).
|
||||
- total_pages is 0 if there are no results; otherwise, it's ceil(total_count /
|
||||
page_size).
|
||||
- has_previous is True if page > 0 or (page == 0 and total_count == 0) (so UI
|
||||
can disable prev button on empty results).
|
||||
- has_next is True if there are more results after the current page.
|
||||
- columns_requested/columns_loaded track what columns were requested/returned
|
||||
for LLM/OpenAPI friendliness.
|
||||
- Returns a strongly-typed Pydantic list schema (output_list_schema) with all
|
||||
metadata.
|
||||
- Handles both object-based and JSON string filters.
|
||||
- Designed for use by LLM agents and API clients.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dao_class: Any,
|
||||
output_schema: Type[S],
|
||||
item_serializer: Callable[[T, List[str]], S | None],
|
||||
filter_type: Type[F],
|
||||
default_columns: List[str],
|
||||
search_columns: List[str],
|
||||
list_field_name: str,
|
||||
output_list_schema: Type[BaseModel],
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self.dao_class = dao_class
|
||||
self.output_schema = output_schema
|
||||
self.item_serializer = item_serializer
|
||||
self.filter_type = filter_type
|
||||
self.default_columns = default_columns
|
||||
self.search_columns = search_columns
|
||||
self.list_field_name = list_field_name
|
||||
self.output_list_schema = output_list_schema
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
def run(
|
||||
self,
|
||||
filters: Optional[Any] = None,
|
||||
search: Optional[str] = None,
|
||||
select_columns: Optional[Any] = None,
|
||||
order_column: Optional[str] = None,
|
||||
order_direction: Optional[Literal["asc", "desc"]] = "asc",
|
||||
page: int = 0,
|
||||
page_size: int = 100,
|
||||
) -> Any:
|
||||
# If filters is a string (e.g., from a test), parse it as JSON
|
||||
if isinstance(filters, str):
|
||||
from superset.utils import json
|
||||
|
||||
filters = json.loads(filters)
|
||||
# Ensure select_columns is a list and track what was requested
|
||||
if select_columns:
|
||||
if isinstance(select_columns, str):
|
||||
select_columns = [
|
||||
col.strip() for col in select_columns.split(",") if col.strip()
|
||||
]
|
||||
columns_to_load = select_columns
|
||||
columns_requested = select_columns
|
||||
else:
|
||||
columns_to_load = self.default_columns
|
||||
columns_requested = self.default_columns
|
||||
# Query the DAO
|
||||
items: List[Any]
|
||||
items, total_count = self.dao_class.list(
|
||||
column_operators=filters,
|
||||
order_column=order_column or "changed_on",
|
||||
order_direction=order_direction or "desc",
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
search=search,
|
||||
search_columns=self.search_columns,
|
||||
columns=columns_to_load,
|
||||
)
|
||||
# Serialize items
|
||||
item_objs = []
|
||||
for item in items:
|
||||
obj = self.item_serializer(item, columns_to_load)
|
||||
if obj is not None:
|
||||
item_objs.append(obj)
|
||||
total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 0
|
||||
from superset.mcp_service.schemas import PaginationInfo
|
||||
|
||||
pagination_info = PaginationInfo(
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_count=total_count,
|
||||
total_pages=total_pages,
|
||||
has_next=page < total_pages - 1,
|
||||
has_previous=page > 0,
|
||||
)
|
||||
|
||||
# Build response
|
||||
def get_keys(obj: BaseModel | dict[str, Any] | Any) -> List[str]:
|
||||
if hasattr(obj, "model_dump"):
|
||||
return list(obj.model_dump().keys())
|
||||
elif isinstance(obj, dict):
|
||||
return list(obj.keys())
|
||||
return []
|
||||
|
||||
response_kwargs = {
|
||||
self.list_field_name: item_objs,
|
||||
"count": len(item_objs),
|
||||
"total_count": total_count,
|
||||
"page": page,
|
||||
"page_size": page_size,
|
||||
"total_pages": total_pages,
|
||||
"has_previous": page > 0,
|
||||
"has_next": page < total_pages - 1,
|
||||
"columns_requested": columns_requested,
|
||||
"columns_loaded": columns_to_load,
|
||||
"filters_applied": filters if isinstance(filters, list) else [],
|
||||
"pagination": pagination_info,
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
}
|
||||
response = self.output_list_schema(**response_kwargs)
|
||||
self.logger.info(
|
||||
f"Successfully retrieved {len(item_objs)} {self.list_field_name}"
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
class ModelGetInfoTool:
|
||||
"""
|
||||
Enhanced tool for retrieving a single model object by ID, UUID, or slug.
|
||||
|
||||
For datasets and charts: supports ID and UUID
|
||||
For dashboards: supports ID, UUID, and slug
|
||||
|
||||
Uses the appropriate DAO method to find the object based on identifier type.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dao_class: Any,
|
||||
output_schema: Type[BaseModel],
|
||||
error_schema: Type[BaseModel],
|
||||
serializer: Callable[[T], BaseModel],
|
||||
supports_slug: bool = False,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self.dao_class = dao_class
|
||||
self.output_schema = output_schema
|
||||
self.error_schema = error_schema
|
||||
self.serializer = serializer
|
||||
self.supports_slug = supports_slug
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
def _is_uuid(self, value: str) -> bool:
|
||||
"""Check if a string is a valid UUID."""
|
||||
import uuid
|
||||
|
||||
try:
|
||||
uuid.UUID(value)
|
||||
return True
|
||||
except ValueError:
|
||||
return False
|
||||
|
||||
def _find_object(self, identifier: int | str) -> Any:
|
||||
"""Find object by identifier using appropriate method."""
|
||||
# If it's an integer or string that can be converted to int, use find_by_id
|
||||
if isinstance(identifier, int):
|
||||
return self.dao_class.find_by_id(identifier)
|
||||
|
||||
try:
|
||||
# Try to convert string to int
|
||||
id_val = int(identifier)
|
||||
return self.dao_class.find_by_id(id_val)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Check if it's a UUID
|
||||
if self._is_uuid(identifier):
|
||||
# Use the new flexible find_by_id with uuid column
|
||||
import uuid
|
||||
|
||||
uuid_obj = uuid.UUID(identifier)
|
||||
return self.dao_class.find_by_id(uuid_obj, id_column="uuid")
|
||||
|
||||
# For dashboards, also check slug
|
||||
if self.supports_slug:
|
||||
# Try to find by slug using the new flexible method
|
||||
result = self.dao_class.find_by_id(identifier, id_column="slug")
|
||||
if result:
|
||||
return result
|
||||
|
||||
# Fallback to the existing id_or_slug_filter for complex cases
|
||||
from superset.extensions import db
|
||||
from superset.models.dashboard import id_or_slug_filter
|
||||
|
||||
model_class = self.dao_class.model_cls
|
||||
return (
|
||||
db.session.query(model_class)
|
||||
.filter(id_or_slug_filter(identifier))
|
||||
.one_or_none()
|
||||
)
|
||||
|
||||
# If we get here, it's an invalid identifier
|
||||
return None
|
||||
|
||||
def run(self, identifier: int | str) -> Any:
|
||||
try:
|
||||
obj = self._find_object(identifier)
|
||||
if obj is None:
|
||||
error_data = self.error_schema(
|
||||
error=(
|
||||
f"{self.output_schema.__name__} with identifier "
|
||||
f"'{identifier}' not found"
|
||||
),
|
||||
error_type="not_found",
|
||||
timestamp=datetime.now(timezone.utc),
|
||||
)
|
||||
self.logger.warning(
|
||||
f"{self.output_schema.__name__} {identifier} error: "
|
||||
"not_found - not found"
|
||||
)
|
||||
return error_data
|
||||
response = self.serializer(obj)
|
||||
self.logger.info(
|
||||
f"{self.output_schema.__name__} response created successfully for "
|
||||
f"identifier {identifier}"
|
||||
)
|
||||
return response
|
||||
except Exception as context_error:
|
||||
error_msg = f"Error in ModelGetInfoTool: {str(context_error)}"
|
||||
self.logger.error(error_msg, exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
class InstanceInfoTool:
|
||||
"""
|
||||
Configurable tool for generating comprehensive instance information.
|
||||
|
||||
Provides a flexible way to gather and present statistics about a Superset
|
||||
instance with configurable metrics, time windows, and data aggregations.
|
||||
Supports custom metric calculators and result transformers for extensibility.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dao_classes: Dict[str, Any],
|
||||
output_schema: Type[BaseModel],
|
||||
metric_calculators: Dict[str, Callable[..., Any]],
|
||||
time_windows: Optional[Dict[str, int]] = None,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize the instance info tool.
|
||||
|
||||
Args:
|
||||
dao_classes: Dict mapping entity names to their DAO classes
|
||||
output_schema: Pydantic schema for the response
|
||||
metric_calculators: Dict of custom metric calculation functions
|
||||
time_windows: Dict of time window configurations (days)
|
||||
logger: Optional logger instance
|
||||
"""
|
||||
self.dao_classes = dao_classes
|
||||
self.output_schema = output_schema
|
||||
self.metric_calculators = metric_calculators
|
||||
self.time_windows = time_windows or {
|
||||
"recent": 7,
|
||||
"monthly": 30,
|
||||
"quarterly": 90,
|
||||
}
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
def _calculate_basic_counts(self) -> Dict[str, int]:
|
||||
"""Calculate basic entity counts using DAOs."""
|
||||
counts = {}
|
||||
for entity_name, dao_class in self.dao_classes.items():
|
||||
try:
|
||||
counts[f"total_{entity_name}"] = dao_class.count()
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to count {entity_name}: {e}")
|
||||
counts[f"total_{entity_name}"] = 0
|
||||
return counts
|
||||
|
||||
def _calculate_time_based_metrics(
|
||||
self, base_counts: Dict[str, int]
|
||||
) -> Dict[str, Dict[str, int]]:
|
||||
"""Calculate time-based metrics for recent activity."""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
time_metrics = {}
|
||||
|
||||
for window_name, days in self.time_windows.items():
|
||||
cutoff_date = now - timedelta(days=days)
|
||||
window_metrics = {}
|
||||
|
||||
# Calculate created and modified counts for each entity
|
||||
for entity_name, dao_class in self.dao_classes.items():
|
||||
# Skip entities without time tracking
|
||||
if not hasattr(dao_class.model_cls, "created_on"):
|
||||
continue
|
||||
|
||||
try:
|
||||
# Created count
|
||||
created_count = dao_class.count(
|
||||
column_operators=[
|
||||
ColumnOperator(
|
||||
col="created_on",
|
||||
opr=ColumnOperatorEnum.gte,
|
||||
value=cutoff_date,
|
||||
)
|
||||
]
|
||||
)
|
||||
window_metrics[f"{entity_name}_created"] = created_count
|
||||
|
||||
# Modified count (if changed_on exists)
|
||||
if hasattr(dao_class.model_cls, "changed_on"):
|
||||
modified_count = dao_class.count(
|
||||
column_operators=[
|
||||
ColumnOperator(
|
||||
col="changed_on",
|
||||
opr=ColumnOperatorEnum.gte,
|
||||
value=cutoff_date,
|
||||
)
|
||||
]
|
||||
)
|
||||
window_metrics[f"{entity_name}_modified"] = modified_count
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(
|
||||
f"Failed to calculate {window_name} metrics for "
|
||||
f"{entity_name}: {e}"
|
||||
)
|
||||
window_metrics[f"{entity_name}_created"] = 0
|
||||
window_metrics[f"{entity_name}_modified"] = 0
|
||||
|
||||
time_metrics[window_name] = window_metrics
|
||||
|
||||
return time_metrics
|
||||
|
||||
def _calculate_custom_metrics(
|
||||
self, base_counts: Dict[str, int], time_metrics: Dict[str, Dict[str, int]]
|
||||
) -> Dict[str, Any]:
|
||||
"""Calculate custom metrics using provided calculators."""
|
||||
custom_metrics = {}
|
||||
|
||||
for metric_name, calculator in self.metric_calculators.items():
|
||||
try:
|
||||
# Pass context to calculator functions
|
||||
result = calculator(
|
||||
base_counts=base_counts,
|
||||
time_metrics=time_metrics,
|
||||
dao_classes=self.dao_classes,
|
||||
)
|
||||
# Only include successful calculations
|
||||
if result is not None:
|
||||
custom_metrics[metric_name] = result
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to calculate {metric_name}: {e}")
|
||||
# Don't add failed metrics to avoid validation errors
|
||||
|
||||
return custom_metrics
|
||||
|
||||
def run(self) -> BaseModel:
|
||||
"""Generate comprehensive instance information."""
|
||||
try:
|
||||
# Calculate all metrics
|
||||
base_counts = self._calculate_basic_counts()
|
||||
time_metrics = self._calculate_time_based_metrics(base_counts)
|
||||
custom_metrics = self._calculate_custom_metrics(base_counts, time_metrics)
|
||||
|
||||
# Combine all data with fallbacks for required fields
|
||||
from datetime import datetime, timezone
|
||||
|
||||
response_data = {
|
||||
**base_counts,
|
||||
**time_metrics,
|
||||
**custom_metrics,
|
||||
"timestamp": datetime.now(timezone.utc),
|
||||
}
|
||||
|
||||
# Create response using the configured schema
|
||||
response = self.output_schema(**response_data)
|
||||
|
||||
self.logger.info("Successfully generated instance information")
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in InstanceInfoTool: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
|
||||
class ModelGetAvailableFiltersTool:
|
||||
"""
|
||||
Generic tool for retrieving available filterable columns and operators for a
|
||||
model. Used for get_dataset_available_filters, get_chart_available_filters,
|
||||
get_dashboard_available_filters, etc.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dao_class: Any,
|
||||
output_schema: Type[BaseModel],
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> None:
|
||||
self.dao_class = dao_class
|
||||
self.output_schema = output_schema
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
|
||||
def run(self) -> BaseModel:
|
||||
try:
|
||||
filterable = self.dao_class.get_filterable_columns_and_operators()
|
||||
# Ensure column_operators is a plain dict, not a custom type
|
||||
column_operators = dict(filterable)
|
||||
response = self.output_schema(column_operators=column_operators)
|
||||
self.logger.info(
|
||||
f"Successfully retrieved available filters for "
|
||||
f"{self.dao_class.__name__}"
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"Error in ModelGetAvailableFiltersTool: {e}", exc_info=True
|
||||
)
|
||||
raise
|
||||
488
superset/mcp_service/mcp_app.py
Normal file
488
superset/mcp_service/mcp_app.py
Normal file
@@ -0,0 +1,488 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
FastMCP app instance and initialization for Superset MCP service.
|
||||
This file provides the global FastMCP instance (mcp) and a function to initialize
|
||||
middleware. All tool modules should import mcp from here and use @mcp.tool and
|
||||
@mcp_auth_hook decorators.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
from fastmcp import FastMCP
|
||||
from fastmcp.server.auth.providers.bearer import BearerAuthProvider
|
||||
from starlette.exceptions import HTTPException
|
||||
from starlette.responses import Response
|
||||
|
||||
from superset.mcp_service.middleware import LoggingMiddleware, PrivateToolMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Simple in-memory cache for screenshots (chart_id -> (timestamp, image_data))
|
||||
_screenshot_cache: Dict[str, Tuple[float, bytes]] = {}
|
||||
SCREENSHOT_CACHE_TTL = 300 # 5 minutes cache
|
||||
|
||||
|
||||
def _create_auth_provider() -> Optional[BearerAuthProvider]:
|
||||
"""
|
||||
Create a BearerAuthProvider using the configured factory function.
|
||||
Uses app.config["MCP_AUTH_FACTORY"](app) pattern as suggested by @dpgaspar.
|
||||
"""
|
||||
try:
|
||||
from superset import app as superset_app
|
||||
from superset.mcp_service.config import DEFAULT_CONFIG
|
||||
|
||||
# Apply defaults to app.config if not already set
|
||||
for key, value in DEFAULT_CONFIG.items():
|
||||
if key not in superset_app.config:
|
||||
superset_app.config[key] = value
|
||||
|
||||
# Call the factory using app.config pattern
|
||||
auth_factory = superset_app.config.get("MCP_AUTH_FACTORY")
|
||||
if auth_factory and callable(auth_factory):
|
||||
return auth_factory(superset_app)
|
||||
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create auth provider: {e}")
|
||||
return None
|
||||
|
||||
|
||||
# Create MCP instance without auth initially - auth will be configured in
|
||||
# init_fastmcp_server()
|
||||
mcp = FastMCP(
|
||||
"Superset MCP Server",
|
||||
auth=None, # Will be set later via factory
|
||||
instructions="""
|
||||
You are connected to the Apache Superset MCP (Model Context Protocol) service.
|
||||
This service provides programmatic access to Superset dashboards, charts, datasets,
|
||||
and instance metadata via a set of high-level tools.
|
||||
|
||||
Available tools include:
|
||||
- list_dashboards: Dashboard listing with advanced filters (use 'filters' for
|
||||
advanced queries, 1-based pagination)
|
||||
- get_dashboard_info: Get detailed information about a dashboard by its integer ID
|
||||
- get_superset_instance_info: Get high-level statistics and metadata about the
|
||||
Superset instance (no arguments)
|
||||
- get_dashboard_available_filters: List all available dashboard filter fields and
|
||||
operators
|
||||
- list_datasets: DatasetInfo listing with advanced filters (use 'filters' for
|
||||
advanced queries, 1-based pagination)
|
||||
- get_dataset_info: Get detailed information about a dataset by its integer ID
|
||||
- get_dataset_available_filters: List all available dataset filter fields and
|
||||
operators
|
||||
- list_charts: Chart listing with advanced filters (use 'filters' for advanced
|
||||
queries, 1-based pagination)
|
||||
- get_chart_info: Get detailed information about a chart by its integer ID
|
||||
- get_chart_preview: Get a visual preview of a chart with image URL for display
|
||||
- get_chart_data: Get the underlying data for a chart in text-friendly format
|
||||
- get_chart_available_filters: List all available chart filter fields and operators
|
||||
- generate_explore_link: Generate a pre-configured explore URL with specified
|
||||
dataset, metrics, dimensions, and filters for direct navigation
|
||||
|
||||
|
||||
General usage tips:
|
||||
- For listing tools, 'page' is 1-based (first page is 1)
|
||||
- Use 'filters' to narrow down results (see get_dashboard_available_filters,
|
||||
get_dataset_available_filters, get_chart_available_filters for supported fields
|
||||
and operators)
|
||||
- Use get_dashboard_info, get_dataset_info, get_chart_info with a valid ID from
|
||||
the listing tools
|
||||
- For instance-wide stats, call get_superset_instance_info with no arguments
|
||||
- All tools return structured, Pydantic-typed responses
|
||||
|
||||
If you are unsure which tool to use, start with list_dashboards or
|
||||
get_superset_instance_info for a summary of the Superset instance.
|
||||
""",
|
||||
)
|
||||
|
||||
# Import all tool modules to ensure registration (must be after mcp is defined)
|
||||
# These imports register the tools with the mcp instance
|
||||
import superset.mcp_service.chart.tool # noqa: F401, E402
|
||||
import superset.mcp_service.dashboard.tool # noqa: F401, E402
|
||||
import superset.mcp_service.dataset.tool # noqa: F401, E402
|
||||
import superset.mcp_service.sql_lab.tool # noqa: F401, E402
|
||||
import superset.mcp_service.system.tool # noqa: F401, E402
|
||||
|
||||
|
||||
# Add custom route for serving screenshot images
|
||||
async def serve_chart_screenshot(chart_id: str) -> Any: # noqa: C901
|
||||
"""
|
||||
Serve chart screenshot images directly as PNG files.
|
||||
This endpoint provides public access to chart screenshots without authentication.
|
||||
"""
|
||||
|
||||
# Check cache first
|
||||
current_time = time.time()
|
||||
cache_key = f"chart_{chart_id}"
|
||||
|
||||
if cache_key in _screenshot_cache:
|
||||
timestamp, cached_data = _screenshot_cache[cache_key]
|
||||
if current_time - timestamp < SCREENSHOT_CACHE_TTL:
|
||||
logger.info(f"Serving cached screenshot for chart {chart_id}")
|
||||
return Response(
|
||||
content=cached_data,
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Cache-Control": "public, max-age=300", # 5 min cache
|
||||
"Content-Disposition": f"inline; filename=chart_{chart_id}.png",
|
||||
"X-Cache": "HIT",
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
from flask import g
|
||||
|
||||
from superset import app as superset_app
|
||||
from superset.daos.chart import ChartDAO
|
||||
from superset.mcp_service.pooled_screenshot import PooledChartScreenshot
|
||||
from superset.utils.urls import get_url_path
|
||||
|
||||
# Set up Flask app context for database access
|
||||
with superset_app.app_context():
|
||||
# Create a mock user context - you might need to adjust this
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.extensions import db
|
||||
|
||||
# Get username from config, fallback to "admin"
|
||||
username = superset_app.config.get("MCP_ADMIN_USERNAME", "admin")
|
||||
mock_user = db.session.query(User).filter_by(username=username).first()
|
||||
if mock_user:
|
||||
g.user = mock_user
|
||||
else:
|
||||
logger.warning(f"User '{username}' not found, screenshot may fail")
|
||||
|
||||
# Find the chart
|
||||
chart = None
|
||||
try:
|
||||
if chart_id.isdigit():
|
||||
chart = ChartDAO.find_by_id(int(chart_id))
|
||||
else:
|
||||
# Try UUID lookup using DAO flexible method
|
||||
chart = ChartDAO.find_by_id(chart_id, id_column="uuid")
|
||||
except Exception as e:
|
||||
logger.error(f"Error looking up chart {chart_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
f"Database error while looking up chart {chart_id}: {str(e)}"
|
||||
),
|
||||
) from e
|
||||
|
||||
if not chart:
|
||||
logger.warning(f"Chart {chart_id} not found in database")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=(
|
||||
f"Chart with ID '{chart_id}' not found. "
|
||||
f"Please verify the chart ID exists."
|
||||
),
|
||||
)
|
||||
|
||||
logger.info(f"Serving screenshot for chart {chart.id}: {chart.slice_name}")
|
||||
|
||||
# Create chart URL for screenshot
|
||||
chart_url = get_url_path("Superset.slice", slice_id=chart.id)
|
||||
|
||||
# Create screenshot object
|
||||
screenshot = PooledChartScreenshot(chart_url, chart.digest)
|
||||
|
||||
# Generate screenshot (800x600 default)
|
||||
window_size = (800, 600)
|
||||
try:
|
||||
image_data = screenshot.get_screenshot(
|
||||
user=g.user, window_size=window_size
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Screenshot generation failed for chart {chart_id}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
f"Failed to generate screenshot for chart {chart_id}. "
|
||||
f"Error: {str(e)}"
|
||||
),
|
||||
) from e
|
||||
|
||||
if image_data:
|
||||
# Cache the screenshot
|
||||
_screenshot_cache[cache_key] = (current_time, image_data)
|
||||
|
||||
# Clean up old cache entries (simple cleanup)
|
||||
keys_to_remove = []
|
||||
for key, (ts, _) in _screenshot_cache.items():
|
||||
if current_time - ts > SCREENSHOT_CACHE_TTL:
|
||||
keys_to_remove.append(key)
|
||||
for key in keys_to_remove:
|
||||
del _screenshot_cache[key]
|
||||
|
||||
logger.info(f"Generated and cached screenshot for chart {chart_id}")
|
||||
|
||||
# Return the PNG image directly
|
||||
return Response(
|
||||
content=image_data,
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Cache-Control": "public, max-age=300", # 5 min cache
|
||||
"Content-Disposition": f"inline; filename=chart_{chart.id}.png",
|
||||
"X-Cache": "MISS",
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.error(f"Screenshot returned None for chart {chart_id}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
f"Screenshot generation returned empty result for "
|
||||
f"chart {chart_id}. The chart may have rendering issues."
|
||||
),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving screenshot for chart {chart_id}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
def _get_form_data_from_cache(form_data_key: str) -> str:
|
||||
"""Retrieve form data from cache using the form data key."""
|
||||
from superset.commands.explore.form_data.get import GetFormDataCommand
|
||||
from superset.commands.explore.form_data.parameters import (
|
||||
CommandParameters as FormDataCommandParameters,
|
||||
)
|
||||
|
||||
try:
|
||||
parameters = FormDataCommandParameters(key=form_data_key)
|
||||
form_data_json = GetFormDataCommand(parameters).run()
|
||||
if not form_data_json:
|
||||
logger.warning(f"Form data key not found in cache: {form_data_key}")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail=(
|
||||
f"Form data key '{form_data_key}' not found or expired. "
|
||||
f"Please generate a new explore link."
|
||||
),
|
||||
)
|
||||
return form_data_json
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to retrieve form data for key {form_data_key}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Error retrieving form data for key '{form_data_key}': {str(e)}",
|
||||
) from e
|
||||
|
||||
|
||||
def _parse_datasource_from_form_data(form_data_json: str) -> tuple[str, str]:
|
||||
"""Parse datasource info from form data JSON."""
|
||||
from superset.utils import json
|
||||
|
||||
try:
|
||||
form_data = json.loads(form_data_json)
|
||||
datasource = form_data.get("datasource", "")
|
||||
if datasource and "__" in datasource:
|
||||
datasource_id, datasource_type = datasource.split("__", 1)
|
||||
else:
|
||||
# Try to extract from other fields
|
||||
datasource_id = form_data.get("datasource_id", "")
|
||||
datasource_type = form_data.get("datasource_type", "table")
|
||||
return datasource_id, datasource_type
|
||||
except Exception:
|
||||
logger.warning("Could not parse form data to get datasource info")
|
||||
return "", "table"
|
||||
|
||||
|
||||
async def serve_explore_screenshot(form_data_key: str) -> Any:
|
||||
"""
|
||||
Serve explore screenshot images from form_data_key.
|
||||
|
||||
Args:
|
||||
form_data_key: The form data key for the explore view
|
||||
|
||||
Returns:
|
||||
StreamingResponse with PNG image data
|
||||
"""
|
||||
try:
|
||||
from flask import g
|
||||
|
||||
from superset import app as superset_app
|
||||
from superset.utils.urls import get_url_path
|
||||
|
||||
# Set up Flask app context for entire screenshot process
|
||||
with superset_app.app_context():
|
||||
# Create a mock user context - you might need to adjust this
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
from superset.extensions import db
|
||||
|
||||
# Get username from config, fallback to "admin"
|
||||
username = superset_app.config.get("MCP_ADMIN_USERNAME", "admin")
|
||||
mock_user = db.session.query(User).filter_by(username=username).first()
|
||||
if mock_user:
|
||||
g.user = mock_user
|
||||
else:
|
||||
logger.warning(f"User '{username}' not found, screenshot may fail")
|
||||
|
||||
# Look up the form data from the cache
|
||||
form_data_json = _get_form_data_from_cache(form_data_key)
|
||||
|
||||
# Parse form data to get datasource info
|
||||
datasource_id, datasource_type = _parse_datasource_from_form_data(
|
||||
form_data_json
|
||||
)
|
||||
|
||||
# Create explore URL with all necessary parameters
|
||||
explore_url = get_url_path("Superset.explore")
|
||||
url_params = [f"form_data_key={form_data_key}"]
|
||||
|
||||
# Add datasource parameters if available
|
||||
if datasource_id:
|
||||
url_params.append(f"datasource_id={datasource_id}")
|
||||
url_params.append(f"datasource_type={datasource_type}")
|
||||
|
||||
explore_url += "?" + "&".join(url_params)
|
||||
|
||||
logger.info(f"Generating screenshot for explore URL: {explore_url}")
|
||||
logger.info(
|
||||
f"Form data retrieved: "
|
||||
f"{form_data_json[:200] if form_data_json else 'None'}..."
|
||||
) # Log first 200 chars
|
||||
|
||||
# Use pooled screenshot for better performance
|
||||
import hashlib
|
||||
|
||||
from superset.mcp_service.pooled_screenshot import PooledExploreScreenshot
|
||||
|
||||
digest = hashlib.sha256(form_data_key.encode()).hexdigest()
|
||||
screenshot = PooledExploreScreenshot(explore_url, digest)
|
||||
|
||||
# Generate screenshot with higher resolution
|
||||
window_size = (1600, 1200) # Doubled resolution from 800x600
|
||||
try:
|
||||
image_data = screenshot.get_screenshot(
|
||||
user=g.user, window_size=window_size
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Screenshot generation failed for explore view "
|
||||
f"{form_data_key}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
f"Failed to generate screenshot for explore view. "
|
||||
f"Error: {str(e)}"
|
||||
),
|
||||
) from e
|
||||
|
||||
if image_data:
|
||||
# Return the PNG image directly
|
||||
return Response(
|
||||
content=image_data,
|
||||
media_type="image/png",
|
||||
headers={
|
||||
"Cache-Control": "public, max-age=3600", # Cache for 1 hour
|
||||
"Content-Disposition": (
|
||||
f"inline; filename=explore_{form_data_key}.png"
|
||||
),
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Screenshot returned None for explore view {form_data_key}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=(
|
||||
"Screenshot generation returned empty result for explore view. "
|
||||
"The view may have rendering issues or invalid parameters."
|
||||
),
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
# Re-raise HTTP exceptions
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Error serving screenshot for form_data_key {form_data_key}: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
# Register the custom route using decorator syntax
|
||||
@mcp.custom_route("/screenshot/chart/{chart_id}.png", methods=["GET"])
|
||||
async def serve_chart_screenshot_endpoint(request: Any) -> Any:
|
||||
"""
|
||||
Custom HTTP endpoint for serving chart screenshots.
|
||||
"""
|
||||
# Extract chart_id from path parameters
|
||||
chart_id = request.path_params["chart_id"]
|
||||
|
||||
# Call our screenshot function
|
||||
return await serve_chart_screenshot(chart_id)
|
||||
|
||||
|
||||
@mcp.custom_route("/screenshot/explore/{form_data_key}.png", methods=["GET"])
|
||||
async def serve_explore_screenshot_endpoint(request: Any) -> Any:
|
||||
"""
|
||||
Custom HTTP endpoint for serving explore screenshots from form_data_key.
|
||||
"""
|
||||
# Extract form_data_key from path parameters
|
||||
form_data_key = request.path_params["form_data_key"]
|
||||
|
||||
# Call our explore screenshot function
|
||||
return await serve_explore_screenshot(form_data_key)
|
||||
|
||||
|
||||
def init_fastmcp_server(enable_auth_configuration: bool = True) -> FastMCP:
|
||||
"""
|
||||
Initialize and configure the FastMCP server with all middleware.
|
||||
This should be called before running the server to ensure middleware is registered.
|
||||
|
||||
Args:
|
||||
enable_auth_configuration: If True, configure auth using the factory pattern
|
||||
"""
|
||||
logger.setLevel(logging.DEBUG)
|
||||
|
||||
# Configure authentication using factory pattern
|
||||
if enable_auth_configuration:
|
||||
try:
|
||||
auth_provider = _create_auth_provider()
|
||||
if auth_provider:
|
||||
logger.info("Configuring MCP authentication using factory pattern")
|
||||
# Set the auth provider on the mcp instance
|
||||
mcp.auth = auth_provider
|
||||
logger.info(
|
||||
f"Authentication configured: {type(auth_provider).__name__}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"No authentication configured - MCP service will run without auth"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Auth configuration failed: {e}")
|
||||
logger.info("MCP service will run without authentication")
|
||||
|
||||
# Add middleware
|
||||
mcp.add_middleware(LoggingMiddleware())
|
||||
mcp.add_middleware(PrivateToolMiddleware())
|
||||
|
||||
logger.info("MCP Server initialized with modular tools structure")
|
||||
return mcp
|
||||
43
superset/mcp_service/mcp_config_template.py
Normal file
43
superset/mcp_service/mcp_config_template.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP Service Configuration Template
|
||||
|
||||
Copy this to `mcp_config.py` and customize:
|
||||
cp mcp_config_template.py mcp_config.py
|
||||
"""
|
||||
|
||||
# Enable authentication
|
||||
MCP_AUTH_ENABLED = True
|
||||
|
||||
# JWT configuration
|
||||
MCP_JWKS_URI = "https://your-auth-provider.com/.well-known/jwks.json"
|
||||
MCP_JWT_ISSUER = "https://your-auth-provider.com/"
|
||||
MCP_JWT_AUDIENCE = "superset-mcp-server"
|
||||
MCP_JWT_ALGORITHM = "RS256"
|
||||
MCP_REQUIRED_SCOPES = ["superset:read", "superset:query"]
|
||||
|
||||
# Alternative: Use RSA public key instead of JWKS
|
||||
# MCP_JWT_PUBLIC_KEY = """-----BEGIN PUBLIC KEY-----
|
||||
# Your RSA public key here
|
||||
# -----END PUBLIC KEY-----"""
|
||||
|
||||
# Optional: Custom auth factory for advanced use cases
|
||||
# def MCP_AUTH_FACTORY(app):
|
||||
# # Custom logic here
|
||||
# return BearerAuthProvider(...)
|
||||
87
superset/mcp_service/middleware.py
Normal file
87
superset/mcp_service/middleware.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
from fastmcp.exceptions import ToolError
|
||||
from fastmcp.server.middleware import Middleware, MiddlewareContext
|
||||
|
||||
from superset.extensions import event_logger
|
||||
from superset.utils.core import get_user_id
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class LoggingMiddleware(Middleware):
|
||||
"""
|
||||
Middleware that logs every MCP message (request and response) using Superset's
|
||||
event logger. This matches the core Superset audit log system (Action Log UI,
|
||||
logs table, custom loggers). Also attempts to log dashboard_id, chart_id
|
||||
(slice_id), and dataset_id if present in tool params.
|
||||
"""
|
||||
|
||||
async def on_message(
|
||||
self,
|
||||
context: MiddlewareContext,
|
||||
call_next: Callable[[MiddlewareContext], Awaitable[Any]],
|
||||
) -> Any:
|
||||
# Extract agent_id and user_id
|
||||
agent_id = None
|
||||
user_id = None
|
||||
dashboard_id = None
|
||||
slice_id = None
|
||||
dataset_id = None
|
||||
params = getattr(context.message, "params", {}) or {}
|
||||
if hasattr(context, "metadata") and context.metadata:
|
||||
agent_id = context.metadata.get("agent_id")
|
||||
if not agent_id and hasattr(context, "session") and context.session:
|
||||
agent_id = getattr(context.session, "agent_id", None)
|
||||
try:
|
||||
user_id = get_user_id()
|
||||
except Exception:
|
||||
user_id = None
|
||||
# Try to extract IDs from params
|
||||
if isinstance(params, dict):
|
||||
dashboard_id = params.get("dashboard_id")
|
||||
# Chart ID may be under 'chart_id' or 'slice_id'
|
||||
slice_id = params.get("chart_id") or params.get("slice_id")
|
||||
dataset_id = params.get("dataset_id")
|
||||
# Log to Superset's event logger (DB, Action Log UI, or custom)
|
||||
event_logger.log(
|
||||
user_id=user_id,
|
||||
action="mcp_tool_call",
|
||||
dashboard_id=dashboard_id,
|
||||
duration_ms=None,
|
||||
slice_id=slice_id,
|
||||
referrer=None,
|
||||
curated_payload={
|
||||
"tool": getattr(context.message, "name", None),
|
||||
"agent_id": agent_id,
|
||||
"params": params,
|
||||
"method": context.method,
|
||||
"dashboard_id": dashboard_id,
|
||||
"slice_id": slice_id,
|
||||
"dataset_id": dataset_id,
|
||||
},
|
||||
)
|
||||
# (Optional) also log to standard logger for debugging
|
||||
logger.info(
|
||||
f"MCP tool call: tool={getattr(context.message, 'name', None)}, "
|
||||
f"agent_id={agent_id}, user_id={user_id}, method={context.method}, "
|
||||
f"dashboard_id={dashboard_id}, slice_id={slice_id}, dataset_id={dataset_id}"
|
||||
)
|
||||
return await call_next(context)
|
||||
|
||||
|
||||
class PrivateToolMiddleware(Middleware):
|
||||
"""
|
||||
Middleware that blocks access to tools tagged as 'private'.
|
||||
"""
|
||||
|
||||
async def on_call_tool(
|
||||
self,
|
||||
context: MiddlewareContext,
|
||||
call_next: Callable[[MiddlewareContext], Awaitable[Any]],
|
||||
) -> Any:
|
||||
tool = await context.fastmcp_context.fastmcp.get_tool(context.message.name)
|
||||
if "private" in getattr(tool, "tags", set()):
|
||||
raise ToolError(f"Access denied to private tool: {context.message.name}")
|
||||
return await call_next(context)
|
||||
448
superset/mcp_service/pooled_screenshot.py
Normal file
448
superset/mcp_service/pooled_screenshot.py
Normal file
@@ -0,0 +1,448 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pooled screenshot implementations for improved performance
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from flask import current_app
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
from selenium.common.exceptions import TimeoutException, WebDriverException
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.support import expected_conditions
|
||||
from selenium.webdriver.support.ui import WebDriverWait
|
||||
|
||||
from superset.extensions import machine_auth_provider_factory
|
||||
from superset.mcp_service.webdriver_pool import get_webdriver_pool
|
||||
from superset.utils.screenshots import BaseScreenshot, WindowSize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PooledBaseScreenshot(BaseScreenshot):
|
||||
"""
|
||||
Enhanced screenshot class that uses WebDriver pooling for better performance.
|
||||
|
||||
Key improvements:
|
||||
- Reuses WebDriver instances from a connection pool
|
||||
- Faster screenshot generation (no browser startup/shutdown per request)
|
||||
- Better resource management and automatic cleanup
|
||||
- Health checking and recovery of WebDriver instances
|
||||
"""
|
||||
|
||||
def get_screenshot(
|
||||
self, user: User, window_size: WindowSize | None = None
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Generate screenshot using pooled WebDriver for improved performance.
|
||||
|
||||
Args:
|
||||
user: User context for authentication
|
||||
window_size: Optional window size override
|
||||
|
||||
Returns:
|
||||
Screenshot as PNG bytes or None if failed
|
||||
"""
|
||||
window_size = window_size or self.window_size
|
||||
pool = get_webdriver_pool()
|
||||
|
||||
# Use pooled WebDriver
|
||||
with pool.get_driver(window_size, user.id) as driver:
|
||||
try:
|
||||
# Authenticate the driver for this user
|
||||
machine_auth_provider_factory.instance.authenticate_webdriver(
|
||||
driver, user
|
||||
)
|
||||
|
||||
# Navigate to the URL
|
||||
driver.get(self.url)
|
||||
|
||||
# Take screenshot using the specific implementation
|
||||
return self._take_screenshot(driver, user)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error taking screenshot with pooled driver: {e}")
|
||||
raise
|
||||
|
||||
def _take_screenshot(self, driver: Any, user: User) -> bytes | None:
|
||||
"""
|
||||
Override this method in subclasses to implement specific screenshot logic.
|
||||
|
||||
Args:
|
||||
driver: Authenticated WebDriver instance
|
||||
user: User context
|
||||
|
||||
Returns:
|
||||
Screenshot as PNG bytes
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement _take_screenshot")
|
||||
|
||||
|
||||
class PooledChartScreenshot(PooledBaseScreenshot):
|
||||
"""Pooled version of chart screenshot generation"""
|
||||
|
||||
thumbnail_type: str = "chart"
|
||||
element: str = "chart-container"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
digest: str | None,
|
||||
window_size: WindowSize | None = None,
|
||||
thumb_size: WindowSize | None = None,
|
||||
):
|
||||
from superset.utils.urls import modify_url_query
|
||||
from superset.utils.webdriver import ChartStandaloneMode
|
||||
|
||||
# Chart reports are in standalone="true" mode
|
||||
url = modify_url_query(
|
||||
url,
|
||||
standalone=ChartStandaloneMode.HIDE_NAV.value,
|
||||
)
|
||||
super().__init__(url, digest)
|
||||
self.window_size = window_size or (800, 600)
|
||||
self.thumb_size = thumb_size or (400, 300)
|
||||
|
||||
def _take_screenshot(self, driver: Any, user: User) -> bytes | None:
|
||||
"""Take screenshot of chart with standard Superset chart handling"""
|
||||
try:
|
||||
# Wait for page to load
|
||||
selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"]
|
||||
logger.debug(f"Sleeping for {selenium_headstart} seconds")
|
||||
time.sleep(selenium_headstart)
|
||||
|
||||
# Wait for chart container
|
||||
wait = WebDriverWait(driver, current_app.config["SCREENSHOT_LOCATE_WAIT"])
|
||||
element = wait.until(
|
||||
expected_conditions.presence_of_element_located(
|
||||
(By.CLASS_NAME, self.element)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for chart containers to render
|
||||
wait.until(
|
||||
expected_conditions.visibility_of_all_elements_located(
|
||||
(By.CLASS_NAME, "chart-container")
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for loading to complete
|
||||
WebDriverWait(driver, current_app.config["SCREENSHOT_LOAD_WAIT"]).until_not(
|
||||
expected_conditions.presence_of_all_elements_located(
|
||||
(By.CLASS_NAME, "loading")
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for animations
|
||||
animation_wait = current_app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"]
|
||||
logger.debug(f"Wait {animation_wait} seconds for chart animation")
|
||||
time.sleep(animation_wait)
|
||||
|
||||
# Handle unexpected errors if configured
|
||||
if current_app.config.get("SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"):
|
||||
from superset.utils.webdriver import WebDriverSelenium
|
||||
|
||||
unexpected_errors = WebDriverSelenium.find_unexpected_errors(driver)
|
||||
if unexpected_errors:
|
||||
logger.warning(
|
||||
f"{len(unexpected_errors)} errors found in screenshot. "
|
||||
f"URL: {self.url}. Errors: {unexpected_errors}"
|
||||
)
|
||||
|
||||
# Take screenshot
|
||||
logger.debug(f"Taking PNG screenshot as user {user.username}")
|
||||
return element.screenshot_as_png
|
||||
|
||||
except TimeoutException:
|
||||
logger.exception(f"Timeout taking chart screenshot for URL: {self.url}")
|
||||
raise
|
||||
except WebDriverException:
|
||||
logger.exception(f"WebDriver error taking screenshot for URL: {self.url}")
|
||||
raise
|
||||
|
||||
|
||||
class PooledExploreScreenshot(PooledBaseScreenshot):
|
||||
"""
|
||||
Pooled version of explore screenshot with UI hiding functionality.
|
||||
|
||||
This class provides the same clean chart-only screenshots as the previous
|
||||
implementation but with improved performance through WebDriver pooling.
|
||||
"""
|
||||
|
||||
thumbnail_type: str = "explore"
|
||||
element: str = "chart-container"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
digest: str | None,
|
||||
window_size: WindowSize | None = None,
|
||||
thumb_size: WindowSize | None = None,
|
||||
):
|
||||
super().__init__(url, digest)
|
||||
self.window_size = window_size or (1600, 1200)
|
||||
self.thumb_size = thumb_size or (800, 600)
|
||||
|
||||
def _take_screenshot(self, driver: Any, user: User) -> bytes | None:
|
||||
"""
|
||||
Take screenshot of explore page with UI elements hidden for clean chart display.
|
||||
|
||||
This implementation:
|
||||
1. Waits for chart to load
|
||||
2. Hides navigation, panels, and headers using JavaScript
|
||||
3. Takes screenshot of just the chart area
|
||||
4. Includes fallback strategies if chart container not found
|
||||
"""
|
||||
try:
|
||||
# Give explore page time to initialize
|
||||
logger.debug("Waiting for explore page to initialize")
|
||||
time.sleep(3)
|
||||
|
||||
# Wait for chart container to load and be visible
|
||||
wait = WebDriverWait(driver, 45)
|
||||
chart_element = wait.until(
|
||||
expected_conditions.visibility_of_element_located(
|
||||
(By.CLASS_NAME, "chart-container")
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for any loading indicators to disappear
|
||||
WebDriverWait(driver, 30).until_not(
|
||||
expected_conditions.presence_of_all_elements_located(
|
||||
(By.CLASS_NAME, "loading")
|
||||
)
|
||||
)
|
||||
|
||||
# Hide overlapping UI elements using JavaScript
|
||||
self._hide_ui_elements(driver)
|
||||
|
||||
# Wait for UI hiding animations to complete
|
||||
time.sleep(2)
|
||||
|
||||
# Take screenshot of just the chart area
|
||||
img = chart_element.screenshot_as_png
|
||||
logger.info("Successfully captured chart-container screenshot")
|
||||
return img
|
||||
|
||||
except TimeoutException as e:
|
||||
logger.warning(f"Chart container not found, trying fallbacks: {e}")
|
||||
return self._fallback_screenshot(driver)
|
||||
except Exception as e:
|
||||
logger.error(f"Error taking explore screenshot: {e}")
|
||||
raise
|
||||
|
||||
def _hide_ui_elements(self, driver: Any) -> None:
|
||||
"""Hide UI elements to show only the chart"""
|
||||
hide_script = """
|
||||
// Hide the specific left panels by their exact class names
|
||||
var dataSourcePanel = document.querySelector(
|
||||
'.explore-column.data-source-selection'
|
||||
);
|
||||
if (dataSourcePanel) {
|
||||
dataSourcePanel.style.display = 'none';
|
||||
console.log('Hidden data source panel');
|
||||
}
|
||||
|
||||
var controlsPanel = document.querySelector(
|
||||
'.col-sm-3.explore-column.controls-column'
|
||||
);
|
||||
if (controlsPanel) {
|
||||
controlsPanel.style.display = 'none';
|
||||
console.log('Hidden controls panel');
|
||||
}
|
||||
|
||||
// Also try alternate selectors for the controls
|
||||
var controlsAlt = document.querySelector('.explore-column.controls-column');
|
||||
if (controlsAlt) {
|
||||
controlsAlt.style.display = 'none';
|
||||
console.log('Hidden controls panel (alt)');
|
||||
}
|
||||
|
||||
// Hide the main navigation header
|
||||
var mainHeader = document.querySelector('header.top#main-menu');
|
||||
if (mainHeader) {
|
||||
mainHeader.style.display = 'none';
|
||||
console.log('Hidden main navigation header');
|
||||
}
|
||||
|
||||
// Hide the chart header with actions (title input and save button)
|
||||
var chartHeader = document.querySelector('.header-with-actions');
|
||||
if (chartHeader) {
|
||||
chartHeader.style.display = 'none';
|
||||
console.log('Hidden chart header with actions');
|
||||
}
|
||||
|
||||
// Make sure the main chart area expands to use available space
|
||||
var mainContent = document.querySelector('.main-explore-content');
|
||||
if (mainContent) {
|
||||
mainContent.style.width = '100%';
|
||||
// Change from col-sm-7 to col-sm-12
|
||||
mainContent.className = 'main-explore-content col-sm-12';
|
||||
}
|
||||
|
||||
// Ensure chart container is visible
|
||||
var chartContainer = document.querySelector('.chart-container');
|
||||
if (chartContainer) {
|
||||
chartContainer.style.position = 'relative';
|
||||
chartContainer.style.zIndex = '999';
|
||||
chartContainer.style.width = '100%';
|
||||
}
|
||||
"""
|
||||
|
||||
try:
|
||||
driver.execute_script(hide_script)
|
||||
logger.debug("Successfully executed UI hiding script")
|
||||
except WebDriverException as e:
|
||||
logger.warning(f"Failed to execute UI hiding script: {e}")
|
||||
|
||||
def _fallback_screenshot(self, driver: Any) -> bytes | None:
|
||||
"""Fallback screenshot strategies if chart container not found"""
|
||||
try:
|
||||
# Fallback 1: Try slice container
|
||||
logger.debug("Trying slice_container fallback")
|
||||
slice_element = WebDriverWait(driver, 10).until(
|
||||
expected_conditions.visibility_of_element_located(
|
||||
(By.CLASS_NAME, "slice_container")
|
||||
)
|
||||
)
|
||||
img = slice_element.screenshot_as_png
|
||||
logger.info("Successfully captured slice_container screenshot")
|
||||
return img
|
||||
|
||||
except TimeoutException:
|
||||
try:
|
||||
# Fallback 2: Try any chart-related container
|
||||
logger.debug("Trying chart-related container fallback")
|
||||
chart_containers = driver.find_elements(
|
||||
By.CSS_SELECTOR,
|
||||
".chart, .slice-container, .chart-content, "
|
||||
"[data-test='chart-container']",
|
||||
)
|
||||
if chart_containers:
|
||||
img = chart_containers[0].screenshot_as_png
|
||||
logger.info(
|
||||
"Successfully captured chart-related container screenshot"
|
||||
)
|
||||
return img
|
||||
|
||||
except Exception as e2:
|
||||
logger.warning(f"Chart-related container fallback failed: {e2}")
|
||||
|
||||
# Fallback 3: Full page screenshot
|
||||
logger.warning("Using full page screenshot as final fallback")
|
||||
img = driver.get_screenshot_as_png()
|
||||
logger.info("Using full page screenshot as fallback")
|
||||
return img
|
||||
|
||||
|
||||
class PooledDashboardScreenshot(PooledBaseScreenshot):
|
||||
"""Pooled version of dashboard screenshot generation"""
|
||||
|
||||
thumbnail_type: str = "dashboard"
|
||||
element: str = "standalone"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
url: str,
|
||||
digest: str | None,
|
||||
window_size: WindowSize | None = None,
|
||||
thumb_size: WindowSize | None = None,
|
||||
):
|
||||
from superset.utils.urls import modify_url_query
|
||||
from superset.utils.webdriver import DashboardStandaloneMode
|
||||
|
||||
# Dashboard screenshots should always capture in standalone
|
||||
url = modify_url_query(
|
||||
url,
|
||||
standalone=DashboardStandaloneMode.REPORT.value,
|
||||
)
|
||||
super().__init__(url, digest)
|
||||
self.window_size = window_size or (1600, 1200)
|
||||
self.thumb_size = thumb_size or (800, 600)
|
||||
|
||||
def _take_screenshot(self, driver: Any, user: User) -> bytes | None:
|
||||
"""Take screenshot of dashboard with standard Superset dashboard handling"""
|
||||
try:
|
||||
# Wait for page to load
|
||||
selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"]
|
||||
time.sleep(selenium_headstart)
|
||||
|
||||
# Wait for dashboard element
|
||||
wait = WebDriverWait(driver, current_app.config["SCREENSHOT_LOCATE_WAIT"])
|
||||
element = wait.until(
|
||||
expected_conditions.presence_of_element_located(
|
||||
(By.CLASS_NAME, self.element)
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for chart containers to render
|
||||
try:
|
||||
wait.until(
|
||||
expected_conditions.visibility_of_all_elements_located(
|
||||
(By.CLASS_NAME, "chart-container")
|
||||
)
|
||||
)
|
||||
except TimeoutException:
|
||||
# Fallback for empty dashboards
|
||||
try:
|
||||
WebDriverWait(driver, 0).until(
|
||||
expected_conditions.visibility_of_all_elements_located(
|
||||
(By.CLASS_NAME, "grid-container")
|
||||
)
|
||||
)
|
||||
except TimeoutException:
|
||||
logger.exception(f"Dashboard failed to load at URL: {self.url}")
|
||||
raise
|
||||
|
||||
# Wait for loading to complete
|
||||
WebDriverWait(driver, current_app.config["SCREENSHOT_LOAD_WAIT"]).until_not(
|
||||
expected_conditions.presence_of_all_elements_located(
|
||||
(By.CLASS_NAME, "loading")
|
||||
)
|
||||
)
|
||||
|
||||
# Wait for animations
|
||||
animation_wait = current_app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"]
|
||||
time.sleep(animation_wait)
|
||||
|
||||
# Handle unexpected errors if configured
|
||||
if current_app.config.get("SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"):
|
||||
from superset.utils.webdriver import WebDriverSelenium
|
||||
|
||||
unexpected_errors = WebDriverSelenium.find_unexpected_errors(driver)
|
||||
if unexpected_errors:
|
||||
logger.warning(
|
||||
f"{len(unexpected_errors)} errors found in dashboard "
|
||||
f"screenshot. URL: {self.url}. Errors: {unexpected_errors}"
|
||||
)
|
||||
|
||||
# Take screenshot
|
||||
logger.debug(f"Taking PNG dashboard screenshot as user {user.username}")
|
||||
return element.screenshot_as_png
|
||||
|
||||
except TimeoutException:
|
||||
logger.exception(f"Timeout taking dashboard screenshot for URL: {self.url}")
|
||||
raise
|
||||
except WebDriverException:
|
||||
logger.exception(
|
||||
f"WebDriver error taking dashboard screenshot for URL: {self.url}"
|
||||
)
|
||||
raise
|
||||
26
superset/mcp_service/run_proxy.sh
Executable file
26
superset/mcp_service/run_proxy.sh
Executable file
@@ -0,0 +1,26 @@
|
||||
#!/bin/bash
|
||||
|
||||
## use in claude like this
|
||||
# "Superset MCP Proxy": {
|
||||
# "command": "~/github/superset/superset/mcp_service/run_proxy.sh",
|
||||
# "args": [],
|
||||
# "env": {}
|
||||
# },
|
||||
|
||||
# Get the directory where this script is located
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
|
||||
# Get the project root (two levels up from mcp_service)
|
||||
PROJECT_ROOT="$(cd "$SCRIPT_DIR/../.." && pwd)"
|
||||
|
||||
# Use python from the virtual environment if it exists, otherwise use system python
|
||||
if [ -f "$PROJECT_ROOT/venv/bin/python" ]; then
|
||||
PYTHON_PATH="$PROJECT_ROOT/venv/bin/python"
|
||||
elif [ -f "$PROJECT_ROOT/.venv/bin/python" ]; then
|
||||
PYTHON_PATH="$PROJECT_ROOT/.venv/bin/python"
|
||||
else
|
||||
PYTHON_PATH="python3"
|
||||
fi
|
||||
|
||||
# Run the proxy script
|
||||
"$PYTHON_PATH" "$SCRIPT_DIR/simple_proxy.py"
|
||||
88
superset/mcp_service/schemas/__init__.py
Normal file
88
superset/mcp_service/schemas/__init__.py
Normal file
@@ -0,0 +1,88 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP Service Schemas Package
|
||||
|
||||
This package contains Pydantic schemas for the MCP service responses.
|
||||
"""
|
||||
|
||||
from .chart_schemas import (
|
||||
ChartAvailableFiltersResponse,
|
||||
ChartError,
|
||||
ChartFilter,
|
||||
ChartInfo,
|
||||
ChartList,
|
||||
serialize_chart_object,
|
||||
)
|
||||
from .dashboard_schemas import (
|
||||
DashboardAvailableFilters,
|
||||
DashboardError,
|
||||
DashboardFilter,
|
||||
DashboardInfo,
|
||||
DashboardList,
|
||||
)
|
||||
from .dataset_schemas import (
|
||||
DatasetAvailableFilters,
|
||||
DatasetError,
|
||||
DatasetFilter,
|
||||
DatasetInfo,
|
||||
DatasetList,
|
||||
serialize_dataset_object,
|
||||
)
|
||||
from .system_schemas import (
|
||||
DashboardBreakdown,
|
||||
DatabaseBreakdown,
|
||||
InstanceInfo,
|
||||
InstanceSummary,
|
||||
PaginationInfo,
|
||||
PopularContent,
|
||||
RecentActivity,
|
||||
RoleInfo,
|
||||
TagInfo,
|
||||
UserInfo,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChartAvailableFiltersResponse",
|
||||
"ChartError",
|
||||
"ChartFilter",
|
||||
"ChartInfo",
|
||||
"ChartList",
|
||||
"DashboardAvailableFilters",
|
||||
"DashboardBreakdown",
|
||||
"DashboardError",
|
||||
"DashboardFilter",
|
||||
"DashboardInfo",
|
||||
"DashboardList",
|
||||
"DatabaseBreakdown",
|
||||
"DatasetAvailableFilters",
|
||||
"DatasetError",
|
||||
"DatasetFilter",
|
||||
"DatasetInfo",
|
||||
"DatasetList",
|
||||
"InstanceInfo",
|
||||
"InstanceSummary",
|
||||
"PaginationInfo",
|
||||
"PopularContent",
|
||||
"RecentActivity",
|
||||
"RoleInfo",
|
||||
"TagInfo",
|
||||
"UserInfo",
|
||||
"serialize_chart_object",
|
||||
"serialize_dataset_object",
|
||||
]
|
||||
145
superset/mcp_service/schemas/cache_schemas.py
Normal file
145
superset/mcp_service/schemas/cache_schemas.py
Normal file
@@ -0,0 +1,145 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Cache control schemas for MCP tools.
|
||||
|
||||
These schemas provide cache control parameters that leverage Superset's
|
||||
existing cache infrastructure including query result cache, metadata cache,
|
||||
form data cache, and dashboard cache.
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CacheControlMixin(BaseModel):
|
||||
"""
|
||||
Mixin for cache control parameters that can be added to any request schema.
|
||||
|
||||
Leverages Superset's existing cache layers:
|
||||
- Query Result Cache: Caches actual query results from customer databases
|
||||
- Metadata Cache: Caches table schemas, column info, etc.
|
||||
- Form Data Cache: Caches chart configurations for explore URLs
|
||||
- Dashboard Cache: Caches rendered dashboard components
|
||||
"""
|
||||
|
||||
use_cache: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether to use Superset's cache layers. When True, will serve from "
|
||||
"cache if available (query results, metadata, form data). When False, "
|
||||
"will bypass cache and fetch fresh data."
|
||||
),
|
||||
)
|
||||
|
||||
force_refresh: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to force refresh cached data. When True, will invalidate "
|
||||
"existing cache entries and fetch fresh data, then update the cache. "
|
||||
"Overrides use_cache=True if both are specified."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class QueryCacheControl(CacheControlMixin):
|
||||
"""
|
||||
Cache control specifically for data queries.
|
||||
|
||||
Used by tools that execute SQL queries against customer databases
|
||||
like get_chart_data, chart previews, and chart creation.
|
||||
"""
|
||||
|
||||
cache_timeout: Optional[int] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Override the default cache timeout in seconds for this query. "
|
||||
"If not specified, uses dataset-level or global cache settings. "
|
||||
"Set to 0 to disable caching for this specific query."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class MetadataCacheControl(CacheControlMixin):
|
||||
"""
|
||||
Cache control for metadata operations.
|
||||
|
||||
Used by tools that fetch database metadata like table schemas,
|
||||
column information, metrics, and dataset listings.
|
||||
"""
|
||||
|
||||
refresh_metadata: bool = Field(
|
||||
default=False,
|
||||
description=(
|
||||
"Whether to refresh metadata cache for datasets, tables, and columns. "
|
||||
"Useful when database schema has changed and you need fresh metadata."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class FormDataCacheControl(CacheControlMixin):
|
||||
"""
|
||||
Cache control for form data and chart configurations.
|
||||
|
||||
Used by tools that work with chart configurations and explore URLs
|
||||
like generate_explore_link and chart preview updates.
|
||||
"""
|
||||
|
||||
cache_form_data: bool = Field(
|
||||
default=True,
|
||||
description=(
|
||||
"Whether to cache the form data configuration for future use. "
|
||||
"When False, generates temporary configurations that are not cached."
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class CacheStatus(BaseModel):
|
||||
"""
|
||||
Information about cache usage in tool responses.
|
||||
|
||||
Provides transparency about whether data was served from cache
|
||||
or freshly fetched, helping users understand data freshness.
|
||||
"""
|
||||
|
||||
cache_hit: bool = Field(
|
||||
description="Whether the data was served from cache (True) or "
|
||||
"freshly fetched (False)"
|
||||
)
|
||||
|
||||
cache_type: Optional[str] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Type of cache used: 'query', 'metadata', 'form_data', 'dashboard', "
|
||||
"or 'none' if no cache was used"
|
||||
),
|
||||
)
|
||||
|
||||
cache_age_seconds: Optional[int] = Field(
|
||||
default=None, description="Age of cached data in seconds, if served from cache"
|
||||
)
|
||||
|
||||
cache_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Cache key used (for debugging), truncated if too long",
|
||||
)
|
||||
|
||||
refreshed: bool = Field(
|
||||
default=False, description="Whether cache was refreshed as part of this request"
|
||||
)
|
||||
741
superset/mcp_service/schemas/chart_schemas.py
Normal file
741
superset/mcp_service/schemas/chart_schemas.py
Normal file
@@ -0,0 +1,741 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for chart-related responses
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator, PositiveInt
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.schemas.cache_schemas import (
|
||||
CacheStatus,
|
||||
FormDataCacheControl,
|
||||
MetadataCacheControl,
|
||||
QueryCacheControl,
|
||||
)
|
||||
from superset.mcp_service.schemas.system_schemas import (
|
||||
PaginationInfo,
|
||||
TagInfo,
|
||||
UserInfo,
|
||||
)
|
||||
|
||||
|
||||
class ChartLike(Protocol):
|
||||
"""Protocol for chart-like objects with expected attributes."""
|
||||
|
||||
id: int
|
||||
slice_name: Optional[str]
|
||||
viz_type: Optional[str]
|
||||
datasource_name: Optional[str]
|
||||
datasource_type: Optional[str]
|
||||
url: Optional[str]
|
||||
description: Optional[str]
|
||||
cache_timeout: Optional[int]
|
||||
form_data: Optional[Dict[str, Any]]
|
||||
query_context: Optional[Any]
|
||||
changed_by: Optional[Any] # User object
|
||||
changed_by_name: Optional[str]
|
||||
changed_on: Optional[str | datetime]
|
||||
changed_on_humanized: Optional[str]
|
||||
created_by: Optional[Any] # User object
|
||||
created_by_name: Optional[str]
|
||||
created_on: Optional[str | datetime]
|
||||
created_on_humanized: Optional[str]
|
||||
uuid: Optional[str]
|
||||
tags: Optional[List[Any]]
|
||||
owners: Optional[List[Any]]
|
||||
|
||||
|
||||
class ChartInfo(BaseModel):
|
||||
"""Full chart model with all possible attributes."""
|
||||
|
||||
id: int = Field(..., description="Chart ID")
|
||||
slice_name: str = Field(..., description="Chart name")
|
||||
viz_type: Optional[str] = Field(None, description="Visualization type")
|
||||
datasource_name: Optional[str] = Field(None, description="Datasource name")
|
||||
datasource_type: Optional[str] = Field(None, description="Datasource type")
|
||||
url: Optional[str] = Field(None, description="Chart URL")
|
||||
description: Optional[str] = Field(None, description="Chart description")
|
||||
cache_timeout: Optional[int] = Field(None, description="Cache timeout")
|
||||
form_data: Optional[Dict[str, Any]] = Field(None, description="Chart form data")
|
||||
query_context: Optional[Any] = Field(None, description="Query context")
|
||||
changed_by: Optional[str] = Field(None, description="Last modifier (username)")
|
||||
changed_by_name: Optional[str] = Field(
|
||||
None, description="Last modifier (display name)"
|
||||
)
|
||||
changed_on: Optional[str | datetime] = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
changed_on_humanized: Optional[str] = Field(
|
||||
None, description="Humanized modification time"
|
||||
)
|
||||
created_by: Optional[str] = Field(None, description="Chart creator (username)")
|
||||
created_on: Optional[str | datetime] = Field(None, description="Creation timestamp")
|
||||
created_on_humanized: Optional[str] = Field(
|
||||
None, description="Humanized creation time"
|
||||
)
|
||||
uuid: Optional[str] = Field(None, description="Chart UUID")
|
||||
tags: List[TagInfo] = Field(default_factory=list, description="Chart tags")
|
||||
owners: List[UserInfo] = Field(default_factory=list, description="Chart owners")
|
||||
model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class GetChartAvailableFiltersRequest(BaseModel):
|
||||
"""
|
||||
Request schema for get_chart_available_filters tool.
|
||||
|
||||
Currently has no parameters but provides consistent API for future extensibility.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
str_strip_whitespace=True,
|
||||
)
|
||||
|
||||
|
||||
class ChartAvailableFiltersResponse(BaseModel):
|
||||
column_operators: Dict[str, Any] = Field(
|
||||
..., description="Available filter operators and metadata for each column"
|
||||
)
|
||||
|
||||
|
||||
class ChartError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: Optional[str | datetime] = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ChartCapabilities(BaseModel):
|
||||
"""Describes what the chart can do for LLM understanding."""
|
||||
|
||||
supports_interaction: bool = Field(description="Chart supports user interaction")
|
||||
supports_real_time: bool = Field(description="Chart supports live data updates")
|
||||
supports_drill_down: bool = Field(
|
||||
description="Chart supports drill-down navigation"
|
||||
)
|
||||
supports_export: bool = Field(description="Chart can be exported to other formats")
|
||||
optimal_formats: List[str] = Field(description="Recommended preview formats")
|
||||
data_types: List[str] = Field(
|
||||
description="Types of data shown (time_series, categorical, etc)"
|
||||
)
|
||||
|
||||
|
||||
class ChartSemantics(BaseModel):
|
||||
"""Semantic information for LLM reasoning."""
|
||||
|
||||
primary_insight: str = Field(
|
||||
description="Main insight or pattern the chart reveals"
|
||||
)
|
||||
data_story: str = Field(description="Narrative description of what the data shows")
|
||||
recommended_actions: List[str] = Field(
|
||||
description="Suggested next steps based on data"
|
||||
)
|
||||
anomalies: List[str] = Field(description="Notable outliers or unusual patterns")
|
||||
statistical_summary: Dict[str, Any] = Field(
|
||||
description="Key statistics (mean, median, trends)"
|
||||
)
|
||||
|
||||
|
||||
class PerformanceMetadata(BaseModel):
|
||||
"""Performance information for LLM cost understanding."""
|
||||
|
||||
query_duration_ms: int = Field(description="Query execution time")
|
||||
estimated_cost: Optional[str] = Field(None, description="Resource cost estimate")
|
||||
cache_status: str = Field(description="Cache hit/miss status")
|
||||
optimization_suggestions: List[str] = Field(
|
||||
default_factory=list, description="Performance improvement tips"
|
||||
)
|
||||
|
||||
|
||||
class AccessibilityMetadata(BaseModel):
|
||||
"""Accessibility information for inclusive visualization."""
|
||||
|
||||
color_blind_safe: bool = Field(description="Uses colorblind-safe palette")
|
||||
alt_text: str = Field(description="Screen reader description")
|
||||
high_contrast_available: bool = Field(description="High contrast version available")
|
||||
|
||||
|
||||
class VersionedResponse(BaseModel):
|
||||
"""Base class for versioned API responses."""
|
||||
|
||||
schema_version: str = Field("2.0", description="Response schema version")
|
||||
api_version: str = Field("v1", description="MCP API version")
|
||||
|
||||
|
||||
class GetChartInfoRequest(BaseModel):
|
||||
"""Request schema for get_chart_info with support for ID or UUID."""
|
||||
|
||||
identifier: Annotated[
|
||||
int | str,
|
||||
Field(description="Chart identifier - can be numeric ID or UUID string"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
|
||||
if not chart:
|
||||
return None
|
||||
|
||||
# Generate MCP service screenshot URL instead of chart's native URL
|
||||
from superset.mcp_service.url_utils import get_chart_screenshot_url
|
||||
|
||||
chart_id = getattr(chart, "id", None)
|
||||
screenshot_url = None
|
||||
if chart_id:
|
||||
screenshot_url = get_chart_screenshot_url(chart_id)
|
||||
|
||||
return ChartInfo(
|
||||
id=chart_id,
|
||||
slice_name=getattr(chart, "slice_name", None),
|
||||
viz_type=getattr(chart, "viz_type", None),
|
||||
datasource_name=getattr(chart, "datasource_name", None),
|
||||
datasource_type=getattr(chart, "datasource_type", None),
|
||||
url=screenshot_url,
|
||||
description=getattr(chart, "description", None),
|
||||
cache_timeout=getattr(chart, "cache_timeout", None),
|
||||
form_data=getattr(chart, "form_data", None),
|
||||
query_context=getattr(chart, "query_context", None),
|
||||
changed_by=getattr(chart, "changed_by_name", None)
|
||||
or (str(chart.changed_by) if getattr(chart, "changed_by", None) else None),
|
||||
changed_by_name=getattr(chart, "changed_by_name", None),
|
||||
changed_on=getattr(chart, "changed_on", None),
|
||||
changed_on_humanized=getattr(chart, "changed_on_humanized", None),
|
||||
created_by=getattr(chart, "created_by_name", None)
|
||||
or (str(chart.created_by) if getattr(chart, "created_by", None) else None),
|
||||
created_on=getattr(chart, "created_on", None),
|
||||
created_on_humanized=getattr(chart, "created_on_humanized", None),
|
||||
uuid=str(getattr(chart, "uuid", "")) if getattr(chart, "uuid", None) else None,
|
||||
tags=[
|
||||
TagInfo.model_validate(tag, from_attributes=True)
|
||||
for tag in getattr(chart, "tags", [])
|
||||
]
|
||||
if getattr(chart, "tags", None)
|
||||
else [],
|
||||
owners=[
|
||||
UserInfo.model_validate(owner, from_attributes=True)
|
||||
for owner in getattr(chart, "owners", [])
|
||||
]
|
||||
if getattr(chart, "owners", None)
|
||||
else [],
|
||||
)
|
||||
|
||||
|
||||
class GenerateChartResponse(BaseModel):
|
||||
"""Comprehensive chart creation response with rich metadata."""
|
||||
|
||||
# Core chart information
|
||||
chart: Optional[ChartInfo] = Field(None, description="Complete chart metadata")
|
||||
|
||||
# Multiple preview formats available
|
||||
previews: Dict[str, ChartPreviewContent] = Field(
|
||||
default_factory=dict,
|
||||
description="Available preview formats keyed by format type",
|
||||
)
|
||||
|
||||
# LLM-friendly capabilities
|
||||
capabilities: Optional[ChartCapabilities] = Field(
|
||||
None, description="Chart interaction capabilities"
|
||||
)
|
||||
semantics: Optional[ChartSemantics] = Field(
|
||||
None, description="Semantic chart understanding"
|
||||
)
|
||||
|
||||
# Navigation and context
|
||||
explore_url: Optional[str] = Field(None, description="Edit chart in Superset")
|
||||
embed_code: Optional[str] = Field(None, description="HTML embed snippet")
|
||||
api_endpoints: Dict[str, str] = Field(
|
||||
default_factory=dict, description="Related API endpoints for data/updates"
|
||||
)
|
||||
|
||||
# Performance and accessibility
|
||||
performance: Optional[PerformanceMetadata] = Field(
|
||||
None, description="Performance metrics"
|
||||
)
|
||||
accessibility: Optional[AccessibilityMetadata] = Field(
|
||||
None, description="Accessibility info"
|
||||
)
|
||||
|
||||
# Success/error handling
|
||||
success: bool = Field(True, description="Whether chart creation succeeded")
|
||||
error: Optional[ChartError] = Field(
|
||||
None, description="Error details if creation failed"
|
||||
)
|
||||
warnings: List[str] = Field(default_factory=list, description="Non-fatal warnings")
|
||||
|
||||
# Inherit versioning
|
||||
schema_version: str = Field("2.0", description="Response schema version")
|
||||
api_version: str = Field("v1", description="MCP API version")
|
||||
|
||||
|
||||
class ChartFilter(ColumnOperator):
|
||||
"""
|
||||
Filter object for chart listing.
|
||||
col: The column to filter on. Must be one of the allowed filter fields.
|
||||
opr: The operator to use. Must be one of the supported operators.
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
"slice_name",
|
||||
"viz_type",
|
||||
"datasource_name",
|
||||
] = Field(
|
||||
...,
|
||||
description="Column to filter on. See get_chart_available_filters for "
|
||||
"allowed values.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
description="Operator to use. See get_chart_available_filters for "
|
||||
"allowed values.",
|
||||
)
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by (type depends on col and opr)"
|
||||
)
|
||||
|
||||
|
||||
class ChartList(BaseModel):
|
||||
charts: List[ChartInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: Optional[List[str]] = None
|
||||
columns_loaded: Optional[List[str]] = None
|
||||
filters_applied: List[ChartFilter] = Field(
|
||||
default_factory=list,
|
||||
description="List of advanced filter dicts applied to the query.",
|
||||
)
|
||||
pagination: Optional[PaginationInfo] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
# --- Simplified schemas for generate_chart tool ---
|
||||
|
||||
|
||||
# Common pieces
|
||||
class ColumnRef(BaseModel):
|
||||
name: str = Field(..., description="Column name")
|
||||
label: Optional[str] = Field(None, description="Display label for the column")
|
||||
dtype: Optional[str] = Field(None, description="Data type hint")
|
||||
aggregate: Optional[str] = Field(
|
||||
None, description="SQL aggregation function (SUM, COUNT, AVG, MIN, MAX, etc.)"
|
||||
)
|
||||
|
||||
|
||||
class AxisConfig(BaseModel):
|
||||
title: Optional[str] = Field(None, description="Axis title")
|
||||
scale: Optional[Literal["linear", "log"]] = Field(
|
||||
"linear", description="Axis scale type"
|
||||
)
|
||||
format: Optional[str] = Field(None, description="Format string (e.g. '$,.2f')")
|
||||
|
||||
|
||||
class LegendConfig(BaseModel):
|
||||
show: bool = Field(True, description="Whether to show legend")
|
||||
position: Optional[Literal["top", "bottom", "left", "right"]] = Field(
|
||||
"right", description="Legend position"
|
||||
)
|
||||
|
||||
|
||||
class FilterConfig(BaseModel):
|
||||
column: str = Field(..., description="Column to filter on")
|
||||
op: Literal["=", ">", "<", ">=", "<=", "!="] = Field(
|
||||
..., description="Filter operator"
|
||||
)
|
||||
value: str | int | float | bool = Field(..., description="Filter value")
|
||||
|
||||
|
||||
# Actual chart types
|
||||
class TableChartConfig(BaseModel):
|
||||
chart_type: Literal["table"] = Field("table", description="Chart type")
|
||||
columns: List[ColumnRef] = Field(..., description="Columns to display")
|
||||
filters: Optional[List[FilterConfig]] = Field(None, description="Filters to apply")
|
||||
sort_by: Optional[List[str]] = Field(None, description="Columns to sort by")
|
||||
|
||||
|
||||
class XYChartConfig(BaseModel):
|
||||
chart_type: Literal["xy"] = Field("xy", description="Chart type")
|
||||
x: ColumnRef = Field(..., description="X-axis column")
|
||||
y: List[ColumnRef] = Field(..., description="Y-axis columns")
|
||||
kind: Literal["line", "bar", "area", "scatter"] = Field(
|
||||
"line", description="Chart visualization type"
|
||||
)
|
||||
group_by: Optional[ColumnRef] = Field(None, description="Column to group by")
|
||||
x_axis: Optional[AxisConfig] = Field(None, description="X-axis configuration")
|
||||
y_axis: Optional[AxisConfig] = Field(None, description="Y-axis configuration")
|
||||
legend: Optional[LegendConfig] = Field(None, description="Legend configuration")
|
||||
filters: Optional[List[FilterConfig]] = Field(None, description="Filters to apply")
|
||||
|
||||
|
||||
# Discriminated union entry point
|
||||
ChartConfig = TableChartConfig | XYChartConfig
|
||||
|
||||
|
||||
class ListChartsRequest(MetadataCacheControl):
|
||||
"""Request schema for list_charts with clear, unambiguous types."""
|
||||
|
||||
filters: Annotated[
|
||||
List[ChartFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (column, operator, value). Each "
|
||||
"filter is an object with 'col', 'opr', and 'value' "
|
||||
"properties. Cannot be used together with 'search'.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=lambda: [
|
||||
"id",
|
||||
"slice_name",
|
||||
"viz_type",
|
||||
"datasource_name",
|
||||
"description",
|
||||
"changed_by_name",
|
||||
"created_by_name",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
"uuid",
|
||||
],
|
||||
description="List of columns to select. Defaults to common columns if not "
|
||||
"specified.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
Optional[str],
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search string to match against chart fields. Cannot be "
|
||||
"used together with 'filters'.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
Optional[str], Field(default=None, description="Column to order results by")
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(
|
||||
default="asc", description="Direction to order results ('asc' or 'desc')"
|
||||
),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number for pagination (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
PositiveInt, Field(default=100, description="Number of items per page")
|
||||
]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListChartsRequest":
|
||||
"""Prevent using both search and filters simultaneously to avoid query
|
||||
conflicts."""
|
||||
if self.search and self.filters:
|
||||
raise ValueError(
|
||||
"Cannot use both 'search' and 'filters' parameters simultaneously. "
|
||||
"Use either 'search' for text-based searching across multiple fields, "
|
||||
"or 'filters' for precise column-based filtering, but not both."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
# The tool input models
|
||||
class GenerateChartRequest(QueryCacheControl):
|
||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="Chart configuration")
|
||||
save_chart: bool = Field(
|
||||
default=True,
|
||||
description="Whether to permanently save the chart in Superset",
|
||||
)
|
||||
generate_preview: bool = Field(
|
||||
default=True,
|
||||
description="Whether to generate a preview image",
|
||||
)
|
||||
preview_formats: List[
|
||||
Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"]
|
||||
] = Field(
|
||||
default_factory=lambda: ["url"],
|
||||
description="List of preview formats to generate",
|
||||
)
|
||||
|
||||
|
||||
class GenerateExploreLinkRequest(FormDataCacheControl):
|
||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="Chart configuration")
|
||||
|
||||
|
||||
class UpdateChartRequest(QueryCacheControl):
|
||||
identifier: int | str = Field(..., description="Chart identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="New chart configuration")
|
||||
chart_name: Optional[str] = Field(
|
||||
None,
|
||||
description="New chart name (optional, will auto-generate if not provided)",
|
||||
)
|
||||
generate_preview: bool = Field(
|
||||
default=True,
|
||||
description="Whether to generate a preview after updating",
|
||||
)
|
||||
preview_formats: List[
|
||||
Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"]
|
||||
] = Field(
|
||||
default_factory=lambda: ["url"],
|
||||
description="List of preview formats to generate",
|
||||
)
|
||||
|
||||
|
||||
class UpdateChartPreviewRequest(FormDataCacheControl):
|
||||
form_data_key: str = Field(..., description="Existing form_data_key to update")
|
||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="New chart configuration")
|
||||
generate_preview: bool = Field(
|
||||
default=True,
|
||||
description="Whether to generate a preview after updating",
|
||||
)
|
||||
preview_formats: List[
|
||||
Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"]
|
||||
] = Field(
|
||||
default_factory=lambda: ["url"],
|
||||
description="List of preview formats to generate",
|
||||
)
|
||||
|
||||
|
||||
class GetChartDataRequest(QueryCacheControl):
|
||||
"""Request for chart data with cache control."""
|
||||
|
||||
identifier: int | str = Field(description="Chart identifier (ID, UUID)")
|
||||
limit: Optional[int] = Field(
|
||||
default=100, description="Maximum number of data rows to return"
|
||||
)
|
||||
format: Literal["json", "csv", "excel"] = Field(
|
||||
default="json", description="Data export format"
|
||||
)
|
||||
|
||||
|
||||
class DataColumn(BaseModel):
|
||||
"""Enhanced column metadata with semantic information."""
|
||||
|
||||
name: str = Field(..., description="Column name")
|
||||
display_name: str = Field(..., description="Human-readable column name")
|
||||
data_type: str = Field(..., description="Inferred data type")
|
||||
sample_values: List[Any] = Field(description="Representative sample values")
|
||||
null_count: int = Field(description="Number of null values")
|
||||
unique_count: int = Field(description="Number of unique values")
|
||||
statistics: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Column statistics if numeric"
|
||||
)
|
||||
semantic_type: Optional[str] = Field(
|
||||
None, description="Semantic type (currency, percentage, etc)"
|
||||
)
|
||||
|
||||
|
||||
class ChartData(BaseModel):
|
||||
"""Rich chart data response with statistical insights."""
|
||||
|
||||
# Basic information
|
||||
chart_id: int
|
||||
chart_name: str
|
||||
chart_type: str
|
||||
|
||||
# Enhanced data description
|
||||
columns: List[DataColumn] = Field(description="Rich column metadata")
|
||||
data: List[Dict[str, Any]] = Field(description="Actual data rows")
|
||||
|
||||
# Data insights
|
||||
row_count: int = Field(description="Rows returned")
|
||||
total_rows: Optional[int] = Field(description="Total available rows")
|
||||
data_freshness: Optional[datetime] = Field(description="When data was last updated")
|
||||
|
||||
# LLM-friendly summaries
|
||||
summary: str = Field(description="Human-readable data summary")
|
||||
insights: List[str] = Field(description="Key patterns discovered in the data")
|
||||
data_quality: Dict[str, Any] = Field(description="Data quality assessment")
|
||||
recommended_visualizations: List[str] = Field(
|
||||
description="Suggested chart types for this data"
|
||||
)
|
||||
|
||||
# Performance and metadata
|
||||
performance: PerformanceMetadata = Field(description="Query performance metrics")
|
||||
cache_status: Optional[CacheStatus] = Field(
|
||||
None, description="Cache usage information"
|
||||
)
|
||||
|
||||
# Inherit versioning
|
||||
schema_version: str = Field("2.0", description="Response schema version")
|
||||
api_version: str = Field("v1", description="MCP API version")
|
||||
|
||||
|
||||
class GetChartPreviewRequest(QueryCacheControl):
|
||||
"""Request for chart preview with cache control."""
|
||||
|
||||
identifier: int | str = Field(description="Chart identifier (ID, UUID)")
|
||||
format: Literal["url", "ascii", "table", "base64"] = Field(
|
||||
default="url",
|
||||
description=(
|
||||
"Preview format: 'url' for image URL, 'ascii' for text art, "
|
||||
"'table' for data table, 'base64' for embedded image"
|
||||
),
|
||||
)
|
||||
width: Optional[int] = Field(
|
||||
default=800,
|
||||
description="Preview image width in pixels (for url/base64 formats)",
|
||||
)
|
||||
height: Optional[int] = Field(
|
||||
default=600,
|
||||
description="Preview image height in pixels (for url/base64 formats)",
|
||||
)
|
||||
ascii_width: Optional[int] = Field(
|
||||
default=80, description="ASCII chart width in characters (for ascii format)"
|
||||
)
|
||||
ascii_height: Optional[int] = Field(
|
||||
default=20, description="ASCII chart height in lines (for ascii format)"
|
||||
)
|
||||
|
||||
|
||||
# Discriminated union preview formats for type safety
|
||||
class URLPreview(BaseModel):
|
||||
"""URL-based image preview format."""
|
||||
|
||||
type: Literal["url"] = "url"
|
||||
preview_url: str = Field(..., description="Direct image URL")
|
||||
width: int = Field(..., description="Image width in pixels")
|
||||
height: int = Field(..., description="Image height in pixels")
|
||||
supports_interaction: bool = Field(
|
||||
False, description="Static image, no interaction"
|
||||
)
|
||||
|
||||
|
||||
class InteractivePreview(BaseModel):
|
||||
"""Interactive HTML preview with JavaScript controls."""
|
||||
|
||||
type: Literal["interactive"] = "interactive"
|
||||
html_content: str = Field(..., description="Embeddable HTML with Plotly/D3")
|
||||
preview_url: str = Field(..., description="Iframe-compatible URL")
|
||||
width: int = Field(..., description="Viewport width")
|
||||
height: int = Field(..., description="Viewport height")
|
||||
supports_pan: bool = Field(True, description="Supports pan interaction")
|
||||
supports_zoom: bool = Field(True, description="Supports zoom interaction")
|
||||
supports_hover: bool = Field(True, description="Supports hover details")
|
||||
|
||||
|
||||
class ASCIIPreview(BaseModel):
|
||||
"""ASCII art text representation."""
|
||||
|
||||
type: Literal["ascii"] = "ascii"
|
||||
ascii_content: str = Field(..., description="Unicode art representation")
|
||||
width: int = Field(..., description="Character width")
|
||||
height: int = Field(..., description="Line height")
|
||||
supports_color: bool = Field(False, description="Uses ANSI color codes")
|
||||
|
||||
|
||||
class VegaLitePreview(BaseModel):
|
||||
"""Vega-Lite grammar of graphics specification."""
|
||||
|
||||
type: Literal["vega_lite"] = "vega_lite"
|
||||
specification: Dict[str, Any] = Field(..., description="Vega-Lite JSON spec")
|
||||
data_url: Optional[str] = Field(None, description="External data URL")
|
||||
supports_streaming: bool = Field(False, description="Supports live data updates")
|
||||
|
||||
|
||||
class TablePreview(BaseModel):
|
||||
"""Tabular data preview format."""
|
||||
|
||||
type: Literal["table"] = "table"
|
||||
table_data: str = Field(..., description="Formatted table content")
|
||||
row_count: int = Field(..., description="Number of rows displayed")
|
||||
supports_sorting: bool = Field(False, description="Table supports sorting")
|
||||
|
||||
|
||||
class Base64Preview(BaseModel):
|
||||
"""Base64 encoded image for embedding."""
|
||||
|
||||
type: Literal["base64"] = "base64"
|
||||
base64_image: str = Field(..., description="Base64 encoded PNG data")
|
||||
width: int = Field(..., description="Image width in pixels")
|
||||
height: int = Field(..., description="Image height in pixels")
|
||||
mime_type: str = Field("image/png", description="Image MIME type")
|
||||
|
||||
|
||||
# Modern discriminated union using | syntax
|
||||
ChartPreviewContent = Annotated[
|
||||
URLPreview
|
||||
| InteractivePreview
|
||||
| ASCIIPreview
|
||||
| VegaLitePreview
|
||||
| TablePreview
|
||||
| Base64Preview,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
||||
class ChartPreview(BaseModel):
|
||||
"""Enhanced chart preview with discriminated union content."""
|
||||
|
||||
chart_id: int
|
||||
chart_name: str
|
||||
chart_type: str = Field(description="Type of chart visualization")
|
||||
explore_url: str = Field(description="URL to open chart in Superset for editing")
|
||||
|
||||
# Type-safe preview content
|
||||
content: ChartPreviewContent = Field(
|
||||
description="Preview content in requested format"
|
||||
)
|
||||
|
||||
# Rich metadata
|
||||
chart_description: str = Field(
|
||||
description="Human-readable description of the chart"
|
||||
)
|
||||
accessibility: AccessibilityMetadata = Field(
|
||||
description="Accessibility information"
|
||||
)
|
||||
performance: PerformanceMetadata = Field(description="Performance metrics")
|
||||
|
||||
# Backward compatibility fields (populated based on content type)
|
||||
format: Optional[str] = Field(
|
||||
None, description="Format of the preview (url, ascii, table, base64)"
|
||||
)
|
||||
preview_url: Optional[str] = Field(None, description="Image URL for 'url' format")
|
||||
ascii_chart: Optional[str] = Field(
|
||||
None, description="ASCII art chart for 'ascii' format"
|
||||
)
|
||||
table_data: Optional[str] = Field(
|
||||
None, description="Formatted table for 'table' format"
|
||||
)
|
||||
base64_image: Optional[str] = Field(
|
||||
None, description="Base64 encoded PNG for 'base64' format"
|
||||
)
|
||||
width: Optional[int] = Field(
|
||||
None, description="Width (pixels for images, characters for ASCII)"
|
||||
)
|
||||
height: Optional[int] = Field(
|
||||
None, description="Height (pixels for images, lines for ASCII)"
|
||||
)
|
||||
|
||||
# Inherit versioning
|
||||
schema_version: str = Field("2.0", description="Response schema version")
|
||||
api_version: str = Field("v1", description="MCP API version")
|
||||
344
superset/mcp_service/schemas/dashboard_schemas.py
Normal file
344
superset/mcp_service/schemas/dashboard_schemas.py
Normal file
@@ -0,0 +1,344 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for dashboard-related responses
|
||||
|
||||
This module contains Pydantic models for serializing dashboard data
|
||||
in a consistent and type-safe manner.
|
||||
|
||||
Example usage:
|
||||
# For detailed dashboard info
|
||||
dashboard_info = DashboardInfo(
|
||||
id=1,
|
||||
dashboard_title="Sales Dashboard",
|
||||
published=True,
|
||||
owners=[UserInfo(id=1, username="admin")],
|
||||
charts=[ChartInfo(id=1, slice_name="Sales Chart")]
|
||||
)
|
||||
|
||||
# For dashboard list responses
|
||||
dashboard_list = DashboardList(
|
||||
dashboards=[
|
||||
DashboardInfo(
|
||||
id=1,
|
||||
dashboard_title="Sales Dashboard",
|
||||
published=True,
|
||||
tags=[TagInfo(id=1, name="sales")]
|
||||
)
|
||||
],
|
||||
count=1,
|
||||
total_count=1,
|
||||
page=0,
|
||||
page_size=10,
|
||||
total_pages=1,
|
||||
has_next=False,
|
||||
has_previous=False,
|
||||
columns_requested=["id", "dashboard_title"],
|
||||
columns_loaded=["id", "dashboard_title", "published"],
|
||||
filters_applied={"published": True},
|
||||
pagination=PaginationInfo(
|
||||
page=0,
|
||||
page_size=10,
|
||||
total_count=1,
|
||||
total_pages=1,
|
||||
has_next=False,
|
||||
has_previous=False
|
||||
),
|
||||
timestamp=datetime.now(timezone.utc)
|
||||
)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator, PositiveInt
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.schemas.cache_schemas import MetadataCacheControl
|
||||
from superset.mcp_service.schemas.chart_schemas import ChartInfo
|
||||
from superset.mcp_service.schemas.system_schemas import (
|
||||
PaginationInfo,
|
||||
RoleInfo,
|
||||
TagInfo,
|
||||
UserInfo,
|
||||
)
|
||||
|
||||
|
||||
class DashboardError(BaseModel):
|
||||
"""Error response for dashboard operations"""
|
||||
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: Optional[str | datetime] = Field(None, description="Error timestamp")
|
||||
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
def serialize_user_object(user: Any) -> Optional[UserInfo]:
|
||||
"""Serialize a user object to UserInfo"""
|
||||
if not user:
|
||||
return None
|
||||
|
||||
return UserInfo(
|
||||
id=getattr(user, "id", None),
|
||||
username=getattr(user, "username", None),
|
||||
first_name=getattr(user, "first_name", None),
|
||||
last_name=getattr(user, "last_name", None),
|
||||
email=getattr(user, "email", None),
|
||||
active=getattr(user, "active", None),
|
||||
)
|
||||
|
||||
|
||||
def serialize_tag_object(tag: Any) -> Optional[TagInfo]:
|
||||
"""Serialize a tag object to TagInfo"""
|
||||
if not tag:
|
||||
return None
|
||||
|
||||
return TagInfo(
|
||||
id=getattr(tag, "id", None),
|
||||
name=getattr(tag, "name", None),
|
||||
type=getattr(tag, "type", None),
|
||||
description=getattr(tag, "description", None),
|
||||
)
|
||||
|
||||
|
||||
def serialize_role_object(role: Any) -> Optional[RoleInfo]:
|
||||
"""Serialize a role object to RoleInfo"""
|
||||
if not role:
|
||||
return None
|
||||
|
||||
return RoleInfo(
|
||||
id=getattr(role, "id", None),
|
||||
name=getattr(role, "name", None),
|
||||
permissions=[perm.name for perm in getattr(role, "permissions", [])]
|
||||
if hasattr(role, "permissions")
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
def serialize_chart_object(chart: Any) -> Optional[ChartInfo]:
|
||||
"""Serialize a chart object to Chart"""
|
||||
if not chart:
|
||||
return None
|
||||
|
||||
return ChartInfo(
|
||||
id=getattr(chart, "id", None),
|
||||
slice_name=getattr(chart, "slice_name", None),
|
||||
viz_type=getattr(chart, "viz_type", None),
|
||||
datasource_name=getattr(chart, "datasource_name", None),
|
||||
datasource_type=getattr(chart, "datasource_type", None),
|
||||
url=getattr(chart, "url", None),
|
||||
description=getattr(chart, "description", None),
|
||||
cache_timeout=getattr(chart, "cache_timeout", None),
|
||||
form_data=getattr(chart, "form_data", None),
|
||||
query_context=getattr(chart, "query_context", None),
|
||||
created_by=serialize_user_object(getattr(chart, "created_by", None)),
|
||||
changed_by=serialize_user_object(getattr(chart, "changed_by", None)),
|
||||
created_on=getattr(chart, "created_on", None),
|
||||
changed_on=getattr(chart, "changed_on", None),
|
||||
)
|
||||
|
||||
|
||||
class DashboardAvailableFilters(BaseModel):
|
||||
column_operators: Dict[str, Any] = Field(
|
||||
..., description="Available filter operators and metadata for each column"
|
||||
)
|
||||
|
||||
|
||||
class DashboardFilter(ColumnOperator):
|
||||
"""
|
||||
Filter object for dashboard listing.
|
||||
col: The column to filter on. Must be one of the allowed filter fields.
|
||||
opr: The operator to use. Must be one of the supported operators.
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
"dashboard_title",
|
||||
"published",
|
||||
"favorite",
|
||||
] = Field(
|
||||
...,
|
||||
description="Column to filter on. See get_dashboard_available_filters for "
|
||||
"allowed values.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
description="Operator to use. See get_dashboard_available_filters for "
|
||||
"allowed values.",
|
||||
)
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by (type depends on col and opr)"
|
||||
)
|
||||
|
||||
|
||||
class ListDashboardsRequest(MetadataCacheControl):
|
||||
"""Request schema for list_dashboards with clear, unambiguous types."""
|
||||
|
||||
filters: Annotated[
|
||||
List[DashboardFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (column, operator, value). Each "
|
||||
"filter is an object with 'col', 'opr', and 'value' properties. "
|
||||
"Cannot be used together with 'search'.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=lambda: [
|
||||
"id",
|
||||
"dashboard_title",
|
||||
"slug",
|
||||
"published",
|
||||
"changed_on",
|
||||
"created_on",
|
||||
"uuid",
|
||||
],
|
||||
description="List of columns to select. Defaults to common columns "
|
||||
"if not specified.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
Optional[str],
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search string to match against dashboard fields. "
|
||||
"Cannot be used together with 'filters'.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
Optional[str], Field(default=None, description="Column to order results by")
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(
|
||||
default="asc", description="Direction to order results ('asc' or 'desc')"
|
||||
),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number for pagination (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
PositiveInt, Field(default=100, description="Number of items per page")
|
||||
]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListDashboardsRequest":
|
||||
"""Prevent using both search and filters simultaneously to avoid query
|
||||
conflicts."""
|
||||
if self.search and self.filters:
|
||||
raise ValueError(
|
||||
"Cannot use both 'search' and 'filters' parameters simultaneously. "
|
||||
"Use either 'search' for text-based searching across multiple fields, "
|
||||
"or 'filters' for precise column-based filtering, but not both."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class GetDashboardInfoRequest(MetadataCacheControl):
|
||||
"""Request schema for get_dashboard_info with support for ID, UUID, or slug."""
|
||||
|
||||
identifier: Annotated[
|
||||
int | str,
|
||||
Field(
|
||||
description="Dashboard identifier - can be numeric ID, UUID string, or slug"
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class DashboardInfo(BaseModel):
|
||||
id: int = Field(..., description="Dashboard ID")
|
||||
dashboard_title: str = Field(..., description="Dashboard title")
|
||||
slug: Optional[str] = Field(None, description="Dashboard slug")
|
||||
description: Optional[str] = Field(None, description="Dashboard description")
|
||||
css: Optional[str] = Field(None, description="Custom CSS for the dashboard")
|
||||
certified_by: Optional[str] = Field(None, description="Who certified the dashboard")
|
||||
certification_details: Optional[str] = Field(
|
||||
None, description="Certification details"
|
||||
)
|
||||
json_metadata: Optional[str] = Field(
|
||||
None, description="Dashboard metadata (JSON string)"
|
||||
)
|
||||
position_json: Optional[str] = Field(
|
||||
None, description="Chart positions (JSON string)"
|
||||
)
|
||||
published: Optional[bool] = Field(
|
||||
None, description="Whether the dashboard is published"
|
||||
)
|
||||
is_managed_externally: Optional[bool] = Field(
|
||||
None, description="Whether managed externally"
|
||||
)
|
||||
external_url: Optional[str] = Field(None, description="External URL")
|
||||
created_on: Optional[str | datetime] = Field(None, description="Creation timestamp")
|
||||
changed_on: Optional[str | datetime] = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
created_by: Optional[str] = Field(None, description="Dashboard creator (username)")
|
||||
changed_by: Optional[str] = Field(None, description="Last modifier (username)")
|
||||
uuid: Optional[str] = Field(
|
||||
None, description="Dashboard UUID (converted to string)"
|
||||
)
|
||||
url: Optional[str] = Field(None, description="Dashboard URL")
|
||||
thumbnail_url: Optional[str] = Field(None, description="Thumbnail URL")
|
||||
created_on_humanized: Optional[str] = Field(
|
||||
None, description="Humanized creation time"
|
||||
)
|
||||
changed_on_humanized: Optional[str] = Field(
|
||||
None, description="Humanized modification time"
|
||||
)
|
||||
chart_count: int = Field(0, description="Number of charts in the dashboard")
|
||||
owners: List[UserInfo] = Field(default_factory=list, description="Dashboard owners")
|
||||
tags: List[TagInfo] = Field(default_factory=list, description="Dashboard tags")
|
||||
roles: List[RoleInfo] = Field(default_factory=list, description="Dashboard roles")
|
||||
charts: List[ChartInfo] = Field(
|
||||
default_factory=list, description="Dashboard charts"
|
||||
)
|
||||
model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class DashboardList(BaseModel):
|
||||
dashboards: List[DashboardInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: Optional[List[str]] = None
|
||||
columns_loaded: Optional[List[str]] = None
|
||||
filters_applied: List[DashboardFilter] = Field(
|
||||
default_factory=list,
|
||||
description="List of advanced filter dicts applied to the query.",
|
||||
)
|
||||
pagination: Optional[PaginationInfo] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class GetDashboardAvailableFiltersRequest(BaseModel):
|
||||
"""
|
||||
Request schema for get_dashboard_available_filters tool.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid", str_strip_whitespace=True)
|
||||
338
superset/mcp_service/schemas/dataset_schemas.py
Normal file
338
superset/mcp_service/schemas/dataset_schemas.py
Normal file
@@ -0,0 +1,338 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for dataset-related responses
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator, PositiveInt
|
||||
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.mcp_service.schemas.cache_schemas import MetadataCacheControl
|
||||
from superset.mcp_service.schemas.system_schemas import (
|
||||
PaginationInfo,
|
||||
TagInfo,
|
||||
UserInfo,
|
||||
)
|
||||
from superset.utils import json
|
||||
|
||||
|
||||
class GetDatasetAvailableFiltersRequest(BaseModel):
|
||||
"""
|
||||
Request schema for get_dataset_available_filters tool.
|
||||
|
||||
Currently has no parameters but provides consistent API for future extensibility.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
str_strip_whitespace=True,
|
||||
)
|
||||
|
||||
|
||||
class DatasetAvailableFilters(BaseModel):
|
||||
column_operators: Dict[str, List[str]] = Field(
|
||||
...,
|
||||
description="Available filter operators for each column: mapping from column "
|
||||
"name to list of supported operators",
|
||||
)
|
||||
|
||||
|
||||
class DatasetFilter(ColumnOperator):
|
||||
"""
|
||||
Filter object for dataset listing.
|
||||
col: The column to filter on. Must be one of the allowed filter fields.
|
||||
opr: The operator to use. Must be one of the supported operators.
|
||||
value: The value to filter by (type depends on col and opr).
|
||||
"""
|
||||
|
||||
col: Literal[
|
||||
"table_name",
|
||||
"schema",
|
||||
"owner",
|
||||
"favorite",
|
||||
] = Field(
|
||||
...,
|
||||
description="Column to filter on. See get_dataset_available_filters for "
|
||||
"allowed values.",
|
||||
)
|
||||
opr: ColumnOperatorEnum = Field(
|
||||
...,
|
||||
description="Operator to use. See get_dataset_available_filters for "
|
||||
"allowed values.",
|
||||
)
|
||||
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
||||
..., description="Value to filter by (type depends on col and opr)"
|
||||
)
|
||||
|
||||
|
||||
class TableColumnInfo(BaseModel):
|
||||
column_name: str = Field(..., description="Column name")
|
||||
verbose_name: Optional[str] = Field(None, description="Verbose name")
|
||||
type: Optional[str] = Field(None, description="Column type")
|
||||
is_dttm: Optional[bool] = Field(None, description="Is datetime column")
|
||||
groupby: Optional[bool] = Field(None, description="Is groupable")
|
||||
filterable: Optional[bool] = Field(None, description="Is filterable")
|
||||
description: Optional[str] = Field(None, description="Column description")
|
||||
|
||||
|
||||
class SqlMetricInfo(BaseModel):
|
||||
metric_name: str = Field(..., description="Metric name")
|
||||
verbose_name: Optional[str] = Field(None, description="Verbose name")
|
||||
expression: Optional[str] = Field(None, description="SQL expression")
|
||||
description: Optional[str] = Field(None, description="Metric description")
|
||||
d3format: Optional[str] = Field(None, description="D3 format string")
|
||||
|
||||
|
||||
class DatasetInfo(BaseModel):
|
||||
id: Optional[int] = Field(None, description="Dataset ID")
|
||||
table_name: Optional[str] = Field(None, description="Table name")
|
||||
schema: Optional[str] = Field(None, description="Schema name")
|
||||
database_name: Optional[str] = Field(None, description="Database name")
|
||||
description: Optional[str] = Field(None, description="Dataset description")
|
||||
changed_by: Optional[str] = Field(None, description="Last modifier (username)")
|
||||
changed_on: Optional[str | datetime] = Field(
|
||||
None, description="Last modification timestamp"
|
||||
)
|
||||
changed_on_humanized: Optional[str] = Field(
|
||||
None, description="Humanized modification time"
|
||||
)
|
||||
created_by: Optional[str] = Field(None, description="Dataset creator (username)")
|
||||
created_on: Optional[str | datetime] = Field(None, description="Creation timestamp")
|
||||
created_on_humanized: Optional[str] = Field(
|
||||
None, description="Humanized creation time"
|
||||
)
|
||||
tags: List[TagInfo] = Field(default_factory=list, description="Dataset tags")
|
||||
owners: List[UserInfo] = Field(
|
||||
default_factory=list, description="DatasetInfo owners"
|
||||
)
|
||||
is_virtual: Optional[bool] = Field(
|
||||
None, description="Whether the dataset is virtual (uses SQL)"
|
||||
)
|
||||
database_id: Optional[int] = Field(None, description="Database ID")
|
||||
uuid: Optional[str] = Field(None, description="Dataset UUID")
|
||||
schema_perm: Optional[str] = Field(None, description="Schema permission string")
|
||||
url: Optional[str] = Field(None, description="Dataset URL")
|
||||
sql: Optional[str] = Field(None, description="SQL for virtual datasets")
|
||||
main_dttm_col: Optional[str] = Field(None, description="Main datetime column")
|
||||
offset: Optional[int] = Field(None, description="Offset")
|
||||
cache_timeout: Optional[int] = Field(None, description="Cache timeout")
|
||||
params: Optional[Dict[str, Any]] = Field(None, description="Extra params")
|
||||
template_params: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Template params"
|
||||
)
|
||||
extra: Optional[Dict[str, Any]] = Field(None, description="Extra metadata")
|
||||
columns: List[TableColumnInfo] = Field(
|
||||
default_factory=list, description="Columns in the dataset"
|
||||
)
|
||||
metrics: List[SqlMetricInfo] = Field(
|
||||
default_factory=list, description="Metrics in the dataset"
|
||||
)
|
||||
is_favorite: Optional[bool] = Field(
|
||||
None, description="Whether this dataset is favorited by the current user"
|
||||
)
|
||||
model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class DatasetList(BaseModel):
|
||||
datasets: List[DatasetInfo]
|
||||
count: int
|
||||
total_count: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
has_previous: bool
|
||||
has_next: bool
|
||||
columns_requested: Optional[List[str]] = None
|
||||
columns_loaded: Optional[List[str]] = None
|
||||
filters_applied: List[DatasetFilter] = Field(
|
||||
default_factory=list,
|
||||
description="List of advanced filter dicts applied to the query.",
|
||||
)
|
||||
pagination: Optional[PaginationInfo] = None
|
||||
timestamp: Optional[datetime] = None
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class ListDatasetsRequest(MetadataCacheControl):
|
||||
"""Request schema for list_datasets with clear, unambiguous types."""
|
||||
|
||||
filters: Annotated[
|
||||
List[DatasetFilter],
|
||||
Field(
|
||||
default_factory=list,
|
||||
description="List of filter objects (column, operator, value). Each "
|
||||
"filter is an object with 'col', 'opr', and 'value' "
|
||||
"properties. Cannot be used together with 'search'.",
|
||||
),
|
||||
]
|
||||
select_columns: Annotated[
|
||||
List[str],
|
||||
Field(
|
||||
default_factory=lambda: [
|
||||
"id",
|
||||
"table_name",
|
||||
"schema",
|
||||
"database_name",
|
||||
"changed_by_name",
|
||||
"changed_on",
|
||||
"created_by_name",
|
||||
"created_on",
|
||||
"metrics",
|
||||
"columns",
|
||||
"uuid",
|
||||
],
|
||||
description="List of columns to select. Defaults to common columns if not "
|
||||
"specified.",
|
||||
),
|
||||
]
|
||||
search: Annotated[
|
||||
Optional[str],
|
||||
Field(
|
||||
default=None,
|
||||
description="Text search string to match against dataset fields. Cannot "
|
||||
"be used together with 'filters'.",
|
||||
),
|
||||
]
|
||||
order_column: Annotated[
|
||||
Optional[str], Field(default=None, description="Column to order results by")
|
||||
]
|
||||
order_direction: Annotated[
|
||||
Literal["asc", "desc"],
|
||||
Field(
|
||||
default="desc", description="Direction to order results ('asc' or 'desc')"
|
||||
),
|
||||
]
|
||||
page: Annotated[
|
||||
PositiveInt,
|
||||
Field(default=1, description="Page number for pagination (1-based)"),
|
||||
]
|
||||
page_size: Annotated[
|
||||
PositiveInt, Field(default=100, description="Number of items per page")
|
||||
]
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_search_and_filters(self) -> "ListDatasetsRequest":
|
||||
"""Prevent using both search and filters simultaneously to avoid query
|
||||
conflicts."""
|
||||
if self.search and self.filters:
|
||||
raise ValueError(
|
||||
"Cannot use both 'search' and 'filters' parameters simultaneously. "
|
||||
"Use either 'search' for text-based searching across multiple fields, "
|
||||
"or 'filters' for precise column-based filtering, but not both."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class DatasetError(BaseModel):
|
||||
error: str = Field(..., description="Error message")
|
||||
error_type: str = Field(..., description="Type of error")
|
||||
timestamp: Optional[str | datetime] = Field(None, description="Error timestamp")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
|
||||
|
||||
class GetDatasetInfoRequest(MetadataCacheControl):
|
||||
"""Request schema for get_dataset_info with support for ID or UUID."""
|
||||
|
||||
identifier: Annotated[
|
||||
int | str,
|
||||
Field(description="Dataset identifier - can be numeric ID or UUID string"),
|
||||
]
|
||||
|
||||
|
||||
def serialize_dataset_object(dataset: Any) -> Optional[DatasetInfo]:
|
||||
if not dataset:
|
||||
return None
|
||||
params = getattr(dataset, "params", None)
|
||||
if isinstance(params, str):
|
||||
try:
|
||||
params = json.loads(params)
|
||||
except Exception:
|
||||
params = None
|
||||
columns = [
|
||||
TableColumnInfo(
|
||||
column_name=getattr(col, "column_name", None),
|
||||
verbose_name=getattr(col, "verbose_name", None),
|
||||
type=getattr(col, "type", None),
|
||||
is_dttm=getattr(col, "is_dttm", None),
|
||||
groupby=getattr(col, "groupby", None),
|
||||
filterable=getattr(col, "filterable", None),
|
||||
description=getattr(col, "description", None),
|
||||
)
|
||||
for col in getattr(dataset, "columns", [])
|
||||
]
|
||||
metrics = [
|
||||
SqlMetricInfo(
|
||||
metric_name=getattr(metric, "metric_name", None),
|
||||
verbose_name=getattr(metric, "verbose_name", None),
|
||||
expression=getattr(metric, "expression", None),
|
||||
description=getattr(metric, "description", None),
|
||||
d3format=getattr(metric, "d3format", None),
|
||||
)
|
||||
for metric in getattr(dataset, "metrics", [])
|
||||
]
|
||||
return DatasetInfo(
|
||||
id=getattr(dataset, "id", None),
|
||||
table_name=getattr(dataset, "table_name", None),
|
||||
schema=getattr(dataset, "schema", None),
|
||||
database_name=getattr(dataset.database, "database_name", None)
|
||||
if getattr(dataset, "database", None)
|
||||
else None,
|
||||
description=getattr(dataset, "description", None),
|
||||
changed_by=getattr(dataset, "changed_by_name", None)
|
||||
or (str(dataset.changed_by) if getattr(dataset, "changed_by", None) else None),
|
||||
changed_on=getattr(dataset, "changed_on", None),
|
||||
changed_on_humanized=getattr(dataset, "changed_on_humanized", None),
|
||||
created_by=getattr(dataset, "created_by_name", None)
|
||||
or (str(dataset.created_by) if getattr(dataset, "created_by", None) else None),
|
||||
created_on=getattr(dataset, "created_on", None),
|
||||
created_on_humanized=getattr(dataset, "created_on_humanized", None),
|
||||
tags=[
|
||||
TagInfo.model_validate(tag, from_attributes=True)
|
||||
for tag in getattr(dataset, "tags", [])
|
||||
]
|
||||
if getattr(dataset, "tags", None)
|
||||
else [],
|
||||
owners=[
|
||||
UserInfo.model_validate(owner, from_attributes=True)
|
||||
for owner in getattr(dataset, "owners", [])
|
||||
]
|
||||
if getattr(dataset, "owners", None)
|
||||
else [],
|
||||
is_virtual=getattr(dataset, "is_virtual", None),
|
||||
database_id=getattr(dataset, "database_id", None),
|
||||
uuid=str(getattr(dataset, "uuid", ""))
|
||||
if getattr(dataset, "uuid", None)
|
||||
else None,
|
||||
schema_perm=getattr(dataset, "schema_perm", None),
|
||||
url=getattr(dataset, "url", None),
|
||||
sql=getattr(dataset, "sql", None),
|
||||
main_dttm_col=getattr(dataset, "main_dttm_col", None),
|
||||
offset=getattr(dataset, "offset", None),
|
||||
cache_timeout=getattr(dataset, "cache_timeout", None),
|
||||
params=params,
|
||||
template_params=getattr(dataset, "template_params", None),
|
||||
extra=getattr(dataset, "extra", None),
|
||||
columns=columns,
|
||||
metrics=metrics,
|
||||
is_favorite=getattr(dataset, "is_favorite", None),
|
||||
)
|
||||
105
superset/mcp_service/schemas/error_schemas.py
Normal file
105
superset/mcp_service/schemas/error_schemas.py
Normal file
@@ -0,0 +1,105 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Enhanced error schemas for MCP chart generation with contextual information
|
||||
"""
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ColumnSuggestion(BaseModel):
|
||||
"""Suggested column with context"""
|
||||
|
||||
name: str = Field(..., description="Column name")
|
||||
type: str = Field(..., description="Column data type")
|
||||
similarity_score: float = Field(..., description="Similarity score (0-1)")
|
||||
description: Optional[str] = Field(None, description="Column description")
|
||||
|
||||
|
||||
class ValidationError(BaseModel):
|
||||
"""Individual validation error with context"""
|
||||
|
||||
field: str = Field(..., description="Field that failed validation")
|
||||
provided_value: Any = Field(..., description="Value that was provided")
|
||||
error_type: str = Field(..., description="Type of validation error")
|
||||
message: str = Field(..., description="Human-readable error message")
|
||||
suggestions: List[ColumnSuggestion] = Field(
|
||||
default_factory=list, description="Suggested alternatives"
|
||||
)
|
||||
|
||||
|
||||
class DatasetContext(BaseModel):
|
||||
"""Dataset information for error context"""
|
||||
|
||||
id: int = Field(..., description="Dataset ID")
|
||||
table_name: str = Field(..., description="Table name")
|
||||
schema: Optional[str] = Field(None, description="Schema name")
|
||||
database_name: str = Field(..., description="Database name")
|
||||
available_columns: List[Dict[str, Any]] = Field(
|
||||
default_factory=list, description="Available columns with metadata"
|
||||
)
|
||||
available_metrics: List[Dict[str, Any]] = Field(
|
||||
default_factory=list, description="Available metrics with metadata"
|
||||
)
|
||||
|
||||
|
||||
class ChartGenerationError(BaseModel):
|
||||
"""Enhanced error response for chart generation failures"""
|
||||
|
||||
error_type: str = Field(
|
||||
..., description="Type of error (validation, execution, etc.)"
|
||||
)
|
||||
message: str = Field(..., description="High-level error message")
|
||||
details: str = Field(..., description="Detailed error explanation")
|
||||
validation_errors: List[ValidationError] = Field(
|
||||
default_factory=list, description="Specific field validation errors"
|
||||
)
|
||||
dataset_context: Optional[DatasetContext] = Field(
|
||||
None, description="Dataset information for context"
|
||||
)
|
||||
query_info: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Query execution details"
|
||||
)
|
||||
suggestions: List[str] = Field(
|
||||
default_factory=list, description="Actionable suggestions to fix the error"
|
||||
)
|
||||
help_url: Optional[str] = Field(
|
||||
None, description="URL to documentation for this error type"
|
||||
)
|
||||
error_code: Optional[str] = Field(
|
||||
None, description="Unique error code for support reference"
|
||||
)
|
||||
|
||||
|
||||
class ChartGenerationResponse(BaseModel):
|
||||
"""Enhanced chart generation response with detailed error handling"""
|
||||
|
||||
success: bool = Field(..., description="Whether chart generation succeeded")
|
||||
chart: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Chart information if successful"
|
||||
)
|
||||
error: Optional[ChartGenerationError] = Field(
|
||||
None, description="Error details if failed"
|
||||
)
|
||||
performance: Optional[Dict[str, Any]] = Field(
|
||||
None, description="Performance metadata"
|
||||
)
|
||||
schema_version: str = Field(default="2.0", description="Response schema version")
|
||||
api_version: str = Field(default="v1", description="API version")
|
||||
143
superset/mcp_service/schemas/system_schemas.py
Normal file
143
superset/mcp_service/schemas/system_schemas.py
Normal file
@@ -0,0 +1,143 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Pydantic schemas for system-level (instance/info) responses
|
||||
|
||||
This module contains Pydantic models for serializing Superset instance metadata and
|
||||
system-level info.
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class GetSupersetInstanceInfoRequest(BaseModel):
|
||||
"""
|
||||
Request schema for get_superset_instance_info tool.
|
||||
|
||||
Currently has no parameters but provides consistent API for future extensibility.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
str_strip_whitespace=True,
|
||||
)
|
||||
|
||||
|
||||
class InstanceSummary(BaseModel):
|
||||
total_dashboards: int = Field(..., description="Total number of dashboards")
|
||||
total_charts: int = Field(..., description="Total number of charts")
|
||||
total_datasets: int = Field(..., description="Total number of datasets")
|
||||
total_databases: int = Field(..., description="Total number of databases")
|
||||
total_users: int = Field(..., description="Total number of users")
|
||||
total_roles: int = Field(..., description="Total number of roles")
|
||||
total_tags: int = Field(..., description="Total number of tags")
|
||||
avg_charts_per_dashboard: float = Field(
|
||||
..., description="Average number of charts per dashboard"
|
||||
)
|
||||
|
||||
|
||||
class RecentActivity(BaseModel):
|
||||
dashboards_created_last_30_days: int = Field(
|
||||
..., description="Dashboards created in the last 30 days"
|
||||
)
|
||||
charts_created_last_30_days: int = Field(
|
||||
..., description="Charts created in the last 30 days"
|
||||
)
|
||||
datasets_created_last_30_days: int = Field(
|
||||
..., description="Datasets created in the last 30 days"
|
||||
)
|
||||
dashboards_modified_last_7_days: int = Field(
|
||||
..., description="Dashboards modified in the last 7 days"
|
||||
)
|
||||
charts_modified_last_7_days: int = Field(
|
||||
..., description="Charts modified in the last 7 days"
|
||||
)
|
||||
datasets_modified_last_7_days: int = Field(
|
||||
..., description="Datasets modified in the last 7 days"
|
||||
)
|
||||
|
||||
|
||||
class DashboardBreakdown(BaseModel):
|
||||
published: int = Field(..., description="Number of published dashboards")
|
||||
unpublished: int = Field(..., description="Number of unpublished dashboards")
|
||||
certified: int = Field(..., description="Number of certified dashboards")
|
||||
with_charts: int = Field(..., description="Number of dashboards with charts")
|
||||
without_charts: int = Field(..., description="Number of dashboards without charts")
|
||||
|
||||
|
||||
class DatabaseBreakdown(BaseModel):
|
||||
by_type: Dict[str, int] = Field(..., description="Breakdown of databases by type")
|
||||
|
||||
|
||||
class PopularContent(BaseModel):
|
||||
top_tags: List[str] = Field(..., description="Most popular tags")
|
||||
top_creators: List[str] = Field(..., description="Most active creators")
|
||||
|
||||
|
||||
class InstanceInfo(BaseModel):
|
||||
instance_summary: InstanceSummary = Field(
|
||||
..., description="Instance summary information"
|
||||
)
|
||||
recent_activity: RecentActivity = Field(
|
||||
..., description="Recent activity information"
|
||||
)
|
||||
dashboard_breakdown: DashboardBreakdown = Field(
|
||||
..., description="Dashboard breakdown information"
|
||||
)
|
||||
database_breakdown: DatabaseBreakdown = Field(
|
||||
..., description="Database breakdown by type"
|
||||
)
|
||||
popular_content: PopularContent = Field(
|
||||
..., description="Popular content information"
|
||||
)
|
||||
timestamp: datetime = Field(..., description="Response timestamp")
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: Optional[int] = None
|
||||
username: Optional[str] = None
|
||||
first_name: Optional[str] = None
|
||||
last_name: Optional[str] = None
|
||||
email: Optional[str] = None
|
||||
active: Optional[bool] = None
|
||||
|
||||
|
||||
class TagInfo(BaseModel):
|
||||
id: Optional[int] = None
|
||||
name: Optional[str] = None
|
||||
type: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class RoleInfo(BaseModel):
|
||||
id: Optional[int] = None
|
||||
name: Optional[str] = None
|
||||
permissions: Optional[List[str]] = None
|
||||
|
||||
|
||||
class PaginationInfo(BaseModel):
|
||||
page: int = Field(..., description="Current page number")
|
||||
page_size: int = Field(..., description="Number of items per page")
|
||||
total_count: int = Field(..., description="Total number of items")
|
||||
total_pages: int = Field(..., description="Total number of pages")
|
||||
has_next: bool = Field(..., description="Whether there is a next page")
|
||||
has_previous: bool = Field(..., description="Whether there is a previous page")
|
||||
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
||||
67
superset/mcp_service/server.py
Normal file
67
superset/mcp_service/server.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP server for Apache Superset
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from superset.mcp_service.mcp_app import init_fastmcp_server, mcp
|
||||
|
||||
|
||||
def configure_logging(debug: bool = False) -> None:
|
||||
"""Configure logging for the MCP service."""
|
||||
if debug or os.environ.get("SQLALCHEMY_DEBUG"):
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
for logger_name in [
|
||||
"sqlalchemy.engine",
|
||||
"sqlalchemy.pool",
|
||||
"sqlalchemy.dialects",
|
||||
]:
|
||||
logging.getLogger(logger_name).setLevel(logging.INFO)
|
||||
print("🔍 SQL Debug logging enabled")
|
||||
|
||||
|
||||
def run_server(host: str = "127.0.0.1", port: int = 5008, debug: bool = False) -> None:
|
||||
"""
|
||||
Run the MCP service server with FastMCP endpoints.
|
||||
Only supports HTTP (streamable-http) transport.
|
||||
"""
|
||||
configure_logging(debug)
|
||||
print("Creating MCP app...")
|
||||
init_fastmcp_server() # This will register middleware, etc.
|
||||
|
||||
env_key = f"FASTMCP_RUNNING_{port}"
|
||||
if not os.environ.get(env_key):
|
||||
os.environ[env_key] = "1"
|
||||
try:
|
||||
print(f"Starting FastMCP on {host}:{port}")
|
||||
mcp.run(transport="streamable-http", host=host, port=port)
|
||||
except Exception as e:
|
||||
print(f"FastMCP failed: {e}")
|
||||
os.environ.pop(env_key, None)
|
||||
else:
|
||||
print(f"FastMCP already running on {host}:{port}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_server()
|
||||
74
superset/mcp_service/simple_proxy.py
Executable file
74
superset/mcp_service/simple_proxy.py
Executable file
@@ -0,0 +1,74 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple MCP proxy server that connects to FastMCP server on localhost:5009
|
||||
"""
|
||||
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
import time
|
||||
from typing import Any, Optional
|
||||
|
||||
from fastmcp import FastMCP
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global proxy instance for cleanup
|
||||
proxy: Optional[FastMCP] = None
|
||||
|
||||
|
||||
def signal_handler(signum: int, frame: Any) -> None:
|
||||
"""Handle shutdown signals gracefully"""
|
||||
logger.info(f"Received signal {signum}, shutting down gracefully...")
|
||||
if proxy:
|
||||
try:
|
||||
# Give the proxy a moment to clean up
|
||||
time.sleep(0.1)
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during proxy cleanup: {e}")
|
||||
sys.exit(0)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main function to run the proxy"""
|
||||
global proxy
|
||||
|
||||
try:
|
||||
from fastmcp import FastMCP
|
||||
|
||||
# Set up signal handlers for graceful shutdown
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
logger.info("Starting MCP proxy server...")
|
||||
|
||||
# Create a proxy to the remote FastMCP server
|
||||
proxy = FastMCP.as_proxy(
|
||||
"http://localhost:5008/mcp/", name="Superset MCP Proxy"
|
||||
)
|
||||
|
||||
logger.info("Proxy created successfully, starting...")
|
||||
|
||||
# Run the proxy (this will block until interrupted)
|
||||
proxy.run()
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Received keyboard interrupt, shutting down...")
|
||||
sys.exit(0)
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import FastMCP: {e}")
|
||||
logger.error("Please install fastmcp: pip install fastmcp")
|
||||
sys.exit(1)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error: {e}")
|
||||
sys.exit(1)
|
||||
finally:
|
||||
logger.info("Proxy server stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
16
superset/mcp_service/sql_lab/__init__.py
Normal file
16
superset/mcp_service/sql_lab/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
22
superset/mcp_service/sql_lab/tool/__init__.py
Normal file
22
superset/mcp_service/sql_lab/tool/__init__.py
Normal file
@@ -0,0 +1,22 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from .open_sql_lab_with_context import open_sql_lab_with_context
|
||||
|
||||
__all__ = [
|
||||
"open_sql_lab_with_context",
|
||||
]
|
||||
154
superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py
Normal file
154
superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
MCP tool: open_sql_lab_with_context
|
||||
|
||||
This tool generates a URL to open SQL Lab with pre-populated context including
|
||||
database connection, schema, dataset context, and SQL query.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenSqlLabRequest(BaseModel):
|
||||
"""Request schema for opening SQL Lab with context."""
|
||||
|
||||
database_connection_id: int = Field(
|
||||
..., description="Database connection ID to use in SQL Lab"
|
||||
)
|
||||
schema_name: Optional[str] = Field(
|
||||
None, description="Default schema to select in SQL Lab", alias="schema"
|
||||
)
|
||||
dataset_in_context: Optional[str] = Field(
|
||||
None, description="Dataset name/table to provide as context"
|
||||
)
|
||||
sql: Optional[str] = Field(
|
||||
None, description="SQL query to pre-populate in the editor"
|
||||
)
|
||||
title: Optional[str] = Field(None, description="Title for the SQL Lab tab/query")
|
||||
|
||||
|
||||
class SqlLabResponse(BaseModel):
|
||||
"""Response schema for SQL Lab URL generation."""
|
||||
|
||||
url: str = Field(..., description="URL to open SQL Lab with context")
|
||||
database_id: int = Field(..., description="Database ID used")
|
||||
schema_name: Optional[str] = Field(
|
||||
None, description="Schema selected", alias="schema"
|
||||
)
|
||||
title: Optional[str] = Field(None, description="Query title")
|
||||
error: Optional[str] = Field(None, description="Error message if failed")
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def open_sql_lab_with_context(request: OpenSqlLabRequest) -> SqlLabResponse:
|
||||
"""
|
||||
Generate a URL to open SQL Lab with pre-populated context.
|
||||
|
||||
This tool creates a SQL Lab URL with the specified database connection,
|
||||
schema, dataset context, and SQL query. The URL can be used to directly
|
||||
navigate users to a pre-configured SQL Lab session.
|
||||
|
||||
Args:
|
||||
request: OpenSqlLabRequest with database_connection_id, schema,
|
||||
dataset_in_context, and sql parameters
|
||||
|
||||
Returns:
|
||||
SqlLabResponse with the generated URL and context information
|
||||
"""
|
||||
try:
|
||||
from superset.daos.database import DatabaseDAO
|
||||
|
||||
# Validate database exists and is accessible
|
||||
database = DatabaseDAO.find_by_id(request.database_connection_id)
|
||||
if not database:
|
||||
return SqlLabResponse(
|
||||
url="",
|
||||
database_id=request.database_connection_id,
|
||||
schema_name=request.schema_name,
|
||||
title=request.title,
|
||||
error=f"Database with ID {request.database_connection_id} not found",
|
||||
)
|
||||
|
||||
# Build query parameters for SQL Lab URL
|
||||
params = {
|
||||
"dbid": str(request.database_connection_id),
|
||||
}
|
||||
|
||||
if request.schema_name:
|
||||
params["schema"] = request.schema_name
|
||||
|
||||
if request.sql:
|
||||
params["sql"] = request.sql
|
||||
|
||||
if request.title:
|
||||
params["title"] = request.title
|
||||
|
||||
if request.dataset_in_context:
|
||||
# Add dataset context as a comment in the SQL if no SQL provided
|
||||
if not request.sql:
|
||||
context_comment = (
|
||||
f"-- Context: Working with dataset '{request.dataset_in_context}'\n"
|
||||
f"-- Database: {database.database_name}\n"
|
||||
)
|
||||
if request.schema_name:
|
||||
context_comment += f"-- Schema: {request.schema_name}\n"
|
||||
table_reference = (
|
||||
f"{request.schema_name}.{request.dataset_in_context}"
|
||||
)
|
||||
else:
|
||||
table_reference = request.dataset_in_context
|
||||
|
||||
context_comment += f"\nSELECT * FROM {table_reference} LIMIT 100;"
|
||||
params["sql"] = context_comment
|
||||
|
||||
# Construct SQL Lab URL
|
||||
query_string = urlencode(params)
|
||||
url = f"/sqllab?{query_string}"
|
||||
|
||||
logger.info(
|
||||
f"Generated SQL Lab URL for database {request.database_connection_id}"
|
||||
)
|
||||
|
||||
return SqlLabResponse(
|
||||
url=url,
|
||||
database_id=request.database_connection_id,
|
||||
schema_name=request.schema_name,
|
||||
title=request.title,
|
||||
error=None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error generating SQL Lab URL: {e}", exc_info=True)
|
||||
return SqlLabResponse(
|
||||
url="",
|
||||
database_id=request.database_connection_id,
|
||||
schema_name=request.schema_name,
|
||||
title=request.title,
|
||||
error=f"Failed to generate SQL Lab URL: {str(e)}",
|
||||
)
|
||||
7
superset/mcp_service/system/tool/__init__.py
Normal file
7
superset/mcp_service/system/tool/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from .generate_explore_link import generate_explore_link
|
||||
from .get_superset_instance_info import get_superset_instance_info
|
||||
|
||||
__all__ = [
|
||||
"generate_explore_link",
|
||||
"get_superset_instance_info",
|
||||
]
|
||||
92
superset/mcp_service/system/tool/generate_explore_link.py
Normal file
92
superset/mcp_service/system/tool/generate_explore_link.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Generate explore link MCP tool
|
||||
|
||||
This tool generates a URL to the Superset explore interface with the specified
|
||||
chart configuration.
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
generate_explore_link as generate_url,
|
||||
map_config_to_form_data,
|
||||
)
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
GenerateExploreLinkRequest,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def generate_explore_link(request: GenerateExploreLinkRequest) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate a Superset explore URL for interactive data visualization and exploration.
|
||||
|
||||
🎯 PREFERRED TOOL for most visualization requests.
|
||||
|
||||
This is the primary tool for data visualization workflows. It creates an explore
|
||||
URL where users can interactively view, modify, and optionally save charts.
|
||||
The explore interface allows users to adjust parameters, add filters, and
|
||||
experiment before deciding to save.
|
||||
|
||||
Use this tool for:
|
||||
- "Show me a chart of [data]"
|
||||
- "Visualize [data] as a [chart type]"
|
||||
- "I want to see [data] trends"
|
||||
- "Create a visualization of [data]"
|
||||
- "Chart [data] by [dimensions]"
|
||||
- "Plot [data]"
|
||||
- General data exploration and analysis
|
||||
- When user wants to see/explore data visually
|
||||
|
||||
This tool provides a much better user experience because:
|
||||
- Users can interact with the chart before saving
|
||||
- Easy to modify parameters and see results instantly
|
||||
- Users control when/if to save the chart permanently
|
||||
- No database clutter from unsaved exploration charts
|
||||
|
||||
Only use generate_chart when user explicitly requests to save/create a
|
||||
permanent chart.
|
||||
|
||||
Args:
|
||||
request: Explore link generation request with dataset_id and config
|
||||
|
||||
Returns:
|
||||
Dictionary containing explore URL for immediate use and error message if any
|
||||
"""
|
||||
try:
|
||||
# Map config to form_data using shared utilities
|
||||
form_data = map_config_to_form_data(request.config)
|
||||
|
||||
# Generate explore link using shared utilities
|
||||
explore_url = generate_url(dataset_id=request.dataset_id, form_data=form_data)
|
||||
|
||||
return {
|
||||
"url": explore_url,
|
||||
"error": None,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
return {
|
||||
"url": "",
|
||||
"error": f"Failed to generate explore link: {str(e)}",
|
||||
}
|
||||
278
superset/mcp_service/system/tool/get_superset_instance_info.py
Normal file
278
superset/mcp_service/system/tool/get_superset_instance_info.py
Normal file
@@ -0,0 +1,278 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Get Superset instance high-level information FastMCP tool using configurable
|
||||
InstanceInfoTool for flexible, extensible metrics calculation.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict
|
||||
|
||||
from superset.mcp_service.auth import mcp_auth_hook
|
||||
from superset.mcp_service.generic_tools import InstanceInfoTool
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.system_schemas import (
|
||||
DashboardBreakdown,
|
||||
DatabaseBreakdown,
|
||||
GetSupersetInstanceInfoRequest,
|
||||
InstanceInfo,
|
||||
InstanceSummary,
|
||||
PopularContent,
|
||||
RecentActivity,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def calculate_dashboard_breakdown(
|
||||
base_counts: Dict[str, int],
|
||||
time_metrics: Dict[str, Dict[str, int]],
|
||||
dao_classes: Dict[str, Any],
|
||||
) -> DashboardBreakdown:
|
||||
"""Calculate detailed dashboard breakdown metrics."""
|
||||
try:
|
||||
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
||||
from superset.extensions import db
|
||||
from superset.models.dashboard import Dashboard
|
||||
|
||||
dashboard_dao = dao_classes["dashboards"]
|
||||
|
||||
# Published vs unpublished
|
||||
published_count = dashboard_dao.count(
|
||||
column_operators=[
|
||||
ColumnOperator(col="published", opr=ColumnOperatorEnum.eq, value=True)
|
||||
]
|
||||
)
|
||||
unpublished_count = base_counts.get("total_dashboards", 0) - published_count
|
||||
|
||||
# Certified dashboards
|
||||
certified_count = dashboard_dao.count(
|
||||
column_operators=[
|
||||
ColumnOperator(
|
||||
col="certified_by", opr=ColumnOperatorEnum.is_not_null, value=None
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
# Dashboards with/without charts
|
||||
dashboards_with_charts = (
|
||||
db.session.query(Dashboard).join(Dashboard.slices).distinct().count()
|
||||
)
|
||||
dashboards_without_charts = (
|
||||
base_counts.get("total_dashboards", 0) - dashboards_with_charts
|
||||
)
|
||||
|
||||
return DashboardBreakdown(
|
||||
published=published_count,
|
||||
unpublished=unpublished_count,
|
||||
certified=certified_count,
|
||||
with_charts=dashboards_with_charts,
|
||||
without_charts=dashboards_without_charts,
|
||||
)
|
||||
except Exception:
|
||||
# Return empty breakdown on error
|
||||
return DashboardBreakdown(
|
||||
published=0,
|
||||
unpublished=0,
|
||||
certified=0,
|
||||
with_charts=0,
|
||||
without_charts=0,
|
||||
)
|
||||
|
||||
|
||||
def calculate_database_breakdown(
|
||||
base_counts: Dict[str, int],
|
||||
time_metrics: Dict[str, Dict[str, int]],
|
||||
dao_classes: Dict[str, Any],
|
||||
) -> DatabaseBreakdown:
|
||||
"""Calculate database type breakdown."""
|
||||
try:
|
||||
from superset.extensions import db
|
||||
from superset.models.core import Database
|
||||
|
||||
# Get database types distribution
|
||||
db_types = db.session.query(
|
||||
Database.database_name, Database.sqlalchemy_uri
|
||||
).all()
|
||||
|
||||
type_counts: Dict[str, int] = {}
|
||||
for _name, uri in db_types:
|
||||
if uri:
|
||||
# Extract database type from SQLAlchemy URI
|
||||
db_type = uri.split("://")[0] if "://" in uri else "unknown"
|
||||
type_counts[db_type] = type_counts.get(db_type, 0) + 1
|
||||
else:
|
||||
type_counts["unknown"] = type_counts.get("unknown", 0) + 1
|
||||
|
||||
return DatabaseBreakdown(by_type=type_counts)
|
||||
except Exception:
|
||||
# Return empty breakdown on error
|
||||
return DatabaseBreakdown(by_type={})
|
||||
|
||||
|
||||
def calculate_instance_summary(
|
||||
base_counts: Dict[str, int],
|
||||
time_metrics: Dict[str, Dict[str, int]],
|
||||
dao_classes: Dict[str, Any],
|
||||
) -> InstanceSummary:
|
||||
"""Calculate instance summary with computed metrics."""
|
||||
try:
|
||||
from flask_appbuilder.security.sqla.models import Role
|
||||
|
||||
from superset.extensions import db
|
||||
|
||||
# Add roles count (no DAO available)
|
||||
total_roles = db.session.query(Role).count()
|
||||
|
||||
# Calculate average charts per dashboard
|
||||
total_dashboards = base_counts.get("total_dashboards", 0)
|
||||
total_charts = base_counts.get("total_charts", 0)
|
||||
avg_charts_per_dashboard = (
|
||||
(total_charts / total_dashboards) if total_dashboards > 0 else 0
|
||||
)
|
||||
|
||||
return InstanceSummary(
|
||||
total_dashboards=total_dashboards,
|
||||
total_charts=total_charts,
|
||||
total_datasets=base_counts.get("total_datasets", 0),
|
||||
total_databases=base_counts.get("total_databases", 0),
|
||||
total_users=base_counts.get("total_users", 0),
|
||||
total_roles=total_roles,
|
||||
total_tags=base_counts.get("total_tags", 0),
|
||||
avg_charts_per_dashboard=round(avg_charts_per_dashboard, 2),
|
||||
)
|
||||
except Exception:
|
||||
# Return empty summary on error
|
||||
return InstanceSummary(
|
||||
total_dashboards=0,
|
||||
total_charts=0,
|
||||
total_datasets=0,
|
||||
total_databases=0,
|
||||
total_users=0,
|
||||
total_roles=0,
|
||||
total_tags=0,
|
||||
avg_charts_per_dashboard=0.0,
|
||||
)
|
||||
|
||||
|
||||
def calculate_recent_activity(
|
||||
base_counts: Dict[str, int],
|
||||
time_metrics: Dict[str, Dict[str, int]],
|
||||
dao_classes: Dict[str, Any],
|
||||
) -> RecentActivity:
|
||||
"""Transform time metrics into RecentActivity format."""
|
||||
monthly = time_metrics.get("monthly", {})
|
||||
recent = time_metrics.get("recent", {})
|
||||
|
||||
return RecentActivity(
|
||||
dashboards_created_last_30_days=monthly.get("dashboards_created", 0),
|
||||
charts_created_last_30_days=monthly.get("charts_created", 0),
|
||||
datasets_created_last_30_days=monthly.get("datasets_created", 0),
|
||||
dashboards_modified_last_7_days=recent.get("dashboards_modified", 0),
|
||||
charts_modified_last_7_days=recent.get("charts_modified", 0),
|
||||
datasets_modified_last_7_days=recent.get("datasets_modified", 0),
|
||||
)
|
||||
|
||||
|
||||
def calculate_popular_content(
|
||||
base_counts: Dict[str, int],
|
||||
time_metrics: Dict[str, Dict[str, int]],
|
||||
dao_classes: Dict[str, Any],
|
||||
) -> PopularContent:
|
||||
"""Calculate popular content metrics (placeholder implementation)."""
|
||||
# TODO: Implement actual popular content calculation
|
||||
# This could include most viewed dashboards, top creators, etc.
|
||||
return PopularContent(
|
||||
top_tags=[],
|
||||
top_creators=[],
|
||||
)
|
||||
|
||||
|
||||
# Configure the instance info tool
|
||||
_instance_info_tool = InstanceInfoTool(
|
||||
dao_classes={
|
||||
"dashboards": None, # Will be set at runtime
|
||||
"charts": None,
|
||||
"datasets": None,
|
||||
"databases": None,
|
||||
"users": None,
|
||||
"tags": None,
|
||||
},
|
||||
output_schema=InstanceInfo,
|
||||
metric_calculators={
|
||||
"instance_summary": calculate_instance_summary,
|
||||
"recent_activity": calculate_recent_activity,
|
||||
"dashboard_breakdown": calculate_dashboard_breakdown,
|
||||
"database_breakdown": calculate_database_breakdown,
|
||||
"popular_content": calculate_popular_content,
|
||||
},
|
||||
time_windows={
|
||||
"recent": 7,
|
||||
"monthly": 30,
|
||||
"quarterly": 90,
|
||||
},
|
||||
logger=logger,
|
||||
)
|
||||
|
||||
|
||||
@mcp.tool
|
||||
@mcp_auth_hook
|
||||
def get_superset_instance_info(request: GetSupersetInstanceInfoRequest) -> InstanceInfo:
|
||||
"""
|
||||
Get comprehensive high-level information about the Superset instance.
|
||||
|
||||
Uses a configurable InstanceInfoTool to gather statistics including:
|
||||
- Basic entity counts (dashboards, charts, datasets, etc.)
|
||||
- Recent activity metrics across multiple time windows
|
||||
- Dashboard status breakdown (published, certified, etc.)
|
||||
- Database type distribution
|
||||
- Popular content analysis
|
||||
|
||||
This tool uses system-level metadata caching for performance. The data
|
||||
is typically cached and refreshed automatically by Superset's background
|
||||
processes.
|
||||
|
||||
Returns:
|
||||
InstanceInfo: Comprehensive instance statistics and metadata
|
||||
"""
|
||||
try:
|
||||
# Import DAOs at runtime to avoid circular imports
|
||||
from superset.daos.chart import ChartDAO
|
||||
from superset.daos.dashboard import DashboardDAO
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
from superset.daos.tag import TagDAO
|
||||
from superset.daos.user import UserDAO
|
||||
|
||||
# Configure DAO classes at runtime
|
||||
_instance_info_tool.dao_classes = {
|
||||
"dashboards": DashboardDAO,
|
||||
"charts": ChartDAO,
|
||||
"datasets": DatasetDAO,
|
||||
"databases": DatabaseDAO,
|
||||
"users": UserDAO,
|
||||
"tags": TagDAO,
|
||||
}
|
||||
|
||||
# Run the configurable tool
|
||||
return _instance_info_tool.run()
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Unexpected error in instance info: {str(e)}"
|
||||
logger.error(error_msg, exc_info=True)
|
||||
raise
|
||||
136
superset/mcp_service/url_utils.py
Normal file
136
superset/mcp_service/url_utils.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
URL utilities for MCP service
|
||||
"""
|
||||
|
||||
from flask import current_app
|
||||
|
||||
|
||||
def get_superset_base_url() -> str:
|
||||
"""
|
||||
Get the Superset base URL from configuration.
|
||||
|
||||
Returns:
|
||||
Base URL for Superset web server (e.g., "http://localhost:8088")
|
||||
"""
|
||||
# Default fallback to localhost:8088
|
||||
default_url = "http://localhost:8088"
|
||||
|
||||
try:
|
||||
# Try to get from configuration
|
||||
config = current_app.config
|
||||
|
||||
# Check for SUPERSET_WEBSERVER_ADDRESS first
|
||||
webserver_address = config.get("SUPERSET_WEBSERVER_ADDRESS")
|
||||
if webserver_address:
|
||||
return webserver_address
|
||||
|
||||
# Fallback to other potential config keys
|
||||
public_role_like_gamma = config.get("PUBLIC_ROLE_LIKE_GAMMA", False)
|
||||
if public_role_like_gamma:
|
||||
# If public access is enabled, might be on a different host
|
||||
webserver_protocol = config.get("ENABLE_PROXY_FIX", False)
|
||||
protocol = "https" if webserver_protocol else "http"
|
||||
host = config.get("WEBSERVER_HOST", "localhost")
|
||||
port = config.get("WEBSERVER_PORT", 8088)
|
||||
return f"{protocol}://{host}:{port}"
|
||||
|
||||
return default_url
|
||||
|
||||
except Exception:
|
||||
# If we can't access Flask config (e.g., outside app context),
|
||||
# return default
|
||||
return default_url
|
||||
|
||||
|
||||
def get_mcp_service_url() -> str:
|
||||
"""
|
||||
Get the MCP service base URL where screenshot endpoints are served.
|
||||
|
||||
The MCP service auto-detects its own host and port since it's running
|
||||
this code. Falls back to configuration if auto-detection fails.
|
||||
|
||||
Returns:
|
||||
Base URL for MCP service
|
||||
"""
|
||||
try:
|
||||
# Try to auto-detect from Flask request context
|
||||
from flask import request
|
||||
|
||||
if request:
|
||||
# Get the host and port from the current request
|
||||
scheme = request.scheme # http or https
|
||||
host = request.host # includes port if non-standard
|
||||
return f"{scheme}://{host}"
|
||||
|
||||
except (RuntimeError, AttributeError):
|
||||
# Not in request context or Flask not available
|
||||
pass
|
||||
|
||||
try:
|
||||
# Fallback: check configuration
|
||||
config = current_app.config
|
||||
|
||||
# Check for explicit MCP_SERVICE_URL in config
|
||||
mcp_service_url = config.get("MCP_SERVICE_URL")
|
||||
if mcp_service_url:
|
||||
return mcp_service_url
|
||||
|
||||
# If no explicit MCP URL, use the same as Superset base URL
|
||||
superset_url = get_superset_base_url()
|
||||
if superset_url and superset_url != "http://localhost:8088":
|
||||
# Only use Superset URL if it's not the default, to avoid infinite recursion
|
||||
return superset_url
|
||||
|
||||
except Exception as e:
|
||||
# Log and fall back if config access fails
|
||||
import logging
|
||||
|
||||
logging.getLogger(__name__).debug(f"Config access failed: {e}")
|
||||
|
||||
# Final fallback to same default as Superset
|
||||
return "http://localhost:8088"
|
||||
|
||||
|
||||
def get_chart_screenshot_url(chart_id: int | str) -> str:
|
||||
"""
|
||||
Generate a screenshot URL for a chart using the MCP service.
|
||||
|
||||
Args:
|
||||
chart_id: Chart ID (numeric or string)
|
||||
|
||||
Returns:
|
||||
Complete URL to the chart screenshot endpoint
|
||||
"""
|
||||
mcp_base = get_mcp_service_url()
|
||||
return f"{mcp_base}/screenshot/chart/{chart_id}.png"
|
||||
|
||||
|
||||
def get_explore_screenshot_url(form_data_key: str) -> str:
|
||||
"""
|
||||
Generate a screenshot URL for an explore view using the MCP service.
|
||||
|
||||
Args:
|
||||
form_data_key: Form data key for the explore view
|
||||
|
||||
Returns:
|
||||
Complete URL to the explore screenshot endpoint
|
||||
"""
|
||||
mcp_base = get_mcp_service_url()
|
||||
return f"{mcp_base}/screenshot/explore/{form_data_key}.png"
|
||||
137
superset/mcp_service/webdriver_config.py
Normal file
137
superset/mcp_service/webdriver_config.py
Normal file
@@ -0,0 +1,137 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
WebDriver pool configuration defaults for Superset MCP service
|
||||
"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
# Default WebDriver pool configuration
|
||||
DEFAULT_WEBDRIVER_POOL_CONFIG = {
|
||||
# Maximum number of WebDriver instances to keep in the pool
|
||||
"MAX_POOL_SIZE": 5,
|
||||
# Maximum age of a WebDriver instance (in seconds)
|
||||
# After this time, the driver will be destroyed and recreated
|
||||
"MAX_AGE_SECONDS": 3600, # 1 hour
|
||||
# Maximum number of times a WebDriver can be reused
|
||||
# After this many uses, the driver will be destroyed and recreated
|
||||
"MAX_USAGE_COUNT": 50,
|
||||
# How long a WebDriver can sit idle before being destroyed (in seconds)
|
||||
"IDLE_TIMEOUT_SECONDS": 300, # 5 minutes
|
||||
# How often to perform health checks on WebDriver instances (in seconds)
|
||||
"HEALTH_CHECK_INTERVAL": 60, # 1 minute
|
||||
}
|
||||
|
||||
|
||||
def configure_webdriver_pool(app_config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Configure WebDriver pool settings in Superset app config.
|
||||
|
||||
This function adds WebDriver pool configuration to the Superset app config
|
||||
if it doesn't already exist, using sensible defaults.
|
||||
|
||||
Args:
|
||||
app_config: The Superset application configuration dictionary
|
||||
"""
|
||||
if "WEBDRIVER_POOL" not in app_config:
|
||||
app_config["WEBDRIVER_POOL"] = DEFAULT_WEBDRIVER_POOL_CONFIG.copy()
|
||||
else:
|
||||
# Merge with defaults for any missing keys
|
||||
for key, default_value in DEFAULT_WEBDRIVER_POOL_CONFIG.items():
|
||||
if key not in app_config["WEBDRIVER_POOL"]:
|
||||
app_config["WEBDRIVER_POOL"][key] = default_value
|
||||
|
||||
|
||||
def get_pool_stats_endpoint() -> Any:
|
||||
"""
|
||||
Create a Flask endpoint to view WebDriver pool statistics.
|
||||
|
||||
This function can be called to register a debugging endpoint
|
||||
that shows the current state of the WebDriver pool.
|
||||
|
||||
Returns:
|
||||
Flask route function for pool statistics
|
||||
"""
|
||||
|
||||
def pool_stats() -> Any:
|
||||
try:
|
||||
from flask import jsonify
|
||||
|
||||
from superset.mcp_service.webdriver_pool import get_webdriver_pool
|
||||
|
||||
pool = get_webdriver_pool()
|
||||
stats = pool.get_stats()
|
||||
|
||||
return jsonify({"webdriver_pool": stats, "status": "healthy"})
|
||||
except Exception as e:
|
||||
from flask import jsonify
|
||||
|
||||
return jsonify({"error": str(e), "status": "error"}), 500
|
||||
|
||||
return pool_stats
|
||||
|
||||
|
||||
# Performance tuning recommendations based on use case
|
||||
PERFORMANCE_CONFIGS = {
|
||||
"low_traffic": {
|
||||
"MAX_POOL_SIZE": 2,
|
||||
"MAX_AGE_SECONDS": 1800, # 30 minutes
|
||||
"MAX_USAGE_COUNT": 25,
|
||||
"IDLE_TIMEOUT_SECONDS": 180, # 3 minutes
|
||||
"HEALTH_CHECK_INTERVAL": 120, # 2 minutes
|
||||
},
|
||||
"medium_traffic": {
|
||||
"MAX_POOL_SIZE": 5,
|
||||
"MAX_AGE_SECONDS": 3600, # 1 hour
|
||||
"MAX_USAGE_COUNT": 50,
|
||||
"IDLE_TIMEOUT_SECONDS": 300, # 5 minutes
|
||||
"HEALTH_CHECK_INTERVAL": 60, # 1 minute
|
||||
},
|
||||
"high_traffic": {
|
||||
"MAX_POOL_SIZE": 10,
|
||||
"MAX_AGE_SECONDS": 7200, # 2 hours
|
||||
"MAX_USAGE_COUNT": 100,
|
||||
"IDLE_TIMEOUT_SECONDS": 600, # 10 minutes
|
||||
"HEALTH_CHECK_INTERVAL": 30, # 30 seconds
|
||||
},
|
||||
"development": {
|
||||
"MAX_POOL_SIZE": 2,
|
||||
"MAX_AGE_SECONDS": 900, # 15 minutes
|
||||
"MAX_USAGE_COUNT": 10,
|
||||
"IDLE_TIMEOUT_SECONDS": 120, # 2 minutes
|
||||
"HEALTH_CHECK_INTERVAL": 30, # 30 seconds
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def configure_for_environment(
|
||||
app_config: Dict[str, Any], environment: str = "medium_traffic"
|
||||
) -> None:
|
||||
"""
|
||||
Configure WebDriver pool for specific environment/traffic levels.
|
||||
|
||||
Args:
|
||||
app_config: The Superset application configuration dictionary
|
||||
environment: Environment type (low_traffic, medium_traffic,
|
||||
high_traffic, development)
|
||||
"""
|
||||
if environment in PERFORMANCE_CONFIGS:
|
||||
app_config["WEBDRIVER_POOL"] = PERFORMANCE_CONFIGS[environment].copy()
|
||||
else:
|
||||
# Fallback to default
|
||||
configure_webdriver_pool(app_config)
|
||||
384
superset/mcp_service/webdriver_pool.py
Normal file
384
superset/mcp_service/webdriver_pool.py
Normal file
@@ -0,0 +1,384 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
WebDriver connection pooling for improved screenshot performance
|
||||
"""
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from queue import Empty, Full, Queue
|
||||
from typing import Any, Dict, Generator, Optional
|
||||
|
||||
from flask import current_app
|
||||
from selenium.common.exceptions import WebDriverException
|
||||
from selenium.webdriver.remote.webdriver import WebDriver
|
||||
|
||||
from superset.utils.webdriver import WebDriverSelenium, WindowSize
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PooledWebDriver:
|
||||
"""Wrapper for pooled WebDriver instance with metadata"""
|
||||
|
||||
driver: WebDriver
|
||||
created_at: float
|
||||
last_used: float
|
||||
window_size: WindowSize
|
||||
user_id: Optional[int] = None
|
||||
is_healthy: bool = True
|
||||
usage_count: int = 0
|
||||
|
||||
|
||||
class WebDriverPool:
|
||||
"""
|
||||
Connection pool for WebDriver instances to improve screenshot performance.
|
||||
|
||||
Features:
|
||||
- Reuses WebDriver instances across requests
|
||||
- Automatic health checking and recovery
|
||||
- TTL-based expiration to prevent memory leaks
|
||||
- Thread-safe operations
|
||||
- Per-user driver isolation for security
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_pool_size: int = 5,
|
||||
max_age_seconds: int = 3600, # 1 hour
|
||||
max_usage_count: int = 50, # Recreate after 50 uses
|
||||
idle_timeout_seconds: int = 300, # 5 minutes
|
||||
health_check_interval: int = 60, # 1 minute
|
||||
):
|
||||
self.max_pool_size = max_pool_size
|
||||
self.max_age_seconds = max_age_seconds
|
||||
self.max_usage_count = max_usage_count
|
||||
self.idle_timeout_seconds = idle_timeout_seconds
|
||||
self.health_check_interval = health_check_interval
|
||||
|
||||
# Thread-safe pool management
|
||||
self._pool: Queue[PooledWebDriver] = Queue(maxsize=max_pool_size)
|
||||
self._active_drivers: Dict[int, PooledWebDriver] = {}
|
||||
self._lock = threading.RLock()
|
||||
self._last_health_check = time.time()
|
||||
|
||||
# Pool statistics
|
||||
self._stats = {
|
||||
"created": 0,
|
||||
"destroyed": 0,
|
||||
"borrowed": 0,
|
||||
"returned": 0,
|
||||
"health_check_failures": 0,
|
||||
"evictions": 0,
|
||||
}
|
||||
|
||||
def get_stats(self) -> Dict[str, Any]:
|
||||
"""Get pool statistics for monitoring"""
|
||||
with self._lock:
|
||||
return {
|
||||
**self._stats,
|
||||
"pool_size": self._pool.qsize(),
|
||||
"active_count": len(self._active_drivers),
|
||||
"max_pool_size": self.max_pool_size,
|
||||
}
|
||||
|
||||
def _create_driver(
|
||||
self, window_size: WindowSize, user_id: Optional[int] = None
|
||||
) -> PooledWebDriver:
|
||||
"""Create a new WebDriver instance"""
|
||||
try:
|
||||
driver_type = current_app.config.get("WEBDRIVER_TYPE", "firefox")
|
||||
selenium_driver = WebDriverSelenium(driver_type, window_size)
|
||||
|
||||
# Create the actual WebDriver (without auth for now)
|
||||
driver = selenium_driver.create()
|
||||
driver.set_window_size(*window_size)
|
||||
|
||||
pooled_driver = PooledWebDriver(
|
||||
driver=driver,
|
||||
created_at=time.time(),
|
||||
last_used=time.time(),
|
||||
window_size=window_size,
|
||||
user_id=user_id,
|
||||
is_healthy=True,
|
||||
usage_count=0,
|
||||
)
|
||||
|
||||
self._stats["created"] += 1
|
||||
logger.debug(
|
||||
f"Created new WebDriver instance for window size {window_size}"
|
||||
)
|
||||
return pooled_driver
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to create WebDriver: {e}")
|
||||
raise
|
||||
|
||||
def _is_driver_valid(self, pooled_driver: PooledWebDriver) -> bool:
|
||||
"""Check if a pooled driver is still valid for use"""
|
||||
now = time.time()
|
||||
|
||||
# Check age limit
|
||||
if now - pooled_driver.created_at > self.max_age_seconds:
|
||||
logger.debug("Driver expired due to age")
|
||||
return False
|
||||
|
||||
# Check usage count limit
|
||||
if pooled_driver.usage_count >= self.max_usage_count:
|
||||
logger.debug("Driver expired due to usage count")
|
||||
return False
|
||||
|
||||
# Check idle timeout
|
||||
if now - pooled_driver.last_used > self.idle_timeout_seconds:
|
||||
logger.debug("Driver expired due to idle timeout")
|
||||
return False
|
||||
|
||||
# Check if driver is healthy
|
||||
if not pooled_driver.is_healthy:
|
||||
logger.debug("Driver marked as unhealthy")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _health_check_driver(self, pooled_driver: PooledWebDriver) -> bool:
|
||||
"""Perform health check on a WebDriver instance"""
|
||||
try:
|
||||
# Simple health check - try to get current URL
|
||||
# This will fail if the driver is dead/hung
|
||||
_ = pooled_driver.driver.current_url
|
||||
pooled_driver.is_healthy = True
|
||||
return True
|
||||
except WebDriverException:
|
||||
pooled_driver.is_healthy = False
|
||||
self._stats["health_check_failures"] += 1
|
||||
logger.warning("WebDriver failed health check")
|
||||
return False
|
||||
except Exception as e:
|
||||
pooled_driver.is_healthy = False
|
||||
self._stats["health_check_failures"] += 1
|
||||
logger.warning(f"WebDriver health check error: {e}")
|
||||
return False
|
||||
|
||||
def _destroy_driver(self, pooled_driver: PooledWebDriver) -> None:
|
||||
"""Safely destroy a WebDriver instance"""
|
||||
try:
|
||||
WebDriverSelenium.destroy(pooled_driver.driver)
|
||||
self._stats["destroyed"] += 1
|
||||
logger.debug("Destroyed WebDriver instance")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error destroying WebDriver: {e}")
|
||||
|
||||
def _cleanup_expired_drivers(self) -> None:
|
||||
"""Remove expired drivers from the pool"""
|
||||
expired_drivers = []
|
||||
|
||||
# Check pool for expired drivers
|
||||
while not self._pool.empty():
|
||||
try:
|
||||
pooled_driver = self._pool.get_nowait()
|
||||
if self._is_driver_valid(pooled_driver):
|
||||
# Driver is still valid, put it back
|
||||
self._pool.put_nowait(pooled_driver)
|
||||
break
|
||||
else:
|
||||
# Driver is expired
|
||||
expired_drivers.append(pooled_driver)
|
||||
self._stats["evictions"] += 1
|
||||
except Empty:
|
||||
break
|
||||
except Full:
|
||||
# Pool is full, stop checking
|
||||
break
|
||||
|
||||
# Destroy expired drivers
|
||||
for pooled_driver in expired_drivers:
|
||||
self._destroy_driver(pooled_driver)
|
||||
|
||||
def _periodic_health_check(self) -> None:
|
||||
"""Perform periodic health checks if needed"""
|
||||
now = time.time()
|
||||
if now - self._last_health_check < self.health_check_interval:
|
||||
return
|
||||
|
||||
self._last_health_check = now
|
||||
logger.debug("Performing periodic WebDriver pool health check")
|
||||
|
||||
# Cleanup expired drivers
|
||||
self._cleanup_expired_drivers()
|
||||
|
||||
# Health check active drivers
|
||||
unhealthy_drivers = []
|
||||
for driver_id, pooled_driver in self._active_drivers.items():
|
||||
if not self._health_check_driver(pooled_driver):
|
||||
unhealthy_drivers.append(driver_id)
|
||||
|
||||
# Remove unhealthy active drivers
|
||||
for driver_id in unhealthy_drivers:
|
||||
pooled_driver = self._active_drivers.pop(driver_id)
|
||||
if pooled_driver:
|
||||
self._destroy_driver(pooled_driver)
|
||||
|
||||
@contextmanager
|
||||
def get_driver(
|
||||
self, window_size: WindowSize, user_id: Optional[int] = None
|
||||
) -> Generator[WebDriver, None, None]:
|
||||
"""
|
||||
Context manager to get a WebDriver from the pool.
|
||||
|
||||
Args:
|
||||
window_size: Required window size for the driver
|
||||
user_id: Optional user ID for driver isolation
|
||||
|
||||
Yields:
|
||||
WebDriver instance ready for use
|
||||
"""
|
||||
pooled_driver = None
|
||||
driver_id = None
|
||||
|
||||
try:
|
||||
with self._lock:
|
||||
# Periodic maintenance
|
||||
self._periodic_health_check()
|
||||
|
||||
# Try to get a driver from the pool
|
||||
while not self._pool.empty():
|
||||
try:
|
||||
candidate = self._pool.get_nowait()
|
||||
|
||||
# Check if driver is valid and matches requirements
|
||||
if (
|
||||
self._is_driver_valid(candidate)
|
||||
and candidate.window_size == window_size
|
||||
):
|
||||
pooled_driver = candidate
|
||||
break
|
||||
else:
|
||||
# Driver is invalid, destroy it
|
||||
self._destroy_driver(candidate)
|
||||
self._stats["evictions"] += 1
|
||||
except Empty:
|
||||
break
|
||||
|
||||
# If no suitable driver found, create a new one
|
||||
if pooled_driver is None:
|
||||
pooled_driver = self._create_driver(window_size, user_id)
|
||||
|
||||
# Mark driver as in use
|
||||
driver_id = id(pooled_driver.driver)
|
||||
pooled_driver.last_used = time.time()
|
||||
pooled_driver.usage_count += 1
|
||||
self._active_drivers[driver_id] = pooled_driver
|
||||
self._stats["borrowed"] += 1
|
||||
|
||||
# Yield the driver for use
|
||||
yield pooled_driver.driver
|
||||
|
||||
except Exception as e:
|
||||
# Mark driver as unhealthy if an error occurred
|
||||
if pooled_driver:
|
||||
pooled_driver.is_healthy = False
|
||||
logger.error(f"Error using pooled WebDriver: {e}")
|
||||
raise
|
||||
|
||||
finally:
|
||||
# Return driver to pool or destroy if unhealthy
|
||||
if pooled_driver and driver_id:
|
||||
with self._lock:
|
||||
self._active_drivers.pop(driver_id, None)
|
||||
|
||||
if pooled_driver.is_healthy and self._is_driver_valid(
|
||||
pooled_driver
|
||||
):
|
||||
# Try to return to pool
|
||||
try:
|
||||
self._pool.put_nowait(pooled_driver)
|
||||
self._stats["returned"] += 1
|
||||
logger.debug("Returned WebDriver to pool")
|
||||
except Full:
|
||||
# Pool is full, destroy the driver
|
||||
self._destroy_driver(pooled_driver)
|
||||
logger.debug("Pool full, destroyed WebDriver")
|
||||
else:
|
||||
# Driver is unhealthy or expired, destroy it
|
||||
self._destroy_driver(pooled_driver)
|
||||
logger.debug("Destroyed unhealthy/expired WebDriver")
|
||||
|
||||
def shutdown(self) -> None:
|
||||
"""Shutdown the pool and destroy all drivers"""
|
||||
with self._lock:
|
||||
logger.info("Shutting down WebDriver pool")
|
||||
|
||||
# Destroy all active drivers
|
||||
for pooled_driver in self._active_drivers.values():
|
||||
self._destroy_driver(pooled_driver)
|
||||
self._active_drivers.clear()
|
||||
|
||||
# Destroy all pooled drivers
|
||||
while not self._pool.empty():
|
||||
try:
|
||||
pooled_driver = self._pool.get_nowait()
|
||||
self._destroy_driver(pooled_driver)
|
||||
except Empty:
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"WebDriver pool shutdown complete. Final stats: {self.get_stats()}"
|
||||
)
|
||||
|
||||
|
||||
# Global pool instance
|
||||
_global_pool: Optional[WebDriverPool] = None
|
||||
_pool_lock = threading.Lock()
|
||||
|
||||
|
||||
def get_webdriver_pool() -> WebDriverPool:
|
||||
"""Get or create the global WebDriver pool"""
|
||||
global _global_pool
|
||||
|
||||
if _global_pool is None:
|
||||
with _pool_lock:
|
||||
if _global_pool is None:
|
||||
# Get pool configuration from Flask config
|
||||
config = current_app.config
|
||||
pool_config = config.get("WEBDRIVER_POOL", {})
|
||||
|
||||
_global_pool = WebDriverPool(
|
||||
max_pool_size=pool_config.get("MAX_POOL_SIZE", 5),
|
||||
max_age_seconds=pool_config.get("MAX_AGE_SECONDS", 3600),
|
||||
max_usage_count=pool_config.get("MAX_USAGE_COUNT", 50),
|
||||
idle_timeout_seconds=pool_config.get("IDLE_TIMEOUT_SECONDS", 300),
|
||||
health_check_interval=pool_config.get("HEALTH_CHECK_INTERVAL", 60),
|
||||
)
|
||||
logger.info("Initialized global WebDriver pool")
|
||||
|
||||
return _global_pool
|
||||
|
||||
|
||||
def shutdown_webdriver_pool() -> None:
|
||||
"""Shutdown the global WebDriver pool"""
|
||||
global _global_pool
|
||||
|
||||
if _global_pool is not None:
|
||||
with _pool_lock:
|
||||
if _global_pool is not None:
|
||||
_global_pool.shutdown()
|
||||
_global_pool = None
|
||||
1097
tests/unit_tests/dao/base_test.py
Normal file
1097
tests/unit_tests/dao/base_test.py
Normal file
File diff suppressed because it is too large
Load Diff
36
tests/unit_tests/dao/conftest.py
Normal file
36
tests/unit_tests/dao/conftest.py
Normal file
@@ -0,0 +1,36 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import pytest
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def user_with_data(session):
|
||||
User.metadata.create_all(session.get_bind())
|
||||
user = User(
|
||||
id=101,
|
||||
username="testuser",
|
||||
first_name="Test",
|
||||
last_name="User",
|
||||
email="testuser@example.com",
|
||||
active=True,
|
||||
)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
yield session
|
||||
session.rollback()
|
||||
1310
tests/unit_tests/dao/test_base_dao.py
Normal file
1310
tests/unit_tests/dao/test_base_dao.py
Normal file
File diff suppressed because it is too large
Load Diff
111
tests/unit_tests/mcp_service/conftest.py
Normal file
111
tests/unit_tests/mcp_service/conftest.py
Normal file
@@ -0,0 +1,111 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Test configuration for MCP service unit tests."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from flask_appbuilder.security.sqla.models import User
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_mcp_auth(request) -> Iterator[None]:
|
||||
"""
|
||||
Mock MCP authentication for all tests in this directory.
|
||||
|
||||
This fixture automatically mocks the authentication system so that
|
||||
MCP tests can run without needing real users in the database.
|
||||
|
||||
Skip this fixture for auth-specific test files that need to test
|
||||
the actual auth functionality.
|
||||
"""
|
||||
# Skip auth mocking for auth test files
|
||||
if "test_auth" in str(request.fspath):
|
||||
yield
|
||||
return
|
||||
# Create a mock user for testing
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.username = "test_user"
|
||||
mock_user.id = 1
|
||||
mock_user.is_active = True
|
||||
mock_user.first_name = "Test"
|
||||
mock_user.last_name = "User"
|
||||
mock_user.email = "test@example.com"
|
||||
|
||||
with (
|
||||
patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user,
|
||||
patch("superset.mcp_service.auth.has_permission") as mock_has_permission,
|
||||
patch("superset.mcp_service.auth.impersonate_user") as mock_impersonate,
|
||||
patch("flask.g") as mock_g,
|
||||
patch("superset.mcp_service.auth.event_logger") as mock_auth_event_logger,
|
||||
patch(
|
||||
"superset.mcp_service.middleware.event_logger"
|
||||
) as mock_middleware_event_logger,
|
||||
patch("superset.utils.log.DBEventLogger.log"),
|
||||
):
|
||||
# Mock user extraction to always return test user
|
||||
mock_get_user.return_value = mock_user
|
||||
|
||||
# Mock permission checks to always allow access
|
||||
mock_has_permission.return_value = True
|
||||
|
||||
# Mock impersonation to return the same user (no impersonation)
|
||||
mock_impersonate.side_effect = lambda user, run_as: user
|
||||
|
||||
# Mock Flask's g object
|
||||
mock_g.user = mock_user
|
||||
|
||||
# Mock event loggers to prevent database foreign key constraint errors
|
||||
# Create a mock decorator that simply returns the original function
|
||||
def mock_log_this_with_context(*args, **kwargs):
|
||||
def decorator(func):
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
# Mock the auth event logger
|
||||
mock_auth_event_logger.log_this_with_context = mock_log_this_with_context
|
||||
# Also mock the log method to prevent database writes
|
||||
mock_auth_event_logger.log = MagicMock()
|
||||
|
||||
# Mock the middleware event logger's log method
|
||||
mock_middleware_event_logger.log = MagicMock()
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_user() -> User:
|
||||
"""Provide a test user for tests that need one."""
|
||||
mock_user = MagicMock(spec=User)
|
||||
mock_user.username = "test_user"
|
||||
mock_user.id = 1
|
||||
mock_user.is_active = True
|
||||
mock_user.first_name = "Test"
|
||||
mock_user.last_name = "User"
|
||||
mock_user.email = "test@example.com"
|
||||
return mock_user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
"""Provide the MCP server instance for testing."""
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
|
||||
return mcp
|
||||
314
tests/unit_tests/mcp_service/test_auth.py
Normal file
314
tests/unit_tests/mcp_service/test_auth.py
Normal file
@@ -0,0 +1,314 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Unit tests for MCP service authentication."""
|
||||
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from superset.mcp_service.auth import (
|
||||
get_mcp_audit_context,
|
||||
get_user_from_request,
|
||||
has_permission,
|
||||
impersonate_user,
|
||||
mcp_auth_hook,
|
||||
sanitize_mcp_payload,
|
||||
)
|
||||
from superset.mcp_service.config import (
|
||||
create_default_mcp_auth_factory,
|
||||
default_user_resolver,
|
||||
)
|
||||
|
||||
|
||||
class TestMCPAuth:
|
||||
"""Test MCP authentication functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def app(self):
|
||||
"""Create test Flask app."""
|
||||
app = Flask(__name__)
|
||||
app.config["MCP_ADMIN_USERNAME"] = "admin"
|
||||
app.config["MCP_USER_RESOLVER"] = default_user_resolver
|
||||
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
|
||||
app.config["SQLALCHEMY_DATABASE_URI"] = "sqlite:///:memory:"
|
||||
return app
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user(self):
|
||||
"""Create mock Superset user."""
|
||||
user = MagicMock()
|
||||
user.username = "testuser"
|
||||
user.id = 1
|
||||
user.is_active = True
|
||||
return user
|
||||
|
||||
@pytest.fixture
|
||||
def mock_token(self):
|
||||
"""Create mock JWT access token."""
|
||||
token = MagicMock()
|
||||
token.subject = "testuser"
|
||||
token.client_id = "testuser"
|
||||
token.scopes = ["dashboard:read", "chart:read"]
|
||||
token.payload = {"sub": "testuser", "email": "test@example.com"}
|
||||
return token
|
||||
|
||||
def test_default_user_resolver(self, mock_token):
|
||||
"""Test JWT user extraction."""
|
||||
# Test subject attribute
|
||||
mock_token.subject = "subject_user"
|
||||
assert default_user_resolver(mock_token) == "subject_user"
|
||||
|
||||
# Test client_id fallback
|
||||
del mock_token.subject
|
||||
assert default_user_resolver(mock_token) == "testuser"
|
||||
|
||||
# Test payload extraction
|
||||
del mock_token.client_id
|
||||
assert default_user_resolver(mock_token) == "testuser"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("fastmcp.server.dependencies.get_access_token")
|
||||
def test_get_user_from_jwt(
|
||||
self, mock_get_token, mock_sm, app, mock_user, mock_token
|
||||
):
|
||||
"""Test user extraction from JWT token."""
|
||||
with app.app_context():
|
||||
mock_get_token.return_value = mock_token
|
||||
# Use Mock instead of MagicMock to avoid async behavior
|
||||
mock_sm.find_user = Mock(return_value=mock_user)
|
||||
|
||||
user = get_user_from_request()
|
||||
assert user == mock_user
|
||||
mock_sm.find_user.assert_called_once_with("testuser")
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("fastmcp.server.dependencies.get_access_token")
|
||||
def test_get_user_fallback(self, mock_get_token, mock_sm, app, mock_user):
|
||||
"""Test fallback to admin user when no JWT."""
|
||||
with app.app_context():
|
||||
mock_get_token.side_effect = Exception("No token")
|
||||
mock_sm.find_user = Mock(return_value=mock_user)
|
||||
|
||||
user = get_user_from_request()
|
||||
assert user == mock_user
|
||||
mock_sm.find_user.assert_called_once_with("admin")
|
||||
|
||||
@patch("superset.security_manager")
|
||||
def test_impersonate_user(self, mock_sm, mock_user):
|
||||
"""Test user impersonation."""
|
||||
impersonated = MagicMock()
|
||||
impersonated.username = "other_user"
|
||||
mock_sm.find_user = Mock(return_value=impersonated)
|
||||
|
||||
result = impersonate_user(mock_user, "other_user")
|
||||
assert result == impersonated
|
||||
mock_sm.find_user.assert_called_once_with("other_user")
|
||||
|
||||
# Test no impersonation
|
||||
result = impersonate_user(mock_user, None)
|
||||
assert result == mock_user
|
||||
|
||||
@patch("fastmcp.server.dependencies.get_access_token")
|
||||
def test_has_permission_with_scopes(self, mock_get_token, mock_user, mock_token):
|
||||
"""Test permission checking with JWT scopes."""
|
||||
mock_get_token.return_value = mock_token
|
||||
|
||||
# Test allowed access
|
||||
mock_func = MagicMock(__name__="list_dashboards")
|
||||
assert has_permission(mock_user, mock_func) is True
|
||||
|
||||
# Test denied access
|
||||
mock_func.__name__ = "generate_chart"
|
||||
assert has_permission(mock_user, mock_func) is False
|
||||
|
||||
def test_has_permission_no_jwt(self, mock_user):
|
||||
"""Test permission checking without JWT."""
|
||||
mock_func = MagicMock(__name__="generate_chart")
|
||||
|
||||
# Should allow access when no JWT
|
||||
with patch(
|
||||
"fastmcp.server.dependencies.get_access_token",
|
||||
side_effect=Exception("No token"),
|
||||
):
|
||||
assert has_permission(mock_user, mock_func) is True
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("fastmcp.server.dependencies.get_access_token")
|
||||
def test_mcp_auth_hook_decorator(
|
||||
self, mock_get_token, mock_sm, app, mock_user, mock_token
|
||||
):
|
||||
"""Test the auth decorator."""
|
||||
with app.app_context():
|
||||
mock_get_token.return_value = mock_token
|
||||
mock_sm.find_user = Mock(return_value=mock_user)
|
||||
|
||||
# Disable event logger to avoid database issues
|
||||
with patch("superset.mcp_service.auth.event_logger", None):
|
||||
|
||||
@mcp_auth_hook
|
||||
def test_tool():
|
||||
return "success"
|
||||
|
||||
result = test_tool()
|
||||
assert result == "success"
|
||||
assert g.user == mock_user
|
||||
|
||||
def test_create_default_auth_factory(self, app):
|
||||
"""Test auth factory creation."""
|
||||
# Test disabled auth
|
||||
app.config["MCP_AUTH_ENABLED"] = False
|
||||
assert create_default_mcp_auth_factory(app) is None
|
||||
|
||||
# Test missing keys
|
||||
app.config["MCP_AUTH_ENABLED"] = True
|
||||
assert create_default_mcp_auth_factory(app) is None
|
||||
|
||||
# Test successful creation with mock
|
||||
app.config["MCP_JWKS_URI"] = "https://example.com/.well-known/jwks"
|
||||
with patch(
|
||||
"fastmcp.server.auth.providers.bearer.BearerAuthProvider"
|
||||
) as mock_provider:
|
||||
result = create_default_mcp_auth_factory(app)
|
||||
mock_provider.assert_called_once()
|
||||
assert result is not None
|
||||
|
||||
def test_sanitize_mcp_payload(self):
|
||||
"""Test MCP payload sanitization for audit logging."""
|
||||
# Test sensitive field redaction
|
||||
payload = {
|
||||
"dataset_id": 123,
|
||||
"password": "secret123",
|
||||
"auth_token": "Bearer xyz",
|
||||
"api_key": "key123",
|
||||
"config": {"chart_type": "table"},
|
||||
"normal_field": "value",
|
||||
}
|
||||
|
||||
sanitized = sanitize_mcp_payload(payload)
|
||||
|
||||
assert sanitized["dataset_id"] == 123
|
||||
assert sanitized["password"] == "[REDACTED]" # noqa: S105
|
||||
assert sanitized["auth_token"] == "[REDACTED]" # noqa: S105
|
||||
assert sanitized["api_key"] == "[REDACTED]"
|
||||
assert sanitized["config"] == {"chart_type": "table"}
|
||||
assert sanitized["normal_field"] == "value"
|
||||
|
||||
def test_sanitize_mcp_payload_truncation(self):
|
||||
"""Test payload truncation for large values."""
|
||||
long_text = "x" * 1500
|
||||
payload = {"long_field": long_text, "short_field": "short"}
|
||||
|
||||
sanitized = sanitize_mcp_payload(payload)
|
||||
|
||||
assert len(sanitized["long_field"]) == 1000 + len("...[TRUNCATED]")
|
||||
assert sanitized["long_field"].endswith("...[TRUNCATED]")
|
||||
assert sanitized["short_field"] == "short"
|
||||
|
||||
def test_get_mcp_audit_context(self, app, mock_user):
|
||||
"""Test MCP audit context generation."""
|
||||
with app.app_context():
|
||||
with app.test_request_context(
|
||||
headers={
|
||||
"User-Agent": "Claude-3.5-Sonnet",
|
||||
"X-Session-ID": "session123",
|
||||
}
|
||||
):
|
||||
g.user = mock_user
|
||||
|
||||
def mock_tool():
|
||||
pass
|
||||
|
||||
mock_tool.__name__ = "test_tool"
|
||||
|
||||
kwargs = {"dataset_id": 123, "config": {"type": "chart"}}
|
||||
|
||||
context = get_mcp_audit_context(mock_tool, kwargs)
|
||||
|
||||
assert context["log_source"] == "mcp"
|
||||
assert context["impersonation"] == "testuser"
|
||||
assert context["mcp_tool"] == "test_tool"
|
||||
assert context["model_info"] == "Claude-3.5-Sonnet"
|
||||
assert context["session_info"] == "session123"
|
||||
assert "whitelisted_payload" in context
|
||||
|
||||
def test_get_mcp_audit_context_no_request(self, app, mock_user):
|
||||
"""Test audit context generation when request info unavailable."""
|
||||
with app.app_context():
|
||||
g.user = mock_user
|
||||
|
||||
def mock_tool():
|
||||
pass
|
||||
|
||||
mock_tool.__name__ = "test_tool"
|
||||
|
||||
kwargs = {"dataset_id": 123}
|
||||
|
||||
# No request context
|
||||
context = get_mcp_audit_context(mock_tool, kwargs)
|
||||
|
||||
assert context["log_source"] == "mcp"
|
||||
assert context["impersonation"] == "testuser"
|
||||
assert context["mcp_tool"] == "test_tool"
|
||||
# Should handle missing request gracefully
|
||||
|
||||
@patch("superset.extensions.event_logger")
|
||||
@patch("superset.security_manager")
|
||||
@patch("fastmcp.server.dependencies.get_access_token")
|
||||
def test_mcp_auth_hook_with_audit_logging(
|
||||
self, mock_get_token, mock_sm, mock_event_logger, app, mock_user, mock_token
|
||||
):
|
||||
"""Test auth decorator with audit logging enabled."""
|
||||
with app.app_context():
|
||||
mock_get_token.return_value = mock_token
|
||||
mock_sm.find_user = Mock(return_value=mock_user)
|
||||
mock_event_logger.log_this_with_context = Mock(return_value=lambda f: f)
|
||||
|
||||
# Mock the event logger to avoid database issues
|
||||
with patch("superset.mcp_service.auth.event_logger", mock_event_logger):
|
||||
|
||||
@mcp_auth_hook
|
||||
def test_tool():
|
||||
return "success"
|
||||
|
||||
result = test_tool()
|
||||
assert result == "success"
|
||||
assert g.user == mock_user
|
||||
assert hasattr(g, "mcp_audit_context")
|
||||
assert g.mcp_audit_context["log_source"] == "mcp"
|
||||
|
||||
@patch("superset.security_manager")
|
||||
@patch("fastmcp.server.dependencies.get_access_token")
|
||||
def test_mcp_auth_hook_without_event_logger(
|
||||
self, mock_get_token, mock_sm, app, mock_user, mock_token
|
||||
):
|
||||
"""Test auth decorator gracefully handles missing event logger."""
|
||||
with app.app_context():
|
||||
mock_get_token.return_value = mock_token
|
||||
mock_sm.find_user = Mock(return_value=mock_user)
|
||||
|
||||
# Event logger is None (fallback case)
|
||||
with patch("superset.mcp_service.auth.event_logger", None):
|
||||
|
||||
@mcp_auth_hook
|
||||
def test_tool():
|
||||
return "success"
|
||||
|
||||
result = test_tool()
|
||||
assert result == "success"
|
||||
assert g.user == mock_user
|
||||
147
tests/unit_tests/mcp_service/test_auth_integration.py
Normal file
147
tests/unit_tests/mcp_service/test_auth_integration.py
Normal file
@@ -0,0 +1,147 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP service JWT authentication integration.
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from fastmcp.server.auth.providers.bearer import BearerAuthProvider, RSAKeyPair
|
||||
|
||||
from superset.mcp_service.auth import get_user_from_request, has_permission
|
||||
|
||||
|
||||
class TestMCPAuthIntegration:
|
||||
"""Test JWT authentication integration with MCP service."""
|
||||
|
||||
def test_auth_disabled_by_default(self):
|
||||
"""Test that authentication is disabled by default."""
|
||||
from superset.mcp_service.mcp_app import _create_auth_provider
|
||||
|
||||
auth_provider = _create_auth_provider()
|
||||
assert auth_provider is None
|
||||
|
||||
@patch.dict(os.environ, {"MCP_AUTH_ENABLED": "true"})
|
||||
def test_auth_enabled_missing_config(self):
|
||||
"""Test that auth is disabled if enabled but config is missing."""
|
||||
from superset.mcp_service.mcp_app import _create_auth_provider
|
||||
|
||||
auth_provider = _create_auth_provider()
|
||||
assert auth_provider is None
|
||||
|
||||
def test_auth_provider_creation_with_invalid_key(self):
|
||||
"""Test that auth provider creation handles invalid configuration gracefully."""
|
||||
from superset.mcp_service.mcp_app import _create_auth_provider
|
||||
|
||||
# This test verifies that _create_auth_provider doesn't crash with
|
||||
# invalid config
|
||||
# The function should return None gracefully when auth is not
|
||||
# properly configured
|
||||
auth_provider = _create_auth_provider()
|
||||
|
||||
# Should return None when auth is not configured (default behavior)
|
||||
assert auth_provider is None
|
||||
|
||||
@patch("superset.security_manager")
|
||||
def test_get_user_from_request_no_jwt(self, mock_sm):
|
||||
"""Test user extraction falls back to configured admin user when
|
||||
no JWT available.
|
||||
"""
|
||||
from unittest.mock import Mock
|
||||
|
||||
# Mock the security manager to return a test user
|
||||
mock_user = MagicMock()
|
||||
mock_user.username = "test_admin"
|
||||
mock_user.is_active = True
|
||||
mock_sm.find_user = Mock(return_value=mock_user)
|
||||
|
||||
user = get_user_from_request()
|
||||
|
||||
# Should return the fallback admin user when no JWT context available
|
||||
assert hasattr(user, "username")
|
||||
assert user.username == "test_admin"
|
||||
assert hasattr(user, "is_active")
|
||||
assert user.is_active is True
|
||||
|
||||
# Verify security manager was called at least once
|
||||
assert mock_sm.find_user.called
|
||||
|
||||
def test_has_permission_no_jwt(self):
|
||||
"""Test permission check without JWT context."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_active = True
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.__name__ = "list_dashboards"
|
||||
|
||||
result = has_permission(mock_user, mock_tool)
|
||||
assert result is True
|
||||
|
||||
def test_has_permission_inactive_user(self):
|
||||
"""Test permission check with inactive user."""
|
||||
mock_user = MagicMock()
|
||||
mock_user.is_active = False
|
||||
mock_tool = MagicMock()
|
||||
mock_tool.__name__ = "list_dashboards"
|
||||
|
||||
result = has_permission(mock_user, mock_tool)
|
||||
assert result is False
|
||||
|
||||
|
||||
class TestJWTTokenGeneration:
|
||||
"""Test JWT token generation and validation for development/testing."""
|
||||
|
||||
def test_rsa_keypair_generation(self):
|
||||
"""Test RSA key pair generation."""
|
||||
keypair = RSAKeyPair.generate()
|
||||
|
||||
assert keypair.private_key.get_secret_value().startswith(
|
||||
"-----BEGIN PRIVATE KEY-----"
|
||||
)
|
||||
assert keypair.public_key.startswith("-----BEGIN PUBLIC KEY-----")
|
||||
|
||||
def test_jwt_token_creation(self):
|
||||
"""Test JWT token creation."""
|
||||
keypair = RSAKeyPair.generate()
|
||||
|
||||
token = keypair.create_token(
|
||||
subject="test-user",
|
||||
issuer="https://test.example.com",
|
||||
audience="test-audience",
|
||||
scopes=["dashboard:read", "chart:read"],
|
||||
)
|
||||
|
||||
assert isinstance(token, str)
|
||||
assert len(token.split(".")) == 3 # Header.Payload.Signature
|
||||
|
||||
def test_bearer_auth_provider_validation(self):
|
||||
"""Test BearerAuthProvider token validation."""
|
||||
keypair = RSAKeyPair.generate()
|
||||
|
||||
# Create auth provider
|
||||
auth_provider = BearerAuthProvider(
|
||||
public_key=keypair.public_key,
|
||||
issuer="https://test.example.com",
|
||||
audience="test-audience",
|
||||
)
|
||||
|
||||
# This is an async method, so we would need to run it in async context
|
||||
# For now, just verify the auth provider was created correctly
|
||||
assert auth_provider.issuer == "https://test.example.com"
|
||||
assert auth_provider.audience == "test-audience"
|
||||
assert auth_provider.public_key == keypair.public_key
|
||||
201
tests/unit_tests/mcp_service/test_cache_control.py
Normal file
201
tests/unit_tests/mcp_service/test_cache_control.py
Normal file
@@ -0,0 +1,201 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for MCP cache control functionality."""
|
||||
|
||||
from superset.mcp_service.cache_utils import (
|
||||
get_cache_status_from_result,
|
||||
should_use_metadata_cache,
|
||||
)
|
||||
from superset.mcp_service.schemas.cache_schemas import (
|
||||
CacheControlMixin,
|
||||
CacheStatus,
|
||||
FormDataCacheControl,
|
||||
MetadataCacheControl,
|
||||
QueryCacheControl,
|
||||
)
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
GetChartDataRequest,
|
||||
ListChartsRequest,
|
||||
)
|
||||
|
||||
|
||||
class TestCacheSchemas:
|
||||
"""Test cache control schemas."""
|
||||
|
||||
def test_cache_control_mixin_defaults(self):
|
||||
"""Test default values for cache control mixin."""
|
||||
cache_control = CacheControlMixin()
|
||||
assert cache_control.use_cache is True
|
||||
assert cache_control.force_refresh is False
|
||||
|
||||
def test_query_cache_control(self):
|
||||
"""Test query cache control schema."""
|
||||
request = QueryCacheControl(
|
||||
use_cache=False,
|
||||
force_refresh=True,
|
||||
cache_timeout=3600,
|
||||
)
|
||||
assert request.use_cache is False
|
||||
assert request.force_refresh is True
|
||||
assert request.cache_timeout == 3600
|
||||
|
||||
def test_metadata_cache_control(self):
|
||||
"""Test metadata cache control schema."""
|
||||
request = MetadataCacheControl(
|
||||
refresh_metadata=True,
|
||||
)
|
||||
assert request.refresh_metadata is True
|
||||
assert request.use_cache is True # Default from mixin
|
||||
|
||||
def test_form_data_cache_control(self):
|
||||
"""Test form data cache control schema."""
|
||||
request = FormDataCacheControl(
|
||||
cache_form_data=False,
|
||||
)
|
||||
assert request.cache_form_data is False
|
||||
|
||||
def test_cache_status_schema(self):
|
||||
"""Test cache status schema."""
|
||||
status = CacheStatus(
|
||||
cache_hit=True,
|
||||
cache_type="query",
|
||||
cache_age_seconds=300,
|
||||
refreshed=False,
|
||||
)
|
||||
assert status.cache_hit is True
|
||||
assert status.cache_type == "query"
|
||||
assert status.cache_age_seconds == 300
|
||||
assert status.refreshed is False
|
||||
|
||||
|
||||
class TestChartSchemasCacheControl:
|
||||
"""Test cache control integration in chart schemas."""
|
||||
|
||||
def test_get_chart_data_request_cache_control(self):
|
||||
"""Test GetChartDataRequest inherits cache control."""
|
||||
request = GetChartDataRequest(
|
||||
identifier=123,
|
||||
limit=50,
|
||||
use_cache=False,
|
||||
force_refresh=True,
|
||||
cache_timeout=1800,
|
||||
)
|
||||
assert request.identifier == 123
|
||||
assert request.limit == 50
|
||||
assert request.use_cache is False
|
||||
assert request.force_refresh is True
|
||||
assert request.cache_timeout == 1800
|
||||
|
||||
def test_list_charts_request_metadata_cache(self):
|
||||
"""Test ListChartsRequest inherits metadata cache control."""
|
||||
request = ListChartsRequest(
|
||||
refresh_metadata=True,
|
||||
use_cache=False,
|
||||
)
|
||||
assert request.refresh_metadata is True
|
||||
assert request.use_cache is False
|
||||
|
||||
|
||||
class TestCacheUtils:
|
||||
"""Test cache utility functions."""
|
||||
|
||||
def test_get_cache_status_from_result_hit(self):
|
||||
"""Test extracting cache status from query result - cache hit."""
|
||||
result = {
|
||||
"queries": [
|
||||
{
|
||||
"is_cached": True,
|
||||
"cache_dttm": "2023-01-01T12:00:00Z",
|
||||
"data": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
status = get_cache_status_from_result(result, force_refresh=False)
|
||||
assert status.cache_hit is True
|
||||
assert status.cache_type == "query"
|
||||
assert status.refreshed is False
|
||||
|
||||
def test_get_cache_status_from_result_miss(self):
|
||||
"""Test extracting cache status from query result - cache miss."""
|
||||
result = {
|
||||
"queries": [
|
||||
{
|
||||
"is_cached": False,
|
||||
"data": [],
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
status = get_cache_status_from_result(result, force_refresh=True)
|
||||
assert status.cache_hit is False
|
||||
assert status.cache_type == "none"
|
||||
assert status.refreshed is True
|
||||
|
||||
def test_get_cache_status_direct_result(self):
|
||||
"""Test extracting cache status from direct result (no queries wrapper)."""
|
||||
result = {
|
||||
"is_cached": True,
|
||||
"data": [],
|
||||
}
|
||||
|
||||
status = get_cache_status_from_result(result)
|
||||
assert status.cache_hit is True
|
||||
assert status.cache_type == "query"
|
||||
|
||||
def test_should_use_metadata_cache(self):
|
||||
"""Test metadata cache usage decision."""
|
||||
# Should use cache
|
||||
assert should_use_metadata_cache(use_cache=True, refresh_metadata=False) is True
|
||||
|
||||
# Should not use cache - disabled
|
||||
assert (
|
||||
should_use_metadata_cache(use_cache=False, refresh_metadata=False) is False
|
||||
)
|
||||
|
||||
# Should not use cache - refresh requested
|
||||
assert should_use_metadata_cache(use_cache=True, refresh_metadata=True) is False
|
||||
|
||||
|
||||
class TestCacheControlIntegration:
|
||||
"""Test cache control integration with MCP tools."""
|
||||
|
||||
def test_cache_control_in_existing_tools(self):
|
||||
"""Test that existing tools have cache control parameters."""
|
||||
# Test that chart data tool has cache control
|
||||
from superset.mcp_service.chart.tool.get_chart_data import get_chart_data
|
||||
from superset.mcp_service.chart.tool.list_charts import list_charts
|
||||
|
||||
# Verify the functions exist
|
||||
assert get_chart_data is not None
|
||||
assert list_charts is not None
|
||||
|
||||
def test_cache_control_request_validation(self):
|
||||
"""Test that cache control parameters are properly validated."""
|
||||
# Valid request
|
||||
request = GetChartDataRequest(
|
||||
identifier="123",
|
||||
cache_timeout=0, # Disable cache for this query
|
||||
)
|
||||
assert request.cache_timeout == 0
|
||||
|
||||
# Test default values
|
||||
request = GetChartDataRequest(identifier=456)
|
||||
assert request.use_cache is True
|
||||
assert request.force_refresh is False
|
||||
assert request.cache_timeout is None
|
||||
804
tests/unit_tests/mcp_service/test_chart_tools.py
Normal file
804
tests/unit_tests/mcp_service/test_chart_tools.py
Normal file
@@ -0,0 +1,804 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP chart tools (list_charts, get_chart_info,
|
||||
get_chart_available_filters, generate_chart)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import fastmcp.exceptions
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
|
||||
# Updated imports for new simplified schemas
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
AxisConfig,
|
||||
ChartInfo,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
GenerateChartRequest,
|
||||
LegendConfig,
|
||||
ListChartsRequest,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_charts_basic(mock_list, mcp_server):
|
||||
chart = Mock()
|
||||
chart.id = 1
|
||||
chart.slice_name = "Test Chart"
|
||||
chart.viz_type = "bar"
|
||||
chart.datasource_name = "test_ds"
|
||||
chart.datasource_type = "table"
|
||||
chart.url = "/chart/1"
|
||||
chart.description = "desc"
|
||||
chart.cache_timeout = 60
|
||||
chart.form_data = {}
|
||||
chart.query_context = {}
|
||||
chart.changed_by_name = "admin"
|
||||
chart.changed_on = None
|
||||
chart.changed_on_humanized = "1 day ago"
|
||||
chart.created_by_name = "admin"
|
||||
chart.created_on = None
|
||||
chart.created_on_humanized = "2 days ago"
|
||||
chart.uuid = "test-chart-uuid-1"
|
||||
chart.tags = []
|
||||
chart.owners = []
|
||||
chart._mapping = {
|
||||
"id": chart.id,
|
||||
"slice_name": chart.slice_name,
|
||||
"viz_type": chart.viz_type,
|
||||
"datasource_name": chart.datasource_name,
|
||||
"datasource_type": chart.datasource_type,
|
||||
"url": chart.url,
|
||||
"description": chart.description,
|
||||
"cache_timeout": chart.cache_timeout,
|
||||
"form_data": chart.form_data,
|
||||
"query_context": chart.query_context,
|
||||
"changed_by_name": chart.changed_by_name,
|
||||
"changed_on": chart.changed_on,
|
||||
"changed_on_humanized": chart.changed_on_humanized,
|
||||
"created_by_name": chart.created_by_name,
|
||||
"created_on": chart.created_on,
|
||||
"created_on_humanized": chart.created_on_humanized,
|
||||
"tags": chart.tags,
|
||||
"owners": chart.owners,
|
||||
}
|
||||
mock_list.return_value = ([chart], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListChartsRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_charts", {"request": request.model_dump()}
|
||||
)
|
||||
charts = result.data.charts
|
||||
assert len(charts) == 1
|
||||
assert charts[0].slice_name == "Test Chart"
|
||||
assert charts[0].uuid == "test-chart-uuid-1"
|
||||
assert charts[0].viz_type == "bar"
|
||||
|
||||
# Verify UUID is in default columns (charts don't have slugs)
|
||||
assert "uuid" in result.data.columns_requested
|
||||
assert "uuid" in result.data.columns_loaded
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_charts_with_search(mock_list, mcp_server):
|
||||
chart = Mock()
|
||||
chart.id = 1
|
||||
chart.slice_name = "search_chart"
|
||||
chart.viz_type = "bar"
|
||||
chart.datasource_name = "test_ds"
|
||||
chart.datasource_type = "table"
|
||||
chart.url = "/chart/1"
|
||||
chart.description = "desc"
|
||||
chart.cache_timeout = 60
|
||||
chart.form_data = {}
|
||||
chart.query_context = {}
|
||||
chart.changed_by_name = "admin"
|
||||
chart.changed_on = None
|
||||
chart.changed_on_humanized = "1 day ago"
|
||||
chart.created_by_name = "admin"
|
||||
chart.created_on = None
|
||||
chart.created_on_humanized = "2 days ago"
|
||||
chart.tags = []
|
||||
chart.owners = []
|
||||
chart._mapping = {
|
||||
"id": chart.id,
|
||||
"slice_name": chart.slice_name,
|
||||
"viz_type": chart.viz_type,
|
||||
"datasource_name": chart.datasource_name,
|
||||
"datasource_type": chart.datasource_type,
|
||||
"url": chart.url,
|
||||
"description": chart.description,
|
||||
"cache_timeout": chart.cache_timeout,
|
||||
"form_data": chart.form_data,
|
||||
"query_context": chart.query_context,
|
||||
"changed_by_name": chart.changed_by_name,
|
||||
"changed_on": chart.changed_on,
|
||||
"changed_on_humanized": chart.changed_on_humanized,
|
||||
"created_by_name": chart.created_by_name,
|
||||
"created_on": chart.created_on,
|
||||
"created_on_humanized": chart.created_on_humanized,
|
||||
"tags": chart.tags,
|
||||
"owners": chart.owners,
|
||||
}
|
||||
mock_list.return_value = ([chart], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListChartsRequest(search="search_chart", page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_charts", {"request": request.model_dump()}
|
||||
)
|
||||
charts = result.data.charts
|
||||
assert len(charts) == 1
|
||||
assert charts[0].slice_name == "search_chart"
|
||||
args, kwargs = mock_list.call_args
|
||||
assert kwargs["search"] == "search_chart"
|
||||
assert "slice_name" in kwargs["search_columns"]
|
||||
assert "description" in kwargs["search_columns"]
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_charts_with_filters(mock_list, mcp_server):
|
||||
mock_list.return_value = ([], 0)
|
||||
filters = [
|
||||
{"col": "slice_name", "opr": "sw", "value": "Sales"},
|
||||
{"col": "viz_type", "opr": "eq", "value": "bar"},
|
||||
]
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListChartsRequest(
|
||||
filters=filters,
|
||||
select_columns=["id", "slice_name"],
|
||||
order_column="changed_on",
|
||||
order_direction="desc",
|
||||
page=1,
|
||||
page_size=50,
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_charts",
|
||||
{"request": request.model_dump()},
|
||||
)
|
||||
assert result.data.count == 0
|
||||
assert result.data.charts == []
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_charts_api_error(mock_list, mcp_server):
|
||||
mock_list.side_effect = fastmcp.exceptions.ToolError("API request failed")
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(fastmcp.exceptions.ToolError) as excinfo: # noqa: PT012
|
||||
request = ListChartsRequest(page=1, page_size=10)
|
||||
await client.call_tool("list_charts", {"request": request.model_dump()})
|
||||
assert "API request failed" in str(excinfo.value)
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chart_info_success(mock_info, mcp_server):
|
||||
chart = Mock()
|
||||
chart.id = 1
|
||||
chart.slice_name = "Test Chart"
|
||||
chart.viz_type = "bar"
|
||||
chart.datasource_name = "test_ds"
|
||||
chart.datasource_type = "table"
|
||||
chart.url = "/chart/1"
|
||||
chart.description = "desc"
|
||||
chart.cache_timeout = 60
|
||||
chart.form_data = {}
|
||||
chart.query_context = {}
|
||||
chart.changed_by_name = "admin"
|
||||
chart.changed_on = None
|
||||
chart.changed_on_humanized = "1 day ago"
|
||||
chart.created_by_name = "admin"
|
||||
chart.created_on = None
|
||||
chart.created_on_humanized = "2 days ago"
|
||||
chart.tags = []
|
||||
chart.owners = []
|
||||
chart.to_model = lambda: ChartInfo(
|
||||
id=chart.id,
|
||||
slice_name=chart.slice_name,
|
||||
viz_type=chart.viz_type,
|
||||
datasource_name=chart.datasource_name,
|
||||
datasource_type=chart.datasource_type,
|
||||
url=chart.url,
|
||||
description=chart.description,
|
||||
cache_timeout=chart.cache_timeout,
|
||||
form_data=chart.form_data,
|
||||
query_context=chart.query_context,
|
||||
changed_by_name=chart.changed_by_name,
|
||||
changed_on=chart.changed_on,
|
||||
changed_on_humanized=chart.changed_on_humanized,
|
||||
created_by_name=chart.created_by_name,
|
||||
created_on=chart.created_on,
|
||||
created_on_humanized=chart.created_on_humanized,
|
||||
tags=chart.tags,
|
||||
owners=chart.owners,
|
||||
)
|
||||
mock_info.return_value = chart # Only the chart object
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_chart_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
assert result.data["slice_name"] == "Test Chart"
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chart_info_not_found(mock_info, mcp_server):
|
||||
mock_info.return_value = None # Not found returns None
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_chart_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
assert result.data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="MCP protocol bug: dict fields named column_operators are deserialized as "
|
||||
"custom types (Column_Operators). TODO: revisit after protocol fix."
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chart_available_filters_success(mcp_server):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_chart_available_filters", {"request": {}})
|
||||
assert hasattr(result.data, "column_operators")
|
||||
assert isinstance(result.data.column_operators, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chart_available_filters_exception_handling(mcp_server):
|
||||
# No exception expected in normal operation
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("get_chart_available_filters", {"request": {}})
|
||||
assert hasattr(result.data, "column_operators")
|
||||
|
||||
|
||||
def _mock_chart(id=1, viz_type="table", form_data=None):
|
||||
from unittest.mock import Mock
|
||||
|
||||
chart = Mock()
|
||||
chart.id = id
|
||||
chart.slice_name = "Test Chart"
|
||||
chart.viz_type = viz_type
|
||||
chart.datasource_name = "test_ds"
|
||||
chart.datasource_type = "table"
|
||||
chart.url = f"/chart/{id}"
|
||||
chart.description = "desc"
|
||||
chart.cache_timeout = 60
|
||||
chart.form_data = form_data or {}
|
||||
chart.query_context = {}
|
||||
chart.changed_by_name = "admin"
|
||||
chart.changed_on = None
|
||||
chart.changed_on_humanized = "1 day ago"
|
||||
chart.created_by_name = "admin"
|
||||
chart.created_on = None
|
||||
chart.created_on_humanized = "2 days ago"
|
||||
chart.tags = []
|
||||
chart.owners = []
|
||||
chart.to_model = lambda: ChartInfo(
|
||||
id=chart.id,
|
||||
slice_name=chart.slice_name,
|
||||
viz_type=chart.viz_type,
|
||||
datasource_name=chart.datasource_name,
|
||||
datasource_type=chart.datasource_type,
|
||||
url=chart.url,
|
||||
description=chart.description,
|
||||
cache_timeout=chart.cache_timeout,
|
||||
form_data=chart.form_data,
|
||||
query_context=chart.query_context,
|
||||
changed_by_name=chart.changed_by_name,
|
||||
changed_on=chart.changed_on,
|
||||
changed_on_humanized=chart.changed_on_humanized,
|
||||
created_by_name=chart.created_by_name,
|
||||
created_on=chart.created_on,
|
||||
created_on_humanized=chart.created_on_humanized,
|
||||
tags=chart.tags,
|
||||
owners=chart.owners,
|
||||
)
|
||||
return chart
|
||||
|
||||
|
||||
def _mock_dataset(id: int = 1) -> Mock:
|
||||
"""Create a mock dataset object with all required attributes for validation."""
|
||||
dataset = Mock()
|
||||
dataset.id = id
|
||||
dataset.table_name = "test_table"
|
||||
dataset.schema = "public"
|
||||
|
||||
# Create mock database
|
||||
mock_database = Mock()
|
||||
mock_database.database_name = "test_db"
|
||||
dataset.database = mock_database
|
||||
|
||||
# Create mock columns with all required attributes
|
||||
mock_columns = []
|
||||
column_names = ["region", "sales", "date", "year", "category", "quantity"]
|
||||
for name in column_names:
|
||||
mock_col = Mock()
|
||||
mock_col.column_name = name
|
||||
mock_col.type = "varchar" if name in ["region", "category"] else "integer"
|
||||
mock_col.description = f"Mock {name} column"
|
||||
mock_col.is_dttm = name == "date"
|
||||
mock_col.python_date_format = None
|
||||
mock_col.verbose_name = None
|
||||
mock_columns.append(mock_col)
|
||||
|
||||
dataset.columns = mock_columns
|
||||
dataset.metrics = [] # No metrics for simplicity
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_table_success(mock_run, mock_find_dataset, mcp_server):
|
||||
mock_run.return_value = _mock_chart(id=101, viz_type="table")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Create a simple table chart request
|
||||
config = TableChartConfig(
|
||||
chart_type="table",
|
||||
columns=[
|
||||
ColumnRef(name="region", label="Region"),
|
||||
ColumnRef(name="sales", label="Sales"),
|
||||
],
|
||||
filters=[FilterConfig(column="year", op="=", value=2024)],
|
||||
sort_by=["sales"],
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "table"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_xy_line_success(mock_run, mock_find_dataset, mcp_server):
|
||||
mock_run.return_value = _mock_chart(id=102, viz_type="echarts_timeseries_line")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Create a simple line chart request
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date", label="Date"),
|
||||
y=[ColumnRef(name="sales", label="Sales")],
|
||||
kind="line",
|
||||
group_by=ColumnRef(name="region", label="Region"),
|
||||
x_axis=AxisConfig(title="Date", scale="linear"),
|
||||
y_axis=AxisConfig(title="Sales", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="right"),
|
||||
filters=[FilterConfig(column="year", op="=", value=2024)],
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "echarts_timeseries_line"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_xy_bar_success(mock_run, mock_find_dataset, mcp_server):
|
||||
mock_run.return_value = _mock_chart(id=103, viz_type="echarts_timeseries_bar")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Create a simple bar chart request
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="region", label="Region"),
|
||||
y=[ColumnRef(name="sales", label="Sales")],
|
||||
kind="bar",
|
||||
x_axis=AxisConfig(title="Region"),
|
||||
y_axis=AxisConfig(title="Sales", format="$,.2f"),
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "echarts_timeseries_bar"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_xy_area_success(mock_run, mock_find_dataset, mcp_server):
|
||||
mock_run.return_value = _mock_chart(id=104, viz_type="echarts_area")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Create a simple area chart request
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date", label="Date"),
|
||||
y=[ColumnRef(name="sales", label="Sales")],
|
||||
kind="area",
|
||||
group_by=ColumnRef(name="region", label="Region"),
|
||||
x_axis=AxisConfig(title="Date", scale="linear"),
|
||||
y_axis=AxisConfig(title="Sales", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="right"),
|
||||
filters=[FilterConfig(column="year", op="=", value=2024)],
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "echarts_area"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_error(mock_run, mock_find_dataset, mcp_server):
|
||||
mock_run.side_effect = Exception("Chart creation failed")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
config = TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="region", label="Region")]
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
# The result should contain error information
|
||||
assert result.data is not None
|
||||
# generate_chart returns error response on chart creation failure
|
||||
assert result.data.get("success") is False
|
||||
assert "Chart creation failed" in str(result.data.get("error", ""))
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_table_minimal(mock_run, mock_find_dataset, mcp_server):
|
||||
mock_run.return_value = _mock_chart(id=105, viz_type="table")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Minimal table chart with just required fields
|
||||
config = TableChartConfig(
|
||||
chart_type="table", columns=[ColumnRef(name="region", label="Region")]
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "table"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_xy_minimal(mock_run, mock_find_dataset, mcp_server):
|
||||
mock_run.return_value = _mock_chart(id=106, viz_type="echarts_timeseries_line")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Create a minimal line chart request
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="sales")], # Use existing column in mock dataset
|
||||
kind="line",
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "echarts_timeseries_line"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_with_simple_metrics(
|
||||
mock_run, mock_find_dataset, mcp_server
|
||||
):
|
||||
mock_run.return_value = _mock_chart(id=107, viz_type="echarts_timeseries_bar")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Test with simple metrics like "count", "sum", etc.
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="region"),
|
||||
y=[
|
||||
ColumnRef(name="quantity"), # Use existing column in mock dataset
|
||||
ColumnRef(name="sales"), # Should be passed as complex object
|
||||
],
|
||||
kind="bar",
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "echarts_timeseries_bar"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_with_sql_aggregators(
|
||||
mock_run, mock_find_dataset, mcp_server
|
||||
):
|
||||
mock_run.return_value = _mock_chart(id=108, viz_type="echarts_timeseries_line")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Test with SQL aggregators
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[
|
||||
ColumnRef(name="sales", aggregate="SUM", label="Total Sales"),
|
||||
ColumnRef(name="quantity", aggregate="COUNT", label="Quantity Count"),
|
||||
ColumnRef(name="sales", aggregate="AVG", label="Average Sales"),
|
||||
],
|
||||
kind="line",
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "echarts_timeseries_line"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_comprehensive_metrics(
|
||||
mock_run, mock_find_dataset, mcp_server
|
||||
):
|
||||
mock_run.return_value = _mock_chart(id=109, viz_type="echarts_timeseries_bar")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Test comprehensive metric scenarios
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="region"),
|
||||
y=[
|
||||
ColumnRef(name="quantity"), # Use existing column
|
||||
ColumnRef(
|
||||
name="sales", aggregate="SUM", label="Total Sales"
|
||||
), # SQL aggregator
|
||||
ColumnRef(
|
||||
name="sales", aggregate="AVG", label="Average Sales"
|
||||
), # SQL aggregator with same column
|
||||
ColumnRef(
|
||||
name="quantity", aggregate="COUNT", label="Quantity Count"
|
||||
), # SQL aggregator
|
||||
ColumnRef(
|
||||
name="sales", aggregate="MAX", label="Max Sales"
|
||||
), # SQL aggregator
|
||||
],
|
||||
kind="bar",
|
||||
group_by=ColumnRef(name="category"),
|
||||
x_axis=AxisConfig(title="Region", format="string"),
|
||||
y_axis=AxisConfig(title="Values", format="$,.2f"),
|
||||
legend=LegendConfig(show=True, position="top"),
|
||||
filters=[
|
||||
FilterConfig(column="year", op="=", value=2024),
|
||||
FilterConfig(column="category", op="!=", value="cancelled"),
|
||||
],
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "echarts_timeseries_bar"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.daos.dataset.DatasetDAO.find_by_id")
|
||||
@patch("superset.commands.chart.create.CreateChartCommand.run")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_xy_scatter_success(
|
||||
mock_run, mock_find_dataset, mcp_server
|
||||
):
|
||||
mock_run.return_value = _mock_chart(id=110, viz_type="echarts_timeseries_scatter")
|
||||
mock_find_dataset.return_value = _mock_dataset(id=1)
|
||||
|
||||
# Create a scatter chart request
|
||||
config = XYChartConfig(
|
||||
chart_type="xy",
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="sales")], # Use existing column for scatter
|
||||
kind="scatter",
|
||||
group_by=ColumnRef(name="category"),
|
||||
x_axis=AxisConfig(title="Date", format="smart_date"),
|
||||
y_axis=AxisConfig(title="Sales", format="SMART_NUMBER"),
|
||||
legend=LegendConfig(show=True, position="top"),
|
||||
filters=[FilterConfig(column="year", op="=", value=2024)],
|
||||
)
|
||||
|
||||
request = GenerateChartRequest(dataset_id="1", config=config)
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
resp = await client.call_tool(
|
||||
"generate_chart", {"request": request.model_dump()}
|
||||
)
|
||||
assert resp.data["chart"] is not None
|
||||
assert resp.data["chart"]["viz_type"] == "echarts_timeseries_scatter"
|
||||
assert resp.data["error"] is None
|
||||
mock_run.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.mcp_service.generic_tools.ModelGetInfoTool._find_object")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_chart_info_by_uuid(mock_find_object, mcp_server):
|
||||
"""Test getting chart info using UUID identifier."""
|
||||
chart = Mock()
|
||||
chart.id = 1
|
||||
chart.slice_name = "Test Chart UUID"
|
||||
chart.viz_type = "bar"
|
||||
chart.datasource_name = "test_ds"
|
||||
chart.datasource_type = "table"
|
||||
chart.url = "/chart/1"
|
||||
chart.description = "desc"
|
||||
chart.cache_timeout = 60
|
||||
chart.form_data = {}
|
||||
chart.query_context = {}
|
||||
chart.changed_by_name = "admin"
|
||||
chart.changed_on = None
|
||||
chart.changed_on_humanized = "1 day ago"
|
||||
chart.created_by_name = "admin"
|
||||
chart.created_on = None
|
||||
chart.created_on_humanized = "2 days ago"
|
||||
chart.tags = []
|
||||
chart.owners = []
|
||||
|
||||
mock_find_object.return_value = chart
|
||||
async with Client(mcp_server) as client:
|
||||
uuid_str = "b2c3d4e5-f6g7-8901-bcde-fg2345678901"
|
||||
result = await client.call_tool(
|
||||
"get_chart_info", {"request": {"identifier": uuid_str}}
|
||||
)
|
||||
assert result.data["slice_name"] == "Test Chart UUID"
|
||||
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_charts_custom_uuid_columns(mock_list, mcp_server):
|
||||
"""Test that custom column selection includes UUID when explicitly requested."""
|
||||
chart = Mock()
|
||||
chart.id = 1
|
||||
chart.slice_name = "Custom Columns Chart"
|
||||
chart.viz_type = "bar"
|
||||
chart.datasource_name = "test_ds"
|
||||
chart.datasource_type = "table"
|
||||
chart.url = "/chart/1"
|
||||
chart.description = "desc"
|
||||
chart.cache_timeout = 60
|
||||
chart.form_data = {}
|
||||
chart.query_context = {}
|
||||
chart.changed_by_name = "admin"
|
||||
chart.changed_on = None
|
||||
chart.changed_on_humanized = "1 day ago"
|
||||
chart.created_by_name = "admin"
|
||||
chart.created_on = None
|
||||
chart.created_on_humanized = "2 days ago"
|
||||
chart.uuid = "test-custom-chart-uuid"
|
||||
chart.tags = []
|
||||
chart.owners = []
|
||||
chart._mapping = {
|
||||
"id": chart.id,
|
||||
"slice_name": chart.slice_name,
|
||||
"viz_type": chart.viz_type,
|
||||
"datasource_name": chart.datasource_name,
|
||||
"datasource_type": chart.datasource_type,
|
||||
"url": chart.url,
|
||||
"description": chart.description,
|
||||
"cache_timeout": chart.cache_timeout,
|
||||
"form_data": chart.form_data,
|
||||
"query_context": chart.query_context,
|
||||
"changed_by_name": chart.changed_by_name,
|
||||
"changed_on": chart.changed_on,
|
||||
"changed_on_humanized": chart.changed_on_humanized,
|
||||
"created_by_name": chart.created_by_name,
|
||||
"created_on": chart.created_on,
|
||||
"created_on_humanized": chart.created_on_humanized,
|
||||
"uuid": chart.uuid,
|
||||
"tags": chart.tags,
|
||||
"owners": chart.owners,
|
||||
}
|
||||
mock_list.return_value = ([chart], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListChartsRequest(
|
||||
select_columns=["id", "slice_name", "uuid"], page=1, page_size=10
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_charts", {"request": request.model_dump()}
|
||||
)
|
||||
charts = result.data.charts
|
||||
assert len(charts) == 1
|
||||
assert charts[0].uuid == "test-custom-chart-uuid"
|
||||
|
||||
# Verify custom columns include UUID
|
||||
assert "uuid" in result.data.columns_requested
|
||||
assert "uuid" in result.data.columns_loaded
|
||||
341
tests/unit_tests/mcp_service/test_chart_utils.py
Normal file
341
tests/unit_tests/mcp_service/test_chart_utils.py
Normal file
@@ -0,0 +1,341 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Tests for chart utilities module"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from superset.mcp_service.chart.chart_utils import (
|
||||
create_metric_object,
|
||||
generate_chart_name,
|
||||
generate_explore_link,
|
||||
map_config_to_form_data,
|
||||
map_filter_operator,
|
||||
map_table_config,
|
||||
map_xy_config,
|
||||
)
|
||||
from superset.mcp_service.schemas.chart_schemas import (
|
||||
AxisConfig,
|
||||
ColumnRef,
|
||||
FilterConfig,
|
||||
LegendConfig,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
|
||||
|
||||
class TestCreateMetricObject:
|
||||
"""Test create_metric_object function"""
|
||||
|
||||
def test_create_metric_object_with_aggregate(self) -> None:
|
||||
"""Test creating metric object with specified aggregate"""
|
||||
col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue")
|
||||
result = create_metric_object(col)
|
||||
|
||||
assert result["aggregate"] == "SUM"
|
||||
assert result["column"]["column_name"] == "revenue"
|
||||
assert result["label"] == "Total Revenue"
|
||||
assert result["optionName"] == "metric_revenue"
|
||||
assert result["expressionType"] == "SIMPLE"
|
||||
|
||||
def test_create_metric_object_default_aggregate(self) -> None:
|
||||
"""Test creating metric object with default aggregate"""
|
||||
col = ColumnRef(name="orders")
|
||||
result = create_metric_object(col)
|
||||
|
||||
assert result["aggregate"] == "SUM"
|
||||
assert result["column"]["column_name"] == "orders"
|
||||
assert result["label"] == "SUM(orders)"
|
||||
assert result["optionName"] == "metric_orders"
|
||||
|
||||
|
||||
class TestMapFilterOperator:
|
||||
"""Test map_filter_operator function"""
|
||||
|
||||
def test_map_filter_operators(self) -> None:
|
||||
"""Test mapping of various filter operators"""
|
||||
assert map_filter_operator("=") == "=="
|
||||
assert map_filter_operator(">") == ">"
|
||||
assert map_filter_operator("<") == "<"
|
||||
assert map_filter_operator(">=") == ">="
|
||||
assert map_filter_operator("<=") == "<="
|
||||
assert map_filter_operator("!=") == "!="
|
||||
|
||||
def test_map_filter_operator_unknown(self) -> None:
|
||||
"""Test mapping of unknown operator returns original"""
|
||||
assert map_filter_operator("UNKNOWN") == "UNKNOWN"
|
||||
|
||||
|
||||
class TestMapTableConfig:
|
||||
"""Test map_table_config function"""
|
||||
|
||||
def test_map_table_config_basic(self) -> None:
|
||||
"""Test basic table config mapping with aggregated columns"""
|
||||
config = TableChartConfig(
|
||||
columns=[
|
||||
ColumnRef(name="product", aggregate="COUNT"),
|
||||
ColumnRef(name="revenue", aggregate="SUM"),
|
||||
]
|
||||
)
|
||||
|
||||
result = map_table_config(config)
|
||||
|
||||
assert result["viz_type"] == "table"
|
||||
assert result["query_mode"] == "aggregate"
|
||||
# Aggregated columns should be in metrics, not all_columns
|
||||
assert "all_columns" not in result
|
||||
assert len(result["metrics"]) == 2
|
||||
assert result["metrics"][0]["aggregate"] == "COUNT"
|
||||
assert result["metrics"][1]["aggregate"] == "SUM"
|
||||
|
||||
def test_map_table_config_raw_columns(self) -> None:
|
||||
"""Test table config mapping with raw columns (no aggregates)"""
|
||||
config = TableChartConfig(
|
||||
columns=[
|
||||
ColumnRef(name="product"),
|
||||
ColumnRef(name="category"),
|
||||
]
|
||||
)
|
||||
|
||||
result = map_table_config(config)
|
||||
|
||||
assert result["viz_type"] == "table"
|
||||
assert result["query_mode"] == "raw"
|
||||
# Raw columns should be in all_columns
|
||||
assert result["all_columns"] == ["product", "category"]
|
||||
assert "metrics" not in result
|
||||
|
||||
def test_map_table_config_with_filters(self) -> None:
|
||||
"""Test table config mapping with filters"""
|
||||
config = TableChartConfig(
|
||||
columns=[ColumnRef(name="product")],
|
||||
filters=[FilterConfig(column="status", op="=", value="active")],
|
||||
)
|
||||
|
||||
result = map_table_config(config)
|
||||
|
||||
assert "adhoc_filters" in result
|
||||
assert len(result["adhoc_filters"]) == 1
|
||||
filter_obj = result["adhoc_filters"][0]
|
||||
assert filter_obj["subject"] == "status"
|
||||
assert filter_obj["operator"] == "=="
|
||||
assert filter_obj["comparator"] == "active"
|
||||
assert filter_obj["expressionType"] == "SIMPLE"
|
||||
|
||||
def test_map_table_config_with_sort(self) -> None:
|
||||
"""Test table config mapping with sort"""
|
||||
config = TableChartConfig(
|
||||
columns=[ColumnRef(name="product")], sort_by=["product", "revenue"]
|
||||
)
|
||||
|
||||
result = map_table_config(config)
|
||||
assert result["order_by_cols"] == ["product", "revenue"]
|
||||
|
||||
|
||||
class TestMapXYConfig:
|
||||
"""Test map_xy_config function"""
|
||||
|
||||
def test_map_xy_config_line_chart(self) -> None:
|
||||
"""Test XY config mapping for line chart"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue", aggregate="SUM")],
|
||||
kind="line",
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert result["viz_type"] == "echarts_timeseries_line"
|
||||
assert result["x_axis"] == "date"
|
||||
assert len(result["metrics"]) == 1
|
||||
assert result["metrics"][0]["aggregate"] == "SUM"
|
||||
|
||||
def test_map_xy_config_with_groupby(self) -> None:
|
||||
"""Test XY config mapping with group by"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue")],
|
||||
kind="bar",
|
||||
group_by=ColumnRef(name="region"),
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert result["viz_type"] == "echarts_timeseries_bar"
|
||||
assert result["groupby"] == ["region"]
|
||||
|
||||
def test_map_xy_config_with_axes(self) -> None:
|
||||
"""Test XY config mapping with axis configurations"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue")],
|
||||
kind="area",
|
||||
x_axis=AxisConfig(title="Date", format="%Y-%m-%d"),
|
||||
y_axis=AxisConfig(title="Revenue", scale="log", format="$,.2f"),
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert result["viz_type"] == "echarts_area"
|
||||
assert result["x_axis_title"] == "Date"
|
||||
assert result["x_axis_format"] == "%Y-%m-%d"
|
||||
assert result["y_axis_title"] == "Revenue"
|
||||
assert result["y_axis_format"] == "$,.2f"
|
||||
assert result["y_axis_scale"] == "log"
|
||||
|
||||
def test_map_xy_config_with_legend(self) -> None:
|
||||
"""Test XY config mapping with legend configuration"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue")],
|
||||
kind="scatter",
|
||||
legend=LegendConfig(show=False, position="top"),
|
||||
)
|
||||
|
||||
result = map_xy_config(config)
|
||||
|
||||
assert result["viz_type"] == "echarts_timeseries_scatter"
|
||||
assert result["show_legend"] is False
|
||||
assert result["legend_orientation"] == "top"
|
||||
|
||||
|
||||
class TestMapConfigToFormData:
|
||||
"""Test map_config_to_form_data function"""
|
||||
|
||||
def test_map_table_config_type(self) -> None:
|
||||
"""Test mapping table config type"""
|
||||
config = TableChartConfig(columns=[ColumnRef(name="test")])
|
||||
result = map_config_to_form_data(config)
|
||||
assert result["viz_type"] == "table"
|
||||
|
||||
def test_map_xy_config_type(self) -> None:
|
||||
"""Test mapping XY config type"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line"
|
||||
)
|
||||
result = map_config_to_form_data(config)
|
||||
assert result["viz_type"] == "echarts_timeseries_line"
|
||||
|
||||
def test_map_unsupported_config_type(self) -> None:
|
||||
"""Test mapping unsupported config type raises error"""
|
||||
with pytest.raises(ValueError, match="Unsupported config type"):
|
||||
map_config_to_form_data("invalid_config") # type: ignore
|
||||
|
||||
|
||||
class TestGenerateChartName:
|
||||
"""Test generate_chart_name function"""
|
||||
|
||||
def test_generate_table_chart_name(self) -> None:
|
||||
"""Test generating name for table chart"""
|
||||
config = TableChartConfig(
|
||||
columns=[
|
||||
ColumnRef(name="product"),
|
||||
ColumnRef(name="revenue"),
|
||||
]
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Table Chart - product, revenue"
|
||||
|
||||
def test_generate_xy_chart_name(self) -> None:
|
||||
"""Test generating name for XY chart"""
|
||||
config = XYChartConfig(
|
||||
x=ColumnRef(name="date"),
|
||||
y=[ColumnRef(name="revenue"), ColumnRef(name="orders")],
|
||||
kind="line",
|
||||
)
|
||||
|
||||
result = generate_chart_name(config)
|
||||
assert result == "Line Chart - date vs revenue, orders"
|
||||
|
||||
def test_generate_chart_name_unsupported(self) -> None:
|
||||
"""Test generating name for unsupported config type"""
|
||||
result = generate_chart_name("invalid_config") # type: ignore
|
||||
assert result == "Chart"
|
||||
|
||||
|
||||
class TestGenerateExploreLink:
|
||||
"""Test generate_explore_link function"""
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
|
||||
def test_generate_explore_link_uses_base_url(self, mock_get_base_url) -> None:
|
||||
"""Test that generate_explore_link uses the configured base URL"""
|
||||
mock_get_base_url.return_value = "https://superset.example.com"
|
||||
form_data = {"viz_type": "table", "metrics": ["count"]}
|
||||
|
||||
result = generate_explore_link("123", form_data)
|
||||
|
||||
# Should use the configured base URL
|
||||
assert result.startswith("https://superset.example.com")
|
||||
assert "/explore/?" in result
|
||||
assert "datasource_id=123" in result
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
|
||||
def test_generate_explore_link_fallback_url(self, mock_get_base_url) -> None:
|
||||
"""Test generate_explore_link returns fallback URL when dataset not found"""
|
||||
mock_get_base_url.return_value = "http://localhost:8088"
|
||||
form_data = {"viz_type": "table"}
|
||||
|
||||
# Mock dataset not found scenario
|
||||
with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None):
|
||||
result = generate_explore_link("999", form_data)
|
||||
|
||||
assert (
|
||||
result
|
||||
== "http://localhost:8088/explore/?datasource_type=table&datasource_id=999"
|
||||
)
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
|
||||
@patch("superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand")
|
||||
def test_generate_explore_link_with_form_data_key(
|
||||
self, mock_command, mock_get_base_url
|
||||
) -> None:
|
||||
"""Test generate_explore_link creates form_data_key when dataset exists"""
|
||||
mock_get_base_url.return_value = "http://localhost:8088"
|
||||
mock_command.return_value.run.return_value = "test_form_data_key"
|
||||
|
||||
# Mock dataset exists
|
||||
mock_dataset = type("Dataset", (), {"id": 123})()
|
||||
with patch(
|
||||
"superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_dataset
|
||||
):
|
||||
result = generate_explore_link(123, {"viz_type": "table"})
|
||||
|
||||
assert (
|
||||
result == "http://localhost:8088/explore/?form_data_key=test_form_data_key"
|
||||
)
|
||||
mock_command.assert_called_once()
|
||||
|
||||
@patch("superset.mcp_service.chart.chart_utils.get_superset_base_url")
|
||||
def test_generate_explore_link_exception_handling(self, mock_get_base_url) -> None:
|
||||
"""Test generate_explore_link handles exceptions gracefully"""
|
||||
mock_get_base_url.return_value = "http://localhost:8088"
|
||||
|
||||
# Mock exception during form_data creation
|
||||
with patch(
|
||||
"superset.daos.dataset.DatasetDAO.find_by_id",
|
||||
side_effect=Exception("DB error"),
|
||||
):
|
||||
result = generate_explore_link("123", {"viz_type": "table"})
|
||||
|
||||
# Should fallback to basic URL
|
||||
assert (
|
||||
result
|
||||
== "http://localhost:8088/explore/?datasource_type=table&datasource_id=123"
|
||||
)
|
||||
394
tests/unit_tests/mcp_service/test_dashboard_generation.py
Normal file
394
tests/unit_tests/mcp_service/test_dashboard_generation.py
Normal file
@@ -0,0 +1,394 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for dashboard generation MCP tools
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.utils import json
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
def _mock_chart(id: int = 1, slice_name: str = "Test Chart") -> Mock:
|
||||
"""Create a mock chart object."""
|
||||
chart = Mock()
|
||||
chart.id = id
|
||||
chart.slice_name = slice_name
|
||||
chart.uuid = f"chart-uuid-{id}"
|
||||
return chart
|
||||
|
||||
|
||||
def _mock_dashboard(id: int = 1, title: str = "Test Dashboard") -> Mock:
|
||||
"""Create a mock dashboard object."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = id
|
||||
dashboard.dashboard_title = title
|
||||
dashboard.slug = f"test-dashboard-{id}"
|
||||
dashboard.description = "Test dashboard description"
|
||||
dashboard.published = True
|
||||
dashboard.created_on = "2024-01-01"
|
||||
dashboard.changed_on = "2024-01-01"
|
||||
dashboard.created_by = Mock()
|
||||
dashboard.created_by.username = "test_user"
|
||||
dashboard.changed_by = Mock()
|
||||
dashboard.changed_by.username = "test_user"
|
||||
dashboard.uuid = f"dashboard-uuid-{id}"
|
||||
dashboard.slices = []
|
||||
dashboard.owners = [] # Add missing owners attribute
|
||||
dashboard.tags = [] # Add missing tags attribute
|
||||
return dashboard
|
||||
|
||||
|
||||
class TestGenerateDashboard:
|
||||
"""Tests for generate_dashboard MCP tool."""
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_basic(
|
||||
self, mock_find_chart, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test basic dashboard generation with valid charts."""
|
||||
# Mock charts exist
|
||||
mock_find_chart.side_effect = [
|
||||
_mock_chart(id=1, slice_name="Sales Chart"),
|
||||
_mock_chart(id=2, slice_name="Revenue Chart"),
|
||||
]
|
||||
|
||||
# Mock dashboard creation
|
||||
mock_dashboard = _mock_dashboard(id=10, title="Analytics Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
|
||||
request = {
|
||||
"chart_ids": [1, 2],
|
||||
"dashboard_title": "Analytics Dashboard",
|
||||
"description": "Dashboard for analytics",
|
||||
"published": True,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard is not None
|
||||
assert result.data.dashboard.id == 10
|
||||
assert result.data.dashboard.dashboard_title == "Analytics Dashboard"
|
||||
assert result.data.dashboard.chart_count == 2
|
||||
assert "/superset/dashboard/10/" in result.data.dashboard_url
|
||||
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_missing_charts(self, mock_find_chart, mcp_server):
|
||||
"""Test error handling when some charts don't exist."""
|
||||
# First chart exists, second doesn't
|
||||
mock_find_chart.side_effect = [
|
||||
_mock_chart(id=1),
|
||||
None, # Chart 2 doesn't exist
|
||||
]
|
||||
|
||||
request = {"chart_ids": [1, 2], "dashboard_title": "Test Dashboard"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Charts not found: [2]" in result.data.error
|
||||
assert result.data.dashboard is None
|
||||
assert result.data.dashboard_url is None
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_single_chart(
|
||||
self, mock_find_chart, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with a single chart."""
|
||||
mock_find_chart.return_value = _mock_chart(id=5, slice_name="Single Chart")
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=20, title="Single Chart Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
|
||||
request = {
|
||||
"chart_ids": [5],
|
||||
"dashboard_title": "Single Chart Dashboard",
|
||||
"published": False,
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard.chart_count == 1
|
||||
assert result.data.dashboard.published is True # From mock
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_many_charts(
|
||||
self, mock_find_chart, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with many charts (grid layout)."""
|
||||
# Mock 6 charts
|
||||
chart_ids = list(range(1, 7))
|
||||
mock_find_chart.side_effect = [
|
||||
_mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids
|
||||
]
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=30, title="Multi Chart Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
|
||||
request = {"chart_ids": chart_ids, "dashboard_title": "Multi Chart Dashboard"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard.chart_count == 6
|
||||
|
||||
# Verify CreateDashboardCommand was called with proper layout
|
||||
mock_create_command.assert_called_once()
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
|
||||
# Check position_json contains proper layout
|
||||
position_json = json.loads(call_args["position_json"])
|
||||
assert "ROOT_ID" in position_json
|
||||
assert len(position_json["ROOT_ID"]["children"]) == 6
|
||||
|
||||
# Check each chart has position embedded in meta
|
||||
for i in chart_ids:
|
||||
chart_key = f"CHART-{i}"
|
||||
assert chart_key in position_json
|
||||
chart_data = position_json[chart_key]
|
||||
assert "meta" in chart_data
|
||||
assert "h" in chart_data["meta"] # Height
|
||||
assert "w" in chart_data["meta"] # Width
|
||||
assert "x" in chart_data["meta"] # X position
|
||||
assert "y" in chart_data["meta"] # Y position
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_creation_failure(
|
||||
self, mock_find_chart, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test error handling when dashboard creation fails."""
|
||||
mock_find_chart.return_value = _mock_chart(id=1)
|
||||
mock_create_command.return_value.run.side_effect = Exception("Creation failed")
|
||||
|
||||
request = {"chart_ids": [1], "dashboard_title": "Failed Dashboard"}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Failed to create dashboard" in result.data.error
|
||||
assert result.data.dashboard is None
|
||||
|
||||
@patch("superset.commands.dashboard.create.CreateDashboardCommand")
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_dashboard_minimal_request(
|
||||
self, mock_find_chart, mock_create_command, mcp_server
|
||||
):
|
||||
"""Test dashboard generation with minimal required parameters."""
|
||||
mock_find_chart.return_value = _mock_chart(id=3)
|
||||
|
||||
mock_dashboard = _mock_dashboard(id=40, title="Minimal Dashboard")
|
||||
mock_create_command.return_value.run.return_value = mock_dashboard
|
||||
|
||||
request = {
|
||||
"chart_ids": [3],
|
||||
"dashboard_title": "Minimal Dashboard",
|
||||
# No description, published defaults to True
|
||||
}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool("generate_dashboard", {"request": request})
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard.dashboard_title == "Minimal Dashboard"
|
||||
|
||||
# Check that description was not included in call
|
||||
call_args = mock_create_command.call_args[0][0]
|
||||
assert call_args["published"] is True # Default value
|
||||
assert (
|
||||
"description" not in call_args or call_args.get("description") is None
|
||||
)
|
||||
|
||||
|
||||
class TestAddChartToExistingDashboard:
|
||||
"""Tests for add_chart_to_existing_dashboard MCP tool."""
|
||||
|
||||
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_to_dashboard_basic(
|
||||
self, mock_find_chart, mock_find_dashboard, mock_update_command, mcp_server
|
||||
):
|
||||
"""Test adding a chart to an existing dashboard."""
|
||||
# Mock existing dashboard with some charts
|
||||
mock_dashboard = _mock_dashboard(id=1, title="Existing Dashboard")
|
||||
mock_dashboard.slices = [Mock(id=10), Mock(id=20)] # Existing charts
|
||||
mock_dashboard.position_json = json.dumps(
|
||||
{
|
||||
"ROOT_ID": {
|
||||
"children": ["CHART-10", "CHART-20"],
|
||||
"id": "ROOT_ID",
|
||||
"type": "ROOT",
|
||||
},
|
||||
"CHART-10": {"id": "CHART-10", "type": "CHART", "parents": ["ROOT_ID"]},
|
||||
"CHART-10_POSITION": {"h": 16, "w": 24, "x": 0, "y": 0},
|
||||
"CHART-20": {"id": "CHART-20", "type": "CHART", "parents": ["ROOT_ID"]},
|
||||
"CHART-20_POSITION": {"h": 16, "w": 24, "x": 24, "y": 0},
|
||||
}
|
||||
)
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
# Mock chart to add
|
||||
mock_chart = _mock_chart(id=30, slice_name="New Chart")
|
||||
mock_find_chart.return_value = mock_chart
|
||||
|
||||
# Mock updated dashboard
|
||||
updated_dashboard = _mock_dashboard(id=1, title="Existing Dashboard")
|
||||
updated_dashboard.slices = [Mock(id=10), Mock(id=20), Mock(id=30)]
|
||||
mock_update_command.return_value.run.return_value = updated_dashboard
|
||||
|
||||
request = {"dashboard_id": 1, "chart_id": 30}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.dashboard is not None
|
||||
assert result.data.dashboard.chart_count == 3
|
||||
assert result.data.position is not None
|
||||
assert result.data.position.get("x") == 0 # Should be positioned at (0, 16)
|
||||
assert result.data.position.get("y") == 16
|
||||
assert "/superset/dashboard/1/" in result.data.dashboard_url
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_dashboard_not_found(self, mock_find_dashboard, mcp_server):
|
||||
"""Test error when dashboard doesn't exist."""
|
||||
mock_find_dashboard.return_value = None
|
||||
|
||||
request = {"dashboard_id": 999, "chart_id": 1}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Dashboard with ID 999 not found" in result.data.error
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_chart_not_found(
|
||||
self, mock_find_chart, mock_find_dashboard, mcp_server
|
||||
):
|
||||
"""Test error when chart doesn't exist."""
|
||||
mock_find_dashboard.return_value = _mock_dashboard()
|
||||
mock_find_chart.return_value = None
|
||||
|
||||
request = {"dashboard_id": 1, "chart_id": 999}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Chart with ID 999 not found" in result.data.error
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_already_in_dashboard(
|
||||
self, mock_find_chart, mock_find_dashboard, mcp_server
|
||||
):
|
||||
"""Test error when chart is already in dashboard."""
|
||||
mock_dashboard = _mock_dashboard()
|
||||
mock_dashboard.slices = [Mock(id=5)] # Chart 5 already exists
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
mock_find_chart.return_value = _mock_chart(id=5)
|
||||
|
||||
request = {"dashboard_id": 1, "chart_id": 5}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is not None
|
||||
assert "Chart 5 is already in dashboard 1" in result.data.error
|
||||
|
||||
@patch("superset.commands.dashboard.update.UpdateDashboardCommand")
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@patch("superset.daos.chart.ChartDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_chart_empty_dashboard(
|
||||
self, mock_find_chart, mock_find_dashboard, mock_update_command, mcp_server
|
||||
):
|
||||
"""Test adding chart to dashboard with no existing layout."""
|
||||
mock_dashboard = _mock_dashboard(id=2)
|
||||
mock_dashboard.slices = []
|
||||
mock_dashboard.position_json = "{}" # Empty layout
|
||||
mock_find_dashboard.return_value = mock_dashboard
|
||||
|
||||
mock_chart = _mock_chart(id=15)
|
||||
mock_find_chart.return_value = mock_chart
|
||||
|
||||
updated_dashboard = _mock_dashboard(id=2)
|
||||
updated_dashboard.slices = [Mock(id=15)]
|
||||
mock_update_command.return_value.run.return_value = updated_dashboard
|
||||
|
||||
request = {"dashboard_id": 2, "chart_id": 15}
|
||||
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"add_chart_to_existing_dashboard", {"request": request}
|
||||
)
|
||||
|
||||
assert result.data.error is None
|
||||
assert result.data.position.get("x") == 0 # Should start at (0, 0)
|
||||
assert result.data.position.get("y") == 0
|
||||
|
||||
# Verify update was called with ROOT_ID created
|
||||
call_args = mock_update_command.call_args[0][1]
|
||||
layout = json.loads(call_args["position_json"])
|
||||
assert "ROOT_ID" in layout
|
||||
assert layout["ROOT_ID"]["children"] == ["CHART-15"]
|
||||
523
tests/unit_tests/mcp_service/test_dashboard_tools.py
Normal file
523
tests/unit_tests/mcp_service/test_dashboard_tools.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""
|
||||
Unit tests for MCP dashboard tools (list_dashboards, get_dashboard_info,
|
||||
get_dashboard_available_filters)
|
||||
"""
|
||||
|
||||
import logging
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastmcp import Client
|
||||
from fastmcp.exceptions import ToolError
|
||||
|
||||
from superset.mcp_service.mcp_app import mcp
|
||||
from superset.mcp_service.schemas.dashboard_schemas import (
|
||||
ListDashboardsRequest,
|
||||
)
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_server():
|
||||
return mcp
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_basic(mock_list, mcp_server):
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Test Dashboard"
|
||||
dashboard.slug = "test-dashboard"
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.published = True
|
||||
dashboard.changed_by_name = "admin"
|
||||
dashboard.changed_on = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.created_by_name = "admin"
|
||||
dashboard.created_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.tags = []
|
||||
dashboard.owners = []
|
||||
dashboard.slices = []
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.uuid = "test-dashboard-uuid-1"
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_list.return_value = ([dashboard], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListDashboardsRequest(page=1, page_size=10)
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
dashboards = result.data.dashboards
|
||||
assert len(dashboards) == 1
|
||||
assert dashboards[0].dashboard_title == "Test Dashboard"
|
||||
assert dashboards[0].uuid == "test-dashboard-uuid-1"
|
||||
assert dashboards[0].slug == "test-dashboard"
|
||||
assert dashboards[0].published is True
|
||||
|
||||
# Verify UUID and slug are in default columns
|
||||
assert "uuid" in result.data.columns_requested
|
||||
assert "slug" in result.data.columns_requested
|
||||
assert "uuid" in result.data.columns_loaded
|
||||
assert "slug" in result.data.columns_loaded
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_with_filters(mock_list, mcp_server):
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Filtered Dashboard"
|
||||
dashboard.slug = "filtered-dashboard"
|
||||
dashboard.url = "/dashboard/2"
|
||||
dashboard.published = True
|
||||
dashboard.changed_by_name = "admin"
|
||||
dashboard.changed_on = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.created_by_name = "admin"
|
||||
dashboard.created_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.tags = []
|
||||
dashboard.owners = []
|
||||
dashboard.slices = []
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.uuid = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_list.return_value = ([dashboard], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
filters = [
|
||||
{"col": "dashboard_title", "opr": "sw", "value": "Sales"},
|
||||
{"col": "published", "opr": "eq", "value": True},
|
||||
]
|
||||
request = ListDashboardsRequest(
|
||||
filters=filters,
|
||||
select_columns=["id", "dashboard_title"],
|
||||
order_column="changed_on",
|
||||
order_direction="desc",
|
||||
page=1,
|
||||
page_size=50,
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.data.count == 1
|
||||
assert result.data.dashboards[0].dashboard_title == "Filtered Dashboard"
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_with_string_filters(mock_list, mcp_server):
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client: # noqa: F841
|
||||
filters = '[{"col": "dashboard_title", "opr": "sw", "value": "Sales"}]'
|
||||
|
||||
# Test that string filters cause validation error at schema level
|
||||
with pytest.raises(ValueError, match="validation error"):
|
||||
ListDashboardsRequest(filters=filters) # noqa: F841
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_api_error(mock_list, mcp_server):
|
||||
mock_list.side_effect = ToolError("API request failed")
|
||||
async with Client(mcp_server) as client:
|
||||
with pytest.raises(ToolError) as excinfo: # noqa: PT012
|
||||
request = ListDashboardsRequest()
|
||||
await client.call_tool("list_dashboards", {"request": request.model_dump()})
|
||||
assert "API request failed" in str(excinfo.value)
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_with_search(mock_list, mcp_server):
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "search_dashboard"
|
||||
dashboard.slug = "search-dashboard"
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.published = True
|
||||
dashboard.changed_by_name = "admin"
|
||||
dashboard.changed_on = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.created_by_name = "admin"
|
||||
dashboard.created_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.tags = []
|
||||
dashboard.owners = []
|
||||
dashboard.slices = []
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.uuid = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_list.return_value = ([dashboard], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListDashboardsRequest(search="search_dashboard")
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
assert result.data.count == 1
|
||||
assert result.data.dashboards[0].dashboard_title == "search_dashboard"
|
||||
args, kwargs = mock_list.call_args
|
||||
assert kwargs["search"] == "search_dashboard"
|
||||
assert "dashboard_title" in kwargs["search_columns"]
|
||||
assert "slug" in kwargs["search_columns"]
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_with_simple_filters(mock_list, mcp_server):
|
||||
mock_list.return_value = ([], 0)
|
||||
async with Client(mcp_server) as client:
|
||||
filters = [
|
||||
{"col": "dashboard_title", "opr": "eq", "value": "Sales"},
|
||||
{"col": "published", "opr": "eq", "value": True},
|
||||
]
|
||||
request = ListDashboardsRequest(filters=filters)
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
assert hasattr(result.data, "count")
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_success(mock_info, mcp_server):
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Test Dashboard"
|
||||
dashboard.slug = "test-dashboard"
|
||||
dashboard.description = "Test description"
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.published = True
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.created_on = None
|
||||
dashboard.changed_on = None
|
||||
dashboard.created_by = None
|
||||
dashboard.changed_by = None
|
||||
dashboard.uuid = None
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_info.return_value = dashboard # Only the dashboard object
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
assert result.data["dashboard_title"] == "Test Dashboard"
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_not_found(mock_info, mcp_server):
|
||||
mock_info.return_value = None # Not found returns None
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": 999}}
|
||||
)
|
||||
assert result.data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.find_by_id")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_access_denied(mock_info, mcp_server):
|
||||
mock_info.return_value = None # Access denied returns None
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": 1}}
|
||||
)
|
||||
assert result.data["error_type"] == "not_found"
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
reason="MCP protocol bug: dict fields named column_operators are deserialized as "
|
||||
"custom types (Column_Operators). TODO: revisit after protocol fix."
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_available_filters_success(mcp_server):
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_available_filters", {"request": {}}
|
||||
)
|
||||
assert hasattr(result.data, "column_operators")
|
||||
assert isinstance(result.data.column_operators, dict)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_available_filters_exception_handling(mcp_server):
|
||||
# No exception expected in normal operation
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_available_filters", {"request": {}}
|
||||
)
|
||||
assert hasattr(result.data, "column_operators")
|
||||
|
||||
|
||||
@patch("superset.mcp_service.generic_tools.ModelGetInfoTool._find_object")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_by_uuid(mock_find_object, mcp_server):
|
||||
"""Test getting dashboard info using UUID identifier."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Test Dashboard UUID"
|
||||
dashboard.slug = "test-dashboard-uuid"
|
||||
dashboard.description = "Test description"
|
||||
dashboard.css = ""
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = "{}"
|
||||
dashboard.position_json = "{}"
|
||||
dashboard.published = True
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.created_on = None
|
||||
dashboard.changed_on = None
|
||||
dashboard.created_by = None
|
||||
dashboard.changed_by = None
|
||||
dashboard.uuid = "c3d4e5f6-g7h8-9012-cdef-gh3456789012"
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.created_on_humanized = "2 days ago"
|
||||
dashboard.changed_on_humanized = "1 day ago"
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
|
||||
mock_find_object.return_value = dashboard
|
||||
async with Client(mcp_server) as client:
|
||||
uuid_str = "c3d4e5f6-g7h8-9012-cdef-gh3456789012"
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": uuid_str}}
|
||||
)
|
||||
assert result.data["dashboard_title"] == "Test Dashboard UUID"
|
||||
|
||||
|
||||
@patch("superset.mcp_service.generic_tools.ModelGetInfoTool._find_object")
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_dashboard_info_by_slug(mock_find_object, mcp_server):
|
||||
"""Test getting dashboard info using slug identifier."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = 2
|
||||
dashboard.dashboard_title = "Test Dashboard Slug"
|
||||
dashboard.slug = "test-dashboard-slug"
|
||||
dashboard.description = "Test description"
|
||||
dashboard.css = ""
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = "{}"
|
||||
dashboard.position_json = "{}"
|
||||
dashboard.published = True
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.created_on = None
|
||||
dashboard.changed_on = None
|
||||
dashboard.created_by = None
|
||||
dashboard.changed_by = None
|
||||
dashboard.uuid = "d4e5f6g7-h8i9-0123-defg-hi4567890123"
|
||||
dashboard.url = "/dashboard/2"
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.created_on_humanized = "2 days ago"
|
||||
dashboard.changed_on_humanized = "1 day ago"
|
||||
dashboard.slices = []
|
||||
dashboard.owners = []
|
||||
dashboard.tags = []
|
||||
dashboard.roles = []
|
||||
|
||||
mock_find_object.return_value = dashboard
|
||||
async with Client(mcp_server) as client:
|
||||
result = await client.call_tool(
|
||||
"get_dashboard_info", {"request": {"identifier": "test-dashboard-slug"}}
|
||||
)
|
||||
assert result.data["dashboard_title"] == "Test Dashboard Slug"
|
||||
|
||||
|
||||
@patch("superset.daos.dashboard.DashboardDAO.list")
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_dashboards_custom_uuid_slug_columns(mock_list, mcp_server):
|
||||
"""Test that custom column selection includes UUID and slug when explicitly
|
||||
requested."""
|
||||
dashboard = Mock()
|
||||
dashboard.id = 1
|
||||
dashboard.dashboard_title = "Custom Columns Dashboard"
|
||||
dashboard.slug = "custom-dashboard"
|
||||
dashboard.uuid = "test-custom-uuid-123"
|
||||
dashboard.url = "/dashboard/1"
|
||||
dashboard.published = True
|
||||
dashboard.changed_by_name = "admin"
|
||||
dashboard.changed_on = None
|
||||
dashboard.changed_on_humanized = None
|
||||
dashboard.created_by_name = "admin"
|
||||
dashboard.created_on = None
|
||||
dashboard.created_on_humanized = None
|
||||
dashboard.tags = []
|
||||
dashboard.owners = []
|
||||
dashboard.slices = []
|
||||
dashboard.description = None
|
||||
dashboard.css = None
|
||||
dashboard.certified_by = None
|
||||
dashboard.certification_details = None
|
||||
dashboard.json_metadata = None
|
||||
dashboard.position_json = None
|
||||
dashboard.is_managed_externally = False
|
||||
dashboard.external_url = None
|
||||
dashboard.thumbnail_url = None
|
||||
dashboard.roles = []
|
||||
dashboard.charts = []
|
||||
dashboard._mapping = {
|
||||
"id": dashboard.id,
|
||||
"dashboard_title": dashboard.dashboard_title,
|
||||
"slug": dashboard.slug,
|
||||
"uuid": dashboard.uuid,
|
||||
"url": dashboard.url,
|
||||
"published": dashboard.published,
|
||||
"changed_by_name": dashboard.changed_by_name,
|
||||
"changed_on": dashboard.changed_on,
|
||||
"changed_on_humanized": dashboard.changed_on_humanized,
|
||||
"created_by_name": dashboard.created_by_name,
|
||||
"created_on": dashboard.created_on,
|
||||
"created_on_humanized": dashboard.created_on_humanized,
|
||||
"tags": dashboard.tags,
|
||||
"owners": dashboard.owners,
|
||||
"charts": [],
|
||||
}
|
||||
mock_list.return_value = ([dashboard], 1)
|
||||
async with Client(mcp_server) as client:
|
||||
request = ListDashboardsRequest(
|
||||
select_columns=["id", "dashboard_title", "uuid", "slug"],
|
||||
page=1,
|
||||
page_size=10,
|
||||
)
|
||||
result = await client.call_tool(
|
||||
"list_dashboards", {"request": request.model_dump()}
|
||||
)
|
||||
dashboards = result.data.dashboards
|
||||
assert len(dashboards) == 1
|
||||
assert dashboards[0].uuid == "test-custom-uuid-123"
|
||||
assert dashboards[0].slug == "custom-dashboard"
|
||||
|
||||
# Verify custom columns include UUID and slug
|
||||
assert "uuid" in result.data.columns_requested
|
||||
assert "slug" in result.data.columns_requested
|
||||
assert "uuid" in result.data.columns_loaded
|
||||
assert "slug" in result.data.columns_loaded
|
||||
1158
tests/unit_tests/mcp_service/test_dataset_tools.py
Normal file
1158
tests/unit_tests/mcp_service/test_dataset_tools.py
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user