mirror of
https://github.com/apache/superset.git
synced 2026-05-11 02:45:46 +00:00
840 lines
33 KiB
Python
840 lines
33 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
||
# or more contributor license agreements. See the NOTICE file
|
||
# distributed with this work for additional information
|
||
# regarding copyright ownership. The ASF licenses this file
|
||
# to you under the Apache License, Version 2.0 (the
|
||
# "License"); you may not use this file except in compliance
|
||
# with the License. You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing,
|
||
# software distributed under the License is distributed on an
|
||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||
# KIND, either express or implied. See the License for the
|
||
# specific language governing permissions and limitations
|
||
# under the License.
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import re
|
||
from abc import ABC, abstractmethod
|
||
from datetime import datetime, timedelta, timezone
|
||
from typing import Any, Callable, Dict, Generic, List, Literal, Type, TypeVar
|
||
|
||
from flask_appbuilder.models.sqla.interface import SQLAInterface
|
||
from pydantic import BaseModel
|
||
from sqlalchemy import func
|
||
|
||
from superset.daos.base import BaseDAO, ColumnOperator, ColumnOperatorEnum
|
||
from superset.extensions import db
|
||
from superset.mcp_service.constants import MAX_PAGE_SIZE, ModelType
|
||
from superset.mcp_service.privacy import (
|
||
filter_user_directory_columns,
|
||
SELF_REFERENCING_FILTER_COLUMNS,
|
||
USER_DIRECTORY_FIELDS,
|
||
)
|
||
from superset.mcp_service.system.schemas import PaginationInfo
|
||
from superset.mcp_service.utils import _is_uuid
|
||
from superset.mcp_service.utils.permissions_utils import get_current_user
|
||
from superset.mcp_service.utils.schema_utils import (
|
||
parse_json_or_list,
|
||
parse_json_or_passthrough,
|
||
)
|
||
from superset.utils import json
|
||
|
||
|
||
def _slugify(value: str) -> str:
|
||
"""Normalize a string to a slug-like form for comparison.
|
||
|
||
Lowercases, drops apostrophes so possessives collapse
|
||
("World Bank's" → "worldbanks" territory), then collapses any
|
||
remaining non-alphanumerics to single hyphens and trims
|
||
leading/trailing hyphens. Mirrors how agents typically guess slugs
|
||
from a dashboard title (e.g. "World Bank's Data" → "world-banks-data").
|
||
"""
|
||
lowered = value.lower()
|
||
# Drop apostrophes entirely so "bank's" collapses to "banks" rather than
|
||
# splitting into "bank-s". Covers straight and curly variants.
|
||
stripped = re.sub(r"['’]", "", lowered)
|
||
return re.sub(r"[^a-z0-9]+", "-", stripped).strip("-")
|
||
|
||
|
||
# Type variables for generic model tools
|
||
T = TypeVar("T") # For model objects
|
||
S = TypeVar("S", bound=BaseModel) # For Pydantic schemas
|
||
F = TypeVar("F", bound=BaseModel) # For filter types
|
||
L = TypeVar("L", bound=BaseModel) # For list response schemas
|
||
|
||
|
||
class BaseCore(ABC):
|
||
"""
|
||
Abstract base class for all MCP Core classes.
|
||
|
||
Provides common functionality:
|
||
- Logger initialization
|
||
- Abstract run_tool method that all subclasses must implement
|
||
- Common error handling patterns
|
||
"""
|
||
|
||
def __init__(self, logger: logging.Logger | None = None) -> None:
|
||
"""Initialize the core with an optional logger."""
|
||
self.logger = logger or logging.getLogger(self.__class__.__name__)
|
||
|
||
@abstractmethod
|
||
def run_tool(self, *args: Any, **kwargs: Any) -> Any:
|
||
"""
|
||
Execute the core tool logic.
|
||
|
||
This method must be implemented by all subclasses.
|
||
"""
|
||
pass
|
||
|
||
def _log_error(self, error: Exception, context: str = "") -> None:
|
||
"""Log an error at DEBUG level for stack-trace context.
|
||
|
||
Callers must re-raise the exception after calling this method.
|
||
The GlobalErrorHandlerMiddleware is the single source of truth
|
||
for error classification and logging level.
|
||
"""
|
||
error_msg = f"Error in {self.__class__.__name__}"
|
||
if context:
|
||
error_msg += f" ({context})"
|
||
error_msg += f": {str(error)}"
|
||
self.logger.debug(error_msg, exc_info=True)
|
||
|
||
def _log_info(self, message: str) -> None:
|
||
"""Log an info message."""
|
||
self.logger.info(message)
|
||
|
||
def _log_warning(self, message: str) -> None:
|
||
"""Log a warning message."""
|
||
self.logger.warning(message)
|
||
|
||
|
||
class ModelListCore(BaseCore, Generic[L]):
|
||
"""
|
||
Generic tool for listing model objects with filtering, search, pagination, and
|
||
column selection.
|
||
|
||
- Paging is 0-based: page=0 is the first page (to match backend and API
|
||
conventions).
|
||
- total_pages is 0 if there are no results; otherwise, it's ceil(total_count /
|
||
page_size).
|
||
- has_previous is True if page > 0 or (page == 0 and total_count == 0) (so UI
|
||
can disable prev button on empty results).
|
||
- has_next is True if there are more results after the current page.
|
||
- columns_requested/columns_loaded track what columns were requested/returned
|
||
for LLM/OpenAPI friendliness.
|
||
- Returns a strongly-typed Pydantic list schema (output_list_schema) with all
|
||
metadata.
|
||
- Handles both object-based and JSON string filters.
|
||
- Designed for use by LLM agents and API clients.
|
||
"""
|
||
|
||
output_list_schema: Type[L]
|
||
|
||
def __init__(
|
||
self,
|
||
dao_class: Type[BaseDAO[Any]],
|
||
output_schema: Type[S],
|
||
item_serializer: Callable[[T, List[str]], S | None],
|
||
filter_type: Type[F],
|
||
default_columns: List[str],
|
||
search_columns: List[str],
|
||
list_field_name: str,
|
||
output_list_schema: Type[L],
|
||
logger: logging.Logger | None = None,
|
||
all_columns: List[str] | None = None,
|
||
sortable_columns: List[str] | None = None,
|
||
) -> None:
|
||
super().__init__(logger)
|
||
self.dao_class = dao_class
|
||
self.output_schema = output_schema
|
||
self.item_serializer = item_serializer
|
||
self.filter_type = filter_type
|
||
self.default_columns = filter_user_directory_columns(default_columns)
|
||
self.search_columns = filter_user_directory_columns(search_columns)
|
||
self.list_field_name = list_field_name
|
||
self.output_list_schema = output_list_schema
|
||
self._all_columns = filter_user_directory_columns(
|
||
all_columns if all_columns else default_columns
|
||
)
|
||
self._sortable_columns = filter_user_directory_columns(
|
||
sortable_columns if sortable_columns else []
|
||
)
|
||
|
||
@property
|
||
def all_columns(self) -> List[str]:
|
||
"""Return a copy of all_columns to prevent external mutation."""
|
||
return list(self._all_columns)
|
||
|
||
@property
|
||
def sortable_columns(self) -> List[str]:
|
||
"""Return a copy of sortable_columns to prevent external mutation."""
|
||
return list(self._sortable_columns)
|
||
|
||
def _get_columns_to_load(
|
||
self, select_columns: Any | None
|
||
) -> tuple[List[str], List[str]]:
|
||
"""Return requested and loaded columns after privacy filtering."""
|
||
if not select_columns:
|
||
return self.default_columns, list(self.default_columns)
|
||
|
||
parsed_columns = parse_json_or_list(select_columns, param_name="select_columns")
|
||
columns_to_load = filter_user_directory_columns(parsed_columns)
|
||
if not columns_to_load:
|
||
raise ValueError("select_columns contains no valid columns")
|
||
|
||
return columns_to_load, list(columns_to_load)
|
||
|
||
def _validate_order_column(self, order_column: str | None) -> None:
|
||
"""Reject privacy-filtered or unknown sort columns.
|
||
|
||
Validation is skipped when no sortable_columns were declared, to preserve
|
||
backward-compatible passthrough behaviour for tools that rely on DAO-level
|
||
sort handling.
|
||
"""
|
||
if (
|
||
order_column
|
||
and self._sortable_columns
|
||
and order_column not in self._sortable_columns
|
||
):
|
||
raise ValueError(
|
||
f"Invalid order_column '{order_column}'. "
|
||
f"Allowed columns: {', '.join(self._sortable_columns)}"
|
||
)
|
||
|
||
@staticmethod
|
||
def _prepend_self_lookup_filters(
|
||
filters: Any,
|
||
created_by_me: bool,
|
||
owned_by_me: bool,
|
||
user: Any,
|
||
) -> Any:
|
||
"""Translate created_by_me/owned_by_me flags into ColumnOperator filters.
|
||
|
||
Validates authentication and injects the current user's ID in one step,
|
||
so no placeholder value ever reaches the DAO layer.
|
||
|
||
When both flags are set, a single combined OR filter is used so results
|
||
include items where the user is either the creator or an owner.
|
||
"""
|
||
if not (created_by_me or owned_by_me):
|
||
return filters
|
||
|
||
if not user or not getattr(user, "is_authenticated", False):
|
||
raise ValueError("This operation requires an authenticated user")
|
||
|
||
user_id: int = user.id
|
||
extra: ColumnOperator
|
||
if created_by_me and owned_by_me:
|
||
extra = ColumnOperator(
|
||
col="created_by_fk_or_owner", opr="eq", value=user_id
|
||
)
|
||
elif created_by_me:
|
||
extra = ColumnOperator(col="created_by_fk", opr="eq", value=user_id)
|
||
else:
|
||
extra = ColumnOperator(col="owner", opr="eq", value=user_id)
|
||
|
||
if filters is None:
|
||
return [extra]
|
||
if isinstance(filters, list):
|
||
return [extra] + filters
|
||
return [extra, filters]
|
||
|
||
def run_tool(
|
||
self,
|
||
filters: Any | None = None,
|
||
search: str | None = None,
|
||
select_columns: Any | None = None,
|
||
order_column: str | None = None,
|
||
order_direction: Literal["asc", "desc"] | None = "asc",
|
||
page: int = 0,
|
||
page_size: int = 10,
|
||
created_by_me: bool = False,
|
||
owned_by_me: bool = False,
|
||
) -> L:
|
||
# Clamp page_size to MAX_PAGE_SIZE as defense-in-depth
|
||
page_size = min(page_size, MAX_PAGE_SIZE)
|
||
|
||
# Parse filters using generic utility (accepts JSON string or object)
|
||
filters = parse_json_or_passthrough(filters, param_name="filters")
|
||
|
||
filters = self._prepend_self_lookup_filters(
|
||
filters, created_by_me, owned_by_me, get_current_user()
|
||
)
|
||
|
||
# Parse select_columns using generic utility (accepts JSON, list, or CSV)
|
||
columns_requested, columns_to_load = self._get_columns_to_load(select_columns)
|
||
|
||
# Ensure computed columns have their dependencies loaded.
|
||
# Humanized timestamps are derived from their raw counterparts —
|
||
# if the raw column isn't loaded, the serializer produces null.
|
||
computed_deps: dict[str, str] = {
|
||
"changed_on_humanized": "changed_on",
|
||
"created_on_humanized": "created_on",
|
||
}
|
||
for computed, dependency in computed_deps.items():
|
||
if computed in columns_to_load and dependency not in columns_to_load:
|
||
columns_to_load.append(dependency)
|
||
|
||
self._validate_order_column(order_column)
|
||
|
||
# Query the DAO
|
||
items: List[Any]
|
||
items, total_count = self.dao_class.list(
|
||
column_operators=filters,
|
||
order_column=order_column or "changed_on",
|
||
order_direction=str(order_direction or "desc"),
|
||
page=page,
|
||
page_size=page_size,
|
||
search=search,
|
||
search_columns=self.search_columns,
|
||
columns=columns_to_load,
|
||
)
|
||
# Serialize items
|
||
item_objs = []
|
||
for item in items:
|
||
obj = self.item_serializer(item, columns_to_load)
|
||
if obj is not None:
|
||
item_objs.append(obj)
|
||
total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 0
|
||
|
||
# Report 1-based page in response to match the 1-based input convention
|
||
# used by all list tool wrappers (list_charts, list_datasets, etc.)
|
||
page_1based = page + 1
|
||
pagination_info = PaginationInfo(
|
||
page=page_1based,
|
||
page_size=page_size,
|
||
total_count=total_count,
|
||
total_pages=total_pages,
|
||
has_next=page < total_pages - 1,
|
||
has_previous=page > 0,
|
||
)
|
||
|
||
# Build response
|
||
def get_keys(obj: BaseModel | dict[str, Any] | Any) -> List[str]:
|
||
if hasattr(obj, "model_dump"):
|
||
return list(obj.model_dump().keys())
|
||
elif isinstance(obj, dict):
|
||
return list(obj.keys())
|
||
return []
|
||
|
||
response_kwargs = {
|
||
self.list_field_name: item_objs,
|
||
"count": len(item_objs),
|
||
"total_count": total_count,
|
||
"page": page_1based,
|
||
"page_size": page_size,
|
||
"total_pages": total_pages,
|
||
"has_previous": page > 0,
|
||
"has_next": page < total_pages - 1,
|
||
"columns_requested": columns_requested,
|
||
"columns_loaded": columns_to_load,
|
||
"columns_available": self.all_columns,
|
||
"sortable_columns": self.sortable_columns,
|
||
"filters_applied": [
|
||
f
|
||
for f in (filters if isinstance(filters, list) else [])
|
||
if (f.get("col") if isinstance(f, dict) else getattr(f, "col", None))
|
||
not in SELF_REFERENCING_FILTER_COLUMNS
|
||
],
|
||
"pagination": pagination_info,
|
||
"timestamp": datetime.now(timezone.utc),
|
||
}
|
||
response = self.output_list_schema(**response_kwargs)
|
||
self._log_info(
|
||
f"Successfully retrieved {len(item_objs)} {self.list_field_name}"
|
||
)
|
||
return response
|
||
|
||
|
||
class ModelGetInfoCore(BaseCore):
|
||
"""
|
||
Enhanced tool for retrieving a single model object by ID, UUID, or slug.
|
||
|
||
For datasets and charts: supports ID and UUID
|
||
For dashboards: supports ID, UUID, and slug
|
||
|
||
Uses the appropriate DAO method to find the object based on identifier type.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dao_class: Type[BaseDAO[Any]],
|
||
output_schema: Type[BaseModel],
|
||
error_schema: Type[BaseModel],
|
||
serializer: Callable[[T], BaseModel],
|
||
supports_slug: bool = False,
|
||
logger: logging.Logger | None = None,
|
||
query_options: list[Any] | None = None,
|
||
title_column_name: str | None = None,
|
||
) -> None:
|
||
super().__init__(logger)
|
||
self.dao_class = dao_class
|
||
self.output_schema = output_schema
|
||
self.error_schema = error_schema
|
||
self.serializer = serializer
|
||
self.supports_slug = supports_slug
|
||
self.query_options = query_options or []
|
||
# When set, enables a slugified-title fallback after slug lookup
|
||
# fails, so identifiers like "world-banks-data" still resolve to
|
||
# "World Bank's Data" when the dashboard's slug field is empty.
|
||
# Defaults to the DAO's `title_column` attribute when not overridden.
|
||
self.title_column_name = title_column_name or getattr(
|
||
dao_class, "title_column", None
|
||
)
|
||
|
||
def _base_filtered_query(self) -> Any:
|
||
"""Build a query for this DAO's model with base_filter applied.
|
||
|
||
Ensures slug-like and title-based lookups respect RBAC — e.g.
|
||
DashboardAccessFilter excludes rows the current user is not
|
||
allowed to see. Mirrors DashboardDAO.get_by_id_or_slug.
|
||
"""
|
||
model_class = self.dao_class.model_cls
|
||
query = db.session.query(model_class)
|
||
|
||
if (base_filter := getattr(self.dao_class, "base_filter", None)) is not None:
|
||
query = base_filter(
|
||
self.dao_class.id_column_name,
|
||
SQLAInterface(model_class, db.session),
|
||
).apply(query, None)
|
||
|
||
if self.query_options:
|
||
query = query.options(*self.query_options)
|
||
return query
|
||
|
||
def _find_by_slugified_title(self, identifier: str) -> Any:
|
||
"""Resolve a slug-like identifier by matching against slugified titles.
|
||
|
||
First narrows candidates with an ILIKE on the title column so the
|
||
DB does the heavy filtering — a slug like "world-banks-data" maps
|
||
to the pattern "%world%banks%data%". The ILIKE side strips
|
||
apostrophes from the title (via SQL REPLACE) so it matches the
|
||
same way `_slugify` does in Python — without that, "World Bank's
|
||
Data" wouldn't match "%banks%" because the raw title has "bank's".
|
||
Then confirms each candidate with `_slugify` to weed out
|
||
coincidental ILIKE matches (e.g. "Worldwide Bank Sandbox Data").
|
||
|
||
Orders by primary key so the returned row is deterministic when
|
||
multiple titles slugify to the same value. The caller can always
|
||
disambiguate by id or UUID; in the rare collision case we log a
|
||
warning and return the lowest-id match.
|
||
"""
|
||
if not self.title_column_name:
|
||
return None
|
||
target = _slugify(identifier)
|
||
if not target:
|
||
return None
|
||
|
||
model_class = self.dao_class.model_cls
|
||
title_col = getattr(model_class, self.title_column_name, None)
|
||
if title_col is None:
|
||
return None
|
||
|
||
parts = [p for p in target.split("-") if p]
|
||
# parts is non-empty: target is non-empty and contains at least one
|
||
# alphanumeric run. The pattern preserves the agent's word order so
|
||
# we don't return rows whose titles only happen to share the same
|
||
# tokens shuffled.
|
||
pattern = "%" + "%".join(parts) + "%"
|
||
# Strip both straight and curly apostrophes from the title before
|
||
# comparing — matches `_slugify`'s Python-side handling.
|
||
normalized_title = func.replace(func.replace(title_col, "'", ""), "’", "")
|
||
id_col = getattr(model_class, self.dao_class.id_column_name)
|
||
candidates = (
|
||
self._base_filtered_query()
|
||
.filter(normalized_title.ilike(pattern))
|
||
.order_by(id_col)
|
||
.all()
|
||
)
|
||
|
||
matches = [
|
||
obj
|
||
for obj in candidates
|
||
if _slugify(getattr(obj, self.title_column_name, "") or "") == target
|
||
]
|
||
if not matches:
|
||
return None
|
||
if len(matches) > 1:
|
||
ids = [getattr(m, "id", None) for m in matches]
|
||
self._log_warning(
|
||
f"Identifier '{identifier}' matched {len(matches)} rows by "
|
||
f"slugified title (ids={ids}); returning the first. Pass an "
|
||
"id or UUID to disambiguate."
|
||
)
|
||
return matches[0]
|
||
|
||
def _find_object(self, identifier: int | str) -> Any:
|
||
"""Find object by identifier using appropriate method."""
|
||
opts = self.query_options or None
|
||
# If it's an integer or string that can be converted to int, use find_by_id
|
||
if isinstance(identifier, int):
|
||
return self.dao_class.find_by_id(identifier, query_options=opts)
|
||
|
||
try:
|
||
# Try to convert string to int
|
||
id_val = int(identifier)
|
||
return self.dao_class.find_by_id(id_val, query_options=opts)
|
||
except ValueError:
|
||
pass
|
||
|
||
# Check if it's a UUID
|
||
if _is_uuid(identifier):
|
||
# Use the new flexible find_by_id with uuid column
|
||
return self.dao_class.find_by_id(
|
||
identifier, id_column="uuid", query_options=opts
|
||
)
|
||
|
||
# For dashboards, also check slug
|
||
if self.supports_slug:
|
||
# Try to find by slug using the new flexible method
|
||
result = self.dao_class.find_by_id(
|
||
identifier, id_column="slug", query_options=opts
|
||
)
|
||
if result:
|
||
return result
|
||
|
||
# Fallback to the existing id_or_slug_filter for complex cases.
|
||
# Apply base_filter so disallowed rows aren't exposed here.
|
||
from superset.models.dashboard import id_or_slug_filter
|
||
|
||
slug_result = (
|
||
self._base_filtered_query()
|
||
.filter(id_or_slug_filter(identifier))
|
||
.one_or_none()
|
||
)
|
||
if slug_result is not None:
|
||
return slug_result
|
||
|
||
# Many dashboards have empty slugs, so slug lookup alone silently
|
||
# fails when agents pass a slug-like string derived from the
|
||
# dashboard title. Fall back to slugified-title matching.
|
||
return self._find_by_slugified_title(identifier)
|
||
|
||
# If we get here, it's an invalid identifier
|
||
return None
|
||
|
||
def run_tool(self, identifier: int | str) -> BaseModel:
|
||
try:
|
||
obj = self._find_object(identifier)
|
||
if obj is None:
|
||
error_data = self.error_schema(
|
||
error=(
|
||
f"{self.output_schema.__name__} with identifier "
|
||
f"'{identifier}' not found"
|
||
),
|
||
error_type="not_found",
|
||
timestamp=datetime.now(timezone.utc),
|
||
)
|
||
self._log_warning(
|
||
f"{self.output_schema.__name__} {identifier} error: "
|
||
"not_found - not found"
|
||
)
|
||
return error_data
|
||
response = self.serializer(obj)
|
||
self._log_info(
|
||
f"{self.output_schema.__name__} response created successfully for "
|
||
f"identifier {identifier}"
|
||
)
|
||
return response
|
||
except Exception as context_error:
|
||
self._log_error(context_error)
|
||
raise
|
||
|
||
|
||
class InstanceInfoCore(BaseCore):
|
||
"""
|
||
Configurable tool for generating comprehensive instance information.
|
||
|
||
Provides a flexible way to gather and present statistics about a Superset
|
||
instance with configurable metrics, time windows, and data aggregations.
|
||
Supports custom metric calculators and result transformers for extensibility.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
dao_classes: Dict[str, Type[BaseDAO[Any]]],
|
||
output_schema: Type[BaseModel],
|
||
metric_calculators: Dict[str, Callable[..., Any]],
|
||
time_windows: Dict[str, int] | None = None,
|
||
logger: logging.Logger | None = None,
|
||
) -> None:
|
||
"""
|
||
Initialize the instance info tool.
|
||
|
||
Args:
|
||
dao_classes: Dict mapping entity names to their DAO classes
|
||
output_schema: Pydantic schema for the response
|
||
metric_calculators: Dict of custom metric calculation functions
|
||
time_windows: Dict of time window configurations (days)
|
||
logger: Optional logger instance
|
||
"""
|
||
super().__init__(logger)
|
||
self.dao_classes = dao_classes
|
||
self.output_schema = output_schema
|
||
self.metric_calculators = metric_calculators
|
||
self.time_windows = time_windows or {
|
||
"recent": 7,
|
||
"monthly": 30,
|
||
"quarterly": 90,
|
||
}
|
||
|
||
def _calculate_basic_counts(self) -> Dict[str, int]:
|
||
"""Calculate basic entity counts using DAOs."""
|
||
counts = {}
|
||
for entity_name, dao_class in self.dao_classes.items():
|
||
try:
|
||
counts[f"total_{entity_name}"] = dao_class.count()
|
||
except Exception as e:
|
||
self._log_warning(f"Failed to count {entity_name}: {e}")
|
||
counts[f"total_{entity_name}"] = 0
|
||
return counts
|
||
|
||
def _calculate_time_based_metrics(
|
||
self, base_counts: Dict[str, int]
|
||
) -> Dict[str, Dict[str, int]]:
|
||
"""Calculate time-based metrics for recent activity."""
|
||
now = datetime.now(timezone.utc)
|
||
time_metrics = {}
|
||
|
||
for window_name, days in self.time_windows.items():
|
||
cutoff_date = now - timedelta(days=days)
|
||
window_metrics = {}
|
||
|
||
# Calculate created and modified counts for each entity
|
||
for entity_name, dao_class in self.dao_classes.items():
|
||
# Skip entities without time tracking
|
||
if not hasattr(dao_class.model_cls, "created_on"):
|
||
continue
|
||
|
||
try:
|
||
# Use list() with filters (count() has no params)
|
||
_, created_count = dao_class.list(
|
||
column_operators=[
|
||
ColumnOperator(
|
||
col="created_on",
|
||
opr=ColumnOperatorEnum.gte,
|
||
value=cutoff_date,
|
||
)
|
||
],
|
||
page_size=1, # We only need the count
|
||
columns=["id"], # Minimal data transfer
|
||
)
|
||
window_metrics[f"{entity_name}_created"] = created_count
|
||
|
||
# Modified count (if changed_on exists)
|
||
if hasattr(dao_class.model_cls, "changed_on"):
|
||
_, modified_count = dao_class.list(
|
||
column_operators=[
|
||
ColumnOperator(
|
||
col="changed_on",
|
||
opr=ColumnOperatorEnum.gte,
|
||
value=cutoff_date,
|
||
)
|
||
],
|
||
page_size=1, # We only need the count
|
||
columns=["id"], # Minimal data transfer
|
||
)
|
||
window_metrics[f"{entity_name}_modified"] = modified_count
|
||
|
||
except Exception as e:
|
||
self._log_warning(
|
||
f"Failed to calculate {window_name} metrics for "
|
||
f"{entity_name}: {e}"
|
||
)
|
||
window_metrics[f"{entity_name}_created"] = 0
|
||
window_metrics[f"{entity_name}_modified"] = 0
|
||
|
||
time_metrics[window_name] = window_metrics
|
||
|
||
return time_metrics
|
||
|
||
def _calculate_custom_metrics(
|
||
self, base_counts: Dict[str, int], time_metrics: Dict[str, Dict[str, int]]
|
||
) -> Dict[str, Any]:
|
||
"""Calculate custom metrics using provided calculators."""
|
||
custom_metrics = {}
|
||
|
||
for metric_name, calculator in self.metric_calculators.items():
|
||
try:
|
||
# Pass context to calculator functions
|
||
result = calculator(
|
||
base_counts=base_counts,
|
||
time_metrics=time_metrics,
|
||
dao_classes=self.dao_classes,
|
||
)
|
||
# Only include successful calculations
|
||
if result is not None:
|
||
custom_metrics[metric_name] = result
|
||
except Exception as e:
|
||
self._log_warning(f"Failed to calculate {metric_name}: {e}")
|
||
# Don't add failed metrics to avoid validation errors
|
||
|
||
return custom_metrics
|
||
|
||
def run_tool(self) -> BaseModel:
|
||
"""Tool interface for generating comprehensive instance information."""
|
||
return self._generate_instance_info()
|
||
|
||
def get_resource(self) -> str:
|
||
"""Resource interface for generating instance metadata as JSON."""
|
||
instance_info = self._generate_instance_info()
|
||
return json.dumps(instance_info.model_dump(), indent=2)
|
||
|
||
def _generate_instance_info(self) -> BaseModel:
|
||
"""Generate comprehensive instance information."""
|
||
try:
|
||
# Calculate all metrics
|
||
base_counts = self._calculate_basic_counts()
|
||
time_metrics = self._calculate_time_based_metrics(base_counts)
|
||
custom_metrics = self._calculate_custom_metrics(base_counts, time_metrics)
|
||
|
||
# Combine all data with fallbacks for required fields
|
||
response_data = {
|
||
**base_counts,
|
||
**time_metrics,
|
||
**custom_metrics,
|
||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||
}
|
||
|
||
# Create response using the configured schema
|
||
response = self.output_schema(**response_data)
|
||
|
||
self._log_info("Successfully generated instance information")
|
||
return response
|
||
|
||
except Exception as e:
|
||
self._log_error(e, "generating instance info")
|
||
raise
|
||
|
||
|
||
class ModelGetSchemaCore(BaseCore, Generic[S]):
|
||
"""
|
||
Generic tool for retrieving comprehensive schema metadata for a model type.
|
||
|
||
Provides unified schema discovery for list tools:
|
||
- select_columns: All columns available for selection
|
||
- filter_columns: Filterable columns with their operators
|
||
- sortable_columns: Columns valid for order_column
|
||
- default_columns: Columns returned when select_columns not specified
|
||
- search_columns: Columns searched by the search parameter
|
||
- default_sort: Default column for sorting
|
||
- default_sort_direction: Default sort direction ("asc" or "desc")
|
||
|
||
Replaces the individual get_*_available_filters tools with a unified approach.
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model_type: ModelType,
|
||
dao_class: Type[BaseDAO[Any]],
|
||
output_schema: Type[S],
|
||
select_columns: List[Any],
|
||
sortable_columns: List[str],
|
||
default_columns: List[str],
|
||
search_columns: List[str],
|
||
default_sort: str = "changed_on",
|
||
default_sort_direction: Literal["asc", "desc"] = "desc",
|
||
exclude_filter_columns: set[str] | None = None,
|
||
logger: logging.Logger | None = None,
|
||
) -> None:
|
||
"""
|
||
Initialize the schema discovery core.
|
||
|
||
Args:
|
||
model_type: The type of model (chart, dataset, dashboard, database)
|
||
dao_class: The DAO class to query for filter columns
|
||
output_schema: Pydantic schema for the response (e.g., ModelSchemaInfo)
|
||
select_columns: Column metadata (List[ColumnMetadata] or similar)
|
||
sortable_columns: Column names that support sorting
|
||
default_columns: Column names returned by default
|
||
search_columns: Column names used for text search
|
||
default_sort: Default sort column
|
||
default_sort_direction: Default sort direction
|
||
exclude_filter_columns: Column names to omit from filter discovery
|
||
(e.g., sensitive fields like passwords or connection URIs)
|
||
logger: Optional logger instance
|
||
"""
|
||
super().__init__(logger)
|
||
self.model_type = model_type
|
||
self.dao_class = dao_class
|
||
self.output_schema = output_schema
|
||
self.select_columns = [
|
||
column
|
||
for column in select_columns
|
||
if getattr(column, "name", None) not in USER_DIRECTORY_FIELDS
|
||
]
|
||
self.sortable_columns = filter_user_directory_columns(sortable_columns)
|
||
self.default_columns = filter_user_directory_columns(default_columns)
|
||
self.search_columns = filter_user_directory_columns(search_columns)
|
||
self.default_sort = default_sort
|
||
self.default_sort_direction = default_sort_direction
|
||
self.exclude_filter_columns = set(exclude_filter_columns or set())
|
||
self.exclude_filter_columns.update(USER_DIRECTORY_FIELDS)
|
||
|
||
def _get_filter_columns(self) -> Dict[str, List[str]]:
|
||
"""Get filterable columns and operators from the DAO."""
|
||
try:
|
||
filterable = self.dao_class.get_filterable_columns_and_operators()
|
||
# Defensive handling: ensure we have a valid mapping
|
||
if filterable is None:
|
||
return {}
|
||
# Convert to dict safely - handle both dict and dict-like objects
|
||
if isinstance(filterable, dict):
|
||
result = dict(filterable)
|
||
else:
|
||
# Try to convert mapping-like objects
|
||
try:
|
||
result = dict(filterable)
|
||
except (TypeError, ValueError):
|
||
self._log_warning(
|
||
f"Unexpected filter columns type for {self.model_type}: "
|
||
f"{type(filterable)}"
|
||
)
|
||
return {}
|
||
# Remove excluded columns (e.g., sensitive fields)
|
||
if self.exclude_filter_columns:
|
||
result = {
|
||
k: v
|
||
for k, v in result.items()
|
||
if k not in self.exclude_filter_columns
|
||
}
|
||
return result
|
||
except Exception as e:
|
||
self._log_warning(
|
||
f"Failed to get filter columns for {self.model_type}: {e}"
|
||
)
|
||
return {}
|
||
|
||
def run_tool(self) -> S:
|
||
"""Execute schema discovery and return comprehensive schema info."""
|
||
try:
|
||
filter_columns = self._get_filter_columns()
|
||
|
||
response = self.output_schema(
|
||
model_type=self.model_type,
|
||
select_columns=self.select_columns,
|
||
filter_columns=filter_columns,
|
||
sortable_columns=self.sortable_columns,
|
||
default_select=self.default_columns,
|
||
default_sort=self.default_sort,
|
||
default_sort_direction=self.default_sort_direction,
|
||
search_columns=self.search_columns,
|
||
)
|
||
|
||
select_count = len(self.select_columns) if self.select_columns else 0
|
||
self._log_info(
|
||
f"Successfully retrieved schema for {self.model_type}: "
|
||
f"{select_count} select columns, "
|
||
f"{len(filter_columns)} filter columns, "
|
||
f"{len(self.sortable_columns)} sortable columns"
|
||
)
|
||
return response
|
||
except (AttributeError, TypeError, ValueError) as e:
|
||
self._log_error(e, f"getting schema for {self.model_type}")
|
||
raise
|