# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ Dataset-specific validation for chart configurations. Validates that referenced columns exist in the dataset schema. """ import difflib import logging from typing import Any, Dict, List, Tuple, TypeVar from superset.mcp_service.chart.schemas import ( ChartConfig, ColumnRef, ) from superset.mcp_service.common.error_schemas import ( ChartGenerationError, ColumnSuggestion, DatasetContext, ) _C = TypeVar("_C", bound=ChartConfig) logger = logging.getLogger(__name__) # Exceptions that can occur during column name normalization. # Shared by the validation pipeline and tool-level normalization calls. NORMALIZATION_EXCEPTIONS = ( ImportError, AttributeError, KeyError, ValueError, TypeError, ) class DatasetValidator: """Validates chart configuration against dataset schema.""" @staticmethod def validate_against_dataset( config: ChartConfig, dataset_id: int | str, dataset_context: DatasetContext | None = None, ) -> Tuple[bool, ChartGenerationError | None]: """ Validate chart configuration against dataset schema. Args: config: Chart configuration to validate dataset_id: Dataset ID to validate against dataset_context: Pre-fetched dataset context to avoid duplicate DB queries. If None, fetches from the database. Returns: Tuple of (is_valid, error) """ # Get dataset context (reuse if provided) if dataset_context is None: dataset_context = DatasetValidator._get_dataset_context(dataset_id) if not dataset_context: from superset.mcp_service.utils.error_builder import ( ChartErrorBuilder, ) return False, ChartErrorBuilder.dataset_not_found_error(dataset_id) # Collect all column references column_refs = DatasetValidator._extract_column_references(config) # Validate saved metrics exist in dataset metrics specifically invalid_saved = DatasetValidator._validate_saved_metrics( column_refs, dataset_context ) if invalid_saved: return False, invalid_saved # Validate columns exist (skip saved metrics — already validated above) column_error = DatasetValidator._validate_columns_exist( column_refs, dataset_context ) if column_error: return False, column_error # Validate aggregation compatibility for every config that produced # column refs. ``_validate_aggregations`` is config-agnostic — gating # it to Table/XY would let pie / pivot table / mixed timeseries / # handlebars / big number slip through ``SUM(non_numeric)`` patterns # for the fast-path tools that skip Tier 2. aggregation_errors = DatasetValidator._validate_aggregations( column_refs, dataset_context ) if aggregation_errors: return False, aggregation_errors[0] return True, None @staticmethod def _validate_columns_exist( column_refs: List[ColumnRef], dataset_context: DatasetContext ) -> ChartGenerationError | None: """Validate that non-saved-metric column refs exist in the dataset. A ``ColumnRef`` with ``saved_metric=False`` must match an entry in ``available_columns``. Saved-metric *names* don't satisfy this check — otherwise ``{name: "sum_boys", aggregate: "SUM"}`` (no ``saved_metric=true``) would slip through and downstream code would emit ``SUM(sum_boys)`` as an ad-hoc SIMPLE metric, producing the broken-SQL pattern this validator is meant to prevent. """ column_names_lower = { col["name"].lower() for col in dataset_context.available_columns } metric_names_lower = { metric["name"].lower() for metric in dataset_context.available_metrics } invalid_columns: List[ColumnRef] = [] saved_metric_typo: List[ColumnRef] = [] for col_ref in column_refs: if col_ref.saved_metric: continue name_lower = col_ref.name.lower() if name_lower in column_names_lower: continue if name_lower in metric_names_lower: # Name matches a saved metric but the ref didn't opt into # saved-metric resolution. Surface a tailored hint so the # caller (typically an LLM) can flip ``saved_metric=true``. saved_metric_typo.append(col_ref) else: invalid_columns.append(col_ref) if saved_metric_typo: return DatasetValidator._build_saved_metric_hint_error(saved_metric_typo) if not invalid_columns: return None suggestions_map = {} for col_ref in invalid_columns: suggestions = DatasetValidator._get_column_suggestions( col_ref.name, dataset_context ) suggestions_map[col_ref.name] = suggestions return DatasetValidator._build_column_error( invalid_columns, suggestions_map, dataset_context ) @staticmethod def _build_saved_metric_hint_error( refs: List[ColumnRef], ) -> ChartGenerationError: """Error response when a non-saved-metric ref names a saved metric.""" names = [r.name for r in refs] names_str = ", ".join(f"'{n}'" for n in names) first = names[0] return ChartGenerationError( error_type="saved_metric_not_marked", message=( f"{names_str} matches a saved metric but the ref doesn't " f"have saved_metric=true" ), details=( f"The dataset has a saved metric named {names_str}. To use " f"it, set 'saved_metric': true on the column ref instead of " f"providing an 'aggregate'. With the current shape, the " f"chart would emit ad-hoc SQL like SUM({first}) — which is " f"invalid because {first} is a metric expression, not a " f"column." ), suggestions=[ f'Did you mean: {{"name": "{first}", "saved_metric": true}}?', "Use saved_metric=true to reference a saved dataset metric", "Or pick a real column name and apply an aggregate to it", ], error_code="SAVED_METRIC_NOT_MARKED", ) @staticmethod def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None: """Get dataset context with column information.""" try: from superset.daos.dataset import DatasetDAO # Find dataset if isinstance(dataset_id, int) or ( isinstance(dataset_id, str) and dataset_id.isdigit() ): dataset = DatasetDAO.find_by_id(int(dataset_id)) else: # Try UUID lookup dataset = DatasetDAO.find_by_id(dataset_id, id_column="uuid") if not dataset: return None # Build context columns = [] metrics = [] # Add table columns for col in dataset.columns: columns.append( { "name": col.column_name, "type": str(col.type) if col.type else "UNKNOWN", "is_temporal": col.is_temporal if hasattr(col, "is_temporal") else False, "is_numeric": col.is_numeric if hasattr(col, "is_numeric") else False, } ) # Add metrics for metric in dataset.metrics: metrics.append( { "name": metric.metric_name, "expression": metric.expression, "description": metric.description, } ) return DatasetContext( id=dataset.id, table_name=dataset.table_name, schema=dataset.schema, database_name=dataset.database.database_name if dataset.database else None, available_columns=columns, available_metrics=metrics, ) except Exception as e: logger.error("Error getting dataset context for %s: %s", dataset_id, e) return None @staticmethod def _extract_column_references( config: ChartConfig, ) -> List[ColumnRef]: """Extract all column references from configuration via the plugin registry. Previously only handled TableChartConfig and XYChartConfig, causing 5 of 7 chart types to silently skip column validation. Now delegates to the plugin for each chart type so all types are covered. """ # Local import: plugins call DatasetValidator helpers from # normalize_column_refs(). # A top-level import of registry in dataset_validator would make loading this # module implicitly trigger plugin registration, creating a circular dependency. from superset.mcp_service.chart.registry import get_registry chart_type = getattr(config, "chart_type", None) if chart_type is None: return [] plugin = get_registry().get(chart_type) if plugin is None: logger.warning("No plugin registered for chart_type=%r", chart_type) return [] return plugin.extract_column_refs(config) @staticmethod def _column_exists(column_name: str, dataset_context: DatasetContext) -> bool: """Check if column exists in dataset (case-insensitive).""" column_lower = column_name.lower() # Check regular columns for col in dataset_context.available_columns: if col["name"].lower() == column_lower: return True # Check metrics for metric in dataset_context.available_metrics: if metric["name"].lower() == column_lower: return True return False @staticmethod def _get_canonical_column_name( column_name: str, dataset_context: DatasetContext ) -> str: """ Get the canonical column name from the dataset. Performs case-insensitive matching and returns the actual column name as stored in the dataset. This ensures column names in form_data match exactly with what the frontend expects. Args: column_name: The column name to normalize dataset_context: Dataset context with column information Returns: The canonical column name from the dataset, or the original name if no match is found. """ column_lower = column_name.lower() # Check regular columns first for col in dataset_context.available_columns: if col["name"].lower() == column_lower: return col["name"] # Check metrics for metric in dataset_context.available_metrics: if metric["name"].lower() == column_lower: return metric["name"] # Return original if not found (validation should catch this case) return column_name @staticmethod def _normalize_filters( config_dict: Dict[str, Any], dataset_context: DatasetContext ) -> None: """Normalize filter column names in a config dict in place.""" if "filters" in config_dict and config_dict["filters"]: for filter_config in config_dict["filters"]: if filter_config and "column" in filter_config: filter_config["column"] = ( DatasetValidator._get_canonical_column_name( filter_config["column"], dataset_context ) ) @staticmethod def normalize_column_names( config: _C, dataset_id: int | str, dataset_context: DatasetContext | None = None, ) -> _C: """ Normalize column names in config to match the canonical dataset column names. This fixes case sensitivity issues where user-provided column names (e.g., 'order_date') don't match exactly with the dataset column names (e.g., 'OrderDate'). The frontend performs case-sensitive comparisons, so we need to ensure column names match exactly. Previously only XYChartConfig and TableChartConfig were normalized; now all 7 chart types are handled via the plugin registry. Args: config: Chart configuration with column references dataset_id: Dataset ID to get canonical column names from dataset_context: Pre-fetched dataset context to avoid duplicate DB queries. If None, fetches from the database. Returns: A new config with normalized column names """ if dataset_context is None: dataset_context = DatasetValidator._get_dataset_context(dataset_id) if not dataset_context: return config # Local import: plugins call DatasetValidator helpers from # normalize_column_refs(). # A top-level import of registry in dataset_validator would make loading this # module implicitly trigger plugin registration, creating a circular dependency. from superset.mcp_service.chart.registry import get_registry chart_type = getattr(config, "chart_type", None) if chart_type is None: return config plugin = get_registry().get(chart_type) if plugin is None: logger.warning( "No plugin for chart_type=%r; skipping column normalization", chart_type ) return config return plugin.normalize_column_refs(config, dataset_context) @staticmethod def _get_column_suggestions( column_name: str, dataset_context: DatasetContext, max_suggestions: int = 3 ) -> List[ColumnSuggestion]: """Get column name suggestions using fuzzy matching.""" all_names = [] # Collect all column names for col in dataset_context.available_columns: all_names.append((col["name"], "column", col.get("type", "UNKNOWN"))) for metric in dataset_context.available_metrics: all_names.append((metric["name"], "metric", "METRIC")) # Find close matches column_lower = column_name.lower() candidate_lookup = [name[0].lower() for name in all_names] close_matches = difflib.get_close_matches( column_lower, candidate_lookup, n=max_suggestions, cutoff=0.6, ) # Build suggestions with proper case and type info. ``ColumnSuggestion`` # requires ``similarity_score`` and does not have a ``data_type`` field; # we score via difflib ratio and store the candidate kind in ``type``. suggestions = [] for match in close_matches: for name, col_type, _data_type in all_names: if name.lower() == match: score = difflib.SequenceMatcher(None, column_lower, match).ratio() suggestions.append( ColumnSuggestion( name=name, type=col_type, similarity_score=round(score, 3), ) ) break return suggestions @staticmethod def _build_column_error( invalid_columns: List[ColumnRef], suggestions_map: Dict[str, List[ColumnSuggestion]], dataset_context: DatasetContext, ) -> ChartGenerationError: """Build error for invalid columns.""" from superset.mcp_service.utils.error_builder import ( ChartErrorBuilder, ) # Format error message if len(invalid_columns) == 1: col = invalid_columns[0] suggestions = suggestions_map.get(col.name, []) if suggestions: return ChartErrorBuilder.column_not_found_error( col.name, [s.name for s in suggestions] ) else: return ChartErrorBuilder.column_not_found_error(col.name) else: # Multiple invalid columns invalid_names = [col.name for col in invalid_columns] return ChartErrorBuilder.build_error( error_type="multiple_invalid_columns", template_key="column_not_found", template_vars={ "column": ", ".join(invalid_names[:3]) + ("..." if len(invalid_names) > 3 else ""), "suggestions": "Use get_dataset_info to see all available columns", }, custom_suggestions=[ f"Invalid columns: {', '.join(invalid_names)}", "Check spelling and case sensitivity", "Use get_dataset_info to list available columns", ], error_code="MULTIPLE_INVALID_COLUMNS", ) @staticmethod def _validate_saved_metrics( column_refs: List[ColumnRef], dataset_context: DatasetContext ) -> ChartGenerationError | None: """Validate that saved_metric refs exist in dataset metrics. A ColumnRef with saved_metric=True must match an entry in available_metrics, not just available_columns. Without this check a regular column name marked as saved_metric would pass _column_exists (which checks both lists) but fail at query time. """ metric_names = {m["name"].lower() for m in dataset_context.available_metrics} invalid = [ col_ref.name for col_ref in column_refs if col_ref.saved_metric and col_ref.name.lower() not in metric_names ] if not invalid: return None from superset.mcp_service.utils.error_builder import ChartErrorBuilder available = [m["name"] for m in dataset_context.available_metrics] return ChartErrorBuilder.build_error( error_type="invalid_saved_metric", template_key="column_not_found", template_vars={ "column": ", ".join(invalid), "suggestions": ( f"Available saved metrics: {', '.join(available[:10])}" if available else "This dataset has no saved metrics" ), }, custom_suggestions=[ f"'{name}' is not a saved metric in this dataset. " "Remove saved_metric=True to use it as a column with an aggregate, " "or use get_dataset_info to see available saved metrics." for name in invalid ], error_code="INVALID_SAVED_METRIC", ) @staticmethod def _validate_aggregations( column_refs: List[ColumnRef], dataset_context: DatasetContext ) -> List[ChartGenerationError]: """Validate that aggregations are appropriate for column types.""" errors = [] for col_ref in column_refs: if col_ref.saved_metric: continue # Saved metrics have built-in aggregation if not col_ref.aggregate: continue # Find column info col_info = None for col in dataset_context.available_columns: if col["name"].lower() == col_ref.name.lower(): col_info = col break if col_info: # Check numeric aggregates on non-numeric columns. # MIN and MAX are intentionally excluded: they work on dates # and text in most SQL engines, so restricting them here would # produce false-positive errors. Leave those to the Tier-2 # compile check. numeric_aggs = ["SUM", "AVG", "STDDEV", "VAR", "MEDIAN"] if ( col_ref.aggregate in numeric_aggs and not col_info.get("is_numeric", False) and col_info.get("type", "").upper() not in ["INTEGER", "FLOAT", "DOUBLE", "DECIMAL", "NUMERIC"] ): from superset.mcp_service.utils.error_builder import ( # noqa: E501 ChartErrorBuilder, ) errors.append( ChartErrorBuilder.build_error( error_type="invalid_aggregation", template_key="incompatible_configuration", template_vars={ "reason": f"Cannot apply {col_ref.aggregate} to " f"non-numeric column " f"'{col_ref.name}' (type:" f" {col_info.get('type', 'UNKNOWN')})", "primary_suggestion": "Use COUNT or COUNT_DISTINCT " "for text columns", }, custom_suggestions=[ "Remove the aggregate function for raw values", "Use COUNT to count occurrences", "Use COUNT_DISTINCT to count unique values", ], error_code="INVALID_AGGREGATION", ) ) return errors