# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """Tests for chart utilities module""" from typing import Any from unittest.mock import MagicMock, patch import pytest from superset.constants import NO_TIME_RANGE from superset.mcp_service.chart.chart_utils import ( _add_adhoc_filters, _ensure_temporal_adhoc_filter, adhoc_filters_to_query_filters, configure_temporal_handling, create_metric_object, generate_chart_name, generate_explore_link, is_column_truly_temporal, map_config_to_form_data, map_filter_operator, map_table_config, map_xy_config, validate_chart_dataset, ) from superset.mcp_service.chart.schemas import ( AxisConfig, ColumnRef, FilterConfig, LegendConfig, TableChartConfig, XYChartConfig, ) from superset.utils.core import FilterOperator, GenericDataType class TestCreateMetricObject: """Test create_metric_object function""" def test_create_metric_object_with_aggregate(self) -> None: """Test creating metric object with specified aggregate""" col = ColumnRef(name="revenue", aggregate="SUM", label="Total Revenue") result = create_metric_object(col) assert isinstance(result, dict) assert result["aggregate"] == "SUM" assert result["column"]["column_name"] == "revenue" assert result["label"] == "Total Revenue" assert result["optionName"] == "metric_revenue" assert result["expressionType"] == "SIMPLE" def test_create_metric_object_default_aggregate(self) -> None: """Test creating metric object with default aggregate""" col = ColumnRef(name="orders") result = create_metric_object(col) assert isinstance(result, dict) assert result["aggregate"] == "SUM" assert result["column"]["column_name"] == "orders" assert result["label"] == "SUM(orders)" assert result["optionName"] == "metric_orders" def test_create_metric_object_saved_metric_returns_string(self) -> None: """Test that saved metrics return a plain string metric name""" col = ColumnRef(name="total_revenue", saved_metric=True) result = create_metric_object(col) assert result == "total_revenue" assert isinstance(result, str) def test_create_metric_object_saved_metric_ignores_aggregate(self) -> None: """Test that saved metrics ignore aggregate even if somehow set""" col = ColumnRef(name="total_revenue", saved_metric=True, aggregate="SUM") result = create_metric_object(col) # saved_metric validator clears aggregate, result is plain string assert result == "total_revenue" class TestMapFilterOperator: """Test map_filter_operator function""" def test_map_filter_operators(self) -> None: """Test mapping of various filter operators""" assert map_filter_operator("=") == "==" assert map_filter_operator(">") == ">" assert map_filter_operator("<") == "<" assert map_filter_operator(">=") == ">=" assert map_filter_operator("<=") == "<=" assert map_filter_operator("!=") == "!=" def test_map_filter_operators_pattern_matching(self) -> None: """Test mapping of pattern matching operators""" assert map_filter_operator("LIKE") == "LIKE" assert map_filter_operator("ILIKE") == "ILIKE" assert map_filter_operator("NOT LIKE") == "NOT LIKE" def test_map_filter_operators_set(self) -> None: """Test mapping of set operators""" assert map_filter_operator("IN") == "IN" assert map_filter_operator("NOT IN") == "NOT IN" def test_map_filter_operator_unknown(self) -> None: """Test mapping of unknown operator returns original""" assert map_filter_operator("UNKNOWN") == "UNKNOWN" class TestMapTableConfig: """Test map_table_config function""" def test_map_table_config_basic(self) -> None: """Test basic table config mapping with aggregated columns""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product", aggregate="COUNT"), ColumnRef(name="revenue", aggregate="SUM"), ], ) result = map_table_config(config) assert result["viz_type"] == "table" assert result["query_mode"] == "aggregate" # Aggregated columns should be in metrics, not all_columns assert "all_columns" not in result assert len(result["metrics"]) == 2 assert result["metrics"][0]["aggregate"] == "COUNT" assert result["metrics"][1]["aggregate"] == "SUM" def test_map_table_config_raw_columns(self) -> None: """Test table config mapping with raw columns (no aggregates)""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product"), ColumnRef(name="category"), ], ) result = map_table_config(config) assert result["viz_type"] == "table" assert result["query_mode"] == "raw" # Raw columns should be in all_columns assert result["all_columns"] == ["product", "category"] assert "metrics" not in result def test_map_table_config_with_filters(self) -> None: """Test table config mapping with filters""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], filters=[FilterConfig(column="status", op="=", value="active")], ) result = map_table_config(config) assert "adhoc_filters" in result assert len(result["adhoc_filters"]) == 1 filter_obj = result["adhoc_filters"][0] assert filter_obj["subject"] == "status" assert filter_obj["operator"] == "==" assert filter_obj["comparator"] == "active" assert filter_obj["expressionType"] == "SIMPLE" def test_map_table_config_with_like_filter(self) -> None: """Test table config mapping with LIKE filter for pattern matching""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="name")], filters=[FilterConfig(column="name", op="LIKE", value="%mario%")], ) result = map_table_config(config) assert "adhoc_filters" in result assert len(result["adhoc_filters"]) == 1 filter_obj = result["adhoc_filters"][0] assert filter_obj["subject"] == "name" assert filter_obj["operator"] == "LIKE" assert filter_obj["comparator"] == "%mario%" assert filter_obj["expressionType"] == "SIMPLE" def test_map_table_config_with_ilike_filter(self) -> None: """Test table config mapping with ILIKE filter for case-insensitive matching""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="name")], filters=[FilterConfig(column="name", op="ILIKE", value="%mario%")], ) result = map_table_config(config) assert "adhoc_filters" in result filter_obj = result["adhoc_filters"][0] assert filter_obj["operator"] == "ILIKE" assert filter_obj["comparator"] == "%mario%" def test_map_table_config_with_in_filter(self) -> None: """Test table config mapping with IN filter for list matching""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="platform")], filters=[ FilterConfig( column="platform", op="IN", value=["Wii", "PS3", "Xbox360"] ) ], ) result = map_table_config(config) assert "adhoc_filters" in result filter_obj = result["adhoc_filters"][0] assert filter_obj["subject"] == "platform" assert filter_obj["operator"] == "IN" assert filter_obj["comparator"] == ["Wii", "PS3", "Xbox360"] def test_map_table_config_with_not_in_filter(self) -> None: """Test table config mapping with NOT IN filter""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="status")], filters=[ FilterConfig( column="status", op="NOT IN", value=["archived", "deleted"] ) ], ) result = map_table_config(config) assert "adhoc_filters" in result filter_obj = result["adhoc_filters"][0] assert filter_obj["operator"] == "NOT IN" assert filter_obj["comparator"] == ["archived", "deleted"] def test_map_table_config_with_mixed_filters(self) -> None: """Test table config mapping with mixed filter operators""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="name"), ColumnRef(name="sales")], filters=[ FilterConfig(column="platform", op="=", value="Wii"), FilterConfig(column="name", op="ILIKE", value="%mario%"), FilterConfig(column="genre", op="IN", value=["Sports", "Racing"]), ], ) result = map_table_config(config) assert len(result["adhoc_filters"]) == 3 assert result["adhoc_filters"][0]["operator"] == "==" assert result["adhoc_filters"][1]["operator"] == "ILIKE" assert result["adhoc_filters"][1]["comparator"] == "%mario%" assert result["adhoc_filters"][2]["operator"] == "IN" assert result["adhoc_filters"][2]["comparator"] == ["Sports", "Racing"] def test_map_table_config_with_sort(self) -> None: """Test table config mapping with sort""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], sort_by=["product", "revenue"], ) result = map_table_config(config) assert result["order_by_cols"] == ["product", "revenue"] def test_map_table_config_ag_grid_table(self) -> None: """Test table config mapping with AG Grid Interactive Table viz_type""" config = TableChartConfig( chart_type="table", viz_type="ag-grid-table", columns=[ ColumnRef(name="product_line"), ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ], ) result = map_table_config(config) # AG Grid tables use 'ag-grid-table' viz_type assert result["viz_type"] == "ag-grid-table" assert result["query_mode"] == "aggregate" assert len(result["metrics"]) == 1 assert result["metrics"][0]["aggregate"] == "SUM" # Non-aggregated columns should be in groupby assert "groupby" in result assert "product_line" in result["groupby"] def test_map_table_config_ag_grid_raw_mode(self) -> None: """Test AG Grid table with raw columns (no aggregates)""" config = TableChartConfig( chart_type="table", viz_type="ag-grid-table", columns=[ ColumnRef(name="product_line"), ColumnRef(name="category"), ColumnRef(name="region"), ], ) result = map_table_config(config) assert result["viz_type"] == "ag-grid-table" assert result["query_mode"] == "raw" assert result["all_columns"] == ["product_line", "category", "region"] assert "metrics" not in result def test_map_table_config_default_viz_type(self) -> None: """Test that default viz_type is 'table' not 'ag-grid-table'""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], ) result = map_table_config(config) assert result["viz_type"] == "table" def test_map_table_config_row_limit(self) -> None: """Test that row_limit is mapped to form_data.""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], row_limit=500, ) result = map_table_config(config) assert result["row_limit"] == 500 def test_map_table_config_default_row_limit(self) -> None: """Test that default row_limit is mapped to form_data.""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product", aggregate="SUM")], ) result = map_table_config(config) assert result["row_limit"] == 1000 def test_map_table_config_saved_metric_as_metric(self) -> None: """Test that saved metrics are routed to metrics, not raw columns.""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product_line"), ColumnRef(name="total_revenue", saved_metric=True), ], ) result = map_table_config(config) assert result["query_mode"] == "aggregate" assert result["metrics"] == ["total_revenue"] assert "product_line" in result["groupby"] def test_map_table_config_saved_metric_only(self) -> None: """Test table with only saved metrics (no raw columns).""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="total_revenue", saved_metric=True), ColumnRef(name="avg_order_value", saved_metric=True), ], ) result = map_table_config(config) assert result["query_mode"] == "aggregate" assert result["metrics"] == ["total_revenue", "avg_order_value"] assert "all_columns" not in result class TestAddAdhocFilters: """Test _add_adhoc_filters helper function""" def test_adds_filters_to_form_data(self) -> None: """Test that filters are correctly added to form_data.""" form_data: dict[str, Any] = {} filters = [ FilterConfig(column="region", op="=", value="US"), FilterConfig(column="year", op=">", value=2020), ] _add_adhoc_filters(form_data, filters) assert "adhoc_filters" in form_data assert len(form_data["adhoc_filters"]) == 2 assert form_data["adhoc_filters"][0]["subject"] == "region" assert form_data["adhoc_filters"][0]["operator"] == "==" assert form_data["adhoc_filters"][1]["subject"] == "year" assert form_data["adhoc_filters"][1]["operator"] == ">" def test_no_filters_does_nothing(self) -> None: """Test that None filters leave form_data unchanged.""" form_data: dict[str, Any] = {"viz_type": "table"} _add_adhoc_filters(form_data, None) assert "adhoc_filters" not in form_data def test_empty_list_does_nothing(self) -> None: """Test that empty filter list leaves form_data unchanged.""" form_data: dict[str, Any] = {"viz_type": "table"} _add_adhoc_filters(form_data, []) assert "adhoc_filters" not in form_data class TestMapXYConfig: """Test map_xy_config function""" def test_map_xy_config_line_chart(self) -> None: """Test XY config mapping for line chart""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_line" assert result["x_axis"] == "date" assert len(result["metrics"]) == 1 assert result["metrics"][0]["aggregate"] == "SUM" def test_map_xy_config_with_groupby(self) -> None: """Test XY config mapping with group by""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="bar", group_by=ColumnRef(name="region"), ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_bar" assert result["groupby"] == ["region"] def test_map_xy_config_with_axes(self) -> None: """Test XY config mapping with axis configurations""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="area", x_axis=AxisConfig(title="Date", format="%Y-%m-%d"), y_axis=AxisConfig(title="Revenue", scale="log", format="$,.2f"), ) result = map_xy_config(config) assert result["viz_type"] == "echarts_area" assert result["x_axis_title"] == "Date" assert result["x_axis_format"] == "%Y-%m-%d" assert result["y_axis_title"] == "Revenue" assert result["y_axis_format"] == "$,.2f" assert result["y_axis_scale"] == "log" def test_map_xy_config_with_legend(self) -> None: """Test XY config mapping with legend configuration""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="scatter", legend=LegendConfig(show=False, position="top"), ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_scatter" assert result["show_legend"] is False assert result["legend_orientation"] == "top" def test_map_xy_config_with_time_grain_month(self) -> None: """Test XY config mapping with monthly time grain""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="bar", time_grain="P1M", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_bar" assert result["x_axis"] == "order_date" assert result["time_grain_sqla"] == "P1M" def test_map_xy_config_with_time_grain_day(self) -> None: """Test XY config mapping with daily time grain""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="created_at"), y=[ColumnRef(name="count", aggregate="COUNT")], kind="line", time_grain="P1D", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_line" assert result["x_axis"] == "created_at" assert result["time_grain_sqla"] == "P1D" def test_map_xy_config_with_time_grain_hour(self) -> None: """Test XY config mapping with hourly time grain""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="timestamp"), y=[ColumnRef(name="requests", aggregate="SUM")], kind="area", time_grain="PT1H", ) result = map_xy_config(config) assert result["time_grain_sqla"] == "PT1H" def test_map_xy_config_without_time_grain(self) -> None: """Test XY config mapping without time grain (should not set time_grain_sqla)""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line", ) result = map_xy_config(config) assert "time_grain_sqla" not in result def test_map_xy_config_with_time_grain_and_groupby(self) -> None: """Test XY config mapping with time grain and group by""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", time_grain="P1W", group_by=ColumnRef(name="category"), ) result = map_xy_config(config) assert result["time_grain_sqla"] == "P1W" assert result["groupby"] == ["category"] assert result["x_axis"] == "order_date" def test_map_xy_config_bar_horizontal_orientation(self) -> None: """Test XY config mapping for horizontal bar chart""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="department"), y=[ColumnRef(name="headcount", aggregate="SUM")], kind="bar", orientation="horizontal", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_bar" assert result["orientation"] == "horizontal" def test_map_xy_config_bar_vertical_orientation(self) -> None: """Test XY config mapping for vertical bar chart (explicit)""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="category"), y=[ColumnRef(name="sales", aggregate="SUM")], kind="bar", orientation="vertical", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_bar" assert result["orientation"] == "vertical" def test_map_xy_config_bar_no_orientation(self) -> None: """Test XY config mapping for bar chart without orientation.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="category"), y=[ColumnRef(name="sales", aggregate="SUM")], kind="bar", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_bar" assert "orientation" not in result def test_map_xy_config_line_orientation_ignored(self) -> None: """Test that orientation is ignored for non-bar chart types""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", orientation="horizontal", ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_line" assert "orientation" not in result def test_map_xy_config_bar_horizontal_with_stacked(self) -> None: """Test horizontal bar chart with stacked option""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="department"), y=[ColumnRef(name="headcount", aggregate="SUM")], kind="bar", orientation="horizontal", stacked=True, group_by=ColumnRef(name="level"), ) result = map_xy_config(config) assert result["viz_type"] == "echarts_timeseries_bar" assert result["orientation"] == "horizontal" assert result["stack"] == "Stack" assert result["groupby"] == ["level"] @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_map_xy_config_with_filters(self, mock_is_temporal) -> None: """Test that filters are mapped to adhoc_filters in XY form_data.""" mock_is_temporal.return_value = True config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", filters=[FilterConfig(column="region", op="=", value="US")], ) result = map_xy_config(config) assert "adhoc_filters" in result # User filter + auto-added TEMPORAL_RANGE filter for temporal x-axis assert len(result["adhoc_filters"]) == 2 assert result["adhoc_filters"][0]["subject"] == "region" assert result["adhoc_filters"][0]["operator"] == "==" assert result["adhoc_filters"][0]["comparator"] == "US" assert result["adhoc_filters"][1]["operator"] == "TEMPORAL_RANGE" assert result["adhoc_filters"][1]["subject"] == "date" @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_map_xy_config_row_limit(self, mock_is_temporal) -> None: """Test that row_limit is mapped to form_data.""" mock_is_temporal.return_value = True config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", row_limit=250, ) result = map_xy_config(config) assert result["row_limit"] == 250 @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_map_xy_config_default_row_limit(self, mock_is_temporal) -> None: """Test that default row_limit is mapped to form_data.""" mock_is_temporal.return_value = True config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="bar", ) result = map_xy_config(config) assert result["row_limit"] == 10000 @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_map_xy_config_saved_metric(self, mock_is_temporal: Any) -> None: """Test XY config with saved metric emits string in metrics list""" mock_is_temporal.return_value = True config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date"), y=[ColumnRef(name="total_revenue", saved_metric=True)], kind="line", ) result = map_xy_config(config, dataset_id=1) assert result["metrics"] == ["total_revenue"] @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_map_xy_config_mixed_saved_and_adhoc_metrics( self, mock_is_temporal: Any ) -> None: """Test XY config with both saved and ad-hoc metrics""" mock_is_temporal.return_value = True config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date"), y=[ ColumnRef(name="total_revenue", saved_metric=True), ColumnRef(name="quantity", aggregate="SUM"), ], kind="line", ) result = map_xy_config(config, dataset_id=1) assert len(result["metrics"]) == 2 assert result["metrics"][0] == "total_revenue" assert isinstance(result["metrics"][1], dict) assert result["metrics"][1]["aggregate"] == "SUM" class TestMapConfigToFormData: """Test map_config_to_form_data function""" def test_map_table_config_type(self) -> None: """Test mapping table config type""" config = TableChartConfig(chart_type="table", columns=[ColumnRef(name="test")]) result = map_config_to_form_data(config) assert result["viz_type"] == "table" def test_map_xy_config_type(self) -> None: """Test mapping XY config type""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue")], kind="line", ) result = map_config_to_form_data(config) assert result["viz_type"] == "echarts_timeseries_line" def test_map_unsupported_config_type(self) -> None: """Test mapping unsupported config type raises error""" with pytest.raises(ValueError, match="Unsupported config type"): map_config_to_form_data("invalid_config") # type: ignore class TestGenerateChartName: """Test generate_chart_name function""" def test_table_no_aggregates(self) -> None: """Table without aggregates uses column names.""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product"), ColumnRef(name="revenue"), ], ) result = generate_chart_name(config) assert result == "Product, Revenue Table" def test_table_no_aggregates_with_dataset_name(self) -> None: """Table without aggregates includes dataset name when available.""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], ) result = generate_chart_name(config, dataset_name="Orders") assert result == "Orders Records" def test_table_with_aggregates(self) -> None: """Table with aggregates produces a summary name.""" config = TableChartConfig( chart_type="table", columns=[ ColumnRef(name="product"), ColumnRef(name="revenue", aggregate="SUM"), ], ) result = generate_chart_name(config) assert result == "Sum(Revenue) Summary" def test_line_chart_over_time(self) -> None: """Line chart without group_by uses 'Over Time' format.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", ) result = generate_chart_name(config) assert result == "Sum(Revenue) Over Time" def test_bar_chart_by_dimension(self) -> None: """Bar chart uses 'by [X]' format.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="product_category"), y=[ColumnRef(name="order_count", aggregate="COUNT")], kind="bar", ) result = generate_chart_name(config) assert result == "Count(Order Count) by Product Category" def test_line_chart_with_group_by(self) -> None: """Line chart with group_by uses 'by [group]' format.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", group_by=ColumnRef(name="sales_rep"), ) result = generate_chart_name(config) assert result == "Sum(Revenue) by Sales Rep" def test_scatter_plot(self) -> None: """Scatter plot uses 'Y vs X' format.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="age"), y=[ColumnRef(name="income")], kind="scatter", ) result = generate_chart_name(config) assert result == "Income vs Age" def test_time_grain_in_context(self) -> None: """Time grain is appended as context.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="line", time_grain="P1M", ) result = generate_chart_name(config) assert result == "Sum(Revenue) Over Time \u2013 Monthly" def test_filter_context(self) -> None: """Filters are appended as context.""" config = TableChartConfig( chart_type="table", columns=[ColumnRef(name="product")], filters=[FilterConfig(column="region", op="=", value="West")], ) result = generate_chart_name(config, dataset_name="Orders") assert result == "Orders Records \u2013 Region West" def test_name_truncation(self) -> None: """Names exceeding 60 chars are truncated.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ ColumnRef( name="very_long_metric_name_that_goes_on_and_on", aggregate="SUM" ) ], kind="line", group_by=ColumnRef(name="another_very_long_dimension_name_here"), ) result = generate_chart_name(config) assert len(result) <= 60 def test_unsupported_config_type(self) -> None: """Unsupported config type returns generic name.""" result = generate_chart_name("invalid_config") # type: ignore assert result == "Chart" def test_custom_labels_used(self) -> None: """Column labels are preferred over names.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="ds", label="Date"), y=[ColumnRef(name="cnt", aggregate="COUNT", label="Order Count")], kind="bar", ) result = generate_chart_name(config) assert result == "Order Count by Date" class TestGenerateExploreLink: """Test generate_explore_link function""" @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") def test_generate_explore_link_uses_base_url(self, mock_get_base_url) -> None: """Test that generate_explore_link uses the configured base URL""" from urllib.parse import urlparse mock_get_base_url.return_value = "https://superset.example.com" form_data = {"viz_type": "table", "metrics": ["count"]} # Mock dataset not found to trigger fallback URL with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None): result = generate_explore_link("123", form_data) # Should use the configured base URL - use urlparse to avoid CodeQL warning parsed_url = urlparse(result) expected_netloc = "superset.example.com" assert parsed_url.scheme == "https" assert parsed_url.netloc == expected_netloc assert "/explore/" in parsed_url.path assert "datasource_id=123" in result @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") def test_generate_explore_link_fallback_url(self, mock_get_base_url) -> None: """Test generate_explore_link returns fallback URL when dataset not found""" mock_get_base_url.return_value = "http://localhost:9001" form_data = {"viz_type": "table"} # Mock dataset not found scenario with patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None): result = generate_explore_link("999", form_data) assert ( result == "http://localhost:9001/explore/?datasource_type=table&datasource_id=999" ) @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") @patch("superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand") def test_generate_explore_link_with_form_data_key( self, mock_command, mock_get_base_url ) -> None: """Test generate_explore_link creates form_data_key when dataset exists""" mock_get_base_url.return_value = "http://localhost:9001" mock_command.return_value.run.return_value = "test_form_data_key" # Mock dataset exists mock_dataset = type("Dataset", (), {"id": 123})() with patch( "superset.daos.dataset.DatasetDAO.find_by_id", return_value=mock_dataset ): result = generate_explore_link(123, {"viz_type": "table"}) assert ( result == "http://localhost:9001/explore/?form_data_key=test_form_data_key" ) mock_command.assert_called_once() @patch("superset.mcp_service.chart.chart_utils.get_superset_base_url") def test_generate_explore_link_exception_handling(self, mock_get_base_url) -> None: """Test generate_explore_link handles SQLAlchemy exceptions gracefully""" from sqlalchemy.exc import SQLAlchemyError mock_get_base_url.return_value = "http://localhost:9001" # Mock SQLAlchemy exception during dataset lookup with patch( "superset.daos.dataset.DatasetDAO.find_by_id", side_effect=SQLAlchemyError("DB error"), ): result = generate_explore_link("123", {"viz_type": "table"}) # Should fallback to basic URL assert ( result == "http://localhost:9001/explore/?datasource_type=table&datasource_id=123" ) class TestCriticalBugFixes: """Test critical bug fixes for chart utilities.""" def test_time_series_aggregation_fix(self) -> None: """Test that time series charts preserve temporal dimension.""" # Create a time series chart configuration config = XYChartConfig( chart_type="xy", kind="line", x=ColumnRef(name="order_date"), y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")], ) form_data = map_xy_config(config) # Verify the fix: x_axis should be set correctly assert form_data["x_axis"] == "order_date" # Verify the fix: groupby should not duplicate x_axis # This prevents the "Duplicate column/metric labels" error assert "groupby" not in form_data or "order_date" not in form_data.get( "groupby", [] ) # Verify chart type mapping assert form_data["viz_type"] == "echarts_timeseries_line" def test_time_series_with_explicit_group_by(self) -> None: """Test time series with explicit group_by different from x_axis.""" config = XYChartConfig( chart_type="xy", kind="line", x=ColumnRef(name="order_date"), y=[ColumnRef(name="sales", aggregate="SUM", label="Total Sales")], group_by=ColumnRef(name="category"), ) form_data = map_xy_config(config) # Verify x_axis is set assert form_data["x_axis"] == "order_date" # Verify groupby only contains the explicit group_by, not x_axis assert form_data.get("groupby") == ["category"] assert "order_date" not in form_data.get("groupby", []) def test_duplicate_label_prevention(self) -> None: """Test that duplicate column/metric labels are prevented.""" # This configuration would previously cause: # "Duplicate column/metric labels: 'price_each'" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="price_each", label="Price Each"), # Custom label y=[ ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ColumnRef(name="quantity", aggregate="COUNT", label="Order Count"), ], group_by=ColumnRef(name="price_each"), # Same column as x_axis kind="line", ) form_data = map_xy_config(config) # Verify the fix: x_axis is set assert form_data["x_axis"] == "price_each" # Verify the fix: groupby is empty because group_by == x_axis # This prevents the duplicate label error assert "groupby" not in form_data or not form_data["groupby"] def test_metric_object_creation_with_labels(self) -> None: """Test that metric objects are created correctly with proper labels.""" config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ ColumnRef(name="sales", aggregate="SUM", label="Total Sales"), ColumnRef(name="profit", aggregate="AVG"), # No custom label ], kind="bar", ) form_data = map_xy_config(config) # Verify metrics are created correctly metrics = form_data["metrics"] assert len(metrics) == 2 # First metric with custom label assert metrics[0]["label"] == "Total Sales" assert metrics[0]["aggregate"] == "SUM" assert metrics[0]["column"]["column_name"] == "sales" # Second metric with auto-generated label assert metrics[1]["label"] == "AVG(profit)" assert metrics[1]["aggregate"] == "AVG" assert metrics[1]["column"]["column_name"] == "profit" def test_chart_type_mapping_comprehensive(self) -> None: """Test that chart types are mapped correctly to Superset viz types.""" test_cases = [ ("line", "echarts_timeseries_line"), ("bar", "echarts_timeseries_bar"), ("area", "echarts_area"), ("scatter", "echarts_timeseries_scatter"), ] for kind, expected_viz_type in test_cases: config = XYChartConfig( chart_type="xy", x=ColumnRef(name="date"), y=[ColumnRef(name="value", aggregate="SUM")], kind=kind, ) form_data = map_xy_config(config) assert form_data["viz_type"] == expected_viz_type class TestIsColumnTrulyTemporal: """Test is_column_truly_temporal function using db_engine_spec""" def _create_mock_dataset( self, column_name: str, column_type: str, generic_type: GenericDataType, ): """Helper to create a mock dataset with proper db_engine_spec""" from unittest.mock import MagicMock from superset.utils.core import ColumnSpec mock_column = MagicMock() mock_column.column_name = column_name mock_column.type = column_type mock_db_engine_spec = MagicMock() mock_column_spec = ColumnSpec( sqla_type=MagicMock(), generic_type=generic_type, is_dttm=False ) mock_db_engine_spec.get_column_spec.return_value = mock_column_spec mock_database = MagicMock() mock_database.db_engine_spec = mock_db_engine_spec mock_dataset = MagicMock() mock_dataset.columns = [mock_column] mock_dataset.database = mock_database return mock_dataset def test_returns_true_when_no_dataset_id(self) -> None: """Test returns True (default) when dataset_id is None""" result = is_column_truly_temporal("year", None) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_when_dataset_not_found(self, mock_dao) -> None: """Test returns True when dataset is not found""" mock_dao.find_by_id.return_value = None result = is_column_truly_temporal("year", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_returns_false_for_numeric_column(self, mock_dao) -> None: """Test returns False for NUMERIC generic type (e.g., BIGINT)""" mock_dataset = self._create_mock_dataset( "year", "BIGINT", GenericDataType.NUMERIC ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", 123) assert result is False @patch("superset.daos.dataset.DatasetDAO") def test_returns_false_for_integer_column(self, mock_dao) -> None: """Test returns False for INTEGER column (NUMERIC generic type)""" mock_dataset = self._create_mock_dataset( "month", "INTEGER", GenericDataType.NUMERIC ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("month", 123) assert result is False @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_for_temporal_column(self, mock_dao) -> None: """Test returns True for TEMPORAL generic type (e.g., TIMESTAMP)""" mock_dataset = self._create_mock_dataset( "created_at", "TIMESTAMP", GenericDataType.TEMPORAL ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("created_at", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_for_date_column(self, mock_dao) -> None: """Test returns True for DATE column (TEMPORAL generic type)""" mock_dataset = self._create_mock_dataset( "order_date", "DATE", GenericDataType.TEMPORAL ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("order_date", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_case_insensitive_column_name_lookup(self, mock_dao) -> None: """Test column name lookup is case insensitive""" mock_dataset = self._create_mock_dataset( "Year", "BIGINT", GenericDataType.NUMERIC ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", 123) assert result is False @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_on_value_error(self, mock_dao) -> None: """Test returns True (default) when ValueError occurs""" mock_dao.find_by_id.side_effect = ValueError("Invalid ID") result = is_column_truly_temporal("year", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_returns_true_on_attribute_error(self, mock_dao) -> None: """Test returns True (default) when AttributeError occurs""" mock_dao.find_by_id.side_effect = AttributeError("Missing attribute") result = is_column_truly_temporal("year", 123) assert result is True @patch("superset.daos.dataset.DatasetDAO") def test_handles_uuid_dataset_id(self, mock_dao) -> None: """Test handles UUID string as dataset_id""" mock_dataset = self._create_mock_dataset( "year", "BIGINT", GenericDataType.NUMERIC ) mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", "abc-123-uuid") assert result is False mock_dao.find_by_id.assert_called_with("abc-123-uuid", id_column="uuid") @patch("superset.daos.dataset.DatasetDAO") def test_falls_back_to_is_dttm_when_no_column_spec(self, mock_dao) -> None: """Test falls back to is_dttm flag when get_column_spec returns None""" from unittest.mock import MagicMock mock_column = MagicMock() mock_column.column_name = "year" mock_column.type = "UNKNOWN_TYPE" mock_column.is_dttm = False mock_db_engine_spec = MagicMock() mock_db_engine_spec.get_column_spec.return_value = None mock_database = MagicMock() mock_database.db_engine_spec = mock_db_engine_spec mock_dataset = MagicMock() mock_dataset.columns = [mock_column] mock_dataset.database = mock_database mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", 123) assert result is False @patch("superset.daos.dataset.DatasetDAO") def test_falls_back_to_is_dttm_when_no_type(self, mock_dao) -> None: """Test falls back to is_dttm flag when column has no type""" from unittest.mock import MagicMock mock_column = MagicMock() mock_column.column_name = "year" mock_column.type = None mock_column.is_dttm = True mock_dataset = MagicMock() mock_dataset.columns = [mock_column] mock_dao.find_by_id.return_value = mock_dataset result = is_column_truly_temporal("year", 123) assert result is True class TestConfigureTemporalHandling: """Test configure_temporal_handling function""" def test_temporal_column_with_time_grain(self) -> None: """Test temporal column sets time_grain_sqla and granularity_sqla""" form_data: dict[str, Any] = {"x_axis": "order_date"} configure_temporal_handling(form_data, x_is_temporal=True, time_grain="P1M") assert form_data["time_grain_sqla"] == "P1M" assert form_data["granularity_sqla"] == "order_date" def test_temporal_column_without_time_grain(self) -> None: """Test temporal column sets granularity_sqla but not time_grain_sqla""" form_data: dict[str, Any] = {"x_axis": "order_date"} configure_temporal_handling(form_data, x_is_temporal=True, time_grain=None) assert "time_grain_sqla" not in form_data assert form_data["granularity_sqla"] == "order_date" def test_non_temporal_column_sets_categorical_config(self) -> None: """Test non-temporal column sets categorical configuration""" form_data: dict[str, Any] = {} configure_temporal_handling(form_data, x_is_temporal=False, time_grain=None) assert form_data["x_axis_sort_series_type"] == "name" assert form_data["x_axis_sort_series_ascending"] is True assert form_data["time_grain_sqla"] is None assert form_data["granularity_sqla"] is None def test_non_temporal_column_ignores_time_grain(self) -> None: """Test non-temporal column ignores time_grain parameter""" form_data: dict[str, Any] = {} configure_temporal_handling(form_data, x_is_temporal=False, time_grain="P1M") # Should still set categorical config, not time_grain assert form_data["time_grain_sqla"] is None assert form_data["x_axis_sort_series_type"] == "name" class TestMapXYConfigWithNonTemporalColumn: """Test map_xy_config with non-temporal columns (DATE_TRUNC fix)""" @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_non_temporal_column_disables_time_grain(self, mock_is_temporal) -> None: """Test non-temporal column sets categorical config""" mock_is_temporal.return_value = False config = XYChartConfig( chart_type="xy", x=ColumnRef(name="year"), y=[ColumnRef(name="sales", aggregate="SUM")], kind="bar", ) result = map_xy_config(config, dataset_id=123) assert result["x_axis"] == "year" assert result["x_axis_sort_series_type"] == "name" assert result["x_axis_sort_series_ascending"] is True assert result["time_grain_sqla"] is None assert result["granularity_sqla"] is None @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_temporal_column_allows_time_grain(self, mock_is_temporal) -> None: """Test temporal column allows time_grain to be set""" mock_is_temporal.return_value = True config = XYChartConfig( chart_type="xy", x=ColumnRef(name="created_at"), y=[ColumnRef(name="count", aggregate="COUNT")], kind="line", time_grain="P1W", ) result = map_xy_config(config, dataset_id=123) assert result["x_axis"] == "created_at" assert result["time_grain_sqla"] == "P1W" assert result["granularity_sqla"] == "created_at" assert "x_axis_sort_series_type" not in result # Temporal x-axis should have a TEMPORAL_RANGE adhoc filter temporal_filters = [ f for f in result.get("adhoc_filters", []) if f.get("operator") == "TEMPORAL_RANGE" ] assert len(temporal_filters) == 1 assert temporal_filters[0]["subject"] == "created_at" @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_non_temporal_ignores_time_grain_param(self, mock_is_temporal) -> None: """Test non-temporal column ignores time_grain even if specified""" mock_is_temporal.return_value = False config = XYChartConfig( chart_type="xy", x=ColumnRef(name="year"), y=[ColumnRef(name="sales", aggregate="SUM")], kind="bar", time_grain="P1M", # This should be ignored for non-temporal ) result = map_xy_config(config, dataset_id=123) # time_grain_sqla should be None, not P1M assert result["time_grain_sqla"] is None assert result["x_axis_sort_series_type"] == "name" class TestEnsureTemporalAdhocFilter: """Test _ensure_temporal_adhoc_filter helper and its integration in map_xy_config""" def test_adds_filter_to_empty_form_data(self) -> None: """Test adds TEMPORAL_RANGE filter when no adhoc_filters exist""" form_data: dict[str, Any] = {} _ensure_temporal_adhoc_filter(form_data, "order_date") assert len(form_data["adhoc_filters"]) == 1 f = form_data["adhoc_filters"][0] assert f["operator"] == FilterOperator.TEMPORAL_RANGE.value assert f["subject"] == "order_date" assert f["comparator"] == NO_TIME_RANGE assert f["expressionType"] == "SIMPLE" assert f["clause"] == "WHERE" def test_appends_to_existing_filters(self) -> None: """Test appends temporal filter after existing user filters""" form_data: dict[str, Any] = { "adhoc_filters": [ {"subject": "region", "operator": "==", "comparator": "US"} ] } _ensure_temporal_adhoc_filter(form_data, "order_date") assert len(form_data["adhoc_filters"]) == 2 assert form_data["adhoc_filters"][0]["subject"] == "region" assert ( form_data["adhoc_filters"][1]["operator"] == FilterOperator.TEMPORAL_RANGE.value ) def test_does_not_duplicate_existing_temporal_filter(self) -> None: """Test skips adding if a TEMPORAL_RANGE filter already exists for the column""" form_data: dict[str, Any] = { "adhoc_filters": [ { "subject": "order_date", "operator": FilterOperator.TEMPORAL_RANGE.value, "comparator": "Last 7 days", } ] } _ensure_temporal_adhoc_filter(form_data, "order_date") # Should still be just 1 filter (no duplicate) assert len(form_data["adhoc_filters"]) == 1 def test_adds_filter_for_different_column(self) -> None: """Test adds filter when existing temporal filter is on a different column""" form_data: dict[str, Any] = { "adhoc_filters": [ { "subject": "created_at", "operator": FilterOperator.TEMPORAL_RANGE.value, "comparator": NO_TIME_RANGE, } ] } _ensure_temporal_adhoc_filter(form_data, "order_date") assert len(form_data["adhoc_filters"]) == 2 @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_temporal_x_axis_adds_filter_in_map_xy(self, mock_is_temporal) -> None: """Test map_xy_config adds TEMPORAL_RANGE filter for temporal x-axis""" mock_is_temporal.return_value = True config = XYChartConfig( chart_type="xy", x=ColumnRef(name="order_date"), y=[ColumnRef(name="revenue", aggregate="SUM")], kind="bar", ) result = map_xy_config(config, dataset_id=123) temporal_filters = [ f for f in result.get("adhoc_filters", []) if f.get("operator") == FilterOperator.TEMPORAL_RANGE.value ] assert len(temporal_filters) == 1 assert temporal_filters[0]["subject"] == "order_date" assert temporal_filters[0]["comparator"] == NO_TIME_RANGE @patch("superset.mcp_service.chart.chart_utils.is_column_truly_temporal") def test_non_temporal_x_axis_no_temporal_filter(self, mock_is_temporal) -> None: """Test non-temporal x-axis skips TEMPORAL_RANGE filter""" mock_is_temporal.return_value = False config = XYChartConfig( chart_type="xy", x=ColumnRef(name="year"), y=[ColumnRef(name="sales", aggregate="SUM")], kind="bar", ) result = map_xy_config(config, dataset_id=123) temporal_filters = [ f for f in result.get("adhoc_filters", []) if f.get("operator") == FilterOperator.TEMPORAL_RANGE.value ] assert len(temporal_filters) == 0 class TestFilterConfigValidation: """Test FilterConfig validation for new operators""" def test_like_operator_with_wildcard(self) -> None: """Test LIKE operator accepts string with % wildcards""" f = FilterConfig(column="name", op="LIKE", value="%mario%") assert f.op == "LIKE" assert f.value == "%mario%" def test_ilike_operator(self) -> None: """Test ILIKE operator accepts string value""" f = FilterConfig(column="name", op="ILIKE", value="%Mario%") assert f.op == "ILIKE" assert f.value == "%Mario%" def test_not_like_operator(self) -> None: """Test NOT LIKE operator accepts string value""" f = FilterConfig(column="name", op="NOT LIKE", value="%test%") assert f.op == "NOT LIKE" def test_in_operator_with_list(self) -> None: """Test IN operator accepts list of values""" f = FilterConfig(column="platform", op="IN", value=["Wii", "PS3", "Xbox360"]) assert f.op == "IN" assert f.value == ["Wii", "PS3", "Xbox360"] def test_not_in_operator_with_list(self) -> None: """Test NOT IN operator accepts list of values""" f = FilterConfig(column="status", op="NOT IN", value=["archived", "deleted"]) assert f.op == "NOT IN" assert f.value == ["archived", "deleted"] def test_in_operator_rejects_scalar_value(self) -> None: """Test IN operator rejects non-list value""" with pytest.raises(ValueError, match="requires a list of values"): FilterConfig(column="platform", op="IN", value="Wii") def test_not_in_operator_rejects_scalar_value(self) -> None: """Test NOT IN operator rejects non-list value""" with pytest.raises(ValueError, match="requires a list of values"): FilterConfig(column="status", op="NOT IN", value="active") def test_equals_operator_rejects_list_value(self) -> None: """Test = operator rejects list value""" with pytest.raises(ValueError, match="requires a single value, not a list"): FilterConfig(column="name", op="=", value=["a", "b"]) def test_like_operator_rejects_list_value(self) -> None: """Test LIKE operator rejects list value""" with pytest.raises(ValueError, match="requires a single value, not a list"): FilterConfig(column="name", op="LIKE", value=["%a%", "%b%"]) def test_in_operator_with_numeric_list(self) -> None: """Test IN operator with numeric values""" f = FilterConfig(column="year", op="IN", value=[2020, 2021, 2022]) assert f.value == [2020, 2021, 2022] def test_in_operator_with_empty_list(self) -> None: """Test IN operator with empty list""" f = FilterConfig(column="platform", op="IN", value=[]) assert f.value == [] class TestValidateChartDataset: """Test validate_chart_dataset function""" @patch("superset.mcp_service.auth.has_dataset_access") @patch("superset.daos.dataset.DatasetDAO.find_by_id") def test_validate_chart_dataset_no_datasource_id( self, mock_find: MagicMock, mock_access: MagicMock ) -> None: """Chart with no datasource_id returns invalid result.""" chart = MagicMock(spec=[]) # no datasource_id attribute result = validate_chart_dataset(chart) assert not result.is_valid assert result.dataset_id is None assert "no dataset reference" in (result.error or "").lower() mock_find.assert_not_called() @patch("superset.mcp_service.auth.has_dataset_access") @patch("superset.daos.dataset.DatasetDAO.find_by_id", return_value=None) def test_validate_chart_dataset_deleted_dataset( self, mock_find: MagicMock, mock_access: MagicMock ) -> None: """Chart whose dataset was deleted returns invalid result.""" chart = MagicMock() chart.datasource_id = 42 result = validate_chart_dataset(chart) assert not result.is_valid assert result.dataset_id == 42 assert "deleted" in (result.error or "").lower() @patch("superset.mcp_service.auth.has_dataset_access", return_value=True) @patch("superset.daos.dataset.DatasetDAO.find_by_id") def test_validate_chart_dataset_valid( self, mock_find: MagicMock, mock_access: MagicMock ) -> None: """Valid chart with accessible dataset returns valid result.""" dataset = MagicMock() dataset.table_name = "my_table" dataset.sql = None mock_find.return_value = dataset chart = MagicMock() chart.datasource_id = 7 result = validate_chart_dataset(chart) assert result.is_valid assert result.dataset_id == 7 assert result.dataset_name == "my_table" assert result.warnings == [] @patch("superset.mcp_service.auth.has_dataset_access", return_value=True) @patch("superset.daos.dataset.DatasetDAO.find_by_id") def test_validate_chart_dataset_virtual_warns( self, mock_find: MagicMock, mock_access: MagicMock ) -> None: """Virtual dataset emits a warning.""" dataset = MagicMock() dataset.table_name = "virt_ds" dataset.sql = "SELECT 1" mock_find.return_value = dataset chart = MagicMock() chart.datasource_id = 10 result = validate_chart_dataset(chart) assert result.is_valid assert len(result.warnings) == 1 assert "virtual" in result.warnings[0].lower() @patch("superset.mcp_service.auth.has_dataset_access") @patch("superset.daos.dataset.DatasetDAO.find_by_id") def test_validate_chart_dataset_sqlalchemy_error( self, mock_find: MagicMock, mock_access: MagicMock ) -> None: """SQLAlchemy errors are caught and produce an invalid result.""" from sqlalchemy.exc import SQLAlchemyError mock_find.side_effect = SQLAlchemyError("connection lost") chart = MagicMock() chart.datasource_id = 99 result = validate_chart_dataset(chart) assert not result.is_valid assert result.dataset_id == 99 assert "error" in (result.error or "").lower() @patch( "superset.mcp_service.chart.chart_utils.get_superset_base_url", return_value="http://localhost:8088", ) @patch("superset.daos.dataset.DatasetDAO.find_by_id") def test_generate_explore_link_sqlalchemy_error( self, mock_find: MagicMock, mock_base_url: MagicMock, ) -> None: """SQLAlchemy errors in generate_explore_link fall back to basic URL.""" from sqlalchemy.exc import SQLAlchemyError mock_find.side_effect = SQLAlchemyError("db gone") url = generate_explore_link(5, {"viz_type": "table"}) assert "datasource_id=5" in url class TestAdhocFiltersToQueryFilters: """Tests for adhoc_filters_to_query_filters conversion.""" def test_converts_simple_filters(self) -> None: adhoc = [ { "clause": "WHERE", "expressionType": "SIMPLE", "subject": "genre", "operator": "==", "comparator": "Action", } ] result = adhoc_filters_to_query_filters(adhoc) assert result == [{"col": "genre", "op": "==", "val": "Action"}] def test_converts_multiple_filters(self) -> None: adhoc = [ { "clause": "WHERE", "expressionType": "SIMPLE", "subject": "genre", "operator": "==", "comparator": "Action", }, { "clause": "WHERE", "expressionType": "SIMPLE", "subject": "year", "operator": ">=", "comparator": "2010", }, ] result = adhoc_filters_to_query_filters(adhoc) assert len(result) == 2 assert result[0] == {"col": "genre", "op": "==", "val": "Action"} assert result[1] == {"col": "year", "op": ">=", "val": "2010"} def test_empty_list(self) -> None: assert adhoc_filters_to_query_filters([]) == [] def test_skips_non_simple_expression_types(self) -> None: adhoc = [ { "clause": "WHERE", "expressionType": "SQL", "sqlExpression": "col > 5", } ] result = adhoc_filters_to_query_filters(adhoc) assert result == []