diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py index b03d0ffe9c0..c50fccc6938 100644 --- a/superset/mcp_service/chart/validation/dataset_validator.py +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -22,7 +22,7 @@ Validates that referenced columns exist in the dataset schema. import difflib import logging -from typing import Dict, List, Tuple +from typing import Any, Dict, List, Tuple from superset.mcp_service.chart.schemas import ( ColumnRef, @@ -37,13 +37,25 @@ from superset.mcp_service.common.error_schemas import ( logger = logging.getLogger(__name__) +# Exceptions that can occur during column name normalization. +# Shared by the validation pipeline and tool-level normalization calls. +NORMALIZATION_EXCEPTIONS = ( + ImportError, + AttributeError, + KeyError, + ValueError, + TypeError, +) + class DatasetValidator: """Validates chart configuration against dataset schema.""" @staticmethod def validate_against_dataset( - config: TableChartConfig | XYChartConfig, dataset_id: int | str + config: TableChartConfig | XYChartConfig, + dataset_id: int | str, + dataset_context: DatasetContext | None = None, ) -> Tuple[bool, ChartGenerationError | None]: """ Validate chart configuration against dataset schema. @@ -51,12 +63,15 @@ class DatasetValidator: Args: config: Chart configuration to validate dataset_id: Dataset ID to validate against + dataset_context: Pre-fetched dataset context to avoid duplicate + DB queries. If None, fetches from the database. Returns: Tuple of (is_valid, error) """ - # Get dataset context - dataset_context = DatasetValidator._get_dataset_context(dataset_id) + # Get dataset context (reuse if provided) + if dataset_context is None: + dataset_context = DatasetValidator._get_dataset_context(dataset_id) if not dataset_context: from superset.mcp_service.utils.error_builder import ( ChartErrorBuilder, @@ -198,6 +213,136 @@ class DatasetValidator: return False + @staticmethod + def _get_canonical_column_name( + column_name: str, dataset_context: DatasetContext + ) -> str: + """ + Get the canonical column name from the dataset. + + Performs case-insensitive matching and returns the actual column name + as stored in the dataset. This ensures column names in form_data match + exactly with what the frontend expects. + + Args: + column_name: The column name to normalize + dataset_context: Dataset context with column information + + Returns: + The canonical column name from the dataset, or the original name + if no match is found. + """ + column_lower = column_name.lower() + + # Check regular columns first + for col in dataset_context.available_columns: + if col["name"].lower() == column_lower: + return col["name"] + + # Check metrics + for metric in dataset_context.available_metrics: + if metric["name"].lower() == column_lower: + return metric["name"] + + # Return original if not found (validation should catch this case) + return column_name + + @staticmethod + def _normalize_xy_config( + config_dict: Dict[str, Any], dataset_context: DatasetContext + ) -> None: + """Normalize column names in an XY chart config dict in place.""" + # Normalize x-axis column + if "x" in config_dict and config_dict["x"]: + config_dict["x"]["name"] = DatasetValidator._get_canonical_column_name( + config_dict["x"]["name"], dataset_context + ) + + # Normalize y-axis columns + if "y" in config_dict and config_dict["y"]: + for y_col in config_dict["y"]: + y_col["name"] = DatasetValidator._get_canonical_column_name( + y_col["name"], dataset_context + ) + + # Normalize group_by column + if "group_by" in config_dict and config_dict["group_by"]: + config_dict["group_by"]["name"] = ( + DatasetValidator._get_canonical_column_name( + config_dict["group_by"]["name"], dataset_context + ) + ) + + @staticmethod + def _normalize_table_config( + config_dict: Dict[str, Any], dataset_context: DatasetContext + ) -> None: + """Normalize column names in a table chart config dict in place.""" + if "columns" in config_dict and config_dict["columns"]: + for col in config_dict["columns"]: + col["name"] = DatasetValidator._get_canonical_column_name( + col["name"], dataset_context + ) + + @staticmethod + def _normalize_filters( + config_dict: Dict[str, Any], dataset_context: DatasetContext + ) -> None: + """Normalize filter column names in a config dict in place.""" + if "filters" in config_dict and config_dict["filters"]: + for filter_config in config_dict["filters"]: + if filter_config and "column" in filter_config: + filter_config["column"] = ( + DatasetValidator._get_canonical_column_name( + filter_config["column"], dataset_context + ) + ) + + @staticmethod + def normalize_column_names( + config: TableChartConfig | XYChartConfig, + dataset_id: int | str, + dataset_context: DatasetContext | None = None, + ) -> TableChartConfig | XYChartConfig: + """ + Normalize column names in config to match the canonical dataset column names. + + This fixes case sensitivity issues where user-provided column names + (e.g., 'order_date') don't match exactly with the dataset column names + (e.g., 'OrderDate'). The frontend performs case-sensitive comparisons, + so we need to ensure column names match exactly. + + Args: + config: Chart configuration with column references + dataset_id: Dataset ID to get canonical column names from + dataset_context: Pre-fetched dataset context to avoid duplicate + DB queries. If None, fetches from the database. + + Returns: + A new config with normalized column names + """ + if dataset_context is None: + dataset_context = DatasetValidator._get_dataset_context(dataset_id) + if not dataset_context: + return config + + # Create a mutable copy of the config + config_dict = config.model_dump() + + # Normalize based on config type + if isinstance(config, XYChartConfig): + DatasetValidator._normalize_xy_config(config_dict, dataset_context) + elif isinstance(config, TableChartConfig): + DatasetValidator._normalize_table_config(config_dict, dataset_context) + + # Normalize filter columns (common to both config types) + DatasetValidator._normalize_filters(config_dict, dataset_context) + + # Reconstruct the config with normalized names + if isinstance(config, XYChartConfig): + return XYChartConfig.model_validate(config_dict) + return TableChartConfig.model_validate(config_dict) + @staticmethod def _get_column_suggestions( column_name: str, dataset_context: DatasetContext, max_suggestions: int = 3 diff --git a/superset/mcp_service/chart/validation/pipeline.py b/superset/mcp_service/chart/validation/pipeline.py index 948f9d2e62d..a0f475ffa0e 100644 --- a/superset/mcp_service/chart/validation/pipeline.py +++ b/superset/mcp_service/chart/validation/pipeline.py @@ -27,7 +27,10 @@ from superset.mcp_service.chart.schemas import ( ChartConfig, GenerateChartRequest, ) -from superset.mcp_service.common.error_schemas import ChartGenerationError +from superset.mcp_service.common.error_schemas import ( + ChartGenerationError, + DatasetContext, +) logger = logging.getLogger(__name__) @@ -168,9 +171,14 @@ class ValidationPipeline: if request is None: return ValidationResult(is_valid=False, error=error) - # Layer 2: Dataset validation + # Fetch dataset context once and reuse across validation layers + dataset_context = ValidationPipeline._get_dataset_context( + request.dataset_id + ) + + # Layer 2: Dataset validation (reuses context) is_valid, error = ValidationPipeline._validate_dataset( - request.config, request.dataset_id + request.config, request.dataset_id, dataset_context ) if not is_valid: return ValidationResult(is_valid=False, request=request, error=error) @@ -181,8 +189,15 @@ class ValidationPipeline: ) # Runtime validation always returns True now, warnings are informational + # Layer 4: Column name normalization (reuses context) + normalized_request = ValidationPipeline._normalize_column_names( + request, dataset_context + ) + return ValidationResult( - is_valid=True, request=request, warnings=warnings_metadata + is_valid=True, + request=normalized_request, + warnings=warnings_metadata, ) except Exception as e: @@ -201,15 +216,32 @@ class ValidationPipeline: ) return ValidationResult(is_valid=False, error=error) + @staticmethod + def _get_dataset_context( + dataset_id: int | str, + ) -> DatasetContext | None: + """Fetch dataset context once to reuse across validation layers.""" + try: + from .dataset_validator import DatasetValidator + + return DatasetValidator._get_dataset_context(dataset_id) + except ImportError: + logger.warning("Dataset validator not available, skipping context fetch") + return None + @staticmethod def _validate_dataset( - config: ChartConfig, dataset_id: int | str + config: ChartConfig, + dataset_id: int | str, + dataset_context: DatasetContext | None = None, ) -> Tuple[bool, ChartGenerationError | None]: """Validate configuration against dataset schema.""" try: from .dataset_validator import DatasetValidator - return DatasetValidator.validate_against_dataset(config, dataset_id) + return DatasetValidator.validate_against_dataset( + config, dataset_id, dataset_context=dataset_context + ) except ImportError: # Skip if dataset validator not available logger.warning( @@ -248,6 +280,48 @@ class ValidationPipeline: # Don't fail on runtime validation errors return True, None + @staticmethod + def _normalize_column_names( + request: GenerateChartRequest, + dataset_context: DatasetContext | None = None, + ) -> GenerateChartRequest: + """ + Normalize column names in the request to match canonical dataset names. + + This fixes case sensitivity issues where user-provided column names + don't match exactly with the dataset column names. For example, + if a user provides 'order_date' but the dataset has 'OrderDate', + this method will normalize it to 'OrderDate'. + + Args: + request: The validated chart generation request + dataset_context: Pre-fetched dataset context to avoid duplicate + DB queries. If None, fetches from the database. + + Returns: + A new request with normalized column names + """ + try: + from .dataset_validator import DatasetValidator + + normalized_config = DatasetValidator.normalize_column_names( + request.config, + request.dataset_id, + dataset_context=dataset_context, + ) + + # Create a new request with the normalized config + request_dict = request.model_dump() + request_dict["config"] = normalized_config.model_dump() + + return GenerateChartRequest.model_validate(request_dict) + + except (ImportError, AttributeError, KeyError, ValueError, TypeError) as e: + # If normalization fails, return the original request + # Validation has already passed, so this is a non-critical failure + logger.warning("Column name normalization failed: %s", e) + return request + @staticmethod def validate_filters( filters: List[Any], diff --git a/superset/mcp_service/explore/tool/generate_explore_link.py b/superset/mcp_service/explore/tool/generate_explore_link.py index 3048a538c6c..dca721caa8b 100644 --- a/superset/mcp_service/explore/tool/generate_explore_link.py +++ b/superset/mcp_service/explore/tool/generate_explore_link.py @@ -93,9 +93,22 @@ async def generate_explore_link( try: await ctx.report_progress(1, 3, "Converting configuration to form data") with event_logger.log_context(action="mcp.generate_explore_link.form_data"): + # Normalize column names to match canonical dataset column names + # This fixes case sensitivity issues (e.g., 'order_date' vs 'OrderDate') + try: + from superset.mcp_service.chart.validation.dataset_validator import ( + DatasetValidator, + ) + + normalized_config = DatasetValidator.normalize_column_names( + request.config, request.dataset_id + ) + except (ImportError, AttributeError, KeyError, ValueError, TypeError): + normalized_config = request.config + # Map config to form_data using shared utilities form_data = map_config_to_form_data( - request.config, dataset_id=request.dataset_id + normalized_config, dataset_id=request.dataset_id ) # Add datasource to form_data for consistency with generate_chart diff --git a/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py new file mode 100644 index 00000000000..77fdf64143f --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/validation/test_column_name_normalization.py @@ -0,0 +1,681 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for column name normalization in the MCP service. + +This tests the fix for the issue where time series charts would incorrectly +prompt to add the x-axis to filters when the column name case didn't match +exactly (e.g., 'order_date' vs 'OrderDate'). +""" + +from typing import Any, Dict +from unittest.mock import patch + +import pytest + +from superset.mcp_service.chart.schemas import ( + ColumnRef, + FilterConfig, + TableChartConfig, + XYChartConfig, +) +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator +from superset.mcp_service.common.error_schemas import DatasetContext + + +@pytest.fixture +def mock_dataset_context() -> DatasetContext: + """Create a mock dataset context with mixed-case column names.""" + return DatasetContext( + id=18, + table_name="Vehicle Sales", + schema="public", + database_name="examples", + available_columns=[ + {"name": "OrderDate", "type": "DATE", "is_temporal": True}, + {"name": "ProductLine", "type": "VARCHAR", "is_temporal": False}, + {"name": "Sales", "type": "DECIMAL", "is_numeric": True}, + {"name": "quantity_ordered", "type": "INTEGER", "is_numeric": True}, + ], + available_metrics=[ + {"name": "TotalRevenue", "expression": "SUM(Sales)", "description": None}, + ], + ) + + +class TestGetCanonicalColumnName: + """Test _get_canonical_column_name static method.""" + + def test_exact_match_returns_same_name( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that exact match returns the same column name.""" + result = DatasetValidator._get_canonical_column_name( + "OrderDate", mock_dataset_context + ) + assert result == "OrderDate" + + def test_lowercase_returns_canonical_name( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that lowercase input returns the canonical (dataset) column name.""" + result = DatasetValidator._get_canonical_column_name( + "orderdate", mock_dataset_context + ) + assert result == "OrderDate" + + def test_snake_case_returns_canonical_name( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that snake_case input returns the canonical column name.""" + # 'order_date' won't match 'OrderDate' directly, but would match if + # the dataset had 'order_date'. This test verifies case-insensitive matching. + result = DatasetValidator._get_canonical_column_name( + "productline", mock_dataset_context + ) + assert result == "ProductLine" + + def test_uppercase_returns_canonical_name( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that uppercase input returns the canonical column name.""" + result = DatasetValidator._get_canonical_column_name( + "SALES", mock_dataset_context + ) + assert result == "Sales" + + def test_metric_name_normalization( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that metric names are also normalized.""" + result = DatasetValidator._get_canonical_column_name( + "totalrevenue", mock_dataset_context + ) + assert result == "TotalRevenue" + + def test_unknown_column_returns_original( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that unknown columns return the original name.""" + result = DatasetValidator._get_canonical_column_name( + "unknown_column", mock_dataset_context + ) + assert result == "unknown_column" + + +class TestNormalizeXYConfig: + """Test _normalize_xy_config static method.""" + + def test_normalize_x_axis_column( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that x-axis column name is normalized.""" + config_dict: Dict[str, Any] = { + "chart_type": "xy", + "x": {"name": "orderdate"}, + "y": [{"name": "Sales", "aggregate": "SUM"}], + "kind": "line", + } + + DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) + + assert config_dict["x"]["name"] == "OrderDate" + + def test_normalize_y_axis_columns( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that y-axis column names are normalized.""" + config_dict: Dict[str, Any] = { + "chart_type": "xy", + "x": {"name": "OrderDate"}, + "y": [ + {"name": "sales", "aggregate": "SUM"}, + {"name": "QUANTITY_ORDERED", "aggregate": "COUNT"}, + ], + "kind": "bar", + } + + DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) + + assert config_dict["y"][0]["name"] == "Sales" + assert config_dict["y"][1]["name"] == "quantity_ordered" + + def test_normalize_group_by_column( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that group_by column name is normalized.""" + config_dict: Dict[str, Any] = { + "chart_type": "xy", + "x": {"name": "OrderDate"}, + "y": [{"name": "Sales", "aggregate": "SUM"}], + "kind": "line", + "group_by": {"name": "productline"}, + } + + DatasetValidator._normalize_xy_config(config_dict, mock_dataset_context) + + assert config_dict["group_by"]["name"] == "ProductLine" + + +class TestNormalizeTableConfig: + """Test _normalize_table_config static method.""" + + def test_normalize_table_columns( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that table column names are normalized.""" + config_dict: Dict[str, Any] = { + "chart_type": "table", + "columns": [ + {"name": "orderdate"}, + {"name": "PRODUCTLINE"}, + {"name": "sales", "aggregate": "SUM"}, + ], + } + + DatasetValidator._normalize_table_config(config_dict, mock_dataset_context) + + assert config_dict["columns"][0]["name"] == "OrderDate" + assert config_dict["columns"][1]["name"] == "ProductLine" + assert config_dict["columns"][2]["name"] == "Sales" + + +class TestNormalizeFilters: + """Test _normalize_filters static method.""" + + def test_normalize_filter_columns( + self, mock_dataset_context: DatasetContext + ) -> None: + """Test that filter column names are normalized.""" + config_dict: Dict[str, Any] = { + "filters": [ + {"column": "productline", "op": "=", "value": "Classic Cars"}, + {"column": "ORDERDATE", "op": ">", "value": "2023-01-01"}, + ], + } + + DatasetValidator._normalize_filters(config_dict, mock_dataset_context) + + assert config_dict["filters"][0]["column"] == "ProductLine" + assert config_dict["filters"][1]["column"] == "OrderDate" + + +class TestNormalizeColumnNames: + """Test the main normalize_column_names method.""" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_normalize_xy_chart_config( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Test full normalization of XY chart config.""" + mock_get_context.return_value = mock_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), # lowercase - should normalize to OrderDate + y=[ + ColumnRef(name="sales", aggregate="SUM") + ], # lowercase - should normalize to Sales + kind="line", + filters=[FilterConfig(column="productline", op="=", value="Classic Cars")], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + assert normalized.x.name == "OrderDate" + assert normalized.y[0].name == "Sales" + assert normalized.filters is not None + assert normalized.filters[0].column == "ProductLine" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_normalize_table_chart_config( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Test full normalization of table chart config.""" + mock_get_context.return_value = mock_dataset_context + + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="orderdate"), + ColumnRef(name="productline"), + ColumnRef(name="sales", aggregate="SUM"), + ], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + assert normalized.columns[0].name == "OrderDate" + assert normalized.columns[1].name == "ProductLine" + assert normalized.columns[2].name == "Sales" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_returns_original_when_dataset_not_found(self, mock_get_context) -> None: + """Test that original config is returned when dataset context is unavailable.""" + mock_get_context.return_value = None + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=999) + + # Should return original config unchanged + assert normalized.x.name == "orderdate" + assert normalized.y[0].name == "sales" + + +class TestTimeSeriesFilterPromptFix: + """Test the fix for time series charts incorrectly prompting x-axis filters.""" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_x_axis_matches_existing_filter_after_normalization( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """ + Test the core fix: when creating a time series chart with + 'order_date' as x-axis, and there's already a filter with + 'OrderDate', after normalization they should match. + + This is the exact scenario from the bug report where: + - User creates chart with x_axis = 'order_date' + - Dataset has column named 'OrderDate' + - Existing filter has subject = 'OrderDate' + - Without normalization: 'order_date' != 'OrderDate' -> prompt shown + - With normalization: 'OrderDate' == 'OrderDate' -> no prompt + """ + mock_get_context.return_value = mock_dataset_context + + # Simulate what the MCP service receives from user + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), # User provides lowercase + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + # Simulating an existing filter with the canonical name + filters=[ + FilterConfig(column="OrderDate", op=">", value="2023-01-01"), + ], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + # After normalization, x.name should match the filter column exactly + assert normalized.x.name == "OrderDate" + assert normalized.filters is not None + assert normalized.filters[0].column == "OrderDate" + + # This equality is what the frontend checks - now they match! + assert normalized.x.name == normalized.filters[0].column + + +@pytest.fixture +def uppercase_dataset_context() -> DatasetContext: + """Create a mock dataset context with all-uppercase column names (like flights).""" + return DatasetContext( + id=24, + table_name="flights", + schema="public", + database_name="examples", + available_columns=[ + {"name": "DEPARTURE_DELAY", "type": "FLOAT", "is_numeric": True}, + {"name": "ARRIVAL_DELAY", "type": "FLOAT", "is_numeric": True}, + {"name": "DISTANCE", "type": "BIGINT", "is_numeric": True}, + {"name": "AIRLINE", "type": "VARCHAR", "is_temporal": False}, + {"name": "ds", "type": "TIMESTAMP", "is_temporal": True}, + ], + available_metrics=[ + {"name": "count", "expression": "COUNT(*)", "description": None}, + ], + ) + + +class TestNormalizeMultipleYAxisColumns: + """Test normalization of multiple y-axis columns.""" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_normalize_multiple_y_columns( + self, mock_get_context, uppercase_dataset_context: DatasetContext + ) -> None: + """Test that all y-axis columns are normalized.""" + mock_get_context.return_value = uppercase_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds"), + y=[ + ColumnRef(name="departure_delay", aggregate="AVG"), + ColumnRef(name="arrival_delay", aggregate="AVG"), + ], + kind="area", + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=24) + + assert normalized.y[0].name == "DEPARTURE_DELAY" + assert normalized.y[1].name == "ARRIVAL_DELAY" + + +class TestNormalizeUppercaseDataset: + """Test normalization against dataset with all-uppercase column names.""" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_lowercase_to_uppercase( + self, mock_get_context, uppercase_dataset_context: DatasetContext + ) -> None: + """Test lowercase input normalizes to uppercase canonical names.""" + mock_get_context.return_value = uppercase_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="distance", aggregate="AVG")], + kind="bar", + group_by=ColumnRef(name="airline"), + filters=[FilterConfig(column="airline", op="=", value="AA")], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=24) + + assert normalized.x.name == "ds" + assert normalized.y[0].name == "DISTANCE" + assert normalized.group_by is not None + assert normalized.group_by.name == "AIRLINE" + assert normalized.filters is not None + assert normalized.filters[0].column == "AIRLINE" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_exact_match_preserved( + self, mock_get_context, uppercase_dataset_context: DatasetContext + ) -> None: + """Test that already-correct names are preserved unchanged.""" + mock_get_context.return_value = uppercase_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="DEPARTURE_DELAY", aggregate="AVG")], + kind="line", + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=24) + + assert normalized.x.name == "ds" + assert normalized.y[0].name == "DEPARTURE_DELAY" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_metric_normalized_in_y_axis( + self, mock_get_context, uppercase_dataset_context: DatasetContext + ) -> None: + """Test that metric names used in y-axis are normalized.""" + mock_get_context.return_value = uppercase_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="COUNT", aggregate="SUM")], + kind="bar", + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=24) + + # 'COUNT' should normalize to 'count' (the metric name) + assert normalized.y[0].name == "count" + + +class TestNormalizeEdgeCases: + """Test edge cases for column name normalization.""" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_config_with_no_filters( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Test normalization when config has no filters.""" + mock_get_context.return_value = mock_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + assert normalized.x.name == "OrderDate" + assert normalized.y[0].name == "Sales" + assert normalized.filters is None + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_config_with_empty_filters( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Test normalization when config has empty filters list.""" + mock_get_context.return_value = mock_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + filters=[], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + assert normalized.x.name == "OrderDate" + assert normalized.filters is not None + assert len(normalized.filters) == 0 + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_config_with_no_group_by( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Test normalization when config has no group_by.""" + mock_get_context.return_value = mock_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="bar", + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + assert normalized.x.name == "OrderDate" + assert normalized.group_by is None + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_all_fields_normalized_together( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Test that x, y, group_by, and filters are all normalized in one call.""" + mock_get_context.return_value = mock_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ORDERDATE"), + y=[ + ColumnRef(name="sales", aggregate="SUM"), + ColumnRef(name="QUANTITY_ORDERED", aggregate="COUNT"), + ], + kind="bar", + group_by=ColumnRef(name="PRODUCTLINE"), + filters=[ + FilterConfig(column="productline", op="=", value="Classic Cars"), + FilterConfig(column="ORDERDATE", op=">", value="2023-01-01"), + ], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + assert normalized.x.name == "OrderDate" + assert normalized.y[0].name == "Sales" + assert normalized.y[1].name == "quantity_ordered" + assert normalized.group_by is not None + assert normalized.group_by.name == "ProductLine" + assert normalized.filters is not None + assert normalized.filters[0].column == "ProductLine" + assert normalized.filters[1].column == "OrderDate" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_normalization_is_idempotent( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Test that normalizing already-normalized config returns same result.""" + mock_get_context.return_value = mock_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + filters=[FilterConfig(column="productline", op="=", value="Cars")], + ) + + first = DatasetValidator.normalize_column_names(config, dataset_id=18) + second = DatasetValidator.normalize_column_names(first, dataset_id=18) + + assert first.x.name == second.x.name == "OrderDate" + assert first.y[0].name == second.y[0].name == "Sales" + assert first.filters is not None + assert second.filters is not None + assert first.filters[0].column == second.filters[0].column == "ProductLine" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_aggregate_preserved_after_normalization( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Test that aggregate functions are preserved during normalization.""" + mock_get_context.return_value = mock_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ + ColumnRef(name="sales", aggregate="SUM"), + ColumnRef(name="QUANTITY_ORDERED", aggregate="AVG"), + ], + kind="bar", + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + assert normalized.y[0].aggregate == "SUM" + assert normalized.y[1].aggregate == "AVG" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_filter_operator_and_value_preserved( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Test that filter op and value are preserved during normalization.""" + mock_get_context.return_value = mock_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + filters=[ + FilterConfig(column="ORDERDATE", op=">=", value="2023-01-01"), + FilterConfig(column="sales", op=">", value=1000), + ], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + assert normalized.filters is not None + assert normalized.filters[0].column == "OrderDate" + assert normalized.filters[0].op == ">=" + assert normalized.filters[0].value == "2023-01-01" + assert normalized.filters[1].column == "Sales" + assert normalized.filters[1].op == ">" + assert normalized.filters[1].value == 1000 + + +class TestNormalizeXAxisFilterConsistency: + """Test that x-axis and filter column names are consistent after normalization. + + These tests verify the core bug fix: when x-axis and filter reference + the same column but with different cases, normalization ensures they match. + """ + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_both_wrong_case_normalized_to_same( + self, mock_get_context, mock_dataset_context: DatasetContext + ) -> None: + """Both x-axis and filter in wrong case normalize to same canonical name.""" + mock_get_context.return_value = mock_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ORDERDATE"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + filters=[FilterConfig(column="orderdate", op=">", value="2023-01-01")], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=18) + + assert normalized.filters is not None + assert normalized.x.name == normalized.filters[0].column == "OrderDate" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_uppercase_dataset_x_filter_match( + self, mock_get_context, uppercase_dataset_context: DatasetContext + ) -> None: + """On uppercase-column dataset, both lowercase refs normalize to uppercase.""" + mock_get_context.return_value = uppercase_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="departure_delay", aggregate="AVG")], + kind="line", + filters=[FilterConfig(column="ds", op=">", value="2015-01-01")], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=24) + + assert normalized.filters is not None + assert normalized.x.name == normalized.filters[0].column == "ds" + + @patch.object(DatasetValidator, "_get_dataset_context") + def test_group_by_matches_filter_after_normalization( + self, mock_get_context, uppercase_dataset_context: DatasetContext + ) -> None: + """group_by and filter for same column normalize to same canonical name.""" + mock_get_context.return_value = uppercase_dataset_context + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="distance", aggregate="AVG")], + kind="bar", + group_by=ColumnRef(name="Airline"), + filters=[FilterConfig(column="airline", op="=", value="AA")], + ) + + normalized = DatasetValidator.normalize_column_names(config, dataset_id=24) + + assert normalized.group_by is not None + assert normalized.filters is not None + assert normalized.group_by.name == normalized.filters[0].column == "AIRLINE" diff --git a/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py index af08834d57b..0a8771e48ba 100644 --- a/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py +++ b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py @@ -35,6 +35,7 @@ from superset.mcp_service.chart.schemas import ( TableChartConfig, XYChartConfig, ) +from superset.mcp_service.common.error_schemas import DatasetContext logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -706,3 +707,151 @@ class TestGenerateExploreLink: assert result.data["form_data"].get("x_axis") == "date" # Verify datasource field format: "{dataset_id}__table" assert result.data["form_data"].get("datasource") == "1__table" + + +class TestGenerateExploreLinkColumnNormalization: + """Tests that generate_explore_link normalizes column names. + + This verifies the fix where user-provided column names in wrong case + (e.g., 'order_date') are normalized to the canonical dataset name + (e.g., 'OrderDate') before being used in form_data. + """ + + @patch( + "superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context" + ) + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_xy_chart_x_axis_normalized_in_form_data( + self, + mock_create_form_data, + mock_find_dataset, + mock_get_context, + mcp_server, + ): + """x-axis column name in wrong case is normalized in form_data.""" + mock_create_form_data.return_value = "norm_test_key_1" + mock_find_dataset.return_value = _mock_dataset(id=18) + mock_get_context.return_value = DatasetContext( + id=18, + table_name="Vehicle Sales", + schema="public", + database_name="examples", + available_columns=[ + {"name": "OrderDate", "type": "DATE", "is_temporal": True}, + {"name": "Sales", "type": "FLOAT", "is_numeric": True}, + ], + available_metrics=[], + ) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + ) + request = GenerateExploreLinkRequest(dataset_id="18", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + # x-axis should be normalized from 'orderdate' to 'OrderDate' + assert result.data["form_data"]["x_axis"] == "OrderDate" + + @patch( + "superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context" + ) + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_filter_column_normalized_in_form_data( + self, + mock_create_form_data, + mock_find_dataset, + mock_get_context, + mcp_server, + ): + """Filter column name in wrong case is normalized in adhoc_filters.""" + mock_create_form_data.return_value = "norm_test_key_2" + mock_find_dataset.return_value = _mock_dataset(id=18) + mock_get_context.return_value = DatasetContext( + id=18, + table_name="Vehicle Sales", + schema="public", + database_name="examples", + available_columns=[ + {"name": "OrderDate", "type": "DATE", "is_temporal": True}, + {"name": "Sales", "type": "FLOAT", "is_numeric": True}, + ], + available_metrics=[], + ) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + filters=[ + FilterConfig(column="orderdate", op=">", value="2023-01-01"), + ], + ) + request = GenerateExploreLinkRequest(dataset_id="18", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + form_data = result.data["form_data"] + # x-axis normalized + assert form_data["x_axis"] == "OrderDate" + # filter subject normalized to match x-axis + adhoc_filters = form_data.get("adhoc_filters", []) + assert len(adhoc_filters) == 1 + assert adhoc_filters[0]["subject"] == "OrderDate" + + @patch( + "superset.mcp_service.chart.validation.dataset_validator.DatasetValidator._get_dataset_context" + ) + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_normalization_fallback_when_dataset_not_found( + self, + mock_create_form_data, + mock_find_dataset, + mock_get_context, + mcp_server, + ): + """When dataset context is unavailable, original names pass through.""" + mock_create_form_data.return_value = "norm_test_key_3" + mock_find_dataset.return_value = _mock_dataset(id=99) + mock_get_context.return_value = None + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="orderdate"), + y=[ColumnRef(name="sales", aggregate="SUM")], + kind="line", + ) + request = GenerateExploreLinkRequest(dataset_id="99", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["error"] is None + # original names should pass through unchanged + assert result.data["form_data"]["x_axis"] == "orderdate"