Files
superset2/superset/mcp_service/common/schema_discovery.py

484 lines
16 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Schema discovery models for MCP tools.
These schemas provide comprehensive metadata about available columns,
filters, and sorting options for each model type (chart, dataset, dashboard).
Column metadata is extracted dynamically from SQLAlchemy models.
"""
from typing import Any, Literal, Type
import sqlalchemy as sa
from pydantic import BaseModel, Field
from sqlalchemy.inspection import inspect
class ColumnMetadata(BaseModel):
"""Metadata for a selectable column."""
name: str = Field(..., description="Column name to use in select_columns")
description: str | None = Field(None, description="Column description")
type: str | None = Field(None, description="Data type (str, int, datetime, etc.)")
is_default: bool = Field(
False, description="Whether this column is included by default"
)
class ModelSchemaInfo(BaseModel):
"""
Comprehensive schema information for a model type.
Provides all metadata needed for LLM clients to construct valid queries:
- Which columns can be selected
- Which columns can be used for filtering (with operators)
- Which columns can be used for sorting
- Default values for each
"""
model_type: Literal["chart", "dataset", "dashboard"] = Field(
..., description="The model type this schema describes"
)
select_columns: list[ColumnMetadata] = Field(
..., description="All columns available for selection via select_columns"
)
filter_columns: dict[str, list[str]] = Field(
..., description="Filterable columns mapped to their supported operators"
)
sortable_columns: list[str] = Field(
..., description="Columns that can be used with order_column"
)
default_select: list[str] = Field(
..., description="Columns returned when select_columns is not specified"
)
default_sort: str = Field(
..., description="Default column used for sorting when order_column is not set"
)
default_sort_direction: Literal["asc", "desc"] = Field(
"desc", description="Default sort direction"
)
search_columns: list[str] = Field(
default_factory=list,
description="Columns searched when using the search parameter",
)
class GetSchemaRequest(BaseModel):
"""Request schema for unified get_schema tool."""
model_type: Literal["chart", "dataset", "dashboard"] = Field(
..., description="Model type to get schema for"
)
class GetSchemaResponse(BaseModel):
"""Response for unified get_schema tool."""
schema_info: ModelSchemaInfo = Field(
..., description="Comprehensive schema information"
)
def _get_sqlalchemy_type_name(col_type: Any) -> str:
"""Convert SQLAlchemy column type to a friendly type name."""
if isinstance(col_type, (sa.String, sa.Text)):
return "str"
elif isinstance(col_type, sa.Boolean):
return "bool"
elif isinstance(col_type, (sa.Integer, sa.SmallInteger, sa.BigInteger)):
return "int"
elif isinstance(col_type, (sa.Float, sa.Numeric)):
return "float"
elif isinstance(col_type, sa.DateTime):
return "datetime"
elif isinstance(col_type, sa.Date):
return "date"
elif isinstance(col_type, sa.Time):
return "time"
elif isinstance(col_type, sa.JSON):
return "dict"
elif isinstance(col_type, sa.ARRAY):
return "list"
else:
return "str" # Default fallback
# Descriptions for common model columns that SQLAlchemy models don't document.
# Used as a fallback when the model column has no doc/comment attribute.
_COLUMN_DESCRIPTIONS: dict[str, str] = {
# Common across models
"id": "Unique integer identifier",
"uuid": "Unique UUID identifier",
"created_on": "Timestamp when the resource was created",
"changed_on": "Timestamp when the resource was last modified",
"created_by_fk": "User ID of the creator",
"changed_by_fk": "User ID of the last modifier",
"description": "User-provided description text",
"cache_timeout": "Cache timeout override in seconds",
"perm": "Permission string for access control",
"schema_perm": "Schema-level permission string",
"catalog_perm": "Catalog-level permission string",
"is_managed_externally": "Whether managed by an external system",
"external_url": "URL of the external management system",
"certified_by": "Name of the person who certified this resource",
"certification_details": "Details about the certification",
# Chart-specific
"slice_name": "Chart display name",
"datasource_id": "ID of the underlying dataset",
"datasource_type": "Type of data source (e.g., table)",
"viz_type": "Visualization type (e.g., echarts_timeseries_line, table)",
"params": "JSON string of chart parameters/configuration",
"query_context": "JSON string of the query context for data fetching",
"last_saved_at": "Timestamp of the last explicit save",
"last_saved_by_fk": "User ID who last saved this chart",
# Dataset-specific
"table_name": "Name of the database table or view",
"schema": "Database schema name",
"catalog": "Database catalog name",
"database_id": "ID of the database connection",
"sql": "Custom SQL expression (for virtual datasets)",
"main_dttm_col": "Primary datetime column for time-series queries",
"is_sqllab_view": "Whether this dataset was created from SQL Lab",
"template_params": "Jinja template parameters as JSON",
"extra": "Extra configuration as JSON",
"filter_select_enabled": "Whether filter select is enabled",
"normalize_columns": "Whether to normalize column names",
"always_filter_main_dttm": "Whether to always filter on the main datetime column",
"fetch_values_predicate": "SQL predicate for fetching filter values",
"default_endpoint": "Default endpoint URL",
"offset": "Row offset for queries",
"is_featured": "Whether this dataset is featured",
"currency_code_column": "Column containing currency codes",
# Dashboard-specific
"dashboard_title": "Dashboard display title",
"slug": "URL-friendly identifier for the dashboard",
"published": "Whether the dashboard is published and visible",
"position_json": "JSON layout of dashboard components",
"json_metadata": "JSON metadata including filters and settings",
"css": "Custom CSS for the dashboard",
"theme_id": "Theme ID for dashboard styling",
}
def get_columns_from_model(
model_cls: Type[Any],
default_columns: list[str],
extra_columns: dict[str, ColumnMetadata] | None = None,
) -> list[ColumnMetadata]:
"""
Dynamically extract column metadata from a SQLAlchemy model.
Args:
model_cls: The SQLAlchemy model class to inspect
default_columns: List of column names that should be marked as defaults
extra_columns: Additional columns not on the model (e.g., computed fields)
Returns:
List of ColumnMetadata objects for all columns
"""
columns: list[ColumnMetadata] = []
mapper = inspect(model_cls)
for col in mapper.columns:
col_name = col.key
col_type = _get_sqlalchemy_type_name(col.type)
# Get description from column doc, comment, or fallback mapping
description = (
getattr(col, "doc", None)
or getattr(col, "comment", None)
or _COLUMN_DESCRIPTIONS.get(col_name)
)
columns.append(
ColumnMetadata(
name=col_name,
description=description,
type=col_type,
is_default=col_name in default_columns,
)
)
# Add extra columns (computed fields, relationships, etc.)
if extra_columns:
for name, metadata in extra_columns.items():
# Check if already added from model columns
if not any(c.name == name for c in columns):
columns.append(metadata)
return columns
# =============================================================================
# Model Configuration
# =============================================================================
# Only business-logic decisions that can't be derived from the model:
# - Default columns (which columns to show by default for reduced token usage)
# - Sortable columns (which columns support ORDER BY)
# - Search columns (which columns to search in)
# - Extra columns (computed/relationship fields not on the model)
# Chart configuration
CHART_DEFAULT_COLUMNS = ["id", "slice_name", "viz_type", "uuid"]
CHART_SORTABLE_COLUMNS = [
"id",
"slice_name",
"viz_type",
"description",
"changed_on",
"created_on",
]
CHART_SEARCH_COLUMNS = ["slice_name", "description"]
CHART_EXTRA_COLUMNS: dict[str, ColumnMetadata] = {
"datasource_name": ColumnMetadata(
name="datasource_name",
description="Data source name",
type="str",
is_default=False,
),
"datasource_type": ColumnMetadata(
name="datasource_type",
description="Data source type",
type="str",
is_default=False,
),
"url": ColumnMetadata(
name="url", description="Chart URL", type="str", is_default=False
),
"form_data": ColumnMetadata(
name="form_data",
description="Chart form data",
type="dict",
is_default=False,
),
"changed_by": ColumnMetadata(
name="changed_by",
description="Last modifier username",
type="str",
is_default=False,
),
"changed_by_name": ColumnMetadata(
name="changed_by_name",
description="Last modifier display name",
type="str",
is_default=False,
),
"changed_on_humanized": ColumnMetadata(
name="changed_on_humanized",
description="Humanized modification time",
type="str",
is_default=False,
),
"created_by": ColumnMetadata(
name="created_by",
description="Creator username",
type="str",
is_default=False,
),
"created_by_name": ColumnMetadata(
name="created_by_name",
description="Creator display name",
type="str",
is_default=False,
),
"created_on_humanized": ColumnMetadata(
name="created_on_humanized",
description="Humanized creation time",
type="str",
is_default=False,
),
"tags": ColumnMetadata(
name="tags", description="Chart tags", type="list", is_default=False
),
"owners": ColumnMetadata(
name="owners", description="Chart owners", type="list", is_default=False
),
}
# Dataset configuration
DATASET_DEFAULT_COLUMNS = ["id", "table_name", "schema", "uuid"]
DATASET_SORTABLE_COLUMNS = [
"id",
"table_name",
"schema",
"changed_on",
"created_on",
]
DATASET_SEARCH_COLUMNS = ["table_name", "description"]
DATASET_EXTRA_COLUMNS: dict[str, ColumnMetadata] = {
"database_name": ColumnMetadata(
name="database_name",
description="Database connection name",
type="str",
is_default=False,
),
"changed_by": ColumnMetadata(
name="changed_by",
description="Last modifier username",
type="str",
is_default=False,
),
"changed_by_name": ColumnMetadata(
name="changed_by_name",
description="Last modifier display name",
type="str",
is_default=False,
),
"changed_on_humanized": ColumnMetadata(
name="changed_on_humanized",
description="Humanized modification time",
type="str",
is_default=False,
),
"created_by": ColumnMetadata(
name="created_by",
description="Creator username",
type="str",
is_default=False,
),
"created_by_name": ColumnMetadata(
name="created_by_name",
description="Creator display name",
type="str",
is_default=False,
),
"created_on_humanized": ColumnMetadata(
name="created_on_humanized",
description="Humanized creation time",
type="str",
is_default=False,
),
"metrics": ColumnMetadata(
name="metrics",
description="Dataset metrics definitions",
type="list",
is_default=False,
),
"columns": ColumnMetadata(
name="columns",
description="Dataset column definitions",
type="list",
is_default=False,
),
"tags": ColumnMetadata(
name="tags", description="Dataset tags", type="list", is_default=False
),
"owners": ColumnMetadata(
name="owners", description="Dataset owners", type="list", is_default=False
),
}
# Dashboard configuration
DASHBOARD_DEFAULT_COLUMNS = ["id", "dashboard_title", "slug", "uuid"]
DASHBOARD_SORTABLE_COLUMNS = [
"id",
"dashboard_title",
"slug",
"published",
"changed_on",
"created_on",
]
DASHBOARD_SEARCH_COLUMNS = ["dashboard_title", "slug"]
DASHBOARD_EXTRA_COLUMNS: dict[str, ColumnMetadata] = {
"url": ColumnMetadata(
name="url", description="Dashboard URL", type="str", is_default=False
),
"changed_by": ColumnMetadata(
name="changed_by",
description="Last modifier username",
type="str",
is_default=False,
),
"changed_by_name": ColumnMetadata(
name="changed_by_name",
description="Last modifier display name",
type="str",
is_default=False,
),
"changed_on_humanized": ColumnMetadata(
name="changed_on_humanized",
description="Humanized modification time",
type="str",
is_default=False,
),
"created_by": ColumnMetadata(
name="created_by",
description="Creator username",
type="str",
is_default=False,
),
"created_by_name": ColumnMetadata(
name="created_by_name",
description="Creator display name",
type="str",
is_default=False,
),
"created_on_humanized": ColumnMetadata(
name="created_on_humanized",
description="Humanized creation time",
type="str",
is_default=False,
),
"tags": ColumnMetadata(
name="tags", description="Dashboard tags", type="list", is_default=False
),
"owners": ColumnMetadata(
name="owners", description="Dashboard owners", type="list", is_default=False
),
"charts": ColumnMetadata(
name="charts", description="Charts in dashboard", type="list", is_default=False
),
}
def get_chart_columns() -> list[ColumnMetadata]:
"""Get column metadata for Chart model dynamically."""
from superset.models.slice import Slice
return get_columns_from_model(Slice, CHART_DEFAULT_COLUMNS, CHART_EXTRA_COLUMNS)
def get_dataset_columns() -> list[ColumnMetadata]:
"""Get column metadata for Dataset model dynamically."""
from superset.connectors.sqla.models import SqlaTable
return get_columns_from_model(
SqlaTable, DATASET_DEFAULT_COLUMNS, DATASET_EXTRA_COLUMNS
)
def get_dashboard_columns() -> list[ColumnMetadata]:
"""Get column metadata for Dashboard model dynamically."""
from superset.models.dashboard import Dashboard
return get_columns_from_model(
Dashboard, DASHBOARD_DEFAULT_COLUMNS, DASHBOARD_EXTRA_COLUMNS
)
def get_all_column_names(columns: list[ColumnMetadata]) -> list[str]:
"""Extract all column names from column metadata list."""
return [col.name for col in columns]
# For backwards compatibility with existing code that imports these
# These will be populated lazily when needed
CHART_ALL_COLUMNS: list[str] = []
DATASET_ALL_COLUMNS: list[str] = []
DASHBOARD_ALL_COLUMNS: list[str] = []