fix(mcp): compress chart config schemas to reduce search_tools token usage (#39018)

This commit is contained in:
Amin Ghadersohi
2026-04-06 19:52:03 -04:00
committed by GitHub
parent b05764d070
commit bf9aff19b5
10 changed files with 267 additions and 110 deletions

View File

@@ -36,6 +36,7 @@ from pydantic import (
model_serializer, model_serializer,
model_validator, model_validator,
PositiveInt, PositiveInt,
TypeAdapter,
) )
from typing_extensions import Self from typing_extensions import Self
@@ -1145,7 +1146,7 @@ class XYChartConfig(UnknownFieldCheckMixin):
return self return self
# Discriminated union entry point with custom error handling # Discriminated union for runtime validation (not exposed in JSON Schema)
ChartConfig = Annotated[ ChartConfig = Annotated[
XYChartConfig XYChartConfig
| TableChartConfig | TableChartConfig
@@ -1164,6 +1165,66 @@ ChartConfig = Annotated[
), ),
] ]
# Module-level TypeAdapter avoids repeated schema compilation in
# parse_chart_config() — safe because ChartConfig is fully defined above.
_CHART_CONFIG_ADAPTER: TypeAdapter[ChartConfig] = TypeAdapter(ChartConfig)
# Compact description for JSON Schema — keeps tool inputSchema small while
# giving LLMs enough context to construct valid configs.
_CHART_CONFIG_DESCRIPTION = (
"Chart configuration object. MUST include 'chart_type' to select the "
"schema. Types: 'xy' (x, y, kind: line/bar/area/scatter), "
"'table' (columns), 'pie' (dimension, metric), "
"'pivot_table' (rows, metrics), 'mixed_timeseries' (x, y, y_secondary), "
"'handlebars' (columns, handlebars_template), "
"'big_number' (metric). "
"See chart://configs resource for full field reference and examples."
)
def parse_chart_config(
config: Dict[str, Any],
) -> (
XYChartConfig
| TableChartConfig
| PieChartConfig
| PivotTableChartConfig
| MixedTimeseriesChartConfig
| HandlebarsChartConfig
| BigNumberChartConfig
):
"""Parse a raw dict into the appropriate typed ChartConfig subclass.
Validates the dict against the discriminated union using chart_type.
Call this in tool function bodies to get a typed config object.
"""
try:
return _CHART_CONFIG_ADAPTER.validate_python(config)
except Exception as e:
raise ValueError(
f"{e}\n\n"
f"Hint: read the chart://configs resource for valid configuration "
f"examples and field reference."
) from e
def _coerce_config_to_dict(v: Any) -> Dict[str, Any]:
"""Accept ChartConfig objects, dicts, or JSON strings for the config field."""
if isinstance(v, str):
from superset.utils import json as json_utils
try:
v = json_utils.loads(v)
except (ValueError, TypeError) as exc:
raise ValueError(
f"config must be a JSON object string, got: {v!r}"
) from exc
if hasattr(v, "model_dump"):
return v.model_dump()
if isinstance(v, dict):
return v
raise TypeError(f"config must be a dict or JSON string, got {type(v).__name__}")
class ListChartsRequest(MetadataCacheControl): class ListChartsRequest(MetadataCacheControl):
"""Request schema for list_charts with clear, unambiguous types.""" """Request schema for list_charts with clear, unambiguous types."""
@@ -1259,7 +1320,7 @@ class ListChartsRequest(MetadataCacheControl):
# The tool input models # The tool input models
class GenerateChartRequest(QueryCacheControl): class GenerateChartRequest(QueryCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)") dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration") config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
chart_name: str | None = Field( chart_name: str | None = Field(
None, description="Auto-generates if omitted", max_length=255 None, description="Auto-generates if omitted", max_length=255
) )
@@ -1269,6 +1330,11 @@ class GenerateChartRequest(QueryCacheControl):
default_factory=lambda: ["url"], default_factory=lambda: ["url"],
) )
@field_validator("config", mode="before")
@classmethod
def coerce_config(cls, v: Any) -> Dict[str, Any]:
return _coerce_config_to_dict(v)
@field_validator("chart_name") @field_validator("chart_name")
@classmethod @classmethod
def sanitize_chart_name(cls, v: str | None) -> str | None: def sanitize_chart_name(cls, v: str | None) -> str | None:
@@ -1301,16 +1367,20 @@ class GenerateChartRequest(QueryCacheControl):
class GenerateExploreLinkRequest(FormDataCacheControl): class GenerateExploreLinkRequest(FormDataCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)") dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration") config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
@field_validator("config", mode="before")
@classmethod
def coerce_config(cls, v: Any) -> Dict[str, Any]:
return _coerce_config_to_dict(v)
class UpdateChartRequest(QueryCacheControl): class UpdateChartRequest(QueryCacheControl):
identifier: int | str = Field(..., description="Chart ID or UUID") identifier: int | str = Field(..., description="Chart ID or UUID")
config: ChartConfig | None = Field( config: Dict[str, Any] | None = Field(
None, None,
description=( description=(
"Chart configuration. Required for visualization changes. " f"{_CHART_CONFIG_DESCRIPTION} Optional; omit to only update chart_name."
"Omit to only update the chart name."
), ),
) )
chart_name: str | None = Field( chart_name: str | None = Field(
@@ -1321,6 +1391,13 @@ class UpdateChartRequest(QueryCacheControl):
default_factory=lambda: ["url"], default_factory=lambda: ["url"],
) )
@field_validator("config", mode="before")
@classmethod
def coerce_config(cls, v: Any) -> Dict[str, Any] | None:
if v is None:
return None
return _coerce_config_to_dict(v)
@field_validator("chart_name") @field_validator("chart_name")
@classmethod @classmethod
def sanitize_chart_name(cls, v: str | None) -> str | None: def sanitize_chart_name(cls, v: str | None) -> str | None:
@@ -1331,12 +1408,17 @@ class UpdateChartRequest(QueryCacheControl):
class UpdateChartPreviewRequest(FormDataCacheControl): class UpdateChartPreviewRequest(FormDataCacheControl):
form_data_key: str = Field(..., description="Existing form_data_key to update") form_data_key: str = Field(..., description="Existing form_data_key to update")
dataset_id: int | str = Field(..., description="Dataset ID or UUID") dataset_id: int | str = Field(..., description="Dataset ID or UUID")
config: ChartConfig config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
generate_preview: bool = True generate_preview: bool = True
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field( preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
default_factory=lambda: ["url"], default_factory=lambda: ["url"],
) )
@field_validator("config", mode="before")
@classmethod
def coerce_config(cls, v: Any) -> Dict[str, Any]:
return _coerce_config_to_dict(v)
class GetChartDataRequest(QueryCacheControl): class GetChartDataRequest(QueryCacheControl):
"""Request for chart data with cache control. """Request for chart data with cache control.

