diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index bdaae915703..ece8e1310a3 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -30,6 +30,7 @@ from superset.mcp_service.chart.schemas import ( ChartCapabilities, ChartSemantics, ColumnRef, + FilterConfig, MixedTimeseriesChartConfig, PieChartConfig, PivotTableChartConfig, @@ -326,6 +327,24 @@ def map_config_to_form_data( raise ValueError(f"Unsupported config type: {type(config)}") +def _add_adhoc_filters( + form_data: Dict[str, Any], filters: list[FilterConfig] | None +) -> None: + """Add adhoc filters to form_data if any are specified.""" + if filters: + form_data["adhoc_filters"] = [ + { + "clause": "WHERE", + "expressionType": "SIMPLE", + "subject": filter_config.column, + "operator": map_filter_operator(filter_config.op), + "comparator": filter_config.value, + } + for filter_config in filters + if filter_config is not None + ] + + def map_table_config(config: TableChartConfig) -> Dict[str, Any]: """Map table chart config to form_data with defensive validation.""" # Early validation to prevent empty charts @@ -362,7 +381,6 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]: "query_mode": "raw", "include_time": False, "order_desc": True, - "row_limit": 1000, # Reasonable limit for raw data } ) @@ -388,22 +406,13 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]: } ) - if config.filters: - form_data["adhoc_filters"] = [ - { - "clause": "WHERE", - "expressionType": "SIMPLE", - "subject": filter_config.column, - "operator": map_filter_operator(filter_config.op), - "comparator": filter_config.value, - } - for filter_config in config.filters - if filter_config is not None - ] + _add_adhoc_filters(form_data, config.filters) if config.sort_by: form_data["order_by_cols"] = config.sort_by + form_data["row_limit"] = config.row_limit + return form_data @@ -574,19 +583,9 @@ def map_xy_config( if groupby_columns: form_data["groupby"] = groupby_columns - # Add filters if specified - if config.filters: - form_data["adhoc_filters"] = [ - { - "clause": "WHERE", - "expressionType": "SIMPLE", - "subject": filter_config.column, - "operator": map_filter_operator(filter_config.op), - "comparator": filter_config.value, - } - for filter_config in config.filters - if filter_config is not None - ] + _add_adhoc_filters(form_data, config.filters) + + form_data["row_limit"] = config.row_limit # Add stacking configuration if getattr(config, "stacked", False): @@ -623,18 +622,7 @@ def map_pie_config(config: PieChartConfig) -> Dict[str, Any]: "date_format": "smart_date", } - if config.filters: - form_data["adhoc_filters"] = [ - { - "clause": "WHERE", - "expressionType": "SIMPLE", - "subject": filter_config.column, - "operator": map_filter_operator(filter_config.op), - "comparator": filter_config.value, - } - for filter_config in config.filters - if filter_config is not None - ] + _add_adhoc_filters(form_data, config.filters) return form_data @@ -667,18 +655,7 @@ def map_pivot_table_config(config: PivotTableChartConfig) -> Dict[str, Any]: "row_limit": config.row_limit, } - if config.filters: - form_data["adhoc_filters"] = [ - { - "clause": "WHERE", - "expressionType": "SIMPLE", - "subject": filter_config.column, - "operator": map_filter_operator(filter_config.op), - "comparator": filter_config.value, - } - for filter_config in config.filters - if filter_config is not None - ] + _add_adhoc_filters(form_data, config.filters) return form_data @@ -772,21 +749,11 @@ def map_mixed_timeseries_config( if config.group_by_secondary and config.group_by_secondary.name != config.x.name: form_data["groupby_b"] = [config.group_by_secondary.name] + form_data["row_limit"] = config.row_limit + _add_mixed_axis_config(form_data, config) - # Filters - if config.filters: - form_data["adhoc_filters"] = [ - { - "clause": "WHERE", - "expressionType": "SIMPLE", - "subject": filter_config.column, - "operator": map_filter_operator(filter_config.op), - "comparator": filter_config.value, - } - for filter_config in config.filters - if filter_config is not None - ] + _add_adhoc_filters(form_data, config.filters) return form_data @@ -820,7 +787,7 @@ def _humanize_column(col: ColumnRef) -> str: def _summarize_filters( - filters: list[Any] | None, + filters: list[FilterConfig] | None, ) -> str | None: """Extract a short context string from filter configs.""" if not filters: diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 373f667e586..91eeac5fc6b 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -588,6 +588,7 @@ class MixedTimeseriesChartConfig(BaseModel): y_axis: AxisConfig | None = None y_axis_secondary: AxisConfig | None = None filters: List[FilterConfig] | None = None + row_limit: int = Field(10000, description="Max data points", ge=1, le=50000) class TableChartConfig(BaseModel): @@ -604,6 +605,7 @@ class TableChartConfig(BaseModel): ) filters: List[FilterConfig] | None = None sort_by: List[str] | None = None + row_limit: int = Field(1000, description="Max rows returned", ge=1, le=50000) @model_validator(mode="after") def validate_unique_column_labels(self) -> "TableChartConfig": @@ -656,6 +658,7 @@ class XYChartConfig(BaseModel): y_axis: AxisConfig | None = None legend: LegendConfig | None = None filters: List[FilterConfig] | None = None + row_limit: int = Field(10000, description="Max data points", ge=1, le=50000) @model_validator(mode="after") def validate_unique_column_labels(self) -> "XYChartConfig": diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index cea45bf0244..f0b35f74376 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -299,6 +299,78 @@ class TestXYChartConfig: assert config.orientation == "horizontal" +class TestRowLimit: + """Test row_limit field on chart configs.""" + + def test_xy_chart_default_row_limit(self) -> None: + """Test that XYChartConfig has default row_limit of 10000.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + ) + assert config.row_limit == 10000 + + def test_xy_chart_custom_row_limit(self) -> None: + """Test that XYChartConfig accepts custom row_limit.""" + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + row_limit=100, + ) + assert config.row_limit == 100 + + def test_xy_chart_row_limit_validation(self) -> None: + """Test that XYChartConfig rejects invalid row_limit.""" + with pytest.raises(ValidationError): + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + row_limit=0, + ) + with pytest.raises(ValidationError): + XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + row_limit=100000, + ) + + def test_table_chart_default_row_limit(self) -> None: + """Test that TableChartConfig has default row_limit of 1000.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + ) + assert config.row_limit == 1000 + + def test_table_chart_custom_row_limit(self) -> None: + """Test that TableChartConfig accepts custom row_limit.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + row_limit=500, + ) + assert config.row_limit == 500 + + def test_table_chart_row_limit_validation(self) -> None: + """Test that TableChartConfig rejects invalid row_limit.""" + with pytest.raises(ValidationError): + TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + row_limit=0, + ) + with pytest.raises(ValidationError): + TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + row_limit=100000, + ) + + class TestTableChartConfigExtraFields: """Test TableChartConfig rejects unknown fields.""" diff --git a/tests/unit_tests/mcp_service/chart/test_chart_utils.py b/tests/unit_tests/mcp_service/chart/test_chart_utils.py index ce5c618e03b..480399f5cec 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_utils.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_utils.py @@ -23,6 +23,7 @@ from unittest.mock import MagicMock, patch import pytest from superset.mcp_service.chart.chart_utils import ( + _add_adhoc_filters, configure_temporal_handling, create_metric_object, generate_chart_name, @@ -313,6 +314,66 @@ class TestMapTableConfig: assert result["viz_type"] == "table" + def test_map_table_config_row_limit(self) -> None: + """Test that row_limit is mapped to form_data.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + row_limit=500, + ) + + result = map_table_config(config) + + assert result["row_limit"] == 500 + + def test_map_table_config_default_row_limit(self) -> None: + """Test that default row_limit is mapped to form_data.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product", aggregate="SUM")], + ) + + result = map_table_config(config) + + assert result["row_limit"] == 1000 + + +class TestAddAdhocFilters: + """Test _add_adhoc_filters helper function""" + + def test_adds_filters_to_form_data(self) -> None: + """Test that filters are correctly added to form_data.""" + form_data: dict[str, Any] = {} + filters = [ + FilterConfig(column="region", op="=", value="US"), + FilterConfig(column="year", op=">", value=2020), + ] + + _add_adhoc_filters(form_data, filters) + + assert "adhoc_filters" in form_data + assert len(form_data["adhoc_filters"]) == 2 + assert form_data["adhoc_filters"][0]["subject"] == "region" + assert form_data["adhoc_filters"][0]["operator"] == "==" + assert form_data["adhoc_filters"][1]["subject"] == "year" + assert form_data["adhoc_filters"][1]["operator"] == ">" + + def test_no_filters_does_nothing(self) -> None: + """Test that None filters leave form_data unchanged.""" + form_data: dict[str, Any] = {"viz_type": "table"} + + _add_adhoc_filters(form_data, None) + + assert "adhoc_filters" not in form_data + + def test_empty_list_does_nothing(self) -> None: + """Test that empty filter list leaves form_data unchanged.""" + form_data: dict[str, Any] = {"viz_type": "table"} + + _add_adhoc_filters(form_data, []) + + assert "adhoc_filters" not in form_data + class TestMapXYConfig: """Test map_xy_config function""" @@ -538,6 +599,57 @@ class TestMapXYConfig: assert result["stack"] == "Stack" assert result["groupby"] == ["level"] + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_map_xy_config_with_filters(self, mock_is_temporal) -> None: + """Test that filters are mapped to adhoc_filters in XY form_data.""" + mock_is_temporal.return_value = True + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="line", + filters=[FilterConfig(column="region", op="=", value="US")], + ) + + result = map_xy_config(config) + + assert "adhoc_filters" in result + assert len(result["adhoc_filters"]) == 1 + assert result["adhoc_filters"][0]["subject"] == "region" + assert result["adhoc_filters"][0]["operator"] == "==" + assert result["adhoc_filters"][0]["comparator"] == "US" + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_map_xy_config_row_limit(self, mock_is_temporal) -> None: + """Test that row_limit is mapped to form_data.""" + mock_is_temporal.return_value = True + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="line", + row_limit=250, + ) + + result = map_xy_config(config) + + assert result["row_limit"] == 250 + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_map_xy_config_default_row_limit(self, mock_is_temporal) -> None: + """Test that default row_limit is mapped to form_data.""" + mock_is_temporal.return_value = True + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + kind="bar", + ) + + result = map_xy_config(config) + + assert result["row_limit"] == 10000 + class TestMapConfigToFormData: """Test map_config_to_form_data function""" diff --git a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py index e2e37ba99e6..9221d2d2bf1 100644 --- a/tests/unit_tests/mcp_service/chart/test_new_chart_types.py +++ b/tests/unit_tests/mcp_service/chart/test_new_chart_types.py @@ -512,6 +512,25 @@ class TestMixedTimeseriesChartConfigSchema: unknown_field="bad", ) + def test_mixed_timeseries_default_row_limit(self) -> None: + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + assert config.row_limit == 10000 + + def test_mixed_timeseries_custom_row_limit(self) -> None: + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + row_limit=500, + ) + assert config.row_limit == 500 + # ============================================================ # Mixed Timeseries Form Data Mapping Tests @@ -636,6 +655,35 @@ class TestMapMixedTimeseriesConfig: assert result["y_axis_format_secondary"] == ",d" assert result["logAxisSecondary"] is True + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_mixed_form_data_row_limit(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + row_limit=300, + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert result["row_limit"] == 300 + + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") + def test_mixed_form_data_default_row_limit(self, mock_is_temporal) -> None: + mock_is_temporal.return_value = True + + config = MixedTimeseriesChartConfig( + chart_type="mixed_timeseries", + x=ColumnRef(name="date"), + y=[ColumnRef(name="revenue", aggregate="SUM")], + y_secondary=[ColumnRef(name="orders", aggregate="COUNT")], + ) + result = map_mixed_timeseries_config(config, dataset_id=1) + + assert result["row_limit"] == 10000 + @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_mixed_form_data_with_filters(self, mock_is_temporal) -> None: mock_is_temporal.return_value = True