Files
superset2/superset/mcp_service/chart/schemas.py

1190 lines
42 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Pydantic schemas for chart-related responses
"""
from __future__ import annotations
import html
import re
from datetime import datetime, timezone
from typing import Annotated, Any, Dict, List, Literal, Protocol
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
PositiveInt,
)
from superset.daos.base import ColumnOperator, ColumnOperatorEnum
from superset.mcp_service.common.cache_schemas import (
CacheStatus,
FormDataCacheControl,
MetadataCacheControl,
QueryCacheControl,
)
from superset.mcp_service.common.error_schemas import ChartGenerationError
from superset.mcp_service.system.schemas import (
PaginationInfo,
TagInfo,
UserInfo,
)
class ChartLike(Protocol):
"""Protocol for chart-like objects with expected attributes."""
id: int
slice_name: str | None
viz_type: str | None
datasource_name: str | None
datasource_type: str | None
url: str | None
description: str | None
cache_timeout: int | None
form_data: Dict[str, Any] | None
query_context: Any | None
changed_by: Any | None # User object
changed_by_name: str | None
changed_on: str | datetime | None
changed_on_humanized: str | None
created_by: Any | None # User object
created_by_name: str | None
created_on: str | datetime | None
created_on_humanized: str | None
uuid: str | None
tags: List[Any] | None
owners: List[Any] | None
class ChartInfo(BaseModel):
"""Full chart model with all possible attributes."""
id: int = Field(..., description="Chart ID")
slice_name: str = Field(..., description="Chart name")
viz_type: str | None = Field(None, description="Visualization type")
datasource_name: str | None = Field(None, description="Datasource name")
datasource_type: str | None = Field(None, description="Datasource type")
url: str | None = Field(None, description="Chart URL")
description: str | None = Field(None, description="Chart description")
cache_timeout: int | None = Field(None, description="Cache timeout")
form_data: Dict[str, Any] | None = Field(None, description="Chart form data")
query_context: Any | None = Field(None, description="Query context")
changed_by: str | None = Field(None, description="Last modifier (username)")
changed_by_name: str | None = Field(
None, description="Last modifier (display name)"
)
changed_on: str | datetime | None = Field(
None, description="Last modification timestamp"
)
changed_on_humanized: str | None = Field(
None, description="Humanized modification time"
)
created_by: str | None = Field(None, description="Chart creator (username)")
created_on: str | datetime | None = Field(None, description="Creation timestamp")
created_on_humanized: str | None = Field(
None, description="Humanized creation time"
)
uuid: str | None = Field(None, description="Chart UUID")
tags: List[TagInfo] = Field(default_factory=list, description="Chart tags")
owners: List[UserInfo] = Field(default_factory=list, description="Chart owners")
model_config = ConfigDict(from_attributes=True, ser_json_timedelta="iso8601")
class GetChartAvailableFiltersRequest(BaseModel):
"""
Request schema for get_chart_available_filters tool.
Currently has no parameters but provides consistent API for future extensibility.
"""
model_config = ConfigDict(
extra="forbid",
str_strip_whitespace=True,
)
class ChartAvailableFiltersResponse(BaseModel):
column_operators: Dict[str, Any] = Field(
..., description="Available filter operators and metadata for each column"
)
class ChartError(BaseModel):
error: str = Field(..., description="Error message")
error_type: str = Field(..., description="Type of error")
timestamp: datetime = Field(
default_factory=lambda: datetime.now(timezone.utc),
description="Error timestamp",
)
model_config = ConfigDict(ser_json_timedelta="iso8601")
class ChartCapabilities(BaseModel):
"""Describes what the chart can do for LLM understanding."""
supports_interaction: bool = Field(description="Chart supports user interaction")
supports_real_time: bool = Field(description="Chart supports live data updates")
supports_drill_down: bool = Field(
description="Chart supports drill-down navigation"
)
supports_export: bool = Field(description="Chart can be exported to other formats")
optimal_formats: List[str] = Field(description="Recommended preview formats")
data_types: List[str] = Field(
description="Types of data shown (time_series, categorical, etc)"
)
class ChartSemantics(BaseModel):
"""Semantic information for LLM reasoning."""
primary_insight: str = Field(
description="Main insight or pattern the chart reveals"
)
data_story: str = Field(description="Narrative description of what the data shows")
recommended_actions: List[str] = Field(
description="Suggested next steps based on data"
)
anomalies: List[str] = Field(description="Notable outliers or unusual patterns")
statistical_summary: Dict[str, Any] = Field(
description="Key statistics (mean, median, trends)"
)
class PerformanceMetadata(BaseModel):
"""Performance information for LLM cost understanding."""
query_duration_ms: int = Field(description="Query execution time")
estimated_cost: str | None = Field(None, description="Resource cost estimate")
cache_status: str = Field(description="Cache hit/miss status")
optimization_suggestions: List[str] = Field(
default_factory=list, description="Performance improvement tips"
)
class AccessibilityMetadata(BaseModel):
"""Accessibility information for inclusive visualization."""
color_blind_safe: bool = Field(description="Uses colorblind-safe palette")
alt_text: str = Field(description="Screen reader description")
high_contrast_available: bool = Field(description="High contrast version available")
class VersionedResponse(BaseModel):
"""Base class for versioned API responses."""
schema_version: str = Field("2.0", description="Response schema version")
api_version: str = Field("v1", description="MCP API version")
class GetChartInfoRequest(BaseModel):
"""Request schema for get_chart_info with support for ID or UUID."""
identifier: Annotated[
int | str,
Field(description="Chart identifier - can be numeric ID or UUID string"),
]
def serialize_chart_object(chart: ChartLike | None) -> ChartInfo | None:
if not chart:
return None
# Generate MCP service screenshot URL instead of chart's native URL
from superset.mcp_service.utils.url_utils import get_chart_screenshot_url
chart_id = getattr(chart, "id", None)
screenshot_url = None
if chart_id:
screenshot_url = get_chart_screenshot_url(chart_id)
return ChartInfo(
id=chart_id,
slice_name=getattr(chart, "slice_name", None),
viz_type=getattr(chart, "viz_type", None),
datasource_name=getattr(chart, "datasource_name", None),
datasource_type=getattr(chart, "datasource_type", None),
url=screenshot_url,
description=getattr(chart, "description", None),
cache_timeout=getattr(chart, "cache_timeout", None),
form_data=getattr(chart, "form_data", None),
query_context=getattr(chart, "query_context", None),
changed_by=getattr(chart, "changed_by_name", None)
or (str(chart.changed_by) if getattr(chart, "changed_by", None) else None),
changed_by_name=getattr(chart, "changed_by_name", None),
changed_on=getattr(chart, "changed_on", None),
changed_on_humanized=getattr(chart, "changed_on_humanized", None),
created_by=getattr(chart, "created_by_name", None)
or (str(chart.created_by) if getattr(chart, "created_by", None) else None),
created_on=getattr(chart, "created_on", None),
created_on_humanized=getattr(chart, "created_on_humanized", None),
uuid=str(getattr(chart, "uuid", "")) if getattr(chart, "uuid", None) else None,
tags=[
TagInfo.model_validate(tag, from_attributes=True)
for tag in getattr(chart, "tags", [])
]
if getattr(chart, "tags", None)
else [],
owners=[
UserInfo.model_validate(owner, from_attributes=True)
for owner in getattr(chart, "owners", [])
]
if getattr(chart, "owners", None)
else [],
)
class ChartFilter(ColumnOperator):
"""
Filter object for chart listing.
col: The column to filter on. Must be one of the allowed filter fields.
opr: The operator to use. Must be one of the supported operators.
value: The value to filter by (type depends on col and opr).
"""
col: Literal[
"slice_name",
"viz_type",
"datasource_name",
] = Field(
...,
description="Column to filter on. See get_chart_available_filters for "
"allowed values.",
)
opr: ColumnOperatorEnum = Field(
...,
description="Operator to use. See get_chart_available_filters for "
"allowed values.",
)
value: str | int | float | bool | List[str | int | float | bool] = Field(
..., description="Value to filter by (type depends on col and opr)"
)
class ChartList(BaseModel):
charts: List[ChartInfo]
count: int
total_count: int
page: int
page_size: int
total_pages: int
has_previous: bool
has_next: bool
columns_requested: List[str] | None = None
columns_loaded: List[str] | None = None
filters_applied: List[ChartFilter] = Field(
default_factory=list,
description="List of advanced filter dicts applied to the query.",
)
pagination: PaginationInfo | None = None
timestamp: datetime | None = None
model_config = ConfigDict(ser_json_timedelta="iso8601")
# --- Simplified schemas for generate_chart tool ---
# Common pieces
class ColumnRef(BaseModel):
name: str = Field(
...,
description="Column name",
min_length=1,
max_length=255,
pattern=r"^[a-zA-Z0-9_][a-zA-Z0-9_\s\-\.]*$",
)
label: str | None = Field(
None, description="Display label for the column", max_length=500
)
dtype: str | None = Field(None, description="Data type hint")
aggregate: (
Literal[
"SUM",
"COUNT",
"AVG",
"MIN",
"MAX",
"COUNT_DISTINCT",
"STDDEV",
"VAR",
"MEDIAN",
"PERCENTILE",
]
| None
) = Field(
None,
description="SQL aggregation function. Only these validated functions are "
"supported to prevent SQL errors.",
)
@field_validator("name")
@classmethod
def sanitize_name(cls, v: str) -> str:
"""Sanitize column name to prevent XSS and SQL injection."""
if not v or not v.strip():
raise ValueError("Column name cannot be empty")
# Length check first to prevent ReDoS attacks
if len(v) > 255:
raise ValueError(
f"Column name too long ({len(v)} characters). "
f"Maximum allowed length is 255 characters."
)
# Remove HTML tags and decode entities
sanitized = html.escape(v.strip())
# Check for dangerous HTML tags using substring checks (safe)
dangerous_tags = ["<script", "</script>", "<iframe", "<object", "<embed"]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Column name contains potentially malicious script content"
)
# Check URL schemes with word boundaries to match only actual URLs
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Column name contains potentially malicious URL scheme")
# Basic SQL injection patterns (basic protection)
# Use simple patterns without backtracking
dangerous_patterns = [
r"[;|&$`]", # Dangerous shell characters
r"\b(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
r"--", # SQL comment
r"/\*", # SQL comment start (just check for start, not full pattern)
]
for pattern in dangerous_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError(
"Column name contains potentially unsafe characters or SQL keywords"
)
return sanitized
@field_validator("label")
@classmethod
def sanitize_label(cls, v: str | None) -> str | None:
"""Sanitize display label to prevent XSS attacks."""
if v is None:
return v
# Strip whitespace
v = v.strip()
if not v:
return None
# Length check first to prevent ReDoS attacks
if len(v) > 500:
raise ValueError(
f"Label too long ({len(v)} characters). "
f"Maximum allowed length is 500 characters."
)
# Check for dangerous HTML tags and JavaScript protocols using substring checks
# This avoids ReDoS vulnerabilities from regex patterns
dangerous_tags = [
"<script",
"</script>",
"<iframe",
"</iframe>",
"<object",
"</object>",
"<embed",
"</embed>",
"<link",
"<meta",
]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Label contains potentially malicious content. "
"HTML tags, JavaScript, and event handlers are not allowed "
"in labels."
)
# Check URL schemes and event handlers with word boundaries
dangerous_patterns = [
r"\b(javascript|vbscript|data):", # URL schemes
r"on\w+\s*=", # Event handlers
]
for pattern in dangerous_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError(
"Label contains potentially malicious content. "
"HTML tags, JavaScript, and event handlers are not allowed."
)
# Filter dangerous Unicode characters
v = re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
)
# HTML escape the cleaned content
sanitized = html.escape(v)
return sanitized if sanitized else None
class AxisConfig(BaseModel):
title: str | None = Field(None, description="Axis title", max_length=200)
scale: Literal["linear", "log"] | None = Field(
"linear", description="Axis scale type"
)
format: str | None = Field(
None, description="Format string (e.g. '$,.2f')", max_length=50
)
class LegendConfig(BaseModel):
show: bool = Field(True, description="Whether to show legend")
position: Literal["top", "bottom", "left", "right"] | None = Field(
"right", description="Legend position"
)
class FilterConfig(BaseModel):
column: str = Field(
..., description="Column to filter on", min_length=1, max_length=255
)
op: Literal["=", ">", "<", ">=", "<=", "!="] = Field(
..., description="Filter operator"
)
value: str | int | float | bool = Field(..., description="Filter value")
@field_validator("column")
@classmethod
def sanitize_column(cls, v: str) -> str:
"""Sanitize filter column name to prevent injection attacks."""
if not v or not v.strip():
raise ValueError("Filter column name cannot be empty")
# Length check first to prevent ReDoS attacks
if len(v) > 255:
raise ValueError(
f"Filter column name too long ({len(v)} characters). "
f"Maximum allowed length is 255 characters."
)
# Remove HTML tags and decode entities
sanitized = html.escape(v.strip())
# Check for dangerous HTML tags using substring checks (safe)
dangerous_tags = ["<script", "</script>"]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Filter column contains potentially malicious script content"
)
# Check URL schemes with word boundaries
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Filter column contains potentially malicious URL scheme")
return sanitized
@staticmethod
def _validate_string_value(v: str) -> None:
"""Validate string filter value for security issues."""
# Check for dangerous HTML tags and SQL procedures
dangerous_substrings = [
"<script",
"</script>",
"<iframe",
"<object",
"<embed",
"xp_cmdshell",
"sp_executesql",
]
v_lower = v.lower()
for substring in dangerous_substrings:
if substring in v_lower:
raise ValueError(
"Filter value contains potentially malicious content. "
"HTML tags and JavaScript are not allowed."
)
# Check URL schemes with word boundaries
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Filter value contains potentially malicious URL scheme")
# SQL injection patterns
sql_patterns = [
r";\s*(DROP|DELETE|INSERT|UPDATE|CREATE|ALTER|EXEC|EXECUTE)\b",
r"'\s*OR\s*'",
r"'\s*AND\s*'",
r"--\s*",
r"/\*",
r"UNION\s+SELECT",
]
for pattern in sql_patterns:
if re.search(pattern, v, re.IGNORECASE):
raise ValueError(
"Filter value contains potentially malicious SQL patterns."
)
# Check for other dangerous patterns
if re.search(r"[;&|`$()]", v):
raise ValueError(
"Filter value contains potentially unsafe shell characters."
)
if re.search(r"on\w+\s*=", v, re.IGNORECASE):
raise ValueError(
"Filter value contains potentially malicious event handlers."
)
if re.search(r"\\x[0-9a-fA-F]{2}", v):
raise ValueError("Filter value contains hex encoding which is not allowed.")
@field_validator("value")
@classmethod
def sanitize_value(cls, v: str | int | float | bool) -> str | int | float | bool:
"""Sanitize filter value to prevent XSS and SQL injection attacks."""
if isinstance(v, str):
v = v.strip()
# Length check FIRST to prevent ReDoS attacks
if len(v) > 1000:
raise ValueError(
f"Filter value too long ({len(v)} characters). "
f"Maximum allowed length is 1000 characters."
)
# Validate security
cls._validate_string_value(v)
# Filter dangerous Unicode characters
v = re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
)
# HTML escape the cleaned content
return html.escape(v)
return v # Return non-string values as-is
# Actual chart types
class TableChartConfig(BaseModel):
chart_type: Literal["table"] = Field("table", description="Chart type")
columns: List[ColumnRef] = Field(
...,
min_length=1,
description=(
"Columns to display. Must have at least one column. Each column must have "
"a unique label "
"(either explicitly set via 'label' field or auto-generated "
"from name/aggregate)"
),
)
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
sort_by: List[str] | None = Field(None, description="Columns to sort by")
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "TableChartConfig":
"""Ensure all column labels are unique."""
labels_seen = set()
duplicates = []
for i, col in enumerate(self.columns):
# Generate the label that will be used (same logic as create_metric_object)
if col.aggregate:
label = col.label or f"{col.aggregate}({col.name})"
else:
label = col.label or col.name
if label in labels_seen:
duplicates.append(f"columns[{i}]: '{label}'")
else:
labels_seen.add(label)
if duplicates:
raise ValueError(
f"Duplicate column/metric labels: {', '.join(duplicates)}. "
f"Please make sure all columns and metrics have a unique label. "
f"Use the 'label' field to provide custom names for columns."
)
return self
class XYChartConfig(BaseModel):
chart_type: Literal["xy"] = Field("xy", description="Chart type")
x: ColumnRef = Field(..., description="X-axis column")
y: List[ColumnRef] = Field(
...,
min_length=1,
description="Y-axis columns (metrics). Must have at least one Y-axis column. "
"Each column must have a unique label "
"that doesn't conflict with x-axis or group_by labels",
)
kind: Literal["line", "bar", "area", "scatter"] = Field(
"line", description="Chart visualization type"
)
group_by: ColumnRef | None = Field(None, description="Column to group by")
x_axis: AxisConfig | None = Field(None, description="X-axis configuration")
y_axis: AxisConfig | None = Field(None, description="Y-axis configuration")
legend: LegendConfig | None = Field(None, description="Legend configuration")
filters: List[FilterConfig] | None = Field(None, description="Filters to apply")
@model_validator(mode="after")
def validate_unique_column_labels(self) -> "XYChartConfig":
"""Ensure all column labels are unique across x, y, and group_by."""
labels_seen = {} # label -> field_name for error reporting
duplicates = []
# Check X-axis label
x_label = self.x.label or self.x.name
labels_seen[x_label] = "x"
# Check Y-axis labels
for i, col in enumerate(self.y):
if col.aggregate:
label = col.label or f"{col.aggregate}({col.name})"
else:
label = col.label or col.name
if label in labels_seen:
duplicates.append(
f"y[{i}]: '{label}' (conflicts with {labels_seen[label]})"
)
else:
labels_seen[label] = f"y[{i}]"
# Check group_by label if present
if self.group_by:
group_label = self.group_by.label or self.group_by.name
if group_label in labels_seen:
duplicates.append(
f"group_by: '{group_label}' "
f"(conflicts with {labels_seen[group_label]})"
)
if duplicates:
raise ValueError(
f"Duplicate column/metric labels: {', '.join(duplicates)}. "
f"Please make sure all columns and metrics have a unique label. "
f"Use the 'label' field to provide custom names for columns."
)
return self
# Discriminated union entry point with custom error handling
ChartConfig = Annotated[
XYChartConfig | TableChartConfig,
Field(
discriminator="chart_type",
description="Chart configuration - specify chart_type as 'xy' or 'table'",
),
]
class ListChartsRequest(MetadataCacheControl):
"""Request schema for list_charts with clear, unambiguous types."""
filters: Annotated[
List[ChartFilter],
Field(
default_factory=list,
description="List of filter objects (column, operator, value). Each "
"filter is an object with 'col', 'opr', and 'value' "
"properties. Cannot be used together with 'search'.",
),
]
select_columns: Annotated[
List[str],
Field(
default_factory=lambda: [
"id",
"slice_name",
"viz_type",
"datasource_name",
"description",
"changed_by_name",
"created_by_name",
"changed_on",
"created_on",
"uuid",
],
description="List of columns to select. Defaults to common columns if not "
"specified.",
),
]
@field_validator("filters", mode="before")
@classmethod
def parse_filters(cls, v: Any) -> List[ChartFilter]:
"""
Parse filters from JSON string or list.
Handles Claude Code bug where objects are double-serialized as strings.
See: https://github.com/anthropics/claude-code/issues/5504
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_model_list
return parse_json_or_model_list(v, ChartFilter, "filters")
@field_validator("select_columns", mode="before")
@classmethod
def parse_select_columns(cls, v: Any) -> List[str]:
"""
Parse select_columns from JSON string, list, or CSV string.
Handles Claude Code bug where arrays are double-serialized as strings.
See: https://github.com/anthropics/claude-code/issues/5504
"""
from superset.mcp_service.utils.schema_utils import parse_json_or_list
return parse_json_or_list(v, "select_columns")
search: Annotated[
str | None,
Field(
default=None,
description="Text search string to match against chart fields. Cannot be "
"used together with 'filters'.",
),
]
order_column: Annotated[
str | None, Field(default=None, description="Column to order results by")
]
order_direction: Annotated[
Literal["asc", "desc"],
Field(
default="asc", description="Direction to order results ('asc' or 'desc')"
),
]
page: Annotated[
PositiveInt,
Field(default=1, description="Page number for pagination (1-based)"),
]
page_size: Annotated[
PositiveInt, Field(default=10, description="Number of items per page")
]
@model_validator(mode="after")
def validate_search_and_filters(self) -> "ListChartsRequest":
"""Prevent using both search and filters simultaneously to avoid query
conflicts."""
if self.search and self.filters:
raise ValueError(
"Cannot use both 'search' and 'filters' parameters simultaneously. "
"Use either 'search' for text-based searching across multiple fields, "
"or 'filters' for precise column-based filtering, but not both."
)
return self
# The tool input models
class GenerateChartRequest(QueryCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration")
save_chart: bool = Field(
default=True,
description="Whether to permanently save the chart in Superset",
)
generate_preview: bool = Field(
default=True,
description="Whether to generate a preview image",
)
preview_formats: List[
Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"]
] = Field(
default_factory=lambda: ["url"],
description="List of preview formats to generate",
)
@model_validator(mode="after")
def validate_cache_timeout(self) -> "GenerateChartRequest":
"""Validate cache timeout is non-negative."""
if (
hasattr(self, "cache_timeout")
and self.cache_timeout is not None
and self.cache_timeout < 0
):
raise ValueError(
"cache_timeout must be non-negative (0 or positive integer)"
)
return self
class GenerateExploreLinkRequest(FormDataCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration")
class UpdateChartRequest(QueryCacheControl):
identifier: int | str = Field(..., description="Chart identifier (ID, UUID)")
config: ChartConfig = Field(..., description="New chart configuration")
chart_name: str | None = Field(
None,
description="New chart name (optional, will auto-generate if not provided)",
max_length=255,
)
generate_preview: bool = Field(
default=True,
description="Whether to generate a preview after updating",
)
preview_formats: List[
Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"]
] = Field(
default_factory=lambda: ["url"],
description="List of preview formats to generate",
)
@field_validator("chart_name")
@classmethod
def sanitize_chart_name(cls, v: str | None) -> str | None:
"""Sanitize chart name to prevent XSS attacks."""
if v is None:
return v
# Strip whitespace
v = v.strip()
if not v:
return None
# Length check first to prevent ReDoS attacks
if len(v) > 255:
raise ValueError(
f"Chart name too long ({len(v)} characters). "
f"Maximum allowed length is 255 characters."
)
# Check for dangerous HTML tags using substring checks (safe)
dangerous_tags = [
"<script",
"</script>",
"<iframe",
"</iframe>",
"<object",
"</object>",
"<embed",
"</embed>",
"<link",
"<meta",
]
v_lower = v.lower()
for tag in dangerous_tags:
if tag in v_lower:
raise ValueError(
"Chart name contains potentially malicious content. "
"HTML tags and JavaScript are not allowed in chart names."
)
# Check URL schemes with word boundaries
if re.search(r"\b(javascript|vbscript|data):", v, re.IGNORECASE):
raise ValueError("Chart name contains potentially malicious URL scheme")
# Check for event handlers with simple regex
if re.search(r"on\w+\s*=", v, re.IGNORECASE):
raise ValueError(
"Chart name contains potentially malicious event handlers."
)
# Filter dangerous Unicode characters
v = re.sub(
r"[\u200B-\u200D\uFEFF\u0000-\u0008\u000B\u000C\u000E-\u001F]", "", v
)
# HTML escape the cleaned content
sanitized = html.escape(v)
return sanitized if sanitized else None
class UpdateChartPreviewRequest(FormDataCacheControl):
form_data_key: str = Field(..., description="Existing form_data_key to update")
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="New chart configuration")
generate_preview: bool = Field(
default=True,
description="Whether to generate a preview after updating",
)
preview_formats: List[
Literal["url", "interactive", "ascii", "vega_lite", "table", "base64"]
] = Field(
default_factory=lambda: ["url"],
description="List of preview formats to generate",
)
class GetChartDataRequest(QueryCacheControl):
"""Request for chart data with cache control."""
identifier: int | str = Field(description="Chart identifier (ID, UUID)")
limit: int | None = Field(
default=100, description="Maximum number of data rows to return"
)
format: Literal["json", "csv", "excel"] = Field(
default="json", description="Data export format"
)
class DataColumn(BaseModel):
"""Enhanced column metadata with semantic information."""
name: str = Field(..., description="Column name")
display_name: str = Field(..., description="Human-readable column name")
data_type: str = Field(..., description="Inferred data type")
sample_values: List[Any] = Field(description="Representative sample values")
null_count: int = Field(description="Number of null values")
unique_count: int = Field(description="Number of unique values")
statistics: Dict[str, Any] | None = Field(
None, description="Column statistics if numeric"
)
semantic_type: str | None = Field(
None, description="Semantic type (currency, percentage, etc)"
)
class ChartData(BaseModel):
"""Rich chart data response with statistical insights."""
# Basic information
chart_id: int
chart_name: str
chart_type: str
# Enhanced data description
columns: List[DataColumn] = Field(description="Rich column metadata")
data: List[Dict[str, Any]] = Field(description="Actual data rows")
# Data insights
row_count: int = Field(description="Rows returned")
total_rows: int | None = Field(description="Total available rows")
data_freshness: datetime | None = Field(description="When data was last updated")
# LLM-friendly summaries
summary: str = Field(description="Human-readable data summary")
insights: List[str] = Field(description="Key patterns discovered in the data")
data_quality: Dict[str, Any] = Field(description="Data quality assessment")
recommended_visualizations: List[str] = Field(
description="Suggested chart types for this data"
)
# Performance and metadata
performance: PerformanceMetadata = Field(description="Query performance metrics")
cache_status: CacheStatus | None = Field(
None, description="Cache usage information"
)
# Export format fields
csv_data: str | None = Field(None, description="CSV content when format='csv'")
excel_data: str | None = Field(
None, description="Base64-encoded Excel content when format='excel'"
)
format: str | None = Field(
None, description="Export format used (json, csv, excel)"
)
# Inherit versioning
schema_version: str = Field("2.0", description="Response schema version")
api_version: str = Field("v1", description="MCP API version")
class GetChartPreviewRequest(QueryCacheControl):
"""Request for chart preview with cache control."""
identifier: int | str = Field(description="Chart identifier (ID, UUID)")
format: Literal["url", "ascii", "table", "base64", "vega_lite"] = Field(
default="url",
description=(
"Preview format: 'url' for image URL, 'ascii' for text art, "
"'table' for data table, 'base64' for embedded image, "
"'vega_lite' for interactive JSON specification"
),
)
width: int | None = Field(
default=800,
description="Preview image width in pixels (for url/base64 formats)",
)
height: int | None = Field(
default=600,
description="Preview image height in pixels (for url/base64 formats)",
)
ascii_width: int | None = Field(
default=80, description="ASCII chart width in characters (for ascii format)"
)
ascii_height: int | None = Field(
default=20, description="ASCII chart height in lines (for ascii format)"
)
# Discriminated union preview formats for type safety
class URLPreview(BaseModel):
"""URL-based image preview format."""
type: Literal["url"] = "url"
preview_url: str = Field(..., description="Direct image URL")
width: int = Field(..., description="Image width in pixels")
height: int = Field(..., description="Image height in pixels")
supports_interaction: bool = Field(
False, description="Static image, no interaction"
)
class InteractivePreview(BaseModel):
"""Interactive HTML preview with JavaScript controls."""
type: Literal["interactive"] = "interactive"
html_content: str = Field(..., description="Embeddable HTML with Plotly/D3")
preview_url: str = Field(..., description="Iframe-compatible URL")
width: int = Field(..., description="Viewport width")
height: int = Field(..., description="Viewport height")
supports_pan: bool = Field(True, description="Supports pan interaction")
supports_zoom: bool = Field(True, description="Supports zoom interaction")
supports_hover: bool = Field(True, description="Supports hover details")
class ASCIIPreview(BaseModel):
"""ASCII art text representation."""
type: Literal["ascii"] = "ascii"
ascii_content: str = Field(..., description="Unicode art representation")
width: int = Field(..., description="Character width")
height: int = Field(..., description="Line height")
supports_color: bool = Field(False, description="Uses ANSI color codes")
class VegaLitePreview(BaseModel):
"""Vega-Lite grammar of graphics specification."""
type: Literal["vega_lite"] = "vega_lite"
specification: Dict[str, Any] = Field(..., description="Vega-Lite JSON spec")
data_url: str | None = Field(None, description="External data URL")
supports_streaming: bool = Field(False, description="Supports live data updates")
class TablePreview(BaseModel):
"""Tabular data preview format."""
type: Literal["table"] = "table"
table_data: str = Field(..., description="Formatted table content")
row_count: int = Field(..., description="Number of rows displayed")
supports_sorting: bool = Field(False, description="Table supports sorting")
# Modern discriminated union using | syntax
ChartPreviewContent = Annotated[
URLPreview | InteractivePreview | ASCIIPreview | VegaLitePreview | TablePreview,
Field(discriminator="type"),
]
class GenerateChartResponse(BaseModel):
"""Comprehensive chart creation response with rich metadata."""
# Core chart information
chart: ChartInfo | None = Field(None, description="Complete chart metadata")
# Multiple preview formats available
previews: Dict[str, ChartPreviewContent] = Field(
default_factory=dict,
description="Available preview formats keyed by format type",
)
# LLM-friendly capabilities
capabilities: ChartCapabilities | None = Field(
None, description="Chart interaction capabilities"
)
semantics: ChartSemantics | None = Field(
None, description="Semantic chart understanding"
)
# Navigation and context
explore_url: str | None = Field(None, description="Edit chart in Superset")
embed_code: str | None = Field(None, description="HTML embed snippet")
api_endpoints: Dict[str, str] = Field(
default_factory=dict, description="Related API endpoints for data/updates"
)
# Performance and accessibility
performance: PerformanceMetadata | None = Field(
None, description="Performance metrics"
)
accessibility: AccessibilityMetadata | None = Field(
None, description="Accessibility info"
)
# Success/error handling
success: bool = Field(True, description="Whether chart creation succeeded")
error: ChartGenerationError | None = Field(
None, description="Error details if creation failed"
)
warnings: List[str] = Field(default_factory=list, description="Non-fatal warnings")
# Inherit versioning
schema_version: str = Field("2.0", description="Response schema version")
api_version: str = Field("v1", description="MCP API version")
class ChartPreview(BaseModel):
"""Enhanced chart preview with discriminated union content."""
chart_id: int
chart_name: str
chart_type: str = Field(description="Type of chart visualization")
explore_url: str = Field(description="URL to open chart in Superset for editing")
# Type-safe preview content
content: ChartPreviewContent = Field(
description="Preview content in requested format"
)
# Rich metadata
chart_description: str = Field(
description="Human-readable description of the chart"
)
accessibility: AccessibilityMetadata = Field(
description="Accessibility information"
)
performance: PerformanceMetadata = Field(description="Performance metrics")
# Backward compatibility fields (populated based on content type)
format: str | None = Field(
None, description="Format of the preview (url, ascii, table, base64)"
)
preview_url: str | None = Field(None, description="Image URL for 'url' format")
ascii_chart: str | None = Field(
None, description="ASCII art chart for 'ascii' format"
)
table_data: str | None = Field(
None, description="Formatted table for 'table' format"
)
width: int | None = Field(
None, description="Width (pixels for images, characters for ASCII)"
)
height: int | None = Field(
None, description="Height (pixels for images, lines for ASCII)"
)
# Inherit versioning
schema_version: str = Field("2.0", description="Response schema version")
api_version: str = Field("v1", description="MCP API version")