View File

@@ -43,6 +43,7 @@ from superset.mcp_service.chart.schemas import (
ChartError, ChartError,
GenerateChartRequest, GenerateChartRequest,
GenerateChartResponse, GenerateChartResponse,
parse_chart_config,
PerformanceMetadata, PerformanceMetadata,
) )
from superset.mcp_service.utils.url_utils import get_superset_base_url from superset.mcp_service.utils.url_utils import get_superset_base_url
@@ -209,13 +210,17 @@ async def generate_chart( # noqa: C901
"save_chart=%s, preview_formats=%s" "save_chart=%s, preview_formats=%s"
% ( % (
request.dataset_id, request.dataset_id,
request.config.chart_type, request.config.get("chart_type", "unknown"),
request.save_chart, request.save_chart,
request.preview_formats, request.preview_formats,
) )
) )
await ctx.debug( await ctx.debug(
"Chart configuration details: config=%s" % (request.config.model_dump(),) "Chart configuration details: chart_type=%s, keys=%s"
% (
request.config.get("chart_type", "unknown"),
sorted(request.config.keys()),
)
) )
# Track runtime warnings to include in response # Track runtime warnings to include in response
@@ -269,11 +274,12 @@ async def generate_chart( # noqa: C901
} }
) )
# Parse the raw config dict into a typed ChartConfig for downstream use
config = parse_chart_config(request.config)
# Map the simplified config to Superset's form_data format # Map the simplified config to Superset's form_data format
# Pass dataset_id to enable column type checking for proper viz_type selection # Pass dataset_id to enable column type checking for proper viz_type selection
form_data = map_config_to_form_data( form_data = map_config_to_form_data(config, dataset_id=request.dataset_id)
request.config, dataset_id=request.dataset_id
)
chart = None chart = None
chart_id = None chart_id = None
@@ -367,7 +373,7 @@ async def generate_chart( # noqa: C901
dataset, "table_name", None dataset, "table_name", None
) )
chart_name = request.chart_name or generate_chart_name( chart_name = request.chart_name or generate_chart_name(
request.config, dataset_name=dataset_name config, dataset_name=dataset_name
) )
await ctx.debug("Chart name: chart_name=%s" % (chart_name,)) await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
@@ -607,8 +613,8 @@ async def generate_chart( # noqa: C901
response_warnings.extend(compile_result.warnings) response_warnings.extend(compile_result.warnings)
# Generate semantic analysis # Generate semantic analysis
capabilities = analyze_chart_capabilities(chart, request.config) capabilities = analyze_chart_capabilities(chart, config)
semantics = analyze_chart_semantics(chart, request.config) semantics = analyze_chart_semantics(chart, config)
# Create performance metadata # Create performance metadata
execution_time = int((time.time() - start_time) * 1000) execution_time = int((time.time() - start_time) * 1000)
@@ -622,7 +628,7 @@ async def generate_chart( # noqa: C901
chart_name = ( chart_name = (
chart.slice_name chart.slice_name
if chart and hasattr(chart, "slice_name") if chart and hasattr(chart, "slice_name")
else generate_chart_name(request.config) else generate_chart_name(config)
) )
accessibility = AccessibilityMetadata( accessibility = AccessibilityMetadata(
color_blind_safe=True, # Would need actual analysis color_blind_safe=True, # Would need actual analysis
@@ -843,9 +849,9 @@ async def generate_chart( # noqa: C901
# Extract chart_type from different sources for better error context # Extract chart_type from different sources for better error context
chart_type = "unknown" chart_type = "unknown"
try: try:
if hasattr(request, "config") and hasattr(request.config, "chart_type"): if hasattr(request, "config") and isinstance(request.config, dict):
chart_type = request.config.chart_type chart_type = request.config.get("chart_type", "unknown")
except AttributeError as extract_error: except (AttributeError, TypeError) as extract_error:
# Ignore errors when extracting chart type for error context # Ignore errors when extracting chart type for error context
logger.debug("Could not extract chart type: %s", extract_error) logger.debug("Could not extract chart type: %s", extract_error)

View File

@@ -38,6 +38,7 @@ from superset.mcp_service.chart.chart_utils import (
from superset.mcp_service.chart.schemas import ( from superset.mcp_service.chart.schemas import (
AccessibilityMetadata, AccessibilityMetadata,
GenerateChartResponse, GenerateChartResponse,
parse_chart_config,
PerformanceMetadata, PerformanceMetadata,
UpdateChartRequest, UpdateChartRequest,
) )
@@ -69,14 +70,15 @@ def _build_update_payload(
when neither config nor chart_name is provided. when neither config nor chart_name is provided.
""" """
if request.config is not None: if request.config is not None:
config = parse_chart_config(request.config)
dataset_id = chart.datasource_id if chart.datasource_id else None dataset_id = chart.datasource_id if chart.datasource_id else None
new_form_data = map_config_to_form_data(request.config, dataset_id=dataset_id) new_form_data = map_config_to_form_data(config, dataset_id=dataset_id)
new_form_data.pop("_mcp_warnings", None) new_form_data.pop("_mcp_warnings", None)
chart_name = ( chart_name = (
request.chart_name request.chart_name
if request.chart_name if request.chart_name
else chart.slice_name or generate_chart_name(request.config) else chart.slice_name or generate_chart_name(config)
) )
return { return {
@@ -222,9 +224,12 @@ async def update_chart(
command = UpdateChartCommand(chart.id, payload_or_error) command = UpdateChartCommand(chart.id, payload_or_error)
updated_chart = command.run() updated_chart = command.run()
# Parse config for analysis (may be None for name-only updates)
config = parse_chart_config(request.config) if request.config else None
# Generate semantic analysis # Generate semantic analysis
capabilities = analyze_chart_capabilities(updated_chart, request.config) capabilities = analyze_chart_capabilities(updated_chart, config)
semantics = analyze_chart_semantics(updated_chart, request.config) semantics = analyze_chart_semantics(updated_chart, config)
# Create performance metadata # Create performance metadata
execution_time = int((time.time() - start_time) * 1000) execution_time = int((time.time() - start_time) * 1000)
@@ -238,11 +243,7 @@ async def update_chart(
chart_name = ( chart_name = (
updated_chart.slice_name updated_chart.slice_name
if updated_chart and hasattr(updated_chart, "slice_name") if updated_chart and hasattr(updated_chart, "slice_name")
else ( else (generate_chart_name(config) if config else "Updated chart")
generate_chart_name(request.config)
if request.config
else "Updated chart"
)
) )
accessibility = AccessibilityMetadata( accessibility = AccessibilityMetadata(
color_blind_safe=True, # Would need actual analysis color_blind_safe=True, # Would need actual analysis

View File

@@ -36,6 +36,7 @@ from superset.mcp_service.chart.chart_utils import (
) )
from superset.mcp_service.chart.schemas import ( from superset.mcp_service.chart.schemas import (
AccessibilityMetadata, AccessibilityMetadata,
parse_chart_config,
PerformanceMetadata, PerformanceMetadata,
UpdateChartPreviewRequest, UpdateChartPreviewRequest,
) )
@@ -95,20 +96,20 @@ def update_chart_preview(
start_time = time.time() start_time = time.time()
try: try:
# Parse the raw config dict into a typed ChartConfig
config = parse_chart_config(request.config)
with event_logger.log_context(action="mcp.update_chart_preview.form_data"): with event_logger.log_context(action="mcp.update_chart_preview.form_data"):
# Map the new config to form_data format # Map the new config to form_data format
# Pass dataset_id to enable column type checking # Pass dataset_id to enable column type checking
new_form_data = map_config_to_form_data( new_form_data = map_config_to_form_data(
request.config, dataset_id=request.dataset_id config, dataset_id=request.dataset_id
) )
new_form_data.pop("_mcp_warnings", None) new_form_data.pop("_mcp_warnings", None)
# Preserve adhoc filters from the previous cached form_data # Preserve adhoc filters from the previous cached form_data
# when the new config doesn't explicitly specify filters # when the new config doesn't explicitly specify filters
if ( if getattr(config, "filters", None) is None and request.form_data_key:
getattr(request.config, "filters", None) is None
and request.form_data_key
):
old_adhoc_filters = _get_old_adhoc_filters(request.form_data_key) old_adhoc_filters = _get_old_adhoc_filters(request.form_data_key)
if old_adhoc_filters: if old_adhoc_filters:
new_form_data["adhoc_filters"] = old_adhoc_filters new_form_data["adhoc_filters"] = old_adhoc_filters
@@ -123,8 +124,8 @@ def update_chart_preview(
with event_logger.log_context(action="mcp.update_chart_preview.metadata"): with event_logger.log_context(action="mcp.update_chart_preview.metadata"):
# Generate semantic analysis # Generate semantic analysis
capabilities = analyze_chart_capabilities(None, request.config) capabilities = analyze_chart_capabilities(None, config)
semantics = analyze_chart_semantics(None, request.config) semantics = analyze_chart_semantics(None, config)
# Create performance metadata # Create performance metadata
execution_time = int((time.time() - start_time) * 1000) execution_time = int((time.time() - start_time) * 1000)
@@ -135,7 +136,7 @@ def update_chart_preview(
) )
# Create accessibility metadata # Create accessibility metadata
chart_name = generate_chart_name(request.config) chart_name = generate_chart_name(config)
accessibility = AccessibilityMetadata( accessibility = AccessibilityMetadata(
color_blind_safe=True, # Would need actual analysis color_blind_safe=True, # Would need actual analysis
alt_text=f"Updated chart preview showing {chart_name}", alt_text=f"Updated chart preview showing {chart_name}",

View File

@@ -26,6 +26,7 @@ from typing import Any, Dict, List, Tuple
from superset.mcp_service.chart.schemas import ( from superset.mcp_service.chart.schemas import (
ChartConfig, ChartConfig,
GenerateChartRequest, GenerateChartRequest,
parse_chart_config,
) )
from superset.mcp_service.common.error_schemas import ( from superset.mcp_service.common.error_schemas import (
ChartGenerationError, ChartGenerationError,
@@ -171,6 +172,10 @@ class ValidationPipeline:
if request is None: if request is None:
return ValidationResult(is_valid=False, error=error) return ValidationResult(is_valid=False, error=error)
# Parse the raw config dict into a typed ChartConfig for
# downstream validators that need typed access.
typed_config = parse_chart_config(request.config)
# Fetch dataset context once and reuse across validation layers # Fetch dataset context once and reuse across validation layers
dataset_context = ValidationPipeline._get_dataset_context( dataset_context = ValidationPipeline._get_dataset_context(
request.dataset_id request.dataset_id
@@ -178,20 +183,20 @@ class ValidationPipeline:
# Layer 2: Dataset validation (reuses context) # Layer 2: Dataset validation (reuses context)
is_valid, error = ValidationPipeline._validate_dataset( is_valid, error = ValidationPipeline._validate_dataset(
request.config, request.dataset_id, dataset_context typed_config, request.dataset_id, dataset_context
) )
if not is_valid: if not is_valid:
return ValidationResult(is_valid=False, request=request, error=error) return ValidationResult(is_valid=False, request=request, error=error)
# Layer 3: Runtime validation - returns warnings as metadata, not errors # Layer 3: Runtime validation - returns warnings as metadata, not errors
_is_valid, warnings_metadata = ValidationPipeline._validate_runtime( _is_valid, warnings_metadata = ValidationPipeline._validate_runtime(
request.config, request.dataset_id typed_config, request.dataset_id
) )
# Runtime validation always returns True now, warnings are informational # Runtime validation always returns True now, warnings are informational
# Layer 4: Column name normalization (reuses context) # Layer 4: Column name normalization (reuses context)
normalized_request = ValidationPipeline._normalize_column_names( normalized_request = ValidationPipeline._normalize_column_names(
request, dataset_context request, dataset_context, typed_config=typed_config
) )
return ValidationResult( return ValidationResult(
@@ -284,6 +289,7 @@ class ValidationPipeline:
def _normalize_column_names( def _normalize_column_names(
request: GenerateChartRequest, request: GenerateChartRequest,
dataset_context: DatasetContext | None = None, dataset_context: DatasetContext | None = None,
typed_config: ChartConfig | None = None,
) -> GenerateChartRequest: ) -> GenerateChartRequest:
""" """
Normalize column names in the request to match canonical dataset names. Normalize column names in the request to match canonical dataset names.
@@ -297,6 +303,8 @@ class ValidationPipeline:
request: The validated chart generation request request: The validated chart generation request
dataset_context: Pre-fetched dataset context to avoid duplicate dataset_context: Pre-fetched dataset context to avoid duplicate
DB queries. If None, fetches from the database. DB queries. If None, fetches from the database.
typed_config: Pre-parsed typed ChartConfig. If None, parses from
request.config dict.
Returns: Returns:
A new request with normalized column names A new request with normalized column names
@@ -304,8 +312,9 @@ class ValidationPipeline:
try: try:
from .dataset_validator import DatasetValidator from .dataset_validator import DatasetValidator
config = typed_config or parse_chart_config(request.config)
normalized_config = DatasetValidator.normalize_column_names( normalized_config = DatasetValidator.normalize_column_names(
request.config, config,
request.dataset_id, request.dataset_id,
dataset_context=dataset_context, dataset_context=dataset_context,
) )

View File

@@ -35,6 +35,7 @@ from superset.mcp_service.chart.chart_utils import (
) )
from superset.mcp_service.chart.schemas import ( from superset.mcp_service.chart.schemas import (
GenerateExploreLinkRequest, GenerateExploreLinkRequest,
parse_chart_config,
) )
@@ -89,7 +90,7 @@ async def generate_explore_link(
""" """
await ctx.info( await ctx.info(
"Generating explore link for dataset_id=%s, chart_type=%s" "Generating explore link for dataset_id=%s, chart_type=%s"
% (request.dataset_id, request.config.chart_type) % (request.dataset_id, request.config.get("chart_type", "unknown"))
) )
await ctx.debug( await ctx.debug(
"Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s" "Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s"
@@ -97,6 +98,9 @@ async def generate_explore_link(
) )
try: try:
# Parse the raw config dict into a typed ChartConfig
config = parse_chart_config(request.config)
await ctx.report_progress(1, 4, "Validating dataset exists") await ctx.report_progress(1, 4, "Validating dataset exists")
with event_logger.log_context(action="mcp.generate_explore_link.dataset_check"): with event_logger.log_context(action="mcp.generate_explore_link.dataset_check"):
from superset.daos.dataset import DatasetDAO from superset.daos.dataset import DatasetDAO
@@ -138,10 +142,10 @@ async def generate_explore_link(
) )
normalized_config = DatasetValidator.normalize_column_names( normalized_config = DatasetValidator.normalize_column_names(
request.config, request.dataset_id config, request.dataset_id
) )
except (ImportError, AttributeError, KeyError, ValueError, TypeError): except (ImportError, AttributeError, KeyError, ValueError, TypeError):
normalized_config = request.config normalized_config = config
# Map config to form_data using shared utilities # Map config to form_data using shared utilities
form_data = map_config_to_form_data( form_data = map_config_to_form_data(
@@ -197,7 +201,12 @@ async def generate_explore_link(
except Exception as e: except Exception as e:
await ctx.error( await ctx.error(
"Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s" "Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s"
% (request.dataset_id, request.config.chart_type, type(e).__name__, str(e)) % (
request.dataset_id,
request.config.get("chart_type", "unknown"),
type(e).__name__,
str(e),
)
) )
return { return {
"url": "", "url": "",

View File

@@ -24,6 +24,8 @@ from pydantic import ValidationError
from superset.mcp_service.chart.schemas import ( from superset.mcp_service.chart.schemas import (
ColumnRef, ColumnRef,
GenerateChartRequest,
parse_chart_config,
TableChartConfig, TableChartConfig,
XYChartConfig, XYChartConfig,
) )
@@ -625,3 +627,48 @@ class TestColumnRefSavedMetric:
), ),
], ],
) )
class TestParseChartConfig:
"""Tests for parse_chart_config and config coercion."""
def test_parse_valid_xy_config(self) -> None:
config = parse_chart_config(
{"chart_type": "xy", "x": {"name": "date"}, "y": [{"name": "v"}]}
)
assert config.chart_type == "xy"
assert config.x.name == "date"
assert len(config.y) == 1
assert config.y[0].name == "v"
def test_parse_valid_table_config(self) -> None:
config = parse_chart_config(
{"chart_type": "table", "columns": [{"name": "col1"}]}
)
assert config.chart_type == "table"
assert len(config.columns) == 1
assert config.columns[0].name == "col1"
def test_parse_missing_chart_type_raises(self) -> None:
with pytest.raises(ValueError, match="chart://configs"):
parse_chart_config({"x": {"name": "date"}, "y": [{"name": "v"}]})
def test_parse_unknown_chart_type_raises(self) -> None:
with pytest.raises(ValueError, match="chart://configs"):
parse_chart_config({"chart_type": "nonexistent", "x": {"name": "d"}})
def test_coerce_json_string_config(self) -> None:
raw = '{"chart_type": "table", "columns": [{"name": "c"}]}'
req = GenerateChartRequest(dataset_id=1, config=raw)
assert isinstance(req.config, dict)
assert req.config["chart_type"] == "table"
def test_coerce_typed_config_object(self) -> None:
typed = TableChartConfig(chart_type="table", columns=[ColumnRef(name="c")])
req = GenerateChartRequest(dataset_id=1, config=typed)
assert isinstance(req.config, dict)
assert req.config["chart_type"] == "table"
def test_coerce_invalid_json_string_raises(self) -> None:
with pytest.raises(ValidationError):
GenerateChartRequest(dataset_id=1, config="not valid json")

View File

@@ -57,10 +57,11 @@ class TestGenerateChart:
) )
table_request = GenerateChartRequest(dataset_id="1", config=table_config) table_request = GenerateChartRequest(dataset_id="1", config=table_config)
assert table_request.dataset_id == "1" assert table_request.dataset_id == "1"
assert table_request.config.chart_type == "table" # config is now Dict[str, Any] in the schema; validate via dict access
assert len(table_request.config.columns) == 2 assert table_request.config["chart_type"] == "table"
assert table_request.config.columns[0].name == "region" assert len(table_request.config["columns"]) == 2
assert table_request.config.columns[1].aggregate == "SUM" assert table_request.config["columns"][0]["name"] == "region"
assert table_request.config["columns"][1]["aggregate"] == "SUM"
# XY chart request # XY chart request
xy_config = XYChartConfig( xy_config = XYChartConfig(
@@ -74,12 +75,12 @@ class TestGenerateChart:
legend=LegendConfig(show=True, position="top"), legend=LegendConfig(show=True, position="top"),
) )
xy_request = GenerateChartRequest(dataset_id="2", config=xy_config) xy_request = GenerateChartRequest(dataset_id="2", config=xy_config)
assert xy_request.config.chart_type == "xy" assert xy_request.config["chart_type"] == "xy"
assert xy_request.config.x.name == "date" assert xy_request.config["x"]["name"] == "date"
assert xy_request.config.y[0].aggregate == "SUM" assert xy_request.config["y"][0]["aggregate"] == "SUM"
assert xy_request.config.kind == "line" assert xy_request.config["kind"] == "line"
assert xy_request.config.x_axis.title == "Date" assert xy_request.config["x_axis"]["title"] == "Date"
assert xy_request.config.legend.show is True assert xy_request.config["legend"]["show"] is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_generate_chart_validation_error_handling(self): async def test_generate_chart_validation_error_handling(self):

View File

@@ -60,10 +60,11 @@ class TestUpdateChart:
) )
table_request = UpdateChartRequest(identifier=123, config=table_config) table_request = UpdateChartRequest(identifier=123, config=table_config)
assert table_request.identifier == 123 assert table_request.identifier == 123
assert table_request.config.chart_type == "table" # config is now Dict[str, Any] in the schema; validate via dict access
assert len(table_request.config.columns) == 2 assert table_request.config["chart_type"] == "table"
assert table_request.config.columns[0].name == "region" assert len(table_request.config["columns"]) == 2
assert table_request.config.columns[1].aggregate == "SUM" assert table_request.config["columns"][0]["name"] == "region"
assert table_request.config["columns"][1]["aggregate"] == "SUM"
# XY chart update with UUID # XY chart update with UUID
xy_config = XYChartConfig( xy_config = XYChartConfig(
@@ -80,10 +81,10 @@ class TestUpdateChart:
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config
) )
assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890" assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert xy_request.config.chart_type == "xy" assert xy_request.config["chart_type"] == "xy"
assert xy_request.config.x.name == "date" assert xy_request.config["x"]["name"] == "date"
assert xy_request.config.y[0].aggregate == "SUM" assert xy_request.config["y"][0]["aggregate"] == "SUM"
assert xy_request.config.kind == "line" assert xy_request.config["kind"] == "line"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_with_chart_name(self): async def test_update_chart_with_chart_name(self):
@@ -170,7 +171,7 @@ class TestUpdateChart:
kind=chart_type, kind=chart_type,
) )
request = UpdateChartRequest(identifier=1, config=config) request = UpdateChartRequest(identifier=1, config=config)
assert request.config.kind == chart_type assert request.config["kind"] == chart_type
# Test multiple Y columns # Test multiple Y columns
multi_y_config = XYChartConfig( multi_y_config = XYChartConfig(
@@ -184,8 +185,8 @@ class TestUpdateChart:
kind="line", kind="line",
) )
request = UpdateChartRequest(identifier=1, config=multi_y_config) request = UpdateChartRequest(identifier=1, config=multi_y_config)
assert len(request.config.y) == 3 assert len(request.config["y"]) == 3
assert request.config.y[1].aggregate == "AVG" assert request.config["y"][1]["aggregate"] == "AVG"
# Test filter operators # Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="] operators = ["=", "!=", ">", ">=", "<", "<="]
@@ -196,7 +197,7 @@ class TestUpdateChart:
filters=filters, filters=filters,
) )
request = UpdateChartRequest(identifier=1, config=table_config) request = UpdateChartRequest(identifier=1, config=table_config)
assert len(request.config.filters) == 6 assert len(request.config["filters"]) == 6
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_response_structure(self): async def test_update_chart_response_structure(self):
@@ -252,12 +253,12 @@ class TestUpdateChart:
), ),
) )
request = UpdateChartRequest(identifier=1, config=config) request = UpdateChartRequest(identifier=1, config=config)
assert request.config.x_axis.title == "Date" assert request.config["x_axis"]["title"] == "Date"
assert request.config.x_axis.format == "smart_date" assert request.config["x_axis"]["format"] == "smart_date"
assert request.config.x_axis.scale == "linear" assert request.config["x_axis"]["scale"] == "linear"
assert request.config.y_axis.title == "Sales Amount" assert request.config["y_axis"]["title"] == "Sales Amount"
assert request.config.y_axis.format == "$,.2f" assert request.config["y_axis"]["format"] == "$,.2f"
assert request.config.y_axis.scale == "log" assert request.config["y_axis"]["scale"] == "log"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_legend_configurations(self): async def test_update_chart_legend_configurations(self):
@@ -271,8 +272,8 @@ class TestUpdateChart:
legend=LegendConfig(show=True, position=pos), legend=LegendConfig(show=True, position=pos),
) )
request = UpdateChartRequest(identifier=1, config=config) request = UpdateChartRequest(identifier=1, config=config)
assert request.config.legend.position == pos assert request.config["legend"]["position"] == pos
assert request.config.legend.show is True assert request.config["legend"]["show"] is True
# Hidden legend # Hidden legend
config = XYChartConfig( config = XYChartConfig(
@@ -282,7 +283,7 @@ class TestUpdateChart:
legend=LegendConfig(show=False), legend=LegendConfig(show=False),
) )
request = UpdateChartRequest(identifier=1, config=config) request = UpdateChartRequest(identifier=1, config=config)
assert request.config.legend.show is False assert request.config["legend"]["show"] is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_aggregation_functions(self): async def test_update_chart_aggregation_functions(self):
@@ -294,7 +295,7 @@ class TestUpdateChart:
columns=[ColumnRef(name="value", aggregate=agg)], columns=[ColumnRef(name="value", aggregate=agg)],
) )
request = UpdateChartRequest(identifier=1, config=config) request = UpdateChartRequest(identifier=1, config=config)
assert request.config.columns[0].aggregate == agg assert request.config["columns"][0]["aggregate"] == agg
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_error_responses(self): async def test_update_chart_error_responses(self):
@@ -378,10 +379,10 @@ class TestUpdateChart:
) )
request = UpdateChartRequest(identifier=1, config=config) request = UpdateChartRequest(identifier=1, config=config)
assert len(request.config.filters) == 3 assert len(request.config["filters"]) == 3
assert request.config.filters[0].column == "region" assert request.config["filters"][0]["column"] == "region"
assert request.config.filters[1].op == ">=" assert request.config["filters"][1]["op"] == ">="
assert request.config.filters[2].value == "2024-01-01" assert request.config["filters"][2]["value"] == "2024-01-01"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_cache_control(self): async def test_update_chart_cache_control(self):

View File

@@ -53,9 +53,9 @@ class TestUpdateChartPreview:
) )
assert table_request.form_data_key == "abc123def456" assert table_request.form_data_key == "abc123def456"
assert table_request.dataset_id == 1 assert table_request.dataset_id == 1
assert table_request.config.chart_type == "table" assert table_request.config["chart_type"] == "table"
assert len(table_request.config.columns) == 2 assert len(table_request.config["columns"]) == 2
assert table_request.config.columns[0].name == "region" assert table_request.config["columns"][0]["name"] == "region"
# XY chart preview update # XY chart preview update
xy_config = XYChartConfig( xy_config = XYChartConfig(
@@ -73,9 +73,9 @@ class TestUpdateChartPreview:
) )
assert xy_request.form_data_key == "xyz789ghi012" assert xy_request.form_data_key == "xyz789ghi012"
assert xy_request.dataset_id == "2" assert xy_request.dataset_id == "2"
assert xy_request.config.chart_type == "xy" assert xy_request.config["chart_type"] == "xy"
assert xy_request.config.x.name == "date" assert xy_request.config["x"]["name"] == "date"
assert xy_request.config.kind == "line" assert xy_request.config["kind"] == "line"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_preview_dataset_id_types(self): async def test_update_chart_preview_dataset_id_types(self):
@@ -158,7 +158,7 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config form_data_key="abc123", dataset_id=1, config=config
) )
assert request.config.kind == chart_type assert request.config["kind"] == chart_type
# Test multiple Y columns # Test multiple Y columns
multi_y_config = XYChartConfig( multi_y_config = XYChartConfig(
@@ -174,8 +174,8 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=multi_y_config form_data_key="abc123", dataset_id=1, config=multi_y_config
) )
assert len(request.config.y) == 3 assert len(request.config["y"]) == 3
assert request.config.y[1].aggregate == "AVG" assert request.config["y"][1]["aggregate"] == "AVG"
# Test filter operators # Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="] operators = ["=", "!=", ">", ">=", "<", "<="]
@@ -188,7 +188,7 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=table_config form_data_key="abc123", dataset_id=1, config=table_config
) )
assert len(request.config.filters) == 6 assert len(request.config["filters"]) == 6
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_preview_response_structure(self): async def test_update_chart_preview_response_structure(self):
@@ -251,10 +251,10 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config form_data_key="abc123", dataset_id=1, config=config
) )
assert request.config.x_axis.title == "Date" assert request.config["x_axis"]["title"] == "Date"
assert request.config.x_axis.format == "smart_date" assert request.config["x_axis"]["format"] == "smart_date"
assert request.config.y_axis.title == "Sales Amount" assert request.config["y_axis"]["title"] == "Sales Amount"
assert request.config.y_axis.format == "$,.2f" assert request.config["y_axis"]["format"] == "$,.2f"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_preview_legend_configurations(self): async def test_update_chart_preview_legend_configurations(self):
@@ -270,8 +270,8 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config form_data_key="abc123", dataset_id=1, config=config
) )
assert request.config.legend.position == pos assert request.config["legend"]["position"] == pos
assert request.config.legend.show is True assert request.config["legend"]["show"] is True
# Hidden legend # Hidden legend
config = XYChartConfig( config = XYChartConfig(
@@ -283,7 +283,7 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config form_data_key="abc123", dataset_id=1, config=config
) )
assert request.config.legend.show is False assert request.config["legend"]["show"] is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_preview_aggregation_functions(self): async def test_update_chart_preview_aggregation_functions(self):
@@ -297,7 +297,7 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config form_data_key="abc123", dataset_id=1, config=config
) )
assert request.config.columns[0].aggregate == agg assert request.config["columns"][0]["aggregate"] == agg
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_preview_error_responses(self): async def test_update_chart_preview_error_responses(self):
@@ -347,10 +347,10 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config form_data_key="abc123", dataset_id=1, config=config
) )
assert len(request.config.filters) == 3 assert len(request.config["filters"]) == 3
assert request.config.filters[0].column == "region" assert request.config["filters"][0]["column"] == "region"
assert request.config.filters[1].op == ">=" assert request.config["filters"][1]["op"] == ">="
assert request.config.filters[2].value == "2024-01-01" assert request.config["filters"][2]["value"] == "2024-01-01"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_preview_form_data_key_handling(self): async def test_update_chart_preview_form_data_key_handling(self):
@@ -447,12 +447,12 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config form_data_key="abc123", dataset_id=1, config=config
) )
assert len(request.config.y) == 4 assert len(request.config["y"]) == 4
assert request.config.y[0].name == "revenue" assert request.config["y"][0]["name"] == "revenue"
assert request.config.y[1].name == "cost" assert request.config["y"][1]["name"] == "cost"
assert request.config.y[2].name == "profit" assert request.config["y"][2]["name"] == "profit"
assert request.config.y[3].name == "orders" assert request.config["y"][3]["name"] == "orders"
assert request.config.y[3].aggregate == "COUNT" assert request.config["y"][3]["aggregate"] == "COUNT"
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_update_chart_preview_table_sorting(self): async def test_update_chart_preview_table_sorting(self):
@@ -470,5 +470,5 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest( request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config form_data_key="abc123", dataset_id=1, config=config
) )
assert request.config.sort_by == ["sales", "profit"] assert request.config["sort_by"] == ["sales", "profit"]
assert len(request.config.columns) == 3 assert len(request.config["columns"]) == 3