mirror of
https://github.com/apache/superset.git
synced 2026-04-07 02:21:51 +00:00
fix(mcp): compress chart config schemas to reduce search_tools token usage (#39018)
This commit is contained in:
@@ -36,6 +36,7 @@ from pydantic import (
|
||||
model_serializer,
|
||||
model_validator,
|
||||
PositiveInt,
|
||||
TypeAdapter,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
@@ -1145,7 +1146,7 @@ class XYChartConfig(UnknownFieldCheckMixin):
|
||||
return self
|
||||
|
||||
|
||||
# Discriminated union entry point with custom error handling
|
||||
# Discriminated union for runtime validation (not exposed in JSON Schema)
|
||||
ChartConfig = Annotated[
|
||||
XYChartConfig
|
||||
| TableChartConfig
|
||||
@@ -1164,6 +1165,66 @@ ChartConfig = Annotated[
|
||||
),
|
||||
]
|
||||
|
||||
# Module-level TypeAdapter avoids repeated schema compilation in
|
||||
# parse_chart_config() — safe because ChartConfig is fully defined above.
|
||||
_CHART_CONFIG_ADAPTER: TypeAdapter[ChartConfig] = TypeAdapter(ChartConfig)
|
||||
|
||||
# Compact description for JSON Schema — keeps tool inputSchema small while
|
||||
# giving LLMs enough context to construct valid configs.
|
||||
_CHART_CONFIG_DESCRIPTION = (
|
||||
"Chart configuration object. MUST include 'chart_type' to select the "
|
||||
"schema. Types: 'xy' (x, y, kind: line/bar/area/scatter), "
|
||||
"'table' (columns), 'pie' (dimension, metric), "
|
||||
"'pivot_table' (rows, metrics), 'mixed_timeseries' (x, y, y_secondary), "
|
||||
"'handlebars' (columns, handlebars_template), "
|
||||
"'big_number' (metric). "
|
||||
"See chart://configs resource for full field reference and examples."
|
||||
)
|
||||
|
||||
|
||||
def parse_chart_config(
|
||||
config: Dict[str, Any],
|
||||
) -> (
|
||||
XYChartConfig
|
||||
| TableChartConfig
|
||||
| PieChartConfig
|
||||
| PivotTableChartConfig
|
||||
| MixedTimeseriesChartConfig
|
||||
| HandlebarsChartConfig
|
||||
| BigNumberChartConfig
|
||||
):
|
||||
"""Parse a raw dict into the appropriate typed ChartConfig subclass.
|
||||
|
||||
Validates the dict against the discriminated union using chart_type.
|
||||
Call this in tool function bodies to get a typed config object.
|
||||
"""
|
||||
try:
|
||||
return _CHART_CONFIG_ADAPTER.validate_python(config)
|
||||
except Exception as e:
|
||||
raise ValueError(
|
||||
f"{e}\n\n"
|
||||
f"Hint: read the chart://configs resource for valid configuration "
|
||||
f"examples and field reference."
|
||||
) from e
|
||||
|
||||
|
||||
def _coerce_config_to_dict(v: Any) -> Dict[str, Any]:
|
||||
"""Accept ChartConfig objects, dicts, or JSON strings for the config field."""
|
||||
if isinstance(v, str):
|
||||
from superset.utils import json as json_utils
|
||||
|
||||
try:
|
||||
v = json_utils.loads(v)
|
||||
except (ValueError, TypeError) as exc:
|
||||
raise ValueError(
|
||||
f"config must be a JSON object string, got: {v!r}"
|
||||
) from exc
|
||||
if hasattr(v, "model_dump"):
|
||||
return v.model_dump()
|
||||
if isinstance(v, dict):
|
||||
return v
|
||||
raise TypeError(f"config must be a dict or JSON string, got {type(v).__name__}")
|
||||
|
||||
|
||||
class ListChartsRequest(MetadataCacheControl):
|
||||
"""Request schema for list_charts with clear, unambiguous types."""
|
||||
@@ -1259,7 +1320,7 @@ class ListChartsRequest(MetadataCacheControl):
|
||||
# The tool input models
|
||||
class GenerateChartRequest(QueryCacheControl):
|
||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="Chart configuration")
|
||||
config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
|
||||
chart_name: str | None = Field(
|
||||
None, description="Auto-generates if omitted", max_length=255
|
||||
)
|
||||
@@ -1269,6 +1330,11 @@ class GenerateChartRequest(QueryCacheControl):
|
||||
default_factory=lambda: ["url"],
|
||||
)
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def coerce_config(cls, v: Any) -> Dict[str, Any]:
|
||||
return _coerce_config_to_dict(v)
|
||||
|
||||
@field_validator("chart_name")
|
||||
@classmethod
|
||||
def sanitize_chart_name(cls, v: str | None) -> str | None:
|
||||
@@ -1301,16 +1367,20 @@ class GenerateChartRequest(QueryCacheControl):
|
||||
|
||||
class GenerateExploreLinkRequest(FormDataCacheControl):
|
||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||
config: ChartConfig = Field(..., description="Chart configuration")
|
||||
config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def coerce_config(cls, v: Any) -> Dict[str, Any]:
|
||||
return _coerce_config_to_dict(v)
|
||||
|
||||
|
||||
class UpdateChartRequest(QueryCacheControl):
|
||||
identifier: int | str = Field(..., description="Chart ID or UUID")
|
||||
config: ChartConfig | None = Field(
|
||||
config: Dict[str, Any] | None = Field(
|
||||
None,
|
||||
description=(
|
||||
"Chart configuration. Required for visualization changes. "
|
||||
"Omit to only update the chart name."
|
||||
f"{_CHART_CONFIG_DESCRIPTION} Optional; omit to only update chart_name."
|
||||
),
|
||||
)
|
||||
chart_name: str | None = Field(
|
||||
@@ -1321,6 +1391,13 @@ class UpdateChartRequest(QueryCacheControl):
|
||||
default_factory=lambda: ["url"],
|
||||
)
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def coerce_config(cls, v: Any) -> Dict[str, Any] | None:
|
||||
if v is None:
|
||||
return None
|
||||
return _coerce_config_to_dict(v)
|
||||
|
||||
@field_validator("chart_name")
|
||||
@classmethod
|
||||
def sanitize_chart_name(cls, v: str | None) -> str | None:
|
||||
@@ -1331,12 +1408,17 @@ class UpdateChartRequest(QueryCacheControl):
|
||||
class UpdateChartPreviewRequest(FormDataCacheControl):
|
||||
form_data_key: str = Field(..., description="Existing form_data_key to update")
|
||||
dataset_id: int | str = Field(..., description="Dataset ID or UUID")
|
||||
config: ChartConfig
|
||||
config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
|
||||
generate_preview: bool = True
|
||||
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
|
||||
default_factory=lambda: ["url"],
|
||||
)
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def coerce_config(cls, v: Any) -> Dict[str, Any]:
|
||||
return _coerce_config_to_dict(v)
|
||||
|
||||
|
||||
class GetChartDataRequest(QueryCacheControl):
|
||||
"""Request for chart data with cache control.
|
||||
|
||||
@@ -43,6 +43,7 @@ from superset.mcp_service.chart.schemas import (
|
||||
ChartError,
|
||||
GenerateChartRequest,
|
||||
GenerateChartResponse,
|
||||
parse_chart_config,
|
||||
PerformanceMetadata,
|
||||
)
|
||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||
@@ -209,13 +210,17 @@ async def generate_chart( # noqa: C901
|
||||
"save_chart=%s, preview_formats=%s"
|
||||
% (
|
||||
request.dataset_id,
|
||||
request.config.chart_type,
|
||||
request.config.get("chart_type", "unknown"),
|
||||
request.save_chart,
|
||||
request.preview_formats,
|
||||
)
|
||||
)
|
||||
await ctx.debug(
|
||||
"Chart configuration details: config=%s" % (request.config.model_dump(),)
|
||||
"Chart configuration details: chart_type=%s, keys=%s"
|
||||
% (
|
||||
request.config.get("chart_type", "unknown"),
|
||||
sorted(request.config.keys()),
|
||||
)
|
||||
)
|
||||
|
||||
# Track runtime warnings to include in response
|
||||
@@ -269,11 +274,12 @@ async def generate_chart( # noqa: C901
|
||||
}
|
||||
)
|
||||
|
||||
# Parse the raw config dict into a typed ChartConfig for downstream use
|
||||
config = parse_chart_config(request.config)
|
||||
|
||||
# Map the simplified config to Superset's form_data format
|
||||
# Pass dataset_id to enable column type checking for proper viz_type selection
|
||||
form_data = map_config_to_form_data(
|
||||
request.config, dataset_id=request.dataset_id
|
||||
)
|
||||
form_data = map_config_to_form_data(config, dataset_id=request.dataset_id)
|
||||
|
||||
chart = None
|
||||
chart_id = None
|
||||
@@ -367,7 +373,7 @@ async def generate_chart( # noqa: C901
|
||||
dataset, "table_name", None
|
||||
)
|
||||
chart_name = request.chart_name or generate_chart_name(
|
||||
request.config, dataset_name=dataset_name
|
||||
config, dataset_name=dataset_name
|
||||
)
|
||||
await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
|
||||
|
||||
@@ -607,8 +613,8 @@ async def generate_chart( # noqa: C901
|
||||
response_warnings.extend(compile_result.warnings)
|
||||
|
||||
# Generate semantic analysis
|
||||
capabilities = analyze_chart_capabilities(chart, request.config)
|
||||
semantics = analyze_chart_semantics(chart, request.config)
|
||||
capabilities = analyze_chart_capabilities(chart, config)
|
||||
semantics = analyze_chart_semantics(chart, config)
|
||||
|
||||
# Create performance metadata
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
@@ -622,7 +628,7 @@ async def generate_chart( # noqa: C901
|
||||
chart_name = (
|
||||
chart.slice_name
|
||||
if chart and hasattr(chart, "slice_name")
|
||||
else generate_chart_name(request.config)
|
||||
else generate_chart_name(config)
|
||||
)
|
||||
accessibility = AccessibilityMetadata(
|
||||
color_blind_safe=True, # Would need actual analysis
|
||||
@@ -843,9 +849,9 @@ async def generate_chart( # noqa: C901
|
||||
# Extract chart_type from different sources for better error context
|
||||
chart_type = "unknown"
|
||||
try:
|
||||
if hasattr(request, "config") and hasattr(request.config, "chart_type"):
|
||||
chart_type = request.config.chart_type
|
||||
except AttributeError as extract_error:
|
||||
if hasattr(request, "config") and isinstance(request.config, dict):
|
||||
chart_type = request.config.get("chart_type", "unknown")
|
||||
except (AttributeError, TypeError) as extract_error:
|
||||
# Ignore errors when extracting chart type for error context
|
||||
logger.debug("Could not extract chart type: %s", extract_error)
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from superset.mcp_service.chart.chart_utils import (
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AccessibilityMetadata,
|
||||
GenerateChartResponse,
|
||||
parse_chart_config,
|
||||
PerformanceMetadata,
|
||||
UpdateChartRequest,
|
||||
)
|
||||
@@ -69,14 +70,15 @@ def _build_update_payload(
|
||||
when neither config nor chart_name is provided.
|
||||
"""
|
||||
if request.config is not None:
|
||||
config = parse_chart_config(request.config)
|
||||
dataset_id = chart.datasource_id if chart.datasource_id else None
|
||||
new_form_data = map_config_to_form_data(request.config, dataset_id=dataset_id)
|
||||
new_form_data = map_config_to_form_data(config, dataset_id=dataset_id)
|
||||
new_form_data.pop("_mcp_warnings", None)
|
||||
|
||||
chart_name = (
|
||||
request.chart_name
|
||||
if request.chart_name
|
||||
else chart.slice_name or generate_chart_name(request.config)
|
||||
else chart.slice_name or generate_chart_name(config)
|
||||
)
|
||||
|
||||
return {
|
||||
@@ -222,9 +224,12 @@ async def update_chart(
|
||||
command = UpdateChartCommand(chart.id, payload_or_error)
|
||||
updated_chart = command.run()
|
||||
|
||||
# Parse config for analysis (may be None for name-only updates)
|
||||
config = parse_chart_config(request.config) if request.config else None
|
||||
|
||||
# Generate semantic analysis
|
||||
capabilities = analyze_chart_capabilities(updated_chart, request.config)
|
||||
semantics = analyze_chart_semantics(updated_chart, request.config)
|
||||
capabilities = analyze_chart_capabilities(updated_chart, config)
|
||||
semantics = analyze_chart_semantics(updated_chart, config)
|
||||
|
||||
# Create performance metadata
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
@@ -238,11 +243,7 @@ async def update_chart(
|
||||
chart_name = (
|
||||
updated_chart.slice_name
|
||||
if updated_chart and hasattr(updated_chart, "slice_name")
|
||||
else (
|
||||
generate_chart_name(request.config)
|
||||
if request.config
|
||||
else "Updated chart"
|
||||
)
|
||||
else (generate_chart_name(config) if config else "Updated chart")
|
||||
)
|
||||
accessibility = AccessibilityMetadata(
|
||||
color_blind_safe=True, # Would need actual analysis
|
||||
|
||||
@@ -36,6 +36,7 @@ from superset.mcp_service.chart.chart_utils import (
|
||||
)
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
AccessibilityMetadata,
|
||||
parse_chart_config,
|
||||
PerformanceMetadata,
|
||||
UpdateChartPreviewRequest,
|
||||
)
|
||||
@@ -95,20 +96,20 @@ def update_chart_preview(
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
# Parse the raw config dict into a typed ChartConfig
|
||||
config = parse_chart_config(request.config)
|
||||
|
||||
with event_logger.log_context(action="mcp.update_chart_preview.form_data"):
|
||||
# Map the new config to form_data format
|
||||
# Pass dataset_id to enable column type checking
|
||||
new_form_data = map_config_to_form_data(
|
||||
request.config, dataset_id=request.dataset_id
|
||||
config, dataset_id=request.dataset_id
|
||||
)
|
||||
new_form_data.pop("_mcp_warnings", None)
|
||||
|
||||
# Preserve adhoc filters from the previous cached form_data
|
||||
# when the new config doesn't explicitly specify filters
|
||||
if (
|
||||
getattr(request.config, "filters", None) is None
|
||||
and request.form_data_key
|
||||
):
|
||||
if getattr(config, "filters", None) is None and request.form_data_key:
|
||||
old_adhoc_filters = _get_old_adhoc_filters(request.form_data_key)
|
||||
if old_adhoc_filters:
|
||||
new_form_data["adhoc_filters"] = old_adhoc_filters
|
||||
@@ -123,8 +124,8 @@ def update_chart_preview(
|
||||
|
||||
with event_logger.log_context(action="mcp.update_chart_preview.metadata"):
|
||||
# Generate semantic analysis
|
||||
capabilities = analyze_chart_capabilities(None, request.config)
|
||||
semantics = analyze_chart_semantics(None, request.config)
|
||||
capabilities = analyze_chart_capabilities(None, config)
|
||||
semantics = analyze_chart_semantics(None, config)
|
||||
|
||||
# Create performance metadata
|
||||
execution_time = int((time.time() - start_time) * 1000)
|
||||
@@ -135,7 +136,7 @@ def update_chart_preview(
|
||||
)
|
||||
|
||||
# Create accessibility metadata
|
||||
chart_name = generate_chart_name(request.config)
|
||||
chart_name = generate_chart_name(config)
|
||||
accessibility = AccessibilityMetadata(
|
||||
color_blind_safe=True, # Would need actual analysis
|
||||
alt_text=f"Updated chart preview showing {chart_name}",
|
||||
|
||||
@@ -26,6 +26,7 @@ from typing import Any, Dict, List, Tuple
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ChartConfig,
|
||||
GenerateChartRequest,
|
||||
parse_chart_config,
|
||||
)
|
||||
from superset.mcp_service.common.error_schemas import (
|
||||
ChartGenerationError,
|
||||
@@ -171,6 +172,10 @@ class ValidationPipeline:
|
||||
if request is None:
|
||||
return ValidationResult(is_valid=False, error=error)
|
||||
|
||||
# Parse the raw config dict into a typed ChartConfig for
|
||||
# downstream validators that need typed access.
|
||||
typed_config = parse_chart_config(request.config)
|
||||
|
||||
# Fetch dataset context once and reuse across validation layers
|
||||
dataset_context = ValidationPipeline._get_dataset_context(
|
||||
request.dataset_id
|
||||
@@ -178,20 +183,20 @@ class ValidationPipeline:
|
||||
|
||||
# Layer 2: Dataset validation (reuses context)
|
||||
is_valid, error = ValidationPipeline._validate_dataset(
|
||||
request.config, request.dataset_id, dataset_context
|
||||
typed_config, request.dataset_id, dataset_context
|
||||
)
|
||||
if not is_valid:
|
||||
return ValidationResult(is_valid=False, request=request, error=error)
|
||||
|
||||
# Layer 3: Runtime validation - returns warnings as metadata, not errors
|
||||
_is_valid, warnings_metadata = ValidationPipeline._validate_runtime(
|
||||
request.config, request.dataset_id
|
||||
typed_config, request.dataset_id
|
||||
)
|
||||
# Runtime validation always returns True now, warnings are informational
|
||||
|
||||
# Layer 4: Column name normalization (reuses context)
|
||||
normalized_request = ValidationPipeline._normalize_column_names(
|
||||
request, dataset_context
|
||||
request, dataset_context, typed_config=typed_config
|
||||
)
|
||||
|
||||
return ValidationResult(
|
||||
@@ -284,6 +289,7 @@ class ValidationPipeline:
|
||||
def _normalize_column_names(
|
||||
request: GenerateChartRequest,
|
||||
dataset_context: DatasetContext | None = None,
|
||||
typed_config: ChartConfig | None = None,
|
||||
) -> GenerateChartRequest:
|
||||
"""
|
||||
Normalize column names in the request to match canonical dataset names.
|
||||
@@ -297,6 +303,8 @@ class ValidationPipeline:
|
||||
request: The validated chart generation request
|
||||
dataset_context: Pre-fetched dataset context to avoid duplicate
|
||||
DB queries. If None, fetches from the database.
|
||||
typed_config: Pre-parsed typed ChartConfig. If None, parses from
|
||||
request.config dict.
|
||||
|
||||
Returns:
|
||||
A new request with normalized column names
|
||||
@@ -304,8 +312,9 @@ class ValidationPipeline:
|
||||
try:
|
||||
from .dataset_validator import DatasetValidator
|
||||
|
||||
config = typed_config or parse_chart_config(request.config)
|
||||
normalized_config = DatasetValidator.normalize_column_names(
|
||||
request.config,
|
||||
config,
|
||||
request.dataset_id,
|
||||
dataset_context=dataset_context,
|
||||
)
|
||||
|
||||
@@ -35,6 +35,7 @@ from superset.mcp_service.chart.chart_utils import (
|
||||
)
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
GenerateExploreLinkRequest,
|
||||
parse_chart_config,
|
||||
)
|
||||
|
||||
|
||||
@@ -89,7 +90,7 @@ async def generate_explore_link(
|
||||
"""
|
||||
await ctx.info(
|
||||
"Generating explore link for dataset_id=%s, chart_type=%s"
|
||||
% (request.dataset_id, request.config.chart_type)
|
||||
% (request.dataset_id, request.config.get("chart_type", "unknown"))
|
||||
)
|
||||
await ctx.debug(
|
||||
"Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s"
|
||||
@@ -97,6 +98,9 @@ async def generate_explore_link(
|
||||
)
|
||||
|
||||
try:
|
||||
# Parse the raw config dict into a typed ChartConfig
|
||||
config = parse_chart_config(request.config)
|
||||
|
||||
await ctx.report_progress(1, 4, "Validating dataset exists")
|
||||
with event_logger.log_context(action="mcp.generate_explore_link.dataset_check"):
|
||||
from superset.daos.dataset import DatasetDAO
|
||||
@@ -138,10 +142,10 @@ async def generate_explore_link(
|
||||
)
|
||||
|
||||
normalized_config = DatasetValidator.normalize_column_names(
|
||||
request.config, request.dataset_id
|
||||
config, request.dataset_id
|
||||
)
|
||||
except (ImportError, AttributeError, KeyError, ValueError, TypeError):
|
||||
normalized_config = request.config
|
||||
normalized_config = config
|
||||
|
||||
# Map config to form_data using shared utilities
|
||||
form_data = map_config_to_form_data(
|
||||
@@ -197,7 +201,12 @@ async def generate_explore_link(
|
||||
except Exception as e:
|
||||
await ctx.error(
|
||||
"Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s"
|
||||
% (request.dataset_id, request.config.chart_type, type(e).__name__, str(e))
|
||||
% (
|
||||
request.dataset_id,
|
||||
request.config.get("chart_type", "unknown"),
|
||||
type(e).__name__,
|
||||
str(e),
|
||||
)
|
||||
)
|
||||
return {
|
||||
"url": "",
|
||||
|
||||
@@ -24,6 +24,8 @@ from pydantic import ValidationError
|
||||
|
||||
from superset.mcp_service.chart.schemas import (
|
||||
ColumnRef,
|
||||
GenerateChartRequest,
|
||||
parse_chart_config,
|
||||
TableChartConfig,
|
||||
XYChartConfig,
|
||||
)
|
||||
@@ -625,3 +627,48 @@ class TestColumnRefSavedMetric:
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class TestParseChartConfig:
|
||||
"""Tests for parse_chart_config and config coercion."""
|
||||
|
||||
def test_parse_valid_xy_config(self) -> None:
|
||||
config = parse_chart_config(
|
||||
{"chart_type": "xy", "x": {"name": "date"}, "y": [{"name": "v"}]}
|
||||
)
|
||||
assert config.chart_type == "xy"
|
||||
assert config.x.name == "date"
|
||||
assert len(config.y) == 1
|
||||
assert config.y[0].name == "v"
|
||||
|
||||
def test_parse_valid_table_config(self) -> None:
|
||||
config = parse_chart_config(
|
||||
{"chart_type": "table", "columns": [{"name": "col1"}]}
|
||||
)
|
||||
assert config.chart_type == "table"
|
||||
assert len(config.columns) == 1
|
||||
assert config.columns[0].name == "col1"
|
||||
|
||||
def test_parse_missing_chart_type_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="chart://configs"):
|
||||
parse_chart_config({"x": {"name": "date"}, "y": [{"name": "v"}]})
|
||||
|
||||
def test_parse_unknown_chart_type_raises(self) -> None:
|
||||
with pytest.raises(ValueError, match="chart://configs"):
|
||||
parse_chart_config({"chart_type": "nonexistent", "x": {"name": "d"}})
|
||||
|
||||
def test_coerce_json_string_config(self) -> None:
|
||||
raw = '{"chart_type": "table", "columns": [{"name": "c"}]}'
|
||||
req = GenerateChartRequest(dataset_id=1, config=raw)
|
||||
assert isinstance(req.config, dict)
|
||||
assert req.config["chart_type"] == "table"
|
||||
|
||||
def test_coerce_typed_config_object(self) -> None:
|
||||
typed = TableChartConfig(chart_type="table", columns=[ColumnRef(name="c")])
|
||||
req = GenerateChartRequest(dataset_id=1, config=typed)
|
||||
assert isinstance(req.config, dict)
|
||||
assert req.config["chart_type"] == "table"
|
||||
|
||||
def test_coerce_invalid_json_string_raises(self) -> None:
|
||||
with pytest.raises(ValidationError):
|
||||
GenerateChartRequest(dataset_id=1, config="not valid json")
|
||||
|
||||
@@ -57,10 +57,11 @@ class TestGenerateChart:
|
||||
)
|
||||
table_request = GenerateChartRequest(dataset_id="1", config=table_config)
|
||||
assert table_request.dataset_id == "1"
|
||||
assert table_request.config.chart_type == "table"
|
||||
assert len(table_request.config.columns) == 2
|
||||
assert table_request.config.columns[0].name == "region"
|
||||
assert table_request.config.columns[1].aggregate == "SUM"
|
||||
# config is now Dict[str, Any] in the schema; validate via dict access
|
||||
assert table_request.config["chart_type"] == "table"
|
||||
assert len(table_request.config["columns"]) == 2
|
||||
assert table_request.config["columns"][0]["name"] == "region"
|
||||
assert table_request.config["columns"][1]["aggregate"] == "SUM"
|
||||
|
||||
# XY chart request
|
||||
xy_config = XYChartConfig(
|
||||
@@ -74,12 +75,12 @@ class TestGenerateChart:
|
||||
legend=LegendConfig(show=True, position="top"),
|
||||
)
|
||||
xy_request = GenerateChartRequest(dataset_id="2", config=xy_config)
|
||||
assert xy_request.config.chart_type == "xy"
|
||||
assert xy_request.config.x.name == "date"
|
||||
assert xy_request.config.y[0].aggregate == "SUM"
|
||||
assert xy_request.config.kind == "line"
|
||||
assert xy_request.config.x_axis.title == "Date"
|
||||
assert xy_request.config.legend.show is True
|
||||
assert xy_request.config["chart_type"] == "xy"
|
||||
assert xy_request.config["x"]["name"] == "date"
|
||||
assert xy_request.config["y"][0]["aggregate"] == "SUM"
|
||||
assert xy_request.config["kind"] == "line"
|
||||
assert xy_request.config["x_axis"]["title"] == "Date"
|
||||
assert xy_request.config["legend"]["show"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generate_chart_validation_error_handling(self):
|
||||
|
||||
@@ -60,10 +60,11 @@ class TestUpdateChart:
|
||||
)
|
||||
table_request = UpdateChartRequest(identifier=123, config=table_config)
|
||||
assert table_request.identifier == 123
|
||||
assert table_request.config.chart_type == "table"
|
||||
assert len(table_request.config.columns) == 2
|
||||
assert table_request.config.columns[0].name == "region"
|
||||
assert table_request.config.columns[1].aggregate == "SUM"
|
||||
# config is now Dict[str, Any] in the schema; validate via dict access
|
||||
assert table_request.config["chart_type"] == "table"
|
||||
assert len(table_request.config["columns"]) == 2
|
||||
assert table_request.config["columns"][0]["name"] == "region"
|
||||
assert table_request.config["columns"][1]["aggregate"] == "SUM"
|
||||
|
||||
# XY chart update with UUID
|
||||
xy_config = XYChartConfig(
|
||||
@@ -80,10 +81,10 @@ class TestUpdateChart:
|
||||
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config
|
||||
)
|
||||
assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||
assert xy_request.config.chart_type == "xy"
|
||||
assert xy_request.config.x.name == "date"
|
||||
assert xy_request.config.y[0].aggregate == "SUM"
|
||||
assert xy_request.config.kind == "line"
|
||||
assert xy_request.config["chart_type"] == "xy"
|
||||
assert xy_request.config["x"]["name"] == "date"
|
||||
assert xy_request.config["y"][0]["aggregate"] == "SUM"
|
||||
assert xy_request.config["kind"] == "line"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_with_chart_name(self):
|
||||
@@ -170,7 +171,7 @@ class TestUpdateChart:
|
||||
kind=chart_type,
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.kind == chart_type
|
||||
assert request.config["kind"] == chart_type
|
||||
|
||||
# Test multiple Y columns
|
||||
multi_y_config = XYChartConfig(
|
||||
@@ -184,8 +185,8 @@ class TestUpdateChart:
|
||||
kind="line",
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=multi_y_config)
|
||||
assert len(request.config.y) == 3
|
||||
assert request.config.y[1].aggregate == "AVG"
|
||||
assert len(request.config["y"]) == 3
|
||||
assert request.config["y"][1]["aggregate"] == "AVG"
|
||||
|
||||
# Test filter operators
|
||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||
@@ -196,7 +197,7 @@ class TestUpdateChart:
|
||||
filters=filters,
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=table_config)
|
||||
assert len(request.config.filters) == 6
|
||||
assert len(request.config["filters"]) == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_response_structure(self):
|
||||
@@ -252,12 +253,12 @@ class TestUpdateChart:
|
||||
),
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.x_axis.title == "Date"
|
||||
assert request.config.x_axis.format == "smart_date"
|
||||
assert request.config.x_axis.scale == "linear"
|
||||
assert request.config.y_axis.title == "Sales Amount"
|
||||
assert request.config.y_axis.format == "$,.2f"
|
||||
assert request.config.y_axis.scale == "log"
|
||||
assert request.config["x_axis"]["title"] == "Date"
|
||||
assert request.config["x_axis"]["format"] == "smart_date"
|
||||
assert request.config["x_axis"]["scale"] == "linear"
|
||||
assert request.config["y_axis"]["title"] == "Sales Amount"
|
||||
assert request.config["y_axis"]["format"] == "$,.2f"
|
||||
assert request.config["y_axis"]["scale"] == "log"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_legend_configurations(self):
|
||||
@@ -271,8 +272,8 @@ class TestUpdateChart:
|
||||
legend=LegendConfig(show=True, position=pos),
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.legend.position == pos
|
||||
assert request.config.legend.show is True
|
||||
assert request.config["legend"]["position"] == pos
|
||||
assert request.config["legend"]["show"] is True
|
||||
|
||||
# Hidden legend
|
||||
config = XYChartConfig(
|
||||
@@ -282,7 +283,7 @@ class TestUpdateChart:
|
||||
legend=LegendConfig(show=False),
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.legend.show is False
|
||||
assert request.config["legend"]["show"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_aggregation_functions(self):
|
||||
@@ -294,7 +295,7 @@ class TestUpdateChart:
|
||||
columns=[ColumnRef(name="value", aggregate=agg)],
|
||||
)
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert request.config.columns[0].aggregate == agg
|
||||
assert request.config["columns"][0]["aggregate"] == agg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_error_responses(self):
|
||||
@@ -378,10 +379,10 @@ class TestUpdateChart:
|
||||
)
|
||||
|
||||
request = UpdateChartRequest(identifier=1, config=config)
|
||||
assert len(request.config.filters) == 3
|
||||
assert request.config.filters[0].column == "region"
|
||||
assert request.config.filters[1].op == ">="
|
||||
assert request.config.filters[2].value == "2024-01-01"
|
||||
assert len(request.config["filters"]) == 3
|
||||
assert request.config["filters"][0]["column"] == "region"
|
||||
assert request.config["filters"][1]["op"] == ">="
|
||||
assert request.config["filters"][2]["value"] == "2024-01-01"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_cache_control(self):
|
||||
|
||||
@@ -53,9 +53,9 @@ class TestUpdateChartPreview:
|
||||
)
|
||||
assert table_request.form_data_key == "abc123def456"
|
||||
assert table_request.dataset_id == 1
|
||||
assert table_request.config.chart_type == "table"
|
||||
assert len(table_request.config.columns) == 2
|
||||
assert table_request.config.columns[0].name == "region"
|
||||
assert table_request.config["chart_type"] == "table"
|
||||
assert len(table_request.config["columns"]) == 2
|
||||
assert table_request.config["columns"][0]["name"] == "region"
|
||||
|
||||
# XY chart preview update
|
||||
xy_config = XYChartConfig(
|
||||
@@ -73,9 +73,9 @@ class TestUpdateChartPreview:
|
||||
)
|
||||
assert xy_request.form_data_key == "xyz789ghi012"
|
||||
assert xy_request.dataset_id == "2"
|
||||
assert xy_request.config.chart_type == "xy"
|
||||
assert xy_request.config.x.name == "date"
|
||||
assert xy_request.config.kind == "line"
|
||||
assert xy_request.config["chart_type"] == "xy"
|
||||
assert xy_request.config["x"]["name"] == "date"
|
||||
assert xy_request.config["kind"] == "line"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_dataset_id_types(self):
|
||||
@@ -158,7 +158,7 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.kind == chart_type
|
||||
assert request.config["kind"] == chart_type
|
||||
|
||||
# Test multiple Y columns
|
||||
multi_y_config = XYChartConfig(
|
||||
@@ -174,8 +174,8 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=multi_y_config
|
||||
)
|
||||
assert len(request.config.y) == 3
|
||||
assert request.config.y[1].aggregate == "AVG"
|
||||
assert len(request.config["y"]) == 3
|
||||
assert request.config["y"][1]["aggregate"] == "AVG"
|
||||
|
||||
# Test filter operators
|
||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||
@@ -188,7 +188,7 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=table_config
|
||||
)
|
||||
assert len(request.config.filters) == 6
|
||||
assert len(request.config["filters"]) == 6
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_response_structure(self):
|
||||
@@ -251,10 +251,10 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.x_axis.title == "Date"
|
||||
assert request.config.x_axis.format == "smart_date"
|
||||
assert request.config.y_axis.title == "Sales Amount"
|
||||
assert request.config.y_axis.format == "$,.2f"
|
||||
assert request.config["x_axis"]["title"] == "Date"
|
||||
assert request.config["x_axis"]["format"] == "smart_date"
|
||||
assert request.config["y_axis"]["title"] == "Sales Amount"
|
||||
assert request.config["y_axis"]["format"] == "$,.2f"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_legend_configurations(self):
|
||||
@@ -270,8 +270,8 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.legend.position == pos
|
||||
assert request.config.legend.show is True
|
||||
assert request.config["legend"]["position"] == pos
|
||||
assert request.config["legend"]["show"] is True
|
||||
|
||||
# Hidden legend
|
||||
config = XYChartConfig(
|
||||
@@ -283,7 +283,7 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.legend.show is False
|
||||
assert request.config["legend"]["show"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_aggregation_functions(self):
|
||||
@@ -297,7 +297,7 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.columns[0].aggregate == agg
|
||||
assert request.config["columns"][0]["aggregate"] == agg
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_error_responses(self):
|
||||
@@ -347,10 +347,10 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert len(request.config.filters) == 3
|
||||
assert request.config.filters[0].column == "region"
|
||||
assert request.config.filters[1].op == ">="
|
||||
assert request.config.filters[2].value == "2024-01-01"
|
||||
assert len(request.config["filters"]) == 3
|
||||
assert request.config["filters"][0]["column"] == "region"
|
||||
assert request.config["filters"][1]["op"] == ">="
|
||||
assert request.config["filters"][2]["value"] == "2024-01-01"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_form_data_key_handling(self):
|
||||
@@ -447,12 +447,12 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert len(request.config.y) == 4
|
||||
assert request.config.y[0].name == "revenue"
|
||||
assert request.config.y[1].name == "cost"
|
||||
assert request.config.y[2].name == "profit"
|
||||
assert request.config.y[3].name == "orders"
|
||||
assert request.config.y[3].aggregate == "COUNT"
|
||||
assert len(request.config["y"]) == 4
|
||||
assert request.config["y"][0]["name"] == "revenue"
|
||||
assert request.config["y"][1]["name"] == "cost"
|
||||
assert request.config["y"][2]["name"] == "profit"
|
||||
assert request.config["y"][3]["name"] == "orders"
|
||||
assert request.config["y"][3]["aggregate"] == "COUNT"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_chart_preview_table_sorting(self):
|
||||
@@ -470,5 +470,5 @@ class TestUpdateChartPreview:
|
||||
request = UpdateChartPreviewRequest(
|
||||
form_data_key="abc123", dataset_id=1, config=config
|
||||
)
|
||||
assert request.config.sort_by == ["sales", "profit"]
|
||||
assert len(request.config.columns) == 3
|
||||
assert request.config["sort_by"] == ["sales", "profit"]
|
||||
assert len(request.config["columns"]) == 3
|
||||
|
||||
Reference in New Issue
Block a user