fix(mcp): compress chart config schemas to reduce search_tools token usage (#39018)

This commit is contained in:
Amin Ghadersohi
2026-04-06 19:52:03 -04:00
committed by GitHub
parent b05764d070
commit bf9aff19b5
10 changed files with 267 additions and 110 deletions

View File

@@ -36,6 +36,7 @@ from pydantic import (
model_serializer,
model_validator,
PositiveInt,
TypeAdapter,
)
from typing_extensions import Self
@@ -1145,7 +1146,7 @@ class XYChartConfig(UnknownFieldCheckMixin):
return self
# Discriminated union entry point with custom error handling
# Discriminated union for runtime validation (not exposed in JSON Schema)
ChartConfig = Annotated[
XYChartConfig
| TableChartConfig
@@ -1164,6 +1165,66 @@ ChartConfig = Annotated[
),
]
# Module-level TypeAdapter avoids repeated schema compilation in
# parse_chart_config() — safe because ChartConfig is fully defined above.
_CHART_CONFIG_ADAPTER: TypeAdapter[ChartConfig] = TypeAdapter(ChartConfig)
# Compact description for JSON Schema — keeps tool inputSchema small while
# giving LLMs enough context to construct valid configs.
_CHART_CONFIG_DESCRIPTION = (
"Chart configuration object. MUST include 'chart_type' to select the "
"schema. Types: 'xy' (x, y, kind: line/bar/area/scatter), "
"'table' (columns), 'pie' (dimension, metric), "
"'pivot_table' (rows, metrics), 'mixed_timeseries' (x, y, y_secondary), "
"'handlebars' (columns, handlebars_template), "
"'big_number' (metric). "
"See chart://configs resource for full field reference and examples."
)
def parse_chart_config(
config: Dict[str, Any],
) -> (
XYChartConfig
| TableChartConfig
| PieChartConfig
| PivotTableChartConfig
| MixedTimeseriesChartConfig
| HandlebarsChartConfig
| BigNumberChartConfig
):
"""Parse a raw dict into the appropriate typed ChartConfig subclass.
Validates the dict against the discriminated union using chart_type.
Call this in tool function bodies to get a typed config object.
"""
try:
return _CHART_CONFIG_ADAPTER.validate_python(config)
except Exception as e:
raise ValueError(
f"{e}\n\n"
f"Hint: read the chart://configs resource for valid configuration "
f"examples and field reference."
) from e
def _coerce_config_to_dict(v: Any) -> Dict[str, Any]:
"""Accept ChartConfig objects, dicts, or JSON strings for the config field."""
if isinstance(v, str):
from superset.utils import json as json_utils
try:
v = json_utils.loads(v)
except (ValueError, TypeError) as exc:
raise ValueError(
f"config must be a JSON object string, got: {v!r}"
) from exc
if hasattr(v, "model_dump"):
return v.model_dump()
if isinstance(v, dict):
return v
raise TypeError(f"config must be a dict or JSON string, got {type(v).__name__}")
class ListChartsRequest(MetadataCacheControl):
"""Request schema for list_charts with clear, unambiguous types."""
@@ -1259,7 +1320,7 @@ class ListChartsRequest(MetadataCacheControl):
# The tool input models
class GenerateChartRequest(QueryCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration")
config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
chart_name: str | None = Field(
None, description="Auto-generates if omitted", max_length=255
)
@@ -1269,6 +1330,11 @@ class GenerateChartRequest(QueryCacheControl):
default_factory=lambda: ["url"],
)
@field_validator("config", mode="before")
@classmethod
def coerce_config(cls, v: Any) -> Dict[str, Any]:
return _coerce_config_to_dict(v)
@field_validator("chart_name")
@classmethod
def sanitize_chart_name(cls, v: str | None) -> str | None:
@@ -1301,16 +1367,20 @@ class GenerateChartRequest(QueryCacheControl):
class GenerateExploreLinkRequest(FormDataCacheControl):
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
config: ChartConfig = Field(..., description="Chart configuration")
config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
@field_validator("config", mode="before")
@classmethod
def coerce_config(cls, v: Any) -> Dict[str, Any]:
return _coerce_config_to_dict(v)
class UpdateChartRequest(QueryCacheControl):
identifier: int | str = Field(..., description="Chart ID or UUID")
config: ChartConfig | None = Field(
config: Dict[str, Any] | None = Field(
None,
description=(
"Chart configuration. Required for visualization changes. "
"Omit to only update the chart name."
f"{_CHART_CONFIG_DESCRIPTION} Optional; omit to only update chart_name."
),
)
chart_name: str | None = Field(
@@ -1321,6 +1391,13 @@ class UpdateChartRequest(QueryCacheControl):
default_factory=lambda: ["url"],
)
@field_validator("config", mode="before")
@classmethod
def coerce_config(cls, v: Any) -> Dict[str, Any] | None:
if v is None:
return None
return _coerce_config_to_dict(v)
@field_validator("chart_name")
@classmethod
def sanitize_chart_name(cls, v: str | None) -> str | None:
@@ -1331,12 +1408,17 @@ class UpdateChartRequest(QueryCacheControl):
class UpdateChartPreviewRequest(FormDataCacheControl):
form_data_key: str = Field(..., description="Existing form_data_key to update")
dataset_id: int | str = Field(..., description="Dataset ID or UUID")
config: ChartConfig
config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
generate_preview: bool = True
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
default_factory=lambda: ["url"],
)
@field_validator("config", mode="before")
@classmethod
def coerce_config(cls, v: Any) -> Dict[str, Any]:
return _coerce_config_to_dict(v)
class GetChartDataRequest(QueryCacheControl):
"""Request for chart data with cache control.

View File

@@ -43,6 +43,7 @@ from superset.mcp_service.chart.schemas import (
ChartError,
GenerateChartRequest,
GenerateChartResponse,
parse_chart_config,
PerformanceMetadata,
)
from superset.mcp_service.utils.url_utils import get_superset_base_url
@@ -209,13 +210,17 @@ async def generate_chart( # noqa: C901
"save_chart=%s, preview_formats=%s"
% (
request.dataset_id,
request.config.chart_type,
request.config.get("chart_type", "unknown"),
request.save_chart,
request.preview_formats,
)
)
await ctx.debug(
"Chart configuration details: config=%s" % (request.config.model_dump(),)
"Chart configuration details: chart_type=%s, keys=%s"
% (
request.config.get("chart_type", "unknown"),
sorted(request.config.keys()),
)
)
# Track runtime warnings to include in response
@@ -269,11 +274,12 @@ async def generate_chart( # noqa: C901
}
)
# Parse the raw config dict into a typed ChartConfig for downstream use
config = parse_chart_config(request.config)
# Map the simplified config to Superset's form_data format
# Pass dataset_id to enable column type checking for proper viz_type selection
form_data = map_config_to_form_data(
request.config, dataset_id=request.dataset_id
)
form_data = map_config_to_form_data(config, dataset_id=request.dataset_id)
chart = None
chart_id = None
@@ -367,7 +373,7 @@ async def generate_chart( # noqa: C901
dataset, "table_name", None
)
chart_name = request.chart_name or generate_chart_name(
request.config, dataset_name=dataset_name
config, dataset_name=dataset_name
)
await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
@@ -607,8 +613,8 @@ async def generate_chart( # noqa: C901
response_warnings.extend(compile_result.warnings)
# Generate semantic analysis
capabilities = analyze_chart_capabilities(chart, request.config)
semantics = analyze_chart_semantics(chart, request.config)
capabilities = analyze_chart_capabilities(chart, config)
semantics = analyze_chart_semantics(chart, config)
# Create performance metadata
execution_time = int((time.time() - start_time) * 1000)
@@ -622,7 +628,7 @@ async def generate_chart( # noqa: C901
chart_name = (
chart.slice_name
if chart and hasattr(chart, "slice_name")
else generate_chart_name(request.config)
else generate_chart_name(config)
)
accessibility = AccessibilityMetadata(
color_blind_safe=True, # Would need actual analysis
@@ -843,9 +849,9 @@ async def generate_chart( # noqa: C901
# Extract chart_type from different sources for better error context
chart_type = "unknown"
try:
if hasattr(request, "config") and hasattr(request.config, "chart_type"):
chart_type = request.config.chart_type
except AttributeError as extract_error:
if hasattr(request, "config") and isinstance(request.config, dict):
chart_type = request.config.get("chart_type", "unknown")
except (AttributeError, TypeError) as extract_error:
# Ignore errors when extracting chart type for error context
logger.debug("Could not extract chart type: %s", extract_error)

View File

@@ -38,6 +38,7 @@ from superset.mcp_service.chart.chart_utils import (
from superset.mcp_service.chart.schemas import (
AccessibilityMetadata,
GenerateChartResponse,
parse_chart_config,
PerformanceMetadata,
UpdateChartRequest,
)
@@ -69,14 +70,15 @@ def _build_update_payload(
when neither config nor chart_name is provided.
"""
if request.config is not None:
config = parse_chart_config(request.config)
dataset_id = chart.datasource_id if chart.datasource_id else None
new_form_data = map_config_to_form_data(request.config, dataset_id=dataset_id)
new_form_data = map_config_to_form_data(config, dataset_id=dataset_id)
new_form_data.pop("_mcp_warnings", None)
chart_name = (
request.chart_name
if request.chart_name
else chart.slice_name or generate_chart_name(request.config)
else chart.slice_name or generate_chart_name(config)
)
return {
@@ -222,9 +224,12 @@ async def update_chart(
command = UpdateChartCommand(chart.id, payload_or_error)
updated_chart = command.run()
# Parse config for analysis (may be None for name-only updates)
config = parse_chart_config(request.config) if request.config else None
# Generate semantic analysis
capabilities = analyze_chart_capabilities(updated_chart, request.config)
semantics = analyze_chart_semantics(updated_chart, request.config)
capabilities = analyze_chart_capabilities(updated_chart, config)
semantics = analyze_chart_semantics(updated_chart, config)
# Create performance metadata
execution_time = int((time.time() - start_time) * 1000)
@@ -238,11 +243,7 @@ async def update_chart(
chart_name = (
updated_chart.slice_name
if updated_chart and hasattr(updated_chart, "slice_name")
else (
generate_chart_name(request.config)
if request.config
else "Updated chart"
)
else (generate_chart_name(config) if config else "Updated chart")
)
accessibility = AccessibilityMetadata(
color_blind_safe=True, # Would need actual analysis

View File

@@ -36,6 +36,7 @@ from superset.mcp_service.chart.chart_utils import (
)
from superset.mcp_service.chart.schemas import (
AccessibilityMetadata,
parse_chart_config,
PerformanceMetadata,
UpdateChartPreviewRequest,
)
@@ -95,20 +96,20 @@ def update_chart_preview(
start_time = time.time()
try:
# Parse the raw config dict into a typed ChartConfig
config = parse_chart_config(request.config)
with event_logger.log_context(action="mcp.update_chart_preview.form_data"):
# Map the new config to form_data format
# Pass dataset_id to enable column type checking
new_form_data = map_config_to_form_data(
request.config, dataset_id=request.dataset_id
config, dataset_id=request.dataset_id
)
new_form_data.pop("_mcp_warnings", None)
# Preserve adhoc filters from the previous cached form_data
# when the new config doesn't explicitly specify filters
if (
getattr(request.config, "filters", None) is None
and request.form_data_key
):
if getattr(config, "filters", None) is None and request.form_data_key:
old_adhoc_filters = _get_old_adhoc_filters(request.form_data_key)
if old_adhoc_filters:
new_form_data["adhoc_filters"] = old_adhoc_filters
@@ -123,8 +124,8 @@ def update_chart_preview(
with event_logger.log_context(action="mcp.update_chart_preview.metadata"):
# Generate semantic analysis
capabilities = analyze_chart_capabilities(None, request.config)
semantics = analyze_chart_semantics(None, request.config)
capabilities = analyze_chart_capabilities(None, config)
semantics = analyze_chart_semantics(None, config)
# Create performance metadata
execution_time = int((time.time() - start_time) * 1000)
@@ -135,7 +136,7 @@ def update_chart_preview(
)
# Create accessibility metadata
chart_name = generate_chart_name(request.config)
chart_name = generate_chart_name(config)
accessibility = AccessibilityMetadata(
color_blind_safe=True, # Would need actual analysis
alt_text=f"Updated chart preview showing {chart_name}",

View File

@@ -26,6 +26,7 @@ from typing import Any, Dict, List, Tuple
from superset.mcp_service.chart.schemas import (
ChartConfig,
GenerateChartRequest,
parse_chart_config,
)
from superset.mcp_service.common.error_schemas import (
ChartGenerationError,
@@ -171,6 +172,10 @@ class ValidationPipeline:
if request is None:
return ValidationResult(is_valid=False, error=error)
# Parse the raw config dict into a typed ChartConfig for
# downstream validators that need typed access.
typed_config = parse_chart_config(request.config)
# Fetch dataset context once and reuse across validation layers
dataset_context = ValidationPipeline._get_dataset_context(
request.dataset_id
@@ -178,20 +183,20 @@ class ValidationPipeline:
# Layer 2: Dataset validation (reuses context)
is_valid, error = ValidationPipeline._validate_dataset(
request.config, request.dataset_id, dataset_context
typed_config, request.dataset_id, dataset_context
)
if not is_valid:
return ValidationResult(is_valid=False, request=request, error=error)
# Layer 3: Runtime validation - returns warnings as metadata, not errors
_is_valid, warnings_metadata = ValidationPipeline._validate_runtime(
request.config, request.dataset_id
typed_config, request.dataset_id
)
# Runtime validation always returns True now, warnings are informational
# Layer 4: Column name normalization (reuses context)
normalized_request = ValidationPipeline._normalize_column_names(
request, dataset_context
request, dataset_context, typed_config=typed_config
)
return ValidationResult(
@@ -284,6 +289,7 @@ class ValidationPipeline:
def _normalize_column_names(
request: GenerateChartRequest,
dataset_context: DatasetContext | None = None,
typed_config: ChartConfig | None = None,
) -> GenerateChartRequest:
"""
Normalize column names in the request to match canonical dataset names.
@@ -297,6 +303,8 @@ class ValidationPipeline:
request: The validated chart generation request
dataset_context: Pre-fetched dataset context to avoid duplicate
DB queries. If None, fetches from the database.
typed_config: Pre-parsed typed ChartConfig. If None, parses from
request.config dict.
Returns:
A new request with normalized column names
@@ -304,8 +312,9 @@ class ValidationPipeline:
try:
from .dataset_validator import DatasetValidator
config = typed_config or parse_chart_config(request.config)
normalized_config = DatasetValidator.normalize_column_names(
request.config,
config,
request.dataset_id,
dataset_context=dataset_context,
)

View File

@@ -35,6 +35,7 @@ from superset.mcp_service.chart.chart_utils import (
)
from superset.mcp_service.chart.schemas import (
GenerateExploreLinkRequest,
parse_chart_config,
)
@@ -89,7 +90,7 @@ async def generate_explore_link(
"""
await ctx.info(
"Generating explore link for dataset_id=%s, chart_type=%s"
% (request.dataset_id, request.config.chart_type)
% (request.dataset_id, request.config.get("chart_type", "unknown"))
)
await ctx.debug(
"Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s"
@@ -97,6 +98,9 @@ async def generate_explore_link(
)
try:
# Parse the raw config dict into a typed ChartConfig
config = parse_chart_config(request.config)
await ctx.report_progress(1, 4, "Validating dataset exists")
with event_logger.log_context(action="mcp.generate_explore_link.dataset_check"):
from superset.daos.dataset import DatasetDAO
@@ -138,10 +142,10 @@ async def generate_explore_link(
)
normalized_config = DatasetValidator.normalize_column_names(
request.config, request.dataset_id
config, request.dataset_id
)
except (ImportError, AttributeError, KeyError, ValueError, TypeError):
normalized_config = request.config
normalized_config = config
# Map config to form_data using shared utilities
form_data = map_config_to_form_data(
@@ -197,7 +201,12 @@ async def generate_explore_link(
except Exception as e:
await ctx.error(
"Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s"
% (request.dataset_id, request.config.chart_type, type(e).__name__, str(e))
% (
request.dataset_id,
request.config.get("chart_type", "unknown"),
type(e).__name__,
str(e),
)
)
return {
"url": "",

View File

@@ -24,6 +24,8 @@ from pydantic import ValidationError
from superset.mcp_service.chart.schemas import (
ColumnRef,
GenerateChartRequest,
parse_chart_config,
TableChartConfig,
XYChartConfig,
)
@@ -625,3 +627,48 @@ class TestColumnRefSavedMetric:
),
],
)
class TestParseChartConfig:
"""Tests for parse_chart_config and config coercion."""
def test_parse_valid_xy_config(self) -> None:
config = parse_chart_config(
{"chart_type": "xy", "x": {"name": "date"}, "y": [{"name": "v"}]}
)
assert config.chart_type == "xy"
assert config.x.name == "date"
assert len(config.y) == 1
assert config.y[0].name == "v"
def test_parse_valid_table_config(self) -> None:
config = parse_chart_config(
{"chart_type": "table", "columns": [{"name": "col1"}]}
)
assert config.chart_type == "table"
assert len(config.columns) == 1
assert config.columns[0].name == "col1"
def test_parse_missing_chart_type_raises(self) -> None:
with pytest.raises(ValueError, match="chart://configs"):
parse_chart_config({"x": {"name": "date"}, "y": [{"name": "v"}]})
def test_parse_unknown_chart_type_raises(self) -> None:
with pytest.raises(ValueError, match="chart://configs"):
parse_chart_config({"chart_type": "nonexistent", "x": {"name": "d"}})
def test_coerce_json_string_config(self) -> None:
raw = '{"chart_type": "table", "columns": [{"name": "c"}]}'
req = GenerateChartRequest(dataset_id=1, config=raw)
assert isinstance(req.config, dict)
assert req.config["chart_type"] == "table"
def test_coerce_typed_config_object(self) -> None:
typed = TableChartConfig(chart_type="table", columns=[ColumnRef(name="c")])
req = GenerateChartRequest(dataset_id=1, config=typed)
assert isinstance(req.config, dict)
assert req.config["chart_type"] == "table"
def test_coerce_invalid_json_string_raises(self) -> None:
with pytest.raises(ValidationError):
GenerateChartRequest(dataset_id=1, config="not valid json")

View File

@@ -57,10 +57,11 @@ class TestGenerateChart:
)
table_request = GenerateChartRequest(dataset_id="1", config=table_config)
assert table_request.dataset_id == "1"
assert table_request.config.chart_type == "table"
assert len(table_request.config.columns) == 2
assert table_request.config.columns[0].name == "region"
assert table_request.config.columns[1].aggregate == "SUM"
# config is now Dict[str, Any] in the schema; validate via dict access
assert table_request.config["chart_type"] == "table"
assert len(table_request.config["columns"]) == 2
assert table_request.config["columns"][0]["name"] == "region"
assert table_request.config["columns"][1]["aggregate"] == "SUM"
# XY chart request
xy_config = XYChartConfig(
@@ -74,12 +75,12 @@ class TestGenerateChart:
legend=LegendConfig(show=True, position="top"),
)
xy_request = GenerateChartRequest(dataset_id="2", config=xy_config)
assert xy_request.config.chart_type == "xy"
assert xy_request.config.x.name == "date"
assert xy_request.config.y[0].aggregate == "SUM"
assert xy_request.config.kind == "line"
assert xy_request.config.x_axis.title == "Date"
assert xy_request.config.legend.show is True
assert xy_request.config["chart_type"] == "xy"
assert xy_request.config["x"]["name"] == "date"
assert xy_request.config["y"][0]["aggregate"] == "SUM"
assert xy_request.config["kind"] == "line"
assert xy_request.config["x_axis"]["title"] == "Date"
assert xy_request.config["legend"]["show"] is True
@pytest.mark.asyncio
async def test_generate_chart_validation_error_handling(self):

View File

@@ -60,10 +60,11 @@ class TestUpdateChart:
)
table_request = UpdateChartRequest(identifier=123, config=table_config)
assert table_request.identifier == 123
assert table_request.config.chart_type == "table"
assert len(table_request.config.columns) == 2
assert table_request.config.columns[0].name == "region"
assert table_request.config.columns[1].aggregate == "SUM"
# config is now Dict[str, Any] in the schema; validate via dict access
assert table_request.config["chart_type"] == "table"
assert len(table_request.config["columns"]) == 2
assert table_request.config["columns"][0]["name"] == "region"
assert table_request.config["columns"][1]["aggregate"] == "SUM"
# XY chart update with UUID
xy_config = XYChartConfig(
@@ -80,10 +81,10 @@ class TestUpdateChart:
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config
)
assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
assert xy_request.config.chart_type == "xy"
assert xy_request.config.x.name == "date"
assert xy_request.config.y[0].aggregate == "SUM"
assert xy_request.config.kind == "line"
assert xy_request.config["chart_type"] == "xy"
assert xy_request.config["x"]["name"] == "date"
assert xy_request.config["y"][0]["aggregate"] == "SUM"
assert xy_request.config["kind"] == "line"
@pytest.mark.asyncio
async def test_update_chart_with_chart_name(self):
@@ -170,7 +171,7 @@ class TestUpdateChart:
kind=chart_type,
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.kind == chart_type
assert request.config["kind"] == chart_type
# Test multiple Y columns
multi_y_config = XYChartConfig(
@@ -184,8 +185,8 @@ class TestUpdateChart:
kind="line",
)
request = UpdateChartRequest(identifier=1, config=multi_y_config)
assert len(request.config.y) == 3
assert request.config.y[1].aggregate == "AVG"
assert len(request.config["y"]) == 3
assert request.config["y"][1]["aggregate"] == "AVG"
# Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="]
@@ -196,7 +197,7 @@ class TestUpdateChart:
filters=filters,
)
request = UpdateChartRequest(identifier=1, config=table_config)
assert len(request.config.filters) == 6
assert len(request.config["filters"]) == 6
@pytest.mark.asyncio
async def test_update_chart_response_structure(self):
@@ -252,12 +253,12 @@ class TestUpdateChart:
),
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.x_axis.title == "Date"
assert request.config.x_axis.format == "smart_date"
assert request.config.x_axis.scale == "linear"
assert request.config.y_axis.title == "Sales Amount"
assert request.config.y_axis.format == "$,.2f"
assert request.config.y_axis.scale == "log"
assert request.config["x_axis"]["title"] == "Date"
assert request.config["x_axis"]["format"] == "smart_date"
assert request.config["x_axis"]["scale"] == "linear"
assert request.config["y_axis"]["title"] == "Sales Amount"
assert request.config["y_axis"]["format"] == "$,.2f"
assert request.config["y_axis"]["scale"] == "log"
@pytest.mark.asyncio
async def test_update_chart_legend_configurations(self):
@@ -271,8 +272,8 @@ class TestUpdateChart:
legend=LegendConfig(show=True, position=pos),
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.legend.position == pos
assert request.config.legend.show is True
assert request.config["legend"]["position"] == pos
assert request.config["legend"]["show"] is True
# Hidden legend
config = XYChartConfig(
@@ -282,7 +283,7 @@ class TestUpdateChart:
legend=LegendConfig(show=False),
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.legend.show is False
assert request.config["legend"]["show"] is False
@pytest.mark.asyncio
async def test_update_chart_aggregation_functions(self):
@@ -294,7 +295,7 @@ class TestUpdateChart:
columns=[ColumnRef(name="value", aggregate=agg)],
)
request = UpdateChartRequest(identifier=1, config=config)
assert request.config.columns[0].aggregate == agg
assert request.config["columns"][0]["aggregate"] == agg
@pytest.mark.asyncio
async def test_update_chart_error_responses(self):
@@ -378,10 +379,10 @@ class TestUpdateChart:
)
request = UpdateChartRequest(identifier=1, config=config)
assert len(request.config.filters) == 3
assert request.config.filters[0].column == "region"
assert request.config.filters[1].op == ">="
assert request.config.filters[2].value == "2024-01-01"
assert len(request.config["filters"]) == 3
assert request.config["filters"][0]["column"] == "region"
assert request.config["filters"][1]["op"] == ">="
assert request.config["filters"][2]["value"] == "2024-01-01"
@pytest.mark.asyncio
async def test_update_chart_cache_control(self):

View File

@@ -53,9 +53,9 @@ class TestUpdateChartPreview:
)
assert table_request.form_data_key == "abc123def456"
assert table_request.dataset_id == 1
assert table_request.config.chart_type == "table"
assert len(table_request.config.columns) == 2
assert table_request.config.columns[0].name == "region"
assert table_request.config["chart_type"] == "table"
assert len(table_request.config["columns"]) == 2
assert table_request.config["columns"][0]["name"] == "region"
# XY chart preview update
xy_config = XYChartConfig(
@@ -73,9 +73,9 @@ class TestUpdateChartPreview:
)
assert xy_request.form_data_key == "xyz789ghi012"
assert xy_request.dataset_id == "2"
assert xy_request.config.chart_type == "xy"
assert xy_request.config.x.name == "date"
assert xy_request.config.kind == "line"
assert xy_request.config["chart_type"] == "xy"
assert xy_request.config["x"]["name"] == "date"
assert xy_request.config["kind"] == "line"
@pytest.mark.asyncio
async def test_update_chart_preview_dataset_id_types(self):
@@ -158,7 +158,7 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.kind == chart_type
assert request.config["kind"] == chart_type
# Test multiple Y columns
multi_y_config = XYChartConfig(
@@ -174,8 +174,8 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=multi_y_config
)
assert len(request.config.y) == 3
assert request.config.y[1].aggregate == "AVG"
assert len(request.config["y"]) == 3
assert request.config["y"][1]["aggregate"] == "AVG"
# Test filter operators
operators = ["=", "!=", ">", ">=", "<", "<="]
@@ -188,7 +188,7 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=table_config
)
assert len(request.config.filters) == 6
assert len(request.config["filters"]) == 6
@pytest.mark.asyncio
async def test_update_chart_preview_response_structure(self):
@@ -251,10 +251,10 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.x_axis.title == "Date"
assert request.config.x_axis.format == "smart_date"
assert request.config.y_axis.title == "Sales Amount"
assert request.config.y_axis.format == "$,.2f"
assert request.config["x_axis"]["title"] == "Date"
assert request.config["x_axis"]["format"] == "smart_date"
assert request.config["y_axis"]["title"] == "Sales Amount"
assert request.config["y_axis"]["format"] == "$,.2f"
@pytest.mark.asyncio
async def test_update_chart_preview_legend_configurations(self):
@@ -270,8 +270,8 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.legend.position == pos
assert request.config.legend.show is True
assert request.config["legend"]["position"] == pos
assert request.config["legend"]["show"] is True
# Hidden legend
config = XYChartConfig(
@@ -283,7 +283,7 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.legend.show is False
assert request.config["legend"]["show"] is False
@pytest.mark.asyncio
async def test_update_chart_preview_aggregation_functions(self):
@@ -297,7 +297,7 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.columns[0].aggregate == agg
assert request.config["columns"][0]["aggregate"] == agg
@pytest.mark.asyncio
async def test_update_chart_preview_error_responses(self):
@@ -347,10 +347,10 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert len(request.config.filters) == 3
assert request.config.filters[0].column == "region"
assert request.config.filters[1].op == ">="
assert request.config.filters[2].value == "2024-01-01"
assert len(request.config["filters"]) == 3
assert request.config["filters"][0]["column"] == "region"
assert request.config["filters"][1]["op"] == ">="
assert request.config["filters"][2]["value"] == "2024-01-01"
@pytest.mark.asyncio
async def test_update_chart_preview_form_data_key_handling(self):
@@ -447,12 +447,12 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert len(request.config.y) == 4
assert request.config.y[0].name == "revenue"
assert request.config.y[1].name == "cost"
assert request.config.y[2].name == "profit"
assert request.config.y[3].name == "orders"
assert request.config.y[3].aggregate == "COUNT"
assert len(request.config["y"]) == 4
assert request.config["y"][0]["name"] == "revenue"
assert request.config["y"][1]["name"] == "cost"
assert request.config["y"][2]["name"] == "profit"
assert request.config["y"][3]["name"] == "orders"
assert request.config["y"][3]["aggregate"] == "COUNT"
@pytest.mark.asyncio
async def test_update_chart_preview_table_sorting(self):
@@ -470,5 +470,5 @@ class TestUpdateChartPreview:
request = UpdateChartPreviewRequest(
form_data_key="abc123", dataset_id=1, config=config
)
assert request.config.sort_by == ["sales", "profit"]
assert len(request.config.columns) == 3
assert request.config["sort_by"] == ["sales", "profit"]
assert len(request.config["columns"]) == 3