mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
1190 lines
42 KiB
Python
1190 lines
42 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""
|
|
Pydantic schemas for chart-related responses
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import html
|
|
import re
|
|
from datetime import datetime, timezone
|
|
from typing import Annotated, Any, Dict, List, Literal, Protocol
|
|
|
|
from pydantic import (
|
|
BaseModel,
|
|
ConfigDict,
|
|
Field,
|
|
field_validator,
|
|
model_validator,
|
|
PositiveInt,
|
|
)
|
|
|
|
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
|
|
from superset.mcp_service.common.cache_schemas import (
|
|
CacheStatus,
|
|
FormDataCacheControl,
|
|
MetadataCacheControl,
|
|
QueryCacheControl,
|
|
)
|
|
from superset.mcp_service.common.error_schemas import ChartGenerationError
|
|
from superset.mcp_service.system.schemas import (
|
|
PaginationInfo,
|
|
TagInfo,
|
|
UserInfo,
|
|
)
|
|
|
|
|
|
class ChartLike(Protocol):
|
|
"""Protocol for chart-like objects with expected attributes."""
|
|
|
|
id: int
|
|
slice_name: str | None
|
|
viz_type: str | None
|
|
datasource_name: str | None
|
|
datasource_type: str | None
|
|
url: str | None
|
|
description: str | None
|
|
cache_timeout: int | None
|
|
form_data: Dict[str, Any] | None
|
|
query_context: Any | None
|
|
changed_by: Any | None # User object
|
|
changed_by_name: str | None
|
|
changed_on: str | datetime | None
|
|
changed_on_humanized: str | None
|
|
created_by: Any | None # User object
|
|
created_by_name: str | None
|
|
created_on: str | datetime | None
|
|
created_on_humanized: str | None
|
|
uuid: str | None
|
|
tags: List[Any] | None
|
|
owners: List[Any] | None
|
|
|
|
|
|
class ChartInfo(BaseModel):
|
|
"""Full chart model with all possible attributes."""
|
|
|
|
id: int = Field(..., description="Chart ID")
|
|
slice_name: str = Field(..., description="Chart name")
|
|
viz_type: str | None = Field(None, description="Visualization type")
|
|
datasource_name: str | None = Field(None, description="Datasource name")
|
|
datasource_type: str | None = Field(None, description="Datasource type")
|
|
url: str | None = Field(None, description="Chart URL")
|
|
description: str | None = Field(None, description="Chart description")
|
|
cache_timeout: int | None = Field(None, description="Cache timeout")
|
|
form_data: Dict[str, Any] | None = Field(None, description="Chart form data")
|
|
query_context: Any | None = Field(None, description="Query context")
|
|
changed_by: str | None = Field(None, description="Last modifier (username)")
|
|
changed_by_name: str | None = Field(
|
|
None, description="Last modifier (display name)"
|
|
)
|
|
changed_on: str | datetime | None = Field(
|
|
None, description="Last modification timestamp"
|
|
)
|
|
changed_on_humanized: str | None = Field(
|
|
None, description="Humanized modification time"
|
|
)
|
|
created_by: str | None = Field(None, description="Chart creator (username)")
|
|
created_on: str | datetime | None = Field(None, description="Creation timestamp")
|
|
created_on_humanized: str | None = Field(
|
|
None, description="Humanized creation time"
|
|
)
|
|
uuid: str | None = Field(None, description="Chart UUID")
|
|
tags: List[TagInfo] = Field(default_factory=list, description="Chart tags")
|
|
owners: List[UserInfo] = Field(default_factory=list, description="Chart owners")
|
|
model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601")
|
|
|
|
|
|
class GetChartAvailableFiltersRequest(BaseModel):
|
|
"""
|
|
Request schema for get_chart_available_filters tool.
|
|
|
|
Currently has no parameters but provides consistent API for future extensibility.
|
|
"""
|
|
|
|
model_config = ConfigDict(
|
|
extra="forbid",
|
|
str_strip_whitespace=True,
|
|
)
|
|
|
|
|
|
class ChartAvailableFiltersResponse(BaseModel):
|
|
column_operators: Dict[str, Any] = Field(
|
|
..., description="Available filter operators and metadata for each column"
|
|
)
|
|
|
|
|
|
class ChartError(BaseModel):
|
|
error: str = Field(..., description="Error message")
|
|
error_type: str = Field(..., description="Type of error")
|
|
timestamp: datetime = Field(
|
|
default_factory=lambda: datetime.now(timezone.utc),
|
|
description="Error timestamp",
|
|
)
|
|
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
|
|
|
|
|
class ChartCapabilities(BaseModel):
|
|
"""Describes what the chart can do for LLM understanding."""
|
|
|
|
supports_interaction: bool = Field(description="Chart supports user interaction")
|
|
supports_real_time: bool = Field(description="Chart supports live data updates")
|
|
supports_drill_down: bool = Field(
|
|
description="Chart supports drill-down navigation"
|
|
)
|
|
supports_export: bool = Field(description="Chart can be exported to other formats")
|
|
optimal_formats: List[str] = Field(description="Recommended preview formats")
|
|
data_types: List[str] = Field(
|
|
description="Types of data shown (time_series, categorical, etc)"
|
|
)
|
|
|
|
|
|
class ChartSemantics(BaseModel):
|
|
"""Semantic information for LLM reasoning."""
|
|
|
|
primary_insight: str = Field(
|
|
description="Main insight or pattern the chart reveals"
|
|
)
|
|
data_story: str = Field(description="Narrative description of what the data shows")
|
|
recommended_actions: List[str] = Field(
|
|
description="Suggested next steps based on data"
|
|
)
|
|
anomalies: List[str] = Field(description="Notable outliers or unusual patterns")
|
|
statistical_summary: Dict[str, Any] = Field(
|
|
description="Key statistics (mean, median, trends)"
|
|
)
|
|
|
|
|
|
class PerformanceMetadata(BaseModel):
|
|
"""Performance information for LLM cost understanding."""
|
|
|
|
query_duration_ms: int = Field(description="Query execution time")
|
|
estimated_cost: str | None = Field(None, description="Resource cost estimate")
|
|
cache_status: str = Field(description="Cache hit/miss status")
|
|
optimization_suggestions: List[str] = Field(
|
|
default_factory=list, description="Performance improvement tips"
|
|
)
|
|
|
|
|
|
class AccessibilityMetadata(BaseModel):
|
|
"""Accessibility information for inclusive visualization."""
|
|
|
|
color_blind_safe: bool = Field(description="Uses colorblind-safe palette")
|
|
alt_text: str = Field(description="Screen reader description")
|
|
high_contrast_available: bool = Field(description="High contrast version available")
|
|
|
|
|
|
class VersionedResponse(BaseModel):
|
|
"""Base class for versioned API responses."""
|
|
|
|
schema_version: str = Field("2.0", description="Response schema version")
|
|
api_version: str = Field("v1", description="MCP API version")
|
|
|
|
|
|
class GetChartInfoRequest(BaseModel):
|
|
"""Request schema for get_chart_info with support for ID or UUID."""
|
|
|
|
identifier: Annotated[
|
|
int | str,
|
|
Field(description="Chart identifier - can be numeric ID or UUID string"),
|
|
]
|
|
|
|
|
|
def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
|
|
if not chart:
|
|
return None
|
|
|
|
# Generate MCP service screenshot URL instead of chart's native URL
|
|
from superset.mcp_service.utils.url_utils import get_chart_screenshot_url
|
|
|
|
chart_id = getattr(chart, "id", None)
|
|
screenshot_url = None
|
|
if chart_id:
|
|
screenshot_url = get_chart_screenshot_url(chart_id)
|
|
|
|
return ChartInfo(
|
|
id=chart_id,
|
|
slice_name=getattr(chart, "slice_name", None),
|
|
viz_type=getattr(chart, "viz_type", None),
|
|
datasource_name=getattr(chart, "datasource_name", None),
|
|
datasource_type=getattr(chart, "datasource_type", None),
|
|
url=screenshot_url,
|
|
description=getattr(chart, "description", None),
|
|
cache_timeout=getattr(chart, "cache_timeout", None),
|
|
form_data=getattr(chart, "form_data", None),
|
|
query_context=getattr(chart, "query_context", None),
|
|
changed_by=getattr(chart, "changed_by_name", None)
|
|
or (str(chart.changed_by) if getattr(chart, "changed_by", None) else None),
|
|
changed_by_name=getattr(chart, "changed_by_name", None),
|
|
changed_on=getattr(chart, "changed_on", None),
|
|
changed_on_humanized=getattr(chart, "changed_on_humanized", None),
|
|
created_by=getattr(chart, "created_by_name", None)
|
|
or (str(chart.created_by) if getattr(chart, "created_by", None) else None),
|
|
created_on=getattr(chart, "created_on", None),
|
|
created_on_humanized=getattr(chart, "created_on_humanized", None),
|
|
uuid=str(getattr(chart, "uuid", "")) if getattr(chart, "uuid", None) else None,
|
|
tags=[
|
|
TagInfo.model_validate(tag, from_attributes=True)
|
|
for tag in getattr(chart, "tags", [])
|
|
]
|
|
if getattr(chart, "tags", None)
|
|
else [],
|
|
owners=[
|
|
UserInfo.model_validate(owner, from_attributes=True)
|
|
for owner in getattr(chart, "owners", [])
|
|
]
|
|
if getattr(chart, "owners", None)
|
|
else [],
|
|
)
|
|
|
|
|
|
class ChartFilter(ColumnOperator):
|
|
"""
|
|
Filter object for chart listing.
|
|
col: The column to filter on. Must be one of the allowed filter fields.
|
|
opr: The operator to use. Must be one of the supported operators.
|
|
value: The value to filter by (type depends on col and opr).
|
|
"""
|
|
|
|
col: Literal[
|
|
"slice_name",
|
|
"viz_type",
|
|
"datasource_name",
|
|
] = Field(
|
|
...,
|
|
description="Column to filter on. See get_chart_available_filters for "
|
|
"allowed values.",
|
|
)
|
|
opr: ColumnOperatorEnum = Field(
|
|
...,
|
|
description="Operator to use. See get_chart_available_filters for "
|
|
"allowed values.",
|
|
)
|
|
value: str | int | float | bool | List[str | int | float | bool] = Field(
|
|
..., description="Value to filter by (type depends on col and opr)"
|
|
)
|
|
|
|
|
|
class ChartList(BaseModel):
|
|
charts: List[ChartInfo]
|
|
count: int
|
|
total_count: int
|
|
page: int
|
|
page_size: int
|
|
total_pages: int
|
|
has_previous: bool
|
|
has_next: bool
|
|
columns_requested: List[str] | None = None
|
|
columns_loaded: List[str] | None = None
|
|
filters_applied: List[ChartFilter] = Field(
|
|
default_factory=list,
|
|
description="List of advanced filter dicts applied to the query.",
|
|
)
|
|
pagination: PaginationInfo | None = None
|
|
timestamp: datetime | None = None
|
|
model_config = ConfigDict(ser_json_timedelta="iso8601")
|
|
|
|
|
|
# --- Simplified schemas for generate_chart tool ---
|
|
|
|
|
|
# Common pieces
|
|
class ColumnRef(BaseModel):
|
|
name: str = Field(
|
|
...,
|
|
description="Column name",
|
|
min_length=1,
|
|
max_length=255,
|
|
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
|
|
)
|
|
label: str | None = Field(
|
|
None, description="Display label for the column", max_length=500
|
|
)
|
|
dtype: str | None = Field(None, description="Data type hint")
|
|
aggregate: (
|
|
Literal[
|
|
"SUM",
|
|
"COUNT",
|
|
"AVG",
|
|
"MIN",
|
|
"MAX",
|
|
"COUNT_DISTINCT",
|
|
"STDDEV",
|
|
"VAR",
|
|
"MEDIAN",
|
|
"PERCENTILE",
|
|
]
|
|
| None
|
|
) = Field(
|
|
None,
|
|
description="SQL aggregation function. Only these validated functions are "
|
|
"supported to prevent SQL errors.",
|
|
)
|
|
|
|
@field_validator("name")
|
|
@classmethod
|
|
def sanitize_name(cls, v: str) -> str:
|
|
"""Sanitize column name to prevent XSS and SQL injection."""
|
|
if not v or not v.strip():
|
|
raise ValueError("Column name cannot be empty")
|
|
|
|
# Length check first to prevent ReDoS attacks
|
|
if len(v) > 255:
|
|
raise ValueError(
|
|
f"Column name too long ({len(v)} characters). "
|
|
f"Maximum allowed length is 255 characters."
|
|
)
|
|
|
|
# Remove HTML tags and decode entities
|
|
sanitized = html.escape(v.strip())
|
|
|
|
# Check for dangerous HTML tags using substring checks (safe)
|
|
dangerous_tags = ["<script", "</script>", "<iframe", "<object", "<embed"]
|
|
v_lower = v.lower()
|
|
for tag in dangerous_tags:
|
|
if tag in v_lower:
|
|
raise ValueError(
|
|
"Column name contains potentially malicious script content"
|
|
)
|
|
|
|
# Check URL schemes with word boundaries to match only actual URLs
|
|
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
|
|
raise ValueError("Column name contains potentially malicious URL scheme")
|
|
|
|
# Basic SQL injection patterns (basic protection)
|
|
# Use simple patterns without backtracking
|
|
dangerous_patterns = [
|
|
r"[;|&$`]", # Dangerous shell characters
|
|
r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
|
|
r"--", # SQL comment
|
|
r"/\*", # SQL comment start (just check for start, not full pattern)
|
|
]
|
|
|
|
for pattern in dangerous_patterns:
|
|
if re.search(pattern, v, re.IGNORECASE):
|
|
raise ValueError(
|
|
"Column name contains potentially unsafe characters or SQL keywords"
|
|
)
|
|
|
|
return sanitized
|
|
|
|
@field_validator("label")
|
|
@classmethod
|
|
def sanitize_label(cls, v: str | None) -> str | None:
|
|
"""Sanitize display label to prevent XSS attacks."""
|
|
if v is None:
|
|
return v
|
|
|
|
# Strip whitespace
|
|
v = v.strip()
|
|
if not v:
|
|
return None
|
|
|
|
# Length check first to prevent ReDoS attacks
|
|
if len(v) > 500:
|
|
raise ValueError(
|
|
f"Label too long ({len(v)} characters). "
|
|
f"Maximum allowed length is 500 characters."
|
|
)
|
|
|
|
# Check for dangerous HTML tags and JavaScript protocols using substring checks
|
|
# This avoids ReDoS vulnerabilities from regex patterns
|
|
dangerous_tags = [
|
|
"<script",
|
|
"</script>",
|
|
"<iframe",
|
|
"</iframe>",
|
|
"<object",
|
|
"</object>",
|
|
"<embed",
|
|
"</embed>",
|
|
"<link",
|
|
"<meta",
|
|
]
|
|
|
|
v_lower = v.lower()
|
|
for tag in dangerous_tags:
|
|
if tag in v_lower:
|
|
raise ValueError(
|
|
"Label contains potentially malicious content. "
|
|
"HTML tags, JavaScript, and event handlers are not allowed "
|
|
"in labels."
|
|
)
|
|
|
|
# Check URL schemes and event handlers with word boundaries
|
|
dangerous_patterns = [
|
|
r"\b(javascript|vbscript|data):", # URL schemes
|
|
r"on\w+\s*=", # Event handlers
|
|
]
|
|
for pattern in dangerous_patterns:
|
|
if re.search(pattern, v, re.IGNORECASE):
|
|
raise ValueError(
|
|
"Label contains potentially malicious content. "
|
|
"HTML tags, JavaScript, and event handlers are not allowed."
|
|
)
|
|
|
|
# Filter dangerous Unicode characters
|
|
v = re.sub(
|
|
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
|
|
)
|
|
|
|
# HTML escape the cleaned content
|
|
sanitized = html.escape(v)
|
|
|
|
return sanitized if sanitized else None
|
|
|
|
|
|
class AxisConfig(BaseModel):
|
|
title: str | None = Field(None, description="Axis title", max_length=200)
|
|
scale: Literal["linear", "log"] | None = Field(
|
|
"linear", description="Axis scale type"
|
|
)
|
|
format: str | None = Field(
|
|
None, description="Format string (e.g. '$,.2f')", max_length=50
|
|
)
|
|
|
|
|
|
class LegendConfig(BaseModel):
|
|
show: bool = Field(True, description="Whether to show legend")
|
|
position: Literal["top", "bottom", "left", "right"] | None = Field(
|
|
"right", description="Legend position"
|
|
)
|
|
|
|
|
|
class FilterConfig(BaseModel):
|
|
column: str = Field(
|
|
..., description="Column to filter on", min_length=1, max_length=255
|
|
)
|
|
op: Literal["=", ">", "<", ">=", "<=", "!="] = Field(
|
|
..., description="Filter operator"
|
|
)
|
|
value: str | int | float | bool = Field(..., description="Filter value")
|
|
|
|
@field_validator("column")
|
|
@classmethod
|
|
def sanitize_column(cls, v: str) -> str:
|
|
"""Sanitize filter column name to prevent injection attacks."""
|
|
if not v or not v.strip():
|
|
raise ValueError("Filter column name cannot be empty")
|
|
|
|
# Length check first to prevent ReDoS attacks
|
|
if len(v) > 255:
|
|
raise ValueError(
|
|
f"Filter column name too long ({len(v)} characters). "
|
|
f"Maximum allowed length is 255 characters."
|
|
)
|
|
|
|
# Remove HTML tags and decode entities
|
|
sanitized = html.escape(v.strip())
|
|
|
|
# Check for dangerous HTML tags using substring checks (safe)
|
|
dangerous_tags = ["<script", "</script>"]
|
|
v_lower = v.lower()
|
|
for tag in dangerous_tags:
|
|
if tag in v_lower:
|
|
raise ValueError(
|
|
"Filter column contains potentially malicious script content"
|
|
)
|
|
|
|
# Check URL schemes with word boundaries
|
|
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
|
|
raise ValueError("Filter column contains potentially malicious URL scheme")
|
|
|
|
return sanitized
|
|
|
|
@staticmethod
|
|
def _validate_string_value(v: str) -> None:
|
|
"""Validate string filter value for security issues."""
|
|
# Check for dangerous HTML tags and SQL procedures
|
|
dangerous_substrings = [
|
|
"<script",
|
|
"</script>",
|
|
"<iframe",
|
|
"<object",
|
|
"<embed",
|
|
"xp_cmdshell",
|
|
"sp_executesql",
|
|
]
|
|
v_lower = v.lower()
|
|
for substring in dangerous_substrings:
|
|
if substring in v_lower:
|
|
raise ValueError(
|
|
"Filter value contains potentially malicious content. "
|
|
"HTML tags and JavaScript are not allowed."
|
|
)
|
|
|
|
# Check URL schemes with word boundaries
|
|
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
|
|
raise ValueError("Filter value contains potentially malicious URL scheme")
|
|
|
|
# SQL injection patterns
|
|
sql_patterns = [
|
|
r";\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
|
|
r"'\s*OR\s*'",
|
|
r"'\s*AND\s*'",
|
|
r"--\s*",
|
|
r"/\*",
|
|
r"UNION\s+SELECT",
|
|
]
|
|
for pattern in sql_patterns:
|
|
if re.search(pattern, v, re.IGNORECASE):
|
|
raise ValueError(
|
|
"Filter value contains potentially malicious SQL patterns."
|
|
)
|
|
|
|
# Check for other dangerous patterns
|
|
if re.search(r"[;&|`$()]", v):
|
|
raise ValueError(
|
|
"Filter value contains potentially unsafe shell characters."
|
|
)
|
|
if re.search(r"on\w+\s*=", v, re.IGNORECASE):
|
|
raise ValueError(
|
|
"Filter value contains potentially malicious event handlers."
|
|
)
|
|
if re.search(r"\\x[0-9a-fA-F]{2}", v):
|
|
raise ValueError("Filter value contains hex encoding which is not allowed.")
|
|
|
|
@field_validator("value")
|
|
@classmethod
|
|
def sanitize_value(cls, v: str | int | float | bool) -> str | int | float | bool:
|
|
"""Sanitize filter value to prevent XSS and SQL injection attacks."""
|
|
if isinstance(v, str):
|
|
v = v.strip()
|
|
|
|
# Length check FIRST to prevent ReDoS attacks
|
|
if len(v) > 1000:
|
|
raise ValueError(
|
|
f"Filter value too long ({len(v)} characters). "
|
|
f"Maximum allowed length is 1000 characters."
|
|
)
|
|
|
|
# Validate security
|
|
cls._validate_string_value(v)
|
|
|
|
# Filter dangerous Unicode characters
|
|
v = re.sub(
|
|
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
|
|
)
|
|
|
|
# HTML escape the cleaned content
|
|
return html.escape(v)
|
|
|
|
return v # Return non-string values as-is
|
|
|
|
|
|
# Actual chart types
|
|
class TableChartConfig(BaseModel):
|
|
chart_type: Literal["table"] = Field("table", description="Chart type")
|
|
columns: List[ColumnRef] = Field(
|
|
...,
|
|
min_length=1,
|
|
description=(
|
|
"Columns to display. Must have at least one column. Each column must have "
|
|
"a unique label "
|
|
"(either explicitly set via 'label' field or auto-generated "
|
|
"from name/aggregate)"
|
|
),
|
|
)
|
|
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
|
|
sort_by: List[str] | None = Field(None, description="Columns to sort by")
|
|
|
|
@model_validator(mode="after")
|
|
def validate_unique_column_labels(self) -> "TableChartConfig":
|
|
"""Ensure all column labels are unique."""
|
|
labels_seen = set()
|
|
duplicates = []
|
|
|
|
for i, col in enumerate(self.columns):
|
|
# Generate the label that will be used (same logic as create_metric_object)
|
|
if col.aggregate:
|
|
label = col.label or f"{col.aggregate}({col.name})"
|
|
else:
|
|
label = col.label or col.name
|
|
|
|
if label in labels_seen:
|
|
duplicates.append(f"columns[{i}]: '{label}'")
|
|
else:
|
|
labels_seen.add(label)
|
|
|
|
if duplicates:
|
|
raise ValueError(
|
|
f"Duplicate column/metric labels: {', '.join(duplicates)}. "
|
|
f"Please make sure all columns and metrics have a unique label. "
|
|
f"Use the 'label' field to provide custom names for columns."
|
|
)
|
|
|
|
return self
|
|
|
|
|
|
class XYChartConfig(BaseModel):
|
|
chart_type: Literal["xy"] = Field("xy", description="Chart type")
|
|
x: ColumnRef = Field(..., description="X-axis column")
|
|
y: List[ColumnRef] = Field(
|
|
...,
|
|
min_length=1,
|
|
description="Y-axis columns (metrics). Must have at least one Y-axis column. "
|
|
"Each column must have a unique label "
|
|
"that doesn't conflict with x-axis or group_by labels",
|
|
)
|
|
kind: Literal["line", "bar", "area", "scatter"] = Field(
|
|
"line", description="Chart visualization type"
|
|
)
|
|
group_by: ColumnRef | None = Field(None, description="Column to group by")
|
|
x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
|
|
y_axis: AxisConfig | None = Field(None, description="Y-axis configuration")
|
|
legend: LegendConfig | None = Field(None, description="Legend configuration")
|
|
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
|
|
|
|
@model_validator(mode="after")
|
|
def validate_unique_column_labels(self) -> "XYChartConfig":
|
|
"""Ensure all column labels are unique across x, y, and group_by."""
|
|
labels_seen = {} # label -> field_name for error reporting
|
|
duplicates = []
|
|
|
|
# Check X-axis label
|
|
x_label = self.x.label or self.x.name
|
|
labels_seen[x_label] = "x"
|
|
|
|
# Check Y-axis labels
|
|
for i, col in enumerate(self.y):
|
|
if col.aggregate:
|
|
label = col.label or f"{col.aggregate}({col.name})"
|
|
else:
|
|
label = col.label or col.name
|
|
|
|
if label in labels_seen:
|
|
duplicates.append(
|
|
f"y[{i}]: '{label}' (conflicts with {labels_seen[label]})"
|
|
)
|
|
else:
|
|
labels_seen[label] = f"y[{i}]"
|
|
|
|
# Check group_by label if present
|
|
if self.group_by:
|
|
group_label = self.group_by.label or self.group_by.name
|
|
if group_label in labels_seen:
|
|
duplicates.append(
|
|
f"group_by: '{group_label}' "
|
|
f"(conflicts with {labels_seen[group_label]})"
|
|
)
|
|
|
|
if duplicates:
|
|
raise ValueError(
|
|
f"Duplicate column/metric labels: {', '.join(duplicates)}. "
|
|
f"Please make sure all columns and metrics have a unique label. "
|
|
f"Use the 'label' field to provide custom names for columns."
|
|
)
|
|
|
|
return self
|
|
|
|
|
|
# Discriminated union entry point with custom error handling
|
|
ChartConfig = Annotated[
|
|
XYChartConfig | TableChartConfig,
|
|
Field(
|
|
discriminator="chart_type",
|
|
description="Chart configuration - specify chart_type as 'xy' or 'table'",
|
|
),
|
|
]
|
|
|
|
|
|
class ListChartsRequest(MetadataCacheControl):
|
|
"""Request schema for list_charts with clear, unambiguous types."""
|
|
|
|
filters: Annotated[
|
|
List[ChartFilter],
|
|
Field(
|
|
default_factory=list,
|
|
description="List of filter objects (column, operator, value). Each "
|
|
"filter is an object with 'col', 'opr', and 'value' "
|
|
"properties. Cannot be used together with 'search'.",
|
|
),
|
|
]
|
|
select_columns: Annotated[
|
|
List[str],
|
|
Field(
|
|
default_factory=lambda: [
|
|
"id",
|
|
"slice_name",
|
|
"viz_type",
|
|
"datasource_name",
|
|
"description",
|
|
"changed_by_name",
|
|
"created_by_name",
|
|
"changed_on",
|
|
"created_on",
|
|
"uuid",
|
|
],
|
|
description="List of columns to select. Defaults to common columns if not "
|
|
"specified.",
|
|
),
|
|
]
|
|
|
|
@field_validator("filters", mode="before")
|
|
@classmethod
|
|
def parse_filters(cls, v: Any) -> List[ChartFilter]:
|
|
"""
|
|
Parse filters from JSON string or list.
|
|
|
|
Handles Claude Code bug where objects are double-serialized as strings.
|
|
See: https://github.com/anthropics/claude-code/issues/5504
|
|
"""
|
|
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
|
|
|
|
return parse_json_or_model_list(v, ChartFilter, "filters")
|
|
|
|
@field_validator("select_columns", mode="before")
|
|
@classmethod
|
|
def parse_select_columns(cls, v: Any) -> List[str]:
|
|
"""
|
|
Parse select_columns from JSON string, list, or CSV string.
|
|
|
|
Handles Claude Code bug where arrays are double-serialized as strings.
|
|
See: https://github.com/anthropics/claude-code/issues/5504
|
|
"""
|
|
from superset.mcp_service.utils.schema_utils import parse_json_or_list
|
|
|
|
return parse_json_or_list(v, "select_columns")
|
|
|
|
search: Annotated[
|
|
str | None,
|
|
Field(
|
|
default=None,
|
|
description="Text search string to match against chart fields. Cannot be "
|
|
"used together with 'filters'.",
|
|
),
|
|
]
|
|
order_column: Annotated[
|
|
str | None, Field(default=None, description="Column to order results by")
|
|
]
|
|
order_direction: Annotated[
|
|
Literal["asc", "desc"],
|
|
Field(
|
|
default="asc", description="Direction to order results ('asc' or 'desc')"
|
|
),
|
|
]
|
|
page: Annotated[
|
|
PositiveInt,
|
|
Field(default=1, description="Page number for pagination (1-based)"),
|
|
]
|
|
page_size: Annotated[
|
|
PositiveInt, Field(default=10, description="Number of items per page")
|
|
]
|
|
|
|
@model_validator(mode="after")
|
|
def validate_search_and_filters(self) -> "ListChartsRequest":
|
|
"""Prevent using both search and filters simultaneously to avoid query
|
|
conflicts."""
|
|
if self.search and self.filters:
|
|
raise ValueError(
|
|
"Cannot use both 'search' and 'filters' parameters simultaneously. "
|
|
"Use either 'search' for text-based searching across multiple fields, "
|
|
"or 'filters' for precise column-based filtering, but not both."
|
|
)
|
|
return self
|
|
|
|
|
|
# The tool input models
|
|
class GenerateChartRequest(QueryCacheControl):
|
|
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
|
config: ChartConfig = Field(..., description="Chart configuration")
|
|
save_chart: bool = Field(
|
|
default=True,
|
|
description="Whether to permanently save the chart in Superset",
|
|
)
|
|
generate_preview: bool = Field(
|
|
default=True,
|
|
description="Whether to generate a preview image",
|
|
)
|
|
preview_formats: List[
|
|
Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"]
|
|
] = Field(
|
|
default_factory=lambda: ["url"],
|
|
description="List of preview formats to generate",
|
|
)
|
|
|
|
@model_validator(mode="after")
|
|
def validate_cache_timeout(self) -> "GenerateChartRequest":
|
|
"""Validate cache timeout is non-negative."""
|
|
if (
|
|
hasattr(self, "cache_timeout")
|
|
and self.cache_timeout is not None
|
|
and self.cache_timeout < 0
|
|
):
|
|
raise ValueError(
|
|
"cache_timeout must be non-negative (0 or positive integer)"
|
|
)
|
|
return self
|
|
|
|
|
|
class GenerateExploreLinkRequest(FormDataCacheControl):
|
|
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
|
config: ChartConfig = Field(..., description="Chart configuration")
|
|
|
|
|
|
class UpdateChartRequest(QueryCacheControl):
|
|
identifier: int | str = Field(..., description="Chart identifier (ID, UUID)")
|
|
config: ChartConfig = Field(..., description="New chart configuration")
|
|
chart_name: str | None = Field(
|
|
None,
|
|
description="New chart name (optional, will auto-generate if not provided)",
|
|
max_length=255,
|
|
)
|
|
generate_preview: bool = Field(
|
|
default=True,
|
|
description="Whether to generate a preview after updating",
|
|
)
|
|
preview_formats: List[
|
|
Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"]
|
|
] = Field(
|
|
default_factory=lambda: ["url"],
|
|
description="List of preview formats to generate",
|
|
)
|
|
|
|
@field_validator("chart_name")
|
|
@classmethod
|
|
def sanitize_chart_name(cls, v: str | None) -> str | None:
|
|
"""Sanitize chart name to prevent XSS attacks."""
|
|
if v is None:
|
|
return v
|
|
|
|
# Strip whitespace
|
|
v = v.strip()
|
|
if not v:
|
|
return None
|
|
|
|
# Length check first to prevent ReDoS attacks
|
|
if len(v) > 255:
|
|
raise ValueError(
|
|
f"Chart name too long ({len(v)} characters). "
|
|
f"Maximum allowed length is 255 characters."
|
|
)
|
|
|
|
# Check for dangerous HTML tags using substring checks (safe)
|
|
dangerous_tags = [
|
|
"<script",
|
|
"</script>",
|
|
"<iframe",
|
|
"</iframe>",
|
|
"<object",
|
|
"</object>",
|
|
"<embed",
|
|
"</embed>",
|
|
"<link",
|
|
"<meta",
|
|
]
|
|
|
|
v_lower = v.lower()
|
|
for tag in dangerous_tags:
|
|
if tag in v_lower:
|
|
raise ValueError(
|
|
"Chart name contains potentially malicious content. "
|
|
"HTML tags and JavaScript are not allowed in chart names."
|
|
)
|
|
|
|
# Check URL schemes with word boundaries
|
|
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
|
|
raise ValueError("Chart name contains potentially malicious URL scheme")
|
|
|
|
# Check for event handlers with simple regex
|
|
if re.search(r"on\w+\s*=", v, re.IGNORECASE):
|
|
raise ValueError(
|
|
"Chart name contains potentially malicious event handlers."
|
|
)
|
|
|
|
# Filter dangerous Unicode characters
|
|
v = re.sub(
|
|
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
|
|
)
|
|
|
|
# HTML escape the cleaned content
|
|
sanitized = html.escape(v)
|
|
|
|
return sanitized if sanitized else None
|
|
|
|
|
|
class UpdateChartPreviewRequest(FormDataCacheControl):
|
|
form_data_key: str = Field(..., description="Existing form_data_key to update")
|
|
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
|
config: ChartConfig = Field(..., description="New chart configuration")
|
|
generate_preview: bool = Field(
|
|
default=True,
|
|
description="Whether to generate a preview after updating",
|
|
)
|
|
preview_formats: List[
|
|
Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"]
|
|
] = Field(
|
|
default_factory=lambda: ["url"],
|
|
description="List of preview formats to generate",
|
|
)
|
|
|
|
|
|
class GetChartDataRequest(QueryCacheControl):
|
|
"""Request for chart data with cache control."""
|
|
|
|
identifier: int | str = Field(description="Chart identifier (ID, UUID)")
|
|
limit: int | None = Field(
|
|
default=100, description="Maximum number of data rows to return"
|
|
)
|
|
format: Literal["json", "csv", "excel"] = Field(
|
|
default="json", description="Data export format"
|
|
)
|
|
|
|
|
|
class DataColumn(BaseModel):
|
|
"""Enhanced column metadata with semantic information."""
|
|
|
|
name: str = Field(..., description="Column name")
|
|
display_name: str = Field(..., description="Human-readable column name")
|
|
data_type: str = Field(..., description="Inferred data type")
|
|
sample_values: List[Any] = Field(description="Representative sample values")
|
|
null_count: int = Field(description="Number of null values")
|
|
unique_count: int = Field(description="Number of unique values")
|
|
statistics: Dict[str, Any] | None = Field(
|
|
None, description="Column statistics if numeric"
|
|
)
|
|
semantic_type: str | None = Field(
|
|
None, description="Semantic type (currency, percentage, etc)"
|
|
)
|
|
|
|
|
|
class ChartData(BaseModel):
|
|
"""Rich chart data response with statistical insights."""
|
|
|
|
# Basic information
|
|
chart_id: int
|
|
chart_name: str
|
|
chart_type: str
|
|
|
|
# Enhanced data description
|
|
columns: List[DataColumn] = Field(description="Rich column metadata")
|
|
data: List[Dict[str, Any]] = Field(description="Actual data rows")
|
|
|
|
# Data insights
|
|
row_count: int = Field(description="Rows returned")
|
|
total_rows: int | None = Field(description="Total available rows")
|
|
data_freshness: datetime | None = Field(description="When data was last updated")
|
|
|
|
# LLM-friendly summaries
|
|
summary: str = Field(description="Human-readable data summary")
|
|
insights: List[str] = Field(description="Key patterns discovered in the data")
|
|
data_quality: Dict[str, Any] = Field(description="Data quality assessment")
|
|
recommended_visualizations: List[str] = Field(
|
|
description="Suggested chart types for this data"
|
|
)
|
|
|
|
# Performance and metadata
|
|
performance: PerformanceMetadata = Field(description="Query performance metrics")
|
|
cache_status: CacheStatus | None = Field(
|
|
None, description="Cache usage information"
|
|
)
|
|
|
|
# Export format fields
|
|
csv_data: str | None = Field(None, description="CSV content when format='csv'")
|
|
excel_data: str | None = Field(
|
|
None, description="Base64-encoded Excel content when format='excel'"
|
|
)
|
|
format: str | None = Field(
|
|
None, description="Export format used (json, csv, excel)"
|
|
)
|
|
|
|
# Inherit versioning
|
|
schema_version: str = Field("2.0", description="Response schema version")
|
|
api_version: str = Field("v1", description="MCP API version")
|
|
|
|
|
|
class GetChartPreviewRequest(QueryCacheControl):
|
|
"""Request for chart preview with cache control."""
|
|
|
|
identifier: int | str = Field(description="Chart identifier (ID, UUID)")
|
|
format: Literal["url", "ascii", "table", "base64", "vega_lite"] = Field(
|
|
default="url",
|
|
description=(
|
|
"Preview format: 'url' for image URL, 'ascii' for text art, "
|
|
"'table' for data table, 'base64' for embedded image, "
|
|
"'vega_lite' for interactive JSON specification"
|
|
),
|
|
)
|
|
width: int | None = Field(
|
|
default=800,
|
|
description="Preview image width in pixels (for url/base64 formats)",
|
|
)
|
|
height: int | None = Field(
|
|
default=600,
|
|
description="Preview image height in pixels (for url/base64 formats)",
|
|
)
|
|
ascii_width: int | None = Field(
|
|
default=80, description="ASCII chart width in characters (for ascii format)"
|
|
)
|
|
ascii_height: int | None = Field(
|
|
default=20, description="ASCII chart height in lines (for ascii format)"
|
|
)
|
|
|
|
|
|
# Discriminated union preview formats for type safety
|
|
class URLPreview(BaseModel):
|
|
"""URL-based image preview format."""
|
|
|
|
type: Literal["url"] = "url"
|
|
preview_url: str = Field(..., description="Direct image URL")
|
|
width: int = Field(..., description="Image width in pixels")
|
|
height: int = Field(..., description="Image height in pixels")
|
|
supports_interaction: bool = Field(
|
|
False, description="Static image, no interaction"
|
|
)
|
|
|
|
|
|
class InteractivePreview(BaseModel):
|
|
"""Interactive HTML preview with JavaScript controls."""
|
|
|
|
type: Literal["interactive"] = "interactive"
|
|
html_content: str = Field(..., description="Embeddable HTML with Plotly/D3")
|
|
preview_url: str = Field(..., description="Iframe-compatible URL")
|
|
width: int = Field(..., description="Viewport width")
|
|
height: int = Field(..., description="Viewport height")
|
|
supports_pan: bool = Field(True, description="Supports pan interaction")
|
|
supports_zoom: bool = Field(True, description="Supports zoom interaction")
|
|
supports_hover: bool = Field(True, description="Supports hover details")
|
|
|
|
|
|
class ASCIIPreview(BaseModel):
|
|
"""ASCII art text representation."""
|
|
|
|
type: Literal["ascii"] = "ascii"
|
|
ascii_content: str = Field(..., description="Unicode art representation")
|
|
width: int = Field(..., description="Character width")
|
|
height: int = Field(..., description="Line height")
|
|
supports_color: bool = Field(False, description="Uses ANSI color codes")
|
|
|
|
|
|
class VegaLitePreview(BaseModel):
|
|
"""Vega-Lite grammar of graphics specification."""
|
|
|
|
type: Literal["vega_lite"] = "vega_lite"
|
|
specification: Dict[str, Any] = Field(..., description="Vega-Lite JSON spec")
|
|
data_url: str | None = Field(None, description="External data URL")
|
|
supports_streaming: bool = Field(False, description="Supports live data updates")
|
|
|
|
|
|
class TablePreview(BaseModel):
|
|
"""Tabular data preview format."""
|
|
|
|
type: Literal["table"] = "table"
|
|
table_data: str = Field(..., description="Formatted table content")
|
|
row_count: int = Field(..., description="Number of rows displayed")
|
|
supports_sorting: bool = Field(False, description="Table supports sorting")
|
|
|
|
|
|
# Modern discriminated union using | syntax
|
|
ChartPreviewContent = Annotated[
|
|
URLPreview | InteractivePreview | ASCIIPreview | VegaLitePreview | TablePreview,
|
|
Field(discriminator="type"),
|
|
]
|
|
|
|
|
|
class GenerateChartResponse(BaseModel):
|
|
"""Comprehensive chart creation response with rich metadata."""
|
|
|
|
# Core chart information
|
|
chart: ChartInfo | None = Field(None, description="Complete chart metadata")
|
|
|
|
# Multiple preview formats available
|
|
previews: Dict[str, ChartPreviewContent] = Field(
|
|
default_factory=dict,
|
|
description="Available preview formats keyed by format type",
|
|
)
|
|
|
|
# LLM-friendly capabilities
|
|
capabilities: ChartCapabilities | None = Field(
|
|
None, description="Chart interaction capabilities"
|
|
)
|
|
semantics: ChartSemantics | None = Field(
|
|
None, description="Semantic chart understanding"
|
|
)
|
|
|
|
# Navigation and context
|
|
explore_url: str | None = Field(None, description="Edit chart in Superset")
|
|
embed_code: str | None = Field(None, description="HTML embed snippet")
|
|
api_endpoints: Dict[str, str] = Field(
|
|
default_factory=dict, description="Related API endpoints for data/updates"
|
|
)
|
|
|
|
# Performance and accessibility
|
|
performance: PerformanceMetadata | None = Field(
|
|
None, description="Performance metrics"
|
|
)
|
|
accessibility: AccessibilityMetadata | None = Field(
|
|
None, description="Accessibility info"
|
|
)
|
|
|
|
# Success/error handling
|
|
success: bool = Field(True, description="Whether chart creation succeeded")
|
|
error: ChartGenerationError | None = Field(
|
|
None, description="Error details if creation failed"
|
|
)
|
|
warnings: List[str] = Field(default_factory=list, description="Non-fatal warnings")
|
|
|
|
# Inherit versioning
|
|
schema_version: str = Field("2.0", description="Response schema version")
|
|
api_version: str = Field("v1", description="MCP API version")
|
|
|
|
|
|
class ChartPreview(BaseModel):
|
|
"""Enhanced chart preview with discriminated union content."""
|
|
|
|
chart_id: int
|
|
chart_name: str
|
|
chart_type: str = Field(description="Type of chart visualization")
|
|
explore_url: str = Field(description="URL to open chart in Superset for editing")
|
|
|
|
# Type-safe preview content
|
|
content: ChartPreviewContent = Field(
|
|
description="Preview content in requested format"
|
|
)
|
|
|
|
# Rich metadata
|
|
chart_description: str = Field(
|
|
description="Human-readable description of the chart"
|
|
)
|
|
accessibility: AccessibilityMetadata = Field(
|
|
description="Accessibility information"
|
|
)
|
|
performance: PerformanceMetadata = Field(description="Performance metrics")
|
|
|
|
# Backward compatibility fields (populated based on content type)
|
|
format: str | None = Field(
|
|
None, description="Format of the preview (url, ascii, table, base64)"
|
|
)
|
|
preview_url: str | None = Field(None, description="Image URL for 'url' format")
|
|
ascii_chart: str | None = Field(
|
|
None, description="ASCII art chart for 'ascii' format"
|
|
)
|
|
table_data: str | None = Field(
|
|
None, description="Formatted table for 'table' format"
|
|
)
|
|
width: int | None = Field(
|
|
None, description="Width (pixels for images, characters for ASCII)"
|
|
)
|
|
height: int | None = Field(
|
|
None, description="Height (pixels for images, lines for ASCII)"
|
|
)
|
|
|
|
# Inherit versioning
|
|
schema_version: str = Field("2.0", description="Response schema version")
|
|
api_version: str = Field("v1", description="MCP API version")
|