diff --git a/superset/mcp_service/chart/chart_utils.py b/superset/mcp_service/chart/chart_utils.py index 633e3c6a9c6..d3801e66885 100644 --- a/superset/mcp_service/chart/chart_utils.py +++ b/superset/mcp_service/chart/chart_utils.py @@ -139,8 +139,9 @@ def map_table_config(config: TableChartConfig) -> Dict[str, Any]: if not raw_columns and not aggregated_metrics: raise ValueError("Table chart configuration resulted in no displayable columns") + # Use the viz_type from config (defaults to "table", can be "ag-grid-table") form_data: Dict[str, Any] = { - "viz_type": "table", + "viz_type": config.viz_type, } # Handle raw columns (no aggregation) @@ -370,7 +371,8 @@ def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilit } viz_type = viz_type_map.get(kind, "echarts_timeseries_line") elif chart_type == "table": - viz_type = "table" + # Use the viz_type from config if available (table or ag-grid-table) + viz_type = getattr(config, "viz_type", "table") else: viz_type = "unknown" @@ -382,10 +384,11 @@ def analyze_chart_capabilities(chart: Any | None, config: Any) -> ChartCapabilit "echarts_timeseries_scatter", "deck_scatter", "deck_hex", + "ag-grid-table", # AG Grid tables are interactive ] supports_interaction = viz_type in interactive_types - supports_drill_down = viz_type in ["table", "pivot_table_v2"] + supports_drill_down = viz_type in ["table", "pivot_table_v2", "ag-grid-table"] supports_real_time = viz_type in [ "echarts_timeseries_line", "echarts_timeseries_bar", @@ -433,7 +436,8 @@ def analyze_chart_semantics(chart: Any | None, config: Any) -> ChartSemantics: } viz_type = viz_type_map.get(kind, "echarts_timeseries_line") elif chart_type == "table": - viz_type = "table" + # Use the viz_type from config if available (table or ag-grid-table) + viz_type = getattr(config, "viz_type", "table") else: viz_type = "unknown" @@ -442,6 +446,10 @@ def analyze_chart_semantics(chart: Any | None, config: Any) -> ChartSemantics: "echarts_timeseries_line": "Shows trends and changes over time", "echarts_timeseries_bar": "Compares values across categories or time periods", "table": "Displays detailed data in tabular format", + "ag-grid-table": ( + "Interactive table with advanced features like column resizing, " + "sorting, filtering, and server-side pagination" + ), "pie": "Shows proportional relationships within a dataset", "echarts_area": "Emphasizes cumulative totals and part-to-whole relationships", } diff --git a/superset/mcp_service/chart/schemas.py b/superset/mcp_service/chart/schemas.py index 1f6a5a0d34a..f28d7e75f69 100644 --- a/superset/mcp_service/chart/schemas.py +++ b/superset/mcp_service/chart/schemas.py @@ -608,6 +608,14 @@ class TableChartConfig(BaseModel): chart_type: Literal["table"] = Field( ..., description="Chart type (REQUIRED: must be 'table')" ) + viz_type: Literal["table", "ag-grid-table"] = Field( + "table", + description=( + "Visualization type: 'table' for standard table, 'ag-grid-table' for " + "AG Grid Interactive Table with advanced features like column resizing, " + "sorting, filtering, and server-side pagination" + ), + ) columns: List[ColumnRef] = Field( ..., min_length=1, diff --git a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py index 7292220cd75..5bbae63912b 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_schemas.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_schemas.py @@ -54,6 +54,59 @@ class TestTableChartConfig: ) assert len(config.columns) == 2 + def test_default_viz_type_is_table(self) -> None: + """Test that default viz_type is 'table'.""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + ) + assert config.viz_type == "table" + + def test_ag_grid_table_viz_type_accepted(self) -> None: + """Test that viz_type='ag-grid-table' is accepted for AG Grid table.""" + config = TableChartConfig( + chart_type="table", + viz_type="ag-grid-table", + columns=[ + ColumnRef(name="product_line"), + ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + ], + ) + assert config.viz_type == "ag-grid-table" + assert len(config.columns) == 2 + + def test_ag_grid_table_with_all_options(self) -> None: + """Test AG Grid table with filters and sorting.""" + from superset.mcp_service.chart.schemas import FilterConfig + + config = TableChartConfig( + chart_type="table", + viz_type="ag-grid-table", + columns=[ + ColumnRef(name="product_line"), + ColumnRef(name="quantity", aggregate="SUM", label="Total Quantity"), + ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + ], + filters=[FilterConfig(column="status", op="=", value="active")], + sort_by=["product_line"], + ) + assert config.viz_type == "ag-grid-table" + assert len(config.columns) == 3 + assert config.filters is not None + assert len(config.filters) == 1 + assert config.sort_by == ["product_line"] + + def test_invalid_viz_type_rejected(self) -> None: + """Test that invalid viz_type values are rejected.""" + from pydantic import ValidationError + + with pytest.raises(ValidationError): + TableChartConfig( + chart_type="table", + viz_type="invalid-type", + columns=[ColumnRef(name="product")], + ) + class TestXYChartConfig: """Test XYChartConfig validation.""" diff --git a/tests/unit_tests/mcp_service/chart/test_chart_utils.py b/tests/unit_tests/mcp_service/chart/test_chart_utils.py index 560af4e8632..4f4f3a3857b 100644 --- a/tests/unit_tests/mcp_service/chart/test_chart_utils.py +++ b/tests/unit_tests/mcp_service/chart/test_chart_utils.py @@ -152,6 +152,58 @@ class TestMapTableConfig: result = map_table_config(config) assert result["order_by_cols"] == ["product", "revenue"] + def test_map_table_config_ag_grid_table(self) -> None: + """Test table config mapping with AG Grid Interactive Table viz_type""" + config = TableChartConfig( + chart_type="table", + viz_type="ag-grid-table", + columns=[ + ColumnRef(name="product_line"), + ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), + ], + ) + + result = map_table_config(config) + + # AG Grid tables use 'ag-grid-table' viz_type + assert result["viz_type"] == "ag-grid-table" + assert result["query_mode"] == "aggregate" + assert len(result["metrics"]) == 1 + assert result["metrics"][0]["aggregate"] == "SUM" + # Non-aggregated columns should be in groupby + assert "groupby" in result + assert "product_line" in result["groupby"] + + def test_map_table_config_ag_grid_raw_mode(self) -> None: + """Test AG Grid table with raw columns (no aggregates)""" + config = TableChartConfig( + chart_type="table", + viz_type="ag-grid-table", + columns=[ + ColumnRef(name="product_line"), + ColumnRef(name="category"), + ColumnRef(name="region"), + ], + ) + + result = map_table_config(config) + + assert result["viz_type"] == "ag-grid-table" + assert result["query_mode"] == "raw" + assert result["all_columns"] == ["product_line", "category", "region"] + assert "metrics" not in result + + def test_map_table_config_default_viz_type(self) -> None: + """Test that default viz_type is 'table' not 'ag-grid-table'""" + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="product")], + ) + + result = map_table_config(config) + + assert result["viz_type"] == "table" + class TestMapXYConfig: """Test map_xy_config function"""