mirror of
https://github.com/apache/superset.git
synced 2026-04-07 02:21:51 +00:00
fix(mcp): compress chart config schemas to reduce search_tools token usage (#39018)
This commit is contained in:
@@ -36,6 +36,7 @@ from pydantic import (
|
|||||||
model_serializer,
|
model_serializer,
|
||||||
model_validator,
|
model_validator,
|
||||||
PositiveInt,
|
PositiveInt,
|
||||||
|
TypeAdapter,
|
||||||
)
|
)
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
@@ -1145,7 +1146,7 @@ class XYChartConfig(UnknownFieldCheckMixin):
|
|||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
# Discriminated union entry point with custom error handling
|
# Discriminated union for runtime validation (not exposed in JSON Schema)
|
||||||
ChartConfig = Annotated[
|
ChartConfig = Annotated[
|
||||||
XYChartConfig
|
XYChartConfig
|
||||||
| TableChartConfig
|
| TableChartConfig
|
||||||
@@ -1164,6 +1165,66 @@ ChartConfig = Annotated[
|
|||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Module-level TypeAdapter avoids repeated schema compilation in
|
||||||
|
# parse_chart_config() — safe because ChartConfig is fully defined above.
|
||||||
|
_CHART_CONFIG_ADAPTER: TypeAdapter[ChartConfig] = TypeAdapter(ChartConfig)
|
||||||
|
|
||||||
|
# Compact description for JSON Schema — keeps tool inputSchema small while
|
||||||
|
# giving LLMs enough context to construct valid configs.
|
||||||
|
_CHART_CONFIG_DESCRIPTION = (
|
||||||
|
"Chart configuration object. MUST include 'chart_type' to select the "
|
||||||
|
"schema. Types: 'xy' (x, y, kind: line/bar/area/scatter), "
|
||||||
|
"'table' (columns), 'pie' (dimension, metric), "
|
||||||
|
"'pivot_table' (rows, metrics), 'mixed_timeseries' (x, y, y_secondary), "
|
||||||
|
"'handlebars' (columns, handlebars_template), "
|
||||||
|
"'big_number' (metric). "
|
||||||
|
"See chart://configs resource for full field reference and examples."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_chart_config(
|
||||||
|
config: Dict[str, Any],
|
||||||
|
) -> (
|
||||||
|
XYChartConfig
|
||||||
|
| TableChartConfig
|
||||||
|
| PieChartConfig
|
||||||
|
| PivotTableChartConfig
|
||||||
|
| MixedTimeseriesChartConfig
|
||||||
|
| HandlebarsChartConfig
|
||||||
|
| BigNumberChartConfig
|
||||||
|
):
|
||||||
|
"""Parse a raw dict into the appropriate typed ChartConfig subclass.
|
||||||
|
|
||||||
|
Validates the dict against the discriminated union using chart_type.
|
||||||
|
Call this in tool function bodies to get a typed config object.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
return _CHART_CONFIG_ADAPTER.validate_python(config)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"{e}\n\n"
|
||||||
|
f"Hint: read the chart://configs resource for valid configuration "
|
||||||
|
f"examples and field reference."
|
||||||
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def _coerce_config_to_dict(v: Any) -> Dict[str, Any]:
|
||||||
|
"""Accept ChartConfig objects, dicts, or JSON strings for the config field."""
|
||||||
|
if isinstance(v, str):
|
||||||
|
from superset.utils import json as json_utils
|
||||||
|
|
||||||
|
try:
|
||||||
|
v = json_utils.loads(v)
|
||||||
|
except (ValueError, TypeError) as exc:
|
||||||
|
raise ValueError(
|
||||||
|
f"config must be a JSON object string, got: {v!r}"
|
||||||
|
) from exc
|
||||||
|
if hasattr(v, "model_dump"):
|
||||||
|
return v.model_dump()
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return v
|
||||||
|
raise TypeError(f"config must be a dict or JSON string, got {type(v).__name__}")
|
||||||
|
|
||||||
|
|
||||||
class ListChartsRequest(MetadataCacheControl):
|
class ListChartsRequest(MetadataCacheControl):
|
||||||
"""Request schema for list_charts with clear, unambiguous types."""
|
"""Request schema for list_charts with clear, unambiguous types."""
|
||||||
@@ -1259,7 +1320,7 @@ class ListChartsRequest(MetadataCacheControl):
|
|||||||
# The tool input models
|
# The tool input models
|
||||||
class GenerateChartRequest(QueryCacheControl):
|
class GenerateChartRequest(QueryCacheControl):
|
||||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||||
config: ChartConfig = Field(..., description="Chart configuration")
|
config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
|
||||||
chart_name: str | None = Field(
|
chart_name: str | None = Field(
|
||||||
None, description="Auto-generates if omitted", max_length=255
|
None, description="Auto-generates if omitted", max_length=255
|
||||||
)
|
)
|
||||||
@@ -1269,6 +1330,11 @@ class GenerateChartRequest(QueryCacheControl):
|
|||||||
default_factory=lambda: ["url"],
|
default_factory=lambda: ["url"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("config", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def coerce_config(cls, v: Any) -> Dict[str, Any]:
|
||||||
|
return _coerce_config_to_dict(v)
|
||||||
|
|
||||||
@field_validator("chart_name")
|
@field_validator("chart_name")
|
||||||
@classmethod
|
@classmethod
|
||||||
def sanitize_chart_name(cls, v: str | None) -> str | None:
|
def sanitize_chart_name(cls, v: str | None) -> str | None:
|
||||||
@@ -1301,16 +1367,20 @@ class GenerateChartRequest(QueryCacheControl):
|
|||||||
|
|
||||||
class GenerateExploreLinkRequest(FormDataCacheControl):
|
class GenerateExploreLinkRequest(FormDataCacheControl):
|
||||||
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
dataset_id: int | str = Field(..., description="Dataset identifier (ID, UUID)")
|
||||||
config: ChartConfig = Field(..., description="Chart configuration")
|
config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
|
||||||
|
|
||||||
|
@field_validator("config", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def coerce_config(cls, v: Any) -> Dict[str, Any]:
|
||||||
|
return _coerce_config_to_dict(v)
|
||||||
|
|
||||||
|
|
||||||
class UpdateChartRequest(QueryCacheControl):
|
class UpdateChartRequest(QueryCacheControl):
|
||||||
identifier: int | str = Field(..., description="Chart ID or UUID")
|
identifier: int | str = Field(..., description="Chart ID or UUID")
|
||||||
config: ChartConfig | None = Field(
|
config: Dict[str, Any] | None = Field(
|
||||||
None,
|
None,
|
||||||
description=(
|
description=(
|
||||||
"Chart configuration. Required for visualization changes. "
|
f"{_CHART_CONFIG_DESCRIPTION} Optional; omit to only update chart_name."
|
||||||
"Omit to only update the chart name."
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
chart_name: str | None = Field(
|
chart_name: str | None = Field(
|
||||||
@@ -1321,6 +1391,13 @@ class UpdateChartRequest(QueryCacheControl):
|
|||||||
default_factory=lambda: ["url"],
|
default_factory=lambda: ["url"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("config", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def coerce_config(cls, v: Any) -> Dict[str, Any] | None:
|
||||||
|
if v is None:
|
||||||
|
return None
|
||||||
|
return _coerce_config_to_dict(v)
|
||||||
|
|
||||||
@field_validator("chart_name")
|
@field_validator("chart_name")
|
||||||
@classmethod
|
@classmethod
|
||||||
def sanitize_chart_name(cls, v: str | None) -> str | None:
|
def sanitize_chart_name(cls, v: str | None) -> str | None:
|
||||||
@@ -1331,12 +1408,17 @@ class UpdateChartRequest(QueryCacheControl):
|
|||||||
class UpdateChartPreviewRequest(FormDataCacheControl):
|
class UpdateChartPreviewRequest(FormDataCacheControl):
|
||||||
form_data_key: str = Field(..., description="Existing form_data_key to update")
|
form_data_key: str = Field(..., description="Existing form_data_key to update")
|
||||||
dataset_id: int | str = Field(..., description="Dataset ID or UUID")
|
dataset_id: int | str = Field(..., description="Dataset ID or UUID")
|
||||||
config: ChartConfig
|
config: Dict[str, Any] = Field(..., description=_CHART_CONFIG_DESCRIPTION)
|
||||||
generate_preview: bool = True
|
generate_preview: bool = True
|
||||||
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
|
preview_formats: List[Literal["url", "ascii", "vega_lite", "table"]] = Field(
|
||||||
default_factory=lambda: ["url"],
|
default_factory=lambda: ["url"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@field_validator("config", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def coerce_config(cls, v: Any) -> Dict[str, Any]:
|
||||||
|
return _coerce_config_to_dict(v)
|
||||||
|
|
||||||
|
|
||||||
class GetChartDataRequest(QueryCacheControl):
|
class GetChartDataRequest(QueryCacheControl):
|
||||||
"""Request for chart data with cache control.
|
"""Request for chart data with cache control.
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from superset.mcp_service.chart.schemas import (
|
|||||||
ChartError,
|
ChartError,
|
||||||
GenerateChartRequest,
|
GenerateChartRequest,
|
||||||
GenerateChartResponse,
|
GenerateChartResponse,
|
||||||
|
parse_chart_config,
|
||||||
PerformanceMetadata,
|
PerformanceMetadata,
|
||||||
)
|
)
|
||||||
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
from superset.mcp_service.utils.url_utils import get_superset_base_url
|
||||||
@@ -209,13 +210,17 @@ async def generate_chart( # noqa: C901
|
|||||||
"save_chart=%s, preview_formats=%s"
|
"save_chart=%s, preview_formats=%s"
|
||||||
% (
|
% (
|
||||||
request.dataset_id,
|
request.dataset_id,
|
||||||
request.config.chart_type,
|
request.config.get("chart_type", "unknown"),
|
||||||
request.save_chart,
|
request.save_chart,
|
||||||
request.preview_formats,
|
request.preview_formats,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
await ctx.debug(
|
await ctx.debug(
|
||||||
"Chart configuration details: config=%s" % (request.config.model_dump(),)
|
"Chart configuration details: chart_type=%s, keys=%s"
|
||||||
|
% (
|
||||||
|
request.config.get("chart_type", "unknown"),
|
||||||
|
sorted(request.config.keys()),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Track runtime warnings to include in response
|
# Track runtime warnings to include in response
|
||||||
@@ -269,11 +274,12 @@ async def generate_chart( # noqa: C901
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Parse the raw config dict into a typed ChartConfig for downstream use
|
||||||
|
config = parse_chart_config(request.config)
|
||||||
|
|
||||||
# Map the simplified config to Superset's form_data format
|
# Map the simplified config to Superset's form_data format
|
||||||
# Pass dataset_id to enable column type checking for proper viz_type selection
|
# Pass dataset_id to enable column type checking for proper viz_type selection
|
||||||
form_data = map_config_to_form_data(
|
form_data = map_config_to_form_data(config, dataset_id=request.dataset_id)
|
||||||
request.config, dataset_id=request.dataset_id
|
|
||||||
)
|
|
||||||
|
|
||||||
chart = None
|
chart = None
|
||||||
chart_id = None
|
chart_id = None
|
||||||
@@ -367,7 +373,7 @@ async def generate_chart( # noqa: C901
|
|||||||
dataset, "table_name", None
|
dataset, "table_name", None
|
||||||
)
|
)
|
||||||
chart_name = request.chart_name or generate_chart_name(
|
chart_name = request.chart_name or generate_chart_name(
|
||||||
request.config, dataset_name=dataset_name
|
config, dataset_name=dataset_name
|
||||||
)
|
)
|
||||||
await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
|
await ctx.debug("Chart name: chart_name=%s" % (chart_name,))
|
||||||
|
|
||||||
@@ -607,8 +613,8 @@ async def generate_chart( # noqa: C901
|
|||||||
response_warnings.extend(compile_result.warnings)
|
response_warnings.extend(compile_result.warnings)
|
||||||
|
|
||||||
# Generate semantic analysis
|
# Generate semantic analysis
|
||||||
capabilities = analyze_chart_capabilities(chart, request.config)
|
capabilities = analyze_chart_capabilities(chart, config)
|
||||||
semantics = analyze_chart_semantics(chart, request.config)
|
semantics = analyze_chart_semantics(chart, config)
|
||||||
|
|
||||||
# Create performance metadata
|
# Create performance metadata
|
||||||
execution_time = int((time.time() - start_time) * 1000)
|
execution_time = int((time.time() - start_time) * 1000)
|
||||||
@@ -622,7 +628,7 @@ async def generate_chart( # noqa: C901
|
|||||||
chart_name = (
|
chart_name = (
|
||||||
chart.slice_name
|
chart.slice_name
|
||||||
if chart and hasattr(chart, "slice_name")
|
if chart and hasattr(chart, "slice_name")
|
||||||
else generate_chart_name(request.config)
|
else generate_chart_name(config)
|
||||||
)
|
)
|
||||||
accessibility = AccessibilityMetadata(
|
accessibility = AccessibilityMetadata(
|
||||||
color_blind_safe=True, # Would need actual analysis
|
color_blind_safe=True, # Would need actual analysis
|
||||||
@@ -843,9 +849,9 @@ async def generate_chart( # noqa: C901
|
|||||||
# Extract chart_type from different sources for better error context
|
# Extract chart_type from different sources for better error context
|
||||||
chart_type = "unknown"
|
chart_type = "unknown"
|
||||||
try:
|
try:
|
||||||
if hasattr(request, "config") and hasattr(request.config, "chart_type"):
|
if hasattr(request, "config") and isinstance(request.config, dict):
|
||||||
chart_type = request.config.chart_type
|
chart_type = request.config.get("chart_type", "unknown")
|
||||||
except AttributeError as extract_error:
|
except (AttributeError, TypeError) as extract_error:
|
||||||
# Ignore errors when extracting chart type for error context
|
# Ignore errors when extracting chart type for error context
|
||||||
logger.debug("Could not extract chart type: %s", extract_error)
|
logger.debug("Could not extract chart type: %s", extract_error)
|
||||||
|
|
||||||
|
|||||||
@@ -38,6 +38,7 @@ from superset.mcp_service.chart.chart_utils import (
|
|||||||
from superset.mcp_service.chart.schemas import (
|
from superset.mcp_service.chart.schemas import (
|
||||||
AccessibilityMetadata,
|
AccessibilityMetadata,
|
||||||
GenerateChartResponse,
|
GenerateChartResponse,
|
||||||
|
parse_chart_config,
|
||||||
PerformanceMetadata,
|
PerformanceMetadata,
|
||||||
UpdateChartRequest,
|
UpdateChartRequest,
|
||||||
)
|
)
|
||||||
@@ -69,14 +70,15 @@ def _build_update_payload(
|
|||||||
when neither config nor chart_name is provided.
|
when neither config nor chart_name is provided.
|
||||||
"""
|
"""
|
||||||
if request.config is not None:
|
if request.config is not None:
|
||||||
|
config = parse_chart_config(request.config)
|
||||||
dataset_id = chart.datasource_id if chart.datasource_id else None
|
dataset_id = chart.datasource_id if chart.datasource_id else None
|
||||||
new_form_data = map_config_to_form_data(request.config, dataset_id=dataset_id)
|
new_form_data = map_config_to_form_data(config, dataset_id=dataset_id)
|
||||||
new_form_data.pop("_mcp_warnings", None)
|
new_form_data.pop("_mcp_warnings", None)
|
||||||
|
|
||||||
chart_name = (
|
chart_name = (
|
||||||
request.chart_name
|
request.chart_name
|
||||||
if request.chart_name
|
if request.chart_name
|
||||||
else chart.slice_name or generate_chart_name(request.config)
|
else chart.slice_name or generate_chart_name(config)
|
||||||
)
|
)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
@@ -222,9 +224,12 @@ async def update_chart(
|
|||||||
command = UpdateChartCommand(chart.id, payload_or_error)
|
command = UpdateChartCommand(chart.id, payload_or_error)
|
||||||
updated_chart = command.run()
|
updated_chart = command.run()
|
||||||
|
|
||||||
|
# Parse config for analysis (may be None for name-only updates)
|
||||||
|
config = parse_chart_config(request.config) if request.config else None
|
||||||
|
|
||||||
# Generate semantic analysis
|
# Generate semantic analysis
|
||||||
capabilities = analyze_chart_capabilities(updated_chart, request.config)
|
capabilities = analyze_chart_capabilities(updated_chart, config)
|
||||||
semantics = analyze_chart_semantics(updated_chart, request.config)
|
semantics = analyze_chart_semantics(updated_chart, config)
|
||||||
|
|
||||||
# Create performance metadata
|
# Create performance metadata
|
||||||
execution_time = int((time.time() - start_time) * 1000)
|
execution_time = int((time.time() - start_time) * 1000)
|
||||||
@@ -238,11 +243,7 @@ async def update_chart(
|
|||||||
chart_name = (
|
chart_name = (
|
||||||
updated_chart.slice_name
|
updated_chart.slice_name
|
||||||
if updated_chart and hasattr(updated_chart, "slice_name")
|
if updated_chart and hasattr(updated_chart, "slice_name")
|
||||||
else (
|
else (generate_chart_name(config) if config else "Updated chart")
|
||||||
generate_chart_name(request.config)
|
|
||||||
if request.config
|
|
||||||
else "Updated chart"
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
accessibility = AccessibilityMetadata(
|
accessibility = AccessibilityMetadata(
|
||||||
color_blind_safe=True, # Would need actual analysis
|
color_blind_safe=True, # Would need actual analysis
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ from superset.mcp_service.chart.chart_utils import (
|
|||||||
)
|
)
|
||||||
from superset.mcp_service.chart.schemas import (
|
from superset.mcp_service.chart.schemas import (
|
||||||
AccessibilityMetadata,
|
AccessibilityMetadata,
|
||||||
|
parse_chart_config,
|
||||||
PerformanceMetadata,
|
PerformanceMetadata,
|
||||||
UpdateChartPreviewRequest,
|
UpdateChartPreviewRequest,
|
||||||
)
|
)
|
||||||
@@ -95,20 +96,20 @@ def update_chart_preview(
|
|||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Parse the raw config dict into a typed ChartConfig
|
||||||
|
config = parse_chart_config(request.config)
|
||||||
|
|
||||||
with event_logger.log_context(action="mcp.update_chart_preview.form_data"):
|
with event_logger.log_context(action="mcp.update_chart_preview.form_data"):
|
||||||
# Map the new config to form_data format
|
# Map the new config to form_data format
|
||||||
# Pass dataset_id to enable column type checking
|
# Pass dataset_id to enable column type checking
|
||||||
new_form_data = map_config_to_form_data(
|
new_form_data = map_config_to_form_data(
|
||||||
request.config, dataset_id=request.dataset_id
|
config, dataset_id=request.dataset_id
|
||||||
)
|
)
|
||||||
new_form_data.pop("_mcp_warnings", None)
|
new_form_data.pop("_mcp_warnings", None)
|
||||||
|
|
||||||
# Preserve adhoc filters from the previous cached form_data
|
# Preserve adhoc filters from the previous cached form_data
|
||||||
# when the new config doesn't explicitly specify filters
|
# when the new config doesn't explicitly specify filters
|
||||||
if (
|
if getattr(config, "filters", None) is None and request.form_data_key:
|
||||||
getattr(request.config, "filters", None) is None
|
|
||||||
and request.form_data_key
|
|
||||||
):
|
|
||||||
old_adhoc_filters = _get_old_adhoc_filters(request.form_data_key)
|
old_adhoc_filters = _get_old_adhoc_filters(request.form_data_key)
|
||||||
if old_adhoc_filters:
|
if old_adhoc_filters:
|
||||||
new_form_data["adhoc_filters"] = old_adhoc_filters
|
new_form_data["adhoc_filters"] = old_adhoc_filters
|
||||||
@@ -123,8 +124,8 @@ def update_chart_preview(
|
|||||||
|
|
||||||
with event_logger.log_context(action="mcp.update_chart_preview.metadata"):
|
with event_logger.log_context(action="mcp.update_chart_preview.metadata"):
|
||||||
# Generate semantic analysis
|
# Generate semantic analysis
|
||||||
capabilities = analyze_chart_capabilities(None, request.config)
|
capabilities = analyze_chart_capabilities(None, config)
|
||||||
semantics = analyze_chart_semantics(None, request.config)
|
semantics = analyze_chart_semantics(None, config)
|
||||||
|
|
||||||
# Create performance metadata
|
# Create performance metadata
|
||||||
execution_time = int((time.time() - start_time) * 1000)
|
execution_time = int((time.time() - start_time) * 1000)
|
||||||
@@ -135,7 +136,7 @@ def update_chart_preview(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Create accessibility metadata
|
# Create accessibility metadata
|
||||||
chart_name = generate_chart_name(request.config)
|
chart_name = generate_chart_name(config)
|
||||||
accessibility = AccessibilityMetadata(
|
accessibility = AccessibilityMetadata(
|
||||||
color_blind_safe=True, # Would need actual analysis
|
color_blind_safe=True, # Would need actual analysis
|
||||||
alt_text=f"Updated chart preview showing {chart_name}",
|
alt_text=f"Updated chart preview showing {chart_name}",
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from typing import Any, Dict, List, Tuple
|
|||||||
from superset.mcp_service.chart.schemas import (
|
from superset.mcp_service.chart.schemas import (
|
||||||
ChartConfig,
|
ChartConfig,
|
||||||
GenerateChartRequest,
|
GenerateChartRequest,
|
||||||
|
parse_chart_config,
|
||||||
)
|
)
|
||||||
from superset.mcp_service.common.error_schemas import (
|
from superset.mcp_service.common.error_schemas import (
|
||||||
ChartGenerationError,
|
ChartGenerationError,
|
||||||
@@ -171,6 +172,10 @@ class ValidationPipeline:
|
|||||||
if request is None:
|
if request is None:
|
||||||
return ValidationResult(is_valid=False, error=error)
|
return ValidationResult(is_valid=False, error=error)
|
||||||
|
|
||||||
|
# Parse the raw config dict into a typed ChartConfig for
|
||||||
|
# downstream validators that need typed access.
|
||||||
|
typed_config = parse_chart_config(request.config)
|
||||||
|
|
||||||
# Fetch dataset context once and reuse across validation layers
|
# Fetch dataset context once and reuse across validation layers
|
||||||
dataset_context = ValidationPipeline._get_dataset_context(
|
dataset_context = ValidationPipeline._get_dataset_context(
|
||||||
request.dataset_id
|
request.dataset_id
|
||||||
@@ -178,20 +183,20 @@ class ValidationPipeline:
|
|||||||
|
|
||||||
# Layer 2: Dataset validation (reuses context)
|
# Layer 2: Dataset validation (reuses context)
|
||||||
is_valid, error = ValidationPipeline._validate_dataset(
|
is_valid, error = ValidationPipeline._validate_dataset(
|
||||||
request.config, request.dataset_id, dataset_context
|
typed_config, request.dataset_id, dataset_context
|
||||||
)
|
)
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
return ValidationResult(is_valid=False, request=request, error=error)
|
return ValidationResult(is_valid=False, request=request, error=error)
|
||||||
|
|
||||||
# Layer 3: Runtime validation - returns warnings as metadata, not errors
|
# Layer 3: Runtime validation - returns warnings as metadata, not errors
|
||||||
_is_valid, warnings_metadata = ValidationPipeline._validate_runtime(
|
_is_valid, warnings_metadata = ValidationPipeline._validate_runtime(
|
||||||
request.config, request.dataset_id
|
typed_config, request.dataset_id
|
||||||
)
|
)
|
||||||
# Runtime validation always returns True now, warnings are informational
|
# Runtime validation always returns True now, warnings are informational
|
||||||
|
|
||||||
# Layer 4: Column name normalization (reuses context)
|
# Layer 4: Column name normalization (reuses context)
|
||||||
normalized_request = ValidationPipeline._normalize_column_names(
|
normalized_request = ValidationPipeline._normalize_column_names(
|
||||||
request, dataset_context
|
request, dataset_context, typed_config=typed_config
|
||||||
)
|
)
|
||||||
|
|
||||||
return ValidationResult(
|
return ValidationResult(
|
||||||
@@ -284,6 +289,7 @@ class ValidationPipeline:
|
|||||||
def _normalize_column_names(
|
def _normalize_column_names(
|
||||||
request: GenerateChartRequest,
|
request: GenerateChartRequest,
|
||||||
dataset_context: DatasetContext | None = None,
|
dataset_context: DatasetContext | None = None,
|
||||||
|
typed_config: ChartConfig | None = None,
|
||||||
) -> GenerateChartRequest:
|
) -> GenerateChartRequest:
|
||||||
"""
|
"""
|
||||||
Normalize column names in the request to match canonical dataset names.
|
Normalize column names in the request to match canonical dataset names.
|
||||||
@@ -297,6 +303,8 @@ class ValidationPipeline:
|
|||||||
request: The validated chart generation request
|
request: The validated chart generation request
|
||||||
dataset_context: Pre-fetched dataset context to avoid duplicate
|
dataset_context: Pre-fetched dataset context to avoid duplicate
|
||||||
DB queries. If None, fetches from the database.
|
DB queries. If None, fetches from the database.
|
||||||
|
typed_config: Pre-parsed typed ChartConfig. If None, parses from
|
||||||
|
request.config dict.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A new request with normalized column names
|
A new request with normalized column names
|
||||||
@@ -304,8 +312,9 @@ class ValidationPipeline:
|
|||||||
try:
|
try:
|
||||||
from .dataset_validator import DatasetValidator
|
from .dataset_validator import DatasetValidator
|
||||||
|
|
||||||
|
config = typed_config or parse_chart_config(request.config)
|
||||||
normalized_config = DatasetValidator.normalize_column_names(
|
normalized_config = DatasetValidator.normalize_column_names(
|
||||||
request.config,
|
config,
|
||||||
request.dataset_id,
|
request.dataset_id,
|
||||||
dataset_context=dataset_context,
|
dataset_context=dataset_context,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from superset.mcp_service.chart.chart_utils import (
|
|||||||
)
|
)
|
||||||
from superset.mcp_service.chart.schemas import (
|
from superset.mcp_service.chart.schemas import (
|
||||||
GenerateExploreLinkRequest,
|
GenerateExploreLinkRequest,
|
||||||
|
parse_chart_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -89,7 +90,7 @@ async def generate_explore_link(
|
|||||||
"""
|
"""
|
||||||
await ctx.info(
|
await ctx.info(
|
||||||
"Generating explore link for dataset_id=%s, chart_type=%s"
|
"Generating explore link for dataset_id=%s, chart_type=%s"
|
||||||
% (request.dataset_id, request.config.chart_type)
|
% (request.dataset_id, request.config.get("chart_type", "unknown"))
|
||||||
)
|
)
|
||||||
await ctx.debug(
|
await ctx.debug(
|
||||||
"Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s"
|
"Configuration details: use_cache=%s, force_refresh=%s, cache_form_data=%s"
|
||||||
@@ -97,6 +98,9 @@ async def generate_explore_link(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Parse the raw config dict into a typed ChartConfig
|
||||||
|
config = parse_chart_config(request.config)
|
||||||
|
|
||||||
await ctx.report_progress(1, 4, "Validating dataset exists")
|
await ctx.report_progress(1, 4, "Validating dataset exists")
|
||||||
with event_logger.log_context(action="mcp.generate_explore_link.dataset_check"):
|
with event_logger.log_context(action="mcp.generate_explore_link.dataset_check"):
|
||||||
from superset.daos.dataset import DatasetDAO
|
from superset.daos.dataset import DatasetDAO
|
||||||
@@ -138,10 +142,10 @@ async def generate_explore_link(
|
|||||||
)
|
)
|
||||||
|
|
||||||
normalized_config = DatasetValidator.normalize_column_names(
|
normalized_config = DatasetValidator.normalize_column_names(
|
||||||
request.config, request.dataset_id
|
config, request.dataset_id
|
||||||
)
|
)
|
||||||
except (ImportError, AttributeError, KeyError, ValueError, TypeError):
|
except (ImportError, AttributeError, KeyError, ValueError, TypeError):
|
||||||
normalized_config = request.config
|
normalized_config = config
|
||||||
|
|
||||||
# Map config to form_data using shared utilities
|
# Map config to form_data using shared utilities
|
||||||
form_data = map_config_to_form_data(
|
form_data = map_config_to_form_data(
|
||||||
@@ -197,7 +201,12 @@ async def generate_explore_link(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
await ctx.error(
|
await ctx.error(
|
||||||
"Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s"
|
"Explore link generation failed for dataset_id=%s, chart_type=%s: %s: %s"
|
||||||
% (request.dataset_id, request.config.chart_type, type(e).__name__, str(e))
|
% (
|
||||||
|
request.dataset_id,
|
||||||
|
request.config.get("chart_type", "unknown"),
|
||||||
|
type(e).__name__,
|
||||||
|
str(e),
|
||||||
|
)
|
||||||
)
|
)
|
||||||
return {
|
return {
|
||||||
"url": "",
|
"url": "",
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ from pydantic import ValidationError
|
|||||||
|
|
||||||
from superset.mcp_service.chart.schemas import (
|
from superset.mcp_service.chart.schemas import (
|
||||||
ColumnRef,
|
ColumnRef,
|
||||||
|
GenerateChartRequest,
|
||||||
|
parse_chart_config,
|
||||||
TableChartConfig,
|
TableChartConfig,
|
||||||
XYChartConfig,
|
XYChartConfig,
|
||||||
)
|
)
|
||||||
@@ -625,3 +627,48 @@ class TestColumnRefSavedMetric:
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseChartConfig:
|
||||||
|
"""Tests for parse_chart_config and config coercion."""
|
||||||
|
|
||||||
|
def test_parse_valid_xy_config(self) -> None:
|
||||||
|
config = parse_chart_config(
|
||||||
|
{"chart_type": "xy", "x": {"name": "date"}, "y": [{"name": "v"}]}
|
||||||
|
)
|
||||||
|
assert config.chart_type == "xy"
|
||||||
|
assert config.x.name == "date"
|
||||||
|
assert len(config.y) == 1
|
||||||
|
assert config.y[0].name == "v"
|
||||||
|
|
||||||
|
def test_parse_valid_table_config(self) -> None:
|
||||||
|
config = parse_chart_config(
|
||||||
|
{"chart_type": "table", "columns": [{"name": "col1"}]}
|
||||||
|
)
|
||||||
|
assert config.chart_type == "table"
|
||||||
|
assert len(config.columns) == 1
|
||||||
|
assert config.columns[0].name == "col1"
|
||||||
|
|
||||||
|
def test_parse_missing_chart_type_raises(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="chart://configs"):
|
||||||
|
parse_chart_config({"x": {"name": "date"}, "y": [{"name": "v"}]})
|
||||||
|
|
||||||
|
def test_parse_unknown_chart_type_raises(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="chart://configs"):
|
||||||
|
parse_chart_config({"chart_type": "nonexistent", "x": {"name": "d"}})
|
||||||
|
|
||||||
|
def test_coerce_json_string_config(self) -> None:
|
||||||
|
raw = '{"chart_type": "table", "columns": [{"name": "c"}]}'
|
||||||
|
req = GenerateChartRequest(dataset_id=1, config=raw)
|
||||||
|
assert isinstance(req.config, dict)
|
||||||
|
assert req.config["chart_type"] == "table"
|
||||||
|
|
||||||
|
def test_coerce_typed_config_object(self) -> None:
|
||||||
|
typed = TableChartConfig(chart_type="table", columns=[ColumnRef(name="c")])
|
||||||
|
req = GenerateChartRequest(dataset_id=1, config=typed)
|
||||||
|
assert isinstance(req.config, dict)
|
||||||
|
assert req.config["chart_type"] == "table"
|
||||||
|
|
||||||
|
def test_coerce_invalid_json_string_raises(self) -> None:
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
GenerateChartRequest(dataset_id=1, config="not valid json")
|
||||||
|
|||||||
@@ -57,10 +57,11 @@ class TestGenerateChart:
|
|||||||
)
|
)
|
||||||
table_request = GenerateChartRequest(dataset_id="1", config=table_config)
|
table_request = GenerateChartRequest(dataset_id="1", config=table_config)
|
||||||
assert table_request.dataset_id == "1"
|
assert table_request.dataset_id == "1"
|
||||||
assert table_request.config.chart_type == "table"
|
# config is now Dict[str, Any] in the schema; validate via dict access
|
||||||
assert len(table_request.config.columns) == 2
|
assert table_request.config["chart_type"] == "table"
|
||||||
assert table_request.config.columns[0].name == "region"
|
assert len(table_request.config["columns"]) == 2
|
||||||
assert table_request.config.columns[1].aggregate == "SUM"
|
assert table_request.config["columns"][0]["name"] == "region"
|
||||||
|
assert table_request.config["columns"][1]["aggregate"] == "SUM"
|
||||||
|
|
||||||
# XY chart request
|
# XY chart request
|
||||||
xy_config = XYChartConfig(
|
xy_config = XYChartConfig(
|
||||||
@@ -74,12 +75,12 @@ class TestGenerateChart:
|
|||||||
legend=LegendConfig(show=True, position="top"),
|
legend=LegendConfig(show=True, position="top"),
|
||||||
)
|
)
|
||||||
xy_request = GenerateChartRequest(dataset_id="2", config=xy_config)
|
xy_request = GenerateChartRequest(dataset_id="2", config=xy_config)
|
||||||
assert xy_request.config.chart_type == "xy"
|
assert xy_request.config["chart_type"] == "xy"
|
||||||
assert xy_request.config.x.name == "date"
|
assert xy_request.config["x"]["name"] == "date"
|
||||||
assert xy_request.config.y[0].aggregate == "SUM"
|
assert xy_request.config["y"][0]["aggregate"] == "SUM"
|
||||||
assert xy_request.config.kind == "line"
|
assert xy_request.config["kind"] == "line"
|
||||||
assert xy_request.config.x_axis.title == "Date"
|
assert xy_request.config["x_axis"]["title"] == "Date"
|
||||||
assert xy_request.config.legend.show is True
|
assert xy_request.config["legend"]["show"] is True
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_generate_chart_validation_error_handling(self):
|
async def test_generate_chart_validation_error_handling(self):
|
||||||
|
|||||||
@@ -60,10 +60,11 @@ class TestUpdateChart:
|
|||||||
)
|
)
|
||||||
table_request = UpdateChartRequest(identifier=123, config=table_config)
|
table_request = UpdateChartRequest(identifier=123, config=table_config)
|
||||||
assert table_request.identifier == 123
|
assert table_request.identifier == 123
|
||||||
assert table_request.config.chart_type == "table"
|
# config is now Dict[str, Any] in the schema; validate via dict access
|
||||||
assert len(table_request.config.columns) == 2
|
assert table_request.config["chart_type"] == "table"
|
||||||
assert table_request.config.columns[0].name == "region"
|
assert len(table_request.config["columns"]) == 2
|
||||||
assert table_request.config.columns[1].aggregate == "SUM"
|
assert table_request.config["columns"][0]["name"] == "region"
|
||||||
|
assert table_request.config["columns"][1]["aggregate"] == "SUM"
|
||||||
|
|
||||||
# XY chart update with UUID
|
# XY chart update with UUID
|
||||||
xy_config = XYChartConfig(
|
xy_config = XYChartConfig(
|
||||||
@@ -80,10 +81,10 @@ class TestUpdateChart:
|
|||||||
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config
|
identifier="a1b2c3d4-e5f6-7890-abcd-ef1234567890", config=xy_config
|
||||||
)
|
)
|
||||||
assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
assert xy_request.identifier == "a1b2c3d4-e5f6-7890-abcd-ef1234567890"
|
||||||
assert xy_request.config.chart_type == "xy"
|
assert xy_request.config["chart_type"] == "xy"
|
||||||
assert xy_request.config.x.name == "date"
|
assert xy_request.config["x"]["name"] == "date"
|
||||||
assert xy_request.config.y[0].aggregate == "SUM"
|
assert xy_request.config["y"][0]["aggregate"] == "SUM"
|
||||||
assert xy_request.config.kind == "line"
|
assert xy_request.config["kind"] == "line"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_with_chart_name(self):
|
async def test_update_chart_with_chart_name(self):
|
||||||
@@ -170,7 +171,7 @@ class TestUpdateChart:
|
|||||||
kind=chart_type,
|
kind=chart_type,
|
||||||
)
|
)
|
||||||
request = UpdateChartRequest(identifier=1, config=config)
|
request = UpdateChartRequest(identifier=1, config=config)
|
||||||
assert request.config.kind == chart_type
|
assert request.config["kind"] == chart_type
|
||||||
|
|
||||||
# Test multiple Y columns
|
# Test multiple Y columns
|
||||||
multi_y_config = XYChartConfig(
|
multi_y_config = XYChartConfig(
|
||||||
@@ -184,8 +185,8 @@ class TestUpdateChart:
|
|||||||
kind="line",
|
kind="line",
|
||||||
)
|
)
|
||||||
request = UpdateChartRequest(identifier=1, config=multi_y_config)
|
request = UpdateChartRequest(identifier=1, config=multi_y_config)
|
||||||
assert len(request.config.y) == 3
|
assert len(request.config["y"]) == 3
|
||||||
assert request.config.y[1].aggregate == "AVG"
|
assert request.config["y"][1]["aggregate"] == "AVG"
|
||||||
|
|
||||||
# Test filter operators
|
# Test filter operators
|
||||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||||
@@ -196,7 +197,7 @@ class TestUpdateChart:
|
|||||||
filters=filters,
|
filters=filters,
|
||||||
)
|
)
|
||||||
request = UpdateChartRequest(identifier=1, config=table_config)
|
request = UpdateChartRequest(identifier=1, config=table_config)
|
||||||
assert len(request.config.filters) == 6
|
assert len(request.config["filters"]) == 6
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_response_structure(self):
|
async def test_update_chart_response_structure(self):
|
||||||
@@ -252,12 +253,12 @@ class TestUpdateChart:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
request = UpdateChartRequest(identifier=1, config=config)
|
request = UpdateChartRequest(identifier=1, config=config)
|
||||||
assert request.config.x_axis.title == "Date"
|
assert request.config["x_axis"]["title"] == "Date"
|
||||||
assert request.config.x_axis.format == "smart_date"
|
assert request.config["x_axis"]["format"] == "smart_date"
|
||||||
assert request.config.x_axis.scale == "linear"
|
assert request.config["x_axis"]["scale"] == "linear"
|
||||||
assert request.config.y_axis.title == "Sales Amount"
|
assert request.config["y_axis"]["title"] == "Sales Amount"
|
||||||
assert request.config.y_axis.format == "$,.2f"
|
assert request.config["y_axis"]["format"] == "$,.2f"
|
||||||
assert request.config.y_axis.scale == "log"
|
assert request.config["y_axis"]["scale"] == "log"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_legend_configurations(self):
|
async def test_update_chart_legend_configurations(self):
|
||||||
@@ -271,8 +272,8 @@ class TestUpdateChart:
|
|||||||
legend=LegendConfig(show=True, position=pos),
|
legend=LegendConfig(show=True, position=pos),
|
||||||
)
|
)
|
||||||
request = UpdateChartRequest(identifier=1, config=config)
|
request = UpdateChartRequest(identifier=1, config=config)
|
||||||
assert request.config.legend.position == pos
|
assert request.config["legend"]["position"] == pos
|
||||||
assert request.config.legend.show is True
|
assert request.config["legend"]["show"] is True
|
||||||
|
|
||||||
# Hidden legend
|
# Hidden legend
|
||||||
config = XYChartConfig(
|
config = XYChartConfig(
|
||||||
@@ -282,7 +283,7 @@ class TestUpdateChart:
|
|||||||
legend=LegendConfig(show=False),
|
legend=LegendConfig(show=False),
|
||||||
)
|
)
|
||||||
request = UpdateChartRequest(identifier=1, config=config)
|
request = UpdateChartRequest(identifier=1, config=config)
|
||||||
assert request.config.legend.show is False
|
assert request.config["legend"]["show"] is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_aggregation_functions(self):
|
async def test_update_chart_aggregation_functions(self):
|
||||||
@@ -294,7 +295,7 @@ class TestUpdateChart:
|
|||||||
columns=[ColumnRef(name="value", aggregate=agg)],
|
columns=[ColumnRef(name="value", aggregate=agg)],
|
||||||
)
|
)
|
||||||
request = UpdateChartRequest(identifier=1, config=config)
|
request = UpdateChartRequest(identifier=1, config=config)
|
||||||
assert request.config.columns[0].aggregate == agg
|
assert request.config["columns"][0]["aggregate"] == agg
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_error_responses(self):
|
async def test_update_chart_error_responses(self):
|
||||||
@@ -378,10 +379,10 @@ class TestUpdateChart:
|
|||||||
)
|
)
|
||||||
|
|
||||||
request = UpdateChartRequest(identifier=1, config=config)
|
request = UpdateChartRequest(identifier=1, config=config)
|
||||||
assert len(request.config.filters) == 3
|
assert len(request.config["filters"]) == 3
|
||||||
assert request.config.filters[0].column == "region"
|
assert request.config["filters"][0]["column"] == "region"
|
||||||
assert request.config.filters[1].op == ">="
|
assert request.config["filters"][1]["op"] == ">="
|
||||||
assert request.config.filters[2].value == "2024-01-01"
|
assert request.config["filters"][2]["value"] == "2024-01-01"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_cache_control(self):
|
async def test_update_chart_cache_control(self):
|
||||||
|
|||||||
@@ -53,9 +53,9 @@ class TestUpdateChartPreview:
|
|||||||
)
|
)
|
||||||
assert table_request.form_data_key == "abc123def456"
|
assert table_request.form_data_key == "abc123def456"
|
||||||
assert table_request.dataset_id == 1
|
assert table_request.dataset_id == 1
|
||||||
assert table_request.config.chart_type == "table"
|
assert table_request.config["chart_type"] == "table"
|
||||||
assert len(table_request.config.columns) == 2
|
assert len(table_request.config["columns"]) == 2
|
||||||
assert table_request.config.columns[0].name == "region"
|
assert table_request.config["columns"][0]["name"] == "region"
|
||||||
|
|
||||||
# XY chart preview update
|
# XY chart preview update
|
||||||
xy_config = XYChartConfig(
|
xy_config = XYChartConfig(
|
||||||
@@ -73,9 +73,9 @@ class TestUpdateChartPreview:
|
|||||||
)
|
)
|
||||||
assert xy_request.form_data_key == "xyz789ghi012"
|
assert xy_request.form_data_key == "xyz789ghi012"
|
||||||
assert xy_request.dataset_id == "2"
|
assert xy_request.dataset_id == "2"
|
||||||
assert xy_request.config.chart_type == "xy"
|
assert xy_request.config["chart_type"] == "xy"
|
||||||
assert xy_request.config.x.name == "date"
|
assert xy_request.config["x"]["name"] == "date"
|
||||||
assert xy_request.config.kind == "line"
|
assert xy_request.config["kind"] == "line"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_preview_dataset_id_types(self):
|
async def test_update_chart_preview_dataset_id_types(self):
|
||||||
@@ -158,7 +158,7 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=config
|
form_data_key="abc123", dataset_id=1, config=config
|
||||||
)
|
)
|
||||||
assert request.config.kind == chart_type
|
assert request.config["kind"] == chart_type
|
||||||
|
|
||||||
# Test multiple Y columns
|
# Test multiple Y columns
|
||||||
multi_y_config = XYChartConfig(
|
multi_y_config = XYChartConfig(
|
||||||
@@ -174,8 +174,8 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=multi_y_config
|
form_data_key="abc123", dataset_id=1, config=multi_y_config
|
||||||
)
|
)
|
||||||
assert len(request.config.y) == 3
|
assert len(request.config["y"]) == 3
|
||||||
assert request.config.y[1].aggregate == "AVG"
|
assert request.config["y"][1]["aggregate"] == "AVG"
|
||||||
|
|
||||||
# Test filter operators
|
# Test filter operators
|
||||||
operators = ["=", "!=", ">", ">=", "<", "<="]
|
operators = ["=", "!=", ">", ">=", "<", "<="]
|
||||||
@@ -188,7 +188,7 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=table_config
|
form_data_key="abc123", dataset_id=1, config=table_config
|
||||||
)
|
)
|
||||||
assert len(request.config.filters) == 6
|
assert len(request.config["filters"]) == 6
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_preview_response_structure(self):
|
async def test_update_chart_preview_response_structure(self):
|
||||||
@@ -251,10 +251,10 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=config
|
form_data_key="abc123", dataset_id=1, config=config
|
||||||
)
|
)
|
||||||
assert request.config.x_axis.title == "Date"
|
assert request.config["x_axis"]["title"] == "Date"
|
||||||
assert request.config.x_axis.format == "smart_date"
|
assert request.config["x_axis"]["format"] == "smart_date"
|
||||||
assert request.config.y_axis.title == "Sales Amount"
|
assert request.config["y_axis"]["title"] == "Sales Amount"
|
||||||
assert request.config.y_axis.format == "$,.2f"
|
assert request.config["y_axis"]["format"] == "$,.2f"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_preview_legend_configurations(self):
|
async def test_update_chart_preview_legend_configurations(self):
|
||||||
@@ -270,8 +270,8 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=config
|
form_data_key="abc123", dataset_id=1, config=config
|
||||||
)
|
)
|
||||||
assert request.config.legend.position == pos
|
assert request.config["legend"]["position"] == pos
|
||||||
assert request.config.legend.show is True
|
assert request.config["legend"]["show"] is True
|
||||||
|
|
||||||
# Hidden legend
|
# Hidden legend
|
||||||
config = XYChartConfig(
|
config = XYChartConfig(
|
||||||
@@ -283,7 +283,7 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=config
|
form_data_key="abc123", dataset_id=1, config=config
|
||||||
)
|
)
|
||||||
assert request.config.legend.show is False
|
assert request.config["legend"]["show"] is False
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_preview_aggregation_functions(self):
|
async def test_update_chart_preview_aggregation_functions(self):
|
||||||
@@ -297,7 +297,7 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=config
|
form_data_key="abc123", dataset_id=1, config=config
|
||||||
)
|
)
|
||||||
assert request.config.columns[0].aggregate == agg
|
assert request.config["columns"][0]["aggregate"] == agg
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_preview_error_responses(self):
|
async def test_update_chart_preview_error_responses(self):
|
||||||
@@ -347,10 +347,10 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=config
|
form_data_key="abc123", dataset_id=1, config=config
|
||||||
)
|
)
|
||||||
assert len(request.config.filters) == 3
|
assert len(request.config["filters"]) == 3
|
||||||
assert request.config.filters[0].column == "region"
|
assert request.config["filters"][0]["column"] == "region"
|
||||||
assert request.config.filters[1].op == ">="
|
assert request.config["filters"][1]["op"] == ">="
|
||||||
assert request.config.filters[2].value == "2024-01-01"
|
assert request.config["filters"][2]["value"] == "2024-01-01"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_preview_form_data_key_handling(self):
|
async def test_update_chart_preview_form_data_key_handling(self):
|
||||||
@@ -447,12 +447,12 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=config
|
form_data_key="abc123", dataset_id=1, config=config
|
||||||
)
|
)
|
||||||
assert len(request.config.y) == 4
|
assert len(request.config["y"]) == 4
|
||||||
assert request.config.y[0].name == "revenue"
|
assert request.config["y"][0]["name"] == "revenue"
|
||||||
assert request.config.y[1].name == "cost"
|
assert request.config["y"][1]["name"] == "cost"
|
||||||
assert request.config.y[2].name == "profit"
|
assert request.config["y"][2]["name"] == "profit"
|
||||||
assert request.config.y[3].name == "orders"
|
assert request.config["y"][3]["name"] == "orders"
|
||||||
assert request.config.y[3].aggregate == "COUNT"
|
assert request.config["y"][3]["aggregate"] == "COUNT"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_update_chart_preview_table_sorting(self):
|
async def test_update_chart_preview_table_sorting(self):
|
||||||
@@ -470,5 +470,5 @@ class TestUpdateChartPreview:
|
|||||||
request = UpdateChartPreviewRequest(
|
request = UpdateChartPreviewRequest(
|
||||||
form_data_key="abc123", dataset_id=1, config=config
|
form_data_key="abc123", dataset_id=1, config=config
|
||||||
)
|
)
|
||||||
assert request.config.sort_by == ["sales", "profit"]
|
assert request.config["sort_by"] == ["sales", "profit"]
|
||||||
assert len(request.config.columns) == 3
|
assert len(request.config["columns"]) == 3
|
||||||
|
|||||||
Reference in New Issue
Block a user