diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index 3f2f9e2cb31..5b865c1f0dd 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -611,6 +611,35 @@ def _ensure_temporal_adhoc_filter(form_data: Dict[str, Any], column: str) -> Non form_data["adhoc_filters"] = existing +def _resolve_default_x_axis( + config: XYChartConfig, dataset_id: int | str | None +) -> XYChartConfig: + """Resolve x-axis to the dataset's main_dttm_col when x is omitted.""" + if config.x is not None: + return config + + if not dataset_id: + raise ValueError("x-axis column is required when dataset_id is not provided") + from superset.daos.dataset import DatasetDAO + + if isinstance(dataset_id, int) or ( + isinstance(dataset_id, str) and dataset_id.isdigit() + ): + dataset = DatasetDAO.find_by_id(int(dataset_id)) + else: + dataset = DatasetDAO.find_by_id(dataset_id, id_column="uuid") + + if not dataset or not dataset.main_dttm_col: + raise ValueError( + "x-axis column is required: dataset has no primary datetime " + "column (main_dttm_col). Please specify the x-axis column " + "explicitly." + ) + from superset.mcp_service.chart.schemas import ColumnRef + + return config.model_copy(update={"x": ColumnRef(name=dataset.main_dttm_col)}) + + def map_xy_config( config: XYChartConfig, dataset_id: int | str | None = None ) -> Dict[str, Any]: @@ -619,6 +648,10 @@ def map_xy_config( if not config.y: raise ValueError("XY chart must have at least one Y-axis metric") + # Resolve x-axis default: use dataset's main_dttm_col when x is omitted + config = _resolve_default_x_axis(config, dataset_id) + assert config.x is not None # _resolve_default_x_axis guarantees x is set + # Check if x-axis column is truly temporal (based on actual SQL type) x_is_temporal = is_column_truly_temporal(config.x.name, dataset_id) @@ -1001,7 +1034,7 @@ def _table_chart_what(config: TableChartConfig, dataset_name: str | None) -> str def _xy_chart_what(config: XYChartConfig) -> str: """Build the descriptive fragment for an XY chart.""" primary_metric = _humanize_column(config.y[0]) if config.y else "Value" - dimension = _humanize_column(config.x) + dimension = _humanize_column(config.x) if config.x else "Dimension" if config.kind in ("line", "area") and not config.group_by: return f"{primary_metric} Over Time" diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 8e5e13d0c44..fae3faba721 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -1054,13 +1054,25 @@ class TableChartConfig(UnknownFieldCheckMixin): return self +def _metric_display_label(col: ColumnRef) -> str: + """Return the display label for a metric column reference.""" + if col.saved_metric: + return col.label or col.name + if col.aggregate: + return col.label or f"{col.aggregate}({col.name})" + return col.label or col.name + + class XYChartConfig(UnknownFieldCheckMixin): model_config = ConfigDict(extra="ignore", populate_by_name=True) chart_type: Literal["xy"] = "xy" - x: ColumnRef = Field( - ..., - description="X-axis column", + x: ColumnRef | None = Field( + None, + description=( + "X-axis column. If omitted, defaults to the dataset's " + "primary datetime column (main_dttm_col)." + ), validation_alias=AliasChoices("x", "x_axis", "x_column"), ) y: List[ColumnRef] = Field( @@ -1107,22 +1119,17 @@ class XYChartConfig(UnknownFieldCheckMixin): @model_validator(mode="after") def validate_unique_column_labels(self) -> "XYChartConfig": """Ensure all column labels are unique across x, y, and group_by.""" - labels_seen = {} # label -> field_name for error reporting - duplicates = [] + labels_seen: dict[str, str] = {} + duplicates: list[str] = [] - # Check X-axis label - x_label = self.x.label or self.x.name - labels_seen[x_label] = "x" + # Add x-axis label if present (x may be None, resolved later) + if self.x is not None: + x_label = self.x.label or self.x.name + labels_seen[x_label] = "x" # Check Y-axis labels for i, col in enumerate(self.y): - if col.saved_metric: - label = col.label or col.name - elif col.aggregate: - label = col.label or f"{col.aggregate}({col.name})" - else: - label = col.label or col.name - + label = _metric_display_label(col) if label in labels_seen: duplicates.append( f"y[{i}]: '{label}' (conflicts with {labels_seen[label]})" @@ -1133,7 +1140,7 @@ class XYChartConfig(UnknownFieldCheckMixin): # Check group_by labels if present if self.group_by: for i, col in enumerate(self.group_by): - if col.name == self.x.name: + if self.x is not None and col.name == self.x.name: # map_xy_config() strips group_by entries that match x # to prevent Superset "duplicate label" errors, so # we allow them through validation. diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py index 202d7cc5afc..c7f497a426b 100644 --- a/superset/mcp_service/chart/validation/dataset_validator.py +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -204,7 +204,8 @@ class DatasetValidator: if isinstance(config, TableChartConfig): refs.extend(config.columns) elif isinstance(config, XYChartConfig): - refs.append(config.x) + if config.x is not None: + refs.append(config.x) refs.extend(config.y) if config.group_by: refs.extend(config.group_by) diff --git a/superset/mcp_service/chart/validation/runtime/__init__.py b/superset/mcp_service/chart/validation/runtime/__init__.py index b9414551737..5e1c89d0a68 100644 --- a/superset/mcp_service/chart/validation/runtime/__init__.py +++ b/superset/mcp_service/chart/validation/runtime/__init__.py @@ -134,6 +134,8 @@ class RuntimeValidator: chart_type = config.kind if hasattr(config, "kind") else "default" # Check X-axis cardinality + if config.x is None: + return warnings, suggestions is_ok, cardinality_info = CardinalityValidator.check_cardinality( dataset_id=dataset_id, x_column=config.x.name, diff --git a/superset/mcp_service/chart/validation/runtime/chart_type_suggester.py b/superset/mcp_service/chart/validation/runtime/chart_type_suggester.py index 910b13ab66c..a707b14b5a8 100644 --- a/superset/mcp_service/chart/validation/runtime/chart_type_suggester.py +++ b/superset/mcp_service/chart/validation/runtime/chart_type_suggester.py @@ -68,6 +68,9 @@ class ChartTypeSuggester: issues = [] suggestions = [] + if config.x is None: + return True, None + x_analysis = ChartTypeSuggester._analyze_x_axis(config.x.name) y_analysis = ChartTypeSuggester._analyze_y_axis(config.y) @@ -147,6 +150,7 @@ class ChartTypeSuggester: config: XYChartConfig, x_analysis: Dict[str, Any], y_analysis: Dict[str, Any] ) -> Tuple[List[str], List[str]]: """Check for chart type specific issues.""" + assert config.x is not None # caller guards for None issues = [] suggestions = [] @@ -195,6 +199,7 @@ class ChartTypeSuggester: x_is_id: bool, ) -> Tuple[List[str], List[str]]: """Check line chart specific issues.""" + assert config.x is not None issues = [] suggestions = [] @@ -228,6 +233,7 @@ class ChartTypeSuggester: config: XYChartConfig, x_is_categorical: bool, num_metrics: int ) -> Tuple[List[str], List[str]]: """Check scatter chart specific issues.""" + assert config.x is not None issues = [] suggestions = [] @@ -258,6 +264,7 @@ class ChartTypeSuggester: config: XYChartConfig, x_is_temporal: bool ) -> Tuple[List[str], List[str]]: """Check area chart specific issues.""" + assert config.x is not None issues = [] suggestions = [] @@ -295,6 +302,7 @@ class ChartTypeSuggester: config: XYChartConfig, x_is_id: bool ) -> Tuple[List[str], List[str]]: """Check bar chart specific issues.""" + assert config.x is not None issues = [] suggestions = [] diff --git a/superset/mcp_service/chart/validation/runtime/format_validator.py b/superset/mcp_service/chart/validation/runtime/format_validator.py index 34971d47907..73f8c450fa9 100644 --- a/superset/mcp_service/chart/validation/runtime/format_validator.py +++ b/superset/mcp_service/chart/validation/runtime/format_validator.py @@ -79,8 +79,9 @@ class FormatTypeValidator: # Validate X-axis format (usually temporal or categorical) if config.x_axis and config.x_axis.format: + x_column = config.x or ColumnRef(name="default_x_axis") x_warnings = FormatTypeValidator._validate_x_axis_format( - config.x_axis.format, config.x + config.x_axis.format, x_column ) warnings.extend(x_warnings) diff --git a/superset/mcp_service/chart/validation/schema_validator.py b/superset/mcp_service/chart/validation/schema_validator.py index c82728b11bd..7cae450ff59 100644 --- a/superset/mcp_service/chart/validation/schema_validator.py +++ b/superset/mcp_service/chart/validation/schema_validator.py @@ -185,22 +185,15 @@ class SchemaValidator: config: Dict[str, Any], ) -> Tuple[bool, ChartGenerationError | None]: """Pre-validate XY chart configuration.""" - missing_fields = [] - - if "x" not in config: - missing_fields.append("'x' (X-axis column)") + # x is optional — defaults to dataset's main_dttm_col in map_xy_config if "y" not in config: - missing_fields.append("'y' (Y-axis metrics)") - - if missing_fields: return False, ChartGenerationError( error_type="missing_xy_fields", - message=f"XY chart missing required " - f"fields: {', '.join(missing_fields)}", - details="XY charts require both X-axis (dimension) and Y-axis (" - "metrics) specifications", + message="XY chart missing required field: 'y' (Y-axis metrics)", + details="XY charts require Y-axis (metrics) specifications. " + "X-axis is optional and defaults to the dataset's primary " + "datetime column when omitted.", suggestions=[ - "Add 'x' field: {'name': 'column_name'} for X-axis", "Add 'y' field: [{'name': 'metric_column', 'aggregate': 'SUM'}] " "for Y-axis", "Example: {'chart_type': 'xy', 'x': {'name': 'date'}, " diff --git a/tests/unit_tests/mcp_service/chart/test_chart_utils.py b/tests/unit_tests/mcp_service/chart/test_chart_utils.py index 894440d4888..7c2da4e5de7 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_utils.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_utils.py @@ -826,6 +826,62 @@ class TestMapConfigToFormData: with pytest.raises(ValueError, match="Unsupported config type"): map_config_to_form_data("invalid_config") # type: ignore + @patch( + "superset.mcp_service.chart.chart_utils.is_column_truly_temporal", + return_value=True, + ) + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + def test_map_xy_config_x_none_defaults_to_main_dttm_col( + self, mock_find_by_id: Any, mock_is_temporal: Any + ) -> None: + """When x is None, map_xy_config resolves it from dataset.main_dttm_col.""" + mock_dataset = MagicMock() + mock_dataset.main_dttm_col = "order_date" + mock_find_by_id.return_value = mock_dataset + + config = XYChartConfig( + chart_type="xy", + x=None, + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="bar", + ) + + result = map_xy_config(config, dataset_id=42) + + assert result["x_axis"] == "order_date" + mock_find_by_id.assert_called_once_with(42) + + def test_map_xy_config_x_none_no_dataset_id_raises(self) -> None: + """When x is None and no dataset_id, raise ValueError.""" + config = XYChartConfig( + chart_type="xy", + x=None, + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="line", + ) + + with pytest.raises(ValueError, match="x-axis column is required"): + map_xy_config(config, dataset_id=None) + + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + def test_map_xy_config_x_none_no_main_dttm_col_raises( + self, mock_find_by_id: Any + ) -> None: + """When x is None and dataset has no main_dttm_col, raise ValueError.""" + mock_dataset = MagicMock() + mock_dataset.main_dttm_col = None + mock_find_by_id.return_value = mock_dataset + + config = XYChartConfig( + chart_type="xy", + x=None, + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="line", + ) + + with pytest.raises(ValueError, match="no primary datetime column"): + map_xy_config(config, dataset_id=42) + class TestGenerateChartName: """Test generate_chart_name function"""