From fee4e7d8e2920d8a753d9cf897c6c7a9e2628c8e Mon Sep 17 00:00:00 2001 From: Amin Ghadersohi Date: Sat, 1 Nov 2025 02:33:21 +1100 Subject: [PATCH] feat(mcp): MCP service implementation (PRs 3-9 consolidated) (#35877) --- .devcontainer/README.md | 11 - .devcontainer/default/devcontainer.json | 19 + .devcontainer/devcontainer-base.json | 39 + .devcontainer/setup-dev.sh | 80 +- .devcontainer/start-superset.sh | 85 +- .devcontainer/with-mcp/devcontainer.json | 29 + docker/docker-bootstrap.sh | 4 + superset/mcp_service/CLAUDE.md | 523 +++++ superset/mcp_service/app.py | 87 +- superset/mcp_service/auth.py | 106 +- superset/mcp_service/chart/chart_utils.py | 484 ++++ superset/mcp_service/chart/preview_utils.py | 561 +++++ .../mcp_service/chart/prompts/__init__.py | 21 + .../chart/prompts/create_chart_guided.py | 195 ++ .../mcp_service/chart/resources/__init__.py | 21 + .../chart/resources/chart_configs.py | 362 +++ superset/mcp_service/chart/schemas.py | 891 ++++++- superset/mcp_service/chart/tool/__init__.py | 12 + .../mcp_service/chart/tool/generate_chart.py | 462 ++++ .../chart/tool/get_chart_available_filters.py | 50 + .../mcp_service/chart/tool/get_chart_data.py | 656 ++++++ .../mcp_service/chart/tool/get_chart_info.py | 8 + .../chart/tool/get_chart_preview.py | 2082 +++++++++++++++++ .../mcp_service/chart/tool/list_charts.py | 11 +- .../mcp_service/chart/tool/update_chart.py | 224 ++ .../chart/tool/update_chart_preview.py | 168 ++ .../mcp_service/chart/validation/__init__.py | 24 + .../chart/validation/dataset_validator.py | 329 +++ .../mcp_service/chart/validation/pipeline.py | 325 +++ .../chart/validation/runtime/__init__.py | 185 ++ .../runtime/cardinality_validator.py | 195 ++ .../runtime/chart_type_suggester.py | 437 ++++ .../validation/runtime/format_validator.py | 225 ++ .../chart/validation/schema_validator.py | 307 +++ .../{system => commands}/__init__.py | 0 .../mcp_service/commands/create_form_data.py | 33 + superset/mcp_service/common/error_schemas.py | 103 + .../mcp_service/dashboard/prompts/__init__.py | 21 + .../dashboard/resources/__init__.py | 21 + superset/mcp_service/dashboard/schemas.py | 470 ++++ .../mcp_service/dashboard/tool/__init__.py | 30 + .../tool/add_chart_to_existing_dashboard.py | 282 +++ .../dashboard/tool/generate_dashboard.py | 241 ++ .../tool/get_dashboard_available_filters.py | 49 + .../dashboard/tool/get_dashboard_info.py | 101 + .../dashboard/tool/list_dashboards.py | 99 + .../mcp_service/dataset/prompts/__init__.py | 21 + .../mcp_service/dataset/resources/__init__.py | 21 + superset/mcp_service/dataset/schemas.py | 349 +++ superset/mcp_service/dataset/tool/__init__.py | 26 + .../tool/get_dataset_available_filters.py | 49 + .../dataset/tool/get_dataset_info.py | 111 + .../mcp_service/dataset/tool/list_datasets.py | 149 ++ superset/mcp_service/explore/__init__.py | 18 + superset/mcp_service/explore/tool/__init__.py | 29 + .../explore/tool/generate_explore_link.py | 111 + superset/mcp_service/flask_singleton.py | 16 +- superset/mcp_service/mcp_config.py | 81 +- superset/mcp_service/mcp_core.py | 209 +- superset/mcp_service/middleware.py | 749 +++++- superset/mcp_service/screenshot/__init__.py | 35 + .../screenshot/pooled_screenshot.py | 483 ++++ .../screenshot/webdriver_config.py | 139 ++ .../mcp_service/screenshot/webdriver_pool.py | 433 ++++ superset/mcp_service/server.py | 2 + superset/mcp_service/sql_lab/__init__.py | 16 + .../mcp_service/sql_lab/execute_sql_core.py | 221 ++ superset/mcp_service/sql_lab/schemas.py | 109 + superset/mcp_service/sql_lab/sql_lab_utils.py | 243 ++ superset/mcp_service/sql_lab/tool/__init__.py | 30 + .../mcp_service/sql_lab/tool/execute_sql.py | 94 + .../sql_lab/tool/open_sql_lab_with_context.py | 118 + .../mcp_service/system/prompts/__init__.py | 21 + .../mcp_service/system/prompts/quickstart.py | 94 + .../mcp_service/system/resources/__init__.py | 21 + .../system/resources/instance_metadata.py | 107 + superset/mcp_service/system/schemas.py | 113 +- superset/mcp_service/system/tool/__init__.py | 8 + .../system/tool/get_superset_instance_info.py | 268 +++ .../mcp_service/system/tool/health_check.py | 5 +- superset/mcp_service/utils/cache_utils.py | 143 ++ superset/mcp_service/utils/error_builder.py | 369 +++ .../mcp_service/utils/permissions_utils.py | 313 +++ superset/mcp_service/utils/retry_utils.py | 341 +++ superset/mcp_service/utils/url_utils.py | 128 + .../mcp_service/chart/test_chart_schemas.py | 160 ++ .../mcp_service/chart/test_chart_utils.py | 465 ++++ .../chart/tool/test_generate_chart.py | 268 +++ .../chart/tool/test_get_chart_preview.py | 290 +++ .../chart/tool/test_update_chart.py | 385 +++ .../chart/tool/test_update_chart_preview.py | 474 ++++ tests/unit_tests/mcp_service/conftest.py | 23 + .../mcp_service/dashboard/__init__.py | 16 + .../mcp_service/dashboard/tool/__init__.py | 16 + .../tool/test_dashboard_generation.py | 450 ++++ .../dashboard/tool/test_dashboard_tools.py | 573 +++++ .../mcp_service/dataset/__init__.py | 16 + .../mcp_service/dataset/tool/__init__.py | 16 + .../dataset/tool/test_dataset_tools.py | 1231 ++++++++++ .../mcp_service/explore/__init__.py | 16 + .../mcp_service/explore/tool/__init__.py | 16 + .../tool/test_generate_explore_link.py | 580 +++++ .../mcp_service/sql_lab/__init__.py | 16 + .../sql_lab/test_execute_sql_helper.py | 64 + .../mcp_service/sql_lab/tool/__init__.py | 16 + .../sql_lab/tool/test_execute_sql.py | 497 ++++ 106 files changed, 21826 insertions(+), 223 deletions(-) create mode 100644 .devcontainer/default/devcontainer.json create mode 100644 .devcontainer/devcontainer-base.json create mode 100644 .devcontainer/with-mcp/devcontainer.json create mode 100644 superset/mcp_service/CLAUDE.md create mode 100644 superset/mcp_service/chart/chart_utils.py create mode 100644 superset/mcp_service/chart/preview_utils.py create mode 100644 superset/mcp_service/chart/prompts/__init__.py create mode 100644 superset/mcp_service/chart/prompts/create_chart_guided.py create mode 100644 superset/mcp_service/chart/resources/__init__.py create mode 100644 superset/mcp_service/chart/resources/chart_configs.py create mode 100644 superset/mcp_service/chart/tool/generate_chart.py create mode 100644 superset/mcp_service/chart/tool/get_chart_available_filters.py create mode 100644 superset/mcp_service/chart/tool/get_chart_data.py create mode 100644 superset/mcp_service/chart/tool/get_chart_preview.py create mode 100644 superset/mcp_service/chart/tool/update_chart.py create mode 100644 superset/mcp_service/chart/tool/update_chart_preview.py create mode 100644 superset/mcp_service/chart/validation/__init__.py create mode 100644 superset/mcp_service/chart/validation/dataset_validator.py create mode 100644 superset/mcp_service/chart/validation/pipeline.py create mode 100644 superset/mcp_service/chart/validation/runtime/__init__.py create mode 100644 superset/mcp_service/chart/validation/runtime/cardinality_validator.py create mode 100644 superset/mcp_service/chart/validation/runtime/chart_type_suggester.py create mode 100644 superset/mcp_service/chart/validation/runtime/format_validator.py create mode 100644 superset/mcp_service/chart/validation/schema_validator.py rename superset/mcp_service/{system => commands}/__init__.py (100%) create mode 100644 superset/mcp_service/commands/create_form_data.py create mode 100644 superset/mcp_service/common/error_schemas.py create mode 100644 superset/mcp_service/dashboard/prompts/__init__.py create mode 100644 superset/mcp_service/dashboard/resources/__init__.py create mode 100644 superset/mcp_service/dashboard/schemas.py create mode 100644 superset/mcp_service/dashboard/tool/__init__.py create mode 100644 superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py create mode 100644 superset/mcp_service/dashboard/tool/generate_dashboard.py create mode 100644 superset/mcp_service/dashboard/tool/get_dashboard_available_filters.py create mode 100644 superset/mcp_service/dashboard/tool/get_dashboard_info.py create mode 100644 superset/mcp_service/dashboard/tool/list_dashboards.py create mode 100644 superset/mcp_service/dataset/prompts/__init__.py create mode 100644 superset/mcp_service/dataset/resources/__init__.py create mode 100644 superset/mcp_service/dataset/schemas.py create mode 100644 superset/mcp_service/dataset/tool/__init__.py create mode 100644 superset/mcp_service/dataset/tool/get_dataset_available_filters.py create mode 100644 superset/mcp_service/dataset/tool/get_dataset_info.py create mode 100644 superset/mcp_service/dataset/tool/list_datasets.py create mode 100644 superset/mcp_service/explore/__init__.py create mode 100644 superset/mcp_service/explore/tool/__init__.py create mode 100644 superset/mcp_service/explore/tool/generate_explore_link.py create mode 100644 superset/mcp_service/screenshot/__init__.py create mode 100644 superset/mcp_service/screenshot/pooled_screenshot.py create mode 100644 superset/mcp_service/screenshot/webdriver_config.py create mode 100644 superset/mcp_service/screenshot/webdriver_pool.py create mode 100644 superset/mcp_service/sql_lab/__init__.py create mode 100644 superset/mcp_service/sql_lab/execute_sql_core.py create mode 100644 superset/mcp_service/sql_lab/schemas.py create mode 100644 superset/mcp_service/sql_lab/sql_lab_utils.py create mode 100644 superset/mcp_service/sql_lab/tool/__init__.py create mode 100644 superset/mcp_service/sql_lab/tool/execute_sql.py create mode 100644 superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py create mode 100644 superset/mcp_service/system/prompts/__init__.py create mode 100644 superset/mcp_service/system/prompts/quickstart.py create mode 100644 superset/mcp_service/system/resources/__init__.py create mode 100644 superset/mcp_service/system/resources/instance_metadata.py create mode 100644 superset/mcp_service/system/tool/get_superset_instance_info.py create mode 100644 superset/mcp_service/utils/cache_utils.py create mode 100644 superset/mcp_service/utils/error_builder.py create mode 100644 superset/mcp_service/utils/permissions_utils.py create mode 100644 superset/mcp_service/utils/retry_utils.py create mode 100644 superset/mcp_service/utils/url_utils.py create mode 100644 tests/unit_tests/mcp_service/chart/test_chart_schemas.py create mode 100644 tests/unit_tests/mcp_service/chart/test_chart_utils.py create mode 100644 tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py create mode 100644 tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py create mode 100644 tests/unit_tests/mcp_service/chart/tool/test_update_chart.py create mode 100644 tests/unit_tests/mcp_service/chart/tool/test_update_chart_preview.py create mode 100644 tests/unit_tests/mcp_service/conftest.py create mode 100644 tests/unit_tests/mcp_service/dashboard/__init__.py create mode 100644 tests/unit_tests/mcp_service/dashboard/tool/__init__.py create mode 100644 tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py create mode 100644 tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_tools.py create mode 100644 tests/unit_tests/mcp_service/dataset/__init__.py create mode 100644 tests/unit_tests/mcp_service/dataset/tool/__init__.py create mode 100644 tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py create mode 100644 tests/unit_tests/mcp_service/explore/__init__.py create mode 100644 tests/unit_tests/mcp_service/explore/tool/__init__.py create mode 100644 tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py create mode 100644 tests/unit_tests/mcp_service/sql_lab/__init__.py create mode 100644 tests/unit_tests/mcp_service/sql_lab/test_execute_sql_helper.py create mode 100644 tests/unit_tests/mcp_service/sql_lab/tool/__init__.py create mode 100644 tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py diff --git a/.devcontainer/README.md b/.devcontainer/README.md index e5dda78fe30..6b24183edc5 100644 --- a/.devcontainer/README.md +++ b/.devcontainer/README.md @@ -3,14 +3,3 @@ For complete documentation on using GitHub Codespaces with Apache Superset, please see: **[Setting up a Development Environment - GitHub Codespaces](https://superset.apache.org/docs/contributing/development#github-codespaces-cloud-development)** - -## Pre-installed Development Environment - -When you create a new Codespace from this repository, it automatically: - -1. **Creates a Python virtual environment** using `uv venv` -2. **Installs all development dependencies** via `uv pip install -r requirements/development.txt` -3. **Sets up pre-commit hooks** with `pre-commit install` -4. **Activates the virtual environment** automatically in all terminals - -The virtual environment is located at `/workspaces/{repository-name}/.venv` and is automatically activated through environment variables set in the devcontainer configuration. diff --git a/.devcontainer/default/devcontainer.json b/.devcontainer/default/devcontainer.json new file mode 100644 index 00000000000..d0988367947 --- /dev/null +++ b/.devcontainer/default/devcontainer.json @@ -0,0 +1,19 @@ +{ + // Extend the base configuration + "extends": "../devcontainer-base.json", + + "name": "Apache Superset Development (Default)", + + // Forward ports for development + "forwardPorts": [9001], + "portsAttributes": { + "9001": { + "label": "Superset (via Webpack Dev Server)", + "onAutoForward": "notify", + "visibility": "public" + } + }, + + // Auto-start Superset on Codespace resume + "postStartCommand": ".devcontainer/start-superset.sh" +} diff --git a/.devcontainer/devcontainer-base.json b/.devcontainer/devcontainer-base.json new file mode 100644 index 00000000000..59ed6ee1d2f --- /dev/null +++ b/.devcontainer/devcontainer-base.json @@ -0,0 +1,39 @@ +{ + "name": "Apache Superset Development", + // Keep this in sync with the base image in Dockerfile (ARG PY_VER) + // Using the same base as Dockerfile, but non-slim for dev tools + "image": "python:3.11.13-bookworm", + + "features": { + "ghcr.io/devcontainers/features/docker-in-docker:2": { + "moby": true, + "dockerDashComposeVersion": "v2" + }, + "ghcr.io/devcontainers/features/node:1": { + "version": "20" + }, + "ghcr.io/devcontainers/features/git:1": {}, + "ghcr.io/devcontainers/features/common-utils:2": { + "configureZshAsDefaultShell": true + }, + "ghcr.io/devcontainers/features/sshd:1": { + "version": "latest" + } + }, + + // Run commands after container is created + "postCreateCommand": "chmod +x .devcontainer/setup-dev.sh && .devcontainer/setup-dev.sh", + + // VS Code customizations + "customizations": { + "vscode": { + "extensions": [ + "ms-python.python", + "ms-python.vscode-pylance", + "charliermarsh.ruff", + "dbaeumer.vscode-eslint", + "esbenp.prettier-vscode" + ] + } + } +} diff --git a/.devcontainer/setup-dev.sh b/.devcontainer/setup-dev.sh index 91482551bee..f8521189009 100755 --- a/.devcontainer/setup-dev.sh +++ b/.devcontainer/setup-dev.sh @@ -3,76 +3,30 @@ echo "๐Ÿ”ง Setting up Superset development environment..." -# System dependencies and uv are now pre-installed in the Docker image -# This speeds up Codespace creation significantly! +# The universal image has most tools, just need Superset-specific libs +echo "๐Ÿ“ฆ Installing Superset-specific dependencies..." +sudo apt-get update +sudo apt-get install -y \ + libsasl2-dev \ + libldap2-dev \ + libpq-dev \ + tmux \ + gh -# Create virtual environment using uv -echo "๐Ÿ Creating Python virtual environment..." -if ! uv venv; then - echo "โŒ Failed to create virtual environment" - exit 1 -fi +# Install uv for fast Python package management +echo "๐Ÿ“ฆ Installing uv..." +curl -LsSf https://astral.sh/uv/install.sh | sh -# Install Python dependencies -echo "๐Ÿ“ฆ Installing Python dependencies..." -if ! uv pip install -r requirements/development.txt; then - echo "โŒ Failed to install Python dependencies" - echo "๐Ÿ’ก You may need to run this manually after the Codespace starts" - exit 1 -fi - -# Install pre-commit hooks -echo "๐Ÿช Installing pre-commit hooks..." -if source .venv/bin/activate && pre-commit install; then - echo "โœ… Pre-commit hooks installed" -else - echo "โš ๏ธ Pre-commit hooks installation failed (non-critical)" -fi +# Add cargo/bin to PATH for uv +echo 'export PATH="$HOME/.cargo/bin:$PATH"' >> ~/.bashrc +echo 'export PATH="$HOME/.cargo/bin:$PATH"' >> ~/.zshrc # Install Claude Code CLI via npm echo "๐Ÿค– Installing Claude Code..." -if npm install -g @anthropic-ai/claude-code; then - echo "โœ… Claude Code installed" -else - echo "โš ๏ธ Claude Code installation failed (non-critical)" -fi +npm install -g @anthropic-ai/claude-code # Make the start script executable chmod +x .devcontainer/start-superset.sh -# Add bashrc additions for automatic venv activation -echo "๐Ÿ”ง Setting up automatic environment activation..." -if [ -f ~/.bashrc ]; then - # Check if we've already added our additions - if ! grep -q "Superset Codespaces environment setup" ~/.bashrc; then - echo "" >> ~/.bashrc - cat .devcontainer/bashrc-additions >> ~/.bashrc - echo "โœ… Added automatic venv activation to ~/.bashrc" - else - echo "โœ… Bashrc additions already present" - fi -else - # Create bashrc if it doesn't exist - cat .devcontainer/bashrc-additions > ~/.bashrc - echo "โœ… Created ~/.bashrc with automatic venv activation" -fi - -# Also add to zshrc since that's the default shell -if [ -f ~/.zshrc ] || [ -n "$ZSH_VERSION" ]; then - if ! grep -q "Superset Codespaces environment setup" ~/.zshrc; then - echo "" >> ~/.zshrc - cat .devcontainer/bashrc-additions >> ~/.zshrc - echo "โœ… Added automatic venv activation to ~/.zshrc" - fi -fi - echo "โœ… Development environment setup complete!" -echo "" -echo "๐Ÿ“ The virtual environment will be automatically activated in new terminals" -echo "" -echo "๐Ÿ”„ To activate in this terminal, run:" -echo " source ~/.bashrc" -echo "" -echo "๐Ÿš€ To start Superset:" -echo " start-superset" -echo "" +echo "๐Ÿš€ Run '.devcontainer/start-superset.sh' to start Superset" diff --git a/.devcontainer/start-superset.sh b/.devcontainer/start-superset.sh index 6ba990cae10..b480b04aacb 100755 --- a/.devcontainer/start-superset.sh +++ b/.devcontainer/start-superset.sh @@ -1,14 +1,14 @@ #!/bin/bash # Startup script for Superset in Codespaces -# Log to a file for debugging -LOG_FILE="/tmp/superset-startup.log" -echo "[$(date)] Starting Superset startup script" >> "$LOG_FILE" -echo "[$(date)] User: $(whoami), PWD: $(pwd)" >> "$LOG_FILE" - echo "๐Ÿš€ Starting Superset in Codespaces..." echo "๐ŸŒ Frontend will be available at port 9001" +# Check if MCP is enabled +if [ "$ENABLE_MCP" = "true" ]; then + echo "๐Ÿค– MCP Service will be available at port 5008" +fi + # Find the workspace directory (Codespaces clones as 'superset', not 'superset-2') WORKSPACE_DIR=$(find /workspaces -maxdepth 1 -name "superset*" -type d | head -1) if [ -n "$WORKSPACE_DIR" ]; then @@ -18,71 +18,32 @@ else echo "๐Ÿ“ Using current directory: $(pwd)" fi -# Wait for Docker to be available -echo "โณ Waiting for Docker to start..." -echo "[$(date)] Waiting for Docker..." >> "$LOG_FILE" -max_attempts=30 -attempt=0 -while ! docker info > /dev/null 2>&1; do - if [ $attempt -eq $max_attempts ]; then - echo "โŒ Docker failed to start after $max_attempts attempts" - echo "[$(date)] Docker failed to start after $max_attempts attempts" >> "$LOG_FILE" - echo "๐Ÿ”„ Please restart the Codespace or run this script manually later" - exit 1 - fi - echo " Attempt $((attempt + 1))/$max_attempts..." - echo "[$(date)] Docker check attempt $((attempt + 1))/$max_attempts" >> "$LOG_FILE" - sleep 2 - attempt=$((attempt + 1)) -done -echo "โœ… Docker is ready!" -echo "[$(date)] Docker is ready" >> "$LOG_FILE" - -# Check if Superset containers are already running -if docker ps | grep -q "superset"; then - echo "โœ… Superset containers are already running!" - echo "" - echo "๐ŸŒ To access Superset:" - echo " 1. Click the 'Ports' tab at the bottom of VS Code" - echo " 2. Find port 9001 and click the globe icon to open" - echo " 3. Wait 10-20 minutes for initial startup" - echo "" - echo "๐Ÿ“ Login credentials: admin/admin" - exit 0 +# Check if docker is running +if ! docker info > /dev/null 2>&1; then + echo "โณ Waiting for Docker to start..." + sleep 5 fi # Clean up any existing containers echo "๐Ÿงน Cleaning up existing containers..." -docker-compose -f docker-compose-light.yml down +docker-compose -f docker-compose-light.yml --profile mcp down # Start services -echo "๐Ÿ—๏ธ Starting Superset in background (daemon mode)..." +echo "๐Ÿ—๏ธ Building and starting services..." echo "" +echo "๐Ÿ“ Once started, login with:" +echo " Username: admin" +echo " Password: admin" +echo "" +echo "๐Ÿ“‹ Running in foreground with live logs (Ctrl+C to stop)..." -# Start in detached mode -docker-compose -f docker-compose-light.yml up -d - -echo "" -echo "โœ… Docker Compose started successfully!" -echo "" -echo "๐Ÿ“‹ Important information:" -echo "โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”" -echo "โฑ๏ธ Initial startup takes 10-20 minutes" -echo "๐ŸŒ Check the 'Ports' tab for your Superset URL (port 9001)" -echo "๐Ÿ‘ค Login: admin / admin" -echo "" -echo "๐Ÿ“Š Useful commands:" -echo " docker-compose -f docker-compose-light.yml logs -f # Follow logs" -echo " docker-compose -f docker-compose-light.yml ps # Check status" -echo " docker-compose -f docker-compose-light.yml down # Stop services" -echo "โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”โ”" -echo "" -echo "๐Ÿ’ค Keeping terminal open for 60 seconds to test persistence..." -sleep 60 -echo "โœ… Test complete - check if this terminal is still visible!" - -# Show final status -docker-compose -f docker-compose-light.yml ps +# Run docker-compose and capture exit code +if [ "$ENABLE_MCP" = "true" ]; then + echo "๐Ÿค– Starting with MCP Service enabled..." + docker-compose -f docker-compose-light.yml --profile mcp up +else + docker-compose -f docker-compose-light.yml up +fi EXIT_CODE=$? # If it failed, provide helpful instructions diff --git a/.devcontainer/with-mcp/devcontainer.json b/.devcontainer/with-mcp/devcontainer.json new file mode 100644 index 00000000000..c3f8b654ebc --- /dev/null +++ b/.devcontainer/with-mcp/devcontainer.json @@ -0,0 +1,29 @@ +{ + // Extend the base configuration + "extends": "../devcontainer-base.json", + + "name": "Apache Superset Development with MCP", + + // Forward ports for development + "forwardPorts": [9001, 5008], + "portsAttributes": { + "9001": { + "label": "Superset (via Webpack Dev Server)", + "onAutoForward": "notify", + "visibility": "public" + }, + "5008": { + "label": "MCP Service (Model Context Protocol)", + "onAutoForward": "notify", + "visibility": "private" + } + }, + + // Auto-start Superset with MCP on Codespace resume + "postStartCommand": "ENABLE_MCP=true .devcontainer/start-superset.sh", + + // Environment variables + "containerEnv": { + "ENABLE_MCP": "true" + } +} diff --git a/docker/docker-bootstrap.sh b/docker/docker-bootstrap.sh index 9d18b66626c..d8524d928f8 100755 --- a/docker/docker-bootstrap.sh +++ b/docker/docker-bootstrap.sh @@ -86,6 +86,10 @@ case "${1}" in echo "Starting web app..." /usr/bin/run-server.sh ;; + mcp) + echo "Starting MCP service..." + superset mcp run --host 0.0.0.0 --port ${MCP_PORT:-5008} --debug + ;; *) echo "Unknown Operation!!!" ;; diff --git a/superset/mcp_service/CLAUDE.md b/superset/mcp_service/CLAUDE.md new file mode 100644 index 00000000000..94b33a1bce4 --- /dev/null +++ b/superset/mcp_service/CLAUDE.md @@ -0,0 +1,523 @@ +# MCP Service - LLM Agent Guide + +This guide helps LLM agents understand the Superset MCP (Model Context Protocol) service architecture and development conventions. + +## โš ๏ธ CRITICAL: Apache License Headers + +**EVERY Python file in the MCP service MUST have the Apache Software Foundation license header.** + +This includes: +- All `.py` files (tool files, schemas, __init__.py files, etc.) +- **NEVER remove existing license headers during refactoring or edits** +- **ALWAYS add license headers when creating new files** +- **ALWAYS verify license headers are present after editing files** + +If you see a file without a license header, ADD IT IMMEDIATELY. If you accidentally remove one during editing, ADD IT BACK. + +Use this exact template at the top of EVERY Python file: + +```python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +``` + +**Note**: LLM instruction files like `CLAUDE.md`, `AGENTS.md`, etc. are excluded from this requirement (listed in `.rat-excludes`) to avoid token overhead, but ALL other Python files require it. + +## Architecture Overview + +The MCP service provides programmatic access to Superset via the Model Context Protocol, allowing AI assistants to interact with dashboards, charts, datasets, SQL Lab, and instance metadata. + +### Key Components + +``` +superset/mcp_service/ +โ”œโ”€โ”€ app.py # FastMCP app factory and tool registration +โ”œโ”€โ”€ auth.py # Authentication and authorization +โ”œโ”€โ”€ mcp_config.py # Default configuration +โ”œโ”€โ”€ mcp_core.py # Reusable core classes for tools +โ”œโ”€โ”€ flask_singleton.py # Flask app singleton for MCP context +โ”œโ”€โ”€ chart/ # Chart-related tools +โ”‚ โ”œโ”€โ”€ schemas.py # Pydantic schemas for chart responses +โ”‚ โ””โ”€โ”€ tool/ # Chart tool implementations +โ”‚ โ”œโ”€โ”€ __init__.py # Tool exports +โ”‚ โ”œโ”€โ”€ list_charts.py +โ”‚ โ””โ”€โ”€ get_chart_info.py +โ”œโ”€โ”€ dashboard/ # Dashboard-related tools +โ”‚ โ”œโ”€โ”€ schemas.py +โ”‚ โ””โ”€โ”€ tool/ +โ”œโ”€โ”€ dataset/ # Dataset-related tools +โ”‚ โ”œโ”€โ”€ schemas.py +โ”‚ โ””โ”€โ”€ tool/ +โ””โ”€โ”€ system/ # System/instance tools + โ”œโ”€โ”€ schemas.py + โ””โ”€โ”€ tool/ +``` + +## Critical Convention: Tool, Prompt, and Resource Registration + +**IMPORTANT**: When creating new MCP tools, prompts, or resources, you MUST add their imports to `app.py` for auto-registration. Do NOT add them to `server.py` - that approach doesn't work properly. + +### How to Add a New Tool + +1. **Create the tool file** in the appropriate directory (e.g., `chart/tool/my_new_tool.py`) +2. **Decorate with `@mcp.tool`** to register it with FastMCP +3. **Add import to `app.py`** at the bottom of the file where other tools are imported (around line 210-242) + +**Example**: +```python +# superset/mcp_service/chart/tool/my_new_tool.py +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook + +@mcp.tool +@mcp_auth_hook +def my_new_tool(param: str) -> dict: + """Tool description for LLMs.""" + return {"result": "success"} +``` + +**Then add to app.py**: +```python +# superset/mcp_service/app.py (at the bottom, around line 207-224) +from superset.mcp_service.chart.tool import ( # noqa: F401, E402 + get_chart_info, + list_charts, + my_new_tool, # ADD YOUR TOOL HERE +) +``` + +**Why this matters**: Tools use `@mcp.tool` decorators and register automatically on import. The import MUST be in `app.py` at the bottom of the file (after the `mcp` instance is created). If you don't import the tool in `app.py`, it won't be available to MCP clients. DO NOT add imports to `server.py` - that file is for running the server only. + +### How to Add a New Prompt + +1. **Create the prompt file** in the appropriate directory (e.g., `chart/prompts/my_new_prompt.py`) +2. **Decorate with `@mcp.prompt`** to register it with FastMCP +3. **Add import to module's `__init__.py`** (e.g., `chart/prompts/__init__.py`) +4. **Ensure module is imported in `app.py`** (around line 244-253) + +**Example**: +```python +# superset/mcp_service/chart/prompts/my_new_prompt.py +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook + +@mcp.prompt("my_new_prompt") +@mcp_auth_hook +async def my_new_prompt_handler(ctx: Context) -> str: + """Interactive prompt for doing something.""" + return "Prompt instructions here..." +``` + +**Then add to `chart/prompts/__init__.py`**: +```python +# superset/mcp_service/chart/prompts/__init__.py +from . import create_chart_guided # existing +from . import my_new_prompt # ADD YOUR PROMPT HERE +``` + +**Verify module import exists in `app.py`** (around line 248): +```python +# superset/mcp_service/app.py +from superset.mcp_service.chart import prompts as chart_prompts # This imports all prompts +``` + +### How to Add a New Resource + +1. **Create the resource file** in the appropriate directory (e.g., `chart/resources/my_new_resource.py`) +2. **Decorate with `@mcp.resource`** to register it with FastMCP +3. **Add import to module's `__init__.py`** (e.g., `chart/resources/__init__.py`) +4. **Ensure module is imported in `app.py`** (around line 244-253) + +**Example**: +```python +# superset/mcp_service/chart/resources/my_new_resource.py +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook + +@mcp.resource("superset://chart/my_resource") +@mcp_auth_hook +def get_my_resource() -> str: + """Resource description for LLMs.""" + return "Resource data here..." +``` + +**Then add to `chart/resources/__init__.py`**: +```python +# superset/mcp_service/chart/resources/__init__.py +from . import chart_configs # existing +from . import my_new_resource # ADD YOUR RESOURCE HERE +``` + +**Verify module import exists in `app.py`** (around line 249): +```python +# superset/mcp_service/app.py +from superset.mcp_service.chart import resources as chart_resources # This imports all resources +``` + +**Why this matters**: Prompts and resources work similarly to tools - they use decorators and register on import. The module-level imports (`chart/prompts/__init__.py`, `chart/resources/__init__.py`) ensure individual files are imported when the module is imported. The `app.py` imports ensure the modules are loaded when the MCP service starts. + +## Tool Development Patterns + +### 1. Use Core Classes for Reusability + +The `mcp_core.py` module provides reusable patterns: + +- **`ModelListCore`**: For listing resources (dashboards, charts, datasets) +- **`ModelGetInfoCore`**: For getting resource details by ID/UUID +- **`ModelGetAvailableFiltersCore`**: For retrieving filterable columns + +**Example**: +```python +from superset.mcp_service.mcp_core import ModelListCore +from superset.daos.dashboard import DashboardDAO +from superset.mcp_service.dashboard.schemas import DashboardList + +list_core = ModelListCore( + dao_class=DashboardDAO, + output_schema=DashboardList, + logger=logger, +) + +@mcp.tool +@mcp_auth_hook +def list_dashboards(filters: List[DashboardFilter], page: int = 1) -> DashboardList: + return list_core.run_tool(filters=filters, page=page, page_size=10) +``` + +### 2. Always Use Authentication + +**Every tool must use `@mcp_auth_hook`** to ensure: +- User authentication from JWT or configured admin user +- Permission checking via JWT scopes +- Audit logging of tool access + +```python +from superset.mcp_service.auth import mcp_auth_hook + +@mcp.tool +@mcp_auth_hook # REQUIRED +def my_tool() -> dict: + # g.user is set by mcp_auth_hook + return {"user": g.user.username} +``` + +### 3. Use Pydantic Schemas + +**All tool inputs and outputs must be Pydantic models** for: +- Automatic validation +- LLM-friendly schema generation +- Type safety + +**Convention**: Place schemas in `{module}/schemas.py` + +```python +from pydantic import BaseModel, Field + +class MyToolRequest(BaseModel): + param: str = Field(..., description="Parameter description for LLMs") + +class MyToolResponse(BaseModel): + result: str = Field(..., description="Result description") + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="Response timestamp" + ) +``` + +### 4. Follow the DAO Pattern + +**Use Superset's DAO (Data Access Object) layer** instead of direct database queries: + +```python +from superset.daos.dashboard import DashboardDAO + +# GOOD: Use DAO +dashboard = DashboardDAO.find_by_id(dashboard_id) + +# BAD: Don't query directly +dashboard = db.session.query(Dashboard).filter_by(id=dashboard_id).first() +``` + +### 5. Python Type Hints (Python 3.10+ Style) + +**CRITICAL**: Always use modern Python 3.10+ union syntax for type hints. + +```python +# GOOD - Modern Python 3.10+ syntax +from typing import List, Dict, Any +from pydantic import BaseModel, Field + +class MySchema(BaseModel): + name: str | None = Field(None, description="Optional name") + tags: List[str] = Field(default_factory=list) + metadata: Dict[str, Any] = Field(default_factory=dict) + +def my_function( + id: int, + filters: List[str] | None = None, + options: Dict[str, Any] | None = None +) -> MySchema | None: + pass + +# BAD - Old-style Optional (DO NOT USE) +from typing import Optional, List, Dict, Any + +class MySchema(BaseModel): + name: Optional[str] = Field(None, description="Optional name") # Wrong! + +def my_function( + id: int, + filters: Optional[List[str]] = None, # Wrong! + options: Optional[Dict[str, Any]] = None # Wrong! +) -> Optional[MySchema]: # Wrong! + pass +``` + +**Key rules:** +- Use `T | None` instead of `Optional[T]` +- Do NOT import `Optional` from typing +- Still import `List`, `Dict`, `Any`, etc. from typing (for now) +- All new code must follow this pattern + +### 6. Error Handling + +**Use consistent error schemas**: + +```python +class MyError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: datetime = Field( + default_factory=lambda: datetime.now(timezone.utc), + description="Error timestamp" + ) + +@mcp.tool +@mcp_auth_hook +def my_tool(id: int) -> MyResponse: + try: + result = process_data(id) + return MyResponse(data=result) + except NotFound: + raise ValueError(f"Resource {id} not found") +``` + +## Testing Conventions + +### Unit Tests + +Place unit tests in `tests/unit_tests/mcp_service/{module}/tool/test_{tool_name}.py` + +**Test structure**: +```python +from unittest.mock import MagicMock, patch +import pytest + +class TestMyTool: + @pytest.fixture + def mock_dao(self): + """Create mock DAO for testing.""" + dao = MagicMock() + dao.find_by_id.return_value = create_mock_object() + return dao + + @patch("superset.mcp_service.chart.tool.my_tool.ChartDAO") + def test_my_tool_success(self, mock_dao_class, mock_dao): + """Test successful tool execution.""" + mock_dao_class.return_value = mock_dao + + result = my_tool(id=1) + + assert result.data is not None + mock_dao.find_by_id.assert_called_once_with(1) +``` + +### Integration Tests + +Use Flask test client for integration tests: + +```python +def test_tool_with_flask_context(app): + """Test tool with full Flask app context.""" + with app.app_context(): + result = my_tool(id=1) + assert result is not None +``` + +## Common Pitfalls to Avoid + +### 1. โŒ Forgetting Tool Import in app.py +**Problem**: Tool exists but isn't available to MCP clients. +**Solution**: Always add tool import to `app.py` (at the bottom) after creating it. Never add to `server.py`. + +### 2. โŒ Adding Tool Imports to server.py +**Problem**: Tools won't register properly, causing runtime errors. +**Solution**: Tool imports must be in `app.py` at the bottom of the file, not in `server.py`. The `server.py` file is only for running the server. + +### 3. โŒ Missing @mcp_auth_hook Decorator +**Problem**: Tool bypasses authentication and authorization. +**Solution**: Always use `@mcp_auth_hook` on every tool. + +### 4. โŒ Using `Optional` Instead of Union Syntax +**Problem**: Old-style Optional[T] is not Python 3.10+ style. +**Solution**: Use `T | None` instead of `Optional[T]` for all type hints. +```python +# GOOD - Modern Python 3.10+ syntax +def my_function(param: str | None = None) -> int | None: + pass + +# BAD - Old-style Optional +from typing import Optional +def my_function(param: Optional[str] = None) -> Optional[int]: + pass +``` + +### 5. โŒ Using `any` Types in Schemas +**Problem**: Violates TypeScript modernization goals, no validation. +**Solution**: Use proper Pydantic types with Field descriptions. + +### 6. โŒ Direct Database Queries +**Problem**: Bypasses Superset's security and caching layers. +**Solution**: Use DAO classes (ChartDAO, DashboardDAO, etc.). + +### 7. โŒ Not Using Core Classes +**Problem**: Duplicating list/get_info/filter logic across tools. +**Solution**: Use ModelListCore, ModelGetInfoCore, ModelGetAvailableFiltersCore. + +### 8. โŒ Missing Apache License Headers +**Problem**: CI fails on license check. +**Solution**: Add Apache license header to all new .py files. Use this exact template at the top of every new Python file: + +```python +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +``` + +**Note**: LLM instruction files like `CLAUDE.md`, `AGENTS.md`, etc. are excluded from this requirement (listed in `.rat-excludes`) to avoid token overhead. + +### 9. โŒ Using `@mcp.tool()` with Empty Parentheses +**Problem**: Inconsistent decorator style. +**Solution**: Use `@mcp.tool` without parentheses unless passing arguments. +```python +# GOOD +@mcp.tool +def my_tool(): + pass + +# BAD +@mcp.tool() +def my_tool(): + pass +``` + +### 10. โŒ Circular Imports +**Problem**: Importing `mcp` from `app.py` creates circular dependency. +**Solution**: Import `mcp` at module level in tool files: +```python +# GOOD +from superset.mcp_service.app import mcp + +@mcp.tool +def my_tool(): + pass + +# BAD - causes circular import +from superset.mcp_service.app import mcp, some_other_function +``` + +## Configuration + +Default configuration is in `mcp_config.py`. Users can override in `superset_config.py`: + +```python +# superset_config.py +MCP_ADMIN_USERNAME = "your_admin" +MCP_AUTH_ENABLED = True +MCP_JWT_PUBLIC_KEY = "your_public_key" +``` + +## Tool Discovery + +MCP clients discover tools via: +1. **Tool listing**: All tools with `@mcp.tool` are automatically listed +2. **Schema introspection**: Pydantic schemas generate JSON Schema for LLMs +3. **Instructions**: `DEFAULT_INSTRUCTIONS` in `app.py` documents available tools + +## Resources for Learning + +- **MCP Specification**: https://modelcontextprotocol.io/ +- **FastMCP Documentation**: https://github.com/jlowin/fastmcp +- **Superset DAO Patterns**: See `superset/daos/` for examples +- **Pydantic Documentation**: https://docs.pydantic.dev/ + +## Quick Checklist for New Tools + +- [ ] Created tool file in `{module}/tool/{tool_name}.py` +- [ ] Added `@mcp.tool` decorator +- [ ] Added `@mcp_auth_hook` decorator +- [ ] Created Pydantic request/response schemas in `{module}/schemas.py` +- [ ] Used DAO classes instead of direct queries +- [ ] Added tool import to `app.py` (around line 210-242) +- [ ] Added Apache license header to new files +- [ ] Created unit tests in `tests/unit_tests/mcp_service/{module}/tool/test_{tool_name}.py` +- [ ] Updated `DEFAULT_INSTRUCTIONS` in `app.py` if adding new capability +- [ ] Tested locally with MCP client (e.g., Claude Desktop) + +## Quick Checklist for New Prompts + +- [ ] Created prompt file in `{module}/prompts/{prompt_name}.py` +- [ ] Added `@mcp.prompt("prompt_name")` decorator +- [ ] Added `@mcp_auth_hook` decorator +- [ ] Made function async: `async def prompt_handler(ctx: Context) -> str` +- [ ] Added import to `{module}/prompts/__init__.py` +- [ ] Verified module import exists in `app.py` (around line 244-253) +- [ ] Added Apache license header to new file +- [ ] Updated `DEFAULT_INSTRUCTIONS` in `app.py` to list the new prompt +- [ ] Tested locally with MCP client (e.g., Claude Desktop) + +## Quick Checklist for New Resources + +- [ ] Created resource file in `{module}/resources/{resource_name}.py` +- [ ] Added `@mcp.resource("superset://{path}")` decorator with unique URI +- [ ] Added `@mcp_auth_hook` decorator +- [ ] Implemented resource data retrieval logic +- [ ] Added import to `{module}/resources/__init__.py` +- [ ] Verified module import exists in `app.py` (around line 244-253) +- [ ] Added Apache license header to new file +- [ ] Updated `DEFAULT_INSTRUCTIONS` in `app.py` to list the new resource +- [ ] Tested locally with MCP client (e.g., Claude Desktop) + +## Getting Help + +- Check existing tool implementations for patterns (chart/tool/, dashboard/tool/) +- Review core classes in `mcp_core.py` for reusable functionality +- See `CLAUDE.md` in project root for general Superset development guidelines +- Consult Superset documentation: https://superset.apache.org/docs/ diff --git a/superset/mcp_service/app.py b/superset/mcp_service/app.py index 57f55d59102..a091b9943a6 100644 --- a/superset/mcp_service/app.py +++ b/superset/mcp_service/app.py @@ -77,6 +77,40 @@ Available Prompts: - superset_quickstart: Interactive guide for getting started with the MCP service - create_chart_guided: Step-by-step chart creation wizard +Common Chart Types (viz_type) and Behaviors: + +Interactive Charts (support sorting, filtering, drill-down): +- table: Standard table view with sorting and filtering +- pivot_table_v2: Pivot table with grouping and aggregations +- echarts_timeseries_line: Time series line chart +- echarts_timeseries_bar: Time series bar chart +- echarts_timeseries_area: Time series area chart +- echarts_timeseries_scatter: Time series scatter plot +- mixed_timeseries: Combined line/bar time series + +Common Visualization Types: +- big_number: Single metric display +- big_number_total: Total value display +- pie: Pie chart for proportions +- echarts_timeseries: Generic time series chart +- funnel: Funnel chart for conversion analysis +- gauge_chart: Gauge/speedometer visualization +- heatmap_v2: Heat map for correlation analysis +- sankey_v2: Sankey diagram for flow visualization +- sunburst_v2: Sunburst chart for hierarchical data +- treemap_v2: Tree map for hierarchical proportions +- word_cloud: Word cloud visualization +- world_map: Geographic world map +- box_plot: Box plot for distribution analysis +- bubble: Bubble chart for 3-dimensional data + +Query Examples: +- List all interactive tables: + filters=[{"col": "viz_type", "opr": "in", "value": ["table", "pivot_table_v2"]}] +- List time series charts: + filters=[{"col": "viz_type", "opr": "sw", "value": "echarts_timeseries"}] +- Search by name: search="sales" + General usage tips: - All listing tools use 1-based pagination (first page is 1) - Use 'filters' parameter for advanced queries (see *_available_filters tools) @@ -202,13 +236,56 @@ def create_mcp_app( # Create default MCP instance for backward compatibility # Tool modules can import this and use @mcp.tool decorators -mcp = create_mcp_app() +mcp = create_mcp_app(stateless_http=True) -from superset.mcp_service.chart.tool import ( # noqa: F401, E402 - get_chart_info, - list_charts, +# Import all MCP tools to register them with the mcp instance +# NOTE: Always add new tool imports here when creating new MCP tools. +# Tools use @mcp.tool decorators and register automatically on import. +# Import prompts and resources to register them with the mcp instance +# NOTE: Always add new prompt/resource imports here when creating new prompts/resources. +# Prompts use @mcp.prompt decorators and resources use @mcp.resource decorators. +# They register automatically on import, similar to tools. +from superset.mcp_service.chart import ( # noqa: F401, E402 + prompts as chart_prompts, + resources as chart_resources, +) +from superset.mcp_service.chart.tool import ( # noqa: F401, E402 + generate_chart, + get_chart_available_filters, + get_chart_data, + get_chart_info, + get_chart_preview, + list_charts, + update_chart, + update_chart_preview, +) +from superset.mcp_service.dashboard.tool import ( # noqa: F401, E402 + add_chart_to_existing_dashboard, + generate_dashboard, + get_dashboard_available_filters, + get_dashboard_info, + list_dashboards, +) +from superset.mcp_service.dataset.tool import ( # noqa: F401, E402 + get_dataset_available_filters, + get_dataset_info, + list_datasets, +) +from superset.mcp_service.explore.tool import ( # noqa: F401, E402 + generate_explore_link, +) +from superset.mcp_service.sql_lab.tool import ( # noqa: F401, E402 + execute_sql, + open_sql_lab_with_context, +) +from superset.mcp_service.system import ( # noqa: F401, E402 + prompts as system_prompts, + resources as system_resources, +) +from superset.mcp_service.system.tool import ( # noqa: F401, E402 + get_superset_instance_info, + health_check, ) -from superset.mcp_service.system.tool import health_check # noqa: F401, E402 def init_fastmcp_server( diff --git a/superset/mcp_service/auth.py b/superset/mcp_service/auth.py index 9b1fc1de9ba..1a1e68054f4 100644 --- a/superset/mcp_service/auth.py +++ b/superset/mcp_service/auth.py @@ -28,11 +28,14 @@ Future enhancements (to be added in separate PRs): """ import logging -from typing import Any, Callable, TypeVar +from typing import Any, Callable, TYPE_CHECKING, TypeVar from flask import g from flask_appbuilder.security.sqla.models import User +if TYPE_CHECKING: + from superset.connectors.sqla.models import SqlaTable + # Type variable for decorated functions F = TypeVar("F", bound=Callable[..., Any]) @@ -45,15 +48,14 @@ def get_user_from_request() -> User: TODO (future PR): Add JWT token extraction and validation. TODO (future PR): Add user impersonation support. + TODO (future PR): Add fallback user configuration. - For now, this uses MCP_DEV_USERNAME from configuration for development. - - Raises: - ValueError: If MCP_DEV_USERNAME is not configured or user doesn't exist + For now, this returns the admin user for development. """ from flask import current_app + from sqlalchemy.orm import joinedload - from superset import security_manager + from superset.extensions import db # TODO: Extract from JWT token once authentication is implemented # For now, use MCP_DEV_USERNAME from configuration @@ -62,14 +64,57 @@ def get_user_from_request() -> User: if not username: raise ValueError("Username not configured") - user = security_manager.find_user(username) + # Query user directly with eager loading to ensure fresh session-bound object + # Do NOT use security_manager.find_user() as it may return cached/detached user + user = ( + db.session.query(User) + .options(joinedload(User.roles), joinedload(User.groups)) + .filter(User.username == username) + .first() + ) if not user: - raise ValueError(f"User '{username}' not found") + raise ValueError( + f"User '{username}' not found. " + f"Please create admin user with: superset fab create-admin" + ) return user +def has_dataset_access(dataset: "SqlaTable") -> bool: + """ + Validate user has access to the dataset. + + This function checks if the current user (from Flask g.user context) + has permission to access the given dataset using Superset's security manager. + + Args: + dataset: The SqlaTable dataset to check access for + + Returns: + True if user has access, False otherwise + + Security Note: + This should be called after mcp_auth_hook has set g.user. + Returns False on any error to fail securely. + """ + try: + from superset import security_manager + + # Check if user has read access to the dataset + if hasattr(g, "user") and g.user: + # Use Superset's security manager to check dataset access + return security_manager.can_access_datasource(datasource=dataset) + + # If no user context, deny access + return False + + except Exception as e: + logger.warning("Error checking dataset access: %s", e) + return False # Deny access on error + + def mcp_auth_hook(tool_func: F) -> F: """ Authentication and authorization decorator for MCP tools. @@ -87,17 +132,48 @@ def mcp_auth_hook(tool_func: F) -> F: @functools.wraps(tool_func) def wrapper(*args: Any, **kwargs: Any) -> Any: - # Get user and set Flask context + from superset.extensions import db + + # Get user and set Flask context OUTSIDE try block user = get_user_from_request() + + # Force load relationships NOW while session is definitely active + _ = user.roles + if hasattr(user, "groups"): + _ = user.groups + g.user = user - # TODO: Add permission checks here in future PR - # TODO: Add audit logging here in future PR + try: + # TODO: Add permission checks here in future PR + # TODO: Add audit logging here in future PR - logger.debug( - "MCP tool call: user=%s, tool=%s", user.username, tool_func.__name__ - ) + logger.debug( + "MCP tool call: user=%s, tool=%s", user.username, tool_func.__name__ + ) - return tool_func(*args, **kwargs) + result = tool_func(*args, **kwargs) + + return result + + except Exception: + # On error, rollback and cleanup session + # pylint: disable=consider-using-transaction + try: + db.session.rollback() + db.session.remove() + except Exception as e: + logger.warning("Error cleaning up session after exception: %s", e) + raise + + finally: + # Only rollback if session is still active (no exception occurred) + # Do NOT call remove() on success to avoid detaching user + try: + if db.session.is_active: + # pylint: disable=consider-using-transaction + db.session.rollback() + except Exception as e: + logger.warning("Error in finally block: %s", e) return wrapper # type: ignore[return-value] diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py new file mode 100644 index 00000000000..633e3c6a9c6 --- /dev/null +++ b/superset/mcp_service/chart/chart_utils.py @@ -0,0 +1,484 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Shared chart utilities for MCP tools + +This module contains shared logic for chart configuration mapping and explore link +generation that can be used by both generate_chart and generate_explore_link tools. +""" + +from typing import Any, Dict + +from superset.mcp_service.chart.schemas import ( + ChartCapabilities, + ChartSemantics, + ColumnRef, + TableChartConfig, + XYChartConfig, +) +from superset.mcp_service.utils.url_utils import get_superset_base_url +from superset.utils import json + + +def generate_explore_link(dataset_id: int | str, form_data: Dict[str, Any]) -> str: + """Generate an explore link for the given dataset and form data.""" + base_url = get_superset_base_url() + numeric_dataset_id = None + + try: + from superset.commands.explore.form_data.parameters import CommandParameters + + # Find the dataset to get its numeric ID + from superset.daos.dataset import DatasetDAO + from superset.mcp_service.commands.create_form_data import ( + MCPCreateFormDataCommand, + ) + from superset.utils.core import DatasourceType + + dataset = None + + if isinstance(dataset_id, int) or ( + isinstance(dataset_id, str) and dataset_id.isdigit() + ): + numeric_dataset_id = ( + int(dataset_id) if isinstance(dataset_id, str) else dataset_id + ) + dataset = DatasetDAO.find_by_id(numeric_dataset_id) + else: + # Try UUID lookup using DAO flexible method + dataset = DatasetDAO.find_by_id(dataset_id, id_column="uuid") + if dataset: + numeric_dataset_id = dataset.id + + if not dataset or numeric_dataset_id is None: + # Fallback to basic explore URL + return ( + f"{base_url}/explore/?datasource_type=table&datasource_id={dataset_id}" + ) + + # Add datasource to form_data + form_data_with_datasource = { + **form_data, + "datasource": f"{numeric_dataset_id}__table", + } + + # Try to create form_data in cache using MCP-specific CreateFormDataCommand + cmd_params = CommandParameters( + datasource_type=DatasourceType.TABLE, + datasource_id=numeric_dataset_id, + chart_id=0, # 0 for new charts + tab_id=None, + form_data=json.dumps(form_data_with_datasource), + ) + + # Create the form_data cache entry and get the key + form_data_key = MCPCreateFormDataCommand(cmd_params).run() + + # Return URL with just the form_data_key + return f"{base_url}/explore/?form_data_key={form_data_key}" + + except Exception: + # Fallback to basic explore URL with numeric ID if available + if numeric_dataset_id is not None: + return ( + f"{base_url}/explore/?datasource_type=table" + f"&datasource_id={numeric_dataset_id}" + ) + else: + return ( + f"{base_url}/explore/?datasource_type=table&datasource_id={dataset_id}" + ) + + +def map_config_to_form_data( + config: TableChartConfig | XYChartConfig, +) -> Dict[str, Any]: + """Map chart config to Superset form_data.""" + if isinstance(config, TableChartConfig): + return map_table_config(config) + elif isinstance(config, XYChartConfig): + return map_xy_config(config) + else: + raise ValueError(f"Unsupported config type: {type(config)}") + + +def map_table_config(config: TableChartConfig) -> Dict[str, Any]: + """Map table chart config to form_data with defensive validation.""" + # Early validation to prevent empty charts + if not config.columns: + raise ValueError("Table chart must have at least one column") + + # Separate columns with aggregates from raw columns + raw_columns = [] + aggregated_metrics = [] + + for col in config.columns: + if col.aggregate: + # Column has aggregation - treat as metric + aggregated_metrics.append(create_metric_object(col)) + else: + # No aggregation - treat as raw column + raw_columns.append(col.name) + + # Final validation - ensure we have some data to display + if not raw_columns and not aggregated_metrics: + raise ValueError("Table chart configuration resulted in no displayable columns") + + form_data: Dict[str, Any] = { + "viz_type": "table", + } + + # Handle raw columns (no aggregation) + if raw_columns and not aggregated_metrics: + # Pure raw columns - show individual rows + form_data.update( + { + "all_columns": raw_columns, + "query_mode": "raw", + "include_time": False, + "order_desc": True, + "row_limit": 1000, # Reasonable limit for raw data + } + ) + + # Handle aggregated columns only + elif aggregated_metrics and not raw_columns: + # Pure aggregation - show totals + form_data.update( + { + "metrics": aggregated_metrics, + "query_mode": "aggregate", + } + ) + + # Handle mixed columns (raw + aggregated) + elif raw_columns and aggregated_metrics: + # Mixed mode - group by raw columns, aggregate metrics + form_data.update( + { + "all_columns": raw_columns, + "metrics": aggregated_metrics, + "groupby": raw_columns, + "query_mode": "aggregate", + } + ) + + if config.filters: + form_data["adhoc_filters"] = [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "subject": filter_config.column, + "operator": map_filter_operator(filter_config.op), + "comparator": filter_config.value, + } + for filter_config in config.filters + if filter_config is not None + ] + + if config.sort_by: + form_data["order_by_cols"] = config.sort_by + + return form_data + + +def create_metric_object(col: ColumnRef) -> Dict[str, Any]: + """Create a metric object for a column with enhanced validation.""" + # Ensure aggregate is valid - default to SUM if not specified or invalid + valid_aggregates = { + "SUM", + "COUNT", + "AVG", + "MIN", + "MAX", + "COUNT_DISTINCT", + "STDDEV", + "VAR", + "MEDIAN", + "PERCENTILE", + } + aggregate = col.aggregate or "SUM" + + # Validate aggregate function (final safety check) + if aggregate.upper() not in valid_aggregates: + aggregate = "SUM" # Safe fallback + + return { + "aggregate": aggregate.upper(), + "column": { + "column_name": col.name, + }, + "expressionType": "SIMPLE", + "label": col.label or f"{aggregate.upper()}({col.name})", + "optionName": f"metric_{col.name}", + "sqlExpression": None, + "hasCustomLabel": bool(col.label), + "datasourceWarning": False, + } + + +def add_axis_config(form_data: Dict[str, Any], config: XYChartConfig) -> None: + """Add axis configurations to form_data.""" + if config.x_axis: + if config.x_axis.title: + form_data["x_axis_title"] = config.x_axis.title + if config.x_axis.format: + form_data["x_axis_format"] = config.x_axis.format + + if config.y_axis: + if config.y_axis.title: + form_data["y_axis_title"] = config.y_axis.title + if config.y_axis.format: + form_data["y_axis_format"] = config.y_axis.format + if config.y_axis.scale == "log": + form_data["y_axis_scale"] = "log" + + +def add_legend_config(form_data: Dict[str, Any], config: XYChartConfig) -> None: + """Add legend configuration to form_data.""" + if config.legend: + if not config.legend.show: + form_data["show_legend"] = False + if config.legend.position: + form_data["legend_orientation"] = config.legend.position + + +def map_xy_config(config: XYChartConfig) -> Dict[str, Any]: + """Map XY chart config to form_data with defensive validation.""" + # Early validation to prevent empty charts + if not config.y: + raise ValueError("XY chart must have at least one Y-axis metric") + + # Map chart kind to viz_type + viz_type_map = { + "line": "echarts_timeseries_line", + "bar": "echarts_timeseries_bar", + "area": "echarts_area", + "scatter": "echarts_timeseries_scatter", + } + + # Convert Y columns to metrics with validation + metrics = [] + for col in config.y: + if not col.name.strip(): # Validate column name is not empty + raise ValueError("Y-axis column name cannot be empty") + metrics.append(create_metric_object(col)) + + # Final validation - ensure we have metrics to display + if not metrics: + raise ValueError("XY chart configuration resulted in no displayable metrics") + + form_data: Dict[str, Any] = { + "viz_type": viz_type_map.get(config.kind, "echarts_timeseries_line"), + "x_axis": config.x.name, + "metrics": metrics, + } + + # CRITICAL FIX: For time series charts, handle groupby carefully to avoid duplicates + # The x_axis field already tells Superset which column to use for time grouping + groupby_columns = [] + + # Only add groupby columns if there's an explicit group_by specified + # The x_axis column should NOT be duplicated in groupby as it causes + # "Duplicate column/metric labels" errors in Superset + # Only add group_by column if it's specified AND different from x_axis + # NEVER add the x_axis column to groupby as it creates duplicate labels + if config.group_by and config.group_by.name != config.x.name: + groupby_columns.append(config.group_by.name) + + # Set the groupby in form_data only if we have valid columns + # Don't set empty groupby - let Superset handle x_axis grouping automatically + if groupby_columns: + form_data["groupby"] = groupby_columns + + # Add filters if specified + if config.filters: + form_data["adhoc_filters"] = [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "subject": filter_config.column, + "operator": map_filter_operator(filter_config.op), + "comparator": filter_config.value, + } + for filter_config in config.filters + if filter_config is not None + ] + + # Add configurations + add_axis_config(form_data, config) + add_legend_config(form_data, config) + + return form_data + + +def map_filter_operator(op: str) -> str: + """Map filter operator to Superset format.""" + operator_map = { + "=": "==", + ">": ">", + "<": "<", + ">=": ">=", + "<=": "<=", + "!=": "!=", + } + return operator_map.get(op, op) + + +def generate_chart_name(config: TableChartConfig | XYChartConfig) -> str: + """Generate a chart name based on the configuration.""" + if isinstance(config, TableChartConfig): + return f"Table Chart - {', '.join(col.name for col in config.columns)}" + elif isinstance(config, XYChartConfig): + chart_type = config.kind.capitalize() + x_col = config.x.name + y_cols = ", ".join(col.name for col in config.y) + return f"{chart_type} Chart - {x_col} vs {y_cols}" + else: + return "Chart" + + +def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilities: + """Analyze chart capabilities based on type and configuration.""" + if chart: + viz_type = getattr(chart, "viz_type", "unknown") + else: + # Map config chart_type to viz_type + chart_type = getattr(config, "chart_type", "unknown") + if chart_type == "xy": + kind = getattr(config, "kind", "line") + viz_type_map = { + "line": "echarts_timeseries_line", + "bar": "echarts_timeseries_bar", + "area": "echarts_area", + "scatter": "echarts_timeseries_scatter", + } + viz_type = viz_type_map.get(kind, "echarts_timeseries_line") + elif chart_type == "table": + viz_type = "table" + else: + viz_type = "unknown" + + # Determine interaction capabilities based on chart type + interactive_types = [ + "echarts_timeseries_line", + "echarts_timeseries_bar", + "echarts_area", + "echarts_timeseries_scatter", + "deck_scatter", + "deck_hex", + ] + + supports_interaction = viz_type in interactive_types + supports_drill_down = viz_type in ["table", "pivot_table_v2"] + supports_real_time = viz_type in [ + "echarts_timeseries_line", + "echarts_timeseries_bar", + ] + + # Determine optimal formats + optimal_formats = ["url"] # Always include static image + if supports_interaction: + optimal_formats.extend(["interactive", "vega_lite"]) + optimal_formats.extend(["ascii", "table"]) + + # Classify data types + data_types = [] + if hasattr(config, "x") and config.x: + data_types.append("categorical" if not config.x.aggregate else "metric") + if hasattr(config, "y") and config.y: + data_types.extend(["metric"] * len(config.y)) + if "time" in viz_type or "timeseries" in viz_type: + data_types.append("time_series") + + return ChartCapabilities( + supports_interaction=supports_interaction, + supports_real_time=supports_real_time, + supports_drill_down=supports_drill_down, + supports_export=True, # All charts can be exported + optimal_formats=optimal_formats, + data_types=list(set(data_types)), + ) + + +def analyze_chart_semantics(chart: Any | None, config: Any) -> ChartSemantics: + """Generate semantic understanding of the chart.""" + if chart: + viz_type = getattr(chart, "viz_type", "unknown") + else: + # Map config chart_type to viz_type + chart_type = getattr(config, "chart_type", "unknown") + if chart_type == "xy": + kind = getattr(config, "kind", "line") + viz_type_map = { + "line": "echarts_timeseries_line", + "bar": "echarts_timeseries_bar", + "area": "echarts_area", + "scatter": "echarts_timeseries_scatter", + } + viz_type = viz_type_map.get(kind, "echarts_timeseries_line") + elif chart_type == "table": + viz_type = "table" + else: + viz_type = "unknown" + + # Generate primary insight based on chart type + insights_map = { + "echarts_timeseries_line": "Shows trends and changes over time", + "echarts_timeseries_bar": "Compares values across categories or time periods", + "table": "Displays detailed data in tabular format", + "pie": "Shows proportional relationships within a dataset", + "echarts_area": "Emphasizes cumulative totals and part-to-whole relationships", + } + + primary_insight = insights_map.get( + viz_type, f"Visualizes data using {viz_type} format" + ) + + # Generate data story + columns = [] + if hasattr(config, "x") and config.x: + columns.append(config.x.name) + if hasattr(config, "y") and config.y: + columns.extend([col.name for col in config.y]) + + if columns: + ellipsis = "..." if len(columns) > 3 else "" + data_story = ( + f"This {viz_type} chart analyzes {', '.join(columns[:3])}{ellipsis}" + ) + else: + data_story = "This chart provides insights into the selected dataset" + + # Generate recommended actions + recommended_actions = [ + "Review data patterns and trends", + "Consider filtering or drilling down for more detail", + "Export chart for reporting or sharing", + ] + + if viz_type in ["echarts_timeseries_line", "echarts_timeseries_bar"]: + recommended_actions.append("Analyze seasonal patterns or cyclical trends") + + return ChartSemantics( + primary_insight=primary_insight, + data_story=data_story, + recommended_actions=recommended_actions, + anomalies=[], # Would need actual data analysis to populate + statistical_summary={}, # Would need actual data analysis to populate + ) diff --git a/superset/mcp_service/chart/preview_utils.py b/superset/mcp_service/chart/preview_utils.py new file mode 100644 index 00000000000..3db475c0da1 --- /dev/null +++ b/superset/mcp_service/chart/preview_utils.py @@ -0,0 +1,561 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Preview utilities for chart generation without saving. + +This module provides utilities for generating chart previews +from form data without requiring a saved chart object. +""" + +import logging +from typing import Any, Dict, List + +from superset.commands.chart.data.get_data_command import ChartDataCommand +from superset.mcp_service.chart.schemas import ( + ASCIIPreview, + ChartError, + TablePreview, + VegaLitePreview, +) + +logger = logging.getLogger(__name__) + + +def generate_preview_from_form_data( + form_data: Dict[str, Any], dataset_id: int, preview_format: str +) -> Any: + """ + Generate preview from form data without a saved chart. + + Args: + form_data: Chart configuration form data + dataset_id: Dataset ID + preview_format: Preview format (ascii, table, etc.) + + Returns: + Preview object or ChartError + """ + try: + # Execute query to get data + from superset.connectors.sqla.models import SqlaTable + from superset.extensions import db + + dataset = db.session.query(SqlaTable).get(dataset_id) + if not dataset: + return ChartError( + error=f"Dataset {dataset_id} not found", error_type="DatasetNotFound" + ) + + # Create query context from form data using factory + from superset.common.query_context_factory import QueryContextFactory + + factory = QueryContextFactory() + query_context_obj = factory.create( + datasource={"id": dataset_id, "type": "table"}, + queries=[ + { + "columns": form_data.get("columns", []), + "metrics": form_data.get("metrics", []), + "orderby": form_data.get("orderby", []), + "row_limit": form_data.get("row_limit", 100), + "filters": form_data.get("adhoc_filters", []), + "time_range": form_data.get("time_range", "No filter"), + } + ], + form_data=form_data, + ) + + # Execute query + command = ChartDataCommand(query_context_obj) + result = command.run() + + if not result or not result.get("queries"): + return ChartError( + error="No data returned from query", error_type="EmptyResult" + ) + + query_result = result["queries"][0] + data = query_result.get("data", []) + + # Generate preview based on format + if preview_format == "ascii": + return _generate_ascii_preview_from_data(data, form_data) + elif preview_format == "table": + return _generate_table_preview_from_data(data, form_data) + elif preview_format == "vega_lite": + return _generate_vega_lite_preview_from_data(data, form_data) + else: + return ChartError( + error=f"Unsupported preview format: {preview_format}", + error_type="UnsupportedFormat", + ) + + except Exception as e: + logger.error("Preview generation from form data failed: %s", e) + return ChartError( + error=f"Failed to generate preview: {str(e)}", error_type="PreviewError" + ) + + +def _generate_ascii_preview_from_data( + data: List[Dict[str, Any]], form_data: Dict[str, Any] +) -> ASCIIPreview: + """Generate ASCII preview from raw data.""" + viz_type = form_data.get("viz_type", "table") + + # Handle different chart types + if viz_type in ["bar", "dist_bar", "column"]: + content = _generate_safe_ascii_bar_chart(data) + elif viz_type in ["line", "area"]: + content = _generate_safe_ascii_line_chart(data) + elif viz_type == "pie": + content = _generate_safe_ascii_pie_chart(data) + else: + content = _generate_safe_ascii_table(data) + + return ASCIIPreview( + ascii_content=content, width=80, height=20, supports_color=False + ) + + +def _calculate_column_widths( + display_columns: List[str], data: List[Dict[str, Any]] +) -> Dict[str, int]: + """Calculate optimal width for each column.""" + column_widths = {} + for col in display_columns: + # Start with column name length + max_width = len(str(col)) + + # Check data values to determine width + for row in data[:20]: # Sample first 20 rows + val = row.get(col, "") + if isinstance(val, float): + val_str = f"{val:.2f}" + elif isinstance(val, int): + val_str = str(val) + else: + val_str = str(val) + max_width = max(max_width, len(val_str)) + + # Set reasonable bounds + column_widths[col] = min(max(max_width, 8), 25) + return column_widths + + +def _format_value(val: Any, width: int) -> str: + """Format a value based on its type.""" + if isinstance(val, float): + if abs(val) >= 1000000: + val_str = f"{val:.2e}" # Scientific notation for large numbers + elif abs(val) >= 1000: + val_str = f"{val:,.2f}" # Thousands separator + else: + val_str = f"{val:.2f}" + elif isinstance(val, int): + if abs(val) >= 1000: + val_str = f"{val:,}" # Thousands separator + else: + val_str = str(val) + elif val is None: + val_str = "NULL" + else: + val_str = str(val) + + # Truncate if too long + if len(val_str) > width: + val_str = val_str[: width - 2] + ".." + return val_str + + +def _generate_table_preview_from_data( + data: List[Dict[str, Any]], form_data: Dict[str, Any] +) -> TablePreview: + """Generate table preview from raw data with improved formatting.""" + if not data: + return TablePreview( + table_data="No data available", row_count=0, supports_sorting=False + ) + + # Get columns + columns = list(data[0].keys()) if data else [] + + # Determine optimal column widths and how many columns to show + max_columns = 8 # Show more columns than before + display_columns = columns[:max_columns] + + # Calculate optimal width for each column + column_widths = _calculate_column_widths(display_columns, data) + + # Format table with proper alignment + lines = ["Table Preview", "=" * 80] + + # Header with dynamic width + header_parts = [] + separator_parts = [] + for col in display_columns: + width = column_widths[col] + col_name = str(col) + if len(col_name) > width: + col_name = col_name[: width - 2] + ".." + header_parts.append(f"{col_name:<{width}}") + separator_parts.append("-" * width) + + lines.append(" | ".join(header_parts)) + lines.append("-+-".join(separator_parts)) + + # Data rows with proper formatting + rows_shown = min(len(data), 15) # Show more rows + for row in data[:rows_shown]: + row_parts = [] + for col in display_columns: + width = column_widths[col] + val = row.get(col, "") + val_str = _format_value(val, width) + row_parts.append(f"{val_str:<{width}}") + lines.append(" | ".join(row_parts)) + + # Summary information + if len(data) > rows_shown: + lines.append(f"... and {len(data) - rows_shown} more rows") + + if len(columns) > max_columns: + lines.append(f"... and {len(columns) - max_columns} more columns") + + lines.append("") + lines.append(f"Total: {len(data)} rows ร— {len(columns)} columns") + + return TablePreview( + table_data="\n".join(lines), row_count=len(data), supports_sorting=True + ) + + +def _generate_safe_ascii_bar_chart(data: List[Dict[str, Any]]) -> str: + """Generate ASCII bar chart with proper error handling.""" + if not data: + return "No data available for bar chart" + + lines = ["ASCII Bar Chart", "=" * 50] + + # Extract values safely + values = [] + labels = [] + + for row in data[:10]: + label = None + value = None + + for _, val in row.items(): + if isinstance(val, (int, float)) and not _is_nan(val) and value is None: + value = val + elif isinstance(val, str) and label is None: + label = val + + if value is not None: + values.append(value) + labels.append(label or f"Item {len(values)}") + + if not values: + return "No numeric data found for bar chart" + + # Generate bars + max_val = max(values) + if max_val == 0: + return "All values are zero" + + for label, value in zip(labels, values, strict=False): + bar_length = int((value / max_val) * 30) + bar = "โ–ˆ" * bar_length + lines.append(f"{label[:10]:>10} |{bar:<30} {value:.2f}") + + return "\n".join(lines) + + +def _generate_safe_ascii_line_chart(data: List[Dict[str, Any]]) -> str: + """Generate ASCII line chart with proper NaN handling.""" + if not data: + return "No data available for line chart" + + lines = ["ASCII Line Chart", "=" * 50] + values = _extract_numeric_values_safe(data) + + if not values: + return "No valid numeric data found for line chart" + + range_str = _format_range_display(values) + lines.append(range_str) + + sparkline = _generate_sparkline_safe(values) + lines.append(sparkline) + + return "\n".join(lines) + + +def _extract_numeric_values_safe(data: List[Dict[str, Any]]) -> List[float]: + """Extract numeric values safely from data.""" + values = [] + for row in data[:20]: + for _, val in row.items(): + if isinstance(val, (int, float)) and not _is_nan(val): + values.append(val) + break + return values + + +def _format_range_display(values: List[float]) -> str: + """Format range display safely.""" + min_val = min(values) + max_val = max(values) + + if _is_nan(min_val) or _is_nan(max_val): + return "Range: Unable to calculate" + else: + return f"Range: {min_val:.2f} to {max_val:.2f}" + + +def _generate_sparkline_safe(values: List[float]) -> str: + """Generate sparkline from values.""" + if not values: + return "" + + min_val = min(values) + + if (max_val := max(values)) != min_val: + sparkline = "" + for val in values: + normalized = (val - min_val) / (max_val - min_val) + if normalized < 0.2: + sparkline += "โ–" + elif normalized < 0.4: + sparkline += "โ–‚" + elif normalized < 0.6: + sparkline += "โ–„" + elif normalized < 0.8: + sparkline += "โ–†" + else: + sparkline += "โ–ˆ" + return sparkline + else: + return "โ”€" * len(values) # Flat line if all values are same + + +def _generate_safe_ascii_pie_chart(data: List[Dict[str, Any]]) -> str: + """Generate ASCII pie chart representation.""" + if not data: + return "No data available for pie chart" + + lines = ["ASCII Pie Chart", "=" * 50] + + # Extract values and labels + values = [] + labels = [] + + for row in data[:8]: # Limit slices + label = None + value = None + + for _, val in row.items(): + if isinstance(val, (int, float)) and not _is_nan(val) and value is None: + value = val + elif isinstance(val, str) and label is None: + label = val + + if value is not None and value > 0: + values.append(value) + labels.append(label or f"Slice {len(values)}") + + if not values: + return "No valid data for pie chart" + + # Calculate percentages + total = sum(values) + if total == 0: + return "Total is zero" + + for label, value in zip(labels, values, strict=False): + percentage = (value / total) * 100 + bar_length = int(percentage / 3) # Scale to fit + bar = "โ—" * bar_length + lines.append(f"{label[:15]:>15}: {bar} {percentage:.1f}%") + + return "\n".join(lines) + + +def _generate_safe_ascii_table(data: List[Dict[str, Any]]) -> str: + """Generate ASCII table with safe formatting.""" + if not data: + return "No data available" + + lines = ["Data Table", "=" * 50] + + # Get columns + columns = list(data[0].keys()) if data else [] + + # Format header + header = " | ".join(str(col)[:10] for col in columns[:5]) + lines.append(header) + lines.append("-" * len(header)) + + # Format rows + for row in data[:10]: + row_str = " | ".join(str(row.get(col, ""))[:10] for col in columns[:5]) + lines.append(row_str) + + if len(data) > 10: + lines.append(f"... {len(data) - 10} more rows") + + return "\n".join(lines) + + +def _is_nan(value: Any) -> bool: + """Check if a value is NaN.""" + try: + import math + + return math.isnan(float(value)) + except (ValueError, TypeError): + return False + + +def _generate_vega_lite_preview_from_data( # noqa: C901 + data: List[Dict[str, Any]], form_data: Dict[str, Any] +) -> VegaLitePreview: + """Generate Vega-Lite preview from raw data and form_data.""" + viz_type = form_data.get("viz_type", "table") + + # Map Superset viz types to Vega-Lite marks + viz_to_mark = { + "echarts_timeseries_line": "line", + "echarts_timeseries_bar": "bar", + "echarts_area": "area", + "echarts_timeseries_scatter": "point", + "bar": "bar", + "line": "line", + "area": "area", + "scatter": "point", + "pie": "arc", + "table": "text", + } + + mark = viz_to_mark.get(viz_type, "bar") + + # Basic Vega-Lite spec + spec = { + "$schema": "https://vega.github.io/schema/vega-lite/v5.json", + "data": {"values": data}, + "mark": mark, + } + + # Get x_axis and metrics from form_data + x_axis = form_data.get("x_axis") + metrics = form_data.get("metrics", []) + groupby = form_data.get("groupby", []) + + # Build encoding based on available fields + encoding = {} + + # Handle X-axis + if x_axis and x_axis in (data[0] if data else {}): + # Detect field type from data + field_type = "nominal" # default + if data and len(data) > 0: + sample_val = data[0].get(x_axis) + if isinstance(sample_val, str): + # Check if it's a date/time + if any(char in str(sample_val) for char in ["-", "/", ":"]): + field_type = "temporal" + else: + field_type = "nominal" + elif isinstance(sample_val, (int, float)): + field_type = "quantitative" + + encoding["x"] = { + "field": x_axis, + "type": field_type, + "title": x_axis, + } + + # Handle Y-axis (metrics) + if metrics and data: + # Find the first metric column in the data + metric_col = None + for col in data[0].keys(): + # Check if this is a metric column (usually has aggregation in name) + if any( + agg in str(col).upper() + for agg in ["SUM", "AVG", "COUNT", "MIN", "MAX", "TOTAL"] + ): + metric_col = col + break + # Or check if it's numeric + elif isinstance(data[0].get(col), (int, float)): + metric_col = col + break + + if metric_col: + encoding["y"] = { + "field": metric_col, + "type": "quantitative", + "title": metric_col, + } + + # Handle color encoding for groupby + if groupby and len(groupby) > 0 and groupby[0] in (data[0] if data else {}): + encoding["color"] = { + "field": groupby[0], + "type": "nominal", + "title": groupby[0], + } + + # Special handling for pie charts + if mark == "arc" and data: + # For pie charts, we need theta encoding + if "y" in encoding: + encoding["theta"] = encoding.pop("y") + encoding["theta"]["stack"] = True + if "x" in encoding: + # Use x as color for pie + encoding["color"] = { + "field": encoding["x"]["field"], + "type": "nominal", + } + del encoding["x"] + + # Add encoding to spec + if encoding: + spec["encoding"] = encoding + + # Add responsive sizing - Vega-Lite supports "container" as a special width value + spec["width"] = "container" + spec["height"] = 400 # type: ignore + + # Add interactivity + if mark in ["line", "point", "bar", "area"]: + spec["selection"] = { + "highlight": { + "type": "single", + "on": "mouseover", + "empty": "none", + } + } + + return VegaLitePreview( + specification=spec, + data_url=None, + supports_streaming=False, + ) diff --git a/superset/mcp_service/chart/prompts/__init__.py b/superset/mcp_service/chart/prompts/__init__.py new file mode 100644 index 00000000000..ce7fa27ec7f --- /dev/null +++ b/superset/mcp_service/chart/prompts/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Chart prompts for Superset MCP service""" + +# Import to register prompts when module is imported +from . import create_chart_guided # noqa: F401 diff --git a/superset/mcp_service/chart/prompts/create_chart_guided.py b/superset/mcp_service/chart/prompts/create_chart_guided.py new file mode 100644 index 00000000000..7213b856f4c --- /dev/null +++ b/superset/mcp_service/chart/prompts/create_chart_guided.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Chart prompts for visualization guidance +""" + +import logging + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook + +logger = logging.getLogger(__name__) + + +@mcp.prompt("create_chart_guided") +@mcp_auth_hook +async def create_chart_guided_prompt( + chart_type: str = "auto", business_goal: str = "exploration" +) -> str: + """ + AI-powered chart creation guide following Anthropic's agent design principles. + + This prompt implements: + - Transparency: Clear reasoning at each step + - Proactive Intelligence: Suggests insights before being asked + - Context Awareness: Maintains conversational flow + - Business Focus: Translates data into actionable insights + - Validation: Verifies choices before proceeding + - Natural Interaction: Conversational, not configuration-driven + + Args: + chart_type: Preferred chart type (auto, line, bar, pie, table, scatter, area) + business_goal: Purpose (exploration, reporting, monitoring, presentation) + """ + + # Enhanced chart intelligence with business context + chart_intelligence = { + "line": { + "description": "Time series visualization for trend analysis", + "best_for": "Tracking performance over time, identifying patterns", + "business_value": "Reveals growth trends, seasonality, and patterns", + "data_requirements": "Temporal column + continuous metrics", + }, + "bar": { + "description": "Category comparison visualization", + "best_for": "Ranking, comparisons, and performance by category", + "business_value": "Identifies top performers, bottlenecks, and gaps", + "data_requirements": "Categorical dimensions + aggregatable metrics", + }, + "scatter": { + "description": "Correlation and relationship analysis", + "best_for": "Finding relationships, outlier detection, clustering", + "business_value": "Uncovers hidden correlations and identifies anomalies", + "data_requirements": "Two continuous variables, optional grouping", + }, + "table": { + "description": "Detailed data exploration and exact values", + "best_for": "Detailed analysis, data validation, precise values", + "business_value": "Provides granular insights and detailed reporting", + "data_requirements": "Any combination of dimensions and metrics", + }, + "area": { + "description": "Volume and composition over time", + "best_for": "Showing cumulative effects, stacked comparisons", + "business_value": "Visualizes contribution and total volume trends", + "data_requirements": "Temporal dimension + stackable metrics", + }, + "auto": { + "description": "AI-powered visualization recommendation", + "best_for": "When you're not sure what chart type to use", + "business_value": "Optimizes chart choice based on data characteristics", + "data_requirements": "I'll analyze your data and recommend the best type", + }, + } + + # Business context intelligence + goal_intelligence = { + "exploration": { + "approach": "Interactive discovery and pattern finding", + "features": "Filters, drill-downs, multiple perspectives", + "outcome": "Uncover hidden insights and generate hypotheses", + }, + "reporting": { + "approach": "Clear, professional, and consistent presentation", + "features": "Clean design, appropriate aggregation, clear labels", + "outcome": "Reliable, repeatable business reporting", + }, + "monitoring": { + "approach": "Real-time tracking with clear thresholds", + "features": "Alert conditions, trend indicators, key metrics", + "outcome": "Proactive issue detection and performance tracking", + }, + "presentation": { + "approach": "Compelling visual storytelling", + "features": "Engaging colors, clear messaging, audience-appropriate detail", + "outcome": "Persuasive data-driven presentations for stakeholders", + }, + } + + selected_chart = chart_intelligence.get(chart_type, chart_intelligence["auto"]) + selected_goal = goal_intelligence.get( + business_goal, goal_intelligence["exploration"] + ) + + return f"""๐ŸŽฏ **AI-Powered Chart Creation Assistant** + +I'm your intelligent data visualization partner! Let me help you create charts. + +**Your Visualization Goal:** +๐Ÿ“Š **Chart Focus**: {chart_type.title()} - {selected_chart["description"]} +๐ŸŽฏ **Business Purpose**: {business_goal.title()} - {selected_goal["approach"]} +๐Ÿ’ก **Expected Value**: {selected_chart["business_value"]} + +--- + +## ๐Ÿš€ My Intelligent Approach + +### **Phase 1: Data Intelligence** ๐Ÿ“Š +I'll automatically analyze your dataset to understand: +- **Data characteristics** (types, distributions, quality) +- **Business relationships** (correlations, hierarchies, trends) +- **Visualization opportunities** (what stories your data can tell) +- **Performance considerations** (size, complexity, aggregation needs) + +*Why this matters: The right chart depends on your data's unique characteristics* + +### **Phase 2: Smart Recommendations** ๐Ÿง  +Based on your data analysis, I'll: +- **Recommend optimal chart types** with confidence scores and reasoning +- **Suggest meaningful metrics** that align with your business goal +- **Identify interesting patterns** you might want to highlight +- **Propose filters** to focus on what matters most + +*Why this matters: I'll spot opportunities you might miss and save you time* + +### **Phase 3: Intelligent Configuration** โš™๏ธ +I'll configure your chart with: +- **Business-appropriate aggregations** (daily, weekly, monthly for time series) +- **Meaningful labels and formatting** (currency, percentages, readable names) +- **Performance optimizations** (appropriate limits, caching strategies) +- **Visual best practices** (colors, scales, legends that enhance understanding) + +*Why this matters: Proper configuration makes charts both beautiful and actionable* + +### **Phase 4: Validation & Refinement** ๐ŸŽฏ +Before finalizing, I'll: +- **Verify the chart answers your business question** +- **Check data quality and completeness** +- **Suggest improvements** based on visualization best practices +- **Provide preview** so you can see exactly what you're getting + +*Why this matters: Great charts require iteration and validation* + +--- + +## ๐ŸŽฌ Let's Begin Your Data Story + +I'm ready to be your proactive data exploration partner. Here's how we can start: + +**Option 1: Quick Start** โšก +Tell me: *"What business question are you trying to answer?"* +(e.g., "How are our sales trending?" or "Which products perform best?") + +**Option 2: Dataset Exploration** ๐Ÿ” +I can show you available datasets: `list_datasets` +Or explore a specific one: `get_dataset_info [dataset_id]` + +**Option 3: Visual Inspiration** ๐ŸŽจ +Browse pre-built chart configurations: `superset://chart/configs` resource +Perfect for when you want to see examples of great charts! + +**Option 4: Autonomous Discovery** ๐Ÿค– +Just point me to a dataset and say *"Find something interesting"* +I'll explore autonomously and surface the most compelling insights! + +--- + +๐Ÿ’ก **Pro Tip**: Great charts combine business intuition with data analysis! + +**What's your data challenge today?** ๐Ÿš€""" diff --git a/superset/mcp_service/chart/resources/__init__.py b/superset/mcp_service/chart/resources/__init__.py new file mode 100644 index 00000000000..4b886b25762 --- /dev/null +++ b/superset/mcp_service/chart/resources/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Chart resources for Superset MCP service""" + +# Import to register resources when module is imported +from . import chart_configs # noqa: F401 diff --git a/superset/mcp_service/chart/resources/chart_configs.py b/superset/mcp_service/chart/resources/chart_configs.py new file mode 100644 index 00000000000..49eb109ae80 --- /dev/null +++ b/superset/mcp_service/chart/resources/chart_configs.py @@ -0,0 +1,362 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Chart resources for valid ChartConfig examples and templates +""" + +import logging + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook + +logger = logging.getLogger(__name__) + + +@mcp.resource("superset://chart/configs") +@mcp_auth_hook +def get_chart_configs_resource() -> str: + """ + Provide valid ChartConfig examples that match the exact schema. + + This resource gives LLMs access to: + - Valid ChartConfig examples for XYChartConfig and TableChartConfig + - Working configurations that pass schema validation + - Examples with proper ColumnRef, FilterConfig, AxisConfig, and LegendConfig + - Best practices for each chart type configuration + """ + + # Valid XYChartConfig examples - these match the exact schema + xy_chart_configs = { + "line_chart": { + "description": "Basic line chart for time series analysis", + "config": { + "chart_type": "xy", + "kind": "line", + "x": {"name": "created_on", "label": "Date Created"}, + "y": [ + { + "name": "count_metric", + "aggregate": "COUNT", + "label": "Total Count", + } + ], + }, + "use_cases": [ + "Time series trends", + "Historical analysis", + "Growth tracking", + ], + }, + "bar_chart": { + "description": "Bar chart for category comparison", + "config": { + "chart_type": "xy", + "kind": "bar", + "x": {"name": "category", "label": "Category"}, + "y": [{"name": "sales", "aggregate": "SUM", "label": "Total Sales"}], + "x_axis": {"title": "Product Categories", "scale": "linear"}, + "y_axis": { + "title": "Revenue ($)", + "format": "$,.0f", + "scale": "linear", + }, + }, + "use_cases": ["Category comparison", "Rankings", "Performance metrics"], + }, + "multi_metric_line": { + "description": "Multi-metric line chart with grouping", + "config": { + "chart_type": "xy", + "kind": "line", + "x": {"name": "date_column", "label": "Date"}, + "y": [ + {"name": "revenue", "aggregate": "SUM", "label": "Revenue"}, + { + "name": "users", + "aggregate": "COUNT_DISTINCT", + "label": "Unique Users", + }, + ], + "group_by": {"name": "region", "label": "Region"}, + "legend": {"show": True, "position": "right"}, + "filters": [{"column": "status", "op": "=", "value": "active"}], + }, + "use_cases": [ + "Multi-dimensional analysis", + "Regional comparisons", + "KPI tracking", + ], + }, + "scatter_plot": { + "description": "Scatter plot for correlation analysis", + "config": { + "chart_type": "xy", + "kind": "scatter", + "x": { + "name": "advertising_spend", + "aggregate": "AVG", + "label": "Avg Ad Spend", + }, + "y": [ + { + "name": "conversion_rate", + "aggregate": "AVG", + "label": "Avg Conversion Rate", + } + ], + "group_by": {"name": "campaign_type", "label": "Campaign Type"}, + "x_axis": {"title": "Average Advertising Spend", "format": "$,.0f"}, + "y_axis": {"title": "Conversion Rate", "format": ".2%"}, + }, + "use_cases": [ + "Correlation analysis", + "Outlier detection", + "Performance relationships", + ], + }, + "area_chart": { + "description": "Area chart for volume visualization", + "config": { + "chart_type": "xy", + "kind": "area", + "x": {"name": "month", "label": "Month"}, + "y": [ + {"name": "signups", "aggregate": "SUM", "label": "Monthly Signups"} + ], + "filters": [ + {"column": "year", "op": ">=", "value": 2023}, + {"column": "active", "op": "=", "value": True}, + ], + }, + "use_cases": ["Volume trends", "Cumulative metrics", "Stacked comparisons"], + }, + } + + # Valid TableChartConfig examples - these match the exact schema + table_chart_configs = { + "basic_table": { + "description": "Basic data table with multiple columns", + "config": { + "chart_type": "table", + "columns": [ + {"name": "name", "label": "Customer Name"}, + {"name": "email", "label": "Email Address"}, + {"name": "orders", "aggregate": "COUNT", "label": "Total Orders"}, + {"name": "revenue", "aggregate": "SUM", "label": "Total Revenue"}, + ], + "sort_by": ["Total Revenue"], + }, + "use_cases": [ + "Detailed data views", + "Customer lists", + "Transaction records", + ], + }, + "aggregated_table": { + "description": "Table with aggregated metrics and filters", + "config": { + "chart_type": "table", + "columns": [ + {"name": "region", "label": "Sales Region"}, + { + "name": "sales_amount", + "aggregate": "SUM", + "label": "Total Sales", + }, + { + "name": "sales_amount", + "aggregate": "AVG", + "label": "Average Sale", + }, + { + "name": "customer_id", + "aggregate": "COUNT_DISTINCT", + "label": "Unique Customers", + }, + ], + "filters": [ + {"column": "sale_date", "op": ">=", "value": "2024-01-01"}, + {"column": "status", "op": "!=", "value": "cancelled"}, + ], + "sort_by": ["Total Sales", "Sales Region"], + }, + "use_cases": ["Summary reports", "Regional analysis", "Performance tables"], + }, + } + + # Schema reference for developers + schema_reference = { + "ChartConfig": { + "description": "Union type - XYChartConfig or TableChartConfig by type", + "discriminator": "chart_type", + "types": ["xy", "table"], + }, + "XYChartConfig": { + "required_fields": ["chart_type", "x", "y"], + "optional_fields": [ + "kind", + "group_by", + "x_axis", + "y_axis", + "legend", + "filters", + ], + "chart_type": "xy", + "kind_options": ["line", "bar", "area", "scatter"], + "validation_rules": [ + "All column labels must be unique across x, y, and group_by", + "Y-axis must have at least one column", + "Column names must match pattern: ^[a-zA-Z0-9_][a-zA-Z0-9_\\s\\-\\.]*$", + ], + }, + "TableChartConfig": { + "required_fields": ["chart_type", "columns"], + "optional_fields": ["filters", "sort_by"], + "chart_type": "table", + "validation_rules": [ + "Must have at least one column", + "All column labels must be unique", + "Column names must match pattern: ^[a-zA-Z0-9_][a-zA-Z0-9_\\s\\-\\.]*$", + ], + }, + "ColumnRef": { + "required_fields": ["name"], + "optional_fields": ["label", "dtype", "aggregate"], + "aggregate_options": [ + "SUM", + "COUNT", + "AVG", + "MIN", + "MAX", + "COUNT_DISTINCT", + "STDDEV", + "VAR", + "MEDIAN", + "PERCENTILE", + ], + "validation_rules": [ + "Name cannot be empty and must follow pattern", + "Labels are HTML-escaped to prevent XSS", + "Aggregates are validated against allowed functions", + ], + }, + "FilterConfig": { + "required_fields": ["column", "op", "value"], + "operator_options": ["=", ">", "<", ">=", "<=", "!="], + "value_types": ["string", "number", "boolean"], + "validation_rules": [ + "Column names are sanitized to prevent injection", + "Values are checked for malicious patterns", + "String values are HTML-escaped", + ], + }, + "AxisConfig": { + "optional_fields": ["title", "scale", "format"], + "scale_options": ["linear", "log"], + "format_examples": ["$,.2f", ".2%", ",.0f", ".1f"], + }, + "LegendConfig": { + "optional_fields": ["show", "position"], + "show_default": True, + "position_options": ["top", "bottom", "left", "right"], + "position_default": "right", + }, + } + + # Best practices for each configuration type + best_practices = { + "xy_charts": [ + "Use descriptive labels for axes and metrics", + "Choose appropriate aggregation functions for your data", + "Limit the number of Y-axis metrics (3-5 maximum)", + "Use filters to focus on relevant data", + "Configure axis formatting for better readability", + "Consider grouping when comparing categories", + "Use chart kinds: line for trends, bar for comparisons, scatter plots", + ], + "table_charts": [ + "Include essential columns only to avoid clutter", + "Use meaningful column labels", + "Apply sorting to highlight important data", + "Use filters to limit result sets", + "Mix dimensions and aggregated metrics appropriately", + "Ensure unique labels to avoid conflicts", + "Consider performance with large datasets", + ], + "general": [ + "Always specify chart_type as the first field", + "Use consistent naming conventions for columns", + "Validate column names exist in your dataset", + "Test configurations with actual data", + "Consider caching for frequently accessed charts", + "Apply security best practices - avoid user input in column names", + ], + } + + # Common patterns and examples + common_patterns = { + "time_series": { + "description": "Standard time-based analysis", + "x_column_types": ["date", "datetime", "timestamp"], + "recommended_aggregations": ["SUM", "COUNT", "AVG"], + "best_chart_types": ["line", "area", "bar"], + }, + "categorical_analysis": { + "description": "Comparing discrete categories", + "x_column_types": ["string", "category", "enum"], + "recommended_aggregations": ["SUM", "COUNT", "COUNT_DISTINCT", "AVG"], + "best_chart_types": ["bar", "table"], + }, + "correlation_analysis": { + "description": "Finding relationships between variables", + "requirements": ["Two numerical metrics"], + "recommended_aggregations": ["AVG", "SUM", "MEDIAN"], + "best_chart_types": ["scatter"], + }, + } + + resource_data = { + "xy_chart_configs": xy_chart_configs, + "table_chart_configs": table_chart_configs, + "schema_reference": schema_reference, + "best_practices": best_practices, + "common_patterns": common_patterns, + "metadata": { + "version": "1.0", + "schema_version": "ChartConfig v1.0", + "last_updated": "2025-08-07", + "usage_notes": [ + "All examples are valid ChartConfig objects that pass validation", + "Copy these configurations directly into generate_chart requests", + "Modify column names and labels to match your actual dataset", + "Test configurations with get_dataset_info to verify columns", + "All examples follow security best practices and input validation", + ], + "validation_info": [ + "Column names must match: ^[a-zA-Z0-9_][a-zA-Z0-9_\\s\\-\\.]*$", + "Labels are automatically HTML-escaped for security", + "Filter values are sanitized to prevent injection attacks", + "All field lengths are validated against schema limits", + "Duplicate labels are automatically detected and rejected", + ], + }, + } + + from superset.utils import json + + return json.dumps(resource_data, indent=2) diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index a2afd40a447..b7ccbb363ba 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -21,6 +21,8 @@ Pydantic schemas for chart-related responses from __future__ import annotations +import html +import re from datetime import datetime, timezone from typing import Annotated, Any, Dict, List, Literal, Protocol @@ -28,12 +30,19 @@ from pydantic import ( BaseModel, ConfigDict, Field, + field_validator, model_validator, PositiveInt, ) from superset.daos.base import ColumnOperator, ColumnOperatorEnum -from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.common.cache_schemas import ( + CacheStatus, + FormDataCacheControl, + MetadataCacheControl, + QueryCacheControl, +) +from superset.mcp_service.common.error_schemas import ChartGenerationError from superset.mcp_service.system.schemas import ( PaginationInfo, TagInfo, @@ -101,6 +110,25 @@ class ChartInfo(BaseModel): model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") +class GetChartAvailableFiltersRequest(BaseModel): + """ + Request schema for get_chart_available_filters tool. + + Currently has no parameters but provides consistent API for future extensibility. + """ + + model_config = ConfigDict( + extra="forbid", + str_strip_whitespace=True, + ) + + +class ChartAvailableFiltersResponse(BaseModel): + column_operators: Dict[str, Any] = Field( + ..., description="Available filter operators and metadata for each column" + ) + + class ChartError(BaseModel): error: str = Field(..., description="Error message") error_type: str = Field(..., description="Type of error") @@ -111,6 +139,63 @@ class ChartError(BaseModel): model_config = ConfigDict(ser_json_timedelta="iso8601") +class ChartCapabilities(BaseModel): + """Describes what the chart can do for LLM understanding.""" + + supports_interaction: bool = Field(description="Chart supports user interaction") + supports_real_time: bool = Field(description="Chart supports live data updates") + supports_drill_down: bool = Field( + description="Chart supports drill-down navigation" + ) + supports_export: bool = Field(description="Chart can be exported to other formats") + optimal_formats: List[str] = Field(description="Recommended preview formats") + data_types: List[str] = Field( + description="Types of data shown (time_series, categorical, etc)" + ) + + +class ChartSemantics(BaseModel): + """Semantic information for LLM reasoning.""" + + primary_insight: str = Field( + description="Main insight or pattern the chart reveals" + ) + data_story: str = Field(description="Narrative description of what the data shows") + recommended_actions: List[str] = Field( + description="Suggested next steps based on data" + ) + anomalies: List[str] = Field(description="Notable outliers or unusual patterns") + statistical_summary: Dict[str, Any] = Field( + description="Key statistics (mean, median, trends)" + ) + + +class PerformanceMetadata(BaseModel): + """Performance information for LLM cost understanding.""" + + query_duration_ms: int = Field(description="Query execution time") + estimated_cost: str | None = Field(None, description="Resource cost estimate") + cache_status: str = Field(description="Cache hit/miss status") + optimization_suggestions: List[str] = Field( + default_factory=list, description="Performance improvement tips" + ) + + +class AccessibilityMetadata(BaseModel): + """Accessibility information for inclusive visualization.""" + + color_blind_safe: bool = Field(description="Uses colorblind-safe palette") + alt_text: str = Field(description="Screen reader description") + high_contrast_available: bool = Field(description="High contrast version available") + + +class VersionedResponse(BaseModel): + """Base class for versioned API responses.""" + + schema_version: str = Field("2.0", description="Response schema version") + api_version: str = Field("v1", description="MCP API version") + + class GetChartInfoRequest(BaseModel): """Request schema for get_chart_info with support for ID or UUID.""" @@ -124,12 +209,13 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None: if not chart: return None - # TODO (Phase 3): Generate MCP service screenshot URL - # For now, use chart's native URL instead of screenshot URL - # Screenshot functionality will be added in Phase 3 PR + # Generate MCP service screenshot URL instead of chart's native URL + from superset.mcp_service.utils.url_utils import get_chart_screenshot_url chart_id = getattr(chart, "id", None) - chart_url = getattr(chart, "url", None) + screenshot_url = None + if chart_id: + screenshot_url = get_chart_screenshot_url(chart_id) return ChartInfo( id=chart_id, @@ -137,7 +223,7 @@ def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None: viz_type=getattr(chart, "viz_type", None), datasource_name=getattr(chart, "datasource_name", None), datasource_type=getattr(chart, "datasource_type", None), - url=chart_url, + url=screenshot_url, description=getattr(chart, "description", None), cache_timeout=getattr(chart, "cache_timeout", None), form_data=getattr(chart, "form_data", None), @@ -214,6 +300,409 @@ class ChartList(BaseModel): model_config = ConfigDict(ser_json_timedelta="iso8601") +# --- Simplified schemas for generate_chart tool --- + + +# Common pieces +class ColumnRef(BaseModel): + name: str = Field( + ..., + description="Column name", + min_length=1, + max_length=255, + pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$", + ) + label: str | None = Field( + None, description="Display label for the column", max_length=500 + ) + dtype: str | None = Field(None, description="Data type hint") + aggregate: ( + Literal[ + "SUM", + "COUNT", + "AVG", + "MIN", + "MAX", + "COUNT_DISTINCT", + "STDDEV", + "VAR", + "MEDIAN", + "PERCENTILE", + ] + | None + ) = Field( + None, + description="SQL aggregation function. Only these validated functions are " + "supported to prevent SQL errors.", + ) + + @field_validator("name") + @classmethod + def sanitize_name(cls, v: str) -> str: + """Sanitize column name to prevent XSS and SQL injection.""" + if not v or not v.strip(): + raise ValueError("Column name cannot be empty") + + # Length check first to prevent ReDoS attacks + if len(v) > 255: + raise ValueError( + f"Column name too long ({len(v)} characters). " + f"Maximum allowed length is 255 characters." + ) + + # Remove HTML tags and decode entities + sanitized = html.escape(v.strip()) + + # Check for dangerous HTML tags using substring checks (safe) + dangerous_tags = ["", " str | None: + """Sanitize display label to prevent XSS attacks.""" + if v is None: + return v + + # Strip whitespace + v = v.strip() + if not v: + return None + + # Length check first to prevent ReDoS attacks + if len(v) > 500: + raise ValueError( + f"Label too long ({len(v)} characters). " + f"Maximum allowed length is 500 characters." + ) + + # Check for dangerous HTML tags and JavaScript protocols using substring checks + # This avoids ReDoS vulnerabilities from regex patterns + dangerous_tags = [ + "", + "", + "", + "", + "", "<", ">=", "<=", "!="] = Field( + ..., description="Filter operator" + ) + value: str | int | float | bool = Field(..., description="Filter value") + + @field_validator("column") + @classmethod + def sanitize_column(cls, v: str) -> str: + """Sanitize filter column name to prevent injection attacks.""" + if not v or not v.strip(): + raise ValueError("Filter column name cannot be empty") + + # Length check first to prevent ReDoS attacks + if len(v) > 255: + raise ValueError( + f"Filter column name too long ({len(v)} characters). " + f"Maximum allowed length is 255 characters." + ) + + # Remove HTML tags and decode entities + sanitized = html.escape(v.strip()) + + # Check for dangerous HTML tags using substring checks (safe) + dangerous_tags = [""] + v_lower = v.lower() + for tag in dangerous_tags: + if tag in v_lower: + raise ValueError( + "Filter column contains potentially malicious script content" + ) + + # Check URL schemes with word boundaries + if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE): + raise ValueError("Filter column contains potentially malicious URL scheme") + + return sanitized + + @staticmethod + def _validate_string_value(v: str) -> None: + """Validate string filter value for security issues.""" + # Check for dangerous HTML tags and SQL procedures + dangerous_substrings = [ + "", + " str | int | float | bool: + """Sanitize filter value to prevent XSS and SQL injection attacks.""" + if isinstance(v, str): + v = v.strip() + + # Length check FIRST to prevent ReDoS attacks + if len(v) > 1000: + raise ValueError( + f"Filter value too long ({len(v)} characters). " + f"Maximum allowed length is 1000 characters." + ) + + # Validate security + cls._validate_string_value(v) + + # Filter dangerous Unicode characters + v = re.sub( + r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v + ) + + # HTML escape the cleaned content + return html.escape(v) + + return v # Return non-string values as-is + + +# Actual chart types +class TableChartConfig(BaseModel): + chart_type: Literal["table"] = Field("table", description="Chart type") + columns: List[ColumnRef] = Field( + ..., + min_length=1, + description=( + "Columns to display. Must have at least one column. Each column must have " + "a unique label " + "(either explicitly set via 'label' field or auto-generated " + "from name/aggregate)" + ), + ) + filters: List[FilterConfig] | None = Field(None, description="Filters to apply") + sort_by: List[str] | None = Field(None, description="Columns to sort by") + + @model_validator(mode="after") + def validate_unique_column_labels(self) -> "TableChartConfig": + """Ensure all column labels are unique.""" + labels_seen = set() + duplicates = [] + + for i, col in enumerate(self.columns): + # Generate the label that will be used (same logic as create_metric_object) + if col.aggregate: + label = col.label or f"{col.aggregate}({col.name})" + else: + label = col.label or col.name + + if label in labels_seen: + duplicates.append(f"columns[{i}]: '{label}'") + else: + labels_seen.add(label) + + if duplicates: + raise ValueError( + f"Duplicate column/metric labels: {', '.join(duplicates)}. " + f"Please make sure all columns and metrics have a unique label. " + f"Use the 'label' field to provide custom names for columns." + ) + + return self + + +class XYChartConfig(BaseModel): + chart_type: Literal["xy"] = Field("xy", description="Chart type") + x: ColumnRef = Field(..., description="X-axis column") + y: List[ColumnRef] = Field( + ..., + min_length=1, + description="Y-axis columns (metrics). Must have at least one Y-axis column. " + "Each column must have a unique label " + "that doesn't conflict with x-axis or group_by labels", + ) + kind: Literal["line", "bar", "area", "scatter"] = Field( + "line", description="Chart visualization type" + ) + group_by: ColumnRef | None = Field(None, description="Column to group by") + x_axis: AxisConfig | None = Field(None, description="X-axis configuration") + y_axis: AxisConfig | None = Field(None, description="Y-axis configuration") + legend: LegendConfig | None = Field(None, description="Legend configuration") + filters: List[FilterConfig] | None = Field(None, description="Filters to apply") + + @model_validator(mode="after") + def validate_unique_column_labels(self) -> "XYChartConfig": + """Ensure all column labels are unique across x, y, and group_by.""" + labels_seen = {} # label -> field_name for error reporting + duplicates = [] + + # Check X-axis label + x_label = self.x.label or self.x.name + labels_seen[x_label] = "x" + + # Check Y-axis labels + for i, col in enumerate(self.y): + if col.aggregate: + label = col.label or f"{col.aggregate}({col.name})" + else: + label = col.label or col.name + + if label in labels_seen: + duplicates.append( + f"y[{i}]: '{label}' (conflicts with {labels_seen[label]})" + ) + else: + labels_seen[label] = f"y[{i}]" + + # Check group_by label if present + if self.group_by: + group_label = self.group_by.label or self.group_by.name + if group_label in labels_seen: + duplicates.append( + f"group_by: '{group_label}' " + f"(conflicts with {labels_seen[group_label]})" + ) + + if duplicates: + raise ValueError( + f"Duplicate column/metric labels: {', '.join(duplicates)}. " + f"Please make sure all columns and metrics have a unique label. " + f"Use the 'label' field to provide custom names for columns." + ) + + return self + + +# Discriminated union entry point with custom error handling +ChartConfig = Annotated[ + XYChartConfig | TableChartConfig, + Field( + discriminator="chart_type", + description="Chart configuration - specify chart_type as 'xy' or 'table'", + ), +] + + class ListChartsRequest(MetadataCacheControl): """Request schema for list_charts with clear, unambiguous types.""" @@ -281,3 +770,393 @@ class ListChartsRequest(MetadataCacheControl): "or 'filters' for precise column-based filtering, but not both." ) return self + + +# The tool input models +class GenerateChartRequest(QueryCacheControl): + dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)") + config: ChartConfig = Field(..., description="Chart configuration") + save_chart: bool = Field( + default=True, + description="Whether to permanently save the chart in Superset", + ) + generate_preview: bool = Field( + default=True, + description="Whether to generate a preview image", + ) + preview_formats: List[ + Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"] + ] = Field( + default_factory=lambda: ["url"], + description="List of preview formats to generate", + ) + + @model_validator(mode="after") + def validate_cache_timeout(self) -> "GenerateChartRequest": + """Validate cache timeout is non-negative.""" + if ( + hasattr(self, "cache_timeout") + and self.cache_timeout is not None + and self.cache_timeout < 0 + ): + raise ValueError( + "cache_timeout must be non-negative (0 or positive integer)" + ) + return self + + +class GenerateExploreLinkRequest(FormDataCacheControl): + dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)") + config: ChartConfig = Field(..., description="Chart configuration") + + +class UpdateChartRequest(QueryCacheControl): + identifier: int | str = Field(..., description="Chart identifier (ID, UUID)") + config: ChartConfig = Field(..., description="New chart configuration") + chart_name: str | None = Field( + None, + description="New chart name (optional, will auto-generate if not provided)", + max_length=255, + ) + generate_preview: bool = Field( + default=True, + description="Whether to generate a preview after updating", + ) + preview_formats: List[ + Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"] + ] = Field( + default_factory=lambda: ["url"], + description="List of preview formats to generate", + ) + + @field_validator("chart_name") + @classmethod + def sanitize_chart_name(cls, v: str | None) -> str | None: + """Sanitize chart name to prevent XSS attacks.""" + if v is None: + return v + + # Strip whitespace + v = v.strip() + if not v: + return None + + # Length check first to prevent ReDoS attacks + if len(v) > 255: + raise ValueError( + f"Chart name too long ({len(v)} characters). " + f"Maximum allowed length is 255 characters." + ) + + # Check for dangerous HTML tags using substring checks (safe) + dangerous_tags = [ + "", + "", + "", + "", + " GenerateChartResponse: + """Create and save a chart in Superset. + + IMPORTANT BEHAVIOR: + - Charts ARE saved by default (save_chart=True) + - Set save_chart=False for temporary preview only + - LLM clients MUST display returned chart URL to users + - Embed preview_url as image: ![Chart Preview](preview_url) + + VALIDATION: + - 5-layer pipeline: Schema, business logic, dataset, Superset compatibility, runtime + - XSS/SQL injection prevention + - Column existence validation with fuzzy match suggestions + - Aggregate function type compatibility checking + + Returns: + - Chart ID and metadata (if saved) + - Preview URL and explore URL + - Detailed validation errors with suggestions + """ + start_time = time.time() + await ctx.info( + "Starting chart generation: dataset_id=%s, chart_type=%s, " + "save_chart=%s, preview_formats=%s" + % ( + request.dataset_id, + request.config.chart_type, + request.save_chart, + request.preview_formats, + ) + ) + await ctx.debug( + "Chart configuration details: config=%s" % (request.config.model_dump(),) + ) + + try: + # Run comprehensive validation pipeline + await ctx.report_progress(1, 5, "Running validation pipeline") + await ctx.debug( + "Validating chart request: dataset_id=%s" % (request.dataset_id,) + ) + from superset.mcp_service.chart.validation import ValidationPipeline + + is_valid, parsed_request, validation_error = ( + ValidationPipeline.validate_request(request.model_dump()) + ) + if is_valid and parsed_request is not None: + # Use the validated request going forward + request = parsed_request + if not is_valid: + execution_time = int((time.time() - start_time) * 1000) + assert validation_error is not None # Type narrowing for mypy + await ctx.error( + "Chart validation failed: error=%s" % (validation_error.model_dump(),) + ) + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": validation_error.model_dump(), + "performance": { + "query_duration_ms": execution_time, + "cache_status": "error", + "optimization_suggestions": [], + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) + + # Map the simplified config to Superset's form_data format + form_data = map_config_to_form_data(request.config) + + chart = None + chart_id = None + explore_url = None + form_data_key = None + + # Save chart by default (unless save_chart=False) + if request.save_chart: + await ctx.report_progress(2, 5, "Creating chart in database") + from superset.commands.chart.create import CreateChartCommand + + # Generate a chart name + chart_name = generate_chart_name(request.config) + await ctx.debug("Generated chart name: chart_name=%s" % (chart_name,)) + + # Find the dataset to get its numeric ID + from superset.daos.dataset import DatasetDAO + + await ctx.debug("Looking up dataset: dataset_id=%s" % (request.dataset_id,)) + dataset = None + if isinstance(request.dataset_id, int) or ( + isinstance(request.dataset_id, str) and request.dataset_id.isdigit() + ): + dataset_id = ( + int(request.dataset_id) + if isinstance(request.dataset_id, str) + else request.dataset_id + ) + dataset = DatasetDAO.find_by_id(dataset_id) + # SECURITY FIX: Also validate permissions for numeric ID access + if dataset and not has_dataset_access(dataset): + logger.warning( + "User %s attempted to access dataset %s without permission", + ctx.user.username if hasattr(ctx, "user") else "unknown", + dataset_id, + ) + dataset = None # Treat as not found + else: + # SECURITY FIX: Try UUID lookup with permission validation + dataset = DatasetDAO.find_by_id(request.dataset_id, id_column="uuid") + # Validate permissions for UUID-based access + if dataset and not has_dataset_access(dataset): + logger.warning( + "User %s attempted access dataset %s via UUID", + ctx.user.username if hasattr(ctx, "user") else "unknown", + request.dataset_id, + ) + dataset = None # Treat as not found + + if not dataset: + await ctx.error( + "Dataset not found: dataset_id=%s" % (request.dataset_id,) + ) + from superset.mcp_service.common.error_schemas import ( + ChartGenerationError, + ) + + execution_time = int((time.time() - start_time) * 1000) + error = ChartGenerationError( + error_type="dataset_not_found", + message=f"Dataset not found: {request.dataset_id}", + details=( + f"No dataset found with identifier '{request.dataset_id}'. " + f"This could be an invalid ID/UUID or a permissions issue." + ), + suggestions=[ + "Verify the dataset ID or UUID is correct", + "Check that you have access to this dataset", + "Use the list_datasets tool to find available datasets", + "If using UUID, ensure it's the correct format", + ], + error_code="DATASET_NOT_FOUND", + ) + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": error.model_dump(), + "performance": { + "query_duration_ms": execution_time, + "cache_status": "error", + "optimization_suggestions": [], + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) + + try: + command = CreateChartCommand( + { + "slice_name": chart_name, + "viz_type": form_data["viz_type"], + "datasource_id": dataset.id, + "datasource_type": "table", + "params": json.dumps(form_data), + } + ) + + chart = command.run() + chart_id = chart.id + + # Ensure chart was created successfully before committing + if not chart or not chart.id: + raise Exception("Chart creation failed - no chart ID returned") + + await ctx.info( + "Chart created successfully: chart_id=%s, chart_name=%s" + % ( + chart.id, + chart.slice_name, + ) + ) + + except Exception as e: + logger.error("Chart creation failed: %s", e) + await ctx.error("Chart creation failed: error=%s" % (str(e),)) + raise + # Update explore URL to use saved chart + explore_url = f"{get_superset_base_url()}/explore/?slice_id={chart.id}" + else: + await ctx.report_progress(2, 5, "Generating temporary chart preview") + # Generate explore link with cached form_data for preview-only mode + from superset.mcp_service.chart.chart_utils import generate_explore_link + + explore_url = generate_explore_link(request.dataset_id, form_data) + await ctx.debug("Generated explore link: explore_url=%s" % (explore_url,)) + + # Extract form_data_key from the explore URL + if explore_url and "form_data_key=" in explore_url: + form_data_key = explore_url.split("form_data_key=")[1].split("&")[0] + + # Generate semantic analysis + capabilities = analyze_chart_capabilities(chart, request.config) + semantics = analyze_chart_semantics(chart, request.config) + + # Create performance metadata + execution_time = int((time.time() - start_time) * 1000) + performance = PerformanceMetadata( + query_duration_ms=execution_time, + cache_status="miss", + optimization_suggestions=[], + ) + + # Create accessibility metadata + chart_name = ( + chart.slice_name + if chart and hasattr(chart, "slice_name") + else generate_chart_name(request.config) + ) + accessibility = AccessibilityMetadata( + color_blind_safe=True, # Would need actual analysis + alt_text=f"Chart showing {chart_name}", + high_contrast_available=False, + ) + + # Generate previews if requested + await ctx.report_progress(3, 5, "Generating chart previews") + previews = {} + if request.generate_preview: + await ctx.debug( + "Generating previews: formats=%s" % (str(request.preview_formats),) + ) + try: + for format_type in request.preview_formats: + await ctx.debug( + "Processing preview format: format=%s" % (format_type,) + ) + # Skip base64 format - we never return base64 + if format_type == "base64": + logger.info("Skipping base64 format - not supported") + continue + + if chart_id: + # For saved charts, use the existing preview generation + from superset.mcp_service.chart.tool.get_chart_preview import ( + _get_chart_preview_internal, + GetChartPreviewRequest, + ) + + preview_request = GetChartPreviewRequest( + identifier=str(chart_id), format=format_type + ) + preview_result = await _get_chart_preview_internal( + preview_request, ctx + ) + + if hasattr(preview_result, "content"): + previews[format_type] = preview_result.content + else: + # For preview-only mode (save_chart=false) + if format_type == "url" and form_data_key: + # Generate screenshot URL using centralized helper + from superset.mcp_service.utils.url_utils import ( + get_explore_screenshot_url, + ) + + preview_url = get_explore_screenshot_url(form_data_key) + previews[format_type] = URLPreview( + preview_url=preview_url, + width=800, + height=600, + supports_interaction=False, + ) + elif format_type in ["ascii", "table", "vega_lite"]: + # Generate preview from form data without saved chart + from superset.mcp_service.chart.preview_utils import ( + generate_preview_from_form_data, + ) + + # Convert dataset_id to int only if it's numeric + if ( + isinstance(request.dataset_id, str) + and request.dataset_id.isdigit() + ): + dataset_id_for_preview = int(request.dataset_id) + elif isinstance(request.dataset_id, int): + dataset_id_for_preview = request.dataset_id + else: + # Skip preview generation for non-numeric dataset IDs + logger.warning( + "Cannot generate preview for non-numeric " + ) + continue + + preview_result = generate_preview_from_form_data( + form_data=form_data, + dataset_id=dataset_id_for_preview, + preview_format=format_type, + ) + + if not hasattr(preview_result, "error"): + previews[format_type] = preview_result + + except Exception as e: + # Log warning but don't fail the entire request + await ctx.warning("Preview generation failed: error=%s" % (str(e),)) + logger.warning("Preview generation failed: %s", e) + + # Return enhanced data while maintaining backward compatibility + await ctx.report_progress(4, 5, "Building response") + result = { + "chart": { + "id": chart.id if chart else None, + "slice_name": chart.slice_name + if chart + else generate_chart_name(request.config), + "viz_type": chart.viz_type if chart else form_data.get("viz_type"), + "url": explore_url, + "uuid": str(chart.uuid) if chart and chart.uuid else None, + "saved": request.save_chart, + } + if request.save_chart and chart + else None, + "error": None, + # Enhanced fields for better LLM integration + "previews": previews, + "capabilities": capabilities.model_dump() if capabilities else None, + "semantics": semantics.model_dump() if semantics else None, + "explore_url": explore_url, + "form_data_key": form_data_key, + "api_endpoints": { + "data": f"{get_superset_base_url()}/api/v1/chart/{chart.id}/data/" + if chart + else None, + "preview": get_chart_screenshot_url(chart.id) if chart else None, + "export": f"{get_superset_base_url()}/api/v1/chart/{chart.id}/export/" + if chart + else None, + } + if chart + else {}, + "performance": performance.model_dump() if performance else None, + "accessibility": accessibility.model_dump() if accessibility else None, + "success": True, + "schema_version": "2.0", + "api_version": "v1", + } + await ctx.report_progress(5, 5, "Chart generation completed") + await ctx.info( + "Chart generation completed successfully: chart_id=%s, execution_time_ms=%s" + % ( + chart.id if chart else None, + int((time.time() - start_time) * 1000), + ) + ) + return GenerateChartResponse.model_validate(result) + + except Exception as e: + await ctx.error( + "Chart generation failed: error=%s, execution_time_ms=%s" + % ( + str(e), + int((time.time() - start_time) * 1000), + ) + ) + from superset.mcp_service.utils.error_builder import ChartErrorBuilder + + logger.exception("Chart generation failed: %s", str(e)) + + # Extract chart_type from different sources for better error context + chart_type = "unknown" + try: + if hasattr(request, "config") and hasattr(request.config, "chart_type"): + chart_type = request.config.chart_type + except Exception as extract_error: + # Ignore errors when extracting chart type for error context + logger.debug("Could not extract chart type: %s", extract_error) + + execution_time = int((time.time() - start_time) * 1000) + + # Build standardized error response + error = ChartErrorBuilder.build_error( + error_type="chart_generation_error", + template_key="generation_failed", + template_vars={ + "reason": str(e), + "dataset_id": str(request.dataset_id), + "chart_type": chart_type, + }, + error_code="CHART_GENERATION_FAILED", + ) + + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": error.model_dump(), + "performance": { + "query_duration_ms": execution_time, + "cache_status": "error", + "optimization_suggestions": [], + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) diff --git a/superset/mcp_service/chart/tool/get_chart_available_filters.py b/superset/mcp_service/chart/tool/get_chart_available_filters.py new file mode 100644 index 00000000000..47f6713e876 --- /dev/null +++ b/superset/mcp_service/chart/tool/get_chart_available_filters.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP tool: get_chart_available_filters +""" + +import logging + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.chart.schemas import ( + ChartAvailableFiltersResponse, + GetChartAvailableFiltersRequest, +) +from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +def get_chart_available_filters( + request: GetChartAvailableFiltersRequest, ctx: Context +) -> ChartAvailableFiltersResponse: + """Get available chart filter fields and operators.""" + from superset.daos.chart import ChartDAO + + tool = ModelGetAvailableFiltersCore( + dao_class=ChartDAO, + output_schema=ChartAvailableFiltersResponse, + logger=logger, + ) + return tool.run_tool() diff --git a/superset/mcp_service/chart/tool/get_chart_data.py b/superset/mcp_service/chart/tool/get_chart_data.py new file mode 100644 index 00000000000..e47478b8346 --- /dev/null +++ b/superset/mcp_service/chart/tool/get_chart_data.py @@ -0,0 +1,656 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP tool: get_chart_data +""" + +import logging +from typing import Any, Dict, List, TYPE_CHECKING + +from fastmcp import Context + +if TYPE_CHECKING: + from superset.models.slice import Slice + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.chart.schemas import ( + ChartData, + ChartError, + DataColumn, + GetChartDataRequest, + PerformanceMetadata, +) +from superset.mcp_service.utils.cache_utils import get_cache_status_from_result + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +async def get_chart_data( # noqa: C901 + request: GetChartDataRequest, ctx: Context +) -> ChartData | ChartError: + """Get chart data by ID or UUID. + + Returns the actual data behind a chart for LLM analysis without image rendering. + + Supports: + - Numeric ID or UUID lookup + - Multiple formats: json, csv, excel + - Cache control: use_cache, force_refresh, cache_timeout + + Returns underlying data in requested format with cache status. + """ + await ctx.info( + "Starting chart data retrieval: identifier=%s, format=%s, limit=%s" + % ( + request.identifier, + request.format, + request.limit, + ) + ) + await ctx.debug( + "Cache settings: use_cache=%s, force_refresh=%s, cache_timeout=%s" + % ( + request.use_cache, + request.force_refresh, + request.cache_timeout, + ) + ) + + try: + await ctx.report_progress(1, 4, "Looking up chart") + from superset.daos.chart import ChartDAO + from superset.utils import json as utils_json + + # Find the chart + chart = None + if isinstance(request.identifier, int) or ( + isinstance(request.identifier, str) and request.identifier.isdigit() + ): + chart_id = ( + int(request.identifier) + if isinstance(request.identifier, str) + else request.identifier + ) + await ctx.debug( + "Performing ID-based chart lookup: chart_id=%s" % (chart_id,) + ) + chart = ChartDAO.find_by_id(chart_id) + else: + await ctx.debug( + "Performing UUID-based chart lookup: uuid=%s" % (request.identifier,) + ) + # Try UUID lookup using DAO flexible method + chart = ChartDAO.find_by_id(request.identifier, id_column="uuid") + + if not chart: + await ctx.error("Chart not found: identifier=%s" % (request.identifier,)) + return ChartError( + error=f"No chart found with identifier: {request.identifier}", + error_type="NotFound", + ) + + await ctx.info( + "Chart found successfully: chart_id=%s, chart_name=%s, viz_type=%s" + % ( + chart.id, + chart.slice_name, + chart.viz_type, + ) + ) + logger.info("Getting data for chart %s: %s", chart.id, chart.slice_name) + + import time + + start_time = time.time() + + try: + await ctx.report_progress(2, 4, "Preparing data query") + # Get chart data using the existing API + from superset.commands.chart.data.get_data_command import ChartDataCommand + from superset.common.query_context_factory import QueryContextFactory + + # Parse the form_data to get query context + form_data = utils_json.loads(chart.params) if chart.params else {} + await ctx.debug( + "Chart form data parsed: has_filters=%s, has_groupby=%s, has_metrics=%s" + % ( + bool(form_data.get("filters")), + bool(form_data.get("groupby")), + bool(form_data.get("metrics")), + ) + ) + + # Create a proper QueryContext using the factory with cache control + factory = QueryContextFactory() + query_context = factory.create( + datasource={"id": chart.datasource_id, "type": chart.datasource_type}, + queries=[ + { + "filters": form_data.get("filters", []), + "columns": form_data.get("groupby", []), + "metrics": form_data.get("metrics", []), + "row_limit": request.limit or 100, + "order_desc": True, + # Apply cache control from request + "cache_timeout": request.cache_timeout, + } + ], + form_data=form_data, + # Use cache unless force_refresh is True + force=request.force_refresh, + ) + + await ctx.report_progress(3, 4, "Executing data query") + await ctx.debug( + "Query execution parameters: datasource_id=%s, datasource_type=%s, " + "row_limit=%s, force_refresh=%s" + % ( + chart.datasource_id, + chart.datasource_type, + request.limit or 100, + request.force_refresh, + ) + ) + + # Execute the query + command = ChartDataCommand(query_context) + result = command.run() + + # Handle empty query results for certain chart types + if not result or ("queries" not in result) or len(result["queries"]) == 0: + await ctx.warning( + "Empty query results: chart_id=%s, chart_type=%s" + % (chart.id, chart.viz_type) + ) + return ChartError( + error=f"No query results returned for chart {chart.id}. " + f"This may occur with chart types like big_number.", + error_type="EmptyQuery", + ) + + # Extract data from result (we've already validated it exists above) + query_result = result["queries"][0] + data = query_result.get("data", []) + raw_columns = query_result.get("colnames", []) + + await ctx.debug( + "Query results received: row_count=%s, column_count=%s, " + "has_cache_key=%s" + % ( + len(data), + len(raw_columns), + bool(query_result.get("cache_key")), + ) + ) + + # Check if we have data to work with + if not data: + await ctx.warning("No data in query results: chart_id=%s" % (chart.id,)) + return ChartError( + error=f"No data available for chart {chart.id}", error_type="NoData" + ) + + # Create rich column metadata + columns = [] + for col_name in raw_columns: + # Sample some values for metadata + sample_values = [ + row.get(col_name) + for row in data[:3] + if row.get(col_name) is not None + ] + + # Infer data type + data_type = "string" + if sample_values: + if all(isinstance(v, (int, float)) for v in sample_values): + data_type = "numeric" + elif all(isinstance(v, bool) for v in sample_values): + data_type = "boolean" + + columns.append( + DataColumn( + name=col_name, + display_name=col_name.replace("_", " ").title(), + data_type=data_type, + sample_values=sample_values[:3], + null_count=sum(1 for row in data if row.get(col_name) is None), + unique_count=len({str(row.get(col_name)) for row in data}), + ) + ) + + # Cache status information using utility function + cache_status = get_cache_status_from_result( + query_result, force_refresh=request.force_refresh + ) + + # Generate insights and recommendations + insights = [] + if len(data) > 100: + insights.append( + "Large dataset - consider filtering for better performance" + ) + if len(raw_columns) > 10: + insights.append("Many columns available - focus on key metrics") + + # Add cache-specific insights + if cache_status.cache_hit: + if ( + cache_status.cache_age_seconds + and cache_status.cache_age_seconds > 3600 + ): + hours_old = cache_status.cache_age_seconds // 3600 + insights.append( + f"Data is from cache ({hours_old}h old) - " + "consider refreshing for latest data" + ) + else: + insights.append("Data served from cache for fast response") + else: + insights.append("Fresh data retrieved from database") + + recommended_visualizations = [] + if any( + "time" in col.lower() or "date" in col.lower() for col in raw_columns + ): + recommended_visualizations.extend(["line chart", "time series"]) + if len(raw_columns) <= 3: + recommended_visualizations.extend(["bar chart", "scatter plot"]) + + # Performance metadata with cache awareness + execution_time = int((time.time() - start_time) * 1000) + performance_status = ( + "cache_hit" if cache_status.cache_hit else "fresh_query" + ) + optimization_suggestions = [] + + if not cache_status.cache_hit and execution_time > 5000: + optimization_suggestions.append( + "Consider using cache for this slow query" + ) + elif ( + cache_status.cache_hit + and cache_status.cache_age_seconds + and cache_status.cache_age_seconds > 86400 + ): + optimization_suggestions.append("Cache is old - consider refreshing") + + performance = PerformanceMetadata( + query_duration_ms=execution_time, + cache_status=performance_status, + optimization_suggestions=optimization_suggestions, + ) + + # Generate comprehensive summary with cache info + cache_info = "" + if cache_status.cache_hit: + age_info = ( + f" (cached {cache_status.cache_age_seconds // 60}m ago)" + if cache_status.cache_age_seconds + else " (cached)" + ) + cache_info = age_info + + summary_parts = [ + f"Chart '{chart.slice_name}' ({chart.viz_type})", + f"Contains {len(data)} rows across {len(raw_columns)} columns" + f"{cache_info}", + ] + + if data and len(data) > 0: + summary_parts.append( + f"Sample data includes: {', '.join(raw_columns[:3])}" + ) + + summary = ". ".join(summary_parts) + + # Handle different export formats + if request.format == "csv": + return _export_data_as_csv( + chart, + data[: request.limit] if request.limit else data, + raw_columns, + cache_status, + performance, + ) + elif request.format == "excel": + return _export_data_as_excel( + chart, + data[: request.limit] if request.limit else data, + raw_columns, + cache_status, + performance, + ) + + await ctx.report_progress(4, 4, "Building response") + + # Calculate data quality metrics + data_completeness = 1.0 - ( + sum(col.null_count for col in columns) + / max(len(data) * len(columns), 1) + ) + + await ctx.info( + "Chart data retrieval completed successfully: chart_id=%s, " + "rows_returned=%s, columns_returned=%s, execution_time_ms=%s, " + "cache_hit=%s, data_completeness=%s" + % ( + chart.id, + len(data), + len(raw_columns), + execution_time, + cache_status.cache_hit, + round(data_completeness, 3), + ) + ) + + # Default JSON format + return ChartData( + chart_id=chart.id, + chart_name=chart.slice_name or f"Chart {chart.id}", + chart_type=chart.viz_type or "unknown", + columns=columns, + data=data[: request.limit] if request.limit else data, + row_count=len(data), + total_rows=query_result.get("rowcount"), + summary=summary, + insights=insights, + data_quality={"completeness": data_completeness}, + recommended_visualizations=recommended_visualizations, + data_freshness=None, # Add missing field + performance=performance, + cache_status=cache_status, + ) + + except Exception as data_error: + await ctx.error( + "Data retrieval failed: chart_id=%s, error=%s, error_type=%s" + % ( + chart.id, + str(data_error), + type(data_error).__name__, + ) + ) + logger.error("Data retrieval error for chart %s: %s", chart.id, data_error) + return ChartError( + error=f"Error retrieving chart data: {str(data_error)}", + error_type="DataError", + ) + + except Exception as e: + await ctx.error( + "Chart data retrieval failed: identifier=%s, error=%s, error_type=%s" + % ( + request.identifier, + str(e), + type(e).__name__, + ) + ) + logger.error("Error in get_chart_data: %s", e) + return ChartError( + error=f"Failed to get chart data: {str(e)}", error_type="InternalError" + ) + + +def _export_data_as_csv( + chart: "Slice", + data: List[Dict[str, Any]], + columns: List[str], + cache_status: Any, + performance: Any, +) -> "ChartData": + """Export chart data as CSV format.""" + import csv + import io + + # Create CSV content + output = io.StringIO() + + if data and columns: + writer = csv.DictWriter(output, fieldnames=columns) + writer.writeheader() + + # Write data rows + for row in data: + # Ensure all values are properly formatted for CSV + csv_row = {} + for col in columns: + value = row.get(col, "") + # Handle None values and convert to string + if value is None: + csv_row[col] = "" + elif isinstance(value, (list, dict)): + csv_row[col] = str(value) + else: + csv_row[col] = value + writer.writerow(csv_row) + + csv_content = output.getvalue() + + # Return as ChartData with CSV content in a special field + from superset.mcp_service.chart.schemas import ChartData + + return ChartData( + chart_id=chart.id, + chart_name=chart.slice_name or f"Chart {chart.id}", + chart_type=chart.viz_type or "unknown", + columns=[], # Not needed for CSV export + data=[], # CSV content is in csv_data field + row_count=len(data), + total_rows=len(data), + summary=f"CSV export of chart '{chart.slice_name}' with {len(data)} rows", + insights=[f"Data exported as CSV format ({len(csv_content)} characters)"], + data_quality={}, + recommended_visualizations=[], + data_freshness=None, + performance=performance, + cache_status=cache_status, + # Store CSV content in data field as string for the response + csv_data=csv_content, + format="csv", + ) + + +def _export_data_as_excel( + chart: "Slice", + data: List[Dict[str, Any]], + columns: List[str], + cache_status: Any, + performance: Any, +) -> "ChartData | ChartError": + """Export chart data as Excel format.""" + try: + excel_b64 = _create_excel_with_openpyxl(chart, data, columns) + return _create_excel_chart_data( + chart, data, excel_b64, performance, cache_status + ) + except ImportError: + return _try_xlsxwriter_fallback(chart, data, columns, cache_status, performance) + + +def _create_excel_with_openpyxl( + chart: "Slice", data: List[Dict[str, Any]], columns: List[str] +) -> str: + """Create Excel file using openpyxl.""" + import base64 + import io + + from openpyxl import Workbook + + wb = Workbook() + ws = wb.active + ws.title = chart.slice_name[:31] if chart.slice_name else "Chart Data" + + if data and columns: + _write_excel_headers(ws, columns) + _write_excel_data(ws, data, columns) + + output = io.BytesIO() + wb.save(output) + output.seek(0) + return base64.b64encode(output.read()).decode() + + +def _write_excel_headers(ws: Any, columns: List[str]) -> None: + """Write headers to Excel worksheet.""" + for idx, col in enumerate(columns, 1): + ws.cell(row=1, column=idx, value=col) + + +def _write_excel_data(ws: Any, data: List[Dict[str, Any]], columns: List[str]) -> None: + """Write data to Excel worksheet.""" + for row_idx, row in enumerate(data, 2): + for col_idx, col in enumerate(columns, 1): + value = row.get(col, "") + if value is None: + value = "" + elif isinstance(value, (list, dict)): + value = str(value) + ws.cell(row=row_idx, column=col_idx, value=value) + + +def _try_xlsxwriter_fallback( + chart: "Slice", + data: List[Dict[str, Any]], + columns: List[str], + cache_status: Any, + performance: Any, +) -> "ChartData | ChartError": + """Try xlsxwriter as fallback for Excel export.""" + try: + excel_b64 = _create_excel_with_xlsxwriter(chart, data, columns) + return _create_excel_chart_data_xlsxwriter( + chart, data, excel_b64, performance, cache_status + ) + except ImportError: + from superset.mcp_service.chart.schemas import ChartError + + return ChartError( + error="Excel export requires openpyxl or xlsxwriter package", + error_type="ExportError", + ) + + +def _create_excel_with_xlsxwriter( + chart: "Slice", data: List[Dict[str, Any]], columns: List[str] +) -> str: + """Create Excel file using xlsxwriter.""" + import base64 + import io + + import xlsxwriter + + output = io.BytesIO() + workbook = xlsxwriter.Workbook(output, {"in_memory": True}) + sheet_name = chart.slice_name[:31] if chart.slice_name else "Chart Data" + worksheet = workbook.add_worksheet(sheet_name) + + if data and columns: + _write_xlsxwriter_data(worksheet, data, columns) + + workbook.close() + output.seek(0) + return base64.b64encode(output.read()).decode() + + +def _write_xlsxwriter_data( + worksheet: Any, data: List[Dict[str, Any]], columns: List[str] +) -> None: + """Write data to xlsxwriter worksheet.""" + # Write headers + for idx, col in enumerate(columns): + worksheet.write(0, idx, col) + + # Write data + for row_idx, row in enumerate(data): + for col_idx, col in enumerate(columns): + value = row.get(col, "") + if value is None: + value = "" + elif isinstance(value, (list, dict)): + value = str(value) + worksheet.write(row_idx + 1, col_idx, value) + + +def _create_excel_chart_data( + chart: "Slice", + data: List[Dict[str, Any]], + excel_b64: str, + performance: Any, + cache_status: Any, +) -> "ChartData": + """Create ChartData response for Excel export (openpyxl).""" + from superset.mcp_service.chart.schemas import ChartData + + chart_name = chart.slice_name or f"Chart {chart.id}" + summary = f"Excel export of chart '{chart.slice_name}' with {len(data)} rows" + + return ChartData( + chart_id=chart.id, + chart_name=chart_name, + chart_type=chart.viz_type or "unknown", + columns=[], + data=[], + row_count=len(data), + total_rows=len(data), + summary=summary, + insights=["Data exported as Excel format (base64 encoded)"], + data_quality={}, + recommended_visualizations=[], + data_freshness=None, + performance=performance, + cache_status=cache_status, + excel_data=excel_b64, + format="excel", + ) + + +def _create_excel_chart_data_xlsxwriter( + chart: "Slice", + data: List[Dict[str, Any]], + excel_b64: str, + performance: Any, + cache_status: Any, +) -> "ChartData": + """Create ChartData response for Excel export (xlsxwriter).""" + from superset.mcp_service.chart.schemas import ChartData + + chart_name = chart.slice_name or f"Chart {chart.id}" + summary = f"Excel export of chart '{chart.slice_name}' with {len(data)} rows" + + return ChartData( + chart_id=chart.id, + chart_name=chart_name, + chart_type=chart.viz_type or "unknown", + columns=[], + data=[], + row_count=len(data), + total_rows=len(data), + summary=summary, + insights=["Data exported as Excel format (base64 encoded, xlsxwriter)"], + data_quality={}, + recommended_visualizations=[], + data_freshness=None, + performance=performance, + cache_status=cache_status, + excel_data=excel_b64, + format="excel", + ) diff --git a/superset/mcp_service/chart/tool/get_chart_info.py b/superset/mcp_service/chart/tool/get_chart_info.py index 86f8170debe..9c59131dee7 100644 --- a/superset/mcp_service/chart/tool/get_chart_info.py +++ b/superset/mcp_service/chart/tool/get_chart_info.py @@ -43,6 +43,14 @@ async def get_chart_info( ) -> ChartInfo | ChartError: """Get chart metadata by ID or UUID. + IMPORTANT FOR LLM CLIENTS: + - ALWAYS display the chart URL when returned + - URL field contains chart's screenshot URL for preview + + Supports: + - Numeric ID (e.g., 123) + - UUID string (e.g., "a1b2c3d4-...") + Returns chart details including name, type, and URL. """ from superset.daos.chart import ChartDAO diff --git a/superset/mcp_service/chart/tool/get_chart_preview.py b/superset/mcp_service/chart/tool/get_chart_preview.py new file mode 100644 index 00000000000..24f1ff70b35 --- /dev/null +++ b/superset/mcp_service/chart/tool/get_chart_preview.py @@ -0,0 +1,2082 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP tool: get_chart_preview +""" + +import logging +from typing import Any, Dict, List, Protocol + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.chart.schemas import ( + AccessibilityMetadata, + ASCIIPreview, + ChartError, + ChartPreview, + GetChartPreviewRequest, + InteractivePreview, + PerformanceMetadata, + TablePreview, + URLPreview, + VegaLitePreview, +) +from superset.mcp_service.utils.url_utils import get_superset_base_url + +logger = logging.getLogger(__name__) + + +class ChartLike(Protocol): + """Protocol for chart-like objects with required attributes for preview.""" + + id: int + slice_name: str | None + viz_type: str | None + datasource_id: int + datasource_type: str + params: str | None + digest: str + uuid: Any + + +class PreviewFormatStrategy: + """Base class for preview format strategies.""" + + def __init__(self, chart: ChartLike, request: GetChartPreviewRequest) -> None: + self.chart = chart + self.request = request + + def generate(self) -> ChartPreview | ChartError: + """Generate preview in the specific format.""" + raise NotImplementedError + + +class URLPreviewStrategy(PreviewFormatStrategy): + """Generate URL-based image preview.""" + + def generate(self) -> URLPreview | ChartError: + try: + from flask import g + + from superset.mcp_service.screenshot.pooled_screenshot import ( + PooledChartScreenshot, + ) + from superset.mcp_service.utils.url_utils import get_superset_base_url + + # Check if chart.id is None + if self.chart.id is None: + return ChartError( + error="Chart has no ID - cannot generate URL preview", + error_type="InvalidChart", + ) + + # Use configured Superset base URL instead of Flask's url_for + # which may not respect SUPERSET_WEBSERVER_ADDRESS + base_url = get_superset_base_url() + chart_url = f"{base_url}/superset/slice/{self.chart.id}/" + screenshot = PooledChartScreenshot(chart_url, self.chart.digest) + + window_size = (self.request.width or 800, self.request.height or 600) + image_data = screenshot.get_screenshot(user=g.user, window_size=window_size) + + if image_data: + # Use the MCP service screenshot URL via centralized helper + from superset.mcp_service.utils.url_utils import ( + get_chart_screenshot_url, + ) + + preview_url = get_chart_screenshot_url(self.chart.id) + + return URLPreview( + preview_url=preview_url, + width=self.request.width or 800, + height=self.request.height or 600, + ) + else: + return ChartError( + error=f"Could not generate screenshot for chart {self.chart.id}", + error_type="ScreenshotError", + ) + except Exception as e: + logger.error("URL preview generation failed: %s", e) + return ChartError( + error=f"Failed to generate URL preview: {str(e)}", error_type="URLError" + ) + + +# Base64 preview support removed - we never return base64 data + + +class ASCIIPreviewStrategy(PreviewFormatStrategy): + """Generate ASCII art preview.""" + + def generate(self) -> ASCIIPreview | ChartError: + try: + from superset.commands.chart.data.get_data_command import ChartDataCommand + from superset.common.query_context_factory import QueryContextFactory + from superset.utils import json as utils_json + + form_data = utils_json.loads(self.chart.params) if self.chart.params else {} + + logger.info("Chart form_data keys: %s", list(form_data.keys())) + logger.info("Chart viz_type: %s", self.chart.viz_type) + logger.info("Chart datasource_id: %s", self.chart.datasource_id) + logger.info("Chart datasource_type: %s", self.chart.datasource_type) + + # Check if datasource_id is None + if self.chart.datasource_id is None: + return ChartError( + error="Chart has no datasource_id - cannot generate preview", + error_type="InvalidChart", + ) + + # Build query for chart data + x_axis_config = form_data.get("x_axis") + groupby_columns = form_data.get("groupby", []) + metrics = form_data.get("metrics", []) + + columns = groupby_columns.copy() + if x_axis_config and isinstance(x_axis_config, str): + columns.append(x_axis_config) + elif x_axis_config and isinstance(x_axis_config, dict): + if "column_name" in x_axis_config: + columns.append(x_axis_config["column_name"]) + + factory = QueryContextFactory() + query_context = factory.create( + datasource={ + "id": self.chart.datasource_id, + "type": self.chart.datasource_type, + }, + queries=[ + { + "filters": form_data.get("filters", []), + "columns": columns, + "metrics": metrics, + "row_limit": 50, + "order_desc": True, + } + ], + form_data=form_data, + force=False, + ) + + command = ChartDataCommand(query_context) + result = command.run() + + data = [] + if result and "queries" in result and len(result["queries"]) > 0: + data = result["queries"][0].get("data", []) + + ascii_chart = generate_ascii_chart( + data, + self.chart.viz_type or "table", + self.request.ascii_width or 80, + self.request.ascii_height or 20, + ) + + return ASCIIPreview( + ascii_content=ascii_chart, + width=self.request.ascii_width or 80, + height=self.request.ascii_height or 20, + ) + + except Exception as e: + logger.error("ASCII preview generation failed: %s", e) + return ChartError( + error=f"Failed to generate ASCII preview: {str(e)}", + error_type="ASCIIError", + ) + + +class TablePreviewStrategy(PreviewFormatStrategy): + """Generate table preview of chart data.""" + + def generate(self) -> TablePreview | ChartError: + try: + from superset.commands.chart.data.get_data_command import ChartDataCommand + from superset.common.query_context_factory import QueryContextFactory + from superset.utils import json as utils_json + + form_data = utils_json.loads(self.chart.params) if self.chart.params else {} + + # Check if datasource_id is None + if self.chart.datasource_id is None: + return ChartError( + error="Chart has no datasource_id - cannot generate table preview", + error_type="InvalidChart", + ) + + factory = QueryContextFactory() + query_context = factory.create( + datasource={ + "id": self.chart.datasource_id, + "type": self.chart.datasource_type, + }, + queries=[ + { + "filters": form_data.get("filters", []), + "columns": form_data.get("groupby", []), + "metrics": form_data.get("metrics", []), + "row_limit": 20, + "order_desc": True, + } + ], + form_data=form_data, + force=False, + ) + + command = ChartDataCommand(query_context) + result = command.run() + + data = [] + if result and "queries" in result and len(result["queries"]) > 0: + data = result["queries"][0].get("data", []) + + table_data = _generate_ascii_table(data, 120) + + return TablePreview( + table_data=table_data, + row_count=len(data), + ) + + except Exception as e: + logger.error("Table preview generation failed: %s", e) + return ChartError( + error=f"Failed to generate table preview: {str(e)}", + error_type="TableError", + ) + + +class VegaLitePreviewStrategy(PreviewFormatStrategy): + """Generate Vega-Lite specification for interactive chart preview.""" + + def _get_form_data(self) -> Dict[str, Any] | None: + """Extract form_data from chart params.""" + try: + if hasattr(self.chart, "params") and self.chart.params: + from superset.utils import json as utils_json + + return utils_json.loads(self.chart.params) + return None + except Exception: + return None + + def generate(self) -> VegaLitePreview | ChartError: + """Generate Vega-Lite JSON specification from chart data.""" + try: + # Get chart data directly using the same logic as get_chart_data tool + # but without calling the MCP tool wrapper + from superset.commands.chart.data.get_data_command import ChartDataCommand + from superset.common.query_context_factory import QueryContextFactory + from superset.daos.chart import ChartDAO + from superset.utils import json as utils_json + + # Get the chart object if we don't have form_data access + if not hasattr(self.chart, "params") or not self.chart.params: + # Fetch full chart details + chart_obj = None + if self.chart.id is None: + return ChartError( + error="Chart has no ID - cannot generate Vega-Lite preview", + error_type="InvalidChart", + ) + + if isinstance(self.chart.id, int): + chart_obj = ChartDAO.find_by_id(self.chart.id) + else: + chart_obj = ChartDAO.find_by_id(self.chart.id, id_column="uuid") + + if not chart_obj: + return ChartError( + error=f"Chart {self.chart.id} not found for data retrieval", + error_type="ChartNotFound", + ) + + form_data = ( + utils_json.loads(chart_obj.params) if chart_obj.params else {} + ) + else: + form_data = ( + utils_json.loads(self.chart.params) if self.chart.params else {} + ) + + # Create query context for data retrieval + factory = QueryContextFactory() + query_context = factory.create( + datasource={ + "id": self.chart.datasource_id, + "type": self.chart.datasource_type, + }, + queries=[ + { + "filters": form_data.get("filters", []), + "columns": form_data.get("groupby", []), + "metrics": form_data.get("metrics", []), + "row_limit": 1000, # More data for visualization + "order_desc": True, + } + ], + form_data=form_data, + force=self.request.force_refresh, + ) + + # Execute the query + command = ChartDataCommand(query_context) + result = command.run() + + # Extract data from result + chart_data = [] + if result and "queries" in result and len(result["queries"]) > 0: + chart_data = result["queries"][0].get("data", []) + + if not chart_data or not isinstance(chart_data, list): + return ChartError( + error="No data available for Vega-Lite visualization", + error_type="NoDataError", + ) + + # Convert Superset chart type to Vega-Lite specification + vega_spec = self._create_vega_lite_spec(chart_data) + + return VegaLitePreview( + type="vega_lite", + specification=vega_spec, + supports_streaming=False, + ) + + except Exception as e: + logger.exception( + "Error generating Vega-Lite preview for chart %s", self.chart.id + ) + return ChartError( + error=f"Failed to generate Vega-Lite preview: {str(e)}", + error_type="VegaLiteGenerationError", + ) + + def _create_vega_lite_spec(self, data: List[Any]) -> Dict[str, Any]: + """Create Vega-Lite specification from chart data.""" + if not data: + return {"data": {"values": []}, "mark": "point"} + + # Get data fields and analyze types + first_row = data[0] if data else {} + fields = list(first_row.keys()) if first_row else [] + field_types = self._analyze_field_types(data, fields) + + # Determine chart type based on Superset viz_type + viz_type = getattr(self.chart, "viz_type", "table") or "table" + + # Basic Vega-Lite specification + spec = { + "$schema": "https://vega.github.io/schema/vega-lite/v5.json", + "description": ( + f"Chart preview for " + f"{getattr(self.chart, 'slice_name', 'Untitled Chart')}" + ), + "data": {"values": data}, + "width": self.request.width or 400, + "height": self.request.height or 300, + } + + # Configure visualization based on chart type + chart_spec = self._get_chart_spec_for_type(viz_type, fields, field_types) + spec.update(chart_spec) + + return spec + + def _get_chart_spec_for_type( + self, viz_type: str, fields: List[str], field_types: Dict[str, str] + ) -> Dict[str, Any]: + """Get chart specification based on visualization type.""" + chart_type_mapping = { + "line": [ + "echarts_timeseries_line", + "echarts_timeseries", + "echarts_timeseries_smooth", + "echarts_timeseries_step", + "line", + ], + "bar": [ + "echarts_timeseries_bar", + "echarts_timeseries_column", + "bar", + "column", + ], + "area": ["echarts_area", "area"], + "scatter": ["echarts_timeseries_scatter", "scatter"], + "pie": ["pie"], + "big_number": ["big_number", "big_number_total"], + "histogram": ["histogram"], + "box_plot": ["box_plot"], + "heatmap": ["heatmap", "heatmap_v2", "cal_heatmap"], + "funnel": ["funnel"], + "gauge": ["gauge_chart"], + "mixed": ["mixed_timeseries"], + "table": ["table"], + } + + # Find matching chart type + for chart_type, viz_types in chart_type_mapping.items(): + if viz_type in viz_types: + method_name = f"_{chart_type}_chart_spec" + if hasattr(self, method_name): + return getattr(self, method_name)(fields, field_types) + + # Default fallback + logger.info("Unknown chart type '%s', using scatter plot fallback", viz_type) + return self._scatter_chart_spec(fields, field_types) + + def _analyze_field_types( + self, data: List[Any], fields: List[str] + ) -> Dict[str, str]: + """Analyze data fields to determine appropriate Vega-Lite types.""" + field_types: Dict[str, str] = {} + + if not data or not fields: + return field_types + + try: + # Sample a few rows to determine types + sample_size = min(10, len(data)) + + for field in fields: + field_type = "nominal" # default + + try: + # Collect sample values + sample_values = self._get_sample_values(data, field, sample_size) + + if not sample_values: + field_types[field] = "nominal" + continue + + # Determine field type based on sample values + field_type = self._determine_field_type(sample_values) + field_types[field] = field_type + + except Exception as e: + logger.warning("Error analyzing field '%s': %s", field, e) + field_types[field] = "nominal" # Safe default + + except Exception as e: + logger.warning("Error in field type analysis: %s", e) + # Return nominal types for all fields as fallback + return {field: "nominal" for field in fields} + + return field_types + + def _get_sample_values( + self, data: List[Any], field: str, sample_size: int + ) -> List[Any]: + """Get sample values for a field from the data.""" + sample_values = [] + for row in data[:sample_size]: + if isinstance(row, dict) and field in row: + val = row[field] + if val is not None: + sample_values.append(val) + return sample_values + + def _determine_field_type(self, sample_values: List[Any]) -> str: + """Determine the field type based on sample values.""" + # Check for temporal fields (dates) + if any( + isinstance(val, str) and self._looks_like_date(val) for val in sample_values + ): + return "temporal" + # Check for numeric fields + elif all( + isinstance(val, (int, float)) and not isinstance(val, bool) + for val in sample_values + ): + return "quantitative" + # Check for ordinal fields (limited unique values) + elif len({str(val) for val in sample_values}) <= 10: + # Could be ordinal or nominal, default to nominal for safety + return "nominal" + else: + return "nominal" + + def _looks_like_date(self, value: str) -> bool: + """Quick heuristic to detect date-like strings.""" + if not isinstance(value, str): + return False + + # Common date patterns + date_indicators = [ + "-", + "/", + "T", + ":", + "jan", + "feb", + "mar", + "apr", + "may", + "jun", + "jul", + "aug", + "sep", + "oct", + "nov", + "dec", + "monday", + "tuesday", + "wednesday", + "thursday", + "friday", + "saturday", + "sunday", + ] + + value_lower = value.lower() + return any(indicator in value_lower for indicator in date_indicators) + + def _line_chart_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create line chart specification.""" + field_types = field_types or {} + + # Try to get original field mappings from chart form_data + form_data = self._get_form_data() + + # Extract original x/y field mappings + x_field = form_data.get("x_axis") if form_data else None + if not x_field: + # Fallback to guessing from aggregated fields + x_field = fields[0] if fields else "x" + + # For y-axis, we need to use the aggregated field name from data + y_field = fields[1] if len(fields) > 1 else fields[0] if fields else "y" + + # Better type detection for x-axis + x_type = field_types.get(x_field, "nominal") + # Override if we know it's the x_axis from form_data (likely temporal) + if form_data and x_field == form_data.get("x_axis"): + if any( + kw in x_field.lower() for kw in ["date", "time", "year", "month", "day"] + ): + x_type = "temporal" + + y_type = field_types.get(y_field, "quantitative") + + return { + "mark": {"type": "line", "point": True, "tooltip": True}, + "encoding": { + "x": {"field": x_field, "type": x_type, "title": x_field}, + "y": {"field": y_field, "type": y_type, "title": y_field}, + "tooltip": [ + {"field": f, "type": field_types.get(f, "nominal")} + for f in fields[:5] + ], + }, + } + + def _bar_chart_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create bar chart specification.""" + field_types = field_types or {} + x_field = fields[0] if fields else "x" + y_field = fields[1] if len(fields) > 1 else fields[0] if fields else "y" + + x_type = field_types.get(x_field, "nominal") + y_type = field_types.get(y_field, "quantitative") + + return { + "mark": {"type": "bar", "tooltip": True}, + "encoding": { + "x": {"field": x_field, "type": x_type, "title": x_field}, + "y": {"field": y_field, "type": y_type, "title": y_field}, + "tooltip": [ + {"field": f, "type": field_types.get(f, "nominal")} + for f in fields[:5] + ], + }, + } + + def _area_chart_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create area chart specification.""" + field_types = field_types or {} + x_field = fields[0] if fields else "x" + y_field = fields[1] if len(fields) > 1 else fields[0] if fields else "y" + + x_type = field_types.get( + x_field, + "temporal" + if "date" in x_field.lower() or "time" in x_field.lower() + else "nominal", + ) + y_type = field_types.get(y_field, "quantitative") + + return { + "mark": {"type": "area", "tooltip": True}, + "encoding": { + "x": {"field": x_field, "type": x_type, "title": x_field}, + "y": {"field": y_field, "type": y_type, "title": y_field}, + "tooltip": [ + {"field": f, "type": field_types.get(f, "nominal")} + for f in fields[:5] + ], + }, + } + + def _scatter_chart_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create scatter plot specification.""" + field_types = field_types or {} + x_field = fields[0] if fields else "x" + y_field = fields[1] if len(fields) > 1 else fields[0] if fields else "y" + + x_type = field_types.get(x_field, "quantitative") + y_type = field_types.get(y_field, "quantitative") + + return { + "mark": {"type": "circle", "size": 100, "tooltip": True}, + "encoding": { + "x": {"field": x_field, "type": x_type, "title": x_field}, + "y": {"field": y_field, "type": y_type, "title": y_field}, + "tooltip": [ + {"field": f, "type": field_types.get(f, "nominal")} + for f in fields[:5] + ], + }, + } + + def _table_chart_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create table-like visualization (using text marks).""" + field_types = field_types or {} + # For table data, create a simple dot plot + y_field = fields[0] if fields else "index" + + return { + "mark": {"type": "circle", "size": 50}, + "encoding": { + "y": { + "field": y_field, + "type": field_types.get(y_field, "nominal"), + "title": y_field, + }, + "tooltip": [ + {"field": f, "type": field_types.get(f, "nominal")} + for f in fields[:10] + ], + }, + } + + def _pie_chart_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create pie chart specification using arc marks.""" + field_types = field_types or {} + category_field = fields[0] if fields else "category" + value_field = fields[1] if len(fields) > 1 else fields[0] if fields else "value" + + return { + "mark": {"type": "arc", "tooltip": True}, + "encoding": { + "theta": { + "field": value_field, + "type": field_types.get(value_field, "quantitative"), + }, + "color": { + "field": category_field, + "type": field_types.get(category_field, "nominal"), + "title": category_field, + }, + "tooltip": [ + {"field": f, "type": field_types.get(f, "nominal")} + for f in fields[:5] + ], + }, + } + + def _big_number_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create big number visualization using text mark.""" + field_types = field_types or {} + value_field = fields[0] if fields else "value" + + return { + "mark": { + "type": "text", + "fontSize": 48, + "fontWeight": "bold", + "align": "center", + "baseline": "middle", + "tooltip": True, + }, + "encoding": { + "text": { + "field": value_field, + "type": field_types.get(value_field, "quantitative"), + "format": ",.0f", + }, + "tooltip": [ + {"field": f, "type": field_types.get(f, "nominal")} + for f in fields[:3] + ], + }, + } + + def _histogram_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create histogram using bar marks with binned data.""" + x_field = fields[0] if fields else "value" + + return { + "mark": {"type": "bar", "tooltip": True}, + "encoding": { + "x": { + "field": x_field, + "type": "quantitative", + "bin": {"maxbins": 20}, + "title": x_field, + }, + "y": {"aggregate": "count", "type": "quantitative", "title": "Count"}, + "tooltip": [{"field": f, "type": "nominal"} for f in fields[:3]], + }, + } + + def _box_plot_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create box plot approximation using error bars.""" + x_field = fields[0] if fields else "category" + y_field = fields[1] if len(fields) > 1 else fields[0] if fields else "value" + + return { + "mark": {"type": "boxplot", "tooltip": True}, + "encoding": { + "x": {"field": x_field, "type": "nominal", "title": x_field}, + "y": {"field": y_field, "type": "quantitative", "title": y_field}, + "tooltip": [{"field": f, "type": "nominal"} for f in fields[:5]], + }, + } + + def _heatmap_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create heatmap using rect marks.""" + x_field = fields[0] if fields else "x" + y_field = fields[1] if len(fields) > 1 else "y" + color_field = ( + fields[2] if len(fields) > 2 else fields[1] if len(fields) > 1 else "value" + ) + + return { + "mark": {"type": "rect", "tooltip": True}, + "encoding": { + "x": {"field": x_field, "type": "nominal", "title": x_field}, + "y": {"field": y_field, "type": "nominal", "title": y_field}, + "color": { + "field": color_field, + "type": "quantitative", + "scale": {"scheme": "blues"}, + "title": color_field, + }, + "tooltip": [{"field": f, "type": "nominal"} for f in fields[:5]], + }, + } + + def _funnel_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create funnel chart using horizontal bars.""" + stage_field = fields[0] if fields else "stage" + value_field = fields[1] if len(fields) > 1 else fields[0] if fields else "value" + + return { + "mark": {"type": "bar", "tooltip": True}, + "encoding": { + "y": { + "field": stage_field, + "type": "nominal", + "sort": "-x", + "title": stage_field, + }, + "x": { + "field": value_field, + "type": "quantitative", + "title": value_field, + }, + "color": { + "field": value_field, + "type": "quantitative", + "scale": {"scheme": "viridis"}, + }, + "tooltip": [{"field": f, "type": "nominal"} for f in fields[:5]], + }, + } + + def _gauge_chart_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create gauge chart using arc marks.""" + value_field = fields[0] if fields else "value" + + return { + "mark": { + "type": "arc", + "innerRadius": 50, + "outerRadius": 80, + "tooltip": True, + }, + "encoding": { + "theta": { + "field": value_field, + "type": "quantitative", + "scale": {"range": [0, 6.28]}, + }, + "color": { + "field": value_field, + "type": "quantitative", + "scale": {"scheme": "redyellowgreen"}, + }, + "tooltip": [{"field": f, "type": "nominal"} for f in fields[:3]], + }, + } + + def _mixed_chart_spec( + self, fields: List[str], field_types: Dict[str, str] | None = None + ) -> Dict[str, Any]: + """Create mixed timeseries using layered marks.""" + x_field = fields[0] if fields else "date" + y_field = fields[1] if len(fields) > 1 else fields[0] if fields else "value" + + return { + "layer": [ + { + "mark": {"type": "line", "tooltip": True}, + "encoding": { + "x": {"field": x_field, "type": "temporal", "title": x_field}, + "y": { + "field": y_field, + "type": "quantitative", + "title": y_field, + }, + "tooltip": [ + {"field": f, "type": "nominal"} for f in fields[:5] + ], + }, + }, + { + "mark": {"type": "point", "filled": True, "size": 50}, + "encoding": { + "x": {"field": x_field, "type": "temporal"}, + "y": {"field": y_field, "type": "quantitative"}, + }, + }, + ] + } + + +class PreviewFormatGenerator: + """Factory for generating different preview formats.""" + + STRATEGIES = { + "url": URLPreviewStrategy, + "ascii": ASCIIPreviewStrategy, + "table": TablePreviewStrategy, + "vega_lite": VegaLitePreviewStrategy, + } + + def __init__(self, chart: ChartLike, request: GetChartPreviewRequest) -> None: + self.chart = chart + self.request = request + + def generate( + self, + ) -> ( + URLPreview + | InteractivePreview + | ASCIIPreview + | VegaLitePreview + | TablePreview + | ChartError + ): + """Generate preview using the appropriate strategy.""" + strategy_class = self.STRATEGIES.get(self.request.format) + + if not strategy_class: + return ChartError( + error=f"Unsupported preview format: {self.request.format}", + error_type="UnsupportedFormat", + ) + + strategy = strategy_class(self.chart, self.request) + return strategy.generate() + + +def generate_ascii_chart( + data: List[Any], chart_type: str, width: int = 80, height: int = 20 +) -> str: + """Generate ASCII art chart from data.""" + if not data or len(data) == 0: + return "No data available for ASCII chart" + + try: + logger.info( + "generate_ascii_chart: chart_type=%s, data_rows=%s", chart_type, len(data) + ) + + # Generate appropriate ASCII chart based on type + if chart_type in ["bar", "column", "echarts_timeseries_bar"]: + logger.info("Generating bar chart") + return _generate_ascii_bar_chart(data, width, height) + elif chart_type in ["line", "echarts_timeseries_line"]: + logger.info("Generating line chart") + return _generate_ascii_line_chart(data, width, height) + elif chart_type in ["scatter", "echarts_timeseries_scatter"]: + logger.info("Generating scatter chart") + return _generate_ascii_scatter_chart(data, width, height) + else: + # Default to table format for unsupported chart types + logger.info( + "Unsupported chart type '%s', falling back to table", chart_type + ) + return _generate_ascii_table(data, width) + except Exception as e: + logger.error("ASCII chart generation failed: %s", e) + import traceback + + logger.error("Traceback: %s", traceback.format_exc()) + return f"ASCII chart generation failed: {str(e)}" + + +def _generate_ascii_bar_chart(data: List[Any], width: int, height: int) -> str: + """Generate enhanced ASCII bar chart with horizontal and vertical options.""" + if not data: + return "No data for bar chart" + + # Extract numeric values for bars + values = [] + labels = [] + + for row in data[:12]: # Increased limit for better charts + if isinstance(row, dict): + # Find numeric and string values + numeric_val = None + label_val = None + + for _key, val in row.items(): + if isinstance(val, (int, float)) and numeric_val is None: + numeric_val = val + elif isinstance(val, str) and label_val is None: + label_val = val + + if numeric_val is not None: + values.append(numeric_val) + labels.append(label_val or f"Item {len(values)}") + + if not values: + return "No numeric data found for bar chart" + + # Decide between horizontal and vertical based on label lengths + avg_label_length = sum(len(str(label)) for label in labels) / len(labels) + use_horizontal = avg_label_length > 8 or len(values) > 8 + + if use_horizontal: + return _generate_horizontal_bar_chart(values, labels, width) + else: + return _generate_vertical_bar_chart(values, labels, width, height) + + +def _generate_horizontal_bar_chart( + values: List[float], labels: List[str], width: int +) -> str: + """Generate horizontal ASCII bar chart.""" + lines = [] + lines.append("๐Ÿ“Š Horizontal Bar Chart") + lines.append("โ•" * min(width, 60)) + + max_val = max(values) if values else 1 + min_val = min(values) if values else 0 + max_bar_width = min(40, width - 20) # Leave space for labels and values + + # Add scale indicator + lines.append(f"Scale: {min_val:.1f} โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ {max_val:.1f}") + lines.append("") + + for label, value in zip(labels, values, strict=False): + # Calculate bar length + if max_val > min_val: + normalized = (value - min_val) / (max_val - min_val) + bar_length = max(1, int(normalized * max_bar_width)) + else: + bar_length = 1 + + # Create bar with gradient effect + bar = _create_gradient_bar(bar_length, value, max_val) + + # Format value + if abs(value) >= 1000000: + value_str = f"{value / 1000000:.1f}M" + elif abs(value) >= 1000: + value_str = f"{value / 1000:.1f}K" + else: + value_str = f"{value:.1f}" + + # Truncate label if too long + display_label = label[:15] if len(label) > 15 else label + lines.append(f"{display_label:>15} โ–{bar:<{max_bar_width}} {value_str}") + + return "\n".join(lines) + + +def _generate_vertical_bar_chart( # noqa: C901 + values: List[float], labels: List[str], width: int, height: int +) -> str: + """Generate vertical ASCII bar chart.""" + lines = [] + lines.append("๐Ÿ“Š Vertical Bar Chart") + lines.append("โ•" * min(width, 60)) + + max_val = max(values) if values else 1 + min_val = min(values) if values else 0 + chart_height = min(15, height - 8) # Leave space for title and labels + + # Create the chart grid + grid = [] + for _ in range(chart_height): + grid.append([" "] * len(values)) + + # Fill the bars + for col, value in enumerate(values): + if max_val > min_val: + normalized = (value - min_val) / (max_val - min_val) + bar_height = max(1, int(normalized * chart_height)) + else: + bar_height = 1 + + # Fill from bottom up + for row_idx in range(chart_height - bar_height, chart_height): + if row_idx < len(grid): + # Use different characters for height effect + if row_idx == chart_height - bar_height: + grid[row_idx][col] = "โ–€" # Top of bar + elif row_idx == chart_height - 1: + grid[row_idx][col] = "โ–ˆ" # Bottom of bar + else: + grid[row_idx][col] = "โ–ˆ" # Middle of bar + + # Add Y-axis scale + for i, row_data in enumerate(grid): + y_val = ( + max_val - (i / (chart_height - 1)) * (max_val - min_val) + if chart_height > 1 + else max_val + ) + if abs(y_val) >= 1000: + y_label = f"{y_val:.0f}" + else: + y_label = f"{y_val:.1f}" + lines.append(f"{y_label:>6} โ”ค " + "".join(f"{cell:^3}" for cell in row_data)) + + # Add X-axis + lines.append(" โ””" + "โ”€โ”€โ”€" * len(values)) + + # Add labels + label_line = " " + for label in labels: + short_label = label[:3] if len(label) > 3 else label + label_line += f"{short_label:^3}" + lines.append(label_line) + + return "\n".join(lines) + + +def _create_gradient_bar(length: int, value: float, max_val: float) -> str: + """Create a gradient bar with visual effects.""" + if length <= 0: + return "" + + # Create gradient effect based on value intensity + intensity = value / max_val if max_val > 0 else 0 + + if intensity > 0.8: + # High values - solid bars + return "โ–ˆ" * length + elif intensity > 0.6: + # Medium-high values - mostly solid with some texture + return "โ–ˆ" * (length - 1) + "โ–‰" if length > 1 else "โ–ˆ" + elif intensity > 0.4: + # Medium values - mixed texture + return "โ–Š" * length + elif intensity > 0.2: + # Low-medium values - lighter texture + return "โ–‹" * length + else: + # Low values - lightest texture + return "โ–Œ" * length + + +def _generate_ascii_line_chart(data: List[Any], width: int, height: int) -> str: + """Generate enhanced ASCII line chart with trend analysis.""" + if not data: + return "No data for line chart" + + lines = [] + lines.append("๐Ÿ“ˆ Line Chart with Trend Analysis") + lines.append("โ•" * min(width, 60)) + + # Extract values and labels for plotting + values, labels = _extract_time_series_data(data) + + if not values: + return "No numeric data found for line chart" + + # Generate enhanced line chart + if len(values) >= 3: + lines.extend(_create_enhanced_line_chart(values, labels, width, height)) + else: + # Fallback to sparkline for small datasets + sparkline_data = _create_sparkline(values) + lines.extend(sparkline_data) + + # Add trend analysis + trend_analysis = _analyze_trend(values) + lines.append("") + lines.append("๐Ÿ“Š Trend Analysis:") + lines.extend(trend_analysis) + + return "\n".join(lines) + + +def _extract_time_series_data(data: List[Any]) -> tuple[List[float], List[str]]: + """Extract time series data with labels.""" + values = [] + labels = [] + + for row in data[:20]: # Limit points for readability + if isinstance(row, dict): + # Find the first numeric value and first string/date value + numeric_val = None + label_val = None + + for key, val in row.items(): + if isinstance(val, (int, float)) and numeric_val is None: + numeric_val = val + elif isinstance(val, str) and label_val is None: + # Use the key name if it looks like a date/time field + if any( + date_word in key.lower() + for date_word in ["date", "time", "month", "day", "year"] + ): + label_val = str(val)[:10] # Truncate long dates + else: + label_val = str(val)[:8] # Truncate long strings + + if numeric_val is not None: + values.append(numeric_val) + labels.append(label_val or f"P{len(values)}") + + return values, labels + + +def _create_enhanced_line_chart( + values: List[float], labels: List[str], width: int, height: int +) -> List[str]: + """Create an enhanced ASCII line chart with better visualization.""" + lines = [] + chart_width = min(50, width - 15) + chart_height = min(12, height - 8) + + if len(values) < 2: + return ["Insufficient data for line chart"] + + # Normalize values to chart height + min_val = min(values) + max_val = max(values) + val_range = max_val - min_val if max_val != min_val else 1 + + # Create chart grid + grid = [[" " for _ in range(chart_width)] for _ in range(chart_height)] + + # Plot the line with connecting segments + prev_x, prev_y = None, None + + for i, value in enumerate(values): + # Map to grid coordinates + x = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0 + y = chart_height - 1 - int(((value - min_val) / val_range) * (chart_height - 1)) + + # Ensure coordinates are within bounds + x = max(0, min(x, chart_width - 1)) + y = max(0, min(y, chart_height - 1)) + + # Mark the point + grid[y][x] = "โ—" + + # Draw line segment to previous point + if prev_x is not None and prev_y is not None: + _draw_line_segment(grid, prev_x, prev_y, x, y, chart_width, chart_height) + + prev_x, prev_y = x, y + + # Render the chart with Y-axis labels + for i, row in enumerate(grid): + y_val = ( + max_val - (i / (chart_height - 1)) * val_range + if chart_height > 1 + else max_val + ) + if abs(y_val) >= 1000: + y_label = f"{y_val:.0f}" + else: + y_label = f"{y_val:.1f}" + lines.append(f"{y_label:>8} โ”ค " + "".join(row)) + + # Add X-axis + lines.append(" โ””" + "โ”€" * chart_width) + + # Add selected X-axis labels (show every few labels to avoid crowding) + label_line = " " + step = max(1, len(labels) // 6) # Show max 6 labels + for i in range(0, len(labels), step): + pos = int((i / (len(values) - 1)) * (chart_width - 1)) if len(values) > 1 else 0 + # Add spacing to position the label + while len(label_line) - 10 < pos: + label_line += " " + label_line += labels[i][:8] + + lines.append(label_line) + + return lines + + +def _draw_line_segment( + grid: List[List[str]], x1: int, y1: int, x2: int, y2: int, width: int, height: int +) -> None: + """Draw a line segment between two points using Bresenham-like algorithm.""" + # Simple line drawing - connect points with line characters + if x1 == x2: # Vertical line + start_y, end_y = sorted([y1, y2]) + for y in range(start_y + 1, end_y): + if 0 <= y < height and 0 <= x1 < width: + grid[y][x1] = "โ”‚" + elif y1 == y2: # Horizontal line + start_x, end_x = sorted([x1, x2]) + for x in range(start_x + 1, end_x): + if 0 <= y1 < height and 0 <= x < width: + grid[y1][x] = "โ”€" + else: # Diagonal line - use simple interpolation + steps = max(abs(x2 - x1), abs(y2 - y1)) + for step in range(1, steps): + x = x1 + int((x2 - x1) * step / steps) + y = y1 + int((y2 - y1) * step / steps) + if 0 <= x < width and 0 <= y < height: + if abs(x2 - x1) > abs(y2 - y1): + grid[y][x] = "โ”€" + else: + grid[y][x] = "โ”‚" + + +def _analyze_trend(values: List[float]) -> List[str]: + """Analyze trend in the data.""" + if len(values) < 2: + return ["โ€ข Insufficient data for trend analysis"] + + analysis = [] + + # Calculate basic statistics + first_val = values[0] + last_val = values[-1] + min_val = min(values) + max_val = max(values) + avg_val = sum(values) / len(values) + + # Overall trend + if last_val > first_val * 1.1: + trend_icon = "๐Ÿ“ˆ" + trend_desc = "Strong upward trend" + elif last_val > first_val * 1.05: + trend_icon = "๐Ÿ“Š" + trend_desc = "Moderate upward trend" + elif last_val < first_val * 0.9: + trend_icon = "๐Ÿ“‰" + trend_desc = "Strong downward trend" + elif last_val < first_val * 0.95: + trend_icon = "๐Ÿ“Š" + trend_desc = "Moderate downward trend" + else: + trend_icon = "โžก๏ธ" + trend_desc = "Relatively stable" + + analysis.append(f"โ€ข {trend_icon} {trend_desc}") + analysis.append(f"โ€ข Range: {min_val:.1f} to {max_val:.1f} (avg: {avg_val:.1f})") + + # Volatility + if len(values) >= 3: + changes = [abs(values[i] - values[i - 1]) for i in range(1, len(values))] + avg_change = sum(changes) / len(changes) + volatility = "High" if avg_change > (max_val - min_val) * 0.1 else "Low" + analysis.append(f"โ€ข Volatility: {volatility}") + + return analysis + + +def _extract_numeric_values(data: List[Any]) -> List[float]: + """Extract numeric values from data for line chart.""" + values = [] + for row in data[:20]: # Limit points + if isinstance(row, dict): + for _key, val in row.items(): + if isinstance(val, (int, float)): + values.append(val) + break + return values + + +def _create_sparkline(values: List[float]) -> List[str]: + """Create sparkline visualization from values.""" + if len(values) <= 1: + return [] + + max_val = max(values) + min_val = min(values) + range_val = max_val - min_val if max_val != min_val else 1 + + sparkline = "" + for val in values: + normalized = (val - min_val) / range_val + if normalized < 0.2: + sparkline += "โ–" + elif normalized < 0.4: + sparkline += "โ–‚" + elif normalized < 0.6: + sparkline += "โ–„" + elif normalized < 0.8: + sparkline += "โ–†" + else: + sparkline += "โ–ˆ" + + # Safe formatting to avoid NaN display + if _is_nan_value(min_val) or _is_nan_value(max_val): + return ["Range: Unable to calculate from data", sparkline] + else: + return [f"Range: {min_val:.2f} to {max_val:.2f}", sparkline] + + +def _is_nan_value(value: Any) -> bool: + """Check if a value is NaN or invalid.""" + try: + import math + + return math.isnan(float(value)) + except (ValueError, TypeError): + return True + + +def _generate_ascii_scatter_chart(data: List[Any], width: int, height: int) -> str: + """Generate ASCII scatter plot.""" + if not data: + return "No data for scatter chart" + + lines = [] + lines.append("ASCII Scatter Plot") + lines.append("=" * min(width, 50)) + + # Extract data points + x_values, y_values, x_column, y_column = _extract_scatter_data(data) + + # Debug info + lines.extend(_create_debug_info(x_values, y_values, x_column, y_column, data)) + + # Check if we have enough data + if len(x_values) < 2: + return _generate_ascii_table(data, width) + + # Add axis info + lines.extend(_create_axis_info(x_values, y_values, x_column, y_column)) + + # Create and render grid + grid = _create_scatter_grid(x_values, y_values, width, height) + lines.extend(_render_scatter_grid(grid, x_values, y_values, width, height)) + + return "\n".join(lines) + + +def _extract_scatter_data( + data: List[Any], +) -> tuple[List[float], List[float], str | None, str | None]: + """Extract X,Y data from scatter chart data.""" + x_values = [] + y_values = [] + x_column = None + y_column = None + numeric_columns = [] + + if data and isinstance(data[0], dict): + # Find the first two numeric columns + for key, val in data[0].items(): + if isinstance(val, (int, float)) and not ( + isinstance(val, float) and (val != val) + ): # Exclude NaN + numeric_columns.append(key) + + if len(numeric_columns) >= 2: + x_column = numeric_columns[0] + y_column = numeric_columns[1] + + # Extract X,Y pairs + for row in data[:50]: # Limit for ASCII display + if isinstance(row, dict): + x_val = row.get(x_column) + y_val = row.get(y_column) + # Check for valid numbers (not NaN) + if ( + isinstance(x_val, (int, float)) + and isinstance(y_val, (int, float)) + and not ( + isinstance(x_val, float) and (x_val != x_val) + ) # Not NaN + and not (isinstance(y_val, float) and (y_val != y_val)) + ): # Not NaN + x_values.append(x_val) + y_values.append(y_val) + + return x_values, y_values, x_column, y_column + + +def _create_debug_info( + x_values: List[float], + y_values: List[float], + x_column: str | None, + y_column: str | None, + data: List[Any], +) -> List[str]: + """Create debug information lines for scatter chart.""" + numeric_columns = [] + if data and isinstance(data[0], dict): + for key, val in data[0].items(): + if isinstance(val, (int, float)) and not ( + isinstance(val, float) and (val != val) + ): + numeric_columns.append(key) + + return [ + f"DEBUG: Found {len(numeric_columns)} numeric columns: {numeric_columns}", + f"DEBUG: X column: {x_column}, Y column: {y_column}", + f"DEBUG: Valid X,Y pairs: {len(x_values)}", + ] + + +def _create_axis_info( + x_values: List[float], + y_values: List[float], + x_column: str | None, + y_column: str | None, +) -> List[str]: + """Create axis information lines.""" + return [ + f"X-axis: {x_column} (range: {min(x_values):.2f} to {max(x_values):.2f})", + f"Y-axis: {y_column} (range: {min(y_values):.2f} to {max(y_values):.2f})", + f"Showing {len(x_values)} data points", + "", + ] + + +def _create_scatter_grid( + x_values: List[float], y_values: List[float], width: int, height: int +) -> List[List[str]]: + """Create and populate the scatter plot grid.""" + plot_width = min(40, width - 10) + plot_height = min(15, height - 8) + + # Normalize values to fit in grid + x_min, x_max = min(x_values), max(x_values) + y_min, y_max = min(y_values), max(y_values) + x_range = x_max - x_min if x_max != x_min else 1 + y_range = y_max - y_min if y_max != y_min else 1 + + # Create grid + grid = [[" " for _ in range(plot_width)] for _ in range(plot_height)] + + # Plot points + for x, y in zip(x_values, y_values, strict=False): + try: + grid_x = int(((x - x_min) / x_range) * (plot_width - 1)) + grid_y = int(((y - y_min) / y_range) * (plot_height - 1)) + grid_y = plot_height - 1 - grid_y # Flip Y axis for display + + if 0 <= grid_x < plot_width and 0 <= grid_y < plot_height: + if grid[grid_y][grid_x] == " ": + grid[grid_y][grid_x] = "โ€ข" + else: + grid[grid_y][grid_x] = "โ–ˆ" # Multiple points + except (ValueError, OverflowError): + # Skip points that can't be converted to integers (NaN, inf, etc.) + continue + + return grid + + +def _render_scatter_grid( + grid: List[List[str]], + x_values: List[float], + y_values: List[float], + width: int, + height: int, +) -> List[str]: + """Render the scatter plot grid with axes and labels.""" + lines = [] + plot_width = min(40, width - 10) + plot_height = min(15, height - 8) + + x_min, x_max = min(x_values), max(x_values) + y_min, y_max = min(y_values), max(y_values) + y_range = y_max - y_min if y_max != y_min else 1 + + # Add Y-axis labels and plot + for i, row in enumerate(grid): + y_val = y_max - (i / (plot_height - 1)) * y_range if plot_height > 1 else y_max + y_label = f"{y_val:.1f}" if abs(y_val) < 1000 else f"{y_val:.0f}" + lines.append(f"{y_label:>6} |{''.join(row)}") + + # Add X-axis + x_axis_line = " " * 7 + "+" + "-" * plot_width + lines.append(x_axis_line) + + # Add X-axis labels + x_left_label = f"{x_min:.1f}" if abs(x_min) < 1000 else f"{x_min:.0f}" + x_right_label = f"{x_max:.1f}" if abs(x_max) < 1000 else f"{x_max:.0f}" + x_labels = ( + " " * 8 + + x_left_label + + " " * (plot_width - len(x_left_label) - len(x_right_label)) + + x_right_label + ) + lines.append(x_labels) + + return lines + + +def _generate_ascii_table(data: List[Any], width: int) -> str: + """Generate enhanced ASCII table with better formatting.""" + if not data: + return "No data for table" + + lines = [] + lines.append("๐Ÿ“‹ Data Table") + lines.append("โ•" * min(width, 70)) + + # Get column headers from first row + if isinstance(data[0], dict): + # Select best columns to display + all_headers = list(data[0].keys()) + headers = _select_display_columns(all_headers, data, max_cols=6) + + # Calculate optimal column widths + col_widths = _calculate_column_widths(headers, data, width) + + # Create enhanced header + lines.append(_create_table_header(headers, col_widths)) + lines.append(_create_table_separator(col_widths)) + + # Add data rows with better formatting + row_count = min(15, len(data)) # Show more rows + for i, row in enumerate(data[:row_count]): + formatted_row = _format_table_row(row, headers, col_widths) + lines.append(formatted_row) + + # Add separator every 5 rows for readability + if i > 0 and (i + 1) % 5 == 0 and i < row_count - 1: + lines.append(_create_light_separator(col_widths)) + + # Add footer with stats + lines.append(_create_table_separator(col_widths)) + lines.append(f"๐Ÿ“Š Showing {row_count} of {len(data)} rows") + + # Add column summaries for numeric columns + numeric_summaries = _create_numeric_summaries(data, headers) + if numeric_summaries: + lines.append("") + lines.append("๐Ÿ“ˆ Numeric Summaries:") + lines.extend(numeric_summaries) + + return "\n".join(lines) + + +def _select_display_columns( + all_headers: List[str], data: List[Any], max_cols: int = 6 +) -> List[str]: + """Select the most interesting columns to display.""" + if len(all_headers) <= max_cols: + return all_headers + + # Prioritize columns by interest level + priority_scores = {} + + for header in all_headers: + score = 0 + header_lower = header.lower() + + # Higher priority for common business fields + if any(word in header_lower for word in ["name", "title", "id"]): + score += 10 + if any( + word in header_lower + for word in ["amount", "price", "cost", "revenue", "sales"] + ): + score += 8 + if any(word in header_lower for word in ["date", "time", "created", "updated"]): + score += 6 + if any(word in header_lower for word in ["count", "total", "sum", "avg"]): + score += 5 + + # Check data variety (more variety = more interesting) + sample_values = [ + str(row.get(header, "")) for row in data[:10] if isinstance(row, dict) + ] + unique_values = len(set(sample_values)) + if unique_values > 1: + score += min(unique_values, 5) + + priority_scores[header] = score + + # Return top scoring columns + sorted_headers = sorted( + all_headers, key=lambda h: priority_scores.get(h, 0), reverse=True + ) + return sorted_headers[:max_cols] + + +def _calculate_column_widths( + headers: List[str], data: List[Any], total_width: int +) -> List[int]: + """Calculate optimal column widths based on content.""" + if not headers: + return [] + + # Start with minimum widths based on header lengths + min_widths = [max(8, min(len(h) + 2, 15)) for h in headers] + + # Check actual data to adjust widths + for row in data[:10]: # Sample first 10 rows + if isinstance(row, dict): + for i, header in enumerate(headers): + val = row.get(header, "") + if isinstance(val, float): + content_len = len(f"{val:.2f}") + elif isinstance(val, int): + content_len = len(str(val)) + else: + content_len = len(str(val)) + + min_widths[i] = max(min_widths[i], min(content_len + 1, 20)) + + # Distribute remaining space proportionally + used_width = sum(min_widths) + len(headers) * 3 # Account for separators + available_width = min(total_width - 10, 80) # Leave margins + + if used_width < available_width: + # Distribute extra space + extra_space = available_width - used_width + for i in range(len(min_widths)): + min_widths[i] += extra_space // len(min_widths) + + return min_widths + + +def _create_table_header(headers: List[str], widths: List[int]) -> str: + """Create formatted table header.""" + formatted_headers = [] + for header, width in zip(headers, widths, strict=False): + # Truncate and center header + display_header = header[: width - 2] if len(header) > width - 2 else header + formatted_headers.append(f"{display_header:^{width}}") + + return ( + "โ”Œ" + + "โ”ฌ".join("โ”€" * w for w in widths) + + "โ”\nโ”‚" + + "โ”‚".join(formatted_headers) + + "โ”‚" + ) + + +def _create_table_separator(widths: List[int]) -> str: + """Create table separator line.""" + return "โ”œ" + "โ”ผ".join("โ”€" * w for w in widths) + "โ”ค" + + +def _create_light_separator(widths: List[int]) -> str: + """Create light separator line.""" + return "โ”œ" + "โ”ผ".join("โ”ˆ" * w for w in widths) + "โ”ค" + + +def _format_table_row( + row: Dict[str, Any], headers: List[str], widths: List[int] +) -> str: + """Format a single table row.""" + formatted_values = [] + + for header, width in zip(headers, widths, strict=False): + val = row.get(header, "") + + # Format value based on type + if isinstance(val, float): + if abs(val) >= 1000000: + formatted_val = f"{val / 1000000:.1f}M" + elif abs(val) >= 1000: + formatted_val = f"{val / 1000:.1f}K" + else: + formatted_val = f"{val:.2f}" + elif isinstance(val, int): + if abs(val) >= 1000000: + formatted_val = f"{val // 1000000}M" + elif abs(val) >= 1000: + formatted_val = f"{val // 1000}K" + else: + formatted_val = str(val) + else: + formatted_val = str(val) + + # Truncate if too long + if len(formatted_val) > width - 2: + formatted_val = formatted_val[: width - 5] + "..." + + # Align numbers right, text left + if isinstance(val, (int, float)): + formatted_values.append(f"{formatted_val:>{width - 2}}") + else: + formatted_values.append(f"{formatted_val:<{width - 2}}") + + return "โ”‚ " + " โ”‚ ".join(formatted_values) + " โ”‚" + + +def _create_numeric_summaries(data: List[Any], headers: List[str]) -> List[str]: + """Create summaries for numeric columns.""" + summaries = [] + + for header in headers: + # Collect numeric values + values = [] + for row in data: + if isinstance(row, dict): + val = row.get(header) + if isinstance(val, (int, float)): + values.append(val) + + if len(values) >= 2: + min_val = min(values) + max_val = max(values) + avg_val = sum(values) / len(values) + + if abs(avg_val) >= 1000: + avg_str = f"{avg_val / 1000:.1f}K" + else: + avg_str = f"{avg_val:.1f}" + + summaries.append( + f" {header}: avg={avg_str}, range={min_val:.1f}-{max_val:.1f}" + ) + + return summaries + + +async def _get_chart_preview_internal( # noqa: C901 + request: GetChartPreviewRequest, + ctx: Context, +) -> ChartPreview | ChartError: + """ + Get a visual preview of a chart with URLs for LLM embedding. + + This tool generates or retrieves URLs for chart images that can be + displayed directly in LLM clients. The URLs point to Superset's + screenshot endpoints for proper image serving. + + Supports lookup by: + - Numeric ID (e.g., 123) + - UUID string (e.g., "a1b2c3d4-e5f6-7890-abcd-ef1234567890") + + Returns a ChartPreview with Superset URLs for the chart image or + ChartError on error. + """ + try: + await ctx.report_progress(1, 3, "Looking up chart") + from superset.daos.chart import ChartDAO + + # Find the chart + chart: Any = None + if isinstance(request.identifier, int) or ( + isinstance(request.identifier, str) and request.identifier.isdigit() + ): + chart_id = ( + int(request.identifier) + if isinstance(request.identifier, str) + else request.identifier + ) + await ctx.debug( + "Performing ID-based chart lookup: chart_id=%s" % (chart_id,) + ) + chart = ChartDAO.find_by_id(chart_id) + else: + await ctx.debug( + "Performing UUID-based chart lookup: uuid=%s" % (request.identifier,) + ) + # Try UUID lookup using DAO flexible method + chart = ChartDAO.find_by_id(request.identifier, id_column="uuid") + + # If not found and looks like a form_data_key, try to create transient chart + if ( + not chart + and isinstance(request.identifier, str) + and len(request.identifier) > 8 + ): + # This might be a form_data_key, try to get form data from cache + from superset.commands.explore.form_data.get import GetFormDataCommand + from superset.commands.explore.form_data.parameters import ( + CommandParameters, + ) + + try: + cmd_params = CommandParameters(key=request.identifier) + cmd = GetFormDataCommand(cmd_params) + form_data_json = cmd.run() + if form_data_json: + from superset.utils import json as utils_json + + form_data = utils_json.loads(form_data_json) + + # Create a transient chart object from form data + class TransientChart: + def __init__(self, form_data: Dict[str, Any]): + self.id = None + self.slice_name = "Unsaved Chart Preview" + self.viz_type = form_data.get("viz_type", "table") + self.datasource_id = None + self.datasource_type = "table" + self.params = utils_json.dumps(form_data) + self.form_data = form_data + self.uuid = None + + chart = TransientChart(form_data) + except Exception as e: + # Form data key not found or invalid + logger.debug( + "Failed to get form data for key %s: %s", request.identifier, e + ) + + if not chart: + await ctx.error("Chart not found: identifier=%s" % (request.identifier,)) + return ChartError( + error=f"No chart found with identifier: {request.identifier}", + error_type="NotFound", + ) + + await ctx.info( + "Chart found successfully: chart_id=%s, chart_name=%s, viz_type=%s" + % ( + getattr(chart, "id", None), + getattr(chart, "slice_name", None), + getattr(chart, "viz_type", None), + ) + ) + + # Log all chart attributes for debugging + logger.info( + "Chart object type: %s, id value: %s, id type: %s", + type(chart).__name__, + getattr(chart, "id", "NO_ID"), + type(getattr(chart, "id", None)), + ) + logger.info("Generating preview for chart %s", getattr(chart, "id", "NO_ID")) + logger.info("Chart datasource_id: %s", getattr(chart, "datasource_id", "NONE")) + + import time + + start_time = time.time() + + await ctx.report_progress(2, 3, f"Generating {request.format} preview") + await ctx.debug( + "Preview generation parameters: chart_id=%s, viz_type=%s, " + "datasource_id=%s, width=%s, height=%s" + % ( + chart.id, + chart.viz_type, + chart.datasource_id, + request.width, + request.height, + ) + ) + + # Handle different preview formats using strategy pattern + preview_generator = PreviewFormatGenerator(chart, request) + content = preview_generator.generate() + + if isinstance(content, ChartError): + await ctx.error( + "Preview generation failed: chart_id=%s, format=%s, error=%s, " + "error_type=%s" + % ( + chart.id, + request.format, + content.error, + content.error_type, + ) + ) + return content + + await ctx.report_progress(3, 3, "Building response") + + # Create performance and accessibility metadata + execution_time = int((time.time() - start_time) * 1000) + performance = PerformanceMetadata( + query_duration_ms=execution_time, + cache_status="miss", + optimization_suggestions=[], + ) + + accessibility = AccessibilityMetadata( + color_blind_safe=True, + alt_text=f"Preview of {chart.slice_name or f'Chart {chart.id}'}", + high_contrast_available=False, + ) + + await ctx.debug( + "Preview generation completed: execution_time_ms=%s, content_type=%s" + % ( + execution_time, + type(content).__name__, + ) + ) + + # Create backward-compatible response with enhanced metadata + result = ChartPreview( + chart_id=chart.id, + chart_name=chart.slice_name or f"Chart {chart.id}", + chart_type=chart.viz_type or "unknown", + explore_url=f"{get_superset_base_url()}/explore/?slice_id={chart.id}", + content=content, + chart_description=( + f"Preview of {chart.viz_type or 'chart'}: " + f"{chart.slice_name or f'Chart {chart.id}'}" + ), + accessibility=accessibility, + performance=performance, + ) + + # Add format-specific fields for backward compatibility + if isinstance(content, URLPreview): + result.format = "url" + result.preview_url = content.preview_url + result.width = content.width + result.height = content.height + elif isinstance(content, ASCIIPreview): + result.format = "ascii" + result.ascii_chart = content.ascii_content + result.width = content.width + result.height = content.height + elif isinstance(content, TablePreview): + result.format = "table" + result.table_data = content.table_data + # Base64 preview support removed + + return result + + except Exception as e: + await ctx.error( + "Chart preview generation failed: identifier=%s, format=%s, error=%s, " + "error_type=%s" + % ( + request.identifier, + request.format, + str(e), + type(e).__name__, + ) + ) + logger.error("Error in get_chart_preview: %s", e) + return ChartError( + error=f"Failed to get chart preview: {str(e)}", error_type="InternalError" + ) + + +@mcp.tool +@mcp_auth_hook +async def get_chart_preview( + request: GetChartPreviewRequest, ctx: Context +) -> ChartPreview | ChartError: + """Get chart preview by ID or UUID. + + Returns preview URL or formatted content (ascii, table, vega_lite). + """ + await ctx.info( + "Starting chart preview generation: identifier=%s, format=%s, width=%s, " + "height=%s" + % ( + request.identifier, + request.format, + request.width, + request.height, + ) + ) + await ctx.debug( + "Cache control settings: use_cache=%s, force_refresh=%s, cache_timeout=%s" + % ( + request.use_cache, + request.force_refresh, + request.cache_timeout, + ) + ) + + try: + result = await _get_chart_preview_internal(request, ctx) + + if isinstance(result, ChartPreview): + await ctx.info( + "Chart preview generated successfully: chart_id=%s, format=%s, " + "has_preview_url=%s" + % ( + getattr(result, "chart_id", None), + result.format, + bool(getattr(result, "preview_url", None)), + ) + ) + else: + await ctx.warning( + "Chart preview generation failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + except Exception as e: + await ctx.error( + "Chart preview generation failed: identifier=%s, error=%s, error_type=%s" + % ( + request.identifier, + str(e), + type(e).__name__, + ) + ) + return ChartError( + error=f"Failed to generate chart preview: {str(e)}", + error_type="InternalError", + ) diff --git a/superset/mcp_service/chart/tool/list_charts.py b/superset/mcp_service/chart/tool/list_charts.py index a5519d0231a..ae8d8013ec6 100644 --- a/superset/mcp_service/chart/tool/list_charts.py +++ b/superset/mcp_service/chart/tool/list_charts.py @@ -20,14 +20,19 @@ MCP tool: list_charts (advanced filtering with metadata cache control) """ import logging +from typing import Any, cast, TYPE_CHECKING from fastmcp import Context +if TYPE_CHECKING: + from superset.models.slice import Slice + from superset.mcp_service.app import mcp from superset.mcp_service.auth import mcp_auth_hook from superset.mcp_service.chart.schemas import ( ChartFilter, ChartInfo, + ChartLike, ChartList, ListChartsRequest, serialize_chart_object, @@ -89,10 +94,14 @@ async def list_charts(request: ListChartsRequest, ctx: Context) -> ChartList: from superset.daos.chart import ChartDAO + def _serialize_chart(obj: "Slice | None", cols: Any) -> ChartInfo | None: + """Serialize chart object with proper type casting.""" + return serialize_chart_object(cast(ChartLike | None, obj)) + tool = ModelListCore( dao_class=ChartDAO, output_schema=ChartInfo, - item_serializer=lambda obj, cols: serialize_chart_object(obj) if obj else None, # type: ignore[arg-type] + item_serializer=_serialize_chart, filter_type=ChartFilter, default_columns=DEFAULT_CHART_COLUMNS, search_columns=[ diff --git a/superset/mcp_service/chart/tool/update_chart.py b/superset/mcp_service/chart/tool/update_chart.py new file mode 100644 index 00000000000..9ed3052a5a5 --- /dev/null +++ b/superset/mcp_service/chart/tool/update_chart.py @@ -0,0 +1,224 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP tool: update_chart +""" + +import logging +import time + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.chart.chart_utils import ( + analyze_chart_capabilities, + analyze_chart_semantics, + generate_chart_name, + map_config_to_form_data, +) +from superset.mcp_service.chart.schemas import ( + AccessibilityMetadata, + GenerateChartResponse, + PerformanceMetadata, + UpdateChartRequest, +) +from superset.mcp_service.utils.url_utils import ( + get_chart_screenshot_url, + get_superset_base_url, +) +from superset.utils import json + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +async def update_chart( + request: UpdateChartRequest, ctx: Context +) -> GenerateChartResponse: + """Update existing chart with new configuration. + + IMPORTANT: + - Chart must already be saved (from generate_chart with save_chart=True) + - LLM clients MUST display updated chart URL to users + - Embed preview_url as image: ![Updated Chart](preview_url) + + Use when: + - Modifying existing saved chart + - Updating title, filters, or visualization settings + - Changing chart type or data columns + + Returns: + - Updated chart info and metadata + - Preview URL and explore URL for further editing + """ + start_time = time.time() + + try: + # Find the existing chart + from superset.daos.chart import ChartDAO + + chart = None + if isinstance(request.identifier, int) or ( + isinstance(request.identifier, str) and request.identifier.isdigit() + ): + chart_id = ( + int(request.identifier) + if isinstance(request.identifier, str) + else request.identifier + ) + chart = ChartDAO.find_by_id(chart_id) + else: + # Try UUID lookup using DAO flexible method + chart = ChartDAO.find_by_id(request.identifier, id_column="uuid") + + if not chart: + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": f"No chart found with identifier: {request.identifier}", + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) + + # Map the new config to form_data format + new_form_data = map_config_to_form_data(request.config) + + # Update chart using Superset's command + from superset.commands.chart.update import UpdateChartCommand + + # Generate new chart name if provided, otherwise keep existing + chart_name = ( + request.chart_name + if request.chart_name + else chart.slice_name or generate_chart_name(request.config) + ) + + update_payload = { + "slice_name": chart_name, + "viz_type": new_form_data["viz_type"], + "params": json.dumps(new_form_data), + } + + command = UpdateChartCommand(chart.id, update_payload) + updated_chart = command.run() + + # Generate semantic analysis + capabilities = analyze_chart_capabilities(updated_chart, request.config) + semantics = analyze_chart_semantics(updated_chart, request.config) + + # Create performance metadata + execution_time = int((time.time() - start_time) * 1000) + performance = PerformanceMetadata( + query_duration_ms=execution_time, + cache_status="miss", + optimization_suggestions=[], + ) + + # Create accessibility metadata + chart_name = ( + updated_chart.slice_name + if updated_chart and hasattr(updated_chart, "slice_name") + else generate_chart_name(request.config) + ) + accessibility = AccessibilityMetadata( + color_blind_safe=True, # Would need actual analysis + alt_text=f"Updated chart showing {chart_name}", + high_contrast_available=False, + ) + + # Generate previews if requested + previews = {} + if request.generate_preview: + try: + from superset.mcp_service.chart.tool.get_chart_preview import ( + _get_chart_preview_internal, + GetChartPreviewRequest, + ) + + for format_type in request.preview_formats: + preview_request = GetChartPreviewRequest( + identifier=str(updated_chart.id), format=format_type + ) + preview_result = await _get_chart_preview_internal( + preview_request, ctx + ) + + if hasattr(preview_result, "content"): + previews[format_type] = preview_result.content + + except Exception as e: + # Log warning but don't fail the entire request + logger.warning("Preview generation failed: %s", e) + + # Return enhanced data + result = { + "chart": { + "id": updated_chart.id, + "slice_name": updated_chart.slice_name, + "viz_type": updated_chart.viz_type, + "url": ( + f"{get_superset_base_url()}/explore/?slice_id={updated_chart.id}" + ), + "uuid": str(updated_chart.uuid) if updated_chart.uuid else None, + "updated": True, + }, + "error": None, + # Enhanced fields for better LLM integration + "previews": previews, + "capabilities": capabilities.model_dump() if capabilities else None, + "semantics": semantics.model_dump() if semantics else None, + "explore_url": ( + f"{get_superset_base_url()}/explore/?slice_id={updated_chart.id}" + ), + "api_endpoints": { + "data": ( + f"{get_superset_base_url()}/api/v1/chart/{updated_chart.id}/data/" + ), + "preview": get_chart_screenshot_url(updated_chart.id), + "export": ( + f"{get_superset_base_url()}/api/v1/chart/{updated_chart.id}/export/" + ), + }, + "performance": performance.model_dump() if performance else None, + "accessibility": accessibility.model_dump() if accessibility else None, + "success": True, + "schema_version": "2.0", + "api_version": "v1", + } + return GenerateChartResponse.model_validate(result) + + except Exception as e: + execution_time = int((time.time() - start_time) * 1000) + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": f"Chart update failed: {str(e)}", + "performance": { + "query_duration_ms": execution_time, + "cache_status": "error", + "optimization_suggestions": [], + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) diff --git a/superset/mcp_service/chart/tool/update_chart_preview.py b/superset/mcp_service/chart/tool/update_chart_preview.py new file mode 100644 index 00000000000..5b01d1a0b14 --- /dev/null +++ b/superset/mcp_service/chart/tool/update_chart_preview.py @@ -0,0 +1,168 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP tool: update_chart_preview +""" + +import logging +import time +from typing import Any, Dict + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.chart.chart_utils import ( + analyze_chart_capabilities, + analyze_chart_semantics, + generate_chart_name, + generate_explore_link, + map_config_to_form_data, +) +from superset.mcp_service.chart.schemas import ( + AccessibilityMetadata, + PerformanceMetadata, + UpdateChartPreviewRequest, + URLPreview, +) +from superset.mcp_service.utils.url_utils import get_mcp_service_url + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +def update_chart_preview( + request: UpdateChartPreviewRequest, ctx: Context +) -> Dict[str, Any]: + """Update cached chart preview without saving. + + IMPORTANT: + - Modifies cached form_data from generate_chart (save_chart=False) + - Original form_data_key is invalidated, new one returned + - LLM clients MUST display explore_url to users + - Embed preview_url as image: ![Chart Preview](preview_url) + + Use when: + - Modifying preview before deciding to save + - Iterating on chart design without creating permanent charts + - Testing different configurations + + Returns new form_data_key, preview images, and explore URL. + """ + start_time = time.time() + + try: + # Map the new config to form_data format + new_form_data = map_config_to_form_data(request.config) + + # Generate new explore link with updated form_data + explore_url = generate_explore_link(request.dataset_id, new_form_data) + + # Extract new form_data_key from the explore URL + new_form_data_key = None + if "form_data_key=" in explore_url: + new_form_data_key = explore_url.split("form_data_key=")[1].split("&")[0] + + # Generate semantic analysis + capabilities = analyze_chart_capabilities(None, request.config) + semantics = analyze_chart_semantics(None, request.config) + + # Create performance metadata + execution_time = int((time.time() - start_time) * 1000) + performance = PerformanceMetadata( + query_duration_ms=execution_time, + cache_status="miss", + optimization_suggestions=[], + ) + + # Create accessibility metadata + chart_name = generate_chart_name(request.config) + accessibility = AccessibilityMetadata( + color_blind_safe=True, # Would need actual analysis + alt_text=f"Updated chart preview showing {chart_name}", + high_contrast_available=False, + ) + + # Generate previews if requested + previews = {} + if request.generate_preview and new_form_data_key: + try: + for format_type in request.preview_formats: + if format_type == "url": + # Generate screenshot URL using new form_data key + mcp_base = get_mcp_service_url() + preview_url = ( + f"{mcp_base}/screenshot/explore/{new_form_data_key}.png" + ) + + previews[format_type] = URLPreview( + preview_url=preview_url, + width=800, + height=600, + supports_interaction=False, + ) + # Other formats would need form_data execution + # which is more complex for preview-only mode + + except Exception as e: + # Log warning but don't fail the entire request + logger.warning("Preview generation failed: %s", e) + + # Return enhanced data + result = { + "chart": { + "id": None, + "slice_name": chart_name, + "viz_type": new_form_data.get("viz_type"), + "url": explore_url, + "uuid": None, + "saved": False, + "updated": True, + }, + "error": None, + # Enhanced fields for better LLM integration + "previews": previews, + "capabilities": capabilities.model_dump() if capabilities else None, + "semantics": semantics.model_dump() if semantics else None, + "explore_url": explore_url, + "form_data_key": new_form_data_key, + "previous_form_data_key": request.form_data_key, # For reference + "api_endpoints": {}, # No API endpoints for unsaved charts + "performance": performance.model_dump() if performance else None, + "accessibility": accessibility.model_dump() if accessibility else None, + "success": True, + "schema_version": "2.0", + "api_version": "v1", + } + return result + + except Exception as e: + execution_time = int((time.time() - start_time) * 1000) + return { + "chart": None, + "error": f"Chart preview update failed: {str(e)}", + "performance": { + "query_duration_ms": execution_time, + "cache_status": "error", + "optimization_suggestions": [], + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } diff --git a/superset/mcp_service/chart/validation/__init__.py b/superset/mcp_service/chart/validation/__init__.py new file mode 100644 index 00000000000..9d1f7073b0f --- /dev/null +++ b/superset/mcp_service/chart/validation/__init__.py @@ -0,0 +1,24 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Chart validation module.""" + +from .dataset_validator import DatasetValidator +from .pipeline import ValidationPipeline +from .schema_validator import SchemaValidator + +__all__ = ["ValidationPipeline", "SchemaValidator", "DatasetValidator"] diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py new file mode 100644 index 00000000000..b03d0ffe9c0 --- /dev/null +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -0,0 +1,329 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Dataset-specific validation for chart configurations. +Validates that referenced columns exist in the dataset schema. +""" + +import difflib +import logging +from typing import Dict, List, Tuple + +from superset.mcp_service.chart.schemas import ( + ColumnRef, + TableChartConfig, + XYChartConfig, +) +from superset.mcp_service.common.error_schemas import ( + ChartGenerationError, + ColumnSuggestion, + DatasetContext, +) + +logger = logging.getLogger(__name__) + + +class DatasetValidator: + """Validates chart configuration against dataset schema.""" + + @staticmethod + def validate_against_dataset( + config: TableChartConfig | XYChartConfig, dataset_id: int | str + ) -> Tuple[bool, ChartGenerationError | None]: + """ + Validate chart configuration against dataset schema. + + Args: + config: Chart configuration to validate + dataset_id: Dataset ID to validate against + + Returns: + Tuple of (is_valid, error) + """ + # Get dataset context + dataset_context = DatasetValidator._get_dataset_context(dataset_id) + if not dataset_context: + from superset.mcp_service.utils.error_builder import ( + ChartErrorBuilder, + ) + + return False, ChartErrorBuilder.dataset_not_found_error(dataset_id) + + # Collect all column references + column_refs = DatasetValidator._extract_column_references(config) + + # Validate each column exists + invalid_columns = [] + for col_ref in column_refs: + if not DatasetValidator._column_exists(col_ref.name, dataset_context): + invalid_columns.append(col_ref) + + if invalid_columns: + # Generate suggestions for invalid columns + suggestions_map = {} + for col_ref in invalid_columns: + suggestions = DatasetValidator._get_column_suggestions( + col_ref.name, dataset_context + ) + suggestions_map[col_ref.name] = suggestions + + # Build error with suggestions + return False, DatasetValidator._build_column_error( + invalid_columns, suggestions_map, dataset_context + ) + + # Validate aggregation compatibility + if isinstance(config, (TableChartConfig, XYChartConfig)): + aggregation_errors = DatasetValidator._validate_aggregations( + column_refs, dataset_context + ) + if aggregation_errors: + return False, aggregation_errors[0] + + return True, None + + @staticmethod + def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None: + """Get dataset context with column information.""" + try: + from superset.daos.dataset import DatasetDAO + + # Find dataset + if isinstance(dataset_id, int) or ( + isinstance(dataset_id, str) and dataset_id.isdigit() + ): + dataset = DatasetDAO.find_by_id(int(dataset_id)) + else: + # Try UUID lookup + dataset = DatasetDAO.find_by_id(dataset_id, id_column="uuid") + + if not dataset: + return None + + # Build context + columns = [] + metrics = [] + + # Add table columns + for col in dataset.columns: + columns.append( + { + "name": col.column_name, + "type": str(col.type) if col.type else "UNKNOWN", + "is_temporal": col.is_temporal + if hasattr(col, "is_temporal") + else False, + "is_numeric": col.is_numeric + if hasattr(col, "is_numeric") + else False, + } + ) + + # Add metrics + for metric in dataset.metrics: + metrics.append( + { + "name": metric.metric_name, + "expression": metric.expression, + "description": metric.description, + } + ) + + return DatasetContext( + id=dataset.id, + table_name=dataset.table_name, + schema=dataset.schema, + database_name=dataset.database.database_name + if dataset.database + else None, + available_columns=columns, + available_metrics=metrics, + ) + + except Exception as e: + logger.error("Error getting dataset context for %s: %s", dataset_id, e) + return None + + @staticmethod + def _extract_column_references( + config: TableChartConfig | XYChartConfig, + ) -> List[ColumnRef]: + """Extract all column references from configuration.""" + refs = [] + + if isinstance(config, TableChartConfig): + refs.extend(config.columns) + elif isinstance(config, XYChartConfig): + refs.append(config.x) + refs.extend(config.y) + if config.group_by: + refs.append(config.group_by) + + # Add filter columns + if hasattr(config, "filters") and config.filters: + for filter_config in config.filters: + refs.append(ColumnRef(name=filter_config.column)) + + return refs + + @staticmethod + def _column_exists(column_name: str, dataset_context: DatasetContext) -> bool: + """Check if column exists in dataset (case-insensitive).""" + column_lower = column_name.lower() + + # Check regular columns + for col in dataset_context.available_columns: + if col["name"].lower() == column_lower: + return True + + # Check metrics + for metric in dataset_context.available_metrics: + if metric["name"].lower() == column_lower: + return True + + return False + + @staticmethod + def _get_column_suggestions( + column_name: str, dataset_context: DatasetContext, max_suggestions: int = 3 + ) -> List[ColumnSuggestion]: + """Get column name suggestions using fuzzy matching.""" + all_names = [] + + # Collect all column names + for col in dataset_context.available_columns: + all_names.append((col["name"], "column", col.get("type", "UNKNOWN"))) + + for metric in dataset_context.available_metrics: + all_names.append((metric["name"], "metric", "METRIC")) + + # Find close matches + column_lower = column_name.lower() + close_matches = difflib.get_close_matches( + column_lower, + [name[0].lower() for name in all_names], + n=max_suggestions, + cutoff=0.6, + ) + + # Build suggestions with proper case and type info + suggestions = [] + for match in close_matches: + for name, col_type, data_type in all_names: + if name.lower() == match: + suggestions.append( + ColumnSuggestion(name=name, type=col_type, data_type=data_type) + ) + break + + return suggestions + + @staticmethod + def _build_column_error( + invalid_columns: List[ColumnRef], + suggestions_map: Dict[str, List[ColumnSuggestion]], + dataset_context: DatasetContext, + ) -> ChartGenerationError: + """Build error for invalid columns.""" + from superset.mcp_service.utils.error_builder import ( + ChartErrorBuilder, + ) + + # Format error message + if len(invalid_columns) == 1: + col = invalid_columns[0] + suggestions = suggestions_map.get(col.name, []) + + if suggestions: + return ChartErrorBuilder.column_not_found_error( + col.name, [s.name for s in suggestions] + ) + else: + return ChartErrorBuilder.column_not_found_error(col.name) + else: + # Multiple invalid columns + invalid_names = [col.name for col in invalid_columns] + return ChartErrorBuilder.build_error( + error_type="multiple_invalid_columns", + template_key="column_not_found", + template_vars={ + "column": ", ".join(invalid_names[:3]) + + ("..." if len(invalid_names) > 3 else ""), + "suggestions": "Use get_dataset_info to see all available columns", + }, + custom_suggestions=[ + f"Invalid columns: {', '.join(invalid_names)}", + "Check spelling and case sensitivity", + "Use get_dataset_info to list available columns", + ], + error_code="MULTIPLE_INVALID_COLUMNS", + ) + + @staticmethod + def _validate_aggregations( + column_refs: List[ColumnRef], dataset_context: DatasetContext + ) -> List[ChartGenerationError]: + """Validate that aggregations are appropriate for column types.""" + errors = [] + + for col_ref in column_refs: + if not col_ref.aggregate: + continue + + # Find column info + col_info = None + for col in dataset_context.available_columns: + if col["name"].lower() == col_ref.name.lower(): + col_info = col + break + + if col_info: + # Check numeric aggregates on non-numeric columns + numeric_aggs = ["SUM", "AVG", "MIN", "MAX", "STDDEV", "VAR", "MEDIAN"] + if ( + col_ref.aggregate in numeric_aggs + and not col_info.get("is_numeric", False) + and col_info.get("type", "").upper() + not in ["INTEGER", "FLOAT", "DOUBLE", "DECIMAL", "NUMERIC"] + ): + from superset.mcp_service.utils.error_builder import ( # noqa: E501 + ChartErrorBuilder, + ) + + errors.append( + ChartErrorBuilder.build_error( + error_type="invalid_aggregation", + template_key="incompatible_configuration", + template_vars={ + "reason": f"Cannot apply {col_ref.aggregate} to " + f"non-numeric column " + f"'{col_ref.name}' (type:" + f" {col_info.get('type', 'UNKNOWN')})", + "primary_suggestion": "Use COUNT or COUNT_DISTINCT " + "for text columns", + }, + custom_suggestions=[ + "Remove the aggregate function for raw values", + "Use COUNT to count occurrences", + "Use COUNT_DISTINCT to count unique values", + ], + error_code="INVALID_AGGREGATION", + ) + ) + + return errors diff --git a/superset/mcp_service/chart/validation/pipeline.py b/superset/mcp_service/chart/validation/pipeline.py new file mode 100644 index 00000000000..8d595aa97ab --- /dev/null +++ b/superset/mcp_service/chart/validation/pipeline.py @@ -0,0 +1,325 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unified validation pipeline for chart operations. +Orchestrates schema, dataset, and runtime validations. +""" + +import logging +from typing import Any, Dict, List, Tuple + +from superset.mcp_service.chart.schemas import ( + ChartConfig, + GenerateChartRequest, +) +from superset.mcp_service.common.error_schemas import ChartGenerationError + +logger = logging.getLogger(__name__) + + +def _redact_sql_select(error_str: str, error_str_upper: str) -> str: + """Redact SELECT...FROM clause content to prevent data disclosure.""" + if "SELECT" in error_str_upper and "FROM" in error_str_upper: + select_idx = error_str_upper.find("SELECT") + from_idx = error_str_upper.find("FROM", select_idx) + if select_idx != -1 and from_idx != -1: + return error_str[: select_idx + 7] + " [REDACTED] " + error_str[from_idx:] + return error_str + + +def _redact_sql_where(error_str: str, error_str_upper: str) -> str: + """Redact WHERE clause content to prevent data disclosure.""" + if "WHERE" not in error_str_upper: + return error_str + + where_idx = error_str_upper.find("WHERE") + terminators = ["ORDER", "GROUP", "LIMIT", "UNION", "EXCEPT", "INTERSECT"] + term_idx = len(error_str) + for term in terminators: + idx = error_str_upper.find(term, where_idx) + if idx != -1 and idx < term_idx: + term_idx = idx + return error_str[: where_idx + 6] + " [REDACTED]" + error_str[term_idx:] + + +def _get_generic_error_message(error_str: str) -> str | None: + """Return generic message for common error types, or None.""" + error_lower = error_str.lower() + if "permission" in error_lower or "access" in error_lower: + return "Validation failed due to access restrictions" + if "database" in error_lower or "connection" in error_lower: + return "Validation failed due to database connectivity" + if "timeout" in error_lower: + return "Validation timed out" + return None + + +def _sanitize_validation_error(error: Exception) -> str: + """SECURITY FIX: Sanitize validation errors to prevent disclosure.""" + error_str = str(error) + + # SECURITY FIX: Limit length FIRST to prevent ReDoS attacks + if len(error_str) > 200: + error_str = error_str[:200] + "...[truncated]" + + # Remove potentially sensitive schema information + import re + + sensitive_patterns = [ + (r'\btable\s+[\'"`]?(\w+)[\'"`]?', "table [REDACTED]"), + (r'\bcolumn\s+[\'"`]?(\w+)[\'"`]?', "column [REDACTED]"), + (r'\bdatabase\s+[\'"`]?(\w+)[\'"`]?', "database [REDACTED]"), + (r'\bschema\s+[\'"`]?(\w+)[\'"`]?', "schema [REDACTED]"), + ] + for pattern, replacement in sensitive_patterns: + error_str = re.sub(pattern, replacement, error_str, flags=re.IGNORECASE) + + # SECURITY FIX: SQL sanitization without ReDoS-vulnerable patterns + error_str_upper = error_str.upper() + error_str = _redact_sql_select(error_str, error_str_upper) + error_str = _redact_sql_where(error_str, error_str_upper) + + # Return generic message for common error types + if generic := _get_generic_error_message(error_str): + return generic + + return error_str + + +class ValidationPipeline: + """ + Main validation orchestrator that runs validations in sequence: + 1. Schema validation (structure and types) + 2. Dataset validation (columns exist) + 3. Runtime validation (performance, compatibility) + """ + + @staticmethod + def validate_request( + request_data: Dict[str, Any], + ) -> Tuple[bool, GenerateChartRequest | None, ChartGenerationError | None]: + """ + Validate a chart generation request through all validation layers. + + Args: + request_data: Raw request data dictionary + + Returns: + Tuple of (is_valid, parsed_request, error) + """ + try: + # Layer 1: Schema validation + from .schema_validator import SchemaValidator + + is_valid, request, error = SchemaValidator.validate_request(request_data) + if not is_valid: + return False, None, error + + # Ensure request is not None + if request is None: + return False, None, error + + # Layer 2: Dataset validation + is_valid, error = ValidationPipeline._validate_dataset( + request.config, request.dataset_id + ) + if not is_valid: + return False, request, error + + # Layer 3: Runtime validation + is_valid, error = ValidationPipeline._validate_runtime( + request.config, request.dataset_id + ) + if not is_valid: + return False, request, error + + return True, request, None + + except Exception as e: + logger.exception("Validation pipeline error") + from superset.mcp_service.utils.error_builder import ( + ChartErrorBuilder, + ) + + # SECURITY FIX: Sanitize validation error to prevent information disclosure + sanitized_reason = _sanitize_validation_error(e) + error = ChartErrorBuilder.build_error( + error_type="validation_system_error", + template_key="validation_error", + template_vars={"reason": sanitized_reason}, + error_code="VALIDATION_PIPELINE_ERROR", + ) + return False, None, error + + @staticmethod + def _validate_dataset( + config: ChartConfig, dataset_id: int | str + ) -> Tuple[bool, ChartGenerationError | None]: + """Validate configuration against dataset schema.""" + try: + from .dataset_validator import DatasetValidator + + return DatasetValidator.validate_against_dataset(config, dataset_id) + except ImportError: + # Skip if dataset validator not available + logger.warning( + "Dataset validator not available, skipping dataset validation" + ) + return True, None + except Exception as e: + logger.warning("Dataset validation failed: %s", e) + # Don't fail on dataset validation errors + return True, None + + @staticmethod + def _validate_runtime( + config: ChartConfig, dataset_id: int | str + ) -> Tuple[bool, ChartGenerationError | None]: + """Validate runtime issues (performance, compatibility).""" + try: + from .runtime import RuntimeValidator + + return RuntimeValidator.validate_runtime_issues(config, dataset_id) + except ImportError: + # Skip if runtime validator not available + logger.warning( + "Runtime validator not available, skipping runtime validation" + ) + return True, None + except Exception as e: + logger.warning("Runtime validation failed: %s", e) + # Don't fail on runtime validation errors + return True, None + + @staticmethod + def validate_filters( + filters: List[Any], + ) -> Tuple[bool, ChartGenerationError | None]: + """ + Validate filter logic for contradictions and empty results. + + Args: + filters: List of filter configurations + + Returns: + Tuple of (is_valid, error) + """ + if not filters: + return True, None + + # Check for contradictory filters + if ValidationPipeline._has_contradictory_filters(filters): + from superset.mcp_service.utils.error_builder import ( + ChartErrorBuilder, + ) + + return False, ChartErrorBuilder.build_error( + error_type="contradictory_filters", + template_key="invalid_value", + template_vars={ + "field": "filters", + "value": "contradictory conditions", + "reason": "Filter conditions are logically impossible", + "allowed_values": "non-contradictory conditions", + "specific_suggestion": "Remove conflicting filters", + }, + error_code="CONTRADICTORY_FILTERS", + ) + + # Check for filters likely to return empty + if empty_warnings := ValidationPipeline._check_empty_result_filters(filters): + from superset.mcp_service.utils.error_builder import ( + ChartErrorBuilder, + ) + + return False, ChartErrorBuilder.build_error( + error_type="empty_result_warning", + template_key="empty_result", + template_vars={"reason": "; ".join(empty_warnings)}, + custom_suggestions=[ + "Verify filter values exist in your dataset", + "Check for typos in filter values", + "Use broader filter criteria", + ], + error_code="EMPTY_RESULT_WARNING", + ) + + return True, None + + @staticmethod + def _has_contradictory_filters(filters: List[Any]) -> bool: + """Check if filters contain logical contradictions.""" + # Group filters by column + column_filters: Dict[str, List[Any]] = {} + for f in filters: + col = f.column + if col not in column_filters: + column_filters[col] = [] + column_filters[col].append(f) + + # Check for contradictions within same column + for _col, col_filters in column_filters.items(): + # Check for > X AND < Y where X >= Y + gt_values = [f.value for f in col_filters if f.op == ">"] + lt_values = [f.value for f in col_filters if f.op == "<"] + + for gt in gt_values: + for lt in lt_values: + try: + if float(gt) >= float(lt): + return True + except (ValueError, TypeError): + pass + + # Check for = X AND = Y where X != Y + eq_values = [f.value for f in col_filters if f.op == "="] + if len(eq_values) > 1 and len(set(eq_values)) > 1: + return True + + return False + + @staticmethod + def _check_empty_result_filters(filters: List[Any]) -> List[str]: + """Check for filter patterns that commonly result in empty results.""" + warnings = [] + + for f in filters: + col_lower = f.column.lower() + val_str = str(f.value).lower() if f.value is not None else "" + + # Check for common empty result patterns + if f.op == "=" and any( + pattern in val_str + for pattern in ["deleted", "archived", "inactive", "disabled"] + ): + warnings.append( + f"Filter '{f.column} = {f.value}' may return few or no results" + ) + + # Check for future dates + if "date" in col_lower and f.op in [">", ">="]: + try: + if "20" in val_str and int(val_str[:4]) > 2025: + warnings.append( + f"Filter '{f.column} {f.op} {f.value}' uses future date" + ) + except (ValueError, IndexError): + # Ignore invalid date formats + pass + + return warnings diff --git a/superset/mcp_service/chart/validation/runtime/__init__.py b/superset/mcp_service/chart/validation/runtime/__init__.py new file mode 100644 index 00000000000..e48585fe19c --- /dev/null +++ b/superset/mcp_service/chart/validation/runtime/__init__.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Runtime validation module for chart configurations. +Validates performance, compatibility, and user experience issues. +""" + +import logging +from typing import List, Tuple + +from superset.mcp_service.chart.schemas import ( + ChartConfig, + XYChartConfig, +) +from superset.mcp_service.common.error_schemas import ChartGenerationError + +logger = logging.getLogger(__name__) + + +class RuntimeValidator: + """Orchestrates runtime validations for chart configurations.""" + + @staticmethod + def validate_runtime_issues( + config: ChartConfig, dataset_id: int | str + ) -> Tuple[bool, ChartGenerationError | None]: + """ + Validate runtime issues that could affect chart rendering or performance. + + Args: + config: Chart configuration to validate + dataset_id: Dataset identifier + + Returns: + Tuple of (is_valid, error) + """ + warnings: List[str] = [] + suggestions: List[str] = [] + + # Only check XY charts for format and cardinality issues + if isinstance(config, XYChartConfig): + # Format-type compatibility validation + format_warnings = RuntimeValidator._validate_format_compatibility(config) + if format_warnings: + warnings.extend(format_warnings) + + # Cardinality validation + cardinality_warnings, cardinality_suggestions = ( + RuntimeValidator._validate_cardinality(config, dataset_id) + ) + if cardinality_warnings: + warnings.extend(cardinality_warnings) + suggestions.extend(cardinality_suggestions) + + # Chart type appropriateness validation (for all chart types) + type_warnings, type_suggestions = RuntimeValidator._validate_chart_type( + config, dataset_id + ) + if type_warnings: + warnings.extend(type_warnings) + suggestions.extend(type_suggestions) + + # If we have warnings, return them as a validation error + if warnings: + from superset.mcp_service.utils.error_builder import ( + ChartErrorBuilder, + ) + + return False, ChartErrorBuilder.build_error( + error_type="runtime_semantic_warning", + template_key="performance_warning", + template_vars={ + "reason": "; ".join(warnings[:3]) + + ("..." if len(warnings) > 3 else "") + }, + custom_suggestions=suggestions[:5], # Limit suggestions + error_code="RUNTIME_SEMANTIC_WARNING", + ) + + return True, None + + @staticmethod + def _validate_format_compatibility(config: XYChartConfig) -> List[str]: + """Validate format-type compatibility.""" + warnings: List[str] = [] + + try: + # Import here to avoid circular imports + from .format_validator import FormatTypeValidator + + is_valid, format_warnings = ( + FormatTypeValidator.validate_format_compatibility(config) + ) + if format_warnings: + warnings.extend(format_warnings) + except ImportError: + logger.warning("Format validator not available") + except Exception as e: + logger.warning("Format validation failed: %s", e) + + return warnings + + @staticmethod + def _validate_cardinality( + config: XYChartConfig, dataset_id: int | str + ) -> Tuple[List[str], List[str]]: + """Validate cardinality issues.""" + warnings: List[str] = [] + suggestions: List[str] = [] + + try: + # Import here to avoid circular imports + from .cardinality_validator import CardinalityValidator + + # Determine chart type for cardinality thresholds + chart_type = config.kind if hasattr(config, "kind") else "default" + + # Check X-axis cardinality + is_ok, cardinality_info = CardinalityValidator.check_cardinality( + dataset_id=dataset_id, + x_column=config.x.name, + chart_type=chart_type, + group_by_column=config.group_by.name if config.group_by else None, + ) + + if not is_ok and cardinality_info: + warnings.extend(cardinality_info.get("warnings", [])) + suggestions.extend(cardinality_info.get("suggestions", [])) + + except ImportError: + logger.warning("Cardinality validator not available") + except Exception as e: + logger.warning("Cardinality validation failed: %s", e) + + return warnings, suggestions + + @staticmethod + def _validate_chart_type( + config: ChartConfig, dataset_id: int | str + ) -> Tuple[List[str], List[str]]: + """Validate chart type appropriateness.""" + warnings: List[str] = [] + suggestions: List[str] = [] + + try: + # Import here to avoid circular imports + from .chart_type_suggester import ChartTypeSuggester + + is_appropriate, suggestion_info = ChartTypeSuggester.analyze_and_suggest( + config, dataset_id + ) + + if not is_appropriate and suggestion_info: + warnings.extend(suggestion_info.get("issues", [])) + suggestions.extend(suggestion_info.get("suggestions", [])) + + # Add recommended chart types + recommended = suggestion_info.get("recommended_types", []) + if recommended: + recommendations = ", ".join(recommended) + suggestions.append( + f"Recommended chart types for this data: {recommendations}" + ) + + except ImportError: + logger.warning("Chart type suggester not available") + except Exception as e: + logger.warning("Chart type validation failed: %s", e) + + return warnings, suggestions diff --git a/superset/mcp_service/chart/validation/runtime/cardinality_validator.py b/superset/mcp_service/chart/validation/runtime/cardinality_validator.py new file mode 100644 index 00000000000..5a3b32062fe --- /dev/null +++ b/superset/mcp_service/chart/validation/runtime/cardinality_validator.py @@ -0,0 +1,195 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Cardinality validation to prevent unusable visualizations from high-cardinality data. +""" + +import logging +from typing import Any, Dict, List, Tuple + +logger = logging.getLogger(__name__) + + +class CardinalityValidator: + """ + Validates cardinality of dimensions to prevent charts with too many categories + that become unreadable or cause performance issues. + """ + + # Thresholds for different chart types + CARDINALITY_THRESHOLDS = { + "bar": 50, # Bar charts become unreadable with >50 bars + "line": 100, # Line charts can handle more points + "scatter": 500, # Scatter plots can show many points + "area": 30, # Area charts need fewer categories + "table": 1000, # Tables can handle many rows with pagination + "default": 50, # Conservative default + } + + # Known high-cardinality column patterns + HIGH_CARDINALITY_PATTERNS = [ + "id", + "uuid", + "guid", + "email", + "phone", + "address", + "session", + "transaction", + "order_number", + "invoice", + "timestamp", + "datetime", + "created_at", + "updated_at", + ] + + @staticmethod + def check_cardinality( + dataset_id: int | str, + x_column: str, + chart_type: str = "default", + group_by_column: str | None = None, + ) -> Tuple[bool, Dict[str, Any] | None]: + """ + Check cardinality of X-axis and group_by columns. + + Returns: + Tuple of (is_ok, warning_info) + """ + try: + # Quick pattern check first (no DB query needed) + pattern_warnings = CardinalityValidator._check_column_patterns( + x_column, group_by_column + ) + + if pattern_warnings: + return False, { + "warnings": pattern_warnings, + "suggestions": CardinalityValidator._get_suggestions( + x_column, chart_type, pattern_based=True + ), + } + + # For non-pattern columns, we could do actual cardinality check + # but that requires DB access - for now just return OK + # In production, you'd want to cache cardinality stats + + return True, None + + except Exception as e: + logger.warning("Cardinality check failed: %s", e) + # Don't block on validation failures + return True, None + + @staticmethod + def _check_column_patterns( + x_column: str, group_by_column: str | None = None + ) -> List[str]: + """Check for known high-cardinality column patterns.""" + warnings = [] + + x_lower = x_column.lower() + + # Check X-axis column + for pattern in CardinalityValidator.HIGH_CARDINALITY_PATTERNS: + if pattern in x_lower: + warnings.append( + f"Column '{x_column}' appears to be a high-cardinality field " + f"(contains '{pattern}'). This may create an unreadable chart " + f"with too many categories on the X-axis." + ) + break + + # Check group_by column if present + if group_by_column: + group_lower = group_by_column.lower() + for pattern in CardinalityValidator.HIGH_CARDINALITY_PATTERNS: + if pattern in group_lower: + warnings.append( + f"Group by column '{group_by_column}' appears to be a " + f"high-cardinality field (contains '{pattern}'). This may " + f"create too many series to visualize effectively." + ) + break + + return warnings + + @staticmethod + def _get_suggestions( + column: str, chart_type: str, pattern_based: bool = False + ) -> List[str]: + """Get suggestions for handling high cardinality.""" + suggestions = [] + + if pattern_based: + # Suggestions when we detected high-cardinality patterns + if any(p in column.lower() for p in ["id", "uuid", "guid"]): + suggestions.extend( + [ + "Consider using a different column for the X-axis", + f"If you need to analyze by {column}, use filters to limit " + f"the data", + "A table chart might be more appropriate for ID-based data", + ] + ) + elif any(p in column.lower() for p in ["email", "phone", "address"]): + suggestions.extend( + [ + "Consider grouping by a higher-level category (e.g., " + "domain for emails)", + f"Use filters to focus on specific {column} values", + "Aggregate the data before visualization", + ] + ) + elif any( + p in column.lower() for p in ["timestamp", "datetime", "created_at"] + ): + suggestions.extend( + [ + "Consider truncating timestamps to date or hour level", + "Use time-based grouping (daily, weekly, monthly)", + "Apply date range filters to limit the data", + ] + ) + else: + # General high-cardinality suggestions + threshold = CardinalityValidator.CARDINALITY_THRESHOLDS.get(chart_type, 50) + suggestions.extend( + [ + f"This chart type works best with fewer than {threshold} " + f"categories", + "Consider using filters to reduce the number of values", + "Try grouping or categorizing the data at a higher level", + "A table or pivot table might better display high-cardinality data", + ] + ) + + return suggestions + + @staticmethod + def suggest_chart_type(cardinality: int) -> List[str]: + """Suggest appropriate chart types based on cardinality.""" + if cardinality <= 10: + return ["bar", "pie", "donut", "area"] + elif cardinality <= 30: + return ["bar", "line", "area"] + elif cardinality <= 100: + return ["line", "scatter"] + else: + return ["table", "pivot_table", "heatmap"] diff --git a/superset/mcp_service/chart/validation/runtime/chart_type_suggester.py b/superset/mcp_service/chart/validation/runtime/chart_type_suggester.py new file mode 100644 index 00000000000..910b13ab66c --- /dev/null +++ b/superset/mcp_service/chart/validation/runtime/chart_type_suggester.py @@ -0,0 +1,437 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Chart type suggestions based on data characteristics and user intent. +""" + +import logging +from typing import Any, Dict, List, Tuple + +from superset.mcp_service.chart.schemas import ( + ChartConfig, + ColumnRef, + TableChartConfig, + XYChartConfig, +) + +logger = logging.getLogger(__name__) + + +class ChartTypeSuggester: + """ + Suggests appropriate chart types based on data characteristics + and identifies potential mismatches between chart type and data. + """ + + @staticmethod + def analyze_and_suggest( + config: ChartConfig, + dataset_id: int | str, # noqa: ARG002 + ) -> Tuple[bool, Dict[str, Any] | None]: + """ + Analyze chart configuration and suggest better chart types if needed. + + Returns: + Tuple of (is_appropriate, suggestion_info) + """ + try: + if isinstance(config, XYChartConfig): + return ChartTypeSuggester._analyze_xy_chart(config) + elif isinstance(config, TableChartConfig): + return ChartTypeSuggester._analyze_table_chart(config) + else: + return True, None + except Exception as e: + logger.warning("Chart type analysis failed: %s", e) + return True, None # Don't block on suggestion failures + + @staticmethod + def _analyze_xy_chart( + config: XYChartConfig, + ) -> Tuple[bool, Dict[str, Any] | None]: + """Analyze XY chart appropriateness.""" + issues = [] + suggestions = [] + + x_analysis = ChartTypeSuggester._analyze_x_axis(config.x.name) + y_analysis = ChartTypeSuggester._analyze_y_axis(config.y) + + # Check chart type specific issues + chart_issues, chart_suggestions = ChartTypeSuggester._check_chart_type_issues( + config, x_analysis, y_analysis + ) + issues.extend(chart_issues) + suggestions.extend(chart_suggestions) + + # Add general suggestions + general_suggestions = ChartTypeSuggester._get_general_suggestions( + x_analysis, y_analysis + ) + suggestions.extend(general_suggestions) + + if issues: + return False, { + "issues": issues, + "suggestions": suggestions, + "recommended_types": ChartTypeSuggester._get_recommended_types( + x_analysis["is_temporal"], + x_analysis["is_categorical"], + y_analysis["has_count"], + y_analysis["num_metrics"], + ), + } + + return True, None + + @staticmethod + def _analyze_x_axis(x_name: str) -> Dict[str, Any]: + """Analyze X-axis characteristics.""" + x_name_lower = x_name.lower() + return { + "is_temporal": any( + t in x_name_lower + for t in [ + "date", + "time", + "year", + "month", + "day", + "hour", + "created", + "updated", + ] + ), + "is_categorical": any( + c in x_name_lower + for c in [ + "category", + "type", + "status", + "department", + "region", + "country", + "state", + ] + ), + "is_id": any(i in x_name_lower for i in ["id", "uuid", "guid", "key"]), + "name": x_name, + } + + @staticmethod + def _analyze_y_axis(y_columns: List[ColumnRef]) -> Dict[str, Any]: + """Analyze Y-axis characteristics.""" + return { + "has_count": any( + col.aggregate in ["COUNT", "COUNT_DISTINCT"] for col in y_columns + ), + "num_metrics": len(y_columns), + } + + @staticmethod + def _check_chart_type_issues( + config: XYChartConfig, x_analysis: Dict[str, Any], y_analysis: Dict[str, Any] + ) -> Tuple[List[str], List[str]]: + """Check for chart type specific issues.""" + issues = [] + suggestions = [] + + # Extract analysis values + x_is_temporal = x_analysis["is_temporal"] + x_is_categorical = x_analysis["is_categorical"] + x_is_id = x_analysis["is_id"] + num_metrics = y_analysis["num_metrics"] + + # Check chart type specific issues by delegating to helper methods + if config.kind == "line": + line_issues, line_suggestions = ChartTypeSuggester._check_line_chart_issues( + config, x_is_temporal, x_is_categorical, x_is_id + ) + issues.extend(line_issues) + suggestions.extend(line_suggestions) + elif config.kind == "scatter": + ( + scatter_issues, + scatter_suggestions, + ) = ChartTypeSuggester._check_scatter_chart_issues( + config, x_is_categorical, num_metrics + ) + issues.extend(scatter_issues) + suggestions.extend(scatter_suggestions) + elif config.kind == "area": + area_issues, area_suggestions = ChartTypeSuggester._check_area_chart_issues( + config, x_is_temporal + ) + issues.extend(area_issues) + suggestions.extend(area_suggestions) + elif config.kind == "bar": + bar_issues, bar_suggestions = ChartTypeSuggester._check_bar_chart_issues( + config, x_is_id + ) + issues.extend(bar_issues) + suggestions.extend(bar_suggestions) + + return issues, suggestions + + @staticmethod + def _check_line_chart_issues( + config: XYChartConfig, + x_is_temporal: bool, + x_is_categorical: bool, + x_is_id: bool, + ) -> Tuple[List[str], List[str]]: + """Check line chart specific issues.""" + issues = [] + suggestions = [] + + if not x_is_temporal and x_is_categorical: + issues.append( + f"Line chart with categorical X-axis '{config.x.name}' may not " + f"show meaningful trends" + ) + suggestions.extend( + [ + "Consider using a bar chart for categorical comparisons", + "Line charts work best with temporal or continuous data", + ] + ) + elif x_is_id: + issues.append( + f"Line chart with ID field '{config.x.name}' on X-axis will not " + f"show meaningful patterns" + ) + suggestions.extend( + [ + "Use a table to display individual records", + "Or aggregate the data by a meaningful dimension", + ] + ) + + return issues, suggestions + + @staticmethod + def _check_scatter_chart_issues( + config: XYChartConfig, x_is_categorical: bool, num_metrics: int + ) -> Tuple[List[str], List[str]]: + """Check scatter chart specific issues.""" + issues = [] + suggestions = [] + + if x_is_categorical: + issues.append( + f"Scatter plot with categorical X-axis '{config.x.name}' may not " + f"effectively show correlations" + ) + suggestions.extend( + [ + "Scatter plots work best with two continuous variables", + "Consider a bar chart for categorical vs numeric data", + ] + ) + if num_metrics > 1: + issues.append("Scatter plots with multiple Y metrics can be confusing") + suggestions.extend( + [ + "Consider using only one Y metric for clarity", + "Or use a line/bar chart to compare multiple metrics", + ] + ) + + return issues, suggestions + + @staticmethod + def _check_area_chart_issues( + config: XYChartConfig, x_is_temporal: bool + ) -> Tuple[List[str], List[str]]: + """Check area chart specific issues.""" + issues = [] + suggestions = [] + + if not x_is_temporal: + issues.append( + f"Area chart with non-temporal X-axis '{config.x.name}' may be " + f"misleading" + ) + suggestions.extend( + [ + "Area charts imply cumulative or part-to-whole relationships over " + "time", + "Consider a stacked bar chart for categorical data", + ] + ) + + # Check for potential negative values + for col in config.y: + if any(term in col.name.lower() for term in ["loss", "debt", "negative"]): + issues.append( + f"Area chart with potentially negative values in '{col.name}' " + f"can create visual confusion" + ) + suggestions.extend( + [ + "Use a line chart for data that can go negative", + "Or ensure all values are positive before using area chart", + ] + ) + + return issues, suggestions + + @staticmethod + def _check_bar_chart_issues( + config: XYChartConfig, x_is_id: bool + ) -> Tuple[List[str], List[str]]: + """Check bar chart specific issues.""" + issues = [] + suggestions = [] + + if x_is_id: + issues.append( + f"Bar chart with ID field '{config.x.name}' may create too many bars" + ) + suggestions.extend( + [ + "Consider aggregating by a higher-level category", + "Or use filters to limit the number of bars displayed", + ] + ) + + return issues, suggestions + + @staticmethod + def _get_general_suggestions( + x_analysis: Dict[str, Any], y_analysis: Dict[str, Any] + ) -> List[str]: + """Get general suggestions based on data patterns.""" + suggestions = [] + x_is_temporal = x_analysis["is_temporal"] + x_is_categorical = x_analysis["is_categorical"] + has_count = y_analysis["has_count"] + num_metrics = y_analysis["num_metrics"] + + if has_count and x_is_categorical: + suggestions.append( + "This looks like frequency analysis - bar charts work well for counts " + "by category" + ) + elif x_is_temporal and num_metrics == 1: + suggestions.append( + "Single metric over time - line charts are ideal for showing trends" + ) + elif x_is_temporal and num_metrics > 3: + suggestions.append( + "Many metrics over time - consider focusing on 2-3 key metrics for " + "clarity" + ) + + return suggestions + + @staticmethod + def _analyze_table_chart( + config: TableChartConfig, + ) -> Tuple[bool, Dict[str, Any] | None]: + """Analyze table chart appropriateness.""" + issues = [] + suggestions = [] + + # Count different column types + raw_columns = sum(1 for col in config.columns if not col.aggregate) + metric_columns = sum(1 for col in config.columns if col.aggregate) + total_columns = len(config.columns) + + # Check if data might be better visualized + if metric_columns > 0 and raw_columns <= 2: + # Mostly metrics with few dimensions - could be visualized + issues.append( + "Table with mostly aggregated metrics could be visualized as a chart" + ) + suggestions.append("Consider a bar chart to compare metric values visually") + suggestions.append("Or use a line chart if there's a time dimension") + + # Check for ID-heavy tables + id_columns = sum( + 1 + for col in config.columns + if any(i in col.name.lower() for i in ["id", "uuid", "guid", "key"]) + ) + if id_columns > total_columns / 2: + suggestions.append( + "Table appears to be ID-heavy - ensure this is for detailed record " + "inspection" + ) + suggestions.append( + "For analysis, consider aggregating by meaningful dimensions instead" + ) + + # Very wide tables + if total_columns > 10: + issues.append( + f"Table with {total_columns} columns may be difficult to read" + ) + suggestions.append("Consider showing only the most important columns") + suggestions.append("Or break into multiple focused views") + + if issues: + return False, { + "issues": issues, + "suggestions": suggestions, + "recommended_types": ["table", "pivot_table"] + if metric_columns > 0 + else ["table"], + } + + return True, None + + @staticmethod + def _get_recommended_types( + x_is_temporal: bool, x_is_categorical: bool, has_count: bool, num_metrics: int + ) -> List[str]: + """Get recommended chart types based on data characteristics.""" + recommendations = [] + + if x_is_temporal: + recommendations.extend(["line", "area", "bar"]) + if num_metrics == 1: + recommendations.append("scatter") # For trend analysis + elif x_is_categorical: + recommendations.extend(["bar", "table"]) + if has_count and num_metrics == 1: + recommendations.append("pie") # For proportion analysis + else: + # Continuous or unclear X-axis + recommendations.extend(["scatter", "line", "table"]) + + # Always include table as fallback + if "table" not in recommendations: + recommendations.append("table") + + return recommendations + + @staticmethod + def get_chart_type_description(chart_type: str) -> str: + """Get a description of when to use each chart type.""" + descriptions = { + "line": "Best for showing trends over time or continuous data", + "bar": "Ideal for comparing values across categories", + "area": "Shows cumulative totals and part-to-whole relationships over time", + "scatter": "Reveals correlations between two continuous variables", + "table": "Displays detailed data or many dimensions at once", + "pie": "Shows proportions of a whole (use sparingly, max 5-7 slices)", + "pivot_table": "Summarizes data across multiple dimensions", + } + return descriptions.get( + chart_type, f"Visualizes data using {chart_type} format" + ) diff --git a/superset/mcp_service/chart/validation/runtime/format_validator.py b/superset/mcp_service/chart/validation/runtime/format_validator.py new file mode 100644 index 00000000000..34971d47907 --- /dev/null +++ b/superset/mcp_service/chart/validation/runtime/format_validator.py @@ -0,0 +1,225 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Format-type compatibility validation to prevent misleading data presentation. +""" + +import logging +import re +from typing import List, Tuple + +from superset.mcp_service.chart.schemas import ColumnRef, XYChartConfig + +logger = logging.getLogger(__name__) + + +class FormatTypeValidator: + """ + Validates that format strings are appropriate for the data type and aggregation. + Prevents issues like currency formatting on COUNT data or percentage on absolute + values. + """ + + # Format patterns and their appropriate uses + CURRENCY_PATTERNS = [ + r"\$", # Dollar sign + r"โ‚ฌ", # Euro + r"ยฃ", # Pound + r"ยฅ", # Yen + r"[,.]2f", # Two decimal places (common for currency) + r"\$[,.]", # Dollar with thousands separator + ] + + PERCENTAGE_PATTERNS = [ + r"%", # Percentage sign + r"\.0%", # Percentage with no decimals + r"\.1%", # Percentage with 1 decimal + r"\.2%", # Percentage with 2 decimals + ] + + INTEGER_PATTERNS = [ + r"\.0f", # No decimals + r",d", # Integer with thousands separator + r"[,.]0f", # Integer format variations + ] + + @staticmethod + def validate_format_compatibility( + config: XYChartConfig, + ) -> Tuple[bool, List[str] | None]: + """ + Validate that axis formats are appropriate for the data types. + + Returns: + Tuple of (is_valid, warnings_list) + """ + warnings = [] + + # Validate Y-axis format against metrics + if config.y_axis and config.y_axis.format: + y_warnings = FormatTypeValidator._validate_y_axis_format( + config.y_axis.format, config.y + ) + warnings.extend(y_warnings) + + # Validate X-axis format (usually temporal or categorical) + if config.x_axis and config.x_axis.format: + x_warnings = FormatTypeValidator._validate_x_axis_format( + config.x_axis.format, config.x + ) + warnings.extend(x_warnings) + + return len(warnings) == 0, warnings if warnings else None + + @staticmethod + def _validate_y_axis_format( + format_string: str, y_columns: List[ColumnRef] + ) -> List[str]: + """Validate Y-axis format against the metrics.""" + warnings = [] + + warnings.extend( + FormatTypeValidator._check_currency_format_issues(format_string, y_columns) + ) + warnings.extend( + FormatTypeValidator._check_percentage_format_issues( + format_string, y_columns + ) + ) + warnings.extend( + FormatTypeValidator._check_decimal_format_issues(format_string, y_columns) + ) + + return warnings + + @staticmethod + def _check_currency_format_issues( + format_string: str, y_columns: List[ColumnRef] + ) -> List[str]: + """Check for currency format issues.""" + warnings = [] + if FormatTypeValidator._is_currency_format(format_string): + for col in y_columns: + if col.aggregate in ["COUNT", "COUNT_DISTINCT"]: + warnings.append( + f"Currency format '{format_string}' applied to {col.aggregate} " + f"of '{col.name}'. COUNT operations return whole numbers, not " + f"currency values. Consider using integer format like '," + f"d' instead." + ) + return warnings + + @staticmethod + def _check_percentage_format_issues( + format_string: str, y_columns: List[ColumnRef] + ) -> List[str]: + """Check for percentage format issues.""" + warnings = [] + if FormatTypeValidator._is_percentage_format(format_string): + for col in y_columns: + if col.aggregate in ["SUM", "COUNT", "COUNT_DISTINCT"]: + label = col.label or f"{col.aggregate}({col.name})" + warnings.append( + f"Percentage format '{format_string}' applied to " + f"{col.aggregate} of '{col.name}'. This will multiply values " + f"by 100 and add %. " + f"If '{label}' contains absolute values (not ratios 0-1), " + f"consider using a numeric format instead." + ) + return warnings + + @staticmethod + def _check_decimal_format_issues( + format_string: str, y_columns: List[ColumnRef] + ) -> List[str]: + """Check for decimal format issues.""" + warnings = [] + if "." in format_string and any(char.isdigit() for char in format_string): + decimal_places = FormatTypeValidator._get_decimal_places(format_string) + if decimal_places and decimal_places > 0: + for col in y_columns: + if col.aggregate in ["COUNT", "COUNT_DISTINCT"]: + warnings.append( + f"Decimal format '{format_string}' applied to " + f"{col.aggregate} of '{col.name}'. COUNT operations " + f"always return " + f"integers. Consider using integer format like '," + f"d' or '.0f' instead." + ) + return warnings + + @staticmethod + def _validate_x_axis_format(format_string: str, x_column: ColumnRef) -> List[str]: + """Validate X-axis format appropriateness.""" + warnings = [] + + # Currency format on X-axis is almost always wrong + if FormatTypeValidator._is_currency_format(format_string): + warnings.append( + f"Currency format '{format_string}' applied to X-axis '" + f"{x_column.name}'. " + f"X-axis typically shows categories, time, or dimensions, " + f"not currency. " + f"Consider removing the format or using a date/category format." + ) + + # Percentage format on X-axis is unusual + elif FormatTypeValidator._is_percentage_format(format_string): + warnings.append( + f"Percentage format '{format_string}' applied to X-axis '" + f"{x_column.name}'. " + f"This is unusual for axis labels. Consider if this is intentional." + ) + + return warnings + + @staticmethod + def _is_currency_format(format_string: str) -> bool: + """Check if format string represents currency.""" + return any( + re.search(pattern, format_string, re.IGNORECASE) + for pattern in FormatTypeValidator.CURRENCY_PATTERNS + ) + + @staticmethod + def _is_percentage_format(format_string: str) -> bool: + """Check if format string represents percentage.""" + return any( + re.search(pattern, format_string) + for pattern in FormatTypeValidator.PERCENTAGE_PATTERNS + ) + + @staticmethod + def _get_decimal_places(format_string: str) -> int | None: + """Extract number of decimal places from format string.""" + if match := re.search(r"\.(\d+)f", format_string): + return int(match.group(1)) + return None + + @staticmethod + def suggest_format(column: ColumnRef) -> str: + """Suggest appropriate format based on column and aggregation.""" + if column.aggregate in ["COUNT", "COUNT_DISTINCT"]: + return ",d" # Integer with thousands separator + elif column.aggregate in ["AVG", "STDDEV", "VAR"]: + return ",.2f" # Two decimals for statistical measures + elif column.aggregate in ["SUM", "MIN", "MAX"]: + # Could be currency or regular number, default to flexible + return ",.2f" # Two decimals with thousands separator + else: + return "" # Let Superset decide diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py new file mode 100644 index 00000000000..9ba30c0bd95 --- /dev/null +++ b/superset/mcp_service/chart/validation/schema_validator.py @@ -0,0 +1,307 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unified schema validation for chart configurations. +Consolidates pre-validation, schema validation, and error enhancement. +""" + +import logging +from typing import Any, Dict, Tuple + +from pydantic import ValidationError as PydanticValidationError + +from superset.mcp_service.chart.schemas import ( + GenerateChartRequest, +) +from superset.mcp_service.common.error_schemas import ChartGenerationError + +logger = logging.getLogger(__name__) + + +class SchemaValidator: + """Unified schema validator with pre-validation and enhanced error messages.""" + + @staticmethod + def validate_request( + request_data: Dict[str, Any], + ) -> Tuple[bool, GenerateChartRequest | None, ChartGenerationError | None]: + """ + Validate request data with pre-validation and enhanced error handling. + + Returns: + Tuple of (is_valid, parsed_request, error) + """ + # Pre-validate to catch common issues early + is_valid, error = SchemaValidator._pre_validate(request_data) + if not is_valid: + return False, None, error + + # Try Pydantic validation + try: + request = GenerateChartRequest(**request_data) + return True, request, None + except PydanticValidationError as e: + # Enhance the error message + error = SchemaValidator._enhance_validation_error(e, request_data) + return False, None, error + + @staticmethod + def _pre_validate( + data: Dict[str, Any], + ) -> Tuple[bool, ChartGenerationError | None]: + """Pre-validate request data before Pydantic processing.""" + if not isinstance(data, dict): + return False, ChartGenerationError( + error_type="invalid_request_format", + message="Request must be a JSON object", + details="The request body must be a valid JSON object, not a string " + "or array", + suggestions=[ + "Ensure you're sending a JSON object with 'dataset_id' and " + "'config' fields", + "Check that Content-Type header is set to 'application/json'", + ], + error_code="INVALID_REQUEST_FORMAT", + ) + + # Check for required top-level fields + if "dataset_id" not in data: + return False, ChartGenerationError( + error_type="missing_dataset_id", + message="Missing required field: dataset_id", + details="The 'dataset_id' field is required to identify which dataset " + "to use", + suggestions=[ + "Add 'dataset_id' field with the ID of your dataset", + "Use list_datasets tool to find available dataset IDs", + "Example: {'dataset_id': 1, 'config': {...}}", + ], + error_code="MISSING_DATASET_ID", + ) + + if "config" not in data: + return False, ChartGenerationError( + error_type="missing_config", + message="Missing required field: config", + details="The 'config' field is required to specify chart configuration", + suggestions=[ + "Add 'config' field with chart type and settings", + "Example: {'dataset_id': 1, 'config': {'chart_type': 'xy', ...}}", + ], + error_code="MISSING_CONFIG", + ) + + config = data.get("config", {}) + if not isinstance(config, dict): + return False, ChartGenerationError( + error_type="invalid_config_format", + message="Config must be a JSON object", + details="The 'config' field must be a valid JSON object with chart " + "settings", + suggestions=[ + "Ensure config is an object, not a string or array", + "Example: 'config': {'chart_type': 'xy', 'x': {...}, 'y': [...]}", + ], + error_code="INVALID_CONFIG_FORMAT", + ) + + # Check chart_type early + chart_type = config.get("chart_type") + if not chart_type: + return False, ChartGenerationError( + error_type="missing_chart_type", + message="Missing required field: chart_type", + details="Chart configuration must specify 'chart_type' as either 'xy' " + "or 'table'", + suggestions=[ + "Add 'chart_type': 'xy' for line/bar/area/scatter charts", + "Add 'chart_type': 'table' for table visualizations", + "Example: 'config': {'chart_type': 'xy', ...}", + ], + error_code="MISSING_CHART_TYPE", + ) + + if chart_type not in ["xy", "table"]: + return False, ChartGenerationError( + error_type="invalid_chart_type", + message=f"Invalid chart_type: '{chart_type}'", + details=f"Chart type '{chart_type}' is not supported. Must be 'xy' or " + f"'table'", + suggestions=[ + "Use 'chart_type': 'xy' for line, bar, area, or scatter charts", + "Use 'chart_type': 'table' for tabular data display", + "Check spelling and ensure lowercase", + ], + error_code="INVALID_CHART_TYPE", + ) + + # Pre-validate structure based on chart type + if chart_type == "xy": + return SchemaValidator._pre_validate_xy_config(config) + elif chart_type == "table": + return SchemaValidator._pre_validate_table_config(config) + + return True, None + + @staticmethod + def _pre_validate_xy_config( + config: Dict[str, Any], + ) -> Tuple[bool, ChartGenerationError | None]: + """Pre-validate XY chart configuration.""" + missing_fields = [] + + if "x" not in config: + missing_fields.append("'x' (X-axis column)") + if "y" not in config: + missing_fields.append("'y' (Y-axis metrics)") + + if missing_fields: + return False, ChartGenerationError( + error_type="missing_xy_fields", + message=f"XY chart missing required " + f"fields: {', '.join(missing_fields)}", + details="XY charts require both X-axis (dimension) and Y-axis (" + "metrics) specifications", + suggestions=[ + "Add 'x' field: {'name': 'column_name'} for X-axis", + "Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}] " + "for Y-axis", + "Example: {'chart_type': 'xy', 'x': {'name': 'date'}, " + "'y': [{'name': 'sales', 'aggregate': 'SUM'}]}", + ], + error_code="MISSING_XY_FIELDS", + ) + + # Validate Y is a list + if not isinstance(config.get("y", []), list): + return False, ChartGenerationError( + error_type="invalid_y_format", + message="Y-axis must be a list of metrics", + details="The 'y' field must be an array of metric specifications", + suggestions=[ + "Wrap Y-axis metric in array: 'y': [{'name': 'column', " + "'aggregate': 'SUM'}]", + "Multiple metrics supported: 'y': [metric1, metric2, ...]", + ], + error_code="INVALID_Y_FORMAT", + ) + + return True, None + + @staticmethod + def _pre_validate_table_config( + config: Dict[str, Any], + ) -> Tuple[bool, ChartGenerationError | None]: + """Pre-validate table chart configuration.""" + if "columns" not in config: + return False, ChartGenerationError( + error_type="missing_columns", + message="Table chart missing required field: columns", + details="Table charts require a 'columns' array to specify which " + "columns to display", + suggestions=[ + "Add 'columns' field with array of column specifications", + "Example: 'columns': [{'name': 'product'}, {'name': 'sales', " + "'aggregate': 'SUM'}]", + "Each column can have optional 'aggregate' for metrics", + ], + error_code="MISSING_COLUMNS", + ) + + if not isinstance(config.get("columns", []), list): + return False, ChartGenerationError( + error_type="invalid_columns_format", + message="Columns must be a list", + details="The 'columns' field must be an array of column specifications", + suggestions=[ + "Ensure columns is an array: 'columns': [...]", + "Each column should be an object with 'name' field", + ], + error_code="INVALID_COLUMNS_FORMAT", + ) + + return True, None + + @staticmethod + def _enhance_validation_error( + error: PydanticValidationError, request_data: Dict[str, Any] + ) -> ChartGenerationError: + """Convert Pydantic validation errors to user-friendly messages.""" + errors = error.errors() + + # Check for discriminated union errors (generic "'table' was expected") + for err in errors: + if err.get("type") == "union_tag_invalid" or "discriminator" in str( + err.get("ctx", {}) + ): + # This is the generic union error - provide better message + config = request_data.get("config", {}) + chart_type = config.get("chart_type", "unknown") + + if chart_type == "xy": + return ChartGenerationError( + error_type="xy_validation_error", + message="XY chart configuration validation failed", + details="The XY chart configuration is missing required " + "fields or has invalid structure", + suggestions=[ + "Ensure 'x' field exists with {'name': 'column_name'}", + "Ensure 'y' field is an array: [{'name': 'metric', " + "'aggregate': 'SUM'}]", + "Check that all column names are strings", + "Verify aggregate functions are valid: SUM, COUNT, AVG, " + "MIN, MAX", + ], + error_code="XY_VALIDATION_ERROR", + ) + elif chart_type == "table": + return ChartGenerationError( + error_type="table_validation_error", + message="Table chart configuration validation failed", + details="The table chart configuration is missing required " + "fields or has invalid structure", + suggestions=[ + "Ensure 'columns' field is an array of column " + "specifications", + "Each column needs {'name': 'column_name'}", + "Optional: add 'aggregate' for metrics", + "Example: 'columns': [{'name': 'product'}, {'name': " + "'sales', 'aggregate': 'SUM'}]", + ], + error_code="TABLE_VALIDATION_ERROR", + ) + + # Default enhanced error + error_details = [] + for err in errors[:3]: # Show first 3 errors + loc = " -> ".join(str(location) for location in err.get("loc", [])) + msg = err.get("msg", "Validation failed") + error_details.append(f"{loc}: {msg}") + + return ChartGenerationError( + error_type="validation_error", + message="Chart configuration validation failed", + details="; ".join(error_details), + suggestions=[ + "Check that all required fields are present", + "Ensure field types match the schema", + "Use get_dataset_info to verify column names", + "Refer to the API documentation for field requirements", + ], + error_code="VALIDATION_ERROR", + ) diff --git a/superset/mcp_service/system/__init__.py b/superset/mcp_service/commands/__init__.py similarity index 100% rename from superset/mcp_service/system/__init__.py rename to superset/mcp_service/commands/__init__.py diff --git a/superset/mcp_service/commands/create_form_data.py b/superset/mcp_service/commands/create_form_data.py new file mode 100644 index 00000000000..689be9c0036 --- /dev/null +++ b/superset/mcp_service/commands/create_form_data.py @@ -0,0 +1,33 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP-specific form data command that extends the base CreateFormDataCommand +""" + +from superset.commands.explore.form_data.create import CreateFormDataCommand +from superset.utils.core import get_user_id + + +class MCPCreateFormDataCommand(CreateFormDataCommand): + """ + MCP-specific CreateFormDataCommand that uses user_id instead of session._id + """ + + def _get_session_id(self) -> str: + """Override to use user_id instead of Flask session for MCP context.""" + return str(get_user_id()) diff --git a/superset/mcp_service/common/error_schemas.py b/superset/mcp_service/common/error_schemas.py new file mode 100644 index 00000000000..a0527c34f12 --- /dev/null +++ b/superset/mcp_service/common/error_schemas.py @@ -0,0 +1,103 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Enhanced error schemas for MCP chart generation with contextual information +""" + +from typing import Any, Dict, List + +from pydantic import BaseModel, Field + + +class ColumnSuggestion(BaseModel): + """Suggested column with context""" + + name: str = Field(..., description="Column name") + type: str = Field(..., description="Column data type") + similarity_score: float = Field(..., description="Similarity score (0-1)") + description: str | None = Field(None, description="Column description") + + +class ValidationError(BaseModel): + """Individual validation error with context""" + + field: str = Field(..., description="Field that failed validation") + provided_value: Any = Field(..., description="Value that was provided") + error_type: str = Field(..., description="Type of validation error") + message: str = Field(..., description="Human-readable error message") + suggestions: List[ColumnSuggestion] = Field( + default_factory=list, description="Suggested alternatives" + ) + + +class DatasetContext(BaseModel): + """Dataset information for error context""" + + id: int = Field(..., description="Dataset ID") + table_name: str = Field(..., description="Table name") + schema: str | None = Field(None, description="Schema name") + database_name: str = Field(..., description="Database name") + available_columns: List[Dict[str, Any]] = Field( + default_factory=list, description="Available columns with metadata" + ) + available_metrics: List[Dict[str, Any]] = Field( + default_factory=list, description="Available metrics with metadata" + ) + + +class ChartGenerationError(BaseModel): + """Enhanced error response for chart generation failures""" + + error_type: str = Field( + ..., description="Type of error (validation, execution, etc.)" + ) + message: str = Field(..., description="High-level error message") + details: str = Field(..., description="Detailed error explanation") + validation_errors: List[ValidationError] = Field( + default_factory=list, description="Specific field validation errors" + ) + dataset_context: DatasetContext | None = Field( + None, description="Dataset information for context" + ) + query_info: Dict[str, Any] | None = Field( + None, description="Query execution details" + ) + suggestions: List[str] = Field( + default_factory=list, description="Actionable suggestions to fix the error" + ) + help_url: str | None = Field( + None, description="URL to documentation for this error type" + ) + error_code: str | None = Field( + None, description="Unique error code for support reference" + ) + + +class ChartGenerationResponse(BaseModel): + """Enhanced chart generation response with detailed error handling""" + + success: bool = Field(..., description="Whether chart generation succeeded") + chart: Dict[str, Any] | None = Field( + None, description="Chart information if successful" + ) + error: ChartGenerationError | None = Field( + None, description="Error details if failed" + ) + performance: Dict[str, Any] | None = Field(None, description="Performance metadata") + schema_version: str = Field(default="2.0", description="Response schema version") + api_version: str = Field(default="v1", description="API version") diff --git a/superset/mcp_service/dashboard/prompts/__init__.py b/superset/mcp_service/dashboard/prompts/__init__.py new file mode 100644 index 00000000000..c89dac2f623 --- /dev/null +++ b/superset/mcp_service/dashboard/prompts/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Dashboard prompts for Superset MCP service""" + +# Import prompt modules here when created +# from . import create_dashboard_guided # noqa: F401 diff --git a/superset/mcp_service/dashboard/resources/__init__.py b/superset/mcp_service/dashboard/resources/__init__.py new file mode 100644 index 00000000000..e3c6ab17889 --- /dev/null +++ b/superset/mcp_service/dashboard/resources/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Dashboard resources for Superset MCP service""" + +# Import resource modules here when created +# from . import layout_templates # noqa: F401 diff --git a/superset/mcp_service/dashboard/schemas.py b/superset/mcp_service/dashboard/schemas.py new file mode 100644 index 00000000000..fd6bd632714 --- /dev/null +++ b/superset/mcp_service/dashboard/schemas.py @@ -0,0 +1,470 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for dashboard-related responses + +This module contains Pydantic models for serializing dashboard data +in a consistent and type-safe manner. + +Example usage: + # For detailed dashboard info + dashboard_info = DashboardInfo( + id=1, + dashboard_title="Sales Dashboard", + published=True, + owners=[UserInfo(id=1, username="admin")], + charts=[ChartInfo(id=1, slice_name="Sales Chart")] + ) + + # For dashboard list responses + dashboard_list = DashboardList( + dashboards=[ + DashboardInfo( + id=1, + dashboard_title="Sales Dashboard", + published=True, + tags=[TagInfo(id=1, name="sales")] + ) + ], + count=1, + total_count=1, + page=0, + page_size=10, + total_pages=1, + has_next=False, + has_previous=False, + columns_requested=["id", "dashboard_title"], + columns_loaded=["id", "dashboard_title", "published"], + filters_applied={"published": True}, + pagination=PaginationInfo( + page=0, + page_size=10, + total_count=1, + total_pages=1, + has_next=False, + has_previous=False + ), + timestamp=datetime.now(timezone.utc) + ) +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal, TYPE_CHECKING + +from pydantic import BaseModel, ConfigDict, Field, model_validator, PositiveInt + +if TYPE_CHECKING: + from superset.models.dashboard import Dashboard + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.chart.schemas import ChartInfo, serialize_chart_object +from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.system.schemas import ( + PaginationInfo, + RoleInfo, + TagInfo, + UserInfo, +) + + +class DashboardError(BaseModel): + """Error response for dashboard operations""" + + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "DashboardError": + """Create a standardized DashboardError with timestamp.""" + from datetime import datetime + + return cls(error=error, error_type=error_type, timestamp=datetime.now()) + + +def serialize_user_object(user: Any) -> UserInfo | None: + """Serialize a user object to UserInfo""" + if not user: + return None + + return UserInfo( + id=getattr(user, "id", None), + username=getattr(user, "username", None), + first_name=getattr(user, "first_name", None), + last_name=getattr(user, "last_name", None), + email=getattr(user, "email", None), + active=getattr(user, "active", None), + ) + + +def serialize_tag_object(tag: Any) -> TagInfo | None: + """Serialize a tag object to TagInfo""" + if not tag: + return None + + return TagInfo( + id=getattr(tag, "id", None), + name=getattr(tag, "name", None), + type=getattr(tag, "type", None), + description=getattr(tag, "description", None), + ) + + +def serialize_role_object(role: Any) -> RoleInfo | None: + """Serialize a role object to RoleInfo""" + if not role: + return None + + return RoleInfo( + id=getattr(role, "id", None), + name=getattr(role, "name", None), + permissions=[perm.name for perm in getattr(role, "permissions", [])] + if hasattr(role, "permissions") + else None, + ) + + +# TODO (Phase 3+): Add DashboardAvailableFilters for +# get_dashboard_available_filters tool + + +class DashboardFilter(ColumnOperator): + """ + Filter object for dashboard listing. + col: The column to filter on. Must be one of the allowed filter fields. + opr: The operator to use. Must be one of the supported operators. + value: The value to filter by (type depends on col and opr). + """ + + col: Literal[ + "dashboard_title", + "published", + "favorite", + ] = Field( + ..., + description="Column to filter on. See get_dashboard_available_filters for " + "allowed values.", + ) + opr: ColumnOperatorEnum = Field( + ..., + description="Operator to use. See get_dashboard_available_filters for " + "allowed values.", + ) + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by (type depends on col and opr)" + ) + + +class ListDashboardsRequest(MetadataCacheControl): + """Request schema for list_dashboards with clear, unambiguous types.""" + + filters: Annotated[ + List[DashboardFilter], + Field( + default_factory=list, + description="List of filter objects (column, operator, value). Each " + "filter is an object with 'col', 'opr', and 'value' properties. " + "Cannot be used together with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=lambda: [ + "id", + "dashboard_title", + "slug", + "published", + "changed_on", + "created_on", + "uuid", + ], + description="List of columns to select. Defaults to common columns " + "if not specified.", + ), + ] + search: Annotated[ + str | None, + Field( + default=None, + description="Text search string to match against dashboard fields. " + "Cannot be used together with 'filters'.", + ), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field( + default="asc", description="Direction to order results ('asc' or 'desc')" + ), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number for pagination (1-based)"), + ] + page_size: Annotated[ + PositiveInt, Field(default=10, description="Number of items per page") + ] + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListDashboardsRequest": + """Prevent using both search and filters simultaneously to avoid query + conflicts.""" + if self.search and self.filters: + raise ValueError( + "Cannot use both 'search' and 'filters' parameters simultaneously. " + "Use either 'search' for text-based searching across multiple fields, " + "or 'filters' for precise column-based filtering, but not both." + ) + return self + + +class GetDashboardInfoRequest(MetadataCacheControl): + """Request schema for get_dashboard_info with support for ID, UUID, or slug.""" + + identifier: Annotated[ + int | str, + Field( + description="Dashboard identifier - can be numeric ID, UUID string, or slug" + ), + ] + + +class DashboardInfo(BaseModel): + id: int = Field(..., description="Dashboard ID") + dashboard_title: str = Field(..., description="Dashboard title") + slug: str | None = Field(None, description="Dashboard slug") + description: str | None = Field(None, description="Dashboard description") + css: str | None = Field(None, description="Custom CSS for the dashboard") + certified_by: str | None = Field(None, description="Who certified the dashboard") + certification_details: str | None = Field(None, description="Certification details") + json_metadata: str | None = Field( + None, description="Dashboard metadata (JSON string)" + ) + position_json: str | None = Field(None, description="Chart positions (JSON string)") + published: bool | None = Field( + None, description="Whether the dashboard is published" + ) + is_managed_externally: bool | None = Field( + None, description="Whether managed externally" + ) + external_url: str | None = Field(None, description="External URL") + created_on: str | datetime | None = Field(None, description="Creation timestamp") + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + created_by: str | None = Field(None, description="Dashboard creator (username)") + changed_by: str | None = Field(None, description="Last modifier (username)") + uuid: str | None = Field(None, description="Dashboard UUID (converted to string)") + url: str | None = Field(None, description="Dashboard URL") + thumbnail_url: str | None = Field(None, description="Thumbnail URL") + created_on_humanized: str | None = Field( + None, description="Humanized creation time" + ) + changed_on_humanized: str | None = Field( + None, description="Humanized modification time" + ) + chart_count: int = Field(0, description="Number of charts in the dashboard") + owners: List[UserInfo] = Field(default_factory=list, description="Dashboard owners") + tags: List[TagInfo] = Field(default_factory=list, description="Dashboard tags") + roles: List[RoleInfo] = Field(default_factory=list, description="Dashboard roles") + charts: List[ChartInfo] = Field( + default_factory=list, description="Dashboard charts" + ) + model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601") + + +class DashboardList(BaseModel): + dashboards: List[DashboardInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] | None = None + columns_loaded: List[str] | None = None + filters_applied: List[DashboardFilter] = Field( + default_factory=list, + description="List of advanced filter dicts applied to the query.", + ) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class AddChartToDashboardRequest(BaseModel): + """Request schema for adding a chart to an existing dashboard.""" + + dashboard_id: int = Field( + ..., description="ID of the dashboard to add the chart to" + ) + chart_id: int = Field(..., description="ID of the chart to add to the dashboard") + target_tab: str | None = Field( + None, description="Target tab name (if dashboard has tabs)" + ) + + +class AddChartToDashboardResponse(BaseModel): + """Response schema for adding chart to dashboard.""" + + dashboard: DashboardInfo | None = Field( + None, description="The updated dashboard info, if successful" + ) + dashboard_url: str | None = Field( + None, description="URL to view the updated dashboard" + ) + position: dict[str, Any] | None = Field( + None, description="Position information for the added chart" + ) + error: str | None = Field(None, description="Error message, if operation failed") + + +class GenerateDashboardRequest(BaseModel): + """Request schema for generating a dashboard.""" + + chart_ids: List[int] = Field( + ..., description="List of chart IDs to include in the dashboard", min_length=1 + ) + dashboard_title: str = Field(..., description="Title for the new dashboard") + description: str | None = Field(None, description="Description for the dashboard") + published: bool = Field( + default=True, description="Whether to publish the dashboard" + ) + + +class GenerateDashboardResponse(BaseModel): + """Response schema for dashboard generation.""" + + dashboard: DashboardInfo | None = Field( + None, description="The created dashboard info, if successful" + ) + dashboard_url: str | None = Field(None, description="URL to view the dashboard") + error: str | None = Field(None, description="Error message, if creation failed") + + +# TODO (Phase 3+): Add GetDashboardAvailableFiltersRequest for +# get_dashboard_available_filters tool +class DashboardAvailableFilters(BaseModel): + column_operators: Dict[str, Any] = Field( + ..., description="Available filter operators and metadata for each column" + ) + + +class GetDashboardAvailableFiltersRequest(BaseModel): + """ + Request schema for get_dashboard_available_filters tool. + """ + + model_config = ConfigDict(extra="forbid", str_strip_whitespace=True) + + +def dashboard_serializer(dashboard: "Dashboard") -> DashboardInfo: + return DashboardInfo( + id=dashboard.id, + dashboard_title=dashboard.dashboard_title or "Untitled", + slug=dashboard.slug or "", + description=dashboard.description, + css=dashboard.css, + certified_by=dashboard.certified_by, + certification_details=dashboard.certification_details, + json_metadata=dashboard.json_metadata, + position_json=dashboard.position_json, + published=dashboard.published, + is_managed_externally=dashboard.is_managed_externally, + external_url=dashboard.external_url, + created_on=dashboard.created_on, + changed_on=dashboard.changed_on, + created_by=getattr(dashboard.created_by, "username", None) + if dashboard.created_by + else None, + changed_by=getattr(dashboard.changed_by, "username", None) + if dashboard.changed_by + else None, + uuid=str(dashboard.uuid) if dashboard.uuid else None, + url=dashboard.url, + thumbnail_url=dashboard.thumbnail_url, + created_on_humanized=dashboard.created_on_humanized, + changed_on_humanized=dashboard.changed_on_humanized, + chart_count=len(dashboard.slices) if dashboard.slices else 0, + owners=[ + UserInfo.model_validate(owner, from_attributes=True) + for owner in dashboard.owners + ] + if dashboard.owners + else [], + tags=[ + TagInfo.model_validate(tag, from_attributes=True) for tag in dashboard.tags + ] + if dashboard.tags + else [], + roles=[ + RoleInfo.model_validate(role, from_attributes=True) + for role in dashboard.roles + ] + if dashboard.roles + else [], + charts=[serialize_chart_object(chart) for chart in dashboard.slices] + if dashboard.slices + else [], + ) + + +def serialize_dashboard_object(dashboard: Any) -> DashboardInfo: + """Simple dashboard serializer that safely handles object attributes.""" + return DashboardInfo( + id=getattr(dashboard, "id", None), + dashboard_title=getattr(dashboard, "dashboard_title", None), + slug=getattr(dashboard, "slug", None), + url=getattr(dashboard, "url", None), + published=getattr(dashboard, "published", None), + changed_by_name=getattr(dashboard, "changed_by_name", None), + changed_on=getattr(dashboard, "changed_on", None), + changed_on_humanized=getattr(dashboard, "changed_on_humanized", None), + created_by_name=getattr(dashboard, "created_by_name", None), + created_on=getattr(dashboard, "created_on", None), + created_on_humanized=getattr(dashboard, "created_on_humanized", None), + description=getattr(dashboard, "description", None), + css=getattr(dashboard, "css", None), + certified_by=getattr(dashboard, "certified_by", None), + certification_details=getattr(dashboard, "certification_details", None), + json_metadata=getattr(dashboard, "json_metadata", None), + position_json=getattr(dashboard, "position_json", None), + is_managed_externally=getattr(dashboard, "is_managed_externally", None), + external_url=getattr(dashboard, "external_url", None), + uuid=str(getattr(dashboard, "uuid", "")) + if getattr(dashboard, "uuid", None) + else None, + thumbnail_url=getattr(dashboard, "thumbnail_url", None), + chart_count=len(getattr(dashboard, "slices", [])), + owners=getattr(dashboard, "owners", []), + tags=getattr(dashboard, "tags", []), + roles=getattr(dashboard, "roles", []), + charts=[ + serialize_chart_object(chart) for chart in getattr(dashboard, "slices", []) + ] + if getattr(dashboard, "slices", None) + else [], + ) diff --git a/superset/mcp_service/dashboard/tool/__init__.py b/superset/mcp_service/dashboard/tool/__init__.py new file mode 100644 index 00000000000..7fa11d67196 --- /dev/null +++ b/superset/mcp_service/dashboard/tool/__init__.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .add_chart_to_existing_dashboard import add_chart_to_existing_dashboard +from .generate_dashboard import generate_dashboard +from .get_dashboard_available_filters import get_dashboard_available_filters +from .get_dashboard_info import get_dashboard_info +from .list_dashboards import list_dashboards + +__all__ = [ + "list_dashboards", + "get_dashboard_info", + "generate_dashboard", + "add_chart_to_existing_dashboard", + "get_dashboard_available_filters", +] diff --git a/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py new file mode 100644 index 00000000000..3cd5f65a243 --- /dev/null +++ b/superset/mcp_service/dashboard/tool/add_chart_to_existing_dashboard.py @@ -0,0 +1,282 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP tool: add_chart_to_existing_dashboard + +This tool adds a chart to an existing dashboard with automatic layout positioning. +""" + +import logging +from typing import Any, Dict + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.dashboard.schemas import ( + AddChartToDashboardRequest, + AddChartToDashboardResponse, + DashboardInfo, +) +from superset.mcp_service.utils.url_utils import get_superset_base_url +from superset.utils import json + +logger = logging.getLogger(__name__) + + +def _find_next_row_position(layout: Dict[str, Any]) -> int: + """ + Find the next available row position in the dashboard layout. + + Returns: + Row index for the new chart + """ + # Find existing rows + row_indices = [] + for key in layout.keys(): + if key.startswith("ROW-") and key[4:].isdigit(): + row_indices.append(int(key[4:])) + + # Return next available row index + return max(row_indices) + 1 if row_indices else 0 + + +def _add_chart_to_layout( + layout: Dict[str, Any], chart: Any, chart_id: int, row_index: int +) -> tuple[str, str]: + """ + Add chart and row components to the dashboard layout. + + Returns: + Tuple of (chart_key, row_key) + """ + chart_key = f"CHART-{chart_id}" + row_key = f"ROW-{row_index}" + chart_width = 5 # Balanced width for good proportions + chart_height = 50 # Good height for most chart types + + # Add chart to layout using proper Superset structure + layout[chart_key] = { + "children": [], + "id": chart_key, + "meta": { + "chartId": chart_id, + "height": chart_height, + "sliceName": chart.slice_name or f"Chart {chart_id}", + "uuid": str(chart.uuid) if chart.uuid else f"chart-{chart_id}", + "width": chart_width, + }, + "parents": ["ROOT_ID", "GRID_ID", row_key], + "type": "CHART", + } + + # Create row for the chart + layout[row_key] = { + "children": [chart_key], + "id": row_key, + "meta": {"background": "BACKGROUND_TRANSPARENT"}, + "parents": ["ROOT_ID", "GRID_ID"], + "type": "ROW", + } + + return chart_key, row_key + + +def _ensure_layout_structure(layout: Dict[str, Any], row_key: str) -> None: + """ + Ensure the dashboard layout has proper GRID and ROOT structure. + """ + # Ensure GRID structure exists + if "GRID_ID" not in layout: + layout["GRID_ID"] = { + "children": [], + "id": "GRID_ID", + "parents": ["ROOT_ID"], + "type": "GRID", + } + + # Add row to GRID + if "children" not in layout["GRID_ID"]: + layout["GRID_ID"]["children"] = [] + layout["GRID_ID"]["children"].append(row_key) + + # Update ROOT_ID if it exists, or create it + if "ROOT_ID" in layout: + if "children" not in layout["ROOT_ID"]: + layout["ROOT_ID"]["children"] = [] + if "GRID_ID" not in layout["ROOT_ID"]["children"]: + layout["ROOT_ID"]["children"].append("GRID_ID") + else: + # Create ROOT_ID if it doesn't exist + layout["ROOT_ID"] = { + "children": ["GRID_ID"], + "id": "ROOT_ID", + "type": "ROOT", + } + + # Ensure dashboard version + if "DASHBOARD_VERSION_KEY" not in layout: + layout["DASHBOARD_VERSION_KEY"] = "v2" + + +@mcp.tool +@mcp_auth_hook +def add_chart_to_existing_dashboard( + request: AddChartToDashboardRequest, ctx: Context +) -> AddChartToDashboardResponse: + """ + Add chart to existing dashboard. Auto-positions in 2-column grid. + Returns updated dashboard info. + """ + try: + from superset.commands.dashboard.update import UpdateDashboardCommand + from superset.daos.dashboard import DashboardDAO + + # Validate dashboard exists + dashboard = DashboardDAO.find_by_id(request.dashboard_id) + if not dashboard: + return AddChartToDashboardResponse( + dashboard=None, + dashboard_url=None, + position=None, + error=f"Dashboard with ID {request.dashboard_id} not found", + ) + + # Get chart object for SQLAlchemy relationships and validation + from superset import db + from superset.models.slice import Slice + + new_chart = db.session.get(Slice, request.chart_id) + if not new_chart: + return AddChartToDashboardResponse( + dashboard=None, + dashboard_url=None, + position=None, + error=f"Chart with ID {request.chart_id} not found", + ) + + # Check if chart is already in dashboard + current_chart_ids = [slice.id for slice in dashboard.slices] + if request.chart_id in current_chart_ids: + return AddChartToDashboardResponse( + dashboard=None, + dashboard_url=None, + position=None, + error=( + f"Chart {request.chart_id} is already in dashboard " + f"{request.dashboard_id}" + ), + ) + + # Parse current layout + try: + current_layout = json.loads(dashboard.position_json or "{}") + except (json.JSONDecodeError, TypeError): + current_layout = {} + + # Find position for new chart + row_index = _find_next_row_position(current_layout) + + # Add chart and row to layout + chart_key, row_key = _add_chart_to_layout( + current_layout, new_chart, request.chart_id, row_index + ) + + # Ensure proper layout structure + _ensure_layout_structure(current_layout, row_key) + + # Get chart objects for SQLAlchemy relationships + # Get existing chart objects + existing_chart_objects = dashboard.slices + + # Combine existing and new chart objects (new_chart was retrieved above) + all_chart_objects = list(existing_chart_objects) + [new_chart] + + # Prepare update data + update_data = { + "position_json": json.dumps(current_layout), + "slices": all_chart_objects, # Pass ORM objects, not IDs + } + + # Update the dashboard + command = UpdateDashboardCommand(request.dashboard_id, update_data) + updated_dashboard = command.run() + + # Convert to response format + from superset.mcp_service.dashboard.schemas import ( + serialize_tag_object, + serialize_user_object, + ) + + dashboard_info = DashboardInfo( + id=updated_dashboard.id, + dashboard_title=updated_dashboard.dashboard_title, + slug=updated_dashboard.slug, + description=updated_dashboard.description, + published=updated_dashboard.published, + created_on=updated_dashboard.created_on, + changed_on=updated_dashboard.changed_on, + created_by=updated_dashboard.created_by.username + if updated_dashboard.created_by + else None, + changed_by=updated_dashboard.changed_by.username + if updated_dashboard.changed_by + else None, + uuid=str(updated_dashboard.uuid) if updated_dashboard.uuid else None, + url=f"{get_superset_base_url()}/superset/dashboard/{updated_dashboard.id}/", + chart_count=len(updated_dashboard.slices), + owners=[ + serialize_user_object(owner) + for owner in getattr(updated_dashboard, "owners", []) + if serialize_user_object(owner) is not None + ], + tags=[ + serialize_tag_object(tag) + for tag in getattr(updated_dashboard, "tags", []) + if serialize_tag_object(tag) is not None + ], + roles=[], + charts=[], + ) + + dashboard_url = ( + f"{get_superset_base_url()}/superset/dashboard/{updated_dashboard.id}/" + ) + + logger.info( + "Added chart %s to dashboard %s ", request.chart_id, request.dashboard_id + ) + + # Return position info for compatibility + position_info = {"row": row_index, "chart_key": chart_key, "row_key": row_key} + + return AddChartToDashboardResponse( + dashboard=dashboard_info, + dashboard_url=dashboard_url, + position=position_info, + error=None, + ) + + except Exception as e: + logger.error("Error adding chart to dashboard: %s", e) + return AddChartToDashboardResponse( + dashboard=None, + dashboard_url=None, + position=None, + error=f"Failed to add chart to dashboard: {str(e)}", + ) diff --git a/superset/mcp_service/dashboard/tool/generate_dashboard.py b/superset/mcp_service/dashboard/tool/generate_dashboard.py new file mode 100644 index 00000000000..c0d5fd517fb --- /dev/null +++ b/superset/mcp_service/dashboard/tool/generate_dashboard.py @@ -0,0 +1,241 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP tool: generate_dashboard + +This tool creates a new dashboard with specified charts and layout configuration. +""" + +import logging +from typing import Any, Dict, List + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.dashboard.schemas import ( + DashboardInfo, + GenerateDashboardRequest, + GenerateDashboardResponse, +) +from superset.mcp_service.utils.url_utils import get_superset_base_url +from superset.utils import json + +logger = logging.getLogger(__name__) + + +def _create_dashboard_layout(chart_objects: List[Any]) -> Dict[str, Any]: + """ + Create a simple dashboard layout with charts arranged in a grid. + + This creates a basic 2-column layout where charts are arranged + vertically in alternating columns. + + Args: + chart_objects: List of Chart ORM objects (not IDs) + """ + layout: Dict[str, Any] = {} + + # Grid configuration based on real Superset dashboard patterns + # Use 2-chart rows with medium-sized charts (like existing dashboards) + charts_per_row = 2 + chart_width = 5 # Balanced width for good proportions + chart_height = 50 # Good height for most chart types + + # Create rows with charts + row_ids = [] + for i in range(0, len(chart_objects), charts_per_row): + row_index = i // charts_per_row + row_id = f"ROW-{row_index}" + row_ids.append(row_id) + + # Get charts for this row (up to 2 charts like real dashboards) + row_charts = chart_objects[i : i + charts_per_row] + chart_keys = [] + + for chart in row_charts: + chart_key = f"CHART-{chart.id}" + chart_keys.append(chart_key) + + # Create chart component with standard dimensions + layout[chart_key] = { + "children": [], + "id": chart_key, + "meta": { + "chartId": chart.id, + "height": chart_height, + "sliceName": chart.slice_name or f"Chart {chart.id}", + "uuid": str(chart.uuid) if chart.uuid else f"chart-{chart.id}", + "width": chart_width, + }, + "parents": ["ROOT_ID", "GRID_ID", row_id], + "type": "CHART", + } + + # Create row containing the charts + layout[row_id] = { + "children": chart_keys, + "id": row_id, + "meta": {"background": "BACKGROUND_TRANSPARENT"}, + "parents": ["ROOT_ID", "GRID_ID"], + "type": "ROW", + } + + # Add GRID container + layout["GRID_ID"] = { + "children": row_ids, + "id": "GRID_ID", + "parents": ["ROOT_ID"], + "type": "GRID", + } + + # Add root layout container + layout["ROOT_ID"] = { + "children": ["GRID_ID"], + "id": "ROOT_ID", + "type": "ROOT", + } + + # Add dashboard version + layout["DASHBOARD_VERSION_KEY"] = "v2" + + return layout + + +@mcp.tool +@mcp_auth_hook +def generate_dashboard( + request: GenerateDashboardRequest, ctx: Context +) -> GenerateDashboardResponse: + """Create dashboard from chart IDs. + + IMPORTANT: + - All charts must exist and be accessible to current user + - Charts arranged automatically in 2-column grid layout + + Returns: + - Dashboard ID and URL + """ + try: + # Get chart objects from IDs (required for SQLAlchemy relationships) + from superset import db + from superset.commands.dashboard.create import CreateDashboardCommand + from superset.models.slice import Slice + + chart_objects = ( + db.session.query(Slice).filter(Slice.id.in_(request.chart_ids)).all() + ) + found_chart_ids = [chart.id for chart in chart_objects] + + # Check if all requested charts were found + missing_chart_ids = set(request.chart_ids) - set(found_chart_ids) + if missing_chart_ids: + return GenerateDashboardResponse( + dashboard=None, + dashboard_url=None, + error=f"Charts not found: {list(missing_chart_ids)}", + ) + + # Create dashboard layout with chart objects + layout = _create_dashboard_layout(chart_objects) + + # Prepare dashboard data + dashboard_data = { + "dashboard_title": request.dashboard_title, + "slug": None, # Let Superset auto-generate slug + "css": "", + "json_metadata": json.dumps( + { + "filter_scopes": {}, + "expanded_slices": {}, + "refresh_frequency": 0, + "timed_refresh_immune_slices": [], + "color_scheme": None, + "label_colors": {}, + "shared_label_colors": {}, + "color_scheme_domain": [], + "cross_filters_enabled": False, + "native_filter_configuration": [], + "global_chart_configuration": { + "scope": {"rootPath": ["ROOT_ID"], "excluded": []} + }, + "chart_configuration": {}, + } + ), + "position_json": json.dumps(layout), + "published": request.published, + "slices": chart_objects, # Pass ORM objects, not IDs + } + + if request.description: + dashboard_data["description"] = request.description + + # Create the dashboard using Superset's command pattern + command = CreateDashboardCommand(dashboard_data) + dashboard = command.run() + + # Convert to our response format + from superset.mcp_service.dashboard.schemas import ( + serialize_tag_object, + serialize_user_object, + ) + + dashboard_info = DashboardInfo( + id=dashboard.id, + dashboard_title=dashboard.dashboard_title, + slug=dashboard.slug, + description=dashboard.description, + published=dashboard.published, + created_on=dashboard.created_on, + changed_on=dashboard.changed_on, + created_by=dashboard.created_by.username if dashboard.created_by else None, + changed_by=dashboard.changed_by.username if dashboard.changed_by else None, + uuid=str(dashboard.uuid) if dashboard.uuid else None, + url=f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/", + chart_count=len(request.chart_ids), + owners=[ + serialize_user_object(owner) + for owner in getattr(dashboard, "owners", []) + if serialize_user_object(owner) is not None + ], + tags=[ + serialize_tag_object(tag) + for tag in getattr(dashboard, "tags", []) + if serialize_tag_object(tag) is not None + ], + roles=[], # Dashboard roles not typically set at creation + charts=[], # Chart details not needed in response + ) + + dashboard_url = f"{get_superset_base_url()}/superset/dashboard/{dashboard.id}/" + + logger.info( + "Created dashboard %s with %s charts", dashboard.id, len(request.chart_ids) + ) + + return GenerateDashboardResponse( + dashboard=dashboard_info, dashboard_url=dashboard_url, error=None + ) + + except Exception as e: + logger.error("Error creating dashboard: %s", e) + return GenerateDashboardResponse( + dashboard=None, + dashboard_url=None, + error=f"Failed to create dashboard: {str(e)}", + ) diff --git a/superset/mcp_service/dashboard/tool/get_dashboard_available_filters.py b/superset/mcp_service/dashboard/tool/get_dashboard_available_filters.py new file mode 100644 index 00000000000..0b4c31c8e4e --- /dev/null +++ b/superset/mcp_service/dashboard/tool/get_dashboard_available_filters.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Get available filters FastMCP tool +""" + +import logging + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.dashboard.schemas import ( + DashboardAvailableFilters, + GetDashboardAvailableFiltersRequest, +) +from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +async def get_dashboard_available_filters( + request: GetDashboardAvailableFiltersRequest, ctx: Context +) -> DashboardAvailableFilters: + """Get available dashboard filter fields and operators.""" + from superset.daos.dashboard import DashboardDAO + + tool = ModelGetAvailableFiltersCore( + dao_class=DashboardDAO, + output_schema=DashboardAvailableFilters, + logger=logger, + ) + return tool.run_tool() diff --git a/superset/mcp_service/dashboard/tool/get_dashboard_info.py b/superset/mcp_service/dashboard/tool/get_dashboard_info.py new file mode 100644 index 00000000000..9a3bfa23b8b --- /dev/null +++ b/superset/mcp_service/dashboard/tool/get_dashboard_info.py @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get dashboard info FastMCP tool + +This module contains the FastMCP tool for getting detailed information +about a specific dashboard. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.dashboard.schemas import ( + dashboard_serializer, + DashboardError, + DashboardInfo, + GetDashboardInfoRequest, +) +from superset.mcp_service.mcp_core import ModelGetInfoCore + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +async def get_dashboard_info( + request: GetDashboardInfoRequest, ctx: Context +) -> DashboardInfo | DashboardError: + """ + Get dashboard metadata by ID, UUID, or slug. + + Returns title, charts, and layout details. + """ + await ctx.info("Retrieving dashboard information: %s" % (request.identifier,)) + await ctx.debug( + "Metadata cache settings: use_cache=%s, refresh_metadata=%s, force_refresh=%s" + % (request.use_cache, request.refresh_metadata, request.force_refresh) + ) + + try: + from superset.daos.dashboard import DashboardDAO + + tool = ModelGetInfoCore( + dao_class=DashboardDAO, + output_schema=DashboardInfo, + error_schema=DashboardError, + serializer=dashboard_serializer, + supports_slug=True, # Dashboards support slugs + logger=logger, + ) + + result = tool.run_tool(request.identifier) + + if isinstance(result, DashboardInfo): + await ctx.info( + "Dashboard information retrieved successfully: id=%s, title=%s, " + "chart_count=%s, published=%s" + % ( + result.id, + result.dashboard_title, + result.chart_count, + result.published, + ) + ) + else: + await ctx.warning( + "Dashboard retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "Dashboard information retrieval failed: identifier=%s, error=%s, " + "error_type=%s" % (request.identifier, str(e), type(e).__name__) + ) + return DashboardError( + error=f"Failed to get dashboard info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/dashboard/tool/list_dashboards.py b/superset/mcp_service/dashboard/tool/list_dashboards.py new file mode 100644 index 00000000000..aaa2fda7fe9 --- /dev/null +++ b/superset/mcp_service/dashboard/tool/list_dashboards.py @@ -0,0 +1,99 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +List dashboards FastMCP tool (Advanced with metadata cache control) + +This module contains the FastMCP tool for listing dashboards using +advanced filtering with clear, unambiguous request schema and metadata cache control. +""" + +import logging + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.dashboard.schemas import ( + DashboardFilter, + DashboardInfo, + DashboardList, + ListDashboardsRequest, + serialize_dashboard_object, +) +from superset.mcp_service.mcp_core import ModelListCore + +logger = logging.getLogger(__name__) + +DEFAULT_DASHBOARD_COLUMNS = [ + "id", + "dashboard_title", + "slug", + "uuid", + "published", + "changed_on", + "created_on", +] + +SORTABLE_DASHBOARD_COLUMNS = [ + "id", + "dashboard_title", + "slug", + "published", + "changed_on", + "created_on", +] + + +@mcp.tool +@mcp_auth_hook +async def list_dashboards( + request: ListDashboardsRequest, ctx: Context +) -> DashboardList: + """List dashboards with filtering and search. Returns dashboard metadata + including title, slug, and charts. + + Sortable columns for order_column: id, dashboard_title, slug, published, + changed_on, created_on + """ + + from superset.daos.dashboard import DashboardDAO + + tool = ModelListCore( + dao_class=DashboardDAO, + output_schema=DashboardInfo, + item_serializer=lambda obj, cols: serialize_dashboard_object(obj), + filter_type=DashboardFilter, + default_columns=DEFAULT_DASHBOARD_COLUMNS, + search_columns=[ + "dashboard_title", + "slug", + "uuid", + ], + list_field_name="dashboards", + output_list_schema=DashboardList, + logger=logger, + ) + return tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) diff --git a/superset/mcp_service/dataset/prompts/__init__.py b/superset/mcp_service/dataset/prompts/__init__.py new file mode 100644 index 00000000000..9d255df465b --- /dev/null +++ b/superset/mcp_service/dataset/prompts/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Dataset prompts for Superset MCP service""" + +# Import prompt modules here when created +# from . import explore_dataset # noqa: F401 diff --git a/superset/mcp_service/dataset/resources/__init__.py b/superset/mcp_service/dataset/resources/__init__.py new file mode 100644 index 00000000000..7928964e41f --- /dev/null +++ b/superset/mcp_service/dataset/resources/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Dataset resources for Superset MCP service""" + +# Import resource modules here when created +# from . import column_suggestions # noqa: F401 diff --git a/superset/mcp_service/dataset/schemas.py b/superset/mcp_service/dataset/schemas.py new file mode 100644 index 00000000000..e7a6adc1300 --- /dev/null +++ b/superset/mcp_service/dataset/schemas.py @@ -0,0 +1,349 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pydantic schemas for dataset-related responses +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Annotated, Any, Dict, List, Literal + +from pydantic import BaseModel, ConfigDict, Field, model_validator, PositiveInt + +from superset.daos.base import ColumnOperator, ColumnOperatorEnum +from superset.mcp_service.common.cache_schemas import MetadataCacheControl +from superset.mcp_service.system.schemas import ( + PaginationInfo, + TagInfo, + UserInfo, +) +from superset.utils import json + + +class GetDatasetAvailableFiltersRequest(BaseModel): + """ + Request schema for get_dataset_available_filters tool. + + Currently has no parameters but provides consistent API for future extensibility. + """ + + model_config = ConfigDict( + extra="forbid", + str_strip_whitespace=True, + ) + + +class DatasetAvailableFilters(BaseModel): + column_operators: Dict[str, List[str]] = Field( + ..., + description="Available filter operators for each column: mapping from column " + "name to list of supported operators", + ) + + +class DatasetFilter(ColumnOperator): + """ + Filter object for dataset listing. + col: The column to filter on. Must be one of the allowed filter fields. + opr: The operator to use. Must be one of the supported operators. + value: The value to filter by (type depends on col and opr). + """ + + col: Literal[ + "table_name", + "schema", + "owner", + "favorite", + ] = Field( + ..., + description="Column to filter on. See get_dataset_available_filters for " + "allowed values.", + ) + opr: ColumnOperatorEnum = Field( + ..., + description="Operator to use. See get_dataset_available_filters for " + "allowed values.", + ) + value: str | int | float | bool | List[str | int | float | bool] = Field( + ..., description="Value to filter by (type depends on col and opr)" + ) + + +class TableColumnInfo(BaseModel): + column_name: str = Field(..., description="Column name") + verbose_name: str | None = Field(None, description="Verbose name") + type: str | None = Field(None, description="Column type") + is_dttm: bool | None = Field(None, description="Is datetime column") + groupby: bool | None = Field(None, description="Is groupable") + filterable: bool | None = Field(None, description="Is filterable") + description: str | None = Field(None, description="Column description") + + +class SqlMetricInfo(BaseModel): + metric_name: str = Field(..., description="Metric name") + verbose_name: str | None = Field(None, description="Verbose name") + expression: str | None = Field(None, description="SQL expression") + description: str | None = Field(None, description="Metric description") + d3format: str | None = Field(None, description="D3 format string") + + +class DatasetInfo(BaseModel): + id: int | None = Field(None, description="Dataset ID") + table_name: str | None = Field(None, description="Table name") + schema_name: str | None = Field(None, description="Schema name", alias="schema") + database_name: str | None = Field(None, description="Database name") + description: str | None = Field(None, description="Dataset description") + changed_by: str | None = Field(None, description="Last modifier (username)") + changed_on: str | datetime | None = Field( + None, description="Last modification timestamp" + ) + changed_on_humanized: str | None = Field( + None, description="Humanized modification time" + ) + created_by: str | None = Field(None, description="Dataset creator (username)") + created_on: str | datetime | None = Field(None, description="Creation timestamp") + created_on_humanized: str | None = Field( + None, description="Humanized creation time" + ) + tags: List[TagInfo] = Field(default_factory=list, description="Dataset tags") + owners: List[UserInfo] = Field( + default_factory=list, description="DatasetInfo owners" + ) + is_virtual: bool | None = Field( + None, description="Whether the dataset is virtual (uses SQL)" + ) + database_id: int | None = Field(None, description="Database ID") + uuid: str | None = Field(None, description="Dataset UUID") + schema_perm: str | None = Field(None, description="Schema permission string") + url: str | None = Field(None, description="Dataset URL") + sql: str | None = Field(None, description="SQL for virtual datasets") + main_dttm_col: str | None = Field(None, description="Main datetime column") + offset: int | None = Field(None, description="Offset") + cache_timeout: int | None = Field(None, description="Cache timeout") + params: Dict[str, Any | None] | None = Field(None, description="Extra params") + template_params: Dict[str, Any | None] | None = Field( + None, description="Template params" + ) + extra: Dict[str, Any | None] | None = Field(None, description="Extra metadata") + columns: List[TableColumnInfo] = Field( + default_factory=list, description="Columns in the dataset" + ) + metrics: List[SqlMetricInfo] = Field( + default_factory=list, description="Metrics in the dataset" + ) + is_favorite: bool | None = Field( + None, description="Whether this dataset is favorited by the current user" + ) + model_config = ConfigDict( + from_attributes=True, + ser_json_timedelta="iso8601", + populate_by_name=True, # Allow both 'schema' (alias) and 'schema_name' (field) + ) + + +class DatasetList(BaseModel): + datasets: List[DatasetInfo] + count: int + total_count: int + page: int + page_size: int + total_pages: int + has_previous: bool + has_next: bool + columns_requested: List[str] | None = None + columns_loaded: List[str] | None = None + filters_applied: List[DatasetFilter] = Field( + default_factory=list, + description="List of advanced filter dicts applied to the query.", + ) + pagination: PaginationInfo | None = None + timestamp: datetime | None = None + model_config = ConfigDict(ser_json_timedelta="iso8601") + + +class ListDatasetsRequest(MetadataCacheControl): + """Request schema for list_datasets with clear, unambiguous types.""" + + filters: Annotated[ + List[DatasetFilter], + Field( + default_factory=list, + description="List of filter objects (column, operator, value). Each " + "filter is an object with 'col', 'opr', and 'value' " + "properties. Cannot be used together with 'search'.", + ), + ] + select_columns: Annotated[ + List[str], + Field( + default_factory=lambda: [ + "id", + "table_name", + "schema", + "database_name", + "changed_by_name", + "changed_on", + "created_by_name", + "created_on", + "metrics", + "columns", + "uuid", + ], + description="List of columns to select. Defaults to common columns if not " + "specified.", + ), + ] + search: Annotated[ + str | None, + Field( + default=None, + description="Text search string to match against dataset fields. Cannot " + "be used together with 'filters'.", + ), + ] + order_column: Annotated[ + str | None, Field(default=None, description="Column to order results by") + ] + order_direction: Annotated[ + Literal["asc", "desc"], + Field( + default="desc", description="Direction to order results ('asc' or 'desc')" + ), + ] + page: Annotated[ + PositiveInt, + Field(default=1, description="Page number for pagination (1-based)"), + ] + page_size: Annotated[ + PositiveInt, Field(default=10, description="Number of items per page") + ] + + @model_validator(mode="after") + def validate_search_and_filters(self) -> "ListDatasetsRequest": + """Prevent using both search and filters simultaneously to avoid query + conflicts.""" + if self.search and self.filters: + raise ValueError( + "Cannot use both 'search' and 'filters' parameters simultaneously. " + "Use either 'search' for text-based searching across multiple fields, " + "or 'filters' for precise column-based filtering, but not both." + ) + return self + + +class DatasetError(BaseModel): + error: str = Field(..., description="Error message") + error_type: str = Field(..., description="Type of error") + timestamp: str | datetime | None = Field(None, description="Error timestamp") + model_config = ConfigDict(ser_json_timedelta="iso8601") + + @classmethod + def create(cls, error: str, error_type: str) -> "DatasetError": + """Create a standardized DatasetError with timestamp.""" + from datetime import datetime + + return cls(error=error, error_type=error_type, timestamp=datetime.now()) + + +class GetDatasetInfoRequest(MetadataCacheControl): + """Request schema for get_dataset_info with support for ID or UUID.""" + + identifier: Annotated[ + int | str, + Field(description="Dataset identifier - can be numeric ID or UUID string"), + ] + + +def serialize_dataset_object(dataset: Any) -> DatasetInfo | None: + if not dataset: + return None + params = getattr(dataset, "params", None) + if isinstance(params, str): + try: + params = json.loads(params) + except Exception: + params = None + columns = [ + TableColumnInfo( + column_name=getattr(col, "column_name", None), + verbose_name=getattr(col, "verbose_name", None), + type=getattr(col, "type", None), + is_dttm=getattr(col, "is_dttm", None), + groupby=getattr(col, "groupby", None), + filterable=getattr(col, "filterable", None), + description=getattr(col, "description", None), + ) + for col in getattr(dataset, "columns", []) + ] + metrics = [ + SqlMetricInfo( + metric_name=getattr(metric, "metric_name", None), + verbose_name=getattr(metric, "verbose_name", None), + expression=getattr(metric, "expression", None), + description=getattr(metric, "description", None), + d3format=getattr(metric, "d3format", None), + ) + for metric in getattr(dataset, "metrics", []) + ] + return DatasetInfo( + id=getattr(dataset, "id", None), + table_name=getattr(dataset, "table_name", None), + schema_name=getattr(dataset, "schema", None), + database_name=getattr(dataset.database, "database_name", None) + if getattr(dataset, "database", None) + else None, + description=getattr(dataset, "description", None), + changed_by=getattr(dataset, "changed_by_name", None) + or (str(dataset.changed_by) if getattr(dataset, "changed_by", None) else None), + changed_on=getattr(dataset, "changed_on", None), + changed_on_humanized=getattr(dataset, "changed_on_humanized", None), + created_by=getattr(dataset, "created_by_name", None) + or (str(dataset.created_by) if getattr(dataset, "created_by", None) else None), + created_on=getattr(dataset, "created_on", None), + created_on_humanized=getattr(dataset, "created_on_humanized", None), + tags=[ + TagInfo.model_validate(tag, from_attributes=True) + for tag in getattr(dataset, "tags", []) + ] + if getattr(dataset, "tags", None) + else [], + owners=[ + UserInfo.model_validate(owner, from_attributes=True) + for owner in getattr(dataset, "owners", []) + ] + if getattr(dataset, "owners", None) + else [], + is_virtual=getattr(dataset, "is_virtual", None), + database_id=getattr(dataset, "database_id", None), + uuid=str(getattr(dataset, "uuid", "")) + if getattr(dataset, "uuid", None) + else None, + schema_perm=getattr(dataset, "schema_perm", None), + url=getattr(dataset, "url", None), + sql=getattr(dataset, "sql", None), + main_dttm_col=getattr(dataset, "main_dttm_col", None), + offset=getattr(dataset, "offset", None), + cache_timeout=getattr(dataset, "cache_timeout", None), + params=params, + template_params=getattr(dataset, "template_params", None), + extra=getattr(dataset, "extra", None), + columns=columns, + metrics=metrics, + is_favorite=getattr(dataset, "is_favorite", None), + ) diff --git a/superset/mcp_service/dataset/tool/__init__.py b/superset/mcp_service/dataset/tool/__init__.py new file mode 100644 index 00000000000..217fafa08d2 --- /dev/null +++ b/superset/mcp_service/dataset/tool/__init__.py @@ -0,0 +1,26 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from .get_dataset_available_filters import get_dataset_available_filters +from .get_dataset_info import get_dataset_info +from .list_datasets import list_datasets + +__all__ = [ + "list_datasets", + "get_dataset_info", + "get_dataset_available_filters", +] diff --git a/superset/mcp_service/dataset/tool/get_dataset_available_filters.py b/superset/mcp_service/dataset/tool/get_dataset_available_filters.py new file mode 100644 index 00000000000..4072f05f3c3 --- /dev/null +++ b/superset/mcp_service/dataset/tool/get_dataset_available_filters.py @@ -0,0 +1,49 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Get available dataset filters FastMCP tool +""" + +import logging + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.dataset.schemas import ( + DatasetAvailableFilters, + GetDatasetAvailableFiltersRequest, +) +from superset.mcp_service.mcp_core import ModelGetAvailableFiltersCore + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +async def get_dataset_available_filters( + request: GetDatasetAvailableFiltersRequest, ctx: Context +) -> DatasetAvailableFilters: + """Get available dataset filter fields and operators.""" + from superset.daos.dataset import DatasetDAO + + tool = ModelGetAvailableFiltersCore( + dao_class=DatasetDAO, + output_schema=DatasetAvailableFilters, + logger=logger, + ) + return tool.run_tool() diff --git a/superset/mcp_service/dataset/tool/get_dataset_info.py b/superset/mcp_service/dataset/tool/get_dataset_info.py new file mode 100644 index 00000000000..a4b83030e1a --- /dev/null +++ b/superset/mcp_service/dataset/tool/get_dataset_info.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get dataset info FastMCP tool + +This module contains the FastMCP tool for getting detailed information +about a specific dataset. +""" + +import logging +from datetime import datetime, timezone + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.dataset.schemas import ( + DatasetError, + DatasetInfo, + GetDatasetInfoRequest, + serialize_dataset_object, +) +from superset.mcp_service.mcp_core import ModelGetInfoCore + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +async def get_dataset_info( + request: GetDatasetInfoRequest, ctx: Context +) -> DatasetInfo | DatasetError: + """Get dataset metadata by ID or UUID. + + Returns columns, metrics, and schema details. + """ + await ctx.info( + "Retrieving dataset information: identifier=%s" % (request.identifier,) + ) + await ctx.debug( + "Metadata cache settings: use_cache=%s refresh_metadata=%s force_refresh=%s" + % ( + request.use_cache, + request.refresh_metadata, + request.force_refresh, + ) + ) + + try: + from superset.daos.dataset import DatasetDAO + + tool = ModelGetInfoCore( + dao_class=DatasetDAO, + output_schema=DatasetInfo, + error_schema=DatasetError, + serializer=serialize_dataset_object, + supports_slug=False, # Datasets don't have slugs + logger=logger, + ) + + result = tool.run_tool(request.identifier) + + if isinstance(result, DatasetInfo): + await ctx.info( + "Dataset information retrieved successfully: " + "dataset_id=%s, table_name=%s, columns_count=%s, metrics_count=%s" + % ( + result.id, + result.table_name, + len(result.columns) if result.columns else 0, + len(result.metrics) if result.metrics else 0, + ) + ) + else: + await ctx.warning( + "Dataset retrieval failed: error_type=%s, error=%s" + % (result.error_type, result.error) + ) + + return result + + except Exception as e: + await ctx.error( + "Dataset information retrieval failed: identifier=%s, error=%s, " + "error_type=%s" + % ( + request.identifier, + str(e), + type(e).__name__, + ) + ) + return DatasetError( + error=f"Failed to get dataset info: {str(e)}", + error_type="InternalError", + timestamp=datetime.now(timezone.utc), + ) diff --git a/superset/mcp_service/dataset/tool/list_datasets.py b/superset/mcp_service/dataset/tool/list_datasets.py new file mode 100644 index 00000000000..32edd6def36 --- /dev/null +++ b/superset/mcp_service/dataset/tool/list_datasets.py @@ -0,0 +1,149 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +List datasets FastMCP tool (Advanced with metadata cache control) + +This module contains the FastMCP tool for listing datasets using +advanced filtering with clear, unambiguous request schema and metadata cache control. +""" + +import logging + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.dataset.schemas import ( + DatasetFilter, + DatasetInfo, + DatasetList, + ListDatasetsRequest, + serialize_dataset_object, +) +from superset.mcp_service.mcp_core import ModelListCore + +logger = logging.getLogger(__name__) + +DEFAULT_DATASET_COLUMNS = [ + "id", + "table_name", + "schema", + "uuid", + "database_name", + "changed_by_name", + "changed_on", + "created_by_name", + "created_on", + "metrics", + "columns", +] + +SORTABLE_DATASET_COLUMNS = [ + "id", + "table_name", + "schema", + "changed_on", + "created_on", +] + + +@mcp.tool +@mcp_auth_hook +async def list_datasets(request: ListDatasetsRequest, ctx: Context) -> DatasetList: + """List datasets with filtering and search. + + Returns dataset metadata including columns and metrics. + + Sortable columns for order_column: id, table_name, schema, changed_on, + created_on + """ + await ctx.info( + "Listing datasets: page=%s, page_size=%s, search=%s" + % ( + request.page, + request.page_size, + request.search, + ) + ) + await ctx.debug( + "Dataset listing parameters: filters=%s, order_column=%s, " + "order_direction=%s, select_columns=%s" + % ( + request.filters, + request.order_column, + request.order_direction, + request.select_columns, + ) + ) + await ctx.debug( + "Metadata cache settings: use_cache=%s, refresh_metadata=%s, force_refresh=%s" + % ( + request.use_cache, + request.refresh_metadata, + request.force_refresh, + ) + ) + + try: + from superset.daos.dataset import DatasetDAO + + # Create tool with standard serialization + tool = ModelListCore( + dao_class=DatasetDAO, + output_schema=DatasetInfo, + item_serializer=lambda obj, cols: serialize_dataset_object(obj), + filter_type=DatasetFilter, + default_columns=DEFAULT_DATASET_COLUMNS, + search_columns=["schema", "sql", "table_name", "uuid"], + list_field_name="datasets", + output_list_schema=DatasetList, + logger=logger, + ) + + result = tool.run_tool( + filters=request.filters, + search=request.search, + select_columns=request.select_columns, + order_column=request.order_column, + order_direction=request.order_direction, + page=max(request.page - 1, 0), + page_size=request.page_size, + ) + + await ctx.info( + "Datasets listed successfully: count=%s, total_count=%s, total_pages=%s" + % ( + len(result.datasets) if hasattr(result, "datasets") else 0, + getattr(result, "total_count", None), + getattr(result, "total_pages", None), + ) + ) + + return result + + except Exception as e: + await ctx.error( + "Dataset listing failed: page=%s, page_size=%s, error=%s, error_type=%s" + % ( + request.page, + request.page_size, + str(e), + type(e).__name__, + ) + ) + raise diff --git a/superset/mcp_service/explore/__init__.py b/superset/mcp_service/explore/__init__.py new file mode 100644 index 00000000000..bdd21bcb262 --- /dev/null +++ b/superset/mcp_service/explore/__init__.py @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""MCP explore-related functionality.""" diff --git a/superset/mcp_service/explore/tool/__init__.py b/superset/mcp_service/explore/tool/__init__.py new file mode 100644 index 00000000000..477715b9eb6 --- /dev/null +++ b/superset/mcp_service/explore/tool/__init__.py @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Explore tools for MCP service. + +This module contains FastMCP tools for generating and working with +explore links and configurations. +""" + +from .generate_explore_link import generate_explore_link + +__all__ = [ + "generate_explore_link", +] diff --git a/superset/mcp_service/explore/tool/generate_explore_link.py b/superset/mcp_service/explore/tool/generate_explore_link.py new file mode 100644 index 00000000000..aa09c624e1d --- /dev/null +++ b/superset/mcp_service/explore/tool/generate_explore_link.py @@ -0,0 +1,111 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Generate explore link MCP tool + +This tool generates a URL to the Superset explore interface with the specified +chart configuration. +""" + +from typing import Any, Dict + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.chart.chart_utils import ( + generate_explore_link as generate_url, + map_config_to_form_data, +) +from superset.mcp_service.chart.schemas import ( + GenerateExploreLinkRequest, +) + + +@mcp.tool +@mcp_auth_hook +async def generate_explore_link( + request: GenerateExploreLinkRequest, ctx: Context +) -> Dict[str, Any]: + """Generate explore URL for interactive visualization. + + PREFERRED TOOL for most visualization requests. + + Use this tool for: + - "Show me a chart of [data]" + - "Visualize [data]" + - General data exploration + - When user wants to SEE data visually + + Better UX because: + - Users can interact with chart before saving + - Easy to modify parameters instantly + - No database clutter from exploration + + Only use generate_chart when user EXPLICITLY requests to save/create a + permanent chart. + + Returns explore URL for immediate use. + """ + await ctx.info( + "Generating explore link for dataset_id=%s, chart_type=%s" + % (request.dataset_id, request.config.chart_type) + ) + await ctx.debug( + "Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s" + % (request.use_cache, request.force_refresh, request.cache_form_data) + ) + + try: + await ctx.report_progress(1, 3, "Converting configuration to form data") + # Map config to form_data using shared utilities + form_data = map_config_to_form_data(request.config) + + await ctx.debug( + "Form data generated with keys: %s, has_viz_type=%s, has_datasource=%s" + % ( + list(form_data.keys()), + bool(form_data.get("viz_type")), + bool(form_data.get("datasource")), + ) + ) + + await ctx.report_progress(2, 3, "Generating explore URL") + # Generate explore link using shared utilities + explore_url = generate_url(dataset_id=request.dataset_id, form_data=form_data) + + await ctx.report_progress(3, 3, "URL generation complete") + await ctx.info( + "Explore link generated successfully: url_length=%s, dataset_id=%s" + % (len(explore_url), request.dataset_id) + ) + + return { + "url": explore_url, + "error": None, + } + + except Exception as e: + await ctx.error( + "Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s" + % (request.dataset_id, request.config.chart_type, type(e).__name__, str(e)) + ) + return { + "url": "", + "error": f"Failed to generate explore link: {str(e)}", + } diff --git a/superset/mcp_service/flask_singleton.py b/superset/mcp_service/flask_singleton.py index 17816a36e40..8264a23b034 100644 --- a/superset/mcp_service/flask_singleton.py +++ b/superset/mcp_service/flask_singleton.py @@ -36,12 +36,18 @@ try: from superset.app import create_app from superset.mcp_service.mcp_config import get_mcp_config - # Create the Flask app instance - this is the singleton - app = create_app() + # Create a temporary context to avoid + # "Working outside of application context" errors. + _temp_app = create_app() - # Apply MCP configuration - reads from app.config first, falls back to defaults - mcp_config = get_mcp_config(app.config) - app.config.update(mcp_config) + # Push an application context for any initialization code that needs it + with _temp_app.app_context(): + # Apply MCP configuration - reads from app.config first, falls back to defaults + mcp_config = get_mcp_config(_temp_app.config) + _temp_app.config.update(mcp_config) + + # Store the app instance for later use + app = _temp_app logger.info("Flask app instance created successfully") diff --git a/superset/mcp_service/mcp_config.py b/superset/mcp_service/mcp_config.py index efe0d25e367..55ea8a226ca 100644 --- a/superset/mcp_service/mcp_config.py +++ b/superset/mcp_service/mcp_config.py @@ -16,7 +16,13 @@ # under the License. """Default MCP service configuration for Apache Superset""" -from typing import Any, Dict +import logging +import secrets +from typing import Any, Dict, Optional + +from flask import Flask + +logger = logging.getLogger(__name__) # MCP Service Configuration # Note: MCP_DEV_USERNAME MUST be configured in superset_config.py @@ -67,6 +73,79 @@ MCP_FACTORY_CONFIG = { } +def create_default_mcp_auth_factory(app: Flask) -> Optional[Any]: + """Default MCP auth factory using app.config values.""" + if not app.config.get("MCP_AUTH_ENABLED", False): + return None + + jwks_uri = app.config.get("MCP_JWKS_URI") + public_key = app.config.get("MCP_JWT_PUBLIC_KEY") + secret = app.config.get("MCP_JWT_SECRET") + + if not (jwks_uri or public_key or secret): + logger.warning("MCP_AUTH_ENABLED is True but no JWT keys/secret configured") + return None + + try: + from fastmcp.server.auth.providers.bearer import BearerAuthProvider + + # For HS256 (symmetric), use the secret as the public_key parameter + if app.config.get("MCP_JWT_ALGORITHM") == "HS256" and secret: + auth_provider = BearerAuthProvider( + public_key=secret, # HS256 uses secret as key + issuer=app.config.get("MCP_JWT_ISSUER"), + audience=app.config.get("MCP_JWT_AUDIENCE"), + algorithm="HS256", + required_scopes=app.config.get("MCP_REQUIRED_SCOPES", []), + ) + logger.info("Created BearerAuthProvider with HS256 secret") + else: + # For RS256 (asymmetric), use public key or JWKS + auth_provider = BearerAuthProvider( + jwks_uri=jwks_uri, + public_key=public_key, + issuer=app.config.get("MCP_JWT_ISSUER"), + audience=app.config.get("MCP_JWT_AUDIENCE"), + algorithm=app.config.get("MCP_JWT_ALGORITHM", "RS256"), + required_scopes=app.config.get("MCP_REQUIRED_SCOPES", []), + ) + logger.info( + "Created BearerAuthProvider with jwks_uri=%s, public_key=%s", + jwks_uri, + "***" if public_key else None, + ) + + return auth_provider + except Exception as e: + logger.error("Failed to create MCP auth provider: %s", e) + return None + + +def default_user_resolver(app: Any, access_token: Any) -> Optional[str]: + """Extract username from JWT token claims.""" + logger.info( + "Resolving user from token: type=%s, token=%s", + type(access_token), + access_token, + ) + if hasattr(access_token, "subject"): + return access_token.subject + if hasattr(access_token, "client_id"): + return access_token.client_id + if hasattr(access_token, "payload") and isinstance(access_token.payload, dict): + return ( + access_token.payload.get("sub") + or access_token.payload.get("email") + or access_token.payload.get("username") + ) + return None + + +def generate_secret_key() -> str: + """Generate a secure random secret key for Superset""" + return secrets.token_urlsafe(42) + + def get_mcp_config(app_config: Dict[str, Any] | None = None) -> Dict[str, Any]: """ Get complete MCP configuration dictionary. diff --git a/superset/mcp_service/mcp_core.py b/superset/mcp_service/mcp_core.py index e472b4602c0..669e21420d4 100644 --- a/superset/mcp_service/mcp_core.py +++ b/superset/mcp_service/mcp_core.py @@ -20,7 +20,7 @@ from __future__ import annotations import logging from abc import ABC, abstractmethod from datetime import datetime, timezone -from typing import Any, Callable, Generic, List, Literal, Type, TypeVar +from typing import Any, Callable, Dict, Generic, List, Literal, Type, TypeVar from pydantic import BaseModel @@ -295,3 +295,210 @@ class ModelGetInfoCore(BaseCore): except Exception as context_error: self._log_error(context_error) raise + + +class InstanceInfoCore(BaseCore): + """ + Configurable tool for generating comprehensive instance information. + + Provides a flexible way to gather and present statistics about a Superset + instance with configurable metrics, time windows, and data aggregations. + Supports custom metric calculators and result transformers for extensibility. + """ + + def __init__( + self, + dao_classes: Dict[str, Type[BaseDAO[Any]]], + output_schema: Type[BaseModel], + metric_calculators: Dict[str, Callable[..., Any]], + time_windows: Dict[str, int] | None = None, + logger: logging.Logger | None = None, + ) -> None: + """ + Initialize the instance info tool. + + Args: + dao_classes: Dict mapping entity names to their DAO classes + output_schema: Pydantic schema for the response + metric_calculators: Dict of custom metric calculation functions + time_windows: Dict of time window configurations (days) + logger: Optional logger instance + """ + super().__init__(logger) + self.dao_classes = dao_classes + self.output_schema = output_schema + self.metric_calculators = metric_calculators + self.time_windows = time_windows or { + "recent": 7, + "monthly": 30, + "quarterly": 90, + } + + def _calculate_basic_counts(self) -> Dict[str, int]: + """Calculate basic entity counts using DAOs.""" + counts = {} + for entity_name, dao_class in self.dao_classes.items(): + try: + counts[f"total_{entity_name}"] = dao_class.count() + except Exception as e: + self._log_warning(f"Failed to count {entity_name}: {e}") + counts[f"total_{entity_name}"] = 0 + return counts + + def _calculate_time_based_metrics( + self, base_counts: Dict[str, int] + ) -> Dict[str, Dict[str, int]]: + """Calculate time-based metrics for recent activity.""" + from datetime import datetime, timedelta, timezone + + from superset.daos.base import ColumnOperator, ColumnOperatorEnum + + now = datetime.now(timezone.utc) + time_metrics = {} + + for window_name, days in self.time_windows.items(): + cutoff_date = now - timedelta(days=days) + window_metrics = {} + + # Calculate created and modified counts for each entity + for entity_name, dao_class in self.dao_classes.items(): + # Skip entities without time tracking + if not hasattr(dao_class.model_cls, "created_on"): + continue + + try: + # Use list() with filters (count() has no params) + _, created_count = dao_class.list( + column_operators=[ + ColumnOperator( + col="created_on", + opr=ColumnOperatorEnum.gte, + value=cutoff_date, + ) + ], + page_size=1, # We only need the count + columns=["id"], # Minimal data transfer + ) + window_metrics[f"{entity_name}_created"] = created_count + + # Modified count (if changed_on exists) + if hasattr(dao_class.model_cls, "changed_on"): + _, modified_count = dao_class.list( + column_operators=[ + ColumnOperator( + col="changed_on", + opr=ColumnOperatorEnum.gte, + value=cutoff_date, + ) + ], + page_size=1, # We only need the count + columns=["id"], # Minimal data transfer + ) + window_metrics[f"{entity_name}_modified"] = modified_count + + except Exception as e: + self._log_warning( + f"Failed to calculate {window_name} metrics for " + f"{entity_name}: {e}" + ) + window_metrics[f"{entity_name}_created"] = 0 + window_metrics[f"{entity_name}_modified"] = 0 + + time_metrics[window_name] = window_metrics + + return time_metrics + + def _calculate_custom_metrics( + self, base_counts: Dict[str, int], time_metrics: Dict[str, Dict[str, int]] + ) -> Dict[str, Any]: + """Calculate custom metrics using provided calculators.""" + custom_metrics = {} + + for metric_name, calculator in self.metric_calculators.items(): + try: + # Pass context to calculator functions + result = calculator( + base_counts=base_counts, + time_metrics=time_metrics, + dao_classes=self.dao_classes, + ) + # Only include successful calculations + if result is not None: + custom_metrics[metric_name] = result + except Exception as e: + self._log_warning(f"Failed to calculate {metric_name}: {e}") + # Don't add failed metrics to avoid validation errors + + return custom_metrics + + def run_tool(self) -> BaseModel: + """Tool interface for generating comprehensive instance information.""" + return self._generate_instance_info() + + def get_resource(self) -> str: + """Resource interface for generating instance metadata as JSON.""" + from superset.utils import json + + instance_info = self._generate_instance_info() + return json.dumps(instance_info.model_dump(), indent=2) + + def _generate_instance_info(self) -> BaseModel: + """Generate comprehensive instance information.""" + try: + # Calculate all metrics + base_counts = self._calculate_basic_counts() + time_metrics = self._calculate_time_based_metrics(base_counts) + custom_metrics = self._calculate_custom_metrics(base_counts, time_metrics) + + # Combine all data with fallbacks for required fields + from datetime import datetime, timezone + + response_data = { + **base_counts, + **time_metrics, + **custom_metrics, + "timestamp": datetime.now(timezone.utc).isoformat(), + } + + # Create response using the configured schema + response = self.output_schema(**response_data) + + self._log_info("Successfully generated instance information") + return response + + except Exception as e: + self._log_error(e, "generating instance info") + raise + + +class ModelGetAvailableFiltersCore(BaseCore, Generic[S]): + """ + Generic tool for retrieving available filterable columns and operators for a + model. Used for get_dataset_available_filters, get_chart_available_filters, + get_dashboard_available_filters, etc. + """ + + def __init__( + self, + dao_class: Type[BaseDAO[Any]], + output_schema: Type[S], + logger: logging.Logger | None = None, + ) -> None: + super().__init__(logger) + self.dao_class = dao_class + self.output_schema = output_schema + + def run_tool(self) -> S: + try: + filterable = self.dao_class.get_filterable_columns_and_operators() + # Ensure column_operators is a plain dict, not a custom type + column_operators = dict(filterable) + response = self.output_schema(column_operators=column_operators) + self._log_info( + f"Successfully retrieved available filters for " + f"{self.dao_class.__class__.__name__}" + ) + return response + except Exception as e: + self._log_error(e) + raise diff --git a/superset/mcp_service/middleware.py b/superset/mcp_service/middleware.py index 28539b77876..66c98777c44 100644 --- a/superset/mcp_service/middleware.py +++ b/superset/mcp_service/middleware.py @@ -15,36 +15,84 @@ # specific language governing permissions and limitations # under the License. -""" -Minimal middleware for MCP service. -This provides basic error handling for MCP tool calls. - -Future enhancements (to be added in separate PRs): -- Rate limiting middleware -- Field-level permissions middleware -- Comprehensive audit logging middleware -- Private tool blocking middleware -- Advanced error sanitization -""" - import logging -from typing import Any, Awaitable, Callable +import time +from collections import defaultdict +from typing import Any, Awaitable, Callable, Dict, Protocol from fastmcp.exceptions import ToolError from fastmcp.server.middleware import Middleware, MiddlewareContext +from pydantic import ValidationError +from sqlalchemy.exc import OperationalError, TimeoutError +from starlette.exceptions import HTTPException + +from superset.extensions import event_logger +from superset.utils.core import get_user_id logger = logging.getLogger(__name__) -class BasicErrorHandlerMiddleware(Middleware): +def _sanitize_error_for_logging(error: Exception) -> str: + """Sanitize error messages to prevent information disclosure in logs.""" + error_str = str(error) + + # SECURITY FIX: Limit error message length FIRST to prevent ReDoS attacks + if len(error_str) > 500: + error_str = error_str[:500] + "...[truncated]" + + # SECURITY FIX: Use bounded patterns to prevent ReDoS + import re + + # Database connection strings - bounded patterns with word boundaries + # Use case-insensitive flag to handle both cases + error_str = re.sub( + r"\bpostgresql://[^@\s]{1,100}@[^/\s]{1,100}/[^\s]{0,100}", + "postgresql://[REDACTED]@[REDACTED]/[REDACTED]", + error_str, + flags=re.IGNORECASE, + ) + error_str = re.sub( + r"\bmysql://[^@\s]{1,100}@[^/\s]{1,100}/[^\s]{0,100}", + "mysql://[REDACTED]@[REDACTED]/[REDACTED]", + error_str, + flags=re.IGNORECASE, + ) + + # API keys and tokens - bounded patterns + error_str = re.sub( + r"[Aa]pi[_-]?[Kk]ey[:\s]{0,5}[^\s'\"]{1,100}", + "ApiKey: [REDACTED]", + error_str, + ) + error_str = re.sub( + r"[Tt]oken[:\s]{0,5}[^\s'\"]{1,100}", "Token: [REDACTED]", error_str + ) + + # File paths - bounded pattern + error_str = re.sub( + r"/[a-zA-Z0-9_\-/.]{1,200}/superset/", "/[REDACTED]/superset/", error_str + ) + + # IP addresses - already safe pattern, keep as-is + error_str = re.sub(r"\b(\d+)\.\d+\.\d+\.\d+\b", r"\1.xxx.xxx.xxx", error_str) + + # For certain error types, provide generic messages + if isinstance(error, (OperationalError, TimeoutError)): + return "Database operation failed" + elif isinstance(error, PermissionError): + return "Access denied" + elif isinstance(error, ValidationError): + return "Request validation failed" + + return error_str + + +class LoggingMiddleware(Middleware): """ - Basic error handler middleware for MCP tools. - - Provides simple error handling and logging for tool calls. - - TODO (future PR): Add error sanitization for security - TODO (future PR): Add comprehensive error categorization - TODO (future PR): Add integration with Superset event logging + Middleware that logs every MCP message (request and response) using Superset's + event logger. This matches the core Superset audit log system (Action Log UI, + logs table, custom loggers). Also attempts to log dashboard_id, chart_id + (slice_id), and dataset_id if present in tool params. """ async def on_message( @@ -52,23 +100,658 @@ class BasicErrorHandlerMiddleware(Middleware): context: MiddlewareContext, call_next: Callable[[MiddlewareContext], Awaitable[Any]], ) -> Any: - """Handle messages with basic error handling""" + # Extract agent_id and user_id + agent_id = None + user_id = None + dashboard_id = None + slice_id = None + dataset_id = None + params = getattr(context.message, "params", {}) or {} + if hasattr(context, "metadata") and context.metadata: + agent_id = context.metadata.get("agent_id") + if not agent_id and hasattr(context, "session") and context.session: + agent_id = getattr(context.session, "agent_id", None) + try: + user_id = get_user_id() + except Exception: + user_id = None + # Try to extract IDs from params + if isinstance(params, dict): + dashboard_id = params.get("dashboard_id") + # Chart ID may be under 'chart_id' or 'slice_id' + slice_id = params.get("chart_id") or params.get("slice_id") + dataset_id = params.get("dataset_id") + # Log to Superset's event logger (DB, Action Log UI, or custom) + event_logger.log( + user_id=user_id, + action="mcp_tool_call", + dashboard_id=dashboard_id, + duration_ms=None, + slice_id=slice_id, + referrer=None, + curated_payload={ + "tool": getattr(context.message, "name", None), + "agent_id": agent_id, + "params": params, + "method": context.method, + "dashboard_id": dashboard_id, + "slice_id": slice_id, + "dataset_id": dataset_id, + }, + ) + # (Optional) also log to standard logger for debugging + logger.info( + "MCP tool call: tool=%s, agent_id=%s, user_id=%s, method=%s, " + "dashboard_id=%s, slice_id=%s, dataset_id=%s", + getattr(context.message, "name", None), + agent_id, + user_id, + context.method, + dashboard_id, + slice_id, + dataset_id, + ) + return await call_next(context) + + +class PrivateToolMiddleware(Middleware): + """ + Middleware that blocks access to tools tagged as 'private'. + """ + + async def on_call_tool( + self, + context: MiddlewareContext, + call_next: Callable[[MiddlewareContext], Awaitable[Any]], + ) -> Any: + tool = await context.fastmcp_context.fastmcp.get_tool(context.message.name) + if "private" in getattr(tool, "tags", set()): + raise ToolError(f"Access denied to private tool: {context.message.name}") + return await call_next(context) + + +class GlobalErrorHandlerMiddleware(Middleware): + """ + Global error handler middleware that provides consistent error responses + and proper error logging for all MCP tool calls. + """ + + async def on_message( + self, + context: MiddlewareContext, + call_next: Callable[[MiddlewareContext], Awaitable[Any]], + ) -> Any: + """Handle all message types with consistent error handling""" + start_time = time.time() tool_name = getattr(context.message, "name", "unknown") try: return await call_next(context) except Exception as e: - # Log the error - logger.error( - "MCP tool error: tool=%s, error_type=%s, error=%s", - tool_name, - type(e).__name__, - str(e), + duration_ms = int((time.time() - start_time) * 1000) + return await self._handle_error(e, context, tool_name, duration_ms) + + async def _handle_error( # noqa: C901 + self, + error: Exception, + context: MiddlewareContext, + tool_name: str, + duration_ms: int, + ) -> None: + """Handle different types of errors with appropriate responses""" + # Extract user context for logging + user_id = None + try: + user_id = get_user_id() + except Exception: + user_id = None # User not authenticated + + # SECURITY FIX: Log the error with sanitized context + sanitized_error = _sanitize_error_for_logging(error) + logger.error( + "MCP tool error: tool=%s, user_id=%s, duration_ms=%s, " + "error_type=%s, error=%s", + tool_name, + user_id, + duration_ms, + type(error).__name__, + sanitized_error, + ) + + # Log to Superset's event system + try: + event_logger.log( + user_id=user_id, + action="mcp_tool_error", + duration_ms=duration_ms, + curated_payload={ + "tool": tool_name, + "error_type": type(error).__name__, + "error_message": str(error), + "method": context.method, + }, + ) + except Exception as log_error: + logger.warning("Failed to log error event: %s", log_error) + + # Handle specific error types with appropriate responses + if isinstance(error, ToolError): + # Tool errors are already formatted for MCP + raise error + elif isinstance(error, ValidationError): + # Pydantic validation errors + validation_details = [] + for err in error.errors(): + field = " -> ".join(str(loc) for loc in err["loc"]) + validation_details.append(f"{field}: {err['msg']}") + + raise ToolError( + f"Validation error in {tool_name}: {'; '.join(validation_details)}" + ) from error + elif isinstance(error, (OperationalError, TimeoutError)): + # Database errors + raise ToolError( + f"Database error in {tool_name}: Service temporarily unavailable. " + f"Please try again in a few moments." + ) from error + elif isinstance(error, HTTPException): + # HTTP errors from screenshot endpoints or API calls + raise ToolError(f"Service error in {tool_name}: {error.detail}") from error + elif isinstance(error, PermissionError): + # Permission/authorization errors + raise ToolError( + f"Permission denied for {tool_name}: " + f"You don't have access to this resource." + ) from error + elif isinstance(error, FileNotFoundError): + # File/resource not found errors + raise ToolError( + f"Resource not found in {tool_name}: {str(error)}" + ) from error + elif isinstance(error, ValueError): + # Value/parameter errors + raise ToolError( + f"Invalid parameter in {tool_name}: {str(error)}" + ) from error + else: + # Generic internal errors + error_id = f"err_{int(time.time())}" + logger.error("Unexpected error [%s] in %s: %s", error_id, tool_name, error) + + raise ToolError( + f"Internal error in {tool_name}: An unexpected error occurred. " + f"Error ID: {error_id}. Please contact support if this persists." + ) from error + + +class RateLimiterProtocol(Protocol): + """Protocol for rate limiter implementations.""" + + def is_rate_limited( + self, key: str, limit: int, window: int = 60 + ) -> tuple[bool, dict[str, Any]]: + """Check if a key is rate limited.""" + ... + + def cleanup(self) -> None: + """Clean up old entries if needed.""" + ... + + +class InMemoryRateLimiter: + """In-memory rate limiter for development.""" + + def __init__(self) -> None: + # Structure: {key: [(timestamp, count), ...]} + self._requests: Dict[str, list[tuple[float, int]]] = defaultdict(list) + self._cleanup_interval = 300 # Clean up every 5 minutes + self._last_cleanup = time.time() + + def is_rate_limited( + self, key: str, limit: int, window: int = 60 + ) -> tuple[bool, dict[str, Any]]: + """Check if request should be rate limited using sliding window.""" + current_time = time.time() + window_start = current_time - window + + # Get requests in the current window + requests_in_window = [ + (timestamp, count) + for timestamp, count in self._requests[key] + if timestamp > window_start + ] + + # Calculate total requests in window + total_requests = sum(count for _, count in requests_in_window) + + # Check if rate limited BEFORE adding the current request + if total_requests >= limit: + # Rate limit info when limited + rate_limit_info = { + "limit": limit, + "remaining": 0, + "reset_time": int(window_start + window), + "window_seconds": window, + } + return True, rate_limit_info + + # Add current request to tracking + self._requests[key].append((current_time, 1)) + + # Update total after adding + total_requests += 1 + + # Keep only recent entries + self._requests[key] = [ + (ts, count) + for ts, count in self._requests[key] + if ts > current_time - 3600 # Keep last hour + ] + + # Rate limit info after adding request + rate_limit_info = { + "limit": limit, + "remaining": max(0, limit - total_requests), + "reset_time": int(window_start + window), + "window_seconds": window, + } + + return False, rate_limit_info + + def cleanup(self) -> None: + """Remove entries older than 1 hour to prevent memory leaks.""" + current_time = time.time() + + # SECURITY FIX: Check both time-based and size-based cleanup conditions + total_entries = sum(len(requests) for requests in self._requests.values()) + size_threshold = 10000 # Maximum entries before forced cleanup + + time_based_cleanup = current_time - self._last_cleanup >= self._cleanup_interval + size_based_cleanup = total_entries > size_threshold + + if not (time_based_cleanup or size_based_cleanup): + return + + cutoff_time = current_time - 3600 # 1 hour ago + keys_to_clean = [] + + for key, requests in self._requests.items(): + # Remove old entries + self._requests[key] = [ + (timestamp, count) + for timestamp, count in requests + if timestamp > cutoff_time + ] + # Mark empty keys for removal + if not self._requests[key]: + keys_to_clean.append(key) + + for key in keys_to_clean: + del self._requests[key] + + # SECURITY FIX: If still too many entries, implement aggressive cleanup + if total_entries > size_threshold: + logger.warning( + "Rate limiter memory high (%d entries), performing aggressive cleanup", + total_entries, + ) + # Keep only the most recent entries per key + for key in list(self._requests.keys()): + if len(self._requests[key]) > 100: # Keep max 100 entries per key + self._requests[key] = self._requests[key][-100:] + + self._last_cleanup = current_time + + +class RedisRateLimiter: + """Redis-backed rate limiter for production.""" + + def __init__(self) -> None: + from superset.extensions import cache_manager + + self._cache = cache_manager.cache + self._prefix = "mcp:ratelimit:" + + def is_rate_limited( + self, key: str, limit: int, window: int = 60 + ) -> tuple[bool, dict[str, Any]]: + """Check if request should be rate limited using Redis sliding window.""" + current_time = time.time() + full_key = "%s%s" % (self._prefix, key) + + try: + # Use Redis sorted set for sliding window + window_start = current_time - window + + # Remove old entries outside the window + self._cache.delete_many( + [ + k + for k, score in self._cache.get(full_key) or [] + if score < window_start + ] ) - # If it's already a ToolError, re-raise it - if isinstance(e, ToolError): - raise + # Get count of requests in window + request_count = self._cache.get("%s:count" % full_key) or 0 - # Convert to ToolError for consistent error format - raise ToolError(f"Error in {tool_name}: {str(e)}") from e + # Rate limit info + rate_limit_info = { + "limit": limit, + "remaining": max(0, limit - request_count), + "reset_time": int(current_time + window), + "window_seconds": window, + } + + if request_count >= limit: + return True, rate_limit_info + + # Increment counter with TTL + new_count = (request_count or 0) + 1 + self._cache.set("%s:count" % full_key, new_count, timeout=window) + + return False, rate_limit_info + + except Exception as e: + logger.warning("Redis rate limiter error: %s, allowing request", e) + # On Redis error, allow the request + return False, { + "limit": limit, + "remaining": limit, + "reset_time": 0, + "window_seconds": window, + } + + def cleanup(self) -> None: + """No cleanup needed for Redis - TTL handles expiration.""" + pass + + +def create_rate_limiter() -> RateLimiterProtocol: + """Factory to create appropriate rate limiter based on environment.""" + try: + # Try to use Redis first (production) + from superset.extensions import cache_manager + + if cache_manager and cache_manager.cache: + # Test Redis connectivity + test_key = "mcp:ratelimit:test" + cache_manager.cache.set(test_key, 1, timeout=1) + if cache_manager.cache.get(test_key): + cache_manager.cache.delete(test_key) + logger.info("Using Redis for rate limiting") + return RedisRateLimiter() + except Exception as e: + logger.warning( + "Redis not available for rate limiting: %s, falling back to in-memory", e + ) + + # Fallback to in-memory rate limiter (development) + logger.info("Using in-memory rate limiter") + return InMemoryRateLimiter() + + +class RateLimitMiddleware(Middleware): + """ + Rate limiting middleware to prevent abuse of MCP tools. + + Implements sliding window rate limiting with separate limits for: + - Per-user limits (if authenticated) + - Per-IP limits (for unauthenticated requests) + - Per-tool limits (for expensive operations) + + Configuration: + - default_requests_per_minute: Default rate limit (60 requests/minute) + - per_user_requests_per_minute: Rate limit per authenticated user (120/min) + - expensive_tool_requests_per_minute: Rate limit for expensive tools (10/min) + """ + + def __init__( + self, + default_requests_per_minute: int = 60, + per_user_requests_per_minute: int = 120, + expensive_tool_requests_per_minute: int = 10, + expensive_tools: list[str] | None = None, + ) -> None: + self.default_rpm = default_requests_per_minute + self.user_rpm = per_user_requests_per_minute + self.expensive_rpm = expensive_tool_requests_per_minute + self.expensive_tools = set( + expensive_tools + or [ + "get_chart_preview", + "generate_chart", + "generate_dashboard", + "get_chart_data", + ] + ) + + # Use hybrid rate limiter (Redis in production, in-memory in development) + self._rate_limiter = create_rate_limiter() + + def _get_rate_limit_key(self, context: MiddlewareContext) -> tuple[str, int]: + """ + Generate rate limit key and determine applicable limit. + + Returns: + Tuple of (key, requests_per_minute_limit) + """ + tool_name = getattr(context.message, "name", "unknown") + + # Get user context + user_id = None + try: + user_id = get_user_id() + except Exception: + user_id = None # User not authenticated + + # Determine rate limit + if tool_name in self.expensive_tools: + limit = self.expensive_rpm + key_prefix = "expensive" + elif user_id: + limit = self.user_rpm + key_prefix = "user" + else: + limit = self.default_rpm + key_prefix = "default" + + # Generate key + if user_id: + key = f"{key_prefix}:user:{user_id}:{tool_name}" + else: + # Use agent_id or session info as fallback + agent_id = None + if hasattr(context, "metadata") and context.metadata: + agent_id = context.metadata.get("agent_id") + if not agent_id and hasattr(context, "session") and context.session: + agent_id = getattr(context.session, "agent_id", None) + + if agent_id: + key = f"{key_prefix}:agent:{agent_id}:{tool_name}" + else: + key = f"{key_prefix}:anonymous:{tool_name}" + + return key, limit + + async def on_call_tool( + self, + context: MiddlewareContext, + call_next: Callable[[MiddlewareContext], Awaitable[Any]], + ) -> Any: + """Check rate limits before allowing tool calls.""" + # Clean up old entries periodically (only needed for in-memory) + self._rate_limiter.cleanup() + + # Get rate limit key and limit + key, limit = self._get_rate_limit_key(context) + + # Check if rate limited + is_limited, rate_info = self._rate_limiter.is_rate_limited(key, limit) + + if is_limited: + tool_name = getattr(context.message, "name", "unknown") + + # Log rate limit event + try: + user_id = get_user_id() if hasattr(context, "session") else None + event_logger.log( + user_id=user_id, + action="mcp_rate_limit_exceeded", + curated_payload={ + "tool": tool_name, + "rate_limit_key": key, + "limit": limit, + "window_seconds": 60, + }, + ) + except Exception as log_error: + logger.warning("Failed to log rate limit event: %s", log_error) + + logger.warning( + "Rate limit exceeded for %s: key=%s, limit=%s/min, reset_in=%ss", + tool_name, + key, + limit, + rate_info["reset_time"] - int(time.time()), + ) + + raise ToolError( + "Rate limit exceeded for %s. " + "Limit: %s requests per minute. " + "Try again in %s seconds." + % (tool_name, limit, rate_info["reset_time"] - int(time.time())) + ) + + # Log rate limit info for monitoring + logger.debug( + "Rate limit check: %s: key=%s, remaining=%s/%s", + getattr(context.message, "name", "unknown"), + key, + rate_info["remaining"], + limit, + ) + + return await call_next(context) + + +class FieldPermissionsMiddleware(Middleware): + """ + Middleware that applies field-level permissions to filter sensitive data + from MCP tool responses based on user permissions. + """ + + # Map tool names to object types for permission filtering + TOOL_OBJECT_TYPE_MAP = { + "list_datasets": "dataset", + "get_dataset_info": "dataset", + "list_charts": "chart", + "get_chart_info": "chart", + "get_chart_data": "chart", + "get_chart_preview": "chart", + "update_chart": "chart", + "generate_chart": "chart", + "list_dashboards": "dashboard", + "get_dashboard_info": "dashboard", + "generate_dashboard": "dashboard", + "add_chart_to_existing_dashboard": "dashboard", + } + + async def on_call_tool( + self, + context: MiddlewareContext, + call_next: Callable[[MiddlewareContext], Awaitable[Any]], + ) -> Any: + """Apply field-level permissions to tool responses.""" + # Get the tool response first + response = await call_next(context) + + # Get tool name + tool_name = getattr(context.message, "name", "unknown") + + # Check if this tool needs field-level filtering + object_type = self.TOOL_OBJECT_TYPE_MAP.get(tool_name) + if not object_type: + # No filtering needed + return response + + # Get current user for permissions + try: + user = self._get_current_user() + except Exception as e: + logger.warning("Could not get current user for field filtering: %s", e) + user = None + + # Apply field-level permissions to the response + try: + filtered_response = self._filter_response(response, object_type, user) + + # Log field filtering activity for monitoring + logger.debug( + "Applied field-level permissions for %s (object_type=%s, user=%s)", + tool_name, + object_type, + getattr(user, "username", "anonymous"), + ) + + return filtered_response + + except Exception as e: + logger.error("Error applying field permissions to %s: %s", tool_name, e) + # Return original response if filtering fails + return response + + def _get_current_user(self) -> Any: + """Get the current authenticated user.""" + try: + from flask import g + + return getattr(g, "user", None) + except Exception: + # Try to get user from core utils + try: + user_id = get_user_id() + if user_id: + from flask_appbuilder.security.sqla.models import User + + from superset.extensions import db + + return db.session.query(User).filter_by(id=user_id).first() + except Exception as e: + logger.debug("Could not get user from session: %s", e) + return None + + def _filter_response(self, response: Any, object_type: str, user: Any) -> Any: + """ + Filter response data based on object type and user permissions. + + Args: + response: The response object to filter + object_type: Type of object ('dataset', 'chart', 'dashboard') + user: User object for permission checking + + Returns: + Filtered response + """ + from superset.mcp_service.utils.permissions_utils import filter_sensitive_data + + if not response: + return response + + # Handle different response types + if hasattr(response, "model_dump"): + # Pydantic model - convert to dict, filter, and return dict + response_dict = response.model_dump() + return filter_sensitive_data(response_dict, object_type, user) + elif isinstance(response, dict): + # Dictionary response - filter directly + return filter_sensitive_data(response, object_type, user) + elif isinstance(response, list): + # List response - filter each item + return [filter_sensitive_data(item, object_type, user) for item in response] + else: + # Unknown response type, return as-is + logger.debug( + "Unknown response type for field filtering: %s", type(response) + ) + return response diff --git a/superset/mcp_service/screenshot/__init__.py b/superset/mcp_service/screenshot/__init__.py new file mode 100644 index 00000000000..3531e6e2c58 --- /dev/null +++ b/superset/mcp_service/screenshot/__init__.py @@ -0,0 +1,35 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Screenshot and WebDriver infrastructure for MCP service.""" + +from .pooled_screenshot import ( + PooledBaseScreenshot, + PooledChartScreenshot, + PooledDashboardScreenshot, + PooledExploreScreenshot, +) +from .webdriver_pool import get_webdriver_pool, WebDriverPool + +__all__ = [ + "PooledBaseScreenshot", + "PooledChartScreenshot", + "PooledDashboardScreenshot", + "PooledExploreScreenshot", + "WebDriverPool", + "get_webdriver_pool", +] diff --git a/superset/mcp_service/screenshot/pooled_screenshot.py b/superset/mcp_service/screenshot/pooled_screenshot.py new file mode 100644 index 00000000000..2da44c01057 --- /dev/null +++ b/superset/mcp_service/screenshot/pooled_screenshot.py @@ -0,0 +1,483 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Pooled screenshot implementations for improved performance +""" + +import logging +import time +from typing import Any + +from flask import current_app +from flask_appbuilder.security.sqla.models import User +from selenium.common.exceptions import TimeoutException, WebDriverException +from selenium.webdriver.common.by import By +from selenium.webdriver.support import expected_conditions +from selenium.webdriver.support.ui import WebDriverWait + +from superset.extensions import machine_auth_provider_factory +from superset.mcp_service.screenshot.webdriver_pool import get_webdriver_pool +from superset.mcp_service.utils.retry_utils import retry_screenshot_operation +from superset.utils.screenshots import BaseScreenshot, WindowSize + +logger = logging.getLogger(__name__) + + +class PooledBaseScreenshot(BaseScreenshot): + """ + Enhanced screenshot class that uses WebDriver pooling for better performance. + + Key improvements: + - Reuses WebDriver instances from a connection pool + - Faster screenshot generation (no browser startup/shutdown per request) + - Better resource management and automatic cleanup + - Health checking and recovery of WebDriver instances + """ + + def get_screenshot( + self, user: User, window_size: WindowSize | None = None + ) -> bytes | None: + """ + Generate screenshot using pooled WebDriver with retry logic for reliability. + + Args: + user: User context for authentication + window_size: Optional window size override + + Returns: + Screenshot as PNG bytes or None if failed + """ + return retry_screenshot_operation( + self._get_screenshot_internal, user, window_size + ) + + def _get_screenshot_internal( + self, user: User, window_size: WindowSize | None = None + ) -> bytes | None: + """ + Internal screenshot generation method with pooled WebDriver. + This method is wrapped by retry logic in get_screenshot(). + """ + window_size = window_size or self.window_size + pool = get_webdriver_pool() + + # Use pooled WebDriver + with pool.get_driver(window_size, user.id) as driver: + try: + # Clear any existing cookies to ensure clean authentication + try: + driver.delete_all_cookies() + logger.debug("Cleared all cookies from WebDriver") + except Exception as e: + logger.warning("Failed to clear cookies: %s", e) + + # Authenticate the driver for this user + user_name = user.username if user else "None" + logger.debug("Authenticating WebDriver for user %s", user_name) + machine_auth_provider_factory.instance.authenticate_webdriver( + driver, user + ) + + # Navigate to the URL + logger.debug("Navigating to screenshot URL: %s", self.url) + driver.get(self.url) + + # Check if we were redirected to login (authentication failed) + current_url = driver.current_url + if "/login" in current_url: + msg = ( + f"Authentication failed - WebDriver redirected to login page: " + f"{current_url}" + ) + raise Exception(msg) + + # Take screenshot using the specific implementation + return self._take_screenshot(driver, user) + + except (TimeoutException, WebDriverException, OSError) as e: + # These are retryable exceptions + logger.warning("Retryable error taking screenshot: %s", e) + raise + except Exception as e: + logger.error("Non-retryable error taking screenshot: %s", e) + raise + + def _take_screenshot(self, driver: Any, user: User) -> bytes | None: + """ + Override this method in subclasses to implement specific screenshot logic. + + Args: + driver: Authenticated WebDriver instance + user: User context + + Returns: + Screenshot as PNG bytes + """ + raise NotImplementedError("Subclasses must implement _take_screenshot") + + +class PooledChartScreenshot(PooledBaseScreenshot): + """Pooled version of chart screenshot generation""" + + thumbnail_type: str = "chart" + element: str = "chart-container" + + def __init__( + self, + url: str, + digest: str | None, + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, + ): + from superset.utils.urls import modify_url_query + from superset.utils.webdriver import ChartStandaloneMode + + # Chart reports are in standalone="true" mode + url = modify_url_query( + url, + standalone=ChartStandaloneMode.HIDE_NAV.value, + ) + super().__init__(url, digest) + self.window_size = window_size or (800, 600) + self.thumb_size = thumb_size or (400, 300) + + def _take_screenshot(self, driver: Any, user: User) -> bytes | None: + """Take screenshot of chart with standard Superset chart handling""" + try: + # Wait for page to load + selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"] + logger.debug("Sleeping for %s seconds", selenium_headstart) + time.sleep(selenium_headstart) + + # Wait for chart container + wait = WebDriverWait(driver, current_app.config["SCREENSHOT_LOCATE_WAIT"]) + element = wait.until( + expected_conditions.presence_of_element_located( + (By.CLASS_NAME, self.element) + ) + ) + + # Wait for chart containers to render + wait.until( + expected_conditions.visibility_of_all_elements_located( + (By.CLASS_NAME, "chart-container") + ) + ) + + # Wait for loading to complete + WebDriverWait(driver, current_app.config["SCREENSHOT_LOAD_WAIT"]).until_not( + expected_conditions.presence_of_all_elements_located( + (By.CLASS_NAME, "loading") + ) + ) + + # Wait for animations + animation_wait = current_app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"] + logger.debug("Wait %s seconds for chart animation", animation_wait) + time.sleep(animation_wait) + + # Handle unexpected errors if configured + if current_app.config.get("SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"): + from superset.utils.webdriver import WebDriverSelenium + + unexpected_errors = WebDriverSelenium.find_unexpected_errors(driver) + if unexpected_errors: + logger.warning( + "%s errors found in screenshot. ", len(unexpected_errors) + ) + + # Take screenshot + logger.debug("Taking PNG screenshot as user %s", user.username) + return element.screenshot_as_png + + except TimeoutException: + logger.exception("Timeout taking chart screenshot for URL: %s", self.url) + raise + except WebDriverException: + logger.exception("WebDriver error taking screenshot for URL: %s", self.url) + raise + + +class PooledExploreScreenshot(PooledBaseScreenshot): + """ + Pooled version of explore screenshot with UI hiding functionality. + + This class provides the same clean chart-only screenshots as the previous + implementation but with improved performance through WebDriver pooling. + """ + + thumbnail_type: str = "explore" + element: str = "chart-container" + + def __init__( + self, + url: str, + digest: str | None, + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, + ): + super().__init__(url, digest) + self.window_size = window_size or (1600, 1200) + self.thumb_size = thumb_size or (800, 600) + + def _take_screenshot(self, driver: Any, user: User) -> bytes | None: + """ + Take screenshot of explore page with UI elements hidden for clean chart display. + + This implementation: + 1. Waits for chart to load + 2. Hides navigation, panels, and headers using JavaScript + 3. Takes screenshot of just the chart area + 4. Includes fallback strategies if chart container not found + """ + try: + # Give explore page time to initialize + logger.debug("Waiting for explore page to initialize") + time.sleep(3) + + # Wait for chart container to load and be visible + wait = WebDriverWait(driver, 45) + chart_element = wait.until( + expected_conditions.visibility_of_element_located( + (By.CLASS_NAME, "chart-container") + ) + ) + + # Wait for any loading indicators to disappear + WebDriverWait(driver, 30).until_not( + expected_conditions.presence_of_all_elements_located( + (By.CLASS_NAME, "loading") + ) + ) + + # Hide overlapping UI elements using JavaScript + self._hide_ui_elements(driver) + + # Wait for UI hiding animations to complete + time.sleep(2) + + # Take screenshot of just the chart area + img = chart_element.screenshot_as_png + logger.info("Successfully captured chart-container screenshot") + return img + + except TimeoutException as e: + logger.warning("Chart container not found, trying fallbacks: %s", e) + return self._fallback_screenshot(driver) + except Exception as e: + logger.error("Error taking explore screenshot: %s", e) + raise + + def _hide_ui_elements(self, driver: Any) -> None: + """Hide UI elements to show only the chart""" + hide_script = """ + // Hide the specific left panels by their exact class names + var dataSourcePanel = document.querySelector( + '.explore-column.data-source-selection' + ); + if (dataSourcePanel) { + dataSourcePanel.style.display = 'none'; + console.log('Hidden data source panel'); + } + + var controlsPanel = document.querySelector( + '.col-sm-3.explore-column.controls-column' + ); + if (controlsPanel) { + controlsPanel.style.display = 'none'; + console.log('Hidden controls panel'); + } + + // Also try alternate selectors for the controls + var controlsAlt = document.querySelector('.explore-column.controls-column'); + if (controlsAlt) { + controlsAlt.style.display = 'none'; + console.log('Hidden controls panel (alt)'); + } + + // Hide the main navigation header + var mainHeader = document.querySelector('header.top#main-menu'); + if (mainHeader) { + mainHeader.style.display = 'none'; + console.log('Hidden main navigation header'); + } + + // Hide the chart header with actions (title input and save button) + var chartHeader = document.querySelector('.header-with-actions'); + if (chartHeader) { + chartHeader.style.display = 'none'; + console.log('Hidden chart header with actions'); + } + + // Make sure the main chart area expands to use available space + var mainContent = document.querySelector('.main-explore-content'); + if (mainContent) { + mainContent.style.width = '100%'; + // Change from col-sm-7 to col-sm-12 + mainContent.className = 'main-explore-content col-sm-12'; + } + + // Ensure chart container is visible + var chartContainer = document.querySelector('.chart-container'); + if (chartContainer) { + chartContainer.style.position = 'relative'; + chartContainer.style.zIndex = '999'; + chartContainer.style.width = '100%'; + } + """ + + try: + driver.execute_script(hide_script) + logger.debug("Successfully executed UI hiding script") + except WebDriverException as e: + logger.warning("Failed to execute UI hiding script: %s", e) + + def _fallback_screenshot(self, driver: Any) -> bytes | None: + """Fallback screenshot strategies if chart container not found""" + try: + # Fallback 1: Try slice container + logger.debug("Trying slice_container fallback") + slice_element = WebDriverWait(driver, 10).until( + expected_conditions.visibility_of_element_located( + (By.CLASS_NAME, "slice_container") + ) + ) + img = slice_element.screenshot_as_png + logger.info("Successfully captured slice_container screenshot") + return img + + except TimeoutException: + try: + # Fallback 2: Try any chart-related container + logger.debug("Trying chart-related container fallback") + chart_containers = driver.find_elements( + By.CSS_SELECTOR, + ".chart, .slice-container, .chart-content, " + "[data-test='chart-container']", + ) + if chart_containers: + img = chart_containers[0].screenshot_as_png + logger.info( + "Successfully captured chart-related container screenshot" + ) + return img + + except Exception as e2: + logger.warning("Chart-related container fallback failed: %s", e2) + + # Fallback 3: Full page screenshot + logger.warning("Using full page screenshot as final fallback") + img = driver.get_screenshot_as_png() + logger.info("Using full page screenshot as fallback") + return img + + +class PooledDashboardScreenshot(PooledBaseScreenshot): + """Pooled version of dashboard screenshot generation""" + + thumbnail_type: str = "dashboard" + element: str = "standalone" + + def __init__( + self, + url: str, + digest: str | None, + window_size: WindowSize | None = None, + thumb_size: WindowSize | None = None, + ): + from superset.utils.urls import modify_url_query + from superset.utils.webdriver import DashboardStandaloneMode + + # Dashboard screenshots should always capture in standalone + url = modify_url_query( + url, + standalone=DashboardStandaloneMode.REPORT.value, + ) + super().__init__(url, digest) + self.window_size = window_size or (1600, 1200) + self.thumb_size = thumb_size or (800, 600) + + def _take_screenshot(self, driver: Any, user: User) -> bytes | None: + """Take screenshot of dashboard with standard Superset dashboard handling""" + try: + # Wait for page to load + selenium_headstart = current_app.config["SCREENSHOT_SELENIUM_HEADSTART"] + time.sleep(selenium_headstart) + + # Wait for dashboard element + wait = WebDriverWait(driver, current_app.config["SCREENSHOT_LOCATE_WAIT"]) + element = wait.until( + expected_conditions.presence_of_element_located( + (By.CLASS_NAME, self.element) + ) + ) + + # Wait for chart containers to render + try: + wait.until( + expected_conditions.visibility_of_all_elements_located( + (By.CLASS_NAME, "chart-container") + ) + ) + except TimeoutException: + # Fallback for empty dashboards + try: + WebDriverWait(driver, 0).until( + expected_conditions.visibility_of_all_elements_located( + (By.CLASS_NAME, "grid-container") + ) + ) + except TimeoutException: + logger.exception("Dashboard failed to load at URL: %s", self.url) + raise + + # Wait for loading to complete + WebDriverWait(driver, current_app.config["SCREENSHOT_LOAD_WAIT"]).until_not( + expected_conditions.presence_of_all_elements_located( + (By.CLASS_NAME, "loading") + ) + ) + + # Wait for animations + animation_wait = current_app.config["SCREENSHOT_SELENIUM_ANIMATION_WAIT"] + time.sleep(animation_wait) + + # Handle unexpected errors if configured + if current_app.config.get("SCREENSHOT_REPLACE_UNEXPECTED_ERRORS"): + from superset.utils.webdriver import WebDriverSelenium + + unexpected_errors = WebDriverSelenium.find_unexpected_errors(driver) + if unexpected_errors: + logger.warning( + "%s errors found in dashboard ", len(unexpected_errors) + ) + + # Take screenshot + logger.debug("Taking PNG dashboard screenshot as user %s", user.username) + return element.screenshot_as_png + + except TimeoutException: + logger.exception( + "Timeout taking dashboard screenshot for URL: %s", self.url + ) + raise + except WebDriverException: + logger.exception( + "WebDriver error taking dashboard screenshot for URL: %s", self.url + ) + raise diff --git a/superset/mcp_service/screenshot/webdriver_config.py b/superset/mcp_service/screenshot/webdriver_config.py new file mode 100644 index 00000000000..7c9a017965f --- /dev/null +++ b/superset/mcp_service/screenshot/webdriver_config.py @@ -0,0 +1,139 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +WebDriver pool configuration defaults for Superset MCP service +""" + +from typing import Any, Dict + +# Default WebDriver pool configuration +DEFAULT_WEBDRIVER_POOL_CONFIG = { + # Maximum number of WebDriver instances to keep in the pool + "MAX_POOL_SIZE": 5, + # Maximum age of a WebDriver instance (in seconds) + # After this time, the driver will be destroyed and recreated + "MAX_AGE_SECONDS": 3600, # 1 hour + # Maximum number of times a WebDriver can be reused + # After this many uses, the driver will be destroyed and recreated + "MAX_USAGE_COUNT": 50, + # How long a WebDriver can sit idle before being destroyed (in seconds) + "IDLE_TIMEOUT_SECONDS": 300, # 5 minutes + # How often to perform health checks on WebDriver instances (in seconds) + "HEALTH_CHECK_INTERVAL": 60, # 1 minute +} + + +def configure_webdriver_pool(app_config: Dict[str, Any]) -> None: + """ + Configure WebDriver pool settings in Superset app config. + + This function adds WebDriver pool configuration to the Superset app config + if it doesn't already exist, using sensible defaults. + + Args: + app_config: The Superset application configuration dictionary + """ + if "WEBDRIVER_POOL" not in app_config: + app_config["WEBDRIVER_POOL"] = DEFAULT_WEBDRIVER_POOL_CONFIG.copy() + else: + # Merge with defaults for any missing keys + for key, default_value in DEFAULT_WEBDRIVER_POOL_CONFIG.items(): + if key not in app_config["WEBDRIVER_POOL"]: + app_config["WEBDRIVER_POOL"][key] = default_value + + +def get_pool_stats_endpoint() -> Any: + """ + Create a Flask endpoint to view WebDriver pool statistics. + + This function can be called to register a debugging endpoint + that shows the current state of the WebDriver pool. + + Returns: + Flask route function for pool statistics + """ + + def pool_stats() -> Any: + try: + from flask import jsonify + + from superset.mcp_service.screenshot.webdriver_pool import ( + get_webdriver_pool, + ) + + pool = get_webdriver_pool() + stats = pool.get_stats() + + return jsonify({"webdriver_pool": stats, "status": "healthy"}) + except Exception as e: + from flask import jsonify + + return jsonify({"error": str(e), "status": "error"}), 500 + + return pool_stats + + +# Performance tuning recommendations based on use case +PERFORMANCE_CONFIGS = { + "low_traffic": { + "MAX_POOL_SIZE": 2, + "MAX_AGE_SECONDS": 1800, # 30 minutes + "MAX_USAGE_COUNT": 25, + "IDLE_TIMEOUT_SECONDS": 180, # 3 minutes + "HEALTH_CHECK_INTERVAL": 120, # 2 minutes + }, + "medium_traffic": { + "MAX_POOL_SIZE": 5, + "MAX_AGE_SECONDS": 3600, # 1 hour + "MAX_USAGE_COUNT": 50, + "IDLE_TIMEOUT_SECONDS": 300, # 5 minutes + "HEALTH_CHECK_INTERVAL": 60, # 1 minute + }, + "high_traffic": { + "MAX_POOL_SIZE": 10, + "MAX_AGE_SECONDS": 7200, # 2 hours + "MAX_USAGE_COUNT": 100, + "IDLE_TIMEOUT_SECONDS": 600, # 10 minutes + "HEALTH_CHECK_INTERVAL": 30, # 30 seconds + }, + "development": { + "MAX_POOL_SIZE": 2, + "MAX_AGE_SECONDS": 900, # 15 minutes + "MAX_USAGE_COUNT": 10, + "IDLE_TIMEOUT_SECONDS": 120, # 2 minutes + "HEALTH_CHECK_INTERVAL": 30, # 30 seconds + }, +} + + +def configure_for_environment( + app_config: Dict[str, Any], environment: str = "medium_traffic" +) -> None: + """ + Configure WebDriver pool for specific environment/traffic levels. + + Args: + app_config: The Superset application configuration dictionary + environment: Environment type (low_traffic, medium_traffic, + high_traffic, development) + """ + if environment in PERFORMANCE_CONFIGS: + app_config["WEBDRIVER_POOL"] = PERFORMANCE_CONFIGS[environment].copy() + else: + # Fallback to default + configure_webdriver_pool(app_config) diff --git a/superset/mcp_service/screenshot/webdriver_pool.py b/superset/mcp_service/screenshot/webdriver_pool.py new file mode 100644 index 00000000000..b56bdf109e0 --- /dev/null +++ b/superset/mcp_service/screenshot/webdriver_pool.py @@ -0,0 +1,433 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +WebDriver connection pooling for improved screenshot performance +""" + +import logging +import signal +import threading +import time +from contextlib import contextmanager +from dataclasses import dataclass +from queue import Empty, Full, Queue +from typing import Any, Dict, Generator + +from flask import current_app +from selenium.common.exceptions import WebDriverException +from selenium.webdriver.remote.webdriver import WebDriver + +from superset.utils.webdriver import WebDriverSelenium, WindowSize + +logger = logging.getLogger(__name__) + + +class WebDriverCreationError(Exception): + """Exception raised when WebDriver creation times out""" + + pass + + +def _timeout_handler(signum: int, frame: Any) -> None: + """Signal handler for WebDriver creation timeout""" + raise WebDriverCreationError("WebDriver creation timed out") + + +@dataclass +class PooledWebDriver: + """Wrapper for pooled WebDriver instance with metadata""" + + driver: WebDriver + created_at: float + last_used: float + window_size: WindowSize + user_id: int | None = None + is_healthy: bool = True + usage_count: int = 0 + + +class WebDriverPool: + """ + Connection pool for WebDriver instances to improve screenshot performance. + + Features: + - Reuses WebDriver instances across requests + - Automatic health checking and recovery + - TTL-based expiration to prevent memory leaks + - Thread-safe operations + - Per-user driver isolation for security + """ + + def __init__( + self, + max_pool_size: int = 5, + max_age_seconds: int = 3600, # 1 hour + max_usage_count: int = 50, # Recreate after 50 uses + idle_timeout_seconds: int = 300, # 5 minutes + health_check_interval: int = 60, # 1 minute + creation_timeout_seconds: int = 30, # SECURITY FIX: Timeout for driver creation + ): + self.max_pool_size = max_pool_size + self.max_age_seconds = max_age_seconds + self.max_usage_count = max_usage_count + self.idle_timeout_seconds = idle_timeout_seconds + self.health_check_interval = health_check_interval + self.creation_timeout_seconds = creation_timeout_seconds + + # Thread-safe pool management + self._pool: Queue[PooledWebDriver] = Queue(maxsize=max_pool_size) + self._active_drivers: Dict[int, PooledWebDriver] = {} + self._lock = threading.RLock() + self._last_health_check = time.time() + + # Pool statistics + self._stats = { + "created": 0, + "destroyed": 0, + "borrowed": 0, + "returned": 0, + "health_check_failures": 0, + "evictions": 0, + } + + def get_stats(self) -> Dict[str, Any]: + """Get pool statistics for monitoring""" + with self._lock: + return { + **self._stats, + "pool_size": self._pool.qsize(), + "active_count": len(self._active_drivers), + "max_pool_size": self.max_pool_size, + } + + def _create_driver( + self, window_size: WindowSize, user_id: int | None = None + ) -> PooledWebDriver: + """Create a new WebDriver instance with timeout protection""" + driver = None + old_handler = None + + try: + # SECURITY FIX: Set up timeout protection for driver creation + old_handler = signal.signal(signal.SIGALRM, _timeout_handler) + signal.alarm(self.creation_timeout_seconds) + + driver_type = current_app.config.get("WEBDRIVER_TYPE", "firefox") + selenium_driver = WebDriverSelenium(driver_type, window_size) + + # Create the actual WebDriver with timeout protection + driver = selenium_driver.create() + driver.set_window_size(*window_size) + + # Clear the alarm - creation successful + signal.alarm(0) + + pooled_driver = PooledWebDriver( + driver=driver, + created_at=time.time(), + last_used=time.time(), + window_size=window_size, + user_id=user_id, + is_healthy=True, + usage_count=0, + ) + + self._stats["created"] += 1 + logger.debug( + "Created new WebDriver instance for window size %s", window_size + ) + return pooled_driver + + except WebDriverCreationError: + logger.error( + "WebDriver creation timed out after %s seconds", + self.creation_timeout_seconds, + ) + if driver: + try: + driver.quit() + except Exception: + logger.debug("Failed to cleanup driver during timeout") + raise Exception("WebDriver creation timed out") from None + + except Exception as e: + logger.error("Failed to create WebDriver: %s", e) + if driver: + try: + driver.quit() + except Exception: + logger.debug("Failed to cleanup driver during error") + raise + + finally: + # Restore original signal handler and clear alarm + signal.alarm(0) + if old_handler is not None: + signal.signal(signal.SIGALRM, old_handler) + + def _is_driver_valid(self, pooled_driver: PooledWebDriver) -> bool: + """Check if a pooled driver is still valid for use""" + now = time.time() + + # Check age limit + if now - pooled_driver.created_at > self.max_age_seconds: + logger.debug("Driver expired due to age") + return False + + # Check usage count limit + if pooled_driver.usage_count >= self.max_usage_count: + logger.debug("Driver expired due to usage count") + return False + + # Check idle timeout + if now - pooled_driver.last_used > self.idle_timeout_seconds: + logger.debug("Driver expired due to idle timeout") + return False + + # Check if driver is healthy + if not pooled_driver.is_healthy: + logger.debug("Driver marked as unhealthy") + return False + + return True + + def _health_check_driver(self, pooled_driver: PooledWebDriver) -> bool: + """Perform health check on a WebDriver instance""" + try: + # Simple health check - try to get current URL + # This will fail if the driver is dead/hung + _ = pooled_driver.driver.current_url + pooled_driver.is_healthy = True + return True + except WebDriverException: + pooled_driver.is_healthy = False + self._stats["health_check_failures"] += 1 + logger.warning("WebDriver failed health check") + return False + except Exception as e: + pooled_driver.is_healthy = False + self._stats["health_check_failures"] += 1 + logger.warning("WebDriver health check error: %s", e) + return False + + def _destroy_driver(self, pooled_driver: PooledWebDriver) -> None: + """Safely destroy a WebDriver instance""" + try: + WebDriverSelenium.destroy(pooled_driver.driver) + self._stats["destroyed"] += 1 + logger.debug("Destroyed WebDriver instance") + except Exception as e: + logger.warning("Error destroying WebDriver: %s", e) + + def _cleanup_expired_drivers(self) -> None: + """Remove expired drivers from the pool""" + expired_drivers = [] + + # Check pool for expired drivers + while not self._pool.empty(): + try: + pooled_driver = self._pool.get_nowait() + if self._is_driver_valid(pooled_driver): + # Driver is still valid, put it back + self._pool.put_nowait(pooled_driver) + break + else: + # Driver is expired + expired_drivers.append(pooled_driver) + self._stats["evictions"] += 1 + except Empty: + break + except Full: + # Pool is full, stop checking + break + + # Destroy expired drivers + for pooled_driver in expired_drivers: + self._destroy_driver(pooled_driver) + + def _periodic_health_check(self) -> None: + """Perform periodic health checks if needed""" + now = time.time() + if now - self._last_health_check < self.health_check_interval: + return + + self._last_health_check = now + logger.debug("Performing periodic WebDriver pool health check") + + # Cleanup expired drivers + self._cleanup_expired_drivers() + + # Health check active drivers + unhealthy_drivers = [] + for driver_id, pooled_driver in self._active_drivers.items(): + if not self._health_check_driver(pooled_driver): + unhealthy_drivers.append(driver_id) + + # Remove unhealthy active drivers + for driver_id in unhealthy_drivers: + pooled_driver = self._active_drivers.pop(driver_id) + if pooled_driver: + self._destroy_driver(pooled_driver) + + @contextmanager + def get_driver( + self, window_size: WindowSize, user_id: int | None = None + ) -> Generator[WebDriver, None, None]: + """ + Context manager to get a WebDriver from the pool. + + Args: + window_size: Required window size for the driver + user_id: Optional user ID for driver isolation + + Yields: + WebDriver instance ready for use + """ + pooled_driver = None + driver_id = None + + try: + with self._lock: + # Periodic maintenance + self._periodic_health_check() + + # Try to get a driver from the pool + while not self._pool.empty(): + try: + candidate = self._pool.get_nowait() + + # Check if driver is valid and matches requirements + if ( + self._is_driver_valid(candidate) + and candidate.window_size == window_size + ): + # Update user_id for the reused driver + candidate.user_id = user_id + pooled_driver = candidate + break + else: + # Driver is invalid, destroy it + self._destroy_driver(candidate) + self._stats["evictions"] += 1 + except Empty: + break + + # If no suitable driver found, create a new one + if pooled_driver is None: + pooled_driver = self._create_driver(window_size, user_id) + + # Mark driver as in use + driver_id = id(pooled_driver.driver) + pooled_driver.last_used = time.time() + pooled_driver.usage_count += 1 + self._active_drivers[driver_id] = pooled_driver + self._stats["borrowed"] += 1 + + # Yield the driver for use + yield pooled_driver.driver + + except Exception as e: + # Mark driver as unhealthy if an error occurred + if pooled_driver: + pooled_driver.is_healthy = False + logger.error("Error using pooled WebDriver: %s", e) + raise + + finally: + # Return driver to pool or destroy if unhealthy + if pooled_driver and driver_id: + with self._lock: + self._active_drivers.pop(driver_id, None) + + if pooled_driver.is_healthy and self._is_driver_valid( + pooled_driver + ): + # Try to return to pool + try: + self._pool.put_nowait(pooled_driver) + self._stats["returned"] += 1 + logger.debug("Returned WebDriver to pool") + except Full: + # Pool is full, destroy the driver + self._destroy_driver(pooled_driver) + logger.debug("Pool full, destroyed WebDriver") + else: + # Driver is unhealthy or expired, destroy it + self._destroy_driver(pooled_driver) + logger.debug("Destroyed unhealthy/expired WebDriver") + + def shutdown(self) -> None: + """Shutdown the pool and destroy all drivers""" + with self._lock: + logger.info("Shutting down WebDriver pool") + + # Destroy all active drivers + for pooled_driver in self._active_drivers.values(): + self._destroy_driver(pooled_driver) + self._active_drivers.clear() + + # Destroy all pooled drivers + while not self._pool.empty(): + try: + pooled_driver = self._pool.get_nowait() + self._destroy_driver(pooled_driver) + except Empty: + break + + logger.info( + "WebDriver pool shutdown complete. Final stats: %s", self.get_stats() + ) + + +# Global pool instance +_global_pool: WebDriverPool | None = None +_pool_lock = threading.Lock() + + +def get_webdriver_pool() -> WebDriverPool: + """Get or create the global WebDriver pool""" + global _global_pool + + if _global_pool is None: + with _pool_lock: + if _global_pool is None: + # Get pool configuration from Flask config + config = current_app.config + pool_config = config.get("WEBDRIVER_POOL", {}) + + _global_pool = WebDriverPool( + max_pool_size=pool_config.get("MAX_POOL_SIZE", 5), + max_age_seconds=pool_config.get("MAX_AGE_SECONDS", 3600), + max_usage_count=pool_config.get("MAX_USAGE_COUNT", 50), + idle_timeout_seconds=pool_config.get("IDLE_TIMEOUT_SECONDS", 300), + health_check_interval=pool_config.get("HEALTH_CHECK_INTERVAL", 60), + ) + logger.info("Initialized global WebDriver pool") + + return _global_pool + + +def shutdown_webdriver_pool() -> None: + """Shutdown the global WebDriver pool""" + global _global_pool + + if _global_pool is not None: + with _pool_lock: + if _global_pool is not None: + _global_pool.shutdown() + _global_pool = None diff --git a/superset/mcp_service/server.py b/superset/mcp_service/server.py index 2fcb39722d3..e723ef611a9 100644 --- a/superset/mcp_service/server.py +++ b/superset/mcp_service/server.py @@ -74,6 +74,8 @@ def run_server( configure_logging(debug) + # DO NOT IMPORT TOOLS HERE!! IMPORT THEM IN app.py!!!!! + if use_factory_config: # Use factory configuration for customization logging.info("Creating MCP app from factory configuration...") diff --git a/superset/mcp_service/sql_lab/__init__.py b/superset/mcp_service/sql_lab/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/superset/mcp_service/sql_lab/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/mcp_service/sql_lab/execute_sql_core.py b/superset/mcp_service/sql_lab/execute_sql_core.py new file mode 100644 index 00000000000..263e0123385 --- /dev/null +++ b/superset/mcp_service/sql_lab/execute_sql_core.py @@ -0,0 +1,221 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Generic SQL execution core for MCP service. +""" + +import logging +from typing import Any + +from superset.mcp_service.mcp_core import BaseCore +from superset.mcp_service.sql_lab.schemas import ( + ExecuteSqlRequest, + ExecuteSqlResponse, +) + + +class ExecuteSqlCore(BaseCore): + """ + Generic tool for executing SQL queries with security validation. + + This tool provides a high-level interface for SQL execution that can be used + by different MCP tools or other components. It handles: + - Database access validation + - SQL query validation (DML permissions, disallowed functions) + - Parameter substitution + - Query execution with timeout + - Result formatting + + The tool can work in two modes: + 1. Simple mode: Direct SQL execution using sql_lab_utils (default) + 2. Command mode: Using ExecuteSqlCommand for full SQL Lab features + """ + + def __init__( + self, + use_command_mode: bool = False, + logger: logging.Logger | None = None, + ) -> None: + super().__init__(logger) + self.use_command_mode = use_command_mode + + def run_tool(self, request: ExecuteSqlRequest) -> ExecuteSqlResponse: + """ + Execute SQL query and return results. + + Args: + request: ExecuteSqlRequest with database_id, sql, and optional parameters + + Returns: + ExecuteSqlResponse with success status, results, or error information + """ + try: + # Import inside method to avoid initialization issues + from superset.mcp_service.sql_lab.sql_lab_utils import check_database_access + + # Check database access + database = check_database_access(request.database_id) + + if self.use_command_mode: + # Use full SQL Lab command for complex queries + return self._execute_with_command(request, database) + else: + # Use simplified execution for basic queries + return self._execute_simple(request, database) + + except Exception as e: + # Handle errors and return error response with proper error types + self._log_error(e, "executing SQL") + return self._handle_execution_error(e) + + def _execute_simple( + self, + request: ExecuteSqlRequest, + database: Any, + ) -> ExecuteSqlResponse: + """Execute SQL using simplified sql_lab_utils.""" + # Import inside method to avoid initialization issues + from superset.mcp_service.sql_lab.sql_lab_utils import execute_sql_query + + results = execute_sql_query( + database=database, + sql=request.sql, + schema=request.schema_name, + limit=request.limit, + timeout=request.timeout, + parameters=request.parameters, + ) + + return ExecuteSqlResponse( + success=True, + rows=results.get("rows"), + columns=results.get("columns"), + row_count=results.get("row_count"), + affected_rows=results.get("affected_rows"), + query_id=None, # Not available in simple mode + execution_time=results.get("execution_time"), + error=None, + error_type=None, + ) + + def _execute_with_command( + self, + request: ExecuteSqlRequest, + database: Any, + ) -> ExecuteSqlResponse: + """Execute SQL using full SQL Lab command (not implemented yet).""" + # This would use ExecuteSqlCommand for full SQL Lab features + # Including query caching, async execution, complex parsing, etc. + # For now, we'll fall back to simple execution + self._log_info("Command mode not fully implemented, using simple mode") + return self._execute_simple(request, database) + + # Future implementation would look like: + # context = SqlJsonExecutionContext( + # database_id=request.database_id, + # sql=request.sql, + # schema=request.schema_name, + # limit=request.limit, + # # ... other context fields + # ) + # + # command = ExecuteSqlCommand( + # execution_context=context, + # query_dao=QueryDAO(), + # database_dao=DatabaseDAO(), + # # ... other dependencies + # ) + # + # result = command.run() + # return self._format_command_result(result) + + def _handle_execution_error(self, e: Exception) -> ExecuteSqlResponse: + """Map exceptions to error responses.""" + error_type = self._get_error_type(e) + return ExecuteSqlResponse( + success=False, + error=str(e), + error_type=error_type, + rows=None, + columns=None, + row_count=None, + affected_rows=None, + query_id=None, + execution_time=None, + ) + + def _get_error_type(self, e: Exception) -> str: + """Determine error type from exception.""" + # Import inside method to avoid initialization issues + from superset.exceptions import ( + SupersetDisallowedSQLFunctionException, + SupersetDMLNotAllowedException, + SupersetErrorException, + SupersetSecurityException, + SupersetTimeoutException, + ) + + if isinstance(e, SupersetSecurityException): + return "SECURITY_ERROR" + elif isinstance(e, SupersetTimeoutException): + return "TIMEOUT" + elif isinstance(e, SupersetDMLNotAllowedException): + return "DML_NOT_ALLOWED" + elif isinstance(e, SupersetDisallowedSQLFunctionException): + return "DISALLOWED_FUNCTION" + elif isinstance(e, SupersetErrorException): + return self._extract_superset_error_type(e) + else: + return "EXECUTION_ERROR" + + def _extract_superset_error_type(self, e: Exception) -> str: + """Extract error type from SupersetErrorException.""" + if hasattr(e, "error") and hasattr(e.error, "error_type"): + error_type_name = e.error.error_type.name + # Map common error type patterns + if "INVALID_PAYLOAD" in error_type_name: + return "INVALID_PAYLOAD_FORMAT_ERROR" + elif "DATABASE_NOT_FOUND" in error_type_name: + return "DATABASE_NOT_FOUND_ERROR" + elif "SECURITY" in error_type_name: + return "SECURITY_ERROR" + elif "TIMEOUT" in error_type_name: + return "TIMEOUT" + elif "DML_NOT_ALLOWED" in error_type_name: + return "DML_NOT_ALLOWED" + else: + return error_type_name + return "EXECUTION_ERROR" + + def _format_command_result( + self, command_result: dict[str, Any] + ) -> ExecuteSqlResponse: + """Format ExecuteSqlCommand result into ExecuteSqlResponse.""" + # This would extract relevant fields from command result + # Placeholder implementation for future use + return ExecuteSqlResponse( + success=command_result.get("success", False), + rows=command_result.get("data"), + columns=command_result.get("columns"), + row_count=command_result.get("row_count"), + affected_rows=command_result.get("affected_rows"), + query_id=command_result.get("query_id"), + execution_time=command_result.get("execution_time"), + error=command_result.get("error"), + error_type=command_result.get("error_type"), + ) diff --git a/superset/mcp_service/sql_lab/schemas.py b/superset/mcp_service/sql_lab/schemas.py new file mode 100644 index 00000000000..fcfe7cb62ba --- /dev/null +++ b/superset/mcp_service/sql_lab/schemas.py @@ -0,0 +1,109 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Schemas for SQL Lab MCP tools.""" + +from typing import Any + +from pydantic import BaseModel, Field, field_validator + + +class ExecuteSqlRequest(BaseModel): + """Request schema for executing SQL queries.""" + + database_id: int = Field( + ..., description="Database connection ID to execute query against" + ) + sql: str = Field(..., description="SQL query to execute") + schema_name: str | None = Field( + None, description="Schema to use for query execution", alias="schema" + ) + limit: int = Field( + default=1000, + description="Maximum number of rows to return", + ge=1, + le=10000, + ) + timeout: int = Field( + default=30, description="Query timeout in seconds", ge=1, le=300 + ) + parameters: dict[str, Any] | None = Field( + None, description="Parameters for query substitution" + ) + + @field_validator("sql") + @classmethod + def sql_not_empty(cls, v: str) -> str: + if not v or not v.strip(): + raise ValueError("SQL query cannot be empty") + return v.strip() + + +class ColumnInfo(BaseModel): + """Column metadata information.""" + + name: str = Field(..., description="Column name") + type: str = Field(..., description="Column data type") + is_nullable: bool | None = Field(None, description="Whether column allows NULL") + + +class ExecuteSqlResponse(BaseModel): + """Response schema for SQL execution results.""" + + success: bool = Field(..., description="Whether query executed successfully") + rows: Any | None = Field( + None, description="Query result rows as list of dictionaries" + ) + columns: list[ColumnInfo] | None = Field( + None, description="Column metadata information" + ) + row_count: int | None = Field(None, description="Number of rows returned") + affected_rows: int | None = Field( + None, description="Number of rows affected (for DML queries)" + ) + query_id: str | None = Field(None, description="Query tracking ID") + execution_time: float | None = Field( + None, description="Query execution time in seconds" + ) + error: str | None = Field(None, description="Error message if query failed") + error_type: str | None = Field(None, description="Type of error if failed") + + +class OpenSqlLabRequest(BaseModel): + """Request schema for opening SQL Lab with context.""" + + database_connection_id: int = Field( + ..., description="Database connection ID to use in SQL Lab" + ) + schema_name: str | None = Field( + None, description="Default schema to select in SQL Lab", alias="schema" + ) + dataset_in_context: str | None = Field( + None, description="Dataset name/table to provide as context" + ) + sql: str | None = Field(None, description="SQL query to pre-populate in the editor") + title: str | None = Field(None, description="Title for the SQL Lab tab/query") + + +class SqlLabResponse(BaseModel): + """Response schema for SQL Lab URL generation.""" + + url: str = Field(..., description="URL to open SQL Lab with context") + database_id: int = Field(..., description="Database ID used") + schema_name: str | None = Field(None, description="Schema selected", alias="schema") + title: str | None = Field(None, description="Query title") + error: str | None = Field(None, description="Error message if failed") diff --git a/superset/mcp_service/sql_lab/sql_lab_utils.py b/superset/mcp_service/sql_lab/sql_lab_utils.py new file mode 100644 index 00000000000..6844e26a497 --- /dev/null +++ b/superset/mcp_service/sql_lab/sql_lab_utils.py @@ -0,0 +1,243 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Utility functions for SQL Lab MCP tools. + +This module contains helper functions for SQL execution, validation, +and database access that are shared across SQL Lab tools. +""" + +import logging +from typing import Any + +logger = logging.getLogger(__name__) + + +def check_database_access(database_id: int) -> Any: + """Check if user has access to the database.""" + # Import inside function to avoid initialization issues + from superset import db, security_manager + from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + from superset.exceptions import SupersetErrorException, SupersetSecurityException + from superset.models.core import Database + + # Use session query to ensure relationships are loaded + database = db.session.query(Database).filter_by(id=database_id).first() + + if not database: + raise SupersetErrorException( + SupersetError( + message=f"Database with ID {database_id} not found", + error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Check database access permissions + if not security_manager.can_access_database(database): + raise SupersetSecurityException( + SupersetError( + message=f"Access denied to database {database.database_name}", + error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + return database + + +def validate_sql_query(sql: str, database: Any) -> None: + """Validate SQL query for security and syntax.""" + # Import inside function to avoid initialization issues + from flask import current_app as app + + from superset.exceptions import ( + SupersetDisallowedSQLFunctionException, + SupersetDMLNotAllowedException, + ) + + # Simplified validation without complex parsing + sql_upper = sql.upper().strip() + + # Check for DML operations if not allowed + dml_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", "TRUNCATE"] + if any(sql_upper.startswith(keyword) for keyword in dml_keywords): + if not database.allow_dml: + raise SupersetDMLNotAllowedException() + + # Check for disallowed functions from config + disallowed_functions = app.config.get("DISALLOWED_SQL_FUNCTIONS", {}).get( + "sqlite", + set(), # Default to sqlite for now + ) + if disallowed_functions: + sql_lower = sql.lower() + for func in disallowed_functions: + if f"{func.lower()}(" in sql_lower: + raise SupersetDisallowedSQLFunctionException(disallowed_functions) + + +def execute_sql_query( + database: Any, + sql: str, + schema: str | None, + limit: int, + timeout: int, + parameters: dict[str, Any] | None, +) -> dict[str, Any]: + """Execute SQL query and return results.""" + # Import inside function to avoid initialization issues + from superset.utils.dates import now_as_float + + start_time = now_as_float() + + # Apply parameters and validate + sql = _apply_parameters(sql, parameters) + validate_sql_query(sql, database) + + # Apply limit for SELECT queries + rendered_sql = _apply_limit(sql, limit) + + # Execute and get results + results = _execute_query(database, rendered_sql, schema, limit) + + # Calculate execution time + end_time = now_as_float() + results["execution_time"] = end_time - start_time + + return results + + +def _apply_parameters(sql: str, parameters: dict[str, Any] | None) -> str: + """Apply parameters to SQL query.""" + # Import inside function to avoid initialization issues + from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + from superset.exceptions import SupersetErrorException + + if parameters: + try: + return sql.format(**parameters) + except KeyError as e: + raise SupersetErrorException( + SupersetError( + message=f"Missing parameter: {e}", + error_type=SupersetErrorType.INVALID_PAYLOAD_FORMAT_ERROR, + level=ErrorLevel.ERROR, + ) + ) from e + else: + # Check if SQL contains placeholders when no parameters provided + import re + + placeholders = re.findall(r"{(\w+)}", sql) + if placeholders: + raise SupersetErrorException( + SupersetError( + message=f"Missing parameter: {placeholders[0]}", + error_type=SupersetErrorType.INVALID_PAYLOAD_FORMAT_ERROR, + level=ErrorLevel.ERROR, + ) + ) + return sql + + +def _apply_limit(sql: str, limit: int) -> str: + """Apply limit to SELECT queries if not already present.""" + sql_lower = sql.lower().strip() + if sql_lower.startswith("select") and "limit" not in sql_lower: + return f"{sql.rstrip().rstrip(';')} LIMIT {limit}" + return sql + + +def _execute_query( + database: Any, + sql: str, + schema: str | None, + limit: int, +) -> dict[str, Any]: + """Execute the query and process results.""" + # Import inside function to avoid initialization issues + from superset.utils.core import QuerySource + + results = { + "rows": [], + "columns": [], + "row_count": 0, + "affected_rows": None, + "execution_time": 0.0, + } + + try: + # Execute query with timeout + with database.get_raw_connection( + catalog=None, + schema=schema, + source=QuerySource.SQL_LAB, + ) as conn: + cursor = conn.cursor() + cursor.execute(sql) + + # Process results based on query type + if _is_select_query(sql): + _process_select_results(cursor, results, limit) + else: + _process_dml_results(cursor, conn, results) + + except Exception as e: + logger.error("Error executing SQL: %s", e) + raise + + return results + + +def _is_select_query(sql: str) -> bool: + """Check if SQL is a SELECT query.""" + return sql.lower().strip().startswith("select") + + +def _process_select_results(cursor: Any, results: dict[str, Any], limit: int) -> None: + """Process SELECT query results.""" + # Fetch results + data = cursor.fetchmany(limit) + + # Get column metadata + column_info = [] + if cursor.description: + for col in cursor.description: + column_info.append( + { + "name": col[0], + "type": str(col[1]) if col[1] else "unknown", + "is_nullable": col[6] if len(col) > 6 else None, + } + ) + + # Set column info regardless of whether there's data + if column_info: + results["columns"] = column_info + + # Convert rows to dictionaries + column_names = [col["name"] for col in column_info] + results["rows"] = [dict(zip(column_names, row, strict=False)) for row in data] + results["row_count"] = len(data) + + +def _process_dml_results(cursor: Any, conn: Any, results: dict[str, Any]) -> None: + """Process DML query results.""" + results["affected_rows"] = cursor.rowcount + conn.commit() # pylint: disable=consider-using-transaction diff --git a/superset/mcp_service/sql_lab/tool/__init__.py b/superset/mcp_service/sql_lab/tool/__init__.py new file mode 100644 index 00000000000..0fc7a0dd89f --- /dev/null +++ b/superset/mcp_service/sql_lab/tool/__init__.py @@ -0,0 +1,30 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +SQL Lab MCP Tools +""" + +from superset.mcp_service.sql_lab.tool.execute_sql import execute_sql +from superset.mcp_service.sql_lab.tool.open_sql_lab_with_context import ( + open_sql_lab_with_context, +) + +__all__ = [ + "execute_sql", + "open_sql_lab_with_context", +] diff --git a/superset/mcp_service/sql_lab/tool/execute_sql.py b/superset/mcp_service/sql_lab/tool/execute_sql.py new file mode 100644 index 00000000000..4fb8914d60e --- /dev/null +++ b/superset/mcp_service/sql_lab/tool/execute_sql.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Execute SQL MCP Tool + +Tool for executing SQL queries against databases with security validation +and timeout protection. +""" + +import logging + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.sql_lab.execute_sql_core import ExecuteSqlCore +from superset.mcp_service.sql_lab.schemas import ( + ExecuteSqlRequest, + ExecuteSqlResponse, +) + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +async def execute_sql(request: ExecuteSqlRequest, ctx: Context) -> ExecuteSqlResponse: + """Execute SQL query against database. + + Returns query results with security validation and timeout protection. + """ + await ctx.info( + "Starting SQL execution: database_id=%s, timeout=%s, limit=%s, schema=%s" + % (request.database_id, request.timeout, request.limit, request.schema_name) + ) + + # Log SQL query details (truncated for security) + sql_preview = request.sql[:100] + "..." if len(request.sql) > 100 else request.sql + await ctx.debug( + "SQL query details: sql_preview=%r, sql_length=%s, has_parameters=%s" + % ( + sql_preview, + len(request.sql), + bool(request.parameters), + ) + ) + + logger.info("Executing SQL query on database ID: %s", request.database_id) + + try: + # Use the ExecuteSqlCore to handle all the logic + sql_tool = ExecuteSqlCore(use_command_mode=False, logger=logger) + result = sql_tool.run_tool(request) + + # Log successful execution + if hasattr(result, "data") and result.data: + row_count = len(result.data) if isinstance(result.data, list) else 1 + await ctx.info( + "SQL execution completed successfully: rows_returned=%s, " + "query_duration_ms=%s" + % ( + row_count, + getattr(result, "query_duration_ms", None), + ) + ) + else: + await ctx.info("SQL execution completed: status=no_data_returned") + + return result + + except Exception as e: + await ctx.error( + "SQL execution failed: error=%s, database_id=%s" + % ( + str(e), + request.database_id, + ) + ) + raise diff --git a/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py b/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py new file mode 100644 index 00000000000..8e25a82cd4f --- /dev/null +++ b/superset/mcp_service/sql_lab/tool/open_sql_lab_with_context.py @@ -0,0 +1,118 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Open SQL Lab with Context MCP Tool + +Tool for generating SQL Lab URLs with pre-populated query and context. +""" + +import logging +from urllib.parse import urlencode + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.sql_lab.schemas import ( + OpenSqlLabRequest, + SqlLabResponse, +) + +logger = logging.getLogger(__name__) + + +@mcp.tool +@mcp_auth_hook +def open_sql_lab_with_context( + request: OpenSqlLabRequest, ctx: Context +) -> SqlLabResponse: + """Generate SQL Lab URL with pre-populated query and context. + + Returns URL for direct navigation. + """ + try: + from superset.daos.database import DatabaseDAO + + # Validate database exists and is accessible + database = DatabaseDAO.find_by_id(request.database_connection_id) + if not database: + return SqlLabResponse( + url="", + database_id=request.database_connection_id, + schema_name=request.schema_name, + title=request.title, + error=f"Database with ID {request.database_connection_id} not found", + ) + + # Build query parameters for SQL Lab URL + params = { + "dbid": str(request.database_connection_id), + } + + if request.schema_name: + params["schema"] = request.schema_name + + if request.sql: + params["sql"] = request.sql + + if request.title: + params["title"] = request.title + + if request.dataset_in_context: + # Add dataset context as a comment in the SQL if no SQL provided + if not request.sql: + context_comment = ( + f"-- Context: Working with dataset '{request.dataset_in_context}'\n" + f"-- Database: {database.database_name}\n" + ) + if request.schema_name: + context_comment += f"-- Schema: {request.schema_name}\n" + table_reference = ( + f"{request.schema_name}.{request.dataset_in_context}" + ) + else: + table_reference = request.dataset_in_context + + context_comment += f"\nSELECT * FROM {table_reference} LIMIT 100;" + params["sql"] = context_comment + + # Construct SQL Lab URL + query_string = urlencode(params) + url = f"/sqllab?{query_string}" + + logger.info( + "Generated SQL Lab URL for database %s", request.database_connection_id + ) + + return SqlLabResponse( + url=url, + database_id=request.database_connection_id, + schema_name=request.schema_name, + title=request.title, + error=None, + ) + + except Exception as e: + logger.error("Error generating SQL Lab URL: %s", e) + return SqlLabResponse( + url="", + database_id=request.database_connection_id, + schema_name=request.schema_name, + title=request.title, + error=f"Failed to generate SQL Lab URL: {str(e)}", + ) diff --git a/superset/mcp_service/system/prompts/__init__.py b/superset/mcp_service/system/prompts/__init__.py new file mode 100644 index 00000000000..0b78f1fd773 --- /dev/null +++ b/superset/mcp_service/system/prompts/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""System prompts for Superset MCP service""" + +# Import to register prompts when module is imported +from . import quickstart # noqa: F401 diff --git a/superset/mcp_service/system/prompts/quickstart.py b/superset/mcp_service/system/prompts/quickstart.py new file mode 100644 index 00000000000..5eaccfec0a9 --- /dev/null +++ b/superset/mcp_service/system/prompts/quickstart.py @@ -0,0 +1,94 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +System prompts for general Superset guidance +""" + +import logging + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook + +logger = logging.getLogger(__name__) + + +@mcp.prompt("superset_quickstart") +@mcp_auth_hook +async def superset_quickstart_prompt( + user_type: str = "analyst", focus_area: str = "general" +) -> str: + """ + Guide new users through their first Superset experience. + + This prompt helps users: + 1. Understand what data is available + 2. Create their first visualization + 3. Build a simple dashboard + 4. Learn key Superset concepts + + Args: + user_type: Type of user (analyst, executive, developer) + focus_area: Area of interest (sales, marketing, operations, general) + """ + # Build personalized prompt based on user type + intro_messages = { + "analyst": "I see you're an analyst. Let's explore the data and build some " + "detailed visualizations.", + "executive": "Welcome! Let's create a high-level dashboard with key business " + "metrics.", + "developer": "Great to have a developer here! Let's explore both the UI and " + "API capabilities.", + } + + focus_examples = { + "sales": "Since you're interested in sales, we'll focus on revenue, customer, " + "and product metrics.", + "marketing": "For marketing analytics, we'll look at campaigns, conversions, " + "and customer acquisition.", + "operations": "Let's explore operational efficiency, inventory, and process " + "metrics.", + "general": "We'll explore various datasets to find what's most relevant to " + "you.", + } + + intro = intro_messages.get(user_type, intro_messages["analyst"]) + focus = focus_examples.get(focus_area, focus_examples["general"]) + + return f"""Welcome to Apache Superset! I'll guide you through creating your first + dashboard. + +{intro} {focus} + +I'll help you through these steps: +1. ๐Ÿ“Š **Explore Available Data** - See what datasets you can work with +2. ๐Ÿ” **Understand Your Data** - Examine columns, metrics, and sample data +3. ๐Ÿ“ˆ **Create Visualizations** - Build charts that tell a story +4. ๐ŸŽฏ **Design a Dashboard** - Combine charts into an interactive dashboard +5. ๐Ÿš€ **Learn Advanced Features** - Discover filters, SQL Lab, and more + +To get started, I'll use these tools: +- `get_superset_instance_info` - Overview of your Superset instance +- `list_datasets` - Find available datasets +- `get_dataset_info` - Explore dataset details +- `generate_chart` - Create visualizations +- `generate_dashboard` - Build your dashboard + +Let me begin by checking what's available in your Superset instance. I'll first get +an overview, then show you the datasets filtered by your interest in {focus_area}. + +Would you like me to start by showing you what data you can work with?""" diff --git a/superset/mcp_service/system/resources/__init__.py b/superset/mcp_service/system/resources/__init__.py new file mode 100644 index 00000000000..2659dbfc198 --- /dev/null +++ b/superset/mcp_service/system/resources/__init__.py @@ -0,0 +1,21 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""System resources for Superset MCP service""" + +# Import to register resources when module is imported +from . import instance_metadata # noqa: F401 diff --git a/superset/mcp_service/system/resources/instance_metadata.py b/superset/mcp_service/system/resources/instance_metadata.py new file mode 100644 index 00000000000..4e433c7a164 --- /dev/null +++ b/superset/mcp_service/system/resources/instance_metadata.py @@ -0,0 +1,107 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +System resources for providing Superset configuration and stats +""" + +import logging + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook + +logger = logging.getLogger(__name__) + + +@mcp.resource("superset://instance/metadata") +@mcp_auth_hook +def get_instance_metadata_resource() -> str: + """ + Provide comprehensive metadata about the Superset instance. + + This resource gives LLMs context about: + - Available datasets and their popularity + - Dashboard and chart statistics + - Database connections + - Popular queries and usage patterns + - Available visualization types + - Feature flags and configuration + """ + try: + # Import the shared core and DAOs at runtime + # Create a shared core instance for the resource + from typing import Any, cast, Type + + from superset.daos.base import BaseDAO + from superset.daos.chart import ChartDAO + from superset.daos.dashboard import DashboardDAO + from superset.daos.database import DatabaseDAO + from superset.daos.dataset import DatasetDAO + from superset.daos.tag import TagDAO + from superset.daos.user import UserDAO + from superset.mcp_service.mcp_core import InstanceInfoCore + from superset.mcp_service.system.schemas import InstanceInfo + from superset.mcp_service.system.tool.get_superset_instance_info import ( + calculate_dashboard_breakdown, + calculate_database_breakdown, + calculate_instance_summary, + calculate_popular_content, + calculate_recent_activity, + ) + + instance_info_core = InstanceInfoCore( + dao_classes={ + "dashboards": cast(Type[BaseDAO[Any]], DashboardDAO), + "charts": cast(Type[BaseDAO[Any]], ChartDAO), + "datasets": cast(Type[BaseDAO[Any]], DatasetDAO), + "databases": cast(Type[BaseDAO[Any]], DatabaseDAO), + "users": cast(Type[BaseDAO[Any]], UserDAO), + "tags": cast(Type[BaseDAO[Any]], TagDAO), + }, + output_schema=InstanceInfo, + metric_calculators={ + "instance_summary": calculate_instance_summary, + "recent_activity": calculate_recent_activity, + "dashboard_breakdown": calculate_dashboard_breakdown, + "database_breakdown": calculate_database_breakdown, + "popular_content": calculate_popular_content, + }, + time_windows={ + "recent": 7, + "monthly": 30, + "quarterly": 90, + }, + logger=logger, + ) + + # Use the shared core's resource method + return instance_info_core.get_resource() + + except Exception as e: + logger.error("Error generating instance metadata: %s", e) + # Return minimal metadata on error + from superset.utils import json + + return json.dumps( + { + "error": "Unable to fetch complete metadata", + "tips": [ + "Use list_datasets to explore available data", + "Use get_superset_instance_info for basic stats", + ], + } + ) diff --git a/superset/mcp_service/system/schemas.py b/superset/mcp_service/system/schemas.py index 7c7159abe36..978b39a50ec 100644 --- a/superset/mcp_service/system/schemas.py +++ b/superset/mcp_service/system/schemas.py @@ -16,10 +16,14 @@ # under the License. """ -Pydantic schemas for shared system types (UserInfo, TagInfo, PaginationInfo) +Pydantic schemas for system-level (instance/info) responses + +This module contains Pydantic models for serializing Superset instance metadata and +system-level info. """ -from typing import List +from datetime import datetime +from typing import Dict, List from pydantic import BaseModel, ConfigDict, Field @@ -36,12 +40,93 @@ class HealthCheckResponse(BaseModel): version: str python_version: str platform: str - uptime_seconds: float + uptime_seconds: float | None = None + + +class GetSupersetInstanceInfoRequest(BaseModel): + """ + Request schema for get_superset_instance_info tool. + + Currently has no parameters but provides consistent API for future extensibility. + """ + + model_config = ConfigDict( + extra="forbid", + str_strip_whitespace=True, + ) + + +class InstanceSummary(BaseModel): + total_dashboards: int = Field(..., description="Total number of dashboards") + total_charts: int = Field(..., description="Total number of charts") + total_datasets: int = Field(..., description="Total number of datasets") + total_databases: int = Field(..., description="Total number of databases") + total_users: int = Field(..., description="Total number of users") + total_roles: int = Field(..., description="Total number of roles") + total_tags: int = Field(..., description="Total number of tags") + avg_charts_per_dashboard: float = Field( + ..., description="Average number of charts per dashboard" + ) + + +class RecentActivity(BaseModel): + dashboards_created_last_30_days: int = Field( + ..., description="Dashboards created in the last 30 days" + ) + charts_created_last_30_days: int = Field( + ..., description="Charts created in the last 30 days" + ) + datasets_created_last_30_days: int = Field( + ..., description="Datasets created in the last 30 days" + ) + dashboards_modified_last_7_days: int = Field( + ..., description="Dashboards modified in the last 7 days" + ) + charts_modified_last_7_days: int = Field( + ..., description="Charts modified in the last 7 days" + ) + datasets_modified_last_7_days: int = Field( + ..., description="Datasets modified in the last 7 days" + ) + + +class DashboardBreakdown(BaseModel): + published: int = Field(..., description="Number of published dashboards") + unpublished: int = Field(..., description="Number of unpublished dashboards") + certified: int = Field(..., description="Number of certified dashboards") + with_charts: int = Field(..., description="Number of dashboards with charts") + without_charts: int = Field(..., description="Number of dashboards without charts") + + +class DatabaseBreakdown(BaseModel): + by_type: Dict[str, int] = Field(..., description="Breakdown of databases by type") + + +class PopularContent(BaseModel): + top_tags: List[str] = Field(..., description="Most popular tags") + top_creators: List[str] = Field(..., description="Most active creators") + + +class InstanceInfo(BaseModel): + instance_summary: InstanceSummary = Field( + ..., description="Instance summary information" + ) + recent_activity: RecentActivity = Field( + ..., description="Recent activity information" + ) + dashboard_breakdown: DashboardBreakdown = Field( + ..., description="Dashboard breakdown information" + ) + database_breakdown: DatabaseBreakdown = Field( + ..., description="Database breakdown by type" + ) + popular_content: PopularContent = Field( + ..., description="Popular content information" + ) + timestamp: datetime = Field(..., description="Response timestamp") class UserInfo(BaseModel): - """User information schema.""" - id: int | None = None username: str | None = None first_name: str | None = None @@ -51,17 +136,19 @@ class UserInfo(BaseModel): class TagInfo(BaseModel): - """Tag information schema.""" - id: int | None = None name: str | None = None type: str | None = None description: str | None = None -class PaginationInfo(BaseModel): - """Pagination metadata.""" +class RoleInfo(BaseModel): + id: int | None = None + name: str | None = None + permissions: List[str] | None = None + +class PaginationInfo(BaseModel): page: int = Field(..., description="Current page number") page_size: int = Field(..., description="Number of items per page") total_count: int = Field(..., description="Total number of items") @@ -69,11 +156,3 @@ class PaginationInfo(BaseModel): has_next: bool = Field(..., description="Whether there is a next page") has_previous: bool = Field(..., description="Whether there is a previous page") model_config = ConfigDict(ser_json_timedelta="iso8601") - - -class RoleInfo(BaseModel): - """Role information schema (for future use).""" - - id: int | None = None - name: str | None = None - permissions: List[str] | None = None diff --git a/superset/mcp_service/system/tool/__init__.py b/superset/mcp_service/system/tool/__init__.py index ee6ed6ae14b..3fd92c16654 100644 --- a/superset/mcp_service/system/tool/__init__.py +++ b/superset/mcp_service/system/tool/__init__.py @@ -16,3 +16,11 @@ # under the License. """System tools for MCP service.""" + +from .get_superset_instance_info import get_superset_instance_info +from .health_check import health_check + +__all__ = [ + "health_check", + "get_superset_instance_info", +] diff --git a/superset/mcp_service/system/tool/get_superset_instance_info.py b/superset/mcp_service/system/tool/get_superset_instance_info.py new file mode 100644 index 00000000000..4d64b334d69 --- /dev/null +++ b/superset/mcp_service/system/tool/get_superset_instance_info.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Get Superset instance high-level information FastMCP tool using configurable +InstanceInfoCore for flexible, extensible metrics calculation. +""" + +import logging +from typing import Any, Dict + +from fastmcp import Context + +from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook +from superset.mcp_service.mcp_core import InstanceInfoCore +from superset.mcp_service.system.schemas import ( + DashboardBreakdown, + DatabaseBreakdown, + GetSupersetInstanceInfoRequest, + InstanceInfo, + InstanceSummary, + PopularContent, + RecentActivity, +) + +logger = logging.getLogger(__name__) + + +def calculate_dashboard_breakdown( + base_counts: Dict[str, int], + time_metrics: Dict[str, Dict[str, int]], + dao_classes: Dict[str, Any], +) -> DashboardBreakdown: + """Calculate detailed dashboard breakdown metrics.""" + try: + from superset.daos.base import ColumnOperator, ColumnOperatorEnum + from superset.extensions import db + from superset.models.dashboard import Dashboard + + dashboard_dao = dao_classes["dashboards"] + + # Published vs unpublished + published_count = dashboard_dao.count( + column_operators=[ + ColumnOperator(col="published", opr=ColumnOperatorEnum.eq, value=True) + ] + ) + unpublished_count = base_counts.get("total_dashboards", 0) - published_count + + # Certified dashboards + certified_count = dashboard_dao.count( + column_operators=[ + ColumnOperator( + col="certified_by", opr=ColumnOperatorEnum.is_not_null, value=None + ) + ] + ) + + # Dashboards with/without charts + dashboards_with_charts = ( + db.session.query(Dashboard).join(Dashboard.slices).distinct().count() + ) + dashboards_without_charts = ( + base_counts.get("total_dashboards", 0) - dashboards_with_charts + ) + + return DashboardBreakdown( + published=published_count, + unpublished=unpublished_count, + certified=certified_count, + with_charts=dashboards_with_charts, + without_charts=dashboards_without_charts, + ) + except Exception: + # Return empty breakdown on error + return DashboardBreakdown( + published=0, + unpublished=0, + certified=0, + with_charts=0, + without_charts=0, + ) + + +def calculate_database_breakdown( + base_counts: Dict[str, int], + time_metrics: Dict[str, Dict[str, int]], + dao_classes: Dict[str, Any], +) -> DatabaseBreakdown: + """Calculate database type breakdown.""" + try: + from superset.extensions import db + from superset.models.core import Database + + # Get database types distribution + db_types = db.session.query( + Database.database_name, Database.sqlalchemy_uri + ).all() + + type_counts: Dict[str, int] = {} + for _name, uri in db_types: + if uri: + # Extract database type from SQLAlchemy URI + db_type = uri.split("://")[0] if "://" in uri else "unknown" + type_counts[db_type] = type_counts.get(db_type, 0) + 1 + else: + type_counts["unknown"] = type_counts.get("unknown", 0) + 1 + + return DatabaseBreakdown(by_type=type_counts) + except Exception: + # Return empty breakdown on error + return DatabaseBreakdown(by_type={}) + + +def calculate_instance_summary( + base_counts: Dict[str, int], + time_metrics: Dict[str, Dict[str, int]], + dao_classes: Dict[str, Any], +) -> InstanceSummary: + """Calculate instance summary with computed metrics.""" + try: + from flask_appbuilder.security.sqla.models import Role + + from superset.extensions import db + + # Add roles count (no DAO available) + total_roles = db.session.query(Role).count() + + # Calculate average charts per dashboard + total_dashboards = base_counts.get("total_dashboards", 0) + total_charts = base_counts.get("total_charts", 0) + avg_charts_per_dashboard = ( + (total_charts / total_dashboards) if total_dashboards > 0 else 0 + ) + + return InstanceSummary( + total_dashboards=total_dashboards, + total_charts=total_charts, + total_datasets=base_counts.get("total_datasets", 0), + total_databases=base_counts.get("total_databases", 0), + total_users=base_counts.get("total_users", 0), + total_roles=total_roles, + total_tags=base_counts.get("total_tags", 0), + avg_charts_per_dashboard=round(avg_charts_per_dashboard, 2), + ) + except Exception: + # Return empty summary on error + return InstanceSummary( + total_dashboards=0, + total_charts=0, + total_datasets=0, + total_databases=0, + total_users=0, + total_roles=0, + total_tags=0, + avg_charts_per_dashboard=0.0, + ) + + +def calculate_recent_activity( + base_counts: Dict[str, int], + time_metrics: Dict[str, Dict[str, int]], + dao_classes: Dict[str, Any], +) -> RecentActivity: + """Transform time metrics into RecentActivity format.""" + monthly = time_metrics.get("monthly", {}) + recent = time_metrics.get("recent", {}) + + return RecentActivity( + dashboards_created_last_30_days=monthly.get("dashboards_created", 0), + charts_created_last_30_days=monthly.get("charts_created", 0), + datasets_created_last_30_days=monthly.get("datasets_created", 0), + dashboards_modified_last_7_days=recent.get("dashboards_modified", 0), + charts_modified_last_7_days=recent.get("charts_modified", 0), + datasets_modified_last_7_days=recent.get("datasets_modified", 0), + ) + + +def calculate_popular_content( + base_counts: Dict[str, int], + time_metrics: Dict[str, Dict[str, int]], + dao_classes: Dict[str, Any], +) -> PopularContent: + """Calculate popular content metrics (placeholder implementation).""" + + return PopularContent( + top_tags=[], + top_creators=[], + ) + + +# Configure the instance info core +_instance_info_core = InstanceInfoCore( + dao_classes={ + "dashboards": None, # type: ignore[dict-item] # Will be set at runtime + "charts": None, # type: ignore[dict-item] + "datasets": None, # type: ignore[dict-item] + "databases": None, # type: ignore[dict-item] + "users": None, # type: ignore[dict-item] + "tags": None, # type: ignore[dict-item] + }, + output_schema=InstanceInfo, + metric_calculators={ + "instance_summary": calculate_instance_summary, + "recent_activity": calculate_recent_activity, + "dashboard_breakdown": calculate_dashboard_breakdown, + "database_breakdown": calculate_database_breakdown, + "popular_content": calculate_popular_content, + }, + time_windows={ + "recent": 7, + "monthly": 30, + "quarterly": 90, + }, + logger=logger, +) + + +@mcp.tool +@mcp_auth_hook +def get_superset_instance_info( + request: GetSupersetInstanceInfoRequest, ctx: Context +) -> InstanceInfo: + """Get Superset instance statistics. + + Returns counts, activity metrics, and database types. + """ + try: + # Import DAOs at runtime to avoid circular imports + from superset.daos.chart import ChartDAO + from superset.daos.dashboard import DashboardDAO + from superset.daos.database import DatabaseDAO + from superset.daos.dataset import DatasetDAO + from superset.daos.tag import TagDAO + from superset.daos.user import UserDAO + + # Configure DAO classes at runtime + _instance_info_core.dao_classes = { + "dashboards": DashboardDAO, + "charts": ChartDAO, + "datasets": DatasetDAO, + "databases": DatabaseDAO, + "users": UserDAO, + "tags": TagDAO, + } + + # Run the configurable core + return _instance_info_core.run_tool() + + except Exception as e: + error_msg = f"Unexpected error in instance info: {str(e)}" + logger.error(error_msg, exc_info=True) + raise diff --git a/superset/mcp_service/system/tool/health_check.py b/superset/mcp_service/system/tool/health_check.py index bc6e5a9d976..6197bfe9af1 100644 --- a/superset/mcp_service/system/tool/health_check.py +++ b/superset/mcp_service/system/tool/health_check.py @@ -20,15 +20,16 @@ import datetime import logging import platform -import time from superset.mcp_service.app import mcp +from superset.mcp_service.auth import mcp_auth_hook from superset.mcp_service.system.schemas import HealthCheckResponse logger = logging.getLogger(__name__) @mcp.tool +@mcp_auth_hook async def health_check() -> HealthCheckResponse: """ Simple health check tool for testing the MCP service. @@ -47,7 +48,6 @@ async def health_check() -> HealthCheckResponse: version="1.0.0", python_version=platform.python_version(), platform=platform.system(), - uptime_seconds=time.time(), # Simple uptime approximation ) logger.info("Health check completed successfully") @@ -63,5 +63,4 @@ async def health_check() -> HealthCheckResponse: version="1.0.0", python_version=platform.python_version(), platform=platform.system(), - uptime_seconds=0.0, ) diff --git a/superset/mcp_service/utils/cache_utils.py b/superset/mcp_service/utils/cache_utils.py new file mode 100644 index 00000000000..1622151f260 --- /dev/null +++ b/superset/mcp_service/utils/cache_utils.py @@ -0,0 +1,143 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Cache utilities for MCP tools. + +This module provides utilities for working with Superset's cache layers +and implementing cache control in MCP tools. +""" + +import logging +from typing import Any, Dict + +from superset.mcp_service.common.cache_schemas import CacheStatus + +logger = logging.getLogger(__name__) + + +def get_cache_status_from_result( + result: Dict[str, Any], force_refresh: bool = False +) -> CacheStatus: + """ + Extract cache status information from a Superset query result. + + Args: + result: Query result dictionary from Superset + force_refresh: Whether cache was force refreshed + + Returns: + CacheStatus object with cache usage information + """ + # Handle different result structures + if "queries" in result and len(result["queries"]) > 0: + query_result = result["queries"][0] + else: + query_result = result + + cache_hit = bool(query_result.get("is_cached", False)) + + # Convert cache age to seconds if available + cache_age_seconds = None + if cache_age := query_result.get("cache_dttm"): + try: + from datetime import datetime + + if isinstance(cache_age, str): + cache_dt = datetime.fromisoformat(cache_age.replace("Z", "+00:00")) + cache_age_seconds = int( + (datetime.now(cache_dt.tzinfo) - cache_dt).total_seconds() + ) + elif isinstance(cache_age, datetime): + cache_age_seconds = int( + (datetime.now(cache_age.tzinfo) - cache_age).total_seconds() + ) + except Exception as e: + logger.debug("Could not parse cache age: %s", e) + + return CacheStatus( + cache_hit=cache_hit, + cache_type="query" if cache_hit else "none", + cache_age_seconds=cache_age_seconds, + refreshed=force_refresh, + ) + + +def apply_cache_control_to_query_context( + query_context: Dict[str, Any], + use_cache: bool = True, + force_refresh: bool = False, + cache_timeout: int | None = None, +) -> Dict[str, Any]: + """ + Apply cache control parameters to a query context. + + Args: + query_context: Query context dictionary + use_cache: Whether to use cache + force_refresh: Whether to force refresh + cache_timeout: Cache timeout override + + Returns: + Modified query context with cache control applied + """ + if not use_cache or force_refresh: + query_context["force"] = True + + if cache_timeout is not None: + # Apply to all queries in the context + for query in query_context.get("queries", []): + query["cache_timeout"] = cache_timeout + + return query_context + + +def should_use_metadata_cache( + use_cache: bool = True, + refresh_metadata: bool = False, +) -> bool: + """ + Determine whether to use metadata cache based on cache control parameters. + + Args: + use_cache: Whether to use cache + refresh_metadata: Whether to refresh metadata + + Returns: + True if metadata cache should be used + """ + return use_cache and not refresh_metadata + + +def get_cache_key_info(cache_key: str | None) -> str | None: + """ + Get truncated cache key for debugging purposes. + + Args: + cache_key: Full cache key + + Returns: + Truncated cache key or None + """ + if not cache_key: + return None + + # Truncate long cache keys for readability + if len(cache_key) > 50: + return cache_key[:47] + "..." + + return cache_key diff --git a/superset/mcp_service/utils/error_builder.py b/superset/mcp_service/utils/error_builder.py new file mode 100644 index 00000000000..61df3eb9108 --- /dev/null +++ b/superset/mcp_service/utils/error_builder.py @@ -0,0 +1,369 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unified error builder for chart operations. +Consolidates error handling logic from multiple files. +""" + +import html +import logging +import re +from typing import Any, Dict, List + +from superset.mcp_service.common.error_schemas import ( + ChartGenerationError, + ValidationError, +) + +logger = logging.getLogger(__name__) + + +def _sanitize_user_input(value: Any) -> str: + """Sanitize user input to prevent XSS and injection attacks in error messages.""" + if value is None: + return "None" + + # Convert to string and limit length to prevent ReDoS attacks + str_value = str(value) + if len(str_value) > 200: + str_value = str_value[:200] + "...[truncated]" + + # HTML escape to prevent XSS + str_value = html.escape(str_value) + + # Check for dangerous HTML tags with simple substring checks + html_tags = [""] + str_lower = str_value.lower() + for tag in html_tags: + if tag in str_lower: + str_value = "[FILTERED]" + break + + # Check for dangerous URL schemes using regex with word boundaries + # This ensures we match actual URL schemes, not arbitrary substrings + dangerous_url_patterns = [ + r"\b(javascript|vbscript|data):", # URL schemes + r"on\w+\s*=", # Event handlers + ] + for pattern in dangerous_url_patterns: + if re.search(pattern, str_value, re.IGNORECASE): + str_value = "[FILTERED]" + break + + return str_value + + +def _sanitize_template_vars(vars_dict: Dict[str, Any]) -> Dict[str, Any]: + """Sanitize all variables before template formatting.""" + sanitized = {} + for key, value in vars_dict.items(): + # Only sanitize string-like values that could contain user input + if isinstance(value, (str, int, float)) or value is None: + sanitized[key] = _sanitize_user_input(value) + elif isinstance(value, (list, tuple)): + # Sanitize lists of strings + sanitized[key] = ", ".join( + [_sanitize_user_input(item) for item in value[:10]] + ) # Limit list size and convert to string + else: + # For other types, convert to string and sanitize + sanitized[key] = _sanitize_user_input(value) + return sanitized + + +class ChartErrorBuilder: + """Unified error builder for consistent error messages across chart operations.""" + + # Error templates organized by category + TEMPLATES = { + # Validation errors + "missing_field": { + "message": "Missing required field: {field}", + "details": "{field_description}", + "suggestions": [ + "Add the '{field}' field to your configuration", + "Check the API documentation for required fields", + "{specific_suggestion}", + ], + }, + "invalid_type": { + "message": "Invalid type for field '{field}'", + "details": "Expected {expected_type}, got {actual_type}", + "suggestions": [ + "Change '{field}' to be a {expected_type}", + "Example: {example}", + ], + }, + "invalid_value": { + "message": "Invalid value for '{field}'", + "details": "Value '{value}' is not allowed. {reason}", + "suggestions": [ + "Use one of the allowed values: {allowed_values}", + "{specific_suggestion}", + ], + }, + # Dataset errors + "dataset_not_found": { + "message": "Dataset not found: {dataset_id}", + "details": "No dataset found with identifier '{dataset_id}'. Please " + "verify the dataset ID or UUID is correct.", + "suggestions": [ + "Check that the dataset ID is correct", + "Verify you have access to this dataset", + "Use the list_datasets tool to find available datasets", + ], + }, + "column_not_found": { + "message": "Column '{column}' not found in dataset", + "details": "The column '{column}' does not exist in the dataset schema", + "suggestions": [ + "Check column name spelling and case sensitivity", + "Use get_dataset_info to see available columns", + "Did you mean: {suggestions}?", + ], + }, + # Runtime errors + "empty_result": { + "message": "Query would return no data", + "details": "{reason}", + "suggestions": [ + "Check your filter conditions", + "Verify the data exists for your criteria", + "Try broader filter values or remove some filters", + ], + }, + "performance_warning": { + "message": "Configuration may cause performance issues", + "details": "{reason}", + "suggestions": [ + "Consider adding filters to limit data", + "Use aggregations to reduce data volume", + "{specific_suggestion}", + ], + }, + # Chart-specific errors + "invalid_chart_type": { + "message": "Invalid chart type: '{chart_type}'", + "details": "Chart type must be either 'xy' or 'table'", + "suggestions": [ + "Use 'chart_type': 'xy' for line, bar, area, or scatter charts", + "Use 'chart_type': 'table' for tabular data display", + ], + }, + "incompatible_configuration": { + "message": "Chart configuration incompatible with data", + "details": "{reason}", + "suggestions": [ + "{primary_suggestion}", + "Consider using a different chart type", + "Modify your data selection or aggregation", + ], + }, + # Chart generation errors + "generation_failed": { + "message": "Chart generation failed: {reason}", + "details": "Failed to create {chart_type} chart for dataset {dataset_id}. " + "{reason}", + "suggestions": [ + "Check that the dataset exists and is accessible", + "Verify chart configuration is valid for the selected chart type", + "Ensure all referenced columns exist in the dataset", + "Check Superset logs for detailed error information", + ], + }, + } + + @classmethod + def build_error( + cls, + error_type: str, + template_key: str, + template_vars: Dict[str, Any] | None = None, + custom_suggestions: List[str] | None = None, + error_code: str | None = None, + validation_errors: List[ValidationError] | None = None, + ) -> ChartGenerationError: + """ + Build a standardized error using templates. + + Args: + error_type: Type of error for categorization + template_key: Key to error template + template_vars: Variables to format into template + custom_suggestions: Additional suggestions to append + error_code: Optional error code + validation_errors: Optional list of validation errors + + Returns: + ChartGenerationError with formatted message + """ + template = cls.TEMPLATES.get(template_key, {}) + # SECURITY FIX: Sanitize template variables to prevent XSS/injection + vars_dict = _sanitize_template_vars(template_vars or {}) + + message = cls._format_message(template, vars_dict) + details = cls._format_details(template, vars_dict) + suggestions = cls._format_suggestions(template, vars_dict, custom_suggestions) + error_code = cls._generate_error_code(error_code, template_key) + + return ChartGenerationError( + error_type=error_type, + message=message, + details=details, + suggestions=suggestions, + error_code=error_code, + validation_errors=validation_errors or [], + ) + + @classmethod + def _format_message( + cls, template: Dict[str, Any], vars_dict: Dict[str, Any] + ) -> str: + """Format the error message from template.""" + message_raw = template.get("message", "An error occurred") + message: str = ( + " ".join(message_raw) if isinstance(message_raw, list) else str(message_raw) + ) + if vars_dict: + try: + # SECURITY FIX: vars_dict is already sanitized by caller + message = message.format(**vars_dict) + except (KeyError, ValueError, TypeError) as e: + logger.warning("Template formatting failed: %s", e) + # Return safe fallback message + message = "An error occurred during chart operation" + return message + + @classmethod + def _format_details( + cls, template: Dict[str, Any], vars_dict: Dict[str, Any] + ) -> str: + """Format the error details from template.""" + details_raw = template.get("details", "") + details: str = ( + " ".join(details_raw) if isinstance(details_raw, list) else str(details_raw) + ) + if vars_dict and details: + try: + # SECURITY FIX: vars_dict is already sanitized by caller + details = details.format(**vars_dict) + except (KeyError, ValueError, TypeError) as e: + logger.warning("Template formatting failed: %s", e) + # Return safe fallback + details = "Additional error details unavailable" + return details + + @classmethod + def _format_suggestions( + cls, + template: Dict[str, Any], + vars_dict: Dict[str, Any], + custom_suggestions: List[str] | None, + ) -> List[str]: + """Format suggestions from template and add custom ones.""" + suggestions = [] + for suggestion in template.get("suggestions", []): + if vars_dict and "{" in suggestion: + try: + # SECURITY FIX: vars_dict is already sanitized by caller + suggestion = suggestion.format(**vars_dict) + if suggestion and suggestion != "None": + suggestions.append(suggestion) + except (KeyError, ValueError, TypeError): + # Skip malformed suggestions rather than exposing errors + continue + else: + suggestions.append(suggestion) + + if custom_suggestions: + # SECURITY FIX: Sanitize custom suggestions too + sanitized_custom = [ + _sanitize_user_input(s) for s in custom_suggestions[:5] + ] # Limit count + suggestions.extend(sanitized_custom) + + return suggestions[:10] # Limit total suggestions to prevent response bloat + + @classmethod + def _generate_error_code(cls, error_code: str | None, template_key: str) -> str: + """Generate error code if not provided.""" + if error_code: + return error_code + return f"CHART_{template_key.upper()}" + + @classmethod + def missing_field_error( + cls, + field: str, + field_description: str, + specific_suggestion: str | None = None, + ) -> ChartGenerationError: + """Build a missing field error.""" + return cls.build_error( + error_type="missing_field", + template_key="missing_field", + template_vars={ + "field": field, + "field_description": field_description, + "specific_suggestion": specific_suggestion + or f"Add '{field}' to your request", + }, + ) + + @classmethod + def invalid_type_error( + cls, + field: str, + expected_type: str, + actual_type: str, + example: str | None = None, + ) -> ChartGenerationError: + """Build an invalid type error.""" + return cls.build_error( + error_type="invalid_type", + template_key="invalid_type", + template_vars={ + "field": field, + "expected_type": expected_type, + "actual_type": actual_type, + "example": example or f"'{field}': <{expected_type}>", + }, + ) + + @classmethod + def column_not_found_error( + cls, column: str, suggestions: List[str] | None = None + ) -> ChartGenerationError: + """Build a column not found error.""" + suggestion_text = ( + ", ".join(suggestions[:3]) if suggestions else "Check available columns" + ) + return cls.build_error( + error_type="column_not_found", + template_key="column_not_found", + template_vars={"column": column, "suggestions": suggestion_text}, + ) + + @classmethod + def dataset_not_found_error(cls, dataset_id: Any) -> ChartGenerationError: + """Build a dataset not found error.""" + return cls.build_error( + error_type="dataset_not_found", + template_key="dataset_not_found", + template_vars={"dataset_id": str(dataset_id)}, + ) diff --git a/superset/mcp_service/utils/permissions_utils.py b/superset/mcp_service/utils/permissions_utils.py new file mode 100644 index 00000000000..a3864206532 --- /dev/null +++ b/superset/mcp_service/utils/permissions_utils.py @@ -0,0 +1,313 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Field-level permissions utilities for MCP service. +Provides functionality to filter sensitive data based on user permissions. +""" + +import logging +from typing import Any, List, Optional, Set + +from flask_appbuilder.security.sqla.models import User +from pydantic import BaseModel + +logger = logging.getLogger(__name__) + +# Define sensitive fields by object type +SENSITIVE_FIELDS = { + "dataset": { + "sql", # Raw SQL queries may contain sensitive logic + "extra", # May contain connection strings or credentials + "database_id", # Internal database references + "changed_by_fk", # Internal user references + "created_by_fk", # Internal user references + }, + "chart": { + "query_context", # May contain sensitive filters or parameters + "cache_key", # Internal cache references + "changed_by_fk", # Internal user references + "created_by_fk", # Internal user references + }, + "dashboard": { + "json_metadata", # May contain sensitive configuration + "position_json", # Internal layout data + "css", # May contain sensitive styling info + "changed_by_fk", # Internal user references + "created_by_fk", # Internal user references + }, + "common": { + "uuid", # Internal identifiers (keep for some use cases) + "changed_by_fk", # Internal user references + "created_by_fk", # Internal user references + }, +} + +# Permissions required to access sensitive fields +SENSITIVE_FIELD_PERMISSIONS = { + "sql": "can_sql_json", # SQL Lab permissions + "extra": "can_this_form_get", # Advanced form permissions + "database_id": "can_this_form_get", # Database access permissions + "query_context": "can_explore_json", # Explore permissions + "cache_key": "can_warm_up_cache", # Cache management permissions + "json_metadata": "can_this_form_get", # Advanced dashboard permissions + "position_json": "can_this_form_get", # Dashboard edit permissions + "css": "can_this_form_get", # Dashboard styling permissions +} + + +def get_current_user() -> Optional[User]: + """Get the current authenticated user.""" + try: + from flask import g + + return getattr(g, "user", None) + except Exception: + return None + + +def user_has_permission( + user: Optional[User], permission: str, resource: Optional[str] = None +) -> bool: + """ + Check if user has a specific permission. + + Args: + user: User object or None + permission: Permission name (e.g., 'can_sql_json') + resource: Resource name (e.g., 'Superset', 'Chart', etc.) + + Returns: + True if user has permission, False otherwise + """ + if not user: + return False + + try: + # Check if user is admin (has all permissions) + if hasattr(user, "roles"): + for role in user.roles: + if role.name in ("Admin", "admin"): + return True + + # Check specific permission + from superset import security_manager + + if resource: + return security_manager.has_access(permission, resource, user) + else: + # Check if user has permission on any resource + for pvm in user.get_permissions(): + if pvm.permission.name == permission: + return True + + return False + except Exception as e: + logger.warning( + "Error checking permission %s for user %s: %s", permission, user, e + ) + return False + + +def get_allowed_fields( + object_type: str, + user: Optional[User] = None, + requested_fields: Optional[List[str]] = None, +) -> Set[str]: + """ + Get the set of fields that the user is allowed to access for a given object type. + + Args: + object_type: Type of object ('dataset', 'chart', 'dashboard') + user: User object (if None, will try to get current user) + requested_fields: List of fields requested (if None, all allowed fields) + + Returns: + Set of allowed field names + """ + if not user: + user = get_current_user() + + # Get sensitive fields for this object type + sensitive_fields = SENSITIVE_FIELDS.get(object_type, set()) + sensitive_fields.update(SENSITIVE_FIELDS.get("common", set())) + + # If no user, only allow non-sensitive fields + if not user: + if requested_fields: + return set(requested_fields) - sensitive_fields + else: + # Return empty set - caller should use default safe fields + return set() + + # Check permissions for sensitive fields + allowed_fields = set() + + if requested_fields: + for field in requested_fields: + if field not in sensitive_fields: + # Non-sensitive field, always allowed + allowed_fields.add(field) + else: + # Check if user has permission for this sensitive field + required_permission = SENSITIVE_FIELD_PERMISSIONS.get(field) + if required_permission and user_has_permission( + user, required_permission + ): + allowed_fields.add(field) + elif not required_permission: + # No specific permission required, but still sensitive + # Allow for authenticated users (basic sensitivity) + allowed_fields.add(field) + else: + # No specific fields requested, return empty set + # Caller should specify default fields + return set() + + return allowed_fields + + +def filter_sensitive_data( + data: Any, + object_type: str, + user: Optional[User] = None, + allowed_fields: Optional[Set[str]] = None, +) -> Any: + """ + Filter sensitive data from an object based on user permissions. + + Args: + data: Data to filter (dict, Pydantic model, or list) + object_type: Type of object ('dataset', 'chart', 'dashboard') + user: User object (if None, will try to get current user) + allowed_fields: Pre-computed allowed fields (optimization) + + Returns: + Filtered data with sensitive fields removed + """ + if not data: + return data + + if not user: + user = get_current_user() + + # Handle different data types + if isinstance(data, list): + return [ + filter_sensitive_data(item, object_type, user, allowed_fields) + for item in data + ] + + if isinstance(data, BaseModel): + # Convert Pydantic model to dict for filtering + data_dict = data.model_dump() + filtered_dict = filter_sensitive_data( + data_dict, object_type, user, allowed_fields + ) + # Return as dict since we can't easily reconstruct the Pydantic model + return filtered_dict + + if not isinstance(data, dict): + # Not a dict-like object, return as-is + return data + + # Get allowed fields if not provided + if allowed_fields is None: + requested_fields = list(data.keys()) + allowed_fields = get_allowed_fields(object_type, user, requested_fields) + + # Filter the dictionary + filtered_data = {} + for key, value in data.items(): + if key in allowed_fields: + filtered_data[key] = value + else: + # Log when we filter out sensitive data + logger.debug( + "Filtered sensitive field '%s' for object type '%s'", key, object_type + ) + + return filtered_data + + +def apply_field_permissions_to_columns( + columns: List[str], object_type: str, user: Optional[User] = None +) -> List[str]: + """ + Filter a list of column names based on field-level permissions. + + Args: + columns: List of column names to filter + object_type: Type of object ('dataset', 'chart', 'dashboard') + user: User object (if None, will try to get current user) + + Returns: + Filtered list of allowed column names + """ + allowed_fields = get_allowed_fields(object_type, user, columns) + return [col for col in columns if col in allowed_fields] + + +class PermissionAwareSerializer: + """ + Wrapper for serializers that automatically applies field-level permissions. + """ + + def __init__(self, object_type: str, base_serializer: Any): + self.object_type = object_type + self.base_serializer = base_serializer + + def serialize( + self, obj: Any, columns: List[str], user: Optional[User] = None + ) -> Any: + """ + Serialize object with field-level permissions applied. + + Args: + obj: Object to serialize + columns: Requested columns + user: User object for permission checking + + Returns: + Serialized object with sensitive fields filtered + """ + # Filter columns based on permissions + allowed_columns = apply_field_permissions_to_columns( + columns, self.object_type, user + ) + + # Use base serializer with filtered columns + serialized = self.base_serializer(obj, allowed_columns) + + # Apply additional filtering to the serialized result + return filter_sensitive_data(serialized, self.object_type, user) + + +# Convenience functions for common object types +def filter_dataset_data(data: Any, user: Optional[User] = None) -> Any: + """Filter sensitive data from dataset objects.""" + return filter_sensitive_data(data, "dataset", user) + + +def filter_chart_data(data: Any, user: Optional[User] = None) -> Any: + """Filter sensitive data from chart objects.""" + return filter_sensitive_data(data, "chart", user) + + +def filter_dashboard_data(data: Any, user: Optional[User] = None) -> Any: + """Filter sensitive data from dashboard objects.""" + return filter_sensitive_data(data, "dashboard", user) diff --git a/superset/mcp_service/utils/retry_utils.py b/superset/mcp_service/utils/retry_utils.py new file mode 100644 index 00000000000..20ef581a0f2 --- /dev/null +++ b/superset/mcp_service/utils/retry_utils.py @@ -0,0 +1,341 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Retry utilities for handling transient failures in MCP service operations. +""" + +import asyncio +import functools +import logging +import secrets +import time +from typing import Any, Callable, Type, TypeVar + +from sqlalchemy.exc import OperationalError, TimeoutError +from starlette.exceptions import HTTPException + +logger = logging.getLogger(__name__) + +T = TypeVar("T") + +# Default retryable exceptions +DEFAULT_RETRYABLE_EXCEPTIONS = ( + OperationalError, # Database connection issues + TimeoutError, # Database timeouts + ConnectionError, # Network issues + OSError, # File system issues (for screenshots) +) + + +def exponential_backoff( + attempt: int, base_delay: float = 1.0, max_delay: float = 60.0, jitter: bool = True +) -> float: + """ + Calculate exponential backoff delay with optional jitter. + + Args: + attempt: Current attempt number (0-based) + base_delay: Base delay in seconds + max_delay: Maximum delay in seconds + jitter: Whether to add random jitter to avoid thundering herd + + Returns: + Delay in seconds + """ + delay = base_delay * (2**attempt) + delay = min(delay, max_delay) + + if jitter: + # Add up to 25% jitter using cryptographically secure random + jitter_amount = delay * 0.25 + random_gen = secrets.SystemRandom() + delay += random_gen.uniform(-jitter_amount, jitter_amount) + + return max(0, delay) + + +def retry_on_exception( + max_attempts: int = 3, + base_delay: float = 1.0, + max_delay: float = 60.0, + exceptions: tuple[Type[Exception], ...] = DEFAULT_RETRYABLE_EXCEPTIONS, + jitter: bool = True, +) -> Callable[[Callable[..., T]], Callable[..., T]]: + """ + Decorator to retry function calls on specific exceptions with exponential backoff. + + Args: + max_attempts: Maximum number of attempts (including initial attempt) + base_delay: Base delay in seconds for exponential backoff + max_delay: Maximum delay in seconds between retries + exceptions: Tuple of exception types to retry on + jitter: Whether to add random jitter to backoff delays + + Returns: + Decorated function + """ + + def decorator(func: Callable[..., T]) -> Callable[..., T]: + @functools.wraps(func) + def wrapper(*args: Any, **kwargs: Any) -> T: + last_exception = None + + for attempt in range(max_attempts): + try: + return func(*args, **kwargs) + except exceptions as e: # noqa: PERF203 + last_exception = e + + if attempt == max_attempts - 1: + # Last attempt, don't wait + break + + delay = exponential_backoff(attempt, base_delay, max_delay, jitter) + + logger.warning( + "Attempt %s/%s failed for %s: %s. Retrying in %.2fs...", + attempt + 1, + max_attempts, + func.__name__, + e, + delay, + ) + + time.sleep(delay) + except Exception as e: + # Non-retryable exception, fail immediately + logger.error("Non-retryable exception in %s: %s", func.__name__, e) + raise + + # All attempts failed + if last_exception is not None: + logger.error( + "All %s attempts failed for %s: %s", + max_attempts, + func.__name__, + last_exception, + ) + raise last_exception + logger.error("All %s attempts failed for %s", max_attempts, func.__name__) + raise RuntimeError( + f"All {max_attempts} attempts failed for {func.__name__}" + ) + + return wrapper + + return decorator + + +def async_retry_on_exception( + max_attempts: int = 3, + base_delay: float = 1.0, + max_delay: float = 60.0, + exceptions: tuple[Type[Exception], ...] = DEFAULT_RETRYABLE_EXCEPTIONS, + jitter: bool = True, +) -> Callable[[Callable[..., Any]], Callable[..., Any]]: + """ + Async version of retry_on_exception decorator. + """ + + def decorator(func: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(func) + async def wrapper(*args: Any, **kwargs: Any) -> Any: + last_exception = None + + for attempt in range(max_attempts): + try: + return await func(*args, **kwargs) + except exceptions as e: # noqa: PERF203 + last_exception = e + + if attempt == max_attempts - 1: + # Last attempt, don't wait + break + + delay = exponential_backoff(attempt, base_delay, max_delay, jitter) + + logger.warning( + "Attempt %s/%s failed for %s: %s. Retrying in %.2fs...", + attempt + 1, + max_attempts, + func.__name__, + e, + delay, + ) + + await asyncio.sleep(delay) + except Exception as e: + # Non-retryable exception, fail immediately + logger.error("Non-retryable exception in %s: %s", func.__name__, e) + raise + + # All attempts failed + if last_exception is not None: + logger.error( + "All %s attempts failed for %s: %s", + max_attempts, + func.__name__, + last_exception, + ) + raise last_exception + logger.error("All %s attempts failed for %s", max_attempts, func.__name__) + raise RuntimeError( + f"All {max_attempts} attempts failed for {func.__name__}" + ) + + return wrapper + + return decorator + + +class RetryableOperation: + """ + Context manager for retryable operations with custom logic. + """ + + def __init__( + self, + operation_name: str, + max_attempts: int = 3, + base_delay: float = 1.0, + max_delay: float = 60.0, + exceptions: tuple[Type[Exception], ...] = DEFAULT_RETRYABLE_EXCEPTIONS, + jitter: bool = True, + ) -> None: + self.operation_name = operation_name + self.max_attempts = max_attempts + self.base_delay = base_delay + self.max_delay = max_delay + self.exceptions = exceptions + self.jitter = jitter + self.current_attempt = 0 + self.last_exception: Exception | None = None + + def __enter__(self) -> "RetryableOperation": + return self + + def __exit__( + self, + exc_type: Type[Exception] | None, + exc_val: Exception | None, + exc_tb: Any, + ) -> bool: + if exc_type is None: + # No exception, operation succeeded + return False + + if not issubclass(exc_type, self.exceptions): + # Non-retryable exception + logger.error( + "Non-retryable exception in %s: %s", self.operation_name, exc_val + ) + return False + + self.last_exception = exc_val + self.current_attempt += 1 + + if self.current_attempt >= self.max_attempts: + # Max attempts reached + logger.error( + "All %s attempts failed for %s. ", + self.max_attempts, + self.operation_name, + ) + return False + + # Calculate delay and wait + delay = exponential_backoff( + self.current_attempt - 1, self.base_delay, self.max_delay, self.jitter + ) + + logger.warning( + "Attempt %s/%s failed for %s: %s. Retrying in %.2fs...", + self.current_attempt, + self.max_attempts, + self.operation_name, + exc_val, + delay, + ) + + time.sleep(delay) + return True # Suppress the exception and continue + + def should_retry(self) -> bool: + """Check if we should continue retrying""" + return self.current_attempt < self.max_attempts + + +# Convenience functions for common operations +def retry_database_operation( + func: Callable[..., T], *args: Any, max_attempts: int = 3, **kwargs: Any +) -> T: + """ + Execute a database operation with retry logic. + """ + + @retry_on_exception( + max_attempts=max_attempts, + exceptions=(OperationalError, TimeoutError), + base_delay=0.5, + max_delay=30.0, + ) + def _wrapped() -> T: + return func(*args, **kwargs) + + return _wrapped() + + +async def async_retry_database_operation( + func: Callable[..., Any], *args: Any, max_attempts: int = 3, **kwargs: Any +) -> Any: + """ + Execute an async database operation with retry logic. + """ + + @async_retry_on_exception( + max_attempts=max_attempts, + exceptions=(OperationalError, TimeoutError), + base_delay=0.5, + max_delay=30.0, + ) + async def _wrapped() -> Any: + return await func(*args, **kwargs) + + return await _wrapped() + + +def retry_screenshot_operation( + func: Callable[..., T], + *args: Any, + max_attempts: int = 2, # Screenshots are expensive, fewer retries + **kwargs: Any, +) -> T: + """ + Execute a screenshot operation with retry logic. + """ + + @retry_on_exception( + max_attempts=max_attempts, + exceptions=(OSError, ConnectionError, HTTPException), + base_delay=2.0, # Longer initial delay for screenshots + max_delay=30.0, + ) + def _wrapped() -> T: + return func(*args, **kwargs) + + return _wrapped() diff --git a/superset/mcp_service/utils/url_utils.py b/superset/mcp_service/utils/url_utils.py new file mode 100644 index 00000000000..17f27a5aee8 --- /dev/null +++ b/superset/mcp_service/utils/url_utils.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +URL utilities for MCP service +""" + +from flask import current_app + + +def get_superset_base_url() -> str: + """ + Get the Superset base URL from configuration. + + Returns: + Base URL for Superset web server (e.g., "http://localhost:9001") + """ + # Default fallback to localhost:9001 + default_url = "http://localhost:9001" + + try: + # Try to get from configuration + config = current_app.config + + # Check for SUPERSET_WEBSERVER_ADDRESS first + webserver_address = config.get("SUPERSET_WEBSERVER_ADDRESS") + if webserver_address: + return webserver_address + + # Fallback to other potential config keys + public_role_like_gamma = config.get("PUBLIC_ROLE_LIKE_GAMMA", False) + if public_role_like_gamma: + # If public access is enabled, might be on a different host + webserver_protocol = config.get("ENABLE_PROXY_FIX", False) + protocol = "https" if webserver_protocol else "http" + host = config.get("WEBSERVER_HOST", "localhost") + port = config.get("WEBSERVER_PORT", 9001) + return f"{protocol}://{host}:{port}" + + return default_url + + except Exception: + # If we can't access Flask config (e.g., outside app context), + # return default + return default_url + + +def get_mcp_service_url() -> str: + """ + Get the MCP service base URL where screenshot endpoints are served. + + The MCP service auto-detects its own host and port since it's running + this code. Falls back to explicit configuration or default port. + + Returns: + Base URL for MCP service (always independent of Superset URL) + """ + try: + # Try to auto-detect from Flask request context + from flask import request + + if request: + # Get the host and port from the current request + scheme = request.scheme # http or https + host = request.host # includes port if non-standard + return f"{scheme}://{host}" + + except (RuntimeError, AttributeError): + # Not in request context or Flask not available + pass + + try: + # Check for explicit MCP_SERVICE_URL in config + config = current_app.config + mcp_service_url = config.get("MCP_SERVICE_URL") + if mcp_service_url: + return mcp_service_url + + except Exception as e: + # Log and fall back if config access fails + import logging + + logging.getLogger(__name__).debug("Config access failed: %s", e) + + # Always fallback to MCP service default port (never use Superset URL) + return "http://localhost:5008" + + +def get_chart_screenshot_url(chart_id: int | str) -> str: + """ + Generate a screenshot URL for a chart using the MCP service. + + Args: + chart_id: Chart ID (numeric or string) + + Returns: + Complete URL to the chart screenshot endpoint + """ + mcp_base = get_mcp_service_url() + return f"{mcp_base}/screenshot/chart/{chart_id}.png" + + +def get_explore_screenshot_url(form_data_key: str) -> str: + """ + Generate a screenshot URL for an explore view using the MCP service. + + Args: + form_data_key: Form data key for the explore view + + Returns: + Complete URL to the explore screenshot endpoint + """ + mcp_base = get_mcp_service_url() + return f"{mcp_base}/screenshot/explore/{form_data_key}.png" diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py new file mode 100644 index 00000000000..7a297e6e5b7 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -0,0 +1,160 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP chart schema validation. +""" + +import pytest +from pydantic import ValidationError + +from superset.mcp_service.chart.schemas import ( + ColumnRef, + TableChartConfig, + XYChartConfig, +) + + +class TestTableChartConfig: + """Test TableChartConfig validation.""" + + def test_duplicate_labels_rejected(self) -> None: + """Test that TableChartConfig rejects duplicate labels.""" + with pytest.raises(ValidationError, match="Duplicate column/metric labels"): + TableChartConfig( + columns=[ + ColumnRef(name="product_line", label="product_line"), + ColumnRef(name="sales", aggregate="SUM", label="product_line"), + ] + ) + + def test_unique_labels_accepted(self) -> None: + """Test that TableChartConfig accepts unique labels.""" + config = TableChartConfig( + columns=[ + ColumnRef(name="product_line", label="Product Line"), + ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + ] + ) + assert len(config.columns) == 2 + + +class TestXYChartConfig: + """Test XYChartConfig validation.""" + + def test_different_labels_accepted(self) -> None: + """Test that different labels for x and y are accepted.""" + config = XYChartConfig( + x=ColumnRef(name="product_line"), # Label: "product_line" + y=[ + ColumnRef( + name="product_line", aggregate="COUNT" + ), # Label: "COUNT(product_line)" + ], + ) + assert config.x.name == "product_line" + assert config.y[0].aggregate == "COUNT" + + def test_explicit_duplicate_label_rejected(self) -> None: + """Test that explicit duplicate labels are rejected.""" + with pytest.raises(ValidationError, match="Duplicate column/metric labels"): + XYChartConfig( + x=ColumnRef(name="product_line"), + y=[ColumnRef(name="sales", label="product_line")], + ) + + def test_duplicate_y_axis_labels_rejected(self) -> None: + """Test that duplicate y-axis labels are rejected.""" + with pytest.raises(ValidationError, match="Duplicate column/metric labels"): + XYChartConfig( + x=ColumnRef(name="date"), + y=[ + ColumnRef(name="sales", aggregate="SUM"), + ColumnRef(name="revenue", aggregate="SUM", label="SUM(sales)"), + ], + ) + + def test_unique_labels_accepted(self) -> None: + """Test that unique labels are accepted.""" + config = XYChartConfig( + x=ColumnRef(name="date", label="Order Date"), + y=[ + ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + ColumnRef(name="profit", aggregate="AVG", label="Average Profit"), + ], + ) + assert len(config.y) == 2 + + def test_group_by_duplicate_with_x_rejected(self) -> None: + """Test that group_by conflicts with x are rejected.""" + with pytest.raises(ValidationError, match="Duplicate column/metric labels"): + XYChartConfig( + x=ColumnRef(name="region"), + y=[ColumnRef(name="sales", aggregate="SUM")], + group_by=ColumnRef(name="category", label="region"), + ) + + def test_realistic_chart_configurations(self) -> None: + """Test realistic chart configurations.""" + # This should work - COUNT(product_line) != product_line + config = XYChartConfig( + x=ColumnRef(name="product_line"), + y=[ + ColumnRef(name="product_line", aggregate="COUNT"), + ColumnRef(name="sales", aggregate="SUM"), + ], + ) + assert config.x.name == "product_line" + assert len(config.y) == 2 + + def test_time_series_chart_configuration(self) -> None: + """Test time series chart configuration works.""" + # This should PASS now - the chart creation logic fixes duplicates + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="order_date"), + y=[ + ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"), + ], + kind="line", + ) + assert config.x.name == "order_date" + assert config.kind == "line" + + def test_time_series_with_custom_x_axis_label(self) -> None: + """Test time series chart with custom x-axis label.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="order_date", label="Order Date"), + y=[ + ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + ColumnRef(name="price_each", aggregate="AVG", label="Avg Price"), + ], + kind="line", + ) + assert config.x.label == "Order Date" + + def test_area_chart_configuration(self) -> None: + """Test area chart configuration.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="category"), + y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")], + kind="area", + ) + assert config.kind == "area" diff --git a/tests/unit_tests/mcp_service/chart/test_chart_utils.py b/tests/unit_tests/mcp_service/chart/test_chart_utils.py new file mode 100644 index 00000000000..84e95812374 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_chart_utils.py @@ -0,0 +1,465 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for chart utilities module""" + +from unittest.mock import patch + +import pytest + +from superset.mcp_service.chart.chart_utils import ( + create_metric_object, + generate_chart_name, + generate_explore_link, + map_config_to_form_data, + map_filter_operator, + map_table_config, + map_xy_config, +) +from superset.mcp_service.chart.schemas import ( + AxisConfig, + ColumnRef, + FilterConfig, + LegendConfig, + TableChartConfig, + XYChartConfig, +) + + +class TestCreateMetricObject: + """Test create_metric_object function""" + + def test_create_metric_object_with_aggregate(self) -> None: + """Test creating metric object with specified aggregate""" + col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue") + result = create_metric_object(col) + + assert result["aggregate"] == "SUM" + assert result["column"]["column_name"] == "revenue" + assert result["label"] == "Total Revenue" + assert result["optionName"] == "metric_revenue" + assert result["expressionType"] == "SIMPLE" + + def test_create_metric_object_default_aggregate(self) -> None: + """Test creating metric object with default aggregate""" + col = ColumnRef(name="orders") + result = create_metric_object(col) + + assert result["aggregate"] == "SUM" + assert result["column"]["column_name"] == "orders" + assert result["label"] == "SUM(orders)" + assert result["optionName"] == "metric_orders" + + +class TestMapFilterOperator: + """Test map_filter_operator function""" + + def test_map_filter_operators(self) -> None: + """Test mapping of various filter operators""" + assert map_filter_operator("=") == "==" + assert map_filter_operator(">") == ">" + assert map_filter_operator("<") == "<" + assert map_filter_operator(">=") == ">=" + assert map_filter_operator("<=") == "<=" + assert map_filter_operator("!=") == "!=" + + def test_map_filter_operator_unknown(self) -> None: + """Test mapping of unknown operator returns original""" + assert map_filter_operator("UNKNOWN") == "UNKNOWN" + + +class TestMapTableConfig: + """Test map_table_config function""" + + def test_map_table_config_basic(self) -> None: + """Test basic table config mapping with aggregated columns""" + config = TableChartConfig( + columns=[ + ColumnRef(name="product", aggregate="COUNT"), + ColumnRef(name="revenue", aggregate="SUM"), + ] + ) + + result = map_table_config(config) + + assert result["viz_type"] == "table" + assert result["query_mode"] == "aggregate" + # Aggregated columns should be in metrics, not all_columns + assert "all_columns" not in result + assert len(result["metrics"]) == 2 + assert result["metrics"][0]["aggregate"] == "COUNT" + assert result["metrics"][1]["aggregate"] == "SUM" + + def test_map_table_config_raw_columns(self) -> None: + """Test table config mapping with raw columns (no aggregates)""" + config = TableChartConfig( + columns=[ + ColumnRef(name="product"), + ColumnRef(name="category"), + ] + ) + + result = map_table_config(config) + + assert result["viz_type"] == "table" + assert result["query_mode"] == "raw" + # Raw columns should be in all_columns + assert result["all_columns"] == ["product", "category"] + assert "metrics" not in result + + def test_map_table_config_with_filters(self) -> None: + """Test table config mapping with filters""" + config = TableChartConfig( + columns=[ColumnRef(name="product")], + filters=[FilterConfig(column="status", op="=", value="active")], + ) + + result = map_table_config(config) + + assert "adhoc_filters" in result + assert len(result["adhoc_filters"]) == 1 + filter_obj = result["adhoc_filters"][0] + assert filter_obj["subject"] == "status" + assert filter_obj["operator"] == "==" + assert filter_obj["comparator"] == "active" + assert filter_obj["expressionType"] == "SIMPLE" + + def test_map_table_config_with_sort(self) -> None: + """Test table config mapping with sort""" + config = TableChartConfig( + columns=[ColumnRef(name="product")], sort_by=["product", "revenue"] + ) + + result = map_table_config(config) + assert result["order_by_cols"] == ["product", "revenue"] + + +class TestMapXYConfig: + """Test map_xy_config function""" + + def test_map_xy_config_line_chart(self) -> None: + """Test XY config mapping for line chart""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="line", + ) + + result = map_xy_config(config) + + assert result["viz_type"] == "echarts_timeseries_line" + assert result["x_axis"] == "date" + assert len(result["metrics"]) == 1 + assert result["metrics"][0]["aggregate"] == "SUM" + + def test_map_xy_config_with_groupby(self) -> None: + """Test XY config mapping with group by""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue")], + kind="bar", + group_by=ColumnRef(name="region"), + ) + + result = map_xy_config(config) + + assert result["viz_type"] == "echarts_timeseries_bar" + assert result["groupby"] == ["region"] + + def test_map_xy_config_with_axes(self) -> None: + """Test XY config mapping with axis configurations""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue")], + kind="area", + x_axis=AxisConfig(title="Date", format="%Y-%m-%d"), + y_axis=AxisConfig(title="Revenue", scale="log", format="$,.2f"), + ) + + result = map_xy_config(config) + + assert result["viz_type"] == "echarts_area" + assert result["x_axis_title"] == "Date" + assert result["x_axis_format"] == "%Y-%m-%d" + assert result["y_axis_title"] == "Revenue" + assert result["y_axis_format"] == "$,.2f" + assert result["y_axis_scale"] == "log" + + def test_map_xy_config_with_legend(self) -> None: + """Test XY config mapping with legend configuration""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue")], + kind="scatter", + legend=LegendConfig(show=False, position="top"), + ) + + result = map_xy_config(config) + + assert result["viz_type"] == "echarts_timeseries_scatter" + assert result["show_legend"] is False + assert result["legend_orientation"] == "top" + + +class TestMapConfigToFormData: + """Test map_config_to_form_data function""" + + def test_map_table_config_type(self) -> None: + """Test mapping table config type""" + config = TableChartConfig(columns=[ColumnRef(name="test")]) + result = map_config_to_form_data(config) + assert result["viz_type"] == "table" + + def test_map_xy_config_type(self) -> None: + """Test mapping XY config type""" + config = XYChartConfig( + x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line" + ) + result = map_config_to_form_data(config) + assert result["viz_type"] == "echarts_timeseries_line" + + def test_map_unsupported_config_type(self) -> None: + """Test mapping unsupported config type raises error""" + with pytest.raises(ValueError, match="Unsupported config type"): + map_config_to_form_data("invalid_config") # type: ignore + + +class TestGenerateChartName: + """Test generate_chart_name function""" + + def test_generate_table_chart_name(self) -> None: + """Test generating name for table chart""" + config = TableChartConfig( + columns=[ + ColumnRef(name="product"), + ColumnRef(name="revenue"), + ] + ) + + result = generate_chart_name(config) + assert result == "Table Chart - product, revenue" + + def test_generate_xy_chart_name(self) -> None: + """Test generating name for XY chart""" + config = XYChartConfig( + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue"), ColumnRef(name="orders")], + kind="line", + ) + + result = generate_chart_name(config) + assert result == "Line Chart - date vs revenue, orders" + + def test_generate_chart_name_unsupported(self) -> None: + """Test generating name for unsupported config type""" + result = generate_chart_name("invalid_config") # type: ignore + assert result == "Chart" + + +class TestGenerateExploreLink: + """Test generate_explore_link function""" + + @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") + def test_generate_explore_link_uses_base_url(self, mock_get_base_url) -> None: + """Test that generate_explore_link uses the configured base URL""" + from urllib.parse import urlparse + + mock_get_base_url.return_value = "https://superset.example.com" + form_data = {"viz_type": "table", "metrics": ["count"]} + + result = generate_explore_link("123", form_data) + + # Should use the configured base URL - use urlparse to avoid CodeQL warning + parsed_url = urlparse(result) + expected_netloc = "superset.example.com" + assert parsed_url.scheme == "https" + assert parsed_url.netloc == expected_netloc + assert "/explore/" in parsed_url.path + assert "datasource_id=123" in result + + @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") + def test_generate_explore_link_fallback_url(self, mock_get_base_url) -> None: + """Test generate_explore_link returns fallback URL when dataset not found""" + mock_get_base_url.return_value = "http://localhost:9001" + form_data = {"viz_type": "table"} + + # Mock dataset not found scenario + with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None): + result = generate_explore_link("999", form_data) + + assert ( + result + == "http://localhost:9001/explore/?datasource_type=table&datasource_id=999" + ) + + @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") + @patch("superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand") + def test_generate_explore_link_with_form_data_key( + self, mock_command, mock_get_base_url + ) -> None: + """Test generate_explore_link creates form_data_key when dataset exists""" + mock_get_base_url.return_value = "http://localhost:9001" + mock_command.return_value.run.return_value = "test_form_data_key" + + # Mock dataset exists + mock_dataset = type("Dataset", (), {"id": 123})() + with patch( + "superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_dataset + ): + result = generate_explore_link(123, {"viz_type": "table"}) + + assert ( + result == "http://localhost:9001/explore/?form_data_key=test_form_data_key" + ) + mock_command.assert_called_once() + + @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") + def test_generate_explore_link_exception_handling(self, mock_get_base_url) -> None: + """Test generate_explore_link handles exceptions gracefully""" + mock_get_base_url.return_value = "http://localhost:9001" + + # Mock exception during form_data creation + with patch( + "superset.daos.dataset.DatasetDAO.find_by_id", + side_effect=Exception("DB error"), + ): + result = generate_explore_link("123", {"viz_type": "table"}) + + # Should fallback to basic URL + assert ( + result + == "http://localhost:9001/explore/?datasource_type=table&datasource_id=123" + ) + + +class TestCriticalBugFixes: + """Test critical bug fixes for chart utilities.""" + + def test_time_series_aggregation_fix(self) -> None: + """Test that time series charts preserve temporal dimension.""" + # Create a time series chart configuration + config = XYChartConfig( + chart_type="xy", + kind="line", + x=ColumnRef(name="order_date"), + y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")], + ) + + form_data = map_xy_config(config) + + # Verify the fix: x_axis should be set correctly + assert form_data["x_axis"] == "order_date" + + # Verify the fix: groupby should not duplicate x_axis + # This prevents the "Duplicate column/metric labels" error + assert "groupby" not in form_data or "order_date" not in form_data.get( + "groupby", [] + ) + + # Verify chart type mapping + assert form_data["viz_type"] == "echarts_timeseries_line" + + def test_time_series_with_explicit_group_by(self) -> None: + """Test time series with explicit group_by different from x_axis.""" + config = XYChartConfig( + chart_type="xy", + kind="line", + x=ColumnRef(name="order_date"), + y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")], + group_by=ColumnRef(name="category"), + ) + + form_data = map_xy_config(config) + + # Verify x_axis is set + assert form_data["x_axis"] == "order_date" + + # Verify groupby only contains the explicit group_by, not x_axis + assert form_data.get("groupby") == ["category"] + assert "order_date" not in form_data.get("groupby", []) + + def test_duplicate_label_prevention(self) -> None: + """Test that duplicate column/metric labels are prevented.""" + # This configuration would previously cause: + # "Duplicate column/metric labels: 'price_each'" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="price_each", label="Price Each"), # Custom label + y=[ + ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + ColumnRef(name="quantity", aggregate="COUNT", label="Order Count"), + ], + group_by=ColumnRef(name="price_each"), # Same column as x_axis + kind="line", + ) + + form_data = map_xy_config(config) + + # Verify the fix: x_axis is set + assert form_data["x_axis"] == "price_each" + + # Verify the fix: groupby is empty because group_by == x_axis + # This prevents the duplicate label error + assert "groupby" not in form_data or not form_data["groupby"] + + def test_metric_object_creation_with_labels(self) -> None: + """Test that metric objects are created correctly with proper labels.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ + ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + ColumnRef(name="profit", aggregate="AVG"), # No custom label + ], + kind="bar", + ) + + form_data = map_xy_config(config) + + # Verify metrics are created correctly + metrics = form_data["metrics"] + assert len(metrics) == 2 + + # First metric with custom label + assert metrics[0]["label"] == "Total Sales" + assert metrics[0]["aggregate"] == "SUM" + assert metrics[0]["column"]["column_name"] == "sales" + + # Second metric with auto-generated label + assert metrics[1]["label"] == "AVG(profit)" + assert metrics[1]["aggregate"] == "AVG" + assert metrics[1]["column"]["column_name"] == "profit" + + def test_chart_type_mapping_comprehensive(self) -> None: + """Test that chart types are mapped correctly to Superset viz types.""" + test_cases = [ + ("line", "echarts_timeseries_line"), + ("bar", "echarts_timeseries_bar"), + ("area", "echarts_area"), + ("scatter", "echarts_timeseries_scatter"), + ] + + for kind, expected_viz_type in test_cases: + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="value", aggregate="SUM")], + kind=kind, + ) + + form_data = map_xy_config(config) + assert form_data["viz_type"] == expected_viz_type diff --git a/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py b/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py new file mode 100644 index 00000000000..d5769c82d5d --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/tool/test_generate_chart.py @@ -0,0 +1,268 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP generate_chart tool +""" + +import pytest + +from superset.mcp_service.chart.schemas import ( + AxisConfig, + ColumnRef, + FilterConfig, + GenerateChartRequest, + LegendConfig, + TableChartConfig, + XYChartConfig, +) + + +class TestGenerateChart: + """Tests for generate_chart MCP tool.""" + + @pytest.mark.asyncio + async def test_generate_chart_request_structure(self): + """Test that chart generation request structures are properly formed.""" + # Table chart request + table_config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="region", label="Region"), + ColumnRef(name="sales", label="Sales", aggregate="SUM"), + ], + filters=[FilterConfig(column="year", op="=", value="2024")], + sort_by=["sales"], + ) + table_request = GenerateChartRequest(dataset_id="1", config=table_config) + assert table_request.dataset_id == "1" + assert table_request.config.chart_type == "table" + assert len(table_request.config.columns) == 2 + assert table_request.config.columns[0].name == "region" + assert table_request.config.columns[1].aggregate == "SUM" + + # XY chart request + xy_config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="value", aggregate="SUM")], + kind="line", + group_by=ColumnRef(name="category"), + x_axis=AxisConfig(title="Date", format="smart_date"), + y_axis=AxisConfig(title="Value", format="$,.2f"), + legend=LegendConfig(show=True, position="top"), + ) + xy_request = GenerateChartRequest(dataset_id="2", config=xy_config) + assert xy_request.config.chart_type == "xy" + assert xy_request.config.x.name == "date" + assert xy_request.config.y[0].aggregate == "SUM" + assert xy_request.config.kind == "line" + assert xy_request.config.x_axis.title == "Date" + assert xy_request.config.legend.show is True + + @pytest.mark.asyncio + async def test_generate_chart_validation_error_handling(self): + """Test that validation errors are properly structured.""" + + # Create a validation error with the correct structure + validation_error_entry = { + "field": "x_axis", + "provided_value": "invalid_col", + "error_type": "column_not_found", + "message": "Column 'invalid_col' not found", + } + + # Test that validation error structure is correct + assert validation_error_entry["field"] == "x_axis" + assert validation_error_entry["error_type"] == "column_not_found" + + @pytest.mark.asyncio + async def test_chart_config_variations(self): + """Test various chart configuration options.""" + # Test all chart types + chart_types = ["line", "bar", "area", "scatter"] + for chart_type in chart_types: + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x_col"), + y=[ColumnRef(name="y_col")], + kind=chart_type, + ) + assert config.kind == chart_type + + # Test multiple Y columns + multi_y_config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ + ColumnRef(name="sales", aggregate="SUM"), + ColumnRef(name="profit", aggregate="AVG"), + ColumnRef(name="quantity", aggregate="COUNT"), + ], + kind="line", + ) + assert len(multi_y_config.y) == 3 + assert multi_y_config.y[1].aggregate == "AVG" + + # Test filter operators + operators = ["=", "!=", ">", ">=", "<", "<="] + filters = [FilterConfig(column="col", op=op, value="val") for op in operators] + for i, f in enumerate(filters): + assert f.op == operators[i] + + @pytest.mark.asyncio + async def test_generate_chart_response_structure(self): + """Test the expected response structure for chart generation.""" + # The response should contain these fields + _ = { + "chart": { + "id": int, + "slice_name": str, + "viz_type": str, + "url": str, + "uuid": str, + "saved": bool, + }, + "error": None, + "success": bool, + "schema_version": str, + "api_version": str, + } + + # When chart creation succeeds, these additional fields may be present + _ = [ + "previews", + "capabilities", + "semantics", + "explore_url", + "form_data_key", + "api_endpoints", + "performance", + "accessibility", + ] + + # This is just a structural test - actual integration tests would verify + # the tool returns data matching this structure + + @pytest.mark.asyncio + async def test_dataset_id_flexibility(self): + """Test that dataset_id can be string or int.""" + configs = [ + GenerateChartRequest( + dataset_id="123", + config=TableChartConfig( + chart_type="table", columns=[ColumnRef(name="col1")] + ), + ), + GenerateChartRequest( + dataset_id="uuid-string-here", + config=TableChartConfig( + chart_type="table", columns=[ColumnRef(name="col1")] + ), + ), + ] + + for config in configs: + assert isinstance(config.dataset_id, str) + + @pytest.mark.asyncio + async def test_save_chart_flag(self): + """Test save_chart flag behavior.""" + # Default should be True (save chart) + request1 = GenerateChartRequest( + dataset_id="1", + config=TableChartConfig( + chart_type="table", columns=[ColumnRef(name="col1")] + ), + ) + assert request1.save_chart is True + + # Explicit False (preview only) + request2 = GenerateChartRequest( + dataset_id="1", + config=TableChartConfig( + chart_type="table", columns=[ColumnRef(name="col1")] + ), + save_chart=False, + ) + assert request2.save_chart is False + + @pytest.mark.asyncio + async def test_preview_formats(self): + """Test preview format options.""" + formats = ["url", "ascii", "table"] + request = GenerateChartRequest( + dataset_id="1", + config=TableChartConfig( + chart_type="table", columns=[ColumnRef(name="col1")] + ), + generate_preview=True, + preview_formats=formats, + ) + assert request.generate_preview is True + assert set(request.preview_formats) == set(formats) + + @pytest.mark.asyncio + async def test_column_ref_features(self): + """Test ColumnRef features like aggregation and labels.""" + # Simple column + col1 = ColumnRef(name="region") + assert col1.name == "region" + assert col1.label is None + assert col1.aggregate is None + + # Column with aggregation + col2 = ColumnRef(name="sales", aggregate="SUM", label="Total Sales") + assert col2.name == "sales" + assert col2.aggregate == "SUM" + assert col2.label == "Total Sales" + + # All supported aggregations + aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"] + for agg in aggs: + col = ColumnRef(name="value", aggregate=agg) + assert col.aggregate == agg + + @pytest.mark.asyncio + async def test_axis_config_options(self): + """Test axis configuration options.""" + axis = AxisConfig( + title="Sales Amount", + format="$,.2f", + scale="linear", + ) + assert axis.title == "Sales Amount" + assert axis.format == "$,.2f" + assert axis.scale == "linear" + + # Test different formats + formats = ["SMART_NUMBER", "$,.2f", ",.0f", "smart_date", ".3%"] + for fmt in formats: + axis = AxisConfig(format=fmt) + assert axis.format == fmt + + @pytest.mark.asyncio + async def test_legend_config_options(self): + """Test legend configuration options.""" + positions = ["top", "bottom", "left", "right"] + for pos in positions: + legend = LegendConfig(show=True, position=pos) + assert legend.position == pos + + # Hidden legend + legend = LegendConfig(show=False) + assert legend.show is False diff --git a/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py new file mode 100644 index 00000000000..cbff760778e --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/tool/test_get_chart_preview.py @@ -0,0 +1,290 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for get_chart_preview MCP tool +""" + +import pytest + +from superset.mcp_service.chart.schemas import ( + ASCIIPreview, + GetChartPreviewRequest, + TablePreview, + URLPreview, +) + + +class TestGetChartPreview: + """Tests for get_chart_preview MCP tool.""" + + @pytest.mark.asyncio + async def test_get_chart_preview_request_structure(self): + """Test that preview request structures are properly formed.""" + # Numeric ID request + request1 = GetChartPreviewRequest(identifier=123, format="url") + assert request1.identifier == 123 + assert request1.format == "url" + # Default dimensions are set + assert request1.width == 800 + assert request1.height == 600 + + # String ID request + request2 = GetChartPreviewRequest(identifier="456", format="ascii") + assert request2.identifier == "456" + assert request2.format == "ascii" + + # UUID request + request3 = GetChartPreviewRequest( + identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", format="table" + ) + assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890" + assert request3.format == "table" + + # Default format + request4 = GetChartPreviewRequest(identifier=789) + assert request4.format == "url" # default + + @pytest.mark.asyncio + async def test_preview_format_types(self): + """Test different preview format types.""" + formats = ["url", "ascii", "table"] + for fmt in formats: + request = GetChartPreviewRequest(identifier=1, format=fmt) + assert request.format == fmt + + @pytest.mark.asyncio + async def test_url_preview_structure(self): + """Test URLPreview response structure.""" + preview = URLPreview( + preview_url="http://localhost:5008/screenshot/chart/123.png", + width=800, + height=600, + supports_interaction=False, + ) + assert preview.type == "url" + assert preview.preview_url == "http://localhost:5008/screenshot/chart/123.png" + assert preview.width == 800 + assert preview.height == 600 + assert preview.supports_interaction is False + + @pytest.mark.asyncio + async def test_ascii_preview_structure(self): + """Test ASCIIPreview response structure.""" + ascii_art = """ +โ”Œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ” +โ”‚ Sales by Region โ”‚ +โ”œโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ค +โ”‚ North โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ 45% โ”‚ +โ”‚ South โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ 30% โ”‚ +โ”‚ East โ–ˆโ–ˆโ–ˆโ–ˆ 20% โ”‚ +โ”‚ West โ–ˆโ–ˆ 5% โ”‚ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜ +""" + preview = ASCIIPreview( + ascii_content=ascii_art.strip(), + width=25, + height=8, + ) + assert preview.type == "ascii" + assert "Sales by Region" in preview.ascii_content + assert preview.width == 25 + assert preview.height == 8 + + @pytest.mark.asyncio + async def test_table_preview_structure(self): + """Test TablePreview response structure.""" + table_content = """ +| Region | Sales | Profit | +|--------|--------|--------| +| North | 45000 | 12000 | +| South | 30000 | 8000 | +| East | 20000 | 5000 | +| West | 5000 | 1000 | +""" + preview = TablePreview( + table_data=table_content.strip(), + row_count=4, + supports_sorting=True, + ) + assert preview.type == "table" + assert "Region" in preview.table_data + assert "North" in preview.table_data + assert preview.row_count == 4 + assert preview.supports_sorting is True + + @pytest.mark.asyncio + async def test_chart_preview_response_structure(self): + """Test the expected response structure for chart preview.""" + # Core fields that should always be present + _ = [ + "chart_id", + "chart_name", + "chart_type", + "explore_url", + "content", # Union of URLPreview | ASCIIPreview | TablePreview + "chart_description", + "accessibility", + "performance", + ] + + # Additional fields that may be present for backward compatibility + _ = [ + "format", + "preview_url", + "ascii_chart", + "table_data", + "width", + "height", + "schema_version", + "api_version", + ] + + # This is a structural test - actual integration tests would verify + # the tool returns data matching this structure + + @pytest.mark.asyncio + async def test_preview_dimensions(self): + """Test preview dimensions in response.""" + # Standard dimensions that may appear in preview responses + standard_sizes = [ + (800, 600), # Default + (1200, 800), # Large + (400, 300), # Small + (1920, 1080), # Full HD + ] + + for width, height in standard_sizes: + # URL preview with dimensions + url_preview = URLPreview( + preview_url="http://example.com/chart.png", + width=width, + height=height, + supports_interaction=False, + ) + assert url_preview.width == width + assert url_preview.height == height + + @pytest.mark.asyncio + async def test_error_response_structures(self): + """Test error response structures.""" + # Error responses typically follow this structure + error_response = { + "error_type": "not_found", + "message": "Chart not found", + "details": "No chart found with ID 999", + "chart_id": 999, + } + assert error_response["error_type"] == "not_found" + assert error_response["chart_id"] == 999 + + # Preview generation error structure + preview_error = { + "error_type": "preview_error", + "message": "Failed to generate preview", + "details": "Screenshot service unavailable", + } + assert preview_error["error_type"] == "preview_error" + + @pytest.mark.asyncio + async def test_accessibility_metadata(self): + """Test accessibility metadata structure.""" + from superset.mcp_service.chart.schemas import AccessibilityMetadata + + metadata = AccessibilityMetadata( + color_blind_safe=True, + alt_text="Bar chart showing sales by region", + high_contrast_available=False, + ) + assert metadata.color_blind_safe is True + assert "sales by region" in metadata.alt_text + assert metadata.high_contrast_available is False + + @pytest.mark.asyncio + async def test_performance_metadata(self): + """Test performance metadata structure.""" + from superset.mcp_service.chart.schemas import PerformanceMetadata + + metadata = PerformanceMetadata( + query_duration_ms=150, + cache_status="hit", + optimization_suggestions=["Consider adding an index on date column"], + ) + assert metadata.query_duration_ms == 150 + assert metadata.cache_status == "hit" + assert len(metadata.optimization_suggestions) == 1 + + @pytest.mark.asyncio + async def test_chart_types_support(self): + """Test that various chart types are supported.""" + chart_types = [ + "echarts_timeseries_line", + "echarts_timeseries_bar", + "echarts_area", + "echarts_timeseries_scatter", + "table", + "pie", + "big_number", + "big_number_total", + "pivot_table_v2", + "dist_bar", + "box_plot", + ] + + # All chart types should be previewable + for _chart_type in chart_types: + # This would be tested in integration tests + pass + + @pytest.mark.asyncio + async def test_ascii_art_variations(self): + """Test ASCII art generation for different chart types.""" + # Line chart ASCII + _ = """ +Sales Trend +โ”‚ +โ”‚ โ•ฑโ•ฒ +โ”‚ โ•ฑ โ•ฒ +โ”‚ โ•ฑ โ•ฒ +โ”‚ โ•ฑ โ•ฒ +โ”‚ โ•ฑ โ•ฒ +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + Jan Feb Mar +""" + + # Bar chart ASCII + _ = """ +Sales by Region +โ”‚ +โ”‚ โ–ˆโ–ˆโ–ˆโ–ˆ North +โ”‚ โ–ˆโ–ˆโ–ˆ South +โ”‚ โ–ˆโ–ˆ East +โ”‚ โ–ˆ West +โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ +""" + + # Pie chart ASCII + _ = """ +Market Share + โ•ญโ”€โ”€โ”€โ”€โ”€โ•ฎ + โ•ฑ โ•ฒ +โ”‚ 45% โ”‚ +โ”‚ North โ”‚ +โ•ฐโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ•ฏ +""" + + # These demonstrate the expected ASCII formats for different chart types diff --git a/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py new file mode 100644 index 00000000000..f21c4ee11e6 --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py @@ -0,0 +1,385 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for update_chart MCP tool +""" + +import pytest + +from superset.mcp_service.chart.schemas import ( + AxisConfig, + ColumnRef, + FilterConfig, + LegendConfig, + TableChartConfig, + UpdateChartRequest, + XYChartConfig, +) + + +class TestUpdateChart: + """Tests for update_chart MCP tool.""" + + @pytest.mark.asyncio + async def test_update_chart_request_structure(self): + """Test that chart update request structures are properly formed.""" + # Table chart update with numeric ID + table_config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="region", label="Region"), + ColumnRef(name="sales", label="Sales", aggregate="SUM"), + ], + filters=[FilterConfig(column="year", op="=", value="2024")], + sort_by=["sales"], + ) + table_request = UpdateChartRequest(identifier=123, config=table_config) + assert table_request.identifier == 123 + assert table_request.config.chart_type == "table" + assert len(table_request.config.columns) == 2 + assert table_request.config.columns[0].name == "region" + assert table_request.config.columns[1].aggregate == "SUM" + + # XY chart update with UUID + xy_config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="value", aggregate="SUM")], + kind="line", + group_by=ColumnRef(name="category"), + x_axis=AxisConfig(title="Date", format="smart_date"), + y_axis=AxisConfig(title="Value", format="$,.2f"), + legend=LegendConfig(show=True, position="top"), + ) + xy_request = UpdateChartRequest( + identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config + ) + assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890" + assert xy_request.config.chart_type == "xy" + assert xy_request.config.x.name == "date" + assert xy_request.config.y[0].aggregate == "SUM" + assert xy_request.config.kind == "line" + + @pytest.mark.asyncio + async def test_update_chart_with_chart_name(self): + """Test updating chart with custom chart name.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + # Without custom name + request1 = UpdateChartRequest(identifier=123, config=config) + assert request1.chart_name is None + + # With custom name + request2 = UpdateChartRequest( + identifier=123, config=config, chart_name="Updated Sales Report" + ) + assert request2.chart_name == "Updated Sales Report" + + @pytest.mark.asyncio + async def test_update_chart_preview_generation(self): + """Test preview generation options in update request.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + # Default preview generation + request1 = UpdateChartRequest(identifier=123, config=config) + assert request1.generate_preview is True + assert request1.preview_formats == ["url"] + + # Custom preview formats + request2 = UpdateChartRequest( + identifier=123, + config=config, + generate_preview=True, + preview_formats=["url", "ascii", "table"], + ) + assert request2.generate_preview is True + assert set(request2.preview_formats) == {"url", "ascii", "table"} + + # Disable preview generation + request3 = UpdateChartRequest( + identifier=123, config=config, generate_preview=False + ) + assert request3.generate_preview is False + + @pytest.mark.asyncio + async def test_update_chart_identifier_types(self): + """Test that identifier can be int or string (UUID).""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + # Integer ID + request1 = UpdateChartRequest(identifier=123, config=config) + assert request1.identifier == 123 + assert isinstance(request1.identifier, int) + + # String numeric ID + request2 = UpdateChartRequest(identifier="456", config=config) + assert request2.identifier == "456" + assert isinstance(request2.identifier, str) + + # UUID string + request3 = UpdateChartRequest( + identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=config + ) + assert request3.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890" + assert isinstance(request3.identifier, str) + + @pytest.mark.asyncio + async def test_update_chart_config_variations(self): + """Test various chart configuration options in updates.""" + # Test all XY chart types + chart_types = ["line", "bar", "area", "scatter"] + for chart_type in chart_types: + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x_col"), + y=[ColumnRef(name="y_col")], + kind=chart_type, + ) + request = UpdateChartRequest(identifier=1, config=config) + assert request.config.kind == chart_type + + # Test multiple Y columns + multi_y_config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ + ColumnRef(name="sales", aggregate="SUM"), + ColumnRef(name="profit", aggregate="AVG"), + ColumnRef(name="quantity", aggregate="COUNT"), + ], + kind="line", + ) + request = UpdateChartRequest(identifier=1, config=multi_y_config) + assert len(request.config.y) == 3 + assert request.config.y[1].aggregate == "AVG" + + # Test filter operators + operators = ["=", "!=", ">", ">=", "<", "<="] + filters = [FilterConfig(column="col", op=op, value="val") for op in operators] + table_config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + filters=filters, + ) + request = UpdateChartRequest(identifier=1, config=table_config) + assert len(request.config.filters) == 6 + + @pytest.mark.asyncio + async def test_update_chart_response_structure(self): + """Test the expected response structure for chart updates.""" + # The response should contain these fields + expected_response = { + "chart": { + "id": int, + "slice_name": str, + "viz_type": str, + "url": str, + "uuid": str, + "updated": bool, + }, + "error": None, + "success": bool, + "schema_version": str, + "api_version": str, + } + + # When chart update succeeds, these additional fields may be present + optional_fields = [ + "previews", + "capabilities", + "semantics", + "explore_url", + "api_endpoints", + "performance", + "accessibility", + ] + + # Validate structure expectations + assert "chart" in expected_response + assert "success" in expected_response + assert len(optional_fields) > 0 + + @pytest.mark.asyncio + async def test_update_chart_axis_configurations(self): + """Test axis configuration updates.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="sales")], + x_axis=AxisConfig( + title="Date", + format="smart_date", + scale="linear", + ), + y_axis=AxisConfig( + title="Sales Amount", + format="$,.2f", + scale="log", + ), + ) + request = UpdateChartRequest(identifier=1, config=config) + assert request.config.x_axis.title == "Date" + assert request.config.x_axis.format == "smart_date" + assert request.config.x_axis.scale == "linear" + assert request.config.y_axis.title == "Sales Amount" + assert request.config.y_axis.format == "$,.2f" + assert request.config.y_axis.scale == "log" + + @pytest.mark.asyncio + async def test_update_chart_legend_configurations(self): + """Test legend configuration updates.""" + positions = ["top", "bottom", "left", "right"] + for pos in positions: + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x"), + y=[ColumnRef(name="y")], + legend=LegendConfig(show=True, position=pos), + ) + request = UpdateChartRequest(identifier=1, config=config) + assert request.config.legend.position == pos + assert request.config.legend.show is True + + # Hidden legend + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x"), + y=[ColumnRef(name="y")], + legend=LegendConfig(show=False), + ) + request = UpdateChartRequest(identifier=1, config=config) + assert request.config.legend.show is False + + @pytest.mark.asyncio + async def test_update_chart_aggregation_functions(self): + """Test all supported aggregation functions in updates.""" + aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"] + for agg in aggs: + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="value", aggregate=agg)], + ) + request = UpdateChartRequest(identifier=1, config=config) + assert request.config.columns[0].aggregate == agg + + @pytest.mark.asyncio + async def test_update_chart_error_responses(self): + """Test expected error response structures.""" + # Chart not found error + error_response = { + "chart": None, + "error": "No chart found with identifier: 999", + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + assert error_response["success"] is False + assert error_response["chart"] is None + assert "chart found" in error_response["error"].lower() + + # General update error + update_error = { + "chart": None, + "error": "Chart update failed: Permission denied", + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + assert update_error["success"] is False + assert "failed" in update_error["error"].lower() + + @pytest.mark.asyncio + async def test_chart_name_sanitization(self): + """Test that chart names are properly sanitized.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + # Test with potentially problematic characters + test_names = [ + "Normal Chart Name", + "Chart with 'quotes'", + 'Chart with "double quotes"', + "Chart with ", + ] + + for name in test_names: + request = UpdateChartRequest(identifier=1, config=config, chart_name=name) + # Chart name should be set (sanitization happens in the validator) + assert request.chart_name is not None + + @pytest.mark.asyncio + async def test_update_chart_with_filters(self): + """Test updating chart with various filter configurations.""" + filters = [ + FilterConfig(column="region", op="=", value="North"), + FilterConfig(column="sales", op=">=", value=1000), + FilterConfig(column="date", op=">", value="2024-01-01"), + ] + + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="region"), + ColumnRef(name="sales"), + ColumnRef(name="date"), + ], + filters=filters, + ) + + request = UpdateChartRequest(identifier=1, config=config) + assert len(request.config.filters) == 3 + assert request.config.filters[0].column == "region" + assert request.config.filters[1].op == ">=" + assert request.config.filters[2].value == "2024-01-01" + + @pytest.mark.asyncio + async def test_update_chart_cache_control(self): + """Test cache control parameters in update request.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + # Default cache settings + request1 = UpdateChartRequest(identifier=1, config=config) + assert request1.use_cache is True + assert request1.force_refresh is False + assert request1.cache_timeout is None + + # Custom cache settings + request2 = UpdateChartRequest( + identifier=1, + config=config, + use_cache=False, + force_refresh=True, + cache_timeout=300, + ) + assert request2.use_cache is False + assert request2.force_refresh is True + assert request2.cache_timeout == 300 diff --git a/tests/unit_tests/mcp_service/chart/tool/test_update_chart_preview.py b/tests/unit_tests/mcp_service/chart/tool/test_update_chart_preview.py new file mode 100644 index 00000000000..f51a1e80afe --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/tool/test_update_chart_preview.py @@ -0,0 +1,474 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for update_chart_preview MCP tool +""" + +import pytest + +from superset.mcp_service.chart.schemas import ( + AxisConfig, + ColumnRef, + FilterConfig, + LegendConfig, + TableChartConfig, + UpdateChartPreviewRequest, + XYChartConfig, +) + + +class TestUpdateChartPreview: + """Tests for update_chart_preview MCP tool.""" + + @pytest.mark.asyncio + async def test_update_chart_preview_request_structure(self): + """Test that chart preview update request structures are properly formed.""" + # Table chart preview update + table_config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="region", label="Region"), + ColumnRef(name="sales", label="Sales", aggregate="SUM"), + ], + filters=[FilterConfig(column="year", op="=", value="2024")], + sort_by=["sales"], + ) + table_request = UpdateChartPreviewRequest( + form_data_key="abc123def456", dataset_id=1, config=table_config + ) + assert table_request.form_data_key == "abc123def456" + assert table_request.dataset_id == 1 + assert table_request.config.chart_type == "table" + assert len(table_request.config.columns) == 2 + assert table_request.config.columns[0].name == "region" + + # XY chart preview update + xy_config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="value", aggregate="SUM")], + kind="line", + group_by=ColumnRef(name="category"), + x_axis=AxisConfig(title="Date", format="smart_date"), + y_axis=AxisConfig(title="Value", format="$,.2f"), + legend=LegendConfig(show=True, position="top"), + ) + xy_request = UpdateChartPreviewRequest( + form_data_key="xyz789ghi012", dataset_id="2", config=xy_config + ) + assert xy_request.form_data_key == "xyz789ghi012" + assert xy_request.dataset_id == "2" + assert xy_request.config.chart_type == "xy" + assert xy_request.config.x.name == "date" + assert xy_request.config.kind == "line" + + @pytest.mark.asyncio + async def test_update_chart_preview_dataset_id_types(self): + """Test that dataset_id can be int or string (UUID).""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + # Integer dataset_id + request1 = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=123, config=config + ) + assert request1.dataset_id == 123 + assert isinstance(request1.dataset_id, int) + + # String numeric dataset_id + request2 = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id="456", config=config + ) + assert request2.dataset_id == "456" + assert isinstance(request2.dataset_id, str) + + # UUID string dataset_id + request3 = UpdateChartPreviewRequest( + form_data_key="abc123", + dataset_id="a1b2c3d4-e5f6-7890-abcd-ef1234567890", + config=config, + ) + assert request3.dataset_id == "a1b2c3d4-e5f6-7890-abcd-ef1234567890" + assert isinstance(request3.dataset_id, str) + + @pytest.mark.asyncio + async def test_update_chart_preview_generation_options(self): + """Test preview generation options in update preview request.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + # Default preview generation + request1 = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert request1.generate_preview is True + assert request1.preview_formats == ["url"] + + # Custom preview formats + request2 = UpdateChartPreviewRequest( + form_data_key="abc123", + dataset_id=1, + config=config, + generate_preview=True, + preview_formats=["url", "ascii", "table"], + ) + assert request2.generate_preview is True + assert set(request2.preview_formats) == {"url", "ascii", "table"} + + # Disable preview generation + request3 = UpdateChartPreviewRequest( + form_data_key="abc123", + dataset_id=1, + config=config, + generate_preview=False, + ) + assert request3.generate_preview is False + + @pytest.mark.asyncio + async def test_update_chart_preview_config_variations(self): + """Test various chart configuration options in preview updates.""" + # Test all XY chart types + chart_types = ["line", "bar", "area", "scatter"] + for chart_type in chart_types: + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x_col"), + y=[ColumnRef(name="y_col")], + kind=chart_type, + ) + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert request.config.kind == chart_type + + # Test multiple Y columns + multi_y_config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ + ColumnRef(name="sales", aggregate="SUM"), + ColumnRef(name="profit", aggregate="AVG"), + ColumnRef(name="quantity", aggregate="COUNT"), + ], + kind="line", + ) + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=multi_y_config + ) + assert len(request.config.y) == 3 + assert request.config.y[1].aggregate == "AVG" + + # Test filter operators + operators = ["=", "!=", ">", ">=", "<", "<="] + filters = [FilterConfig(column="col", op=op, value="val") for op in operators] + table_config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + filters=filters, + ) + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=table_config + ) + assert len(request.config.filters) == 6 + + @pytest.mark.asyncio + async def test_update_chart_preview_response_structure(self): + """Test the expected response structure for chart preview updates.""" + # The response should contain these fields + expected_response = { + "chart": { + "id": None, # No ID for unsaved previews + "slice_name": str, + "viz_type": str, + "url": str, + "uuid": None, # No UUID for unsaved previews + "saved": bool, + "updated": bool, + }, + "error": None, + "success": bool, + "schema_version": str, + "api_version": str, + } + + # When preview update succeeds, these additional fields may be present + optional_fields = [ + "previews", + "capabilities", + "semantics", + "explore_url", + "form_data_key", + "previous_form_data_key", + "api_endpoints", + "performance", + "accessibility", + ] + + # Validate structure expectations + assert "chart" in expected_response + assert "success" in expected_response + assert len(optional_fields) > 0 + assert expected_response["chart"]["id"] is None + assert expected_response["chart"]["uuid"] is None + + @pytest.mark.asyncio + async def test_update_chart_preview_axis_configurations(self): + """Test axis configuration updates in preview.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="sales")], + x_axis=AxisConfig( + title="Date", + format="smart_date", + scale="linear", + ), + y_axis=AxisConfig( + title="Sales Amount", + format="$,.2f", + scale="log", + ), + ) + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert request.config.x_axis.title == "Date" + assert request.config.x_axis.format == "smart_date" + assert request.config.y_axis.title == "Sales Amount" + assert request.config.y_axis.format == "$,.2f" + + @pytest.mark.asyncio + async def test_update_chart_preview_legend_configurations(self): + """Test legend configuration updates in preview.""" + positions = ["top", "bottom", "left", "right"] + for pos in positions: + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x"), + y=[ColumnRef(name="y")], + legend=LegendConfig(show=True, position=pos), + ) + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert request.config.legend.position == pos + assert request.config.legend.show is True + + # Hidden legend + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x"), + y=[ColumnRef(name="y")], + legend=LegendConfig(show=False), + ) + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert request.config.legend.show is False + + @pytest.mark.asyncio + async def test_update_chart_preview_aggregation_functions(self): + """Test all supported aggregation functions in preview updates.""" + aggs = ["SUM", "AVG", "COUNT", "MIN", "MAX", "COUNT_DISTINCT"] + for agg in aggs: + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="value", aggregate=agg)], + ) + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert request.config.columns[0].aggregate == agg + + @pytest.mark.asyncio + async def test_update_chart_preview_error_responses(self): + """Test expected error response structures for preview updates.""" + # General update error + error_response = { + "chart": None, + "error": "Chart preview update failed: Invalid form_data_key", + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + assert error_response["success"] is False + assert error_response["chart"] is None + assert "failed" in error_response["error"].lower() + + # Missing dataset error + dataset_error = { + "chart": None, + "error": "Chart preview update failed: Dataset not found", + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + assert dataset_error["success"] is False + assert "dataset" in dataset_error["error"].lower() + + @pytest.mark.asyncio + async def test_update_chart_preview_with_filters(self): + """Test updating preview with various filter configurations.""" + filters = [ + FilterConfig(column="region", op="=", value="North"), + FilterConfig(column="sales", op=">=", value=1000), + FilterConfig(column="date", op=">", value="2024-01-01"), + ] + + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="region"), + ColumnRef(name="sales"), + ColumnRef(name="date"), + ], + filters=filters, + ) + + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert len(request.config.filters) == 3 + assert request.config.filters[0].column == "region" + assert request.config.filters[1].op == ">=" + assert request.config.filters[2].value == "2024-01-01" + + @pytest.mark.asyncio + async def test_update_chart_preview_form_data_key_handling(self): + """Test form_data_key handling in preview updates.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + # Various form_data_key formats + form_data_keys = [ + "abc123def456", + "xyz-789-ghi-012", + "key_with_underscore", + "UPPERCASE_KEY", + ] + + for key in form_data_keys: + request = UpdateChartPreviewRequest( + form_data_key=key, dataset_id=1, config=config + ) + assert request.form_data_key == key + + @pytest.mark.asyncio + async def test_update_chart_preview_cache_control(self): + """Test cache control parameters in update preview request.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + # Default cache settings + request1 = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert request1.use_cache is True + assert request1.force_refresh is False + assert request1.cache_form_data is True + + # Custom cache settings + request2 = UpdateChartPreviewRequest( + form_data_key="abc123", + dataset_id=1, + config=config, + use_cache=False, + force_refresh=True, + cache_form_data=False, + ) + assert request2.use_cache is False + assert request2.force_refresh is True + assert request2.cache_form_data is False + + @pytest.mark.asyncio + async def test_update_chart_preview_no_save_behavior(self): + """Test that preview updates don't create permanent charts.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="col1")], + ) + + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + + # Preview updates should never create permanent charts + # This is validated by checking the response structure + expected_unsaved_fields = { + "id": None, # No chart ID + "uuid": None, # No UUID + "saved": False, # Not saved + } + + # These expectations are validated in the response, not the request + assert request.form_data_key == "abc123" + assert expected_unsaved_fields["id"] is None + assert expected_unsaved_fields["uuid"] is None + assert expected_unsaved_fields["saved"] is False + + @pytest.mark.asyncio + async def test_update_chart_preview_multiple_y_columns(self): + """Test preview updates with multiple Y-axis columns.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ + ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue"), + ColumnRef(name="cost", aggregate="SUM", label="Total Cost"), + ColumnRef(name="profit", aggregate="SUM", label="Total Profit"), + ColumnRef(name="orders", aggregate="COUNT", label="Order Count"), + ], + kind="line", + ) + + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert len(request.config.y) == 4 + assert request.config.y[0].name == "revenue" + assert request.config.y[1].name == "cost" + assert request.config.y[2].name == "profit" + assert request.config.y[3].name == "orders" + assert request.config.y[3].aggregate == "COUNT" + + @pytest.mark.asyncio + async def test_update_chart_preview_table_sorting(self): + """Test table chart sorting in preview updates.""" + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="region"), + ColumnRef(name="sales", aggregate="SUM"), + ColumnRef(name="profit", aggregate="AVG"), + ], + sort_by=["sales", "profit"], + ) + + request = UpdateChartPreviewRequest( + form_data_key="abc123", dataset_id=1, config=config + ) + assert request.config.sort_by == ["sales", "profit"] + assert len(request.config.columns) == 3 diff --git a/tests/unit_tests/mcp_service/conftest.py b/tests/unit_tests/mcp_service/conftest.py new file mode 100644 index 00000000000..92868309552 --- /dev/null +++ b/tests/unit_tests/mcp_service/conftest.py @@ -0,0 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +MCP service test configuration. + +Tool imports are handled by app.py, not here. +This conftest is empty to prevent test pollution. +""" diff --git a/tests/unit_tests/mcp_service/dashboard/__init__.py b/tests/unit_tests/mcp_service/dashboard/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/dashboard/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/dashboard/tool/__init__.py b/tests/unit_tests/mcp_service/dashboard/tool/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/dashboard/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py new file mode 100644 index 00000000000..b24d21e457a --- /dev/null +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_generation.py @@ -0,0 +1,450 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for dashboard generation MCP tools +""" + +import logging +from unittest.mock import Mock, patch + +import pytest +from fastmcp import Client + +from superset.mcp_service.app import mcp +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +def _mock_chart(id: int = 1, slice_name: str = "Test Chart") -> Mock: + """Create a mock chart object.""" + chart = Mock() + chart.id = id + chart.slice_name = slice_name + chart.uuid = f"chart-uuid-{id}" + return chart + + +def _mock_dashboard(id: int = 1, title: str = "Test Dashboard") -> Mock: + """Create a mock dashboard object.""" + dashboard = Mock() + dashboard.id = id + dashboard.dashboard_title = title + dashboard.slug = f"test-dashboard-{id}" + dashboard.description = "Test dashboard description" + dashboard.published = True + dashboard.created_on = "2024-01-01" + dashboard.changed_on = "2024-01-01" + dashboard.created_by = Mock() + dashboard.created_by.username = "test_user" + dashboard.changed_by = Mock() + dashboard.changed_by.username = "test_user" + dashboard.uuid = f"dashboard-uuid-{id}" + dashboard.slices = [] + dashboard.owners = [] # Add missing owners attribute + dashboard.tags = [] # Add missing tags attribute + return dashboard + + +class TestGenerateDashboard: + """Tests for generate_dashboard MCP tool.""" + + @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_generate_dashboard_basic( + self, mock_db_session, mock_create_command, mcp_server + ): + """Test basic dashboard generation with valid charts.""" + # Mock database query for charts + mock_query = Mock() + mock_filter = Mock() + mock_query.filter.return_value = mock_filter + mock_filter.all.return_value = [ + _mock_chart(id=1, slice_name="Sales Chart"), + _mock_chart(id=2, slice_name="Revenue Chart"), + ] + mock_db_session.query.return_value = mock_query + + # Mock dashboard creation + mock_dashboard = _mock_dashboard(id=10, title="Analytics Dashboard") + mock_create_command.return_value.run.return_value = mock_dashboard + + request = { + "chart_ids": [1, 2], + "dashboard_title": "Analytics Dashboard", + "description": "Dashboard for analytics", + "published": True, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("generate_dashboard", {"request": request}) + + assert result.data.error is None + assert result.data.dashboard is not None + assert result.data.dashboard.id == 10 + assert result.data.dashboard.dashboard_title == "Analytics Dashboard" + assert result.data.dashboard.chart_count == 2 + assert "/superset/dashboard/10/" in result.data.dashboard_url + + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_generate_dashboard_missing_charts(self, mock_db_session, mcp_server): + """Test error handling when some charts don't exist.""" + # Mock database query returning only chart 1 (chart 2 missing) + mock_query = Mock() + mock_filter = Mock() + mock_query.filter.return_value = mock_filter + mock_filter.all.return_value = [ + _mock_chart(id=1), + # Chart 2 is missing from the result + ] + mock_db_session.query.return_value = mock_query + + request = {"chart_ids": [1, 2], "dashboard_title": "Test Dashboard"} + + async with Client(mcp_server) as client: + result = await client.call_tool("generate_dashboard", {"request": request}) + + assert result.data.error is not None + assert "Charts not found: [2]" in result.data.error + assert result.data.dashboard is None + assert result.data.dashboard_url is None + + @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_generate_dashboard_single_chart( + self, mock_db_session, mock_create_command, mcp_server + ): + """Test dashboard generation with a single chart.""" + # Mock database query for single chart + mock_query = Mock() + mock_filter = Mock() + mock_query.filter.return_value = mock_filter + mock_filter.all.return_value = [_mock_chart(id=5, slice_name="Single Chart")] + mock_db_session.query.return_value = mock_query + + mock_dashboard = _mock_dashboard(id=20, title="Single Chart Dashboard") + mock_create_command.return_value.run.return_value = mock_dashboard + + request = { + "chart_ids": [5], + "dashboard_title": "Single Chart Dashboard", + "published": False, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("generate_dashboard", {"request": request}) + + assert result.data.error is None + assert result.data.dashboard.chart_count == 1 + assert result.data.dashboard.published is True # From mock + + @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_generate_dashboard_many_charts( + self, mock_db_session, mock_create_command, mcp_server + ): + """Test dashboard generation with many charts (grid layout).""" + # Mock 6 charts + chart_ids = list(range(1, 7)) + mock_query = Mock() + mock_filter = Mock() + mock_query.filter.return_value = mock_filter + mock_filter.all.return_value = [ + _mock_chart(id=i, slice_name=f"Chart {i}") for i in chart_ids + ] + mock_db_session.query.return_value = mock_query + + mock_dashboard = _mock_dashboard(id=30, title="Multi Chart Dashboard") + mock_create_command.return_value.run.return_value = mock_dashboard + + request = {"chart_ids": chart_ids, "dashboard_title": "Multi Chart Dashboard"} + + async with Client(mcp_server) as client: + result = await client.call_tool("generate_dashboard", {"request": request}) + + assert result.data.error is None + assert result.data.dashboard.chart_count == 6 + + # Verify CreateDashboardCommand was called with proper layout + mock_create_command.assert_called_once() + call_args = mock_create_command.call_args[0][0] + + # Check position_json contains proper layout + position_json = json.loads(call_args["position_json"]) + assert "ROOT_ID" in position_json + assert "GRID_ID" in position_json + assert "DASHBOARD_VERSION_KEY" in position_json + assert position_json["DASHBOARD_VERSION_KEY"] == "v2" + + # ROOT should only contain GRID + assert position_json["ROOT_ID"]["children"] == ["GRID_ID"] + + # GRID should contain rows (6 charts = 3 rows in 2-chart layout) + grid_children = position_json["GRID_ID"]["children"] + assert len(grid_children) == 3 + + # Check each chart has proper structure + for i, chart_id in enumerate(chart_ids): + chart_key = f"CHART-{chart_id}" + row_index = i // 2 # 2 charts per row + row_key = f"ROW-{row_index}" + + # Chart should exist + assert chart_key in position_json + chart_data = position_json[chart_key] + assert chart_data["type"] == "CHART" + assert "meta" in chart_data + assert chart_data["meta"]["chartId"] == chart_id + + # Row should exist and contain charts (up to 2 per row) + assert row_key in position_json + row_data = position_json[row_key] + assert row_data["type"] == "ROW" + assert chart_key in row_data["children"] + + @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_generate_dashboard_creation_failure( + self, mock_db_session, mock_create_command, mcp_server + ): + """Test error handling when dashboard creation fails.""" + mock_query = Mock() + mock_filter = Mock() + mock_query.filter.return_value = mock_filter + mock_filter.all.return_value = [_mock_chart(id=1)] + mock_db_session.query.return_value = mock_query + mock_create_command.return_value.run.side_effect = Exception("Creation failed") + + request = {"chart_ids": [1], "dashboard_title": "Failed Dashboard"} + + async with Client(mcp_server) as client: + result = await client.call_tool("generate_dashboard", {"request": request}) + + assert result.data.error is not None + assert "Failed to create dashboard" in result.data.error + assert result.data.dashboard is None + + @patch("superset.commands.dashboard.create.CreateDashboardCommand") + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_generate_dashboard_minimal_request( + self, mock_db_session, mock_create_command, mcp_server + ): + """Test dashboard generation with minimal required parameters.""" + # Mock database query for single chart + mock_query = Mock() + mock_filter = Mock() + mock_query.filter.return_value = mock_filter + mock_filter.all.return_value = [_mock_chart(id=3)] + mock_db_session.query.return_value = mock_query + + mock_dashboard = _mock_dashboard(id=40, title="Minimal Dashboard") + mock_create_command.return_value.run.return_value = mock_dashboard + + request = { + "chart_ids": [3], + "dashboard_title": "Minimal Dashboard", + # No description, published defaults to True + } + + async with Client(mcp_server) as client: + result = await client.call_tool("generate_dashboard", {"request": request}) + + assert result.data.error is None + assert result.data.dashboard.dashboard_title == "Minimal Dashboard" + + # Check that description was not included in call + call_args = mock_create_command.call_args[0][0] + assert call_args["published"] is True # Default value + assert ( + "description" not in call_args or call_args.get("description") is None + ) + + +class TestAddChartToExistingDashboard: + """Tests for add_chart_to_existing_dashboard MCP tool.""" + + @patch("superset.commands.dashboard.update.UpdateDashboardCommand") + @patch("superset.daos.dashboard.DashboardDAO.find_by_id") + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_add_chart_to_dashboard_basic( + self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server + ): + """Test adding a chart to an existing dashboard.""" + # Mock existing dashboard with some charts + mock_dashboard = _mock_dashboard(id=1, title="Existing Dashboard") + mock_dashboard.slices = [Mock(id=10), Mock(id=20)] # Existing charts + mock_dashboard.position_json = json.dumps( + { + "ROOT_ID": { + "children": ["CHART-10", "CHART-20"], + "id": "ROOT_ID", + "type": "ROOT", + }, + "CHART-10": {"id": "CHART-10", "type": "CHART", "parents": ["ROOT_ID"]}, + "CHART-10_POSITION": {"h": 16, "w": 24, "x": 0, "y": 0}, + "CHART-20": {"id": "CHART-20", "type": "CHART", "parents": ["ROOT_ID"]}, + "CHART-20_POSITION": {"h": 16, "w": 24, "x": 24, "y": 0}, + } + ) + mock_find_dashboard.return_value = mock_dashboard + + # Mock chart to add + mock_chart = _mock_chart(id=30, slice_name="New Chart") + mock_db_session.get.return_value = mock_chart + + # Mock updated dashboard + updated_dashboard = _mock_dashboard(id=1, title="Existing Dashboard") + updated_dashboard.slices = [Mock(id=10), Mock(id=20), Mock(id=30)] + mock_update_command.return_value.run.return_value = updated_dashboard + + request = {"dashboard_id": 1, "chart_id": 30} + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", {"request": request} + ) + + assert result.data.error is None + assert result.data.dashboard is not None + assert result.data.dashboard.chart_count == 3 + assert result.data.position is not None + assert "row" in result.data.position # Should have row info + assert "chart_key" in result.data.position + assert "/superset/dashboard/1/" in result.data.dashboard_url + + @patch("superset.daos.dashboard.DashboardDAO.find_by_id") + @pytest.mark.asyncio + async def test_add_chart_dashboard_not_found(self, mock_find_dashboard, mcp_server): + """Test error when dashboard doesn't exist.""" + mock_find_dashboard.return_value = None + + request = {"dashboard_id": 999, "chart_id": 1} + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", {"request": request} + ) + + assert result.data.error is not None + assert "Dashboard with ID 999 not found" in result.data.error + + @patch("superset.daos.dashboard.DashboardDAO.find_by_id") + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_add_chart_chart_not_found( + self, mock_db_session, mock_find_dashboard, mcp_server + ): + """Test error when chart doesn't exist.""" + mock_find_dashboard.return_value = _mock_dashboard() + mock_db_session.get.return_value = None + + request = {"dashboard_id": 1, "chart_id": 999} + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", {"request": request} + ) + + assert result.data.error is not None + assert "Chart with ID 999 not found" in result.data.error + + @patch("superset.daos.dashboard.DashboardDAO.find_by_id") + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_add_chart_already_in_dashboard( + self, mock_db_session, mock_find_dashboard, mcp_server + ): + """Test error when chart is already in dashboard.""" + mock_dashboard = _mock_dashboard() + mock_dashboard.slices = [Mock(id=5)] # Chart 5 already exists + mock_find_dashboard.return_value = mock_dashboard + + mock_db_session.get.return_value = _mock_chart(id=5) + + request = {"dashboard_id": 1, "chart_id": 5} + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", {"request": request} + ) + + assert result.data.error is not None + assert "Chart 5 is already in dashboard 1" in result.data.error + + @patch("superset.commands.dashboard.update.UpdateDashboardCommand") + @patch("superset.daos.dashboard.DashboardDAO.find_by_id") + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_add_chart_empty_dashboard( + self, mock_db_session, mock_find_dashboard, mock_update_command, mcp_server + ): + """Test adding chart to dashboard with no existing layout.""" + mock_dashboard = _mock_dashboard(id=2) + mock_dashboard.slices = [] + mock_dashboard.position_json = "{}" # Empty layout + mock_find_dashboard.return_value = mock_dashboard + + mock_chart = _mock_chart(id=15) + mock_db_session.get.return_value = mock_chart + + updated_dashboard = _mock_dashboard(id=2) + updated_dashboard.slices = [Mock(id=15)] + mock_update_command.return_value.run.return_value = updated_dashboard + + request = {"dashboard_id": 2, "chart_id": 15} + + async with Client(mcp_server) as client: + result = await client.call_tool( + "add_chart_to_existing_dashboard", {"request": request} + ) + + assert result.data.error is None + assert "row" in result.data.position # Should have row info + assert result.data.position.get("row") == 0 # First row + + # Verify update was called with proper layout structure + call_args = mock_update_command.call_args[0][1] + layout = json.loads(call_args["position_json"]) + assert "ROOT_ID" in layout + assert "GRID_ID" in layout + assert "ROW-0" in layout + assert "CHART-15" in layout diff --git a/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_tools.py b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_tools.py new file mode 100644 index 00000000000..b6935f97518 --- /dev/null +++ b/tests/unit_tests/mcp_service/dashboard/tool/test_dashboard_tools.py @@ -0,0 +1,573 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for MCP dashboard tools (list_dashboards, get_dashboard_info) +""" + +import logging +from unittest.mock import Mock, patch + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError + +from superset.mcp_service.app import mcp +from superset.mcp_service.dashboard.schemas import ( + ListDashboardsRequest, +) + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +@patch("superset.daos.dashboard.DashboardDAO.list") +@pytest.mark.asyncio +async def test_list_dashboards_basic(mock_list, mcp_server): + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "Test Dashboard" + dashboard.slug = "test-dashboard" + dashboard.url = "/dashboard/1" + dashboard.published = True + dashboard.changed_by_name = "admin" + dashboard.changed_on = None + dashboard.changed_on_humanized = None + dashboard.created_by_name = "admin" + dashboard.created_on = None + dashboard.created_on_humanized = None + dashboard.tags = [] + dashboard.owners = [] + dashboard.slices = [] + dashboard.description = None + dashboard.css = None + dashboard.certified_by = None + dashboard.certification_details = None + dashboard.json_metadata = None + dashboard.position_json = None + dashboard.is_managed_externally = False + dashboard.external_url = None + dashboard.uuid = "test-dashboard-uuid-1" + dashboard.thumbnail_url = None + dashboard.roles = [] + dashboard.charts = [] + dashboard._mapping = { + "id": dashboard.id, + "dashboard_title": dashboard.dashboard_title, + "slug": dashboard.slug, + "url": dashboard.url, + "published": dashboard.published, + "changed_by_name": dashboard.changed_by_name, + "changed_on": dashboard.changed_on, + "changed_on_humanized": dashboard.changed_on_humanized, + "created_by_name": dashboard.created_by_name, + "created_on": dashboard.created_on, + "created_on_humanized": dashboard.created_on_humanized, + "tags": dashboard.tags, + "owners": dashboard.owners, + "charts": [], + } + mock_list.return_value = ([dashboard], 1) + async with Client(mcp_server) as client: + request = ListDashboardsRequest(page=1, page_size=10) + result = await client.call_tool( + "list_dashboards", {"request": request.model_dump()} + ) + dashboards = result.data.dashboards + assert len(dashboards) == 1 + assert dashboards[0].dashboard_title == "Test Dashboard" + assert dashboards[0].uuid == "test-dashboard-uuid-1" + assert dashboards[0].slug == "test-dashboard" + assert dashboards[0].published is True + + # Verify UUID and slug are in default columns + assert "uuid" in result.data.columns_requested + assert "slug" in result.data.columns_requested + assert "uuid" in result.data.columns_loaded + assert "slug" in result.data.columns_loaded + + +@patch("superset.daos.dashboard.DashboardDAO.list") +@pytest.mark.asyncio +async def test_list_dashboards_with_filters(mock_list, mcp_server): + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "Filtered Dashboard" + dashboard.slug = "filtered-dashboard" + dashboard.url = "/dashboard/2" + dashboard.published = True + dashboard.changed_by_name = "admin" + dashboard.changed_on = None + dashboard.changed_on_humanized = None + dashboard.created_by_name = "admin" + dashboard.created_on = None + dashboard.created_on_humanized = None + dashboard.tags = [] + dashboard.owners = [] + dashboard.slices = [] + dashboard.description = None + dashboard.css = None + dashboard.certified_by = None + dashboard.certification_details = None + dashboard.json_metadata = None + dashboard.position_json = None + dashboard.is_managed_externally = False + dashboard.external_url = None + dashboard.uuid = None + dashboard.thumbnail_url = None + dashboard.roles = [] + dashboard.charts = [] + dashboard._mapping = { + "id": dashboard.id, + "dashboard_title": dashboard.dashboard_title, + "slug": dashboard.slug, + "url": dashboard.url, + "published": dashboard.published, + "changed_by_name": dashboard.changed_by_name, + "changed_on": dashboard.changed_on, + "changed_on_humanized": dashboard.changed_on_humanized, + "created_by_name": dashboard.created_by_name, + "created_on": dashboard.created_on, + "created_on_humanized": dashboard.created_on_humanized, + "tags": dashboard.tags, + "owners": dashboard.owners, + "charts": [], + } + mock_list.return_value = ([dashboard], 1) + async with Client(mcp_server) as client: + filters = [ + {"col": "dashboard_title", "opr": "sw", "value": "Sales"}, + {"col": "published", "opr": "eq", "value": True}, + ] + request = ListDashboardsRequest( + filters=filters, + select_columns=["id", "dashboard_title"], + order_column="changed_on", + order_direction="desc", + page=1, + page_size=50, + ) + result = await client.call_tool( + "list_dashboards", {"request": request.model_dump()} + ) + assert result.data.count == 1 + assert result.data.dashboards[0].dashboard_title == "Filtered Dashboard" + + +@patch("superset.daos.dashboard.DashboardDAO.list") +@pytest.mark.asyncio +async def test_list_dashboards_with_string_filters(mock_list, mcp_server): + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: # noqa: F841 + filters = '[{"col": "dashboard_title", "opr": "sw", "value": "Sales"}]' + + # Test that string filters cause validation error at schema level + with pytest.raises(ValueError, match="validation error"): + ListDashboardsRequest(filters=filters) # noqa: F841 + + +@patch("superset.daos.dashboard.DashboardDAO.list") +@pytest.mark.asyncio +async def test_list_dashboards_api_error(mock_list, mcp_server): + mock_list.side_effect = ToolError("API request failed") + async with Client(mcp_server) as client: + with pytest.raises(ToolError) as excinfo: # noqa: PT012 + request = ListDashboardsRequest() + await client.call_tool("list_dashboards", {"request": request.model_dump()}) + assert "API request failed" in str(excinfo.value) + + +@patch("superset.daos.dashboard.DashboardDAO.list") +@pytest.mark.asyncio +async def test_list_dashboards_with_search(mock_list, mcp_server): + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "search_dashboard" + dashboard.slug = "search-dashboard" + dashboard.url = "/dashboard/1" + dashboard.published = True + dashboard.changed_by_name = "admin" + dashboard.changed_on = None + dashboard.changed_on_humanized = None + dashboard.created_by_name = "admin" + dashboard.created_on = None + dashboard.created_on_humanized = None + dashboard.tags = [] + dashboard.owners = [] + dashboard.slices = [] + dashboard.description = None + dashboard.css = None + dashboard.certified_by = None + dashboard.certification_details = None + dashboard.json_metadata = None + dashboard.position_json = None + dashboard.is_managed_externally = False + dashboard.external_url = None + dashboard.uuid = None + dashboard.thumbnail_url = None + dashboard.roles = [] + dashboard.charts = [] + dashboard._mapping = { + "id": dashboard.id, + "dashboard_title": dashboard.dashboard_title, + "slug": dashboard.slug, + "url": dashboard.url, + "published": dashboard.published, + "changed_by_name": dashboard.changed_by_name, + "changed_on": dashboard.changed_on, + "changed_on_humanized": dashboard.changed_on_humanized, + "created_by_name": dashboard.created_by_name, + "created_on": dashboard.created_on, + "created_on_humanized": dashboard.created_on_humanized, + "tags": dashboard.tags, + "owners": dashboard.owners, + "charts": [], + } + mock_list.return_value = ([dashboard], 1) + async with Client(mcp_server) as client: + request = ListDashboardsRequest(search="search_dashboard") + result = await client.call_tool( + "list_dashboards", {"request": request.model_dump()} + ) + assert result.data.count == 1 + assert result.data.dashboards[0].dashboard_title == "search_dashboard" + args, kwargs = mock_list.call_args + assert kwargs["search"] == "search_dashboard" + assert "dashboard_title" in kwargs["search_columns"] + assert "slug" in kwargs["search_columns"] + + +@patch("superset.daos.dashboard.DashboardDAO.list") +@pytest.mark.asyncio +async def test_list_dashboards_with_simple_filters(mock_list, mcp_server): + mock_list.return_value = ([], 0) + async with Client(mcp_server) as client: + filters = [ + {"col": "dashboard_title", "opr": "eq", "value": "Sales"}, + {"col": "published", "opr": "eq", "value": True}, + ] + request = ListDashboardsRequest(filters=filters) + result = await client.call_tool( + "list_dashboards", {"request": request.model_dump()} + ) + assert hasattr(result.data, "count") + + +@patch("superset.daos.dashboard.DashboardDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dashboard_info_success(mock_info, mcp_server): + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "Test Dashboard" + dashboard.slug = "test-dashboard" + dashboard.description = "Test description" + dashboard.css = None + dashboard.certified_by = None + dashboard.certification_details = None + dashboard.json_metadata = None + dashboard.position_json = None + dashboard.published = True + dashboard.is_managed_externally = False + dashboard.external_url = None + dashboard.created_on = None + dashboard.changed_on = None + dashboard.created_by = None + dashboard.changed_by = None + dashboard.uuid = None + dashboard.url = "/dashboard/1" + dashboard.thumbnail_url = None + dashboard.created_on_humanized = None + dashboard.changed_on_humanized = None + dashboard.slices = [] + dashboard.owners = [] + dashboard.tags = [] + dashboard.roles = [] + dashboard.charts = [] + dashboard._mapping = { + "id": dashboard.id, + "dashboard_title": dashboard.dashboard_title, + "slug": dashboard.slug, + "url": dashboard.url, + "published": dashboard.published, + "changed_by_name": dashboard.changed_by_name, + "changed_on": dashboard.changed_on, + "changed_on_humanized": dashboard.changed_on_humanized, + "created_by_name": dashboard.created_by_name, + "created_on": dashboard.created_on, + "created_on_humanized": dashboard.created_on_humanized, + "tags": dashboard.tags, + "owners": dashboard.owners, + "charts": [], + } + mock_info.return_value = dashboard # Only the dashboard object + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dashboard_info", {"request": {"identifier": 1}} + ) + assert result.data["dashboard_title"] == "Test Dashboard" + + +@patch("superset.daos.dashboard.DashboardDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dashboard_info_not_found(mock_info, mcp_server): + mock_info.return_value = None # Not found returns None + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dashboard_info", {"request": {"identifier": 999}} + ) + assert result.data["error_type"] == "not_found" + + +@patch("superset.daos.dashboard.DashboardDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dashboard_info_access_denied(mock_info, mcp_server): + mock_info.return_value = None # Access denied returns None + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dashboard_info", {"request": {"identifier": 1}} + ) + assert result.data["error_type"] == "not_found" + + +# TODO (Phase 3+): Add tests for get_dashboard_available_filters tool + + +@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object") +@pytest.mark.asyncio +async def test_get_dashboard_info_by_uuid(mock_find_object, mcp_server): + """Test getting dashboard info using UUID identifier.""" + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "Test Dashboard UUID" + dashboard.slug = "test-dashboard-uuid" + dashboard.description = "Test description" + dashboard.css = "" + dashboard.certified_by = None + dashboard.certification_details = None + dashboard.json_metadata = "{}" + dashboard.position_json = "{}" + dashboard.published = True + dashboard.is_managed_externally = False + dashboard.external_url = None + dashboard.created_on = None + dashboard.changed_on = None + dashboard.created_by = None + dashboard.changed_by = None + dashboard.uuid = "c3d4e5f6-g7h8-9012-cdef-gh3456789012" + dashboard.url = "/dashboard/1" + dashboard.thumbnail_url = None + dashboard.created_on_humanized = "2 days ago" + dashboard.changed_on_humanized = "1 day ago" + dashboard.slices = [] + dashboard.owners = [] + dashboard.tags = [] + dashboard.roles = [] + + mock_find_object.return_value = dashboard + async with Client(mcp_server) as client: + uuid_str = "c3d4e5f6-g7h8-9012-cdef-gh3456789012" + result = await client.call_tool( + "get_dashboard_info", {"request": {"identifier": uuid_str}} + ) + assert result.data["dashboard_title"] == "Test Dashboard UUID" + + +@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object") +@pytest.mark.asyncio +async def test_get_dashboard_info_by_slug(mock_find_object, mcp_server): + """Test getting dashboard info using slug identifier.""" + dashboard = Mock() + dashboard.id = 2 + dashboard.dashboard_title = "Test Dashboard Slug" + dashboard.slug = "test-dashboard-slug" + dashboard.description = "Test description" + dashboard.css = "" + dashboard.certified_by = None + dashboard.certification_details = None + dashboard.json_metadata = "{}" + dashboard.position_json = "{}" + dashboard.published = True + dashboard.is_managed_externally = False + dashboard.external_url = None + dashboard.created_on = None + dashboard.changed_on = None + dashboard.created_by = None + dashboard.changed_by = None + dashboard.uuid = "d4e5f6g7-h8i9-0123-defg-hi4567890123" + dashboard.url = "/dashboard/2" + dashboard.thumbnail_url = None + dashboard.created_on_humanized = "2 days ago" + dashboard.changed_on_humanized = "1 day ago" + dashboard.slices = [] + dashboard.owners = [] + dashboard.tags = [] + dashboard.roles = [] + + mock_find_object.return_value = dashboard + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dashboard_info", {"request": {"identifier": "test-dashboard-slug"}} + ) + assert result.data["dashboard_title"] == "Test Dashboard Slug" + + +@patch("superset.daos.dashboard.DashboardDAO.list") +@pytest.mark.asyncio +async def test_list_dashboards_custom_uuid_slug_columns(mock_list, mcp_server): + """Test that custom column selection includes UUID and slug when explicitly + requested.""" + dashboard = Mock() + dashboard.id = 1 + dashboard.dashboard_title = "Custom Columns Dashboard" + dashboard.slug = "custom-dashboard" + dashboard.uuid = "test-custom-uuid-123" + dashboard.url = "/dashboard/1" + dashboard.published = True + dashboard.changed_by_name = "admin" + dashboard.changed_on = None + dashboard.changed_on_humanized = None + dashboard.created_by_name = "admin" + dashboard.created_on = None + dashboard.created_on_humanized = None + dashboard.tags = [] + dashboard.owners = [] + dashboard.slices = [] + dashboard.description = None + dashboard.css = None + dashboard.certified_by = None + dashboard.certification_details = None + dashboard.json_metadata = None + dashboard.position_json = None + dashboard.is_managed_externally = False + dashboard.external_url = None + dashboard.thumbnail_url = None + dashboard.roles = [] + dashboard.charts = [] + dashboard._mapping = { + "id": dashboard.id, + "dashboard_title": dashboard.dashboard_title, + "slug": dashboard.slug, + "uuid": dashboard.uuid, + "url": dashboard.url, + "published": dashboard.published, + "changed_by_name": dashboard.changed_by_name, + "changed_on": dashboard.changed_on, + "changed_on_humanized": dashboard.changed_on_humanized, + "created_by_name": dashboard.created_by_name, + "created_on": dashboard.created_on, + "created_on_humanized": dashboard.created_on_humanized, + "tags": dashboard.tags, + "owners": dashboard.owners, + "charts": [], + } + mock_list.return_value = ([dashboard], 1) + async with Client(mcp_server) as client: + request = ListDashboardsRequest( + select_columns=["id", "dashboard_title", "uuid", "slug"], + page=1, + page_size=10, + ) + result = await client.call_tool( + "list_dashboards", {"request": request.model_dump()} + ) + dashboards = result.data.dashboards + assert len(dashboards) == 1 + assert dashboards[0].uuid == "test-custom-uuid-123" + assert dashboards[0].slug == "custom-dashboard" + + # Verify custom columns include UUID and slug + assert "uuid" in result.data.columns_requested + assert "slug" in result.data.columns_requested + assert "uuid" in result.data.columns_loaded + assert "slug" in result.data.columns_loaded + + +class TestDashboardSortableColumns: + """Test sortable columns configuration for dashboard tools.""" + + def test_dashboard_sortable_columns_definition(self): + """Test that dashboard sortable columns are properly defined.""" + from superset.mcp_service.dashboard.tool.list_dashboards import ( + SORTABLE_DASHBOARD_COLUMNS, + ) + + assert SORTABLE_DASHBOARD_COLUMNS == [ + "id", + "dashboard_title", + "slug", + "published", + "changed_on", + "created_on", + ] + # Ensure no computed properties are included + assert "changed_on_delta_humanized" not in SORTABLE_DASHBOARD_COLUMNS + assert "changed_by_name" not in SORTABLE_DASHBOARD_COLUMNS + assert "uuid" not in SORTABLE_DASHBOARD_COLUMNS + + @patch("superset.daos.dashboard.DashboardDAO.list") + @pytest.mark.asyncio + async def test_list_dashboards_with_valid_order_column(self, mock_list, mcp_server): + """Test list_dashboards with valid order column.""" + mock_list.return_value = ([], 0) + + async with Client(mcp_server) as client: + # Test with valid sortable column + request = ListDashboardsRequest( + order_column="dashboard_title", order_direction="desc" + ) + result = await client.call_tool( + "list_dashboards", {"request": request.model_dump()} + ) + + # Verify the DAO was called with the correct order column + mock_list.assert_called_once() + call_args = mock_list.call_args[1] + assert call_args["order_column"] == "dashboard_title" + assert call_args["order_direction"] == "desc" + + # Verify the result + assert result.data.count == 0 + assert result.data.dashboards == [] + + def test_sortable_columns_in_docstring(self): + """Test that sortable columns are documented in tool docstring.""" + from superset.mcp_service.dashboard.tool.list_dashboards import ( + list_dashboards, + SORTABLE_DASHBOARD_COLUMNS, + ) + + # Check list_dashboards docstring (stored in description after @mcp.tool) + assert hasattr(list_dashboards, "description") + assert "Sortable columns for order_column:" in list_dashboards.description + for col in SORTABLE_DASHBOARD_COLUMNS: + assert col in list_dashboards.description diff --git a/tests/unit_tests/mcp_service/dataset/__init__.py b/tests/unit_tests/mcp_service/dataset/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/dataset/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/dataset/tool/__init__.py b/tests/unit_tests/mcp_service/dataset/tool/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/dataset/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py new file mode 100644 index 00000000000..6ef51191cec --- /dev/null +++ b/tests/unit_tests/mcp_service/dataset/tool/test_dataset_tools.py @@ -0,0 +1,1231 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + + +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +import fastmcp +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError + +from superset.mcp_service.app import mcp +from superset.mcp_service.dataset.schemas import ListDatasetsRequest +from superset.utils import json + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +def create_mock_dataset( + dataset_id=1, + table_name="Test DatasetInfo", + schema="main", + database_name="examples", + columns=None, + metrics=None, +): + """Factory function to create mock dataset objects with sensible defaults.""" + dataset = MagicMock() + dataset.id = dataset_id + dataset.table_name = table_name + dataset.schema = schema + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = f"[{database_name}].[{schema}]" + dataset.url = f"/tablemodelview/edit/{dataset_id}" + dataset.database = MagicMock() + dataset.database.database_name = database_name + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset.uuid = f"test-dataset-uuid-{dataset_id}" + dataset.columns = columns or [] + dataset.metrics = metrics or [] + return dataset + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + from unittest.mock import Mock, patch + + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_basic(mock_list, mcp_server): + """Test basic dataset listing functionality. + + Note: Dataset tests use json.loads(result.content[0].text) pattern + for response parsing, which differs from dashboard/chart tests that + use result.data directly. This is intentional based on how the + dataset tool responses are structured. + """ + dataset = MagicMock() + dataset.id = 1 + dataset.table_name = "Test DatasetInfo" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/1" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset.uuid = "test-dataset-uuid-1" + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + col2 = MagicMock() + col2.column_name = "name" + col2.verbose_name = "Name" + col2.type = "VARCHAR" + col2.is_dttm = False + col2.groupby = True + col2.filterable = True + col2.description = "Name column" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1, col2] + dataset.metrics = [metric1] + dataset._mapping = { + "id": dataset.id, + "table_name": dataset.table_name, + "schema": dataset.schema, + "database_name": dataset.database.database_name, + "description": dataset.description, + "changed_by_name": dataset.changed_by_name, + "changed_on": dataset.changed_on, + "changed_on_humanized": dataset.changed_on_humanized, + "created_by_name": dataset.created_by_name, + "created_on": dataset.created_on, + "created_on_humanized": dataset.created_on_humanized, + "tags": dataset.tags, + "owners": dataset.owners, + "is_virtual": dataset.is_virtual, + "database_id": dataset.database_id, + "schema_perm": dataset.schema_perm, + "url": dataset.url, + "sql": dataset.sql, + "main_dttm_col": dataset.main_dttm_col, + "offset": dataset.offset, + "cache_timeout": dataset.cache_timeout, + "params": dataset.params, + "template_params": dataset.template_params, + "extra": dataset.extra, + } + mock_list.return_value = ([dataset], 1) + async with Client(mcp_server) as client: + request = ListDatasetsRequest(page=1, page_size=10) + result = await client.call_tool( + "list_datasets", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 1 + assert data["datasets"][0]["table_name"] == "Test DatasetInfo" + assert data["datasets"][0]["uuid"] == "test-dataset-uuid-1" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 2 + assert len(data["datasets"][0]["metrics"]) == 1 + assert data["datasets"][0]["columns"][0]["column_name"] == "id" + assert data["datasets"][0]["metrics"][0]["metric_name"] == "count" + + # Verify UUID is in default columns (datasets don't have slugs) + assert "uuid" in data["columns_requested"] + assert "uuid" in data["columns_loaded"] + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_custom_uuid_columns(mock_list, mcp_server): + """Test that custom column selection includes UUID when explicitly requested.""" + dataset = MagicMock() + dataset.id = 1 + dataset.table_name = "custom_dataset" + dataset.schema = "public" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[public]" + dataset.url = "/tablemodelview/edit/1" + dataset.database = MagicMock() + dataset.database.database_name = "test_db" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset.uuid = "test-custom-dataset-uuid" + dataset.columns = [] + dataset.metrics = [] + dataset._mapping = { + "id": dataset.id, + "table_name": dataset.table_name, + "schema": dataset.schema, + "database_name": dataset.database.database_name, + "uuid": dataset.uuid, + "description": dataset.description, + "changed_by_name": dataset.changed_by_name, + "changed_on": dataset.changed_on, + "changed_on_humanized": dataset.changed_on_humanized, + "created_by_name": dataset.created_by_name, + "created_on": dataset.created_on, + "created_on_humanized": dataset.created_on_humanized, + "tags": dataset.tags, + "owners": dataset.owners, + "is_virtual": dataset.is_virtual, + "database_id": dataset.database_id, + "schema_perm": dataset.schema_perm, + "url": dataset.url, + "sql": dataset.sql, + "main_dttm_col": dataset.main_dttm_col, + "offset": dataset.offset, + "cache_timeout": dataset.cache_timeout, + "params": dataset.params, + "template_params": dataset.template_params, + "extra": dataset.extra, + "columns": dataset.columns, + "metrics": dataset.metrics, + } + mock_list.return_value = ([dataset], 1) + async with Client(mcp_server) as client: + request = ListDatasetsRequest( + select_columns=["id", "table_name", "uuid"], page=1, page_size=10 + ) + result = await client.call_tool( + "list_datasets", {"request": request.model_dump()} + ) + data = json.loads(result.content[0].text) + assert data["count"] == 1 + assert data["datasets"][0]["uuid"] == "test-custom-dataset-uuid" + + # Verify custom columns include UUID + assert "uuid" in data["columns_requested"] + assert "uuid" in data["columns_loaded"] + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_with_filters(mock_list, mcp_server): + dataset = MagicMock() + dataset.id = 2 + dataset.table_name = "Filtered Dataset" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/2" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "sum" + metric1.verbose_name = "Sum" + metric1.expression = "SUM(value)" + metric1.description = "Sum of values" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] + dataset._mapping = { + "id": dataset.id, + "table_name": dataset.table_name, + "schema": dataset.schema, + "database_name": dataset.database.database_name, + "description": dataset.description, + "changed_by_name": dataset.changed_by_name, + "changed_on": dataset.changed_on, + "changed_on_humanized": dataset.changed_on_humanized, + "created_by_name": dataset.created_by_name, + "created_on": dataset.created_on, + "created_on_humanized": dataset.created_on_humanized, + "tags": dataset.tags, + "owners": dataset.owners, + "is_virtual": dataset.is_virtual, + "database_id": dataset.database_id, + "schema_perm": dataset.schema_perm, + "url": dataset.url, + "sql": dataset.sql, + "main_dttm_col": dataset.main_dttm_col, + "offset": dataset.offset, + "cache_timeout": dataset.cache_timeout, + "params": dataset.params, + "template_params": dataset.template_params, + "extra": dataset.extra, + } + mock_list.return_value = ([dataset], 1) + filters = [ + {"col": "table_name", "opr": "sw", "value": "Sales"}, + {"col": "schema", "opr": "eq", "value": "main"}, + ] + async with Client(mcp_server) as client: + request = ListDatasetsRequest( + filters=filters, + select_columns=["id", "table_name"], + order_column="changed_on", + order_direction="desc", + page=1, + page_size=50, + ) + result = await client.call_tool( + "list_datasets", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 2 + assert data["datasets"][0]["table_name"] == "Filtered Dataset" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_with_string_filters(mock_list, mcp_server): + dataset = MagicMock() + dataset.id = 3 + dataset.table_name = "String Filter Dataset" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/3" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset._mapping = { + "id": dataset.id, + "table_name": dataset.table_name, + "schema": dataset.schema, + "database_name": dataset.database.database_name, + "description": dataset.description, + "changed_by_name": dataset.changed_by_name, + "changed_on": dataset.changed_on, + "changed_on_humanized": dataset.changed_on_humanized, + "created_by_name": dataset.created_by_name, + "created_on": dataset.created_on, + "created_on_humanized": dataset.created_on_humanized, + "tags": dataset.tags, + "owners": dataset.owners, + "is_virtual": dataset.is_virtual, + "database_id": dataset.database_id, + "schema_perm": dataset.schema_perm, + "url": dataset.url, + "sql": dataset.sql, + "main_dttm_col": dataset.main_dttm_col, + "offset": dataset.offset, + "cache_timeout": dataset.cache_timeout, + "params": dataset.params, + "template_params": dataset.template_params, + "extra": dataset.extra, + } + mock_list.return_value = ([dataset], 1) + async with Client(mcp_server) as client: # noqa: F841 + with pytest.raises(ValueError, match="Input should be a valid list"): + # This should fail validation since filters expects a list, not a string + ListDatasetsRequest( # noqa: F841 + filters='[{"col": "table_name", "opr": "sw", "value": "Sales"}]' + ) + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_api_error(mock_list, mcp_server): + mock_list.side_effect = ToolError("API request failed") + async with Client(mcp_server) as client: + with pytest.raises(ToolError) as excinfo: # noqa: PT012 + request = ListDatasetsRequest() + await client.call_tool("list_datasets", {"request": request.model_dump()}) + assert "API request failed" in str(excinfo.value) + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_with_search(mock_list, mcp_server): + dataset = MagicMock() + dataset.id = 1 + dataset.table_name = "search_table" + dataset.schema = "public" + dataset.database_name = "test_db" + dataset.database = None + dataset.description = "A test dataset" + dataset.changed_by = "admin" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by = "admin" + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = None + dataset.url = None + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] + dataset._mapping = { + "id": dataset.id, + "table_name": dataset.table_name, + "schema": dataset.schema, + "database_name": dataset.database_name, + "description": dataset.description, + "changed_by_name": dataset.changed_by_name, + "changed_on": dataset.changed_on, + "changed_on_humanized": dataset.changed_on_humanized, + "created_by_name": dataset.created_by_name, + "created_on": dataset.created_on, + "created_on_humanized": dataset.created_on_humanized, + "tags": dataset.tags, + "owners": dataset.owners, + "is_virtual": dataset.is_virtual, + "database_id": dataset.database_id, + "schema_perm": dataset.schema_perm, + "url": dataset.url, + "sql": dataset.sql, + "main_dttm_col": dataset.main_dttm_col, + "offset": dataset.offset, + "cache_timeout": dataset.cache_timeout, + "params": dataset.params, + "template_params": dataset.template_params, + "extra": dataset.extra, + } + mock_list.return_value = ([dataset], 1) + async with Client(mcp_server) as client: + request = ListDatasetsRequest(search="search_table") + result = await client.call_tool( + "list_datasets", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 1 + assert data["datasets"][0]["table_name"] == "search_table" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_simple_with_search(mock_list, mcp_server): + dataset = MagicMock() + dataset.id = 2 + dataset.table_name = "simple_search" + dataset.schema = "analytics" + dataset.database_name = "analytics_db" + dataset.database = None + dataset.description = "Another test dataset" + dataset.changed_by = "user" + dataset.changed_by_name = "user" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by = "user" + dataset.created_by_name = "user" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = True + dataset.database_id = 2 + dataset.schema_perm = None + dataset.url = None + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] + dataset._mapping = { + "id": dataset.id, + "table_name": dataset.table_name, + "schema": dataset.schema, + "database_name": dataset.database_name, + "description": dataset.description, + "changed_by_name": dataset.changed_by_name, + "changed_on": dataset.changed_on, + "changed_on_humanized": dataset.changed_on_humanized, + "created_by_name": dataset.created_by_name, + "created_on": dataset.created_on, + "created_on_humanized": dataset.created_on_humanized, + "tags": dataset.tags, + "owners": dataset.owners, + "is_virtual": dataset.is_virtual, + "database_id": dataset.database_id, + "schema_perm": dataset.schema_perm, + "url": dataset.url, + "sql": dataset.sql, + "main_dttm_col": dataset.main_dttm_col, + "offset": dataset.offset, + "cache_timeout": dataset.cache_timeout, + "params": dataset.params, + "template_params": dataset.template_params, + "extra": dataset.extra, + } + mock_list.return_value = ([dataset], 1) + async with Client(mcp_server) as client: + request = ListDatasetsRequest(search="simple_search") + result = await client.call_tool( + "list_datasets", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 2 + assert data["datasets"][0]["table_name"] == "simple_search" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_simple_basic(mock_list, mcp_server): + dataset = MagicMock() + dataset.id = 1 + dataset.table_name = "Test DatasetInfo" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/1" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] + dataset._mapping = { + "id": dataset.id, + "table_name": dataset.table_name, + "schema": dataset.schema, + "database_name": dataset.database.database_name, + "description": dataset.description, + "changed_by_name": dataset.changed_by_name, + "changed_on": dataset.changed_on, + "changed_on_humanized": dataset.changed_on_humanized, + "created_by_name": dataset.created_by_name, + "created_on": dataset.created_on, + "created_on_humanized": dataset.created_on_humanized, + "tags": dataset.tags, + "owners": dataset.owners, + "is_virtual": dataset.is_virtual, + "database_id": dataset.database_id, + "schema_perm": dataset.schema_perm, + "url": dataset.url, + "sql": dataset.sql, + "main_dttm_col": dataset.main_dttm_col, + "offset": dataset.offset, + "cache_timeout": dataset.cache_timeout, + "params": dataset.params, + "template_params": dataset.template_params, + "extra": dataset.extra, + } + mock_list.return_value = ([dataset], 1) + filters = [ + {"col": "table_name", "opr": "eq", "value": "Test DatasetInfo"}, + {"col": "schema", "opr": "eq", "value": "main"}, + ] + async with Client(mcp_server) as client: + request = ListDatasetsRequest(filters=filters) + result = await client.call_tool( + "list_datasets", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 1 + assert data["datasets"][0]["table_name"] == "Test DatasetInfo" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_simple_with_filters(mock_list, mcp_server): + dataset = MagicMock() + dataset.id = 2 + dataset.table_name = "Sales Dataset" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/2" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "sum" + metric1.verbose_name = "Sum" + metric1.expression = "SUM(value)" + metric1.description = "Sum of values" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] + dataset._mapping = { + "id": dataset.id, + "table_name": dataset.table_name, + "schema": dataset.schema, + "database_name": dataset.database.database_name, + "description": dataset.description, + "changed_by_name": dataset.changed_by_name, + "changed_on": dataset.changed_on, + "changed_on_humanized": dataset.changed_on_humanized, + "created_by_name": dataset.created_by_name, + "created_on": dataset.created_on, + "created_on_humanized": dataset.created_on_humanized, + "tags": dataset.tags, + "owners": dataset.owners, + "is_virtual": dataset.is_virtual, + "database_id": dataset.database_id, + "schema_perm": dataset.schema_perm, + "url": dataset.url, + "sql": dataset.sql, + "main_dttm_col": dataset.main_dttm_col, + "offset": dataset.offset, + "cache_timeout": dataset.cache_timeout, + "params": dataset.params, + "template_params": dataset.template_params, + "extra": dataset.extra, + } + mock_list.return_value = ([dataset], 1) + filters = [ + {"col": "table_name", "opr": "sw", "value": "Sales"}, + {"col": "schema", "opr": "eq", "value": "main"}, + ] + async with Client(mcp_server) as client: + request = ListDatasetsRequest(filters=filters) + result = await client.call_tool( + "list_datasets", {"request": request.model_dump()} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["datasets"] is not None + assert len(data["datasets"]) == 1 + assert data["datasets"][0]["id"] == 2 + assert data["datasets"][0]["table_name"] == "Sales Dataset" + # Check that columns and metrics are included + assert len(data["datasets"][0]["columns"]) == 1 + assert len(data["datasets"][0]["metrics"]) == 1 + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_simple_api_error(mock_list, mcp_server): + mock_list.side_effect = Exception("API request failed") + filters = [ + {"col": "table_name", "opr": "sw", "value": "Sales"}, + {"col": "schema", "opr": "eq", "value": "main"}, + ] + async with Client(mcp_server) as client: + with pytest.raises(ToolError) as excinfo: # noqa: PT012 + request = ListDatasetsRequest(filters=filters) + await client.call_tool("list_datasets", {"request": request.model_dump()}) + assert "API request failed" in str(excinfo.value) + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dataset_info_success(mock_info, mcp_server): + dataset = MagicMock() + dataset.id = 1 + dataset.table_name = "Test DatasetInfo" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/1" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + # Add proper mock columns and metrics + col1 = MagicMock() + col1.column_name = "id" + col1.verbose_name = "ID" + col1.type = "INTEGER" + col1.is_dttm = False + col1.groupby = True + col1.filterable = True + col1.description = "Primary key" + + metric1 = MagicMock() + metric1.metric_name = "count" + metric1.verbose_name = "Count" + metric1.expression = "COUNT(*)" + metric1.description = "Row count" + metric1.d3format = None + + dataset.columns = [col1] + dataset.metrics = [metric1] + mock_info.return_value = dataset + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dataset_info", {"request": {"identifier": 1}} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["table_name"] == "Test DatasetInfo" + assert data["database_name"] == "examples" + # Check that columns and metrics are included + assert len(data["columns"]) == 1 + assert len(data["metrics"]) == 1 + assert data["columns"][0]["column_name"] == "id" + assert data["metrics"][0]["metric_name"] == "count" + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dataset_info_not_found(mock_info, mcp_server): + mock_info.return_value = None # Not found returns None + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dataset_info", {"request": {"identifier": 999}} + ) + assert result.data["error_type"] == "not_found" + + +# TODO (Phase 3+): Add tests for get_dataset_available_filters tool + + +@pytest.mark.asyncio +async def test_invalid_filter_column_raises(mcp_server): + async with fastmcp.Client(mcp_server) as client: # noqa: F841 + with pytest.raises(ValueError, match="Input should be"): + # This should fail validation at the schema level due to invalid column name + ListDatasetsRequest( # noqa: F841 + filters=[{"col": "not_a_column", "opr": "eq", "value": "foo"}] + ) + + +@patch("superset.daos.dataset.DatasetDAO.find_by_id") +@pytest.mark.asyncio +async def test_get_dataset_info_includes_columns_and_metrics(mock_info, mcp_server): + dataset = MagicMock() + dataset.id = 10 + dataset.table_name = "Dataset With Columns" + dataset.schema = "main" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/10" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset.columns = [ + MagicMock( + column_name="col1", + verbose_name="Column 1", + type="INTEGER", + is_dttm=False, + groupby=True, + filterable=True, + description="First column", + ), + MagicMock( + column_name="col2", + verbose_name="Column 2", + type="VARCHAR", + is_dttm=False, + groupby=False, + filterable=True, + description="Second column", + ), + ] + dataset.metrics = [ + MagicMock( + metric_name="sum_sales", + verbose_name="Sum Sales", + expression="SUM(sales)", + description="Total sales", + d3format=None, + ), + MagicMock( + metric_name="count_orders", + verbose_name="Count Orders", + expression="COUNT(orders)", + description="Order count", + d3format=None, + ), + ] + mock_info.return_value = dataset + async with Client(mcp_server) as client: + result = await client.call_tool( + "get_dataset_info", {"request": {"identifier": 10}} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["table_name"] == "Dataset With Columns" + assert data["database_name"] == "examples" + # Check that columns and metrics are included + assert len(data["columns"]) == 2 + assert len(data["metrics"]) == 2 + assert data["columns"][0]["column_name"] == "col1" + assert data["columns"][1]["column_name"] == "col2" + assert data["metrics"][0]["metric_name"] == "sum_sales" + assert data["metrics"][1]["metric_name"] == "count_orders" + + +@patch("superset.daos.dataset.DatasetDAO.list") +@pytest.mark.asyncio +async def test_list_datasets_includes_columns_and_metrics(mock_list, mcp_server): + dataset = MagicMock() + dataset.id = 11 + dataset.table_name = "DatasetList With Columns" + dataset.schema = "main" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/11" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset.columns = [ + MagicMock( + column_name="colA", + verbose_name="Column A", + type="FLOAT", + is_dttm=False, + groupby=True, + filterable=True, + description="A column", + ), + ] + dataset.metrics = [ + MagicMock( + metric_name="avg_value", + verbose_name="Avg Value", + expression="AVG(value)", + description="Average value", + d3format=None, + ), + ] + mock_list.return_value = ([dataset], 1) + async with Client(mcp_server) as client: + request = ListDatasetsRequest(page=1, page_size=10) + result = await client.call_tool( + "list_datasets", {"request": request.model_dump()} + ) + datasets = result.data.datasets + assert len(datasets) == 1 + ds = datasets[0] + assert hasattr(ds, "columns") + assert hasattr(ds, "metrics") + assert isinstance(ds.columns, list) + assert isinstance(ds.metrics, list) + assert len(ds.columns) == 1 + assert len(ds.metrics) == 1 + assert ds.columns[0].column_name == "colA" + assert ds.metrics[0].metric_name == "avg_value" + + +@patch("superset.mcp_service.mcp_core.ModelGetInfoCore._find_object") +@pytest.mark.asyncio +async def test_get_dataset_info_by_uuid(mock_find_object, mcp_server): + """Test getting dataset info using UUID identifier.""" + dataset = MagicMock() + dataset.id = 1 + dataset.table_name = "Test Dataset UUID" + dataset.schema = "main" + dataset.description = "desc" + dataset.changed_by_name = "admin" + dataset.changed_on = None + dataset.changed_on_humanized = None + dataset.created_by_name = "admin" + dataset.created_on = None + dataset.created_on_humanized = None + dataset.tags = [] + dataset.owners = [] + dataset.is_virtual = False + dataset.database_id = 1 + dataset.schema_perm = "[examples].[main]" + dataset.url = "/tablemodelview/edit/1" + dataset.database = MagicMock() + dataset.database.database_name = "examples" + dataset.sql = None + dataset.main_dttm_col = None + dataset.offset = 0 + dataset.cache_timeout = 0 + dataset.params = {} + dataset.template_params = {} + dataset.extra = {} + dataset.columns = [] + dataset.metrics = [] + + mock_find_object.return_value = dataset + async with Client(mcp_server) as client: + uuid_str = "a1b2c3d4-e5f6-7890-abcd-ef1234567890" + result = await client.call_tool( + "get_dataset_info", {"request": {"identifier": uuid_str}} + ) + assert result.content is not None + data = json.loads(result.content[0].text) + assert data["id"] == 1 + assert data["table_name"] == "Test Dataset UUID" + + +class TestDatasetSortableColumns: + """Test sortable columns configuration for dataset tools.""" + + def test_dataset_sortable_columns_definition(self): + """Test that dataset sortable columns are properly defined.""" + from superset.mcp_service.dataset.tool.list_datasets import ( + SORTABLE_DATASET_COLUMNS, + ) + + assert SORTABLE_DATASET_COLUMNS == [ + "id", + "table_name", + "schema", + "changed_on", + "created_on", + ] + # Ensure no computed properties are included + assert "changed_on_delta_humanized" not in SORTABLE_DATASET_COLUMNS + assert "changed_by_name" not in SORTABLE_DATASET_COLUMNS + assert "database_name" not in SORTABLE_DATASET_COLUMNS + assert "uuid" not in SORTABLE_DATASET_COLUMNS + + @patch("superset.daos.dataset.DatasetDAO") + @patch("superset.mcp_service.auth.get_user_from_request") + @pytest.mark.asyncio + async def test_list_datasets_with_valid_order_column( + self, mock_get_user, mock_dataset_dao + ): + """Test list_datasets with valid order column.""" + from superset.mcp_service.dataset.tool.list_datasets import list_datasets + + mock_get_user.return_value = MagicMock(id=1) + mock_tool = MagicMock() + mock_tool.run_tool.return_value = MagicMock(datasets=[], count=0) + mock_ctx = MagicMock() + mock_ctx.info = AsyncMock() + mock_ctx.debug = AsyncMock() + + with patch( + "superset.mcp_service.dataset.tool.list_datasets.ModelListCore", + return_value=mock_tool, + ): + # Test with valid sortable column + request = ListDatasetsRequest( + order_column="table_name", order_direction="asc" + ) + await list_datasets.fn(request, mock_ctx) + + # Verify the tool was called with the correct order column + mock_tool.run_tool.assert_called_once() + call_args = mock_tool.run_tool.call_args[1] + assert call_args["order_column"] == "table_name" + assert call_args["order_direction"] == "asc" + + def test_sortable_columns_in_docstring(self): + """Test that sortable columns are documented in tool docstring.""" + from superset.mcp_service.dataset.tool.list_datasets import ( + list_datasets, + SORTABLE_DATASET_COLUMNS, + ) + + # Check list_datasets docstring (stored in description after @mcp.tool) + assert hasattr(list_datasets, "description") + assert "Sortable columns for order_column:" in list_datasets.description + for col in SORTABLE_DATASET_COLUMNS: + assert col in list_datasets.description + + @pytest.mark.asyncio + async def test_default_ordering(self): + """Test default ordering behavior for datasets.""" + from superset.mcp_service.dataset.tool.list_datasets import list_datasets + + # Test that when no order_column is specified, None is passed + with patch( + "superset.mcp_service.dataset.tool.list_datasets.ModelListCore" + ) as mock_tool: + with patch("superset.mcp_service.auth.get_user_from_request"): + mock_tool.return_value.run_tool.return_value = MagicMock( + datasets=[], count=0 + ) + mock_ctx = MagicMock() + mock_ctx.info = AsyncMock() + mock_ctx.debug = AsyncMock() + request = ListDatasetsRequest() # No order specified + await list_datasets.fn(request, mock_ctx) + + call_args = mock_tool.return_value.run_tool.call_args[1] + assert call_args["order_column"] is None # None when not specified + assert ( + call_args["order_direction"] == "desc" + ) # From ListDatasetsRequest default diff --git a/tests/unit_tests/mcp_service/explore/__init__.py b/tests/unit_tests/mcp_service/explore/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/explore/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/explore/tool/__init__.py b/tests/unit_tests/mcp_service/explore/tool/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/explore/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py new file mode 100644 index 00000000000..56588666db4 --- /dev/null +++ b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py @@ -0,0 +1,580 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Comprehensive unit tests for MCP generate_explore_link tool +""" + +import logging +from unittest.mock import Mock, patch + +import pytest +from fastmcp import Client + +from superset.mcp_service.app import mcp +from superset.mcp_service.chart.schemas import ( + AxisConfig, + ColumnRef, + FilterConfig, + GenerateExploreLinkRequest, + LegendConfig, + TableChartConfig, + XYChartConfig, +) + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +def _mock_dataset(id: int = 1) -> Mock: + """Create a mock dataset object.""" + dataset = Mock() + dataset.id = id + return dataset + + +class TestGenerateExploreLink: + """Comprehensive tests for generate_explore_link MCP tool.""" + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_table_explore_link_minimal( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test generating explore link for minimal table chart.""" + mock_create_form_data.return_value = "test_form_data_key_123" + mock_find_dataset.return_value = _mock_dataset(id=1) + + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="region")] + ) + request = GenerateExploreLinkRequest(dataset_id="1", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=test_form_data_key_123" + ) + mock_create_form_data.assert_called_once() + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_table_explore_link_with_features( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test generating explore link for table chart with features.""" + mock_create_form_data.return_value = "comprehensive_key_456" + mock_find_dataset.return_value = _mock_dataset(id=5) + + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="region", label="Sales Region"), + ColumnRef(name="revenue", label="Total Revenue", aggregate="SUM"), + ColumnRef(name="orders", label="Order Count", aggregate="COUNT"), + ], + filters=[ + FilterConfig(column="year", op="=", value="2024"), + FilterConfig(column="status", op="!=", value="cancelled"), + ], + sort_by=["revenue", "orders"], + ) + request = GenerateExploreLinkRequest(dataset_id="5", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=comprehensive_key_456" + ) + mock_create_form_data.assert_called_once() + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_line_chart_explore_link( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test generating explore link for line chart.""" + mock_create_form_data.return_value = "line_chart_key_789" + mock_find_dataset.return_value = _mock_dataset(id=3) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date", label="Date"), + y=[ + ColumnRef(name="sales", label="Daily Sales", aggregate="SUM"), + ColumnRef(name="orders", label="Order Count", aggregate="COUNT"), + ], + kind="line", + group_by=ColumnRef(name="region", label="Sales Region"), + x_axis=AxisConfig(title="Time Period", format="smart_date"), + y_axis=AxisConfig(title="Sales Metrics", format="$,.2f"), + legend=LegendConfig(show=True, position="bottom"), + ) + request = GenerateExploreLinkRequest(dataset_id="3", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=line_chart_key_789" + ) + mock_create_form_data.assert_called_once() + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_bar_chart_explore_link( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test generating explore link for bar chart.""" + mock_create_form_data.return_value = "bar_chart_key_abc" + mock_find_dataset.return_value = _mock_dataset(id=7) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="product_category", label="Category"), + y=[ColumnRef(name="revenue", label="Revenue", aggregate="SUM")], + kind="bar", + group_by=ColumnRef(name="quarter", label="Quarter"), + y_axis=AxisConfig(title="Revenue ($)", format="$,.0f"), + ) + request = GenerateExploreLinkRequest(dataset_id="7", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=bar_chart_key_abc" + ) + mock_create_form_data.assert_called_once() + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_area_chart_explore_link( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test generating explore link for area chart.""" + mock_create_form_data.return_value = "area_chart_key_def" + mock_find_dataset.return_value = _mock_dataset(id=2) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="month", label="Month"), + y=[ + ColumnRef( + name="cumulative_sales", label="Cumulative Sales", aggregate="SUM" + ) + ], + kind="area", + legend=LegendConfig(show=False), + ) + request = GenerateExploreLinkRequest(dataset_id="2", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=area_chart_key_def" + ) + mock_create_form_data.assert_called_once() + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_scatter_chart_explore_link( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test generating explore link for scatter chart.""" + mock_create_form_data.return_value = "scatter_chart_key_ghi" + mock_find_dataset.return_value = _mock_dataset(id=4) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="price", label="Unit Price"), + y=[ColumnRef(name="quantity", label="Quantity Sold", aggregate="SUM")], + kind="scatter", + group_by=ColumnRef(name="product_type", label="Product Type"), + x_axis=AxisConfig(title="Price ($)", format="$,.2f"), + y_axis=AxisConfig(title="Quantity", format=",.0f"), + ) + request = GenerateExploreLinkRequest(dataset_id="4", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=scatter_chart_key_ghi" + ) + mock_create_form_data.assert_called_once() + + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_explore_link_cache_failure_fallback( + self, mock_create_form_data, mcp_server + ): + """Test fallback when form_data cache creation fails.""" + mock_create_form_data.side_effect = Exception("Cache storage failed") + + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="test_col")] + ) + request = GenerateExploreLinkRequest(dataset_id="1", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + # Should fallback to basic URL format + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?datasource_type=table&datasource_id=1" + ) + + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_explore_link_database_lock_fallback( + self, mock_create_form_data, mcp_server + ): + """Test fallback when database is locked.""" + from sqlalchemy.exc import OperationalError + + mock_create_form_data.side_effect = OperationalError( + "database is locked", None, None + ) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="sales")], + kind="line", + ) + request = GenerateExploreLinkRequest(dataset_id="5", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + # Should fallback to basic dataset URL + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?datasource_type=table&datasource_id=5" + ) + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_explore_link_with_many_columns( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test generating explore link with many columns.""" + mock_create_form_data.return_value = "many_columns_key" + mock_find_dataset.return_value = _mock_dataset(id=1) + + # Create 15 columns + columns = [ + ColumnRef( + name=f"metric_{i}", + label=f"Metric {i}", + aggregate="SUM" if i % 2 == 0 else "COUNT", + ) + for i in range(15) + ] + + config = TableChartConfig(chart_type="table", columns=columns) + request = GenerateExploreLinkRequest(dataset_id="1", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=many_columns_key" + ) + mock_create_form_data.assert_called_once() + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_explore_link_with_many_filters( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test generating explore link with many filters.""" + mock_create_form_data.return_value = "many_filters_key" + mock_find_dataset.return_value = _mock_dataset(id=1) + + # Create 12 filters + filters = [ + FilterConfig( + column=f"filter_col_{i}", + op="=" if i % 3 == 0 else "!=", + value=f"value_{i}", + ) + for i in range(12) + ] + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x_col"), + y=[ColumnRef(name="y_col")], + kind="bar", + filters=filters, + ) + request = GenerateExploreLinkRequest(dataset_id="1", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=many_filters_key" + ) + mock_create_form_data.assert_called_once() + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_explore_link_url_format_consistency( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test that all generated URLs follow consistent format.""" + mock_create_form_data.return_value = "consistency_test_key" + mock_find_dataset.return_value = _mock_dataset(id=1) + + configs = [ + TableChartConfig(chart_type="table", columns=[ColumnRef(name="col1")]), + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x"), + y=[ColumnRef(name="y")], + kind="line", + ), + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x"), + y=[ColumnRef(name="y")], + kind="bar", + ), + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x"), + y=[ColumnRef(name="y")], + kind="area", + ), + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="x"), + y=[ColumnRef(name="y")], + kind="scatter", + ), + ] + + for i, config in enumerate(configs): + request = GenerateExploreLinkRequest(dataset_id=str(i + 1), config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + # All URLs should follow the same format + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=consistency_test_key" + ) + assert result.data["error"] is None + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_explore_link_dataset_id_types( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test explore link generation with different dataset_id formats.""" + mock_create_form_data.return_value = "dataset_test_key" + mock_find_dataset.return_value = _mock_dataset(id=1) + + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="test_col")] + ) + + # Test various dataset_id formats + dataset_ids = ["1", "42", "999", "123456789"] + + for dataset_id in dataset_ids: + request = GenerateExploreLinkRequest(dataset_id=dataset_id, config=config) + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=dataset_test_key" + ) + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_generate_explore_link_complex_configuration( + self, mock_create_form_data, mock_find_dataset, mcp_server + ): + """Test explore link generation with complex chart configuration.""" + mock_create_form_data.return_value = "complex_config_key" + mock_find_dataset.return_value = _mock_dataset(id=10) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="timestamp", label="Time"), + y=[ + ColumnRef(name="sales", label="Sales", aggregate="SUM"), + ColumnRef(name="orders", label="Orders", aggregate="COUNT"), + ColumnRef(name="profit", label="Profit", aggregate="AVG"), + ], + kind="line", + group_by=ColumnRef(name="region", label="Region"), + x_axis=AxisConfig(title="Time Period", format="smart_date"), + y_axis=AxisConfig(title="Metrics", format="$,.2f", scale="linear"), + legend=LegendConfig(show=True, position="bottom"), + filters=[ + FilterConfig(column="status", op="=", value="active"), + FilterConfig(column="date", op=">=", value="2024-01-01"), + FilterConfig(column="revenue", op=">", value="1000"), + ], + ) + request = GenerateExploreLinkRequest(dataset_id="10", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + assert ( + result.data["url"] + == "http://localhost:9001/explore/?form_data_key=complex_config_key" + ) + mock_create_form_data.assert_called_once() + + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_fallback_url_different_datasets( + self, mock_create_form_data, mcp_server + ): + """Test fallback URLs are correct for different dataset IDs.""" + mock_create_form_data.side_effect = Exception( + "Always fail for fallback testing" + ) + + config = TableChartConfig(chart_type="table", columns=[ColumnRef(name="col")]) + + dataset_ids = ["1", "5", "100", "999"] + + for dataset_id in dataset_ids: + request = GenerateExploreLinkRequest(dataset_id=dataset_id, config=config) + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + # Should fallback to basic URL with correct dataset_id + expected_url = f"http://localhost:9001/explore/?datasource_type=table&datasource_id={dataset_id}" + assert result.data["error"] is None + assert result.data["url"] == expected_url diff --git a/tests/unit_tests/mcp_service/sql_lab/__init__.py b/tests/unit_tests/mcp_service/sql_lab/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/sql_lab/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/sql_lab/test_execute_sql_helper.py b/tests/unit_tests/mcp_service/sql_lab/test_execute_sql_helper.py new file mode 100644 index 00000000000..d366c234a73 --- /dev/null +++ b/tests/unit_tests/mcp_service/sql_lab/test_execute_sql_helper.py @@ -0,0 +1,64 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Helper function to extract row data from MCP responses. + +The MCP client seems to wrap dict rows in Root objects. +This helper handles the extraction properly. +""" + + +def extract_row_data(row): + """Extract dictionary data from a row object.""" + # Handle different possible formats + if isinstance(row, dict): + return row + + # Check for Pydantic Root object + if hasattr(row, "__root__"): + return row.__root__ + + # Check if it's a Pydantic model with model_dump + if hasattr(row, "model_dump"): + return row.model_dump() + + # Try to access __dict__ directly + if hasattr(row, "__dict__"): + # Filter out private attributes + return {k: v for k, v in row.__dict__.items() if not k.startswith("_")} + + # Last resort - convert to string and parse + # This is for the Root object issue + row_str = str(row) + if row_str == "Root()": + # Empty Root object - the actual data might be elsewhere + # Let's check all attributes + attrs = dir(row) + for attr in attrs: + if not attr.startswith("_") and attr not in [ + "model_dump", + "model_validate", + ]: + try: + val = getattr(row, attr) + if isinstance(val, dict): + return val + except AttributeError: + pass + + raise ValueError(f"Cannot extract data from row of type {type(row)}: {row}") diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/__init__.py b/tests/unit_tests/mcp_service/sql_lab/tool/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/mcp_service/sql_lab/tool/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py new file mode 100644 index 00000000000..b6b9d70009b --- /dev/null +++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py @@ -0,0 +1,497 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Unit tests for execute_sql MCP tool +""" + +import logging +from unittest.mock import MagicMock, Mock, patch + +import pytest +from fastmcp import Client +from fastmcp.exceptions import ToolError + +from superset.mcp_service.app import mcp + +logging.basicConfig(level=logging.DEBUG) +logger = logging.getLogger(__name__) + + +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture(autouse=True) +def mock_auth(): + """Mock authentication for all tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + mock_user = Mock() + mock_user.id = 1 + mock_user.username = "admin" + mock_get_user.return_value = mock_user + yield mock_get_user + + +def _mock_database( + id: int = 1, + database_name: str = "test_db", + allow_dml: bool = False, +) -> Mock: + """Create a mock database object.""" + database = Mock() + database.id = id + database.database_name = database_name + database.allow_dml = allow_dml + + # Mock raw connection context manager + mock_cursor = Mock() + mock_cursor.description = [ + ("id", "INTEGER", None, None, None, None, False), + ("name", "VARCHAR", None, None, None, None, True), + ] + mock_cursor.fetchmany.return_value = [(1, "test_name")] + mock_cursor.rowcount = 1 + + mock_conn = Mock() + mock_conn.cursor.return_value = mock_cursor + mock_conn.commit = Mock() + + mock_context = MagicMock() + mock_context.__enter__.return_value = mock_conn + mock_context.__exit__.return_value = None + + database.get_raw_connection.return_value = mock_context + + return database + + +class TestExecuteSql: + """Tests for execute_sql MCP tool.""" + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_basic_select( + self, mock_db, mock_security_manager, mcp_server + ): + """Test basic SELECT query execution.""" + # Setup mocks + mock_database = _mock_database() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT id, name FROM users LIMIT 10", + "limit": 10, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is True + assert result.data.error is None + assert result.data.row_count == 1 + assert len(result.data.rows) == 1 + assert result.data.rows[0]["id"] == 1 + assert result.data.rows[0]["name"] == "test_name" + assert len(result.data.columns) == 2 + assert result.data.columns[0].name == "id" + assert result.data.columns[0].type == "INTEGER" + assert result.data.execution_time > 0 + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_with_parameters( + self, mock_db, mock_security_manager, mcp_server + ): + """Test SQL execution with parameter substitution.""" + mock_database = _mock_database() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT * FROM {table} WHERE status = '{status}' LIMIT {limit}", + "parameters": {"table": "orders", "status": "active", "limit": "5"}, + "limit": 10, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is True + assert result.data.error is None + # Verify parameter substitution happened + mock_database.get_raw_connection.assert_called_once() + cursor = ( # fmt: skip + mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value + ) + # Check that the SQL was formatted with parameters + executed_sql = cursor.execute.call_args[0][0] + assert "orders" in executed_sql + assert "active" in executed_sql + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_database_not_found( + self, mock_db, mock_security_manager, mcp_server + ): + """Test error when database is not found.""" + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + None + ) + + request = { + "database_id": 999, + "sql": "SELECT 1", + "limit": 1, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is False + assert result.data.error is not None + assert "Database with ID 999 not found" in result.data.error + assert result.data.error_type == "DATABASE_NOT_FOUND_ERROR" + assert result.data.rows is None + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_access_denied( + self, mock_db, mock_security_manager, mcp_server + ): + """Test error when user lacks database access.""" + mock_database = _mock_database() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + # Use Mock instead of AsyncMock for synchronous call + from unittest.mock import Mock + + mock_security_manager.can_access_database = Mock(return_value=False) + + request = { + "database_id": 1, + "sql": "SELECT 1", + "limit": 1, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is False + assert result.data.error is not None + assert "Access denied to database" in result.data.error + assert result.data.error_type == "SECURITY_ERROR" + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_dml_not_allowed( + self, mock_db, mock_security_manager, mcp_server + ): + """Test error when DML operations are not allowed.""" + mock_database = _mock_database(allow_dml=False) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "UPDATE users SET name = 'test' WHERE id = 1", + "limit": 1, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is False + assert result.data.error is not None + assert result.data.error_type == "DML_NOT_ALLOWED" + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_dml_allowed( + self, mock_db, mock_security_manager, mcp_server + ): + """Test successful DML execution when allowed.""" + mock_database = _mock_database(allow_dml=True) + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + # Mock cursor for DML operation + cursor = ( # fmt: skip + mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value + ) + cursor.rowcount = 3 # 3 rows affected + + request = { + "database_id": 1, + "sql": "UPDATE users SET active = true WHERE last_login > '2024-01-01'", + "limit": 1, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is True + assert result.data.error is None + assert result.data.affected_rows == 3 + assert result.data.rows == [] # Empty rows for DML + assert result.data.row_count == 0 + # Verify commit was called + ( + mock_database.get_raw_connection.return_value.__enter__.return_value.commit.assert_called_once() + ) + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_empty_results( + self, mock_db, mock_security_manager, mcp_server + ): + """Test query that returns no results.""" + mock_database = _mock_database() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + # Mock empty results + cursor = ( # fmt: skip + mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value + ) + cursor.fetchmany.return_value = [] + + request = { + "database_id": 1, + "sql": "SELECT * FROM users WHERE id = 999999", + "limit": 10, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is True + assert result.data.error is None + assert result.data.row_count == 0 + assert len(result.data.rows) == 0 + assert len(result.data.columns) == 2 # Column metadata still returned + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_missing_parameter( + self, mock_db, mock_security_manager, mcp_server + ): + """Test error when required parameter is missing.""" + mock_database = _mock_database() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT * FROM {table_name} WHERE id = {user_id}", + "parameters": {"table_name": "users"}, # Missing user_id + "limit": 1, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is False + assert result.data.error is not None + assert "user_id" in result.data.error # Error contains parameter name + assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR" + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_empty_parameters_with_placeholders( + self, mock_db, mock_security_manager, mcp_server + ): + """Test error when empty parameters dict is provided but SQL has + placeholders.""" + mock_database = _mock_database() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT * FROM {table_name} LIMIT 5", + "parameters": {}, # Empty dict but SQL has {table_name} + "limit": 5, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is False + assert result.data.error is not None + assert "Missing parameter: table_name" in result.data.error + assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR" + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_with_schema( + self, mock_db, mock_security_manager, mcp_server + ): + """Test SQL execution with schema specification.""" + mock_database = _mock_database() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT COUNT(*) as total FROM orders", + "schema": "sales", + "limit": 1, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is True + assert result.data.error is None + # Verify schema was passed to get_raw_connection + # Verify schema was passed + call_args = mock_database.get_raw_connection.call_args + assert call_args[1]["schema"] == "sales" + assert call_args[1]["catalog"] is None + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_limit_enforcement( + self, mock_db, mock_security_manager, mcp_server + ): + """Test that LIMIT is added to SELECT queries without one.""" + mock_database = _mock_database() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + request = { + "database_id": 1, + "sql": "SELECT * FROM users", # No LIMIT + "limit": 50, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is True + # Verify LIMIT was added + cursor = ( # fmt: skip + mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value + ) + executed_sql = cursor.execute.call_args[0][0] + assert "LIMIT 50" in executed_sql + + @patch("superset.security_manager") + @patch("superset.db") + @pytest.mark.asyncio + async def test_execute_sql_sql_injection_prevention( + self, mock_db, mock_security_manager, mcp_server + ): + """Test that SQL injection attempts are handled safely.""" + mock_database = _mock_database() + mock_db.session.query.return_value.filter_by.return_value.first.return_value = ( + mock_database + ) + mock_security_manager.can_access_database.return_value = True + + # Mock execute to raise an exception + cursor = ( # fmt: skip + mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value + ) + cursor.execute.side_effect = Exception("Syntax error") + + request = { + "database_id": 1, + "sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--", + "limit": 10, + } + + async with Client(mcp_server) as client: + result = await client.call_tool("execute_sql", {"request": request}) + + assert result.data.success is False + assert result.data.error is not None + assert "Syntax error" in result.data.error # Contains actual error + assert result.data.error_type == "EXECUTION_ERROR" + + @pytest.mark.asyncio + async def test_execute_sql_empty_query_validation(self, mcp_server): + """Test validation of empty SQL query.""" + request = { + "database_id": 1, + "sql": " ", # Empty/whitespace only + "limit": 10, + } + + async with Client(mcp_server) as client: + with pytest.raises(ToolError, match="SQL query cannot be empty"): + await client.call_tool("execute_sql", {"request": request}) + + @pytest.mark.asyncio + async def test_execute_sql_invalid_limit(self, mcp_server): + """Test validation of invalid limit values.""" + # Test limit too low + request = { + "database_id": 1, + "sql": "SELECT 1", + "limit": 0, + } + + async with Client(mcp_server) as client: + with pytest.raises(ToolError, match="minimum of 1"): + await client.call_tool("execute_sql", {"request": request}) + + # Test limit too high + request = { + "database_id": 1, + "sql": "SELECT 1", + "limit": 20000, + } + + async with Client(mcp_server) as client: + with pytest.raises(ToolError, match="maximum of 10000"): + await client.call_tool("execute_sql", {"request": request})