# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Tests for chart utilities module""" from typing import Any from unittest.mock import patch import pytest from superset.mcp_service.chart.chart_utils import ( configure_temporal_handling, create_metric_object, generate_chart_name, generate_explore_link, is_column_truly_temporal, map_config_to_form_data, map_filter_operator, map_table_config, map_xy_config, ) from superset.mcp_service.chart.schemas import ( AxisConfig, ColumnRef, FilterConfig, LegendConfig, TableChartConfig, XYChartConfig, ) from superset.utils.core import GenericDataType class TestCreateMetricObject: """Test create_metric_object function""" def test_create_metric_object_with_aggregate(self) -> None: """Test creating metric object with specified aggregate""" col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue") result = create_metric_object(col) assert result["aggregate"] == "SUM" assert result["column"]["column_name"] == "revenue" assert result["label"] == "Total Revenue" assert result["optionName"] == "metric_revenue" assert result["expressionType"] == "SIMPLE" def test_create_metric_object_default_aggregate(self) -> None: """Test creating metric object with default aggregate""" col = ColumnRef(name="orders") result = create_metric_object(col) assert result["aggregate"] == "SUM" assert result["column"]["column_name"] == "orders" assert result["label"] == "SUM(orders)" assert result["optionName"] == "metric_orders" class TestMapFilterOperator: """Test map_filter_operator function""" def test_map_filter_operators(self) -> None: """Test mapping of various filter operators""" assert map_filter_operator("=") == "==" assert map_filter_operator(">") == ">" assert map_filter_operator("<") == "<" assert map_filter_operator(">=") == ">=" assert map_filter_operator("<=") == "<=" assert map_filter_operator("!=") == "!=" def test_map_filter_operator_unknown(self) -> None: """Test mapping of unknown operator returns original""" assert map_filter_operator("UNKNOWN") == "UNKNOWN" class TestMapTableConfig: """Test map_table_config function""" def test_map_table_config_basic(self) -> None: """Test basic table config mapping with aggregated columns""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product", aggregate="COUNT"), ColumnRef(name="revenue", aggregate="SUM"), ], ) result = map_table_config(config) assert result["viz_type"] == "table" assert result["query_mode"] == "aggregate" # Aggregated columns should be in metrics, not all_columns assert "all_columns" not in result assert len(result["metrics"]) == 2 assert result["metrics"][0]["aggregate"] == "COUNT" assert result["metrics"][1]["aggregate"] == "SUM" def test_map_table_config_raw_columns(self) -> None: """Test table config mapping with raw columns (no aggregates)""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product"), ColumnRef(name="category"), ], ) result = map_table_config(config) assert result["viz_type"] == "table" assert result["query_mode"] == "raw" # Raw columns should be in all_columns assert result["all_columns"] == ["product", "category"] assert "metrics" not in result def test_map_table_config_with_filters(self) -> None: """Test table config mapping with filters""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], filters=[FilterConfig(column="status", op="=", value="active")], ) result = map_table_config(config) assert "adhoc_filters" in result assert len(result["adhoc_filters"]) == 1 filter_obj = result["adhoc_filters"][0] assert filter_obj["subject"] == "status" assert filter_obj["operator"] == "==" assert filter_obj["comparator"] == "active" assert filter_obj["expressionType"] == "SIMPLE" def test_map_table_config_with_sort(self) -> None: """Test table config mapping with sort""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], sort_by=["product", "revenue"], ) result = map_table_config(config) assert result["order_by_cols"] == ["product", "revenue"] def test_map_table_config_ag_grid_table(self) -> None: """Test table config mapping with AG Grid Interactive Table viz_type""" config = TableChartConfig( chart_type="table", viz_type="ag-grid-table", columns=[ ColumnRef(name="product_line"), ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ], ) result = map_table_config(config) # AG Grid tables use 'ag-grid-table' viz_type assert result["viz_type"] == "ag-grid-table" assert result["query_mode"] == "aggregate" assert len(result["metrics"]) == 1 assert result["metrics"][0]["aggregate"] == "SUM" # Non-aggregated columns should be in groupby assert "groupby" in result assert "product_line" in result["groupby"] def test_map_table_config_ag_grid_raw_mode(self) -> None: """Test AG Grid table with raw columns (no aggregates)""" config = TableChartConfig( chart_type="table", viz_type="ag-grid-table", columns=[ ColumnRef(name="product_line"), ColumnRef(name="category"), ColumnRef(name="region"), ], ) result = map_table_config(config) assert result["viz_type"] == "ag-grid-table" assert result["query_mode"] == "raw" assert result["all_columns"] == ["product_line", "category", "region"] assert "metrics" not in result def test_map_table_config_default_viz_type(self) -> None: """Test that default viz_type is 'table' not 'ag-grid-table'""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], ) result = map_table_config(config) assert result["viz_type"] == "table" class TestMapXYConfig: """Test map_xy_config function""" def test_map_xy_config_line_chart(self) -> None: """Test XY config mapping for line chart""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_line" assert result["x_axis"] == "date" assert len(result["metrics"]) == 1 assert result["metrics"][0]["aggregate"] == "SUM" def test_map_xy_config_with_groupby(self) -> None: """Test XY config mapping with group by""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="bar", group_by=ColumnRef(name="region"), ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_bar" assert result["groupby"] == ["region"] def test_map_xy_config_with_axes(self) -> None: """Test XY config mapping with axis configurations""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="area", x_axis=AxisConfig(title="Date", format="%Y-%m-%d"), y_axis=AxisConfig(title="Revenue", scale="log", format="$,.2f"), ) result = map_xy_config(config) assert result["viz_type"] == "echarts_area" assert result["x_axis_title"] == "Date" assert result["x_axis_format"] == "%Y-%m-%d" assert result["y_axis_title"] == "Revenue" assert result["y_axis_format"] == "$,.2f" assert result["y_axis_scale"] == "log" def test_map_xy_config_with_legend(self) -> None: """Test XY config mapping with legend configuration""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="scatter", legend=LegendConfig(show=False, position="top"), ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_scatter" assert result["show_legend"] is False assert result["legend_orientation"] == "top" def test_map_xy_config_with_time_grain_month(self) -> None: """Test XY config mapping with monthly time grain""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="bar", time_grain="P1M", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_bar" assert result["x_axis"] == "order_date" assert result["time_grain_sqla"] == "P1M" def test_map_xy_config_with_time_grain_day(self) -> None: """Test XY config mapping with daily time grain""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="created_at"), y=[ColumnRef(name="count", aggregate="COUNT")], kind="line", time_grain="P1D", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_line" assert result["x_axis"] == "created_at" assert result["time_grain_sqla"] == "P1D" def test_map_xy_config_with_time_grain_hour(self) -> None: """Test XY config mapping with hourly time grain""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="timestamp"), y=[ColumnRef(name="requests", aggregate="SUM")], kind="area", time_grain="PT1H", ) result = map_xy_config(config) assert result["time_grain_sqla"] == "PT1H" def test_map_xy_config_without_time_grain(self) -> None: """Test XY config mapping without time grain (should not set time_grain_sqla)""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line", ) result = map_xy_config(config) assert "time_grain_sqla" not in result def test_map_xy_config_with_time_grain_and_groupby(self) -> None: """Test XY config mapping with time grain and group by""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", time_grain="P1W", group_by=ColumnRef(name="category"), ) result = map_xy_config(config) assert result["time_grain_sqla"] == "P1W" assert result["groupby"] == ["category"] assert result["x_axis"] == "order_date" class TestMapConfigToFormData: """Test map_config_to_form_data function""" def test_map_table_config_type(self) -> None: """Test mapping table config type""" config = TableChartConfig(chart_type="table", columns=[ColumnRef(name="test")]) result = map_config_to_form_data(config) assert result["viz_type"] == "table" def test_map_xy_config_type(self) -> None: """Test mapping XY config type""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line", ) result = map_config_to_form_data(config) assert result["viz_type"] == "echarts_timeseries_line" def test_map_unsupported_config_type(self) -> None: """Test mapping unsupported config type raises error""" with pytest.raises(ValueError, match="Unsupported config type"): map_config_to_form_data("invalid_config") # type: ignore class TestGenerateChartName: """Test generate_chart_name function""" def test_generate_table_chart_name(self) -> None: """Test generating name for table chart""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product"), ColumnRef(name="revenue"), ], ) result = generate_chart_name(config) assert result == "Table Chart - product, revenue" def test_generate_xy_chart_name(self) -> None: """Test generating name for XY chart""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue"), ColumnRef(name="orders")], kind="line", ) result = generate_chart_name(config) assert result == "Line Chart - date vs revenue, orders" def test_generate_chart_name_unsupported(self) -> None: """Test generating name for unsupported config type""" result = generate_chart_name("invalid_config") # type: ignore assert result == "Chart" class TestGenerateExploreLink: """Test generate_explore_link function""" @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") def test_generate_explore_link_uses_base_url(self, mock_get_base_url) -> None: """Test that generate_explore_link uses the configured base URL""" from urllib.parse import urlparse mock_get_base_url.return_value = "https://superset.example.com" form_data = {"viz_type": "table", "metrics": ["count"]} # Mock dataset not found to trigger fallback URL with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None): result = generate_explore_link("123", form_data) # Should use the configured base URL - use urlparse to avoid CodeQL warning parsed_url = urlparse(result) expected_netloc = "superset.example.com" assert parsed_url.scheme == "https" assert parsed_url.netloc == expected_netloc assert "/explore/" in parsed_url.path assert "datasource_id=123" in result @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") def test_generate_explore_link_fallback_url(self, mock_get_base_url) -> None: """Test generate_explore_link returns fallback URL when dataset not found""" mock_get_base_url.return_value = "http://localhost:9001" form_data = {"viz_type": "table"} # Mock dataset not found scenario with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None): result = generate_explore_link("999", form_data) assert ( result == "http://localhost:9001/explore/?datasource_type=table&datasource_id=999" ) @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") @patch("superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand") def test_generate_explore_link_with_form_data_key( self, mock_command, mock_get_base_url ) -> None: """Test generate_explore_link creates form_data_key when dataset exists""" mock_get_base_url.return_value = "http://localhost:9001" mock_command.return_value.run.return_value = "test_form_data_key" # Mock dataset exists mock_dataset = type("Dataset", (), {"id": 123})() with patch( "superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_dataset ): result = generate_explore_link(123, {"viz_type": "table"}) assert ( result == "http://localhost:9001/explore/?form_data_key=test_form_data_key" ) mock_command.assert_called_once() @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") def test_generate_explore_link_exception_handling(self, mock_get_base_url) -> None: """Test generate_explore_link handles SQLAlchemy exceptions gracefully""" from sqlalchemy.exc import SQLAlchemyError mock_get_base_url.return_value = "http://localhost:9001" # Mock SQLAlchemy exception during dataset lookup with patch( "superset.daos.dataset.DatasetDAO.find_by_id", side_effect=SQLAlchemyError("DB error"), ): result = generate_explore_link("123", {"viz_type": "table"}) # Should fallback to basic URL assert ( result == "http://localhost:9001/explore/?datasource_type=table&datasource_id=123" ) class TestCriticalBugFixes: """Test critical bug fixes for chart utilities.""" def test_time_series_aggregation_fix(self) -> None: """Test that time series charts preserve temporal dimension.""" # Create a time series chart configuration config = XYChartConfig( chart_type="xy", kind="line", x=ColumnRef(name="order_date"), y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")], ) form_data = map_xy_config(config) # Verify the fix: x_axis should be set correctly assert form_data["x_axis"] == "order_date" # Verify the fix: groupby should not duplicate x_axis # This prevents the "Duplicate column/metric labels" error assert "groupby" not in form_data or "order_date" not in form_data.get( "groupby", [] ) # Verify chart type mapping assert form_data["viz_type"] == "echarts_timeseries_line" def test_time_series_with_explicit_group_by(self) -> None: """Test time series with explicit group_by different from x_axis.""" config = XYChartConfig( chart_type="xy", kind="line", x=ColumnRef(name="order_date"), y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")], group_by=ColumnRef(name="category"), ) form_data = map_xy_config(config) # Verify x_axis is set assert form_data["x_axis"] == "order_date" # Verify groupby only contains the explicit group_by, not x_axis assert form_data.get("groupby") == ["category"] assert "order_date" not in form_data.get("groupby", []) def test_duplicate_label_prevention(self) -> None: """Test that duplicate column/metric labels are prevented.""" # This configuration would previously cause: # "Duplicate column/metric labels: 'price_each'" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="price_each", label="Price Each"), # Custom label y=[ ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ColumnRef(name="quantity", aggregate="COUNT", label="Order Count"), ], group_by=ColumnRef(name="price_each"), # Same column as x_axis kind="line", ) form_data = map_xy_config(config) # Verify the fix: x_axis is set assert form_data["x_axis"] == "price_each" # Verify the fix: groupby is empty because group_by == x_axis # This prevents the duplicate label error assert "groupby" not in form_data or not form_data["groupby"] def test_metric_object_creation_with_labels(self) -> None: """Test that metric objects are created correctly with proper labels.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ColumnRef(name="profit", aggregate="AVG"), # No custom label ], kind="bar", ) form_data = map_xy_config(config) # Verify metrics are created correctly metrics = form_data["metrics"] assert len(metrics) == 2 # First metric with custom label assert metrics[0]["label"] == "Total Sales" assert metrics[0]["aggregate"] == "SUM" assert metrics[0]["column"]["column_name"] == "sales" # Second metric with auto-generated label assert metrics[1]["label"] == "AVG(profit)" assert metrics[1]["aggregate"] == "AVG" assert metrics[1]["column"]["column_name"] == "profit" def test_chart_type_mapping_comprehensive(self) -> None: """Test that chart types are mapped correctly to Superset viz types.""" test_cases = [ ("line", "echarts_timeseries_line"), ("bar", "echarts_timeseries_bar"), ("area", "echarts_area"), ("scatter", "echarts_timeseries_scatter"), ] for kind, expected_viz_type in test_cases: config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="value", aggregate="SUM")], kind=kind, ) form_data = map_xy_config(config) assert form_data["viz_type"] == expected_viz_type class TestIsColumnTrulyTemporal: """Test is_column_truly_temporal function using db_engine_spec""" def _create_mock_dataset( self, column_name: str, column_type: str, generic_type: GenericDataType, ): """Helper to create a mock dataset with proper db_engine_spec""" from unittest.mock import MagicMock from superset.utils.core import ColumnSpec mock_column = MagicMock() mock_column.column_name = column_name mock_column.type = column_type mock_db_engine_spec = MagicMock() mock_column_spec = ColumnSpec( sqla_type=MagicMock(), generic_type=generic_type, is_dttm=False ) mock_db_engine_spec.get_column_spec.return_value = mock_column_spec mock_database = MagicMock() mock_database.db_engine_spec = mock_db_engine_spec mock_dataset = MagicMock() mock_dataset.columns = [mock_column] mock_dataset.database = mock_database return mock_dataset def test_returns_true_when_no_dataset_id(self) -> None: """Test returns True (default) when dataset_id is None""" result = is_column_truly_temporal("year", None) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_when_dataset_not_found(self, mock_dao) -> None: """Test returns True when dataset is not found""" mock_dao.find_by_id.return_value = None result = is_column_truly_temporal("year", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_returns_false_for_numeric_column(self, mock_dao) -> None: """Test returns False for NUMERIC generic type (e.g., BIGINT)""" mock_dataset = self._create_mock_dataset( "year", "BIGINT", GenericDataType.NUMERIC ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", 123) assert result is False @patch("superset.daos.dataset.DatasetDAO") def test_returns_false_for_integer_column(self, mock_dao) -> None: """Test returns False for INTEGER column (NUMERIC generic type)""" mock_dataset = self._create_mock_dataset( "month", "INTEGER", GenericDataType.NUMERIC ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("month", 123) assert result is False @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_for_temporal_column(self, mock_dao) -> None: """Test returns True for TEMPORAL generic type (e.g., TIMESTAMP)""" mock_dataset = self._create_mock_dataset( "created_at", "TIMESTAMP", GenericDataType.TEMPORAL ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("created_at", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_for_date_column(self, mock_dao) -> None: """Test returns True for DATE column (TEMPORAL generic type)""" mock_dataset = self._create_mock_dataset( "order_date", "DATE", GenericDataType.TEMPORAL ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("order_date", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_case_insensitive_column_name_lookup(self, mock_dao) -> None: """Test column name lookup is case insensitive""" mock_dataset = self._create_mock_dataset( "Year", "BIGINT", GenericDataType.NUMERIC ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", 123) assert result is False @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_on_value_error(self, mock_dao) -> None: """Test returns True (default) when ValueError occurs""" mock_dao.find_by_id.side_effect = ValueError("Invalid ID") result = is_column_truly_temporal("year", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_on_attribute_error(self, mock_dao) -> None: """Test returns True (default) when AttributeError occurs""" mock_dao.find_by_id.side_effect = AttributeError("Missing attribute") result = is_column_truly_temporal("year", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_handles_uuid_dataset_id(self, mock_dao) -> None: """Test handles UUID string as dataset_id""" mock_dataset = self._create_mock_dataset( "year", "BIGINT", GenericDataType.NUMERIC ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", "abc-123-uuid") assert result is False mock_dao.find_by_id.assert_called_with("abc-123-uuid", id_column="uuid") @patch("superset.daos.dataset.DatasetDAO") def test_falls_back_to_is_dttm_when_no_column_spec(self, mock_dao) -> None: """Test falls back to is_dttm flag when get_column_spec returns None""" from unittest.mock import MagicMock mock_column = MagicMock() mock_column.column_name = "year" mock_column.type = "UNKNOWN_TYPE" mock_column.is_dttm = False mock_db_engine_spec = MagicMock() mock_db_engine_spec.get_column_spec.return_value = None mock_database = MagicMock() mock_database.db_engine_spec = mock_db_engine_spec mock_dataset = MagicMock() mock_dataset.columns = [mock_column] mock_dataset.database = mock_database mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", 123) assert result is False @patch("superset.daos.dataset.DatasetDAO") def test_falls_back_to_is_dttm_when_no_type(self, mock_dao) -> None: """Test falls back to is_dttm flag when column has no type""" from unittest.mock import MagicMock mock_column = MagicMock() mock_column.column_name = "year" mock_column.type = None mock_column.is_dttm = True mock_dataset = MagicMock() mock_dataset.columns = [mock_column] mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", 123) assert result is True class TestConfigureTemporalHandling: """Test configure_temporal_handling function""" def test_temporal_column_with_time_grain(self) -> None: """Test temporal column sets time_grain_sqla""" form_data: dict[str, Any] = {} configure_temporal_handling(form_data, x_is_temporal=True, time_grain="P1M") assert form_data["time_grain_sqla"] == "P1M" def test_temporal_column_without_time_grain(self) -> None: """Test temporal column without time_grain doesn't set time_grain_sqla""" form_data: dict[str, Any] = {} configure_temporal_handling(form_data, x_is_temporal=True, time_grain=None) assert "time_grain_sqla" not in form_data def test_non_temporal_column_sets_categorical_config(self) -> None: """Test non-temporal column sets categorical configuration""" form_data: dict[str, Any] = {} configure_temporal_handling(form_data, x_is_temporal=False, time_grain=None) assert form_data["x_axis_sort_series_type"] == "name" assert form_data["x_axis_sort_series_ascending"] is True assert form_data["time_grain_sqla"] is None assert form_data["granularity_sqla"] is None def test_non_temporal_column_ignores_time_grain(self) -> None: """Test non-temporal column ignores time_grain parameter""" form_data: dict[str, Any] = {} configure_temporal_handling(form_data, x_is_temporal=False, time_grain="P1M") # Should still set categorical config, not time_grain assert form_data["time_grain_sqla"] is None assert form_data["x_axis_sort_series_type"] == "name" class TestMapXYConfigWithNonTemporalColumn: """Test map_xy_config with non-temporal columns (DATE_TRUNC fix)""" @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_non_temporal_column_disables_time_grain(self, mock_is_temporal) -> None: """Test non-temporal column sets categorical config""" mock_is_temporal.return_value = False config = XYChartConfig( chart_type="xy", x=ColumnRef(name="year"), y=[ColumnRef(name="sales", aggregate="SUM")], kind="bar", ) result = map_xy_config(config, dataset_id=123) assert result["x_axis"] == "year" assert result["x_axis_sort_series_type"] == "name" assert result["x_axis_sort_series_ascending"] is True assert result["time_grain_sqla"] is None assert result["granularity_sqla"] is None @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_temporal_column_allows_time_grain(self, mock_is_temporal) -> None: """Test temporal column allows time_grain to be set""" mock_is_temporal.return_value = True config = XYChartConfig( chart_type="xy", x=ColumnRef(name="created_at"), y=[ColumnRef(name="count", aggregate="COUNT")], kind="line", time_grain="P1W", ) result = map_xy_config(config, dataset_id=123) assert result["x_axis"] == "created_at" assert result["time_grain_sqla"] == "P1W" assert "x_axis_sort_series_type" not in result @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_non_temporal_ignores_time_grain_param(self, mock_is_temporal) -> None: """Test non-temporal column ignores time_grain even if specified""" mock_is_temporal.return_value = False config = XYChartConfig( chart_type="xy", x=ColumnRef(name="year"), y=[ColumnRef(name="sales", aggregate="SUM")], kind="bar", time_grain="P1M", # This should be ignored for non-temporal ) result = map_xy_config(config, dataset_id=123) # time_grain_sqla should be None, not P1M assert result["time_grain_sqla"] is None assert result["x_axis_sort_series_type"] == "name"