mirror of
https://github.com/apache/superset.git
synced 2026-05-21 15:55:10 +00:00
- Move error_schemas import above _C TypeVar definition (E402) - Split two over-length comment lines to ≤88 chars (E501, lines 268 and 380)
586 lines
22 KiB
Python
586 lines
22 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""
|
|
Dataset-specific validation for chart configurations.
|
|
Validates that referenced columns exist in the dataset schema.
|
|
"""
|
|
|
|
import difflib
|
|
import logging
|
|
from typing import Any, Dict, List, Tuple, TypeVar
|
|
|
|
from superset.mcp_service.chart.schemas import (
|
|
ChartConfig,
|
|
ColumnRef,
|
|
)
|
|
from superset.mcp_service.common.error_schemas import (
|
|
ChartGenerationError,
|
|
ColumnSuggestion,
|
|
DatasetContext,
|
|
)
|
|
|
|
_C = TypeVar("_C", bound=ChartConfig)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
# Exceptions that can occur during column name normalization.
|
|
# Shared by the validation pipeline and tool-level normalization calls.
|
|
NORMALIZATION_EXCEPTIONS = (
|
|
ImportError,
|
|
AttributeError,
|
|
KeyError,
|
|
ValueError,
|
|
TypeError,
|
|
)
|
|
|
|
|
|
class DatasetValidator:
|
|
"""Validates chart configuration against dataset schema."""
|
|
|
|
@staticmethod
|
|
def validate_against_dataset(
|
|
config: ChartConfig,
|
|
dataset_id: int | str,
|
|
dataset_context: DatasetContext | None = None,
|
|
) -> Tuple[bool, ChartGenerationError | None]:
|
|
"""
|
|
Validate chart configuration against dataset schema.
|
|
|
|
Args:
|
|
config: Chart configuration to validate
|
|
dataset_id: Dataset ID to validate against
|
|
dataset_context: Pre-fetched dataset context to avoid duplicate
|
|
DB queries. If None, fetches from the database.
|
|
|
|
Returns:
|
|
Tuple of (is_valid, error)
|
|
"""
|
|
# Get dataset context (reuse if provided)
|
|
if dataset_context is None:
|
|
dataset_context = DatasetValidator._get_dataset_context(dataset_id)
|
|
if not dataset_context:
|
|
from superset.mcp_service.utils.error_builder import (
|
|
ChartErrorBuilder,
|
|
)
|
|
|
|
return False, ChartErrorBuilder.dataset_not_found_error(dataset_id)
|
|
|
|
# Collect all column references
|
|
column_refs = DatasetValidator._extract_column_references(config)
|
|
|
|
# Validate saved metrics exist in dataset metrics specifically
|
|
invalid_saved = DatasetValidator._validate_saved_metrics(
|
|
column_refs, dataset_context
|
|
)
|
|
if invalid_saved:
|
|
return False, invalid_saved
|
|
|
|
# Validate columns exist (skip saved metrics — already validated above)
|
|
column_error = DatasetValidator._validate_columns_exist(
|
|
column_refs, dataset_context
|
|
)
|
|
if column_error:
|
|
return False, column_error
|
|
|
|
# Validate aggregation compatibility for every config that produced
|
|
# column refs. ``_validate_aggregations`` is config-agnostic — gating
|
|
# it to Table/XY would let pie / pivot table / mixed timeseries /
|
|
# handlebars / big number slip through ``SUM(non_numeric)`` patterns
|
|
# for the fast-path tools that skip Tier 2.
|
|
aggregation_errors = DatasetValidator._validate_aggregations(
|
|
column_refs, dataset_context
|
|
)
|
|
if aggregation_errors:
|
|
return False, aggregation_errors[0]
|
|
|
|
return True, None
|
|
|
|
@staticmethod
|
|
def _validate_columns_exist(
|
|
column_refs: List[ColumnRef], dataset_context: DatasetContext
|
|
) -> ChartGenerationError | None:
|
|
"""Validate that non-saved-metric column refs exist in the dataset.
|
|
|
|
A ``ColumnRef`` with ``saved_metric=False`` must match an entry in
|
|
``available_columns``. Saved-metric *names* don't satisfy this check —
|
|
otherwise ``{name: "sum_boys", aggregate: "SUM"}`` (no
|
|
``saved_metric=true``) would slip through and downstream code would
|
|
emit ``SUM(sum_boys)`` as an ad-hoc SIMPLE metric, producing the
|
|
broken-SQL pattern this validator is meant to prevent.
|
|
"""
|
|
column_names_lower = {
|
|
col["name"].lower() for col in dataset_context.available_columns
|
|
}
|
|
metric_names_lower = {
|
|
metric["name"].lower() for metric in dataset_context.available_metrics
|
|
}
|
|
|
|
invalid_columns: List[ColumnRef] = []
|
|
saved_metric_typo: List[ColumnRef] = []
|
|
for col_ref in column_refs:
|
|
if col_ref.saved_metric:
|
|
continue
|
|
name_lower = col_ref.name.lower()
|
|
if name_lower in column_names_lower:
|
|
continue
|
|
if name_lower in metric_names_lower:
|
|
# Name matches a saved metric but the ref didn't opt into
|
|
# saved-metric resolution. Surface a tailored hint so the
|
|
# caller (typically an LLM) can flip ``saved_metric=true``.
|
|
saved_metric_typo.append(col_ref)
|
|
else:
|
|
invalid_columns.append(col_ref)
|
|
|
|
if saved_metric_typo:
|
|
return DatasetValidator._build_saved_metric_hint_error(saved_metric_typo)
|
|
|
|
if not invalid_columns:
|
|
return None
|
|
|
|
suggestions_map = {}
|
|
for col_ref in invalid_columns:
|
|
suggestions = DatasetValidator._get_column_suggestions(
|
|
col_ref.name, dataset_context
|
|
)
|
|
suggestions_map[col_ref.name] = suggestions
|
|
|
|
return DatasetValidator._build_column_error(
|
|
invalid_columns, suggestions_map, dataset_context
|
|
)
|
|
|
|
@staticmethod
|
|
def _build_saved_metric_hint_error(
|
|
refs: List[ColumnRef],
|
|
) -> ChartGenerationError:
|
|
"""Error response when a non-saved-metric ref names a saved metric."""
|
|
names = [r.name for r in refs]
|
|
names_str = ", ".join(f"'{n}'" for n in names)
|
|
first = names[0]
|
|
return ChartGenerationError(
|
|
error_type="saved_metric_not_marked",
|
|
message=(
|
|
f"{names_str} matches a saved metric but the ref doesn't "
|
|
f"have saved_metric=true"
|
|
),
|
|
details=(
|
|
f"The dataset has a saved metric named {names_str}. To use "
|
|
f"it, set 'saved_metric': true on the column ref instead of "
|
|
f"providing an 'aggregate'. With the current shape, the "
|
|
f"chart would emit ad-hoc SQL like SUM({first}) — which is "
|
|
f"invalid because {first} is a metric expression, not a "
|
|
f"column."
|
|
),
|
|
suggestions=[
|
|
f'Did you mean: {{"name": "{first}", "saved_metric": true}}?',
|
|
"Use saved_metric=true to reference a saved dataset metric",
|
|
"Or pick a real column name and apply an aggregate to it",
|
|
],
|
|
error_code="SAVED_METRIC_NOT_MARKED",
|
|
)
|
|
|
|
@staticmethod
|
|
def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None:
|
|
"""Get dataset context with column information."""
|
|
try:
|
|
from superset.daos.dataset import DatasetDAO
|
|
|
|
# Find dataset
|
|
if isinstance(dataset_id, int) or (
|
|
isinstance(dataset_id, str) and dataset_id.isdigit()
|
|
):
|
|
dataset = DatasetDAO.find_by_id(int(dataset_id))
|
|
else:
|
|
# Try UUID lookup
|
|
dataset = DatasetDAO.find_by_id(dataset_id, id_column="uuid")
|
|
|
|
if not dataset:
|
|
return None
|
|
|
|
# Build context
|
|
columns = []
|
|
metrics = []
|
|
|
|
# Add table columns
|
|
for col in dataset.columns:
|
|
columns.append(
|
|
{
|
|
"name": col.column_name,
|
|
"type": str(col.type) if col.type else "UNKNOWN",
|
|
"is_temporal": col.is_temporal
|
|
if hasattr(col, "is_temporal")
|
|
else False,
|
|
"is_numeric": col.is_numeric
|
|
if hasattr(col, "is_numeric")
|
|
else False,
|
|
}
|
|
)
|
|
|
|
# Add metrics
|
|
for metric in dataset.metrics:
|
|
metrics.append(
|
|
{
|
|
"name": metric.metric_name,
|
|
"expression": metric.expression,
|
|
"description": metric.description,
|
|
}
|
|
)
|
|
|
|
return DatasetContext(
|
|
id=dataset.id,
|
|
table_name=dataset.table_name,
|
|
schema=dataset.schema,
|
|
database_name=dataset.database.database_name
|
|
if dataset.database
|
|
else None,
|
|
available_columns=columns,
|
|
available_metrics=metrics,
|
|
)
|
|
|
|
except Exception as e:
|
|
logger.error("Error getting dataset context for %s: %s", dataset_id, e)
|
|
return None
|
|
|
|
@staticmethod
|
|
def _extract_column_references(
|
|
config: ChartConfig,
|
|
) -> List[ColumnRef]:
|
|
"""Extract all column references from configuration via the plugin registry.
|
|
|
|
Previously only handled TableChartConfig and XYChartConfig, causing
|
|
5 of 7 chart types to silently skip column validation. Now delegates
|
|
to the plugin for each chart type so all types are covered.
|
|
"""
|
|
# Local import: plugins call DatasetValidator helpers from
|
|
# normalize_column_refs().
|
|
# A top-level import of registry in dataset_validator would make loading this
|
|
# module implicitly trigger plugin registration, creating a circular dependency.
|
|
from superset.mcp_service.chart.registry import get_registry
|
|
|
|
chart_type = getattr(config, "chart_type", None)
|
|
if chart_type is None:
|
|
return []
|
|
|
|
plugin = get_registry().get(chart_type)
|
|
if plugin is None:
|
|
logger.warning("No plugin registered for chart_type=%r", chart_type)
|
|
return []
|
|
|
|
return plugin.extract_column_refs(config)
|
|
|
|
@staticmethod
|
|
def _column_exists(column_name: str, dataset_context: DatasetContext) -> bool:
|
|
"""Check if column exists in dataset (case-insensitive)."""
|
|
column_lower = column_name.lower()
|
|
|
|
# Check regular columns
|
|
for col in dataset_context.available_columns:
|
|
if col["name"].lower() == column_lower:
|
|
return True
|
|
|
|
# Check metrics
|
|
for metric in dataset_context.available_metrics:
|
|
if metric["name"].lower() == column_lower:
|
|
return True
|
|
|
|
return False
|
|
|
|
@staticmethod
|
|
def _get_canonical_column_name(
|
|
column_name: str, dataset_context: DatasetContext
|
|
) -> str:
|
|
"""
|
|
Get the canonical column name from the dataset.
|
|
|
|
Performs case-insensitive matching and returns the actual column name
|
|
as stored in the dataset. This ensures column names in form_data match
|
|
exactly with what the frontend expects.
|
|
|
|
Args:
|
|
column_name: The column name to normalize
|
|
dataset_context: Dataset context with column information
|
|
|
|
Returns:
|
|
The canonical column name from the dataset, or the original name
|
|
if no match is found.
|
|
"""
|
|
column_lower = column_name.lower()
|
|
|
|
# Check regular columns first
|
|
for col in dataset_context.available_columns:
|
|
if col["name"].lower() == column_lower:
|
|
return col["name"]
|
|
|
|
# Check metrics
|
|
for metric in dataset_context.available_metrics:
|
|
if metric["name"].lower() == column_lower:
|
|
return metric["name"]
|
|
|
|
# Return original if not found (validation should catch this case)
|
|
return column_name
|
|
|
|
@staticmethod
|
|
def _normalize_filters(
|
|
config_dict: Dict[str, Any], dataset_context: DatasetContext
|
|
) -> None:
|
|
"""Normalize filter column names in a config dict in place."""
|
|
if "filters" in config_dict and config_dict["filters"]:
|
|
for filter_config in config_dict["filters"]:
|
|
if filter_config and "column" in filter_config:
|
|
filter_config["column"] = (
|
|
DatasetValidator._get_canonical_column_name(
|
|
filter_config["column"], dataset_context
|
|
)
|
|
)
|
|
|
|
@staticmethod
|
|
def normalize_column_names(
|
|
config: _C,
|
|
dataset_id: int | str,
|
|
dataset_context: DatasetContext | None = None,
|
|
) -> _C:
|
|
"""
|
|
Normalize column names in config to match the canonical dataset column names.
|
|
|
|
This fixes case sensitivity issues where user-provided column names
|
|
(e.g., 'order_date') don't match exactly with the dataset column names
|
|
(e.g., 'OrderDate'). The frontend performs case-sensitive comparisons,
|
|
so we need to ensure column names match exactly.
|
|
|
|
Previously only XYChartConfig and TableChartConfig were normalized; now
|
|
all 7 chart types are handled via the plugin registry.
|
|
|
|
Args:
|
|
config: Chart configuration with column references
|
|
dataset_id: Dataset ID to get canonical column names from
|
|
dataset_context: Pre-fetched dataset context to avoid duplicate
|
|
DB queries. If None, fetches from the database.
|
|
|
|
Returns:
|
|
A new config with normalized column names
|
|
"""
|
|
if dataset_context is None:
|
|
dataset_context = DatasetValidator._get_dataset_context(dataset_id)
|
|
if not dataset_context:
|
|
return config
|
|
|
|
# Local import: plugins call DatasetValidator helpers from
|
|
# normalize_column_refs().
|
|
# A top-level import of registry in dataset_validator would make loading this
|
|
# module implicitly trigger plugin registration, creating a circular dependency.
|
|
from superset.mcp_service.chart.registry import get_registry
|
|
|
|
chart_type = getattr(config, "chart_type", None)
|
|
if chart_type is None:
|
|
return config
|
|
|
|
plugin = get_registry().get(chart_type)
|
|
if plugin is None:
|
|
logger.warning(
|
|
"No plugin for chart_type=%r; skipping column normalization", chart_type
|
|
)
|
|
return config
|
|
|
|
return plugin.normalize_column_refs(config, dataset_context)
|
|
|
|
@staticmethod
|
|
def _get_column_suggestions(
|
|
column_name: str, dataset_context: DatasetContext, max_suggestions: int = 3
|
|
) -> List[ColumnSuggestion]:
|
|
"""Get column name suggestions using fuzzy matching."""
|
|
all_names = []
|
|
|
|
# Collect all column names
|
|
for col in dataset_context.available_columns:
|
|
all_names.append((col["name"], "column", col.get("type", "UNKNOWN")))
|
|
|
|
for metric in dataset_context.available_metrics:
|
|
all_names.append((metric["name"], "metric", "METRIC"))
|
|
|
|
# Find close matches
|
|
column_lower = column_name.lower()
|
|
candidate_lookup = [name[0].lower() for name in all_names]
|
|
close_matches = difflib.get_close_matches(
|
|
column_lower,
|
|
candidate_lookup,
|
|
n=max_suggestions,
|
|
cutoff=0.6,
|
|
)
|
|
|
|
# Build suggestions with proper case and type info. ``ColumnSuggestion``
|
|
# requires ``similarity_score`` and does not have a ``data_type`` field;
|
|
# we score via difflib ratio and store the candidate kind in ``type``.
|
|
suggestions = []
|
|
for match in close_matches:
|
|
for name, col_type, _data_type in all_names:
|
|
if name.lower() == match:
|
|
score = difflib.SequenceMatcher(None, column_lower, match).ratio()
|
|
suggestions.append(
|
|
ColumnSuggestion(
|
|
name=name,
|
|
type=col_type,
|
|
similarity_score=round(score, 3),
|
|
)
|
|
)
|
|
break
|
|
|
|
return suggestions
|
|
|
|
@staticmethod
|
|
def _build_column_error(
|
|
invalid_columns: List[ColumnRef],
|
|
suggestions_map: Dict[str, List[ColumnSuggestion]],
|
|
dataset_context: DatasetContext,
|
|
) -> ChartGenerationError:
|
|
"""Build error for invalid columns."""
|
|
from superset.mcp_service.utils.error_builder import (
|
|
ChartErrorBuilder,
|
|
)
|
|
|
|
# Format error message
|
|
if len(invalid_columns) == 1:
|
|
col = invalid_columns[0]
|
|
suggestions = suggestions_map.get(col.name, [])
|
|
|
|
if suggestions:
|
|
return ChartErrorBuilder.column_not_found_error(
|
|
col.name, [s.name for s in suggestions]
|
|
)
|
|
else:
|
|
return ChartErrorBuilder.column_not_found_error(col.name)
|
|
else:
|
|
# Multiple invalid columns
|
|
invalid_names = [col.name for col in invalid_columns]
|
|
return ChartErrorBuilder.build_error(
|
|
error_type="multiple_invalid_columns",
|
|
template_key="column_not_found",
|
|
template_vars={
|
|
"column": ", ".join(invalid_names[:3])
|
|
+ ("..." if len(invalid_names) > 3 else ""),
|
|
"suggestions": "Use get_dataset_info to see all available columns",
|
|
},
|
|
custom_suggestions=[
|
|
f"Invalid columns: {', '.join(invalid_names)}",
|
|
"Check spelling and case sensitivity",
|
|
"Use get_dataset_info to list available columns",
|
|
],
|
|
error_code="MULTIPLE_INVALID_COLUMNS",
|
|
)
|
|
|
|
@staticmethod
|
|
def _validate_saved_metrics(
|
|
column_refs: List[ColumnRef], dataset_context: DatasetContext
|
|
) -> ChartGenerationError | None:
|
|
"""Validate that saved_metric refs exist in dataset metrics.
|
|
|
|
A ColumnRef with saved_metric=True must match an entry in
|
|
available_metrics, not just available_columns. Without this check
|
|
a regular column name marked as saved_metric would pass
|
|
_column_exists (which checks both lists) but fail at query time.
|
|
"""
|
|
metric_names = {m["name"].lower() for m in dataset_context.available_metrics}
|
|
invalid = [
|
|
col_ref.name
|
|
for col_ref in column_refs
|
|
if col_ref.saved_metric and col_ref.name.lower() not in metric_names
|
|
]
|
|
if not invalid:
|
|
return None
|
|
|
|
from superset.mcp_service.utils.error_builder import ChartErrorBuilder
|
|
|
|
available = [m["name"] for m in dataset_context.available_metrics]
|
|
return ChartErrorBuilder.build_error(
|
|
error_type="invalid_saved_metric",
|
|
template_key="column_not_found",
|
|
template_vars={
|
|
"column": ", ".join(invalid),
|
|
"suggestions": (
|
|
f"Available saved metrics: {', '.join(available[:10])}"
|
|
if available
|
|
else "This dataset has no saved metrics"
|
|
),
|
|
},
|
|
custom_suggestions=[
|
|
f"'{name}' is not a saved metric in this dataset. "
|
|
"Remove saved_metric=True to use it as a column with an aggregate, "
|
|
"or use get_dataset_info to see available saved metrics."
|
|
for name in invalid
|
|
],
|
|
error_code="INVALID_SAVED_METRIC",
|
|
)
|
|
|
|
@staticmethod
|
|
def _validate_aggregations(
|
|
column_refs: List[ColumnRef], dataset_context: DatasetContext
|
|
) -> List[ChartGenerationError]:
|
|
"""Validate that aggregations are appropriate for column types."""
|
|
errors = []
|
|
|
|
for col_ref in column_refs:
|
|
if col_ref.saved_metric:
|
|
continue # Saved metrics have built-in aggregation
|
|
if not col_ref.aggregate:
|
|
continue
|
|
|
|
# Find column info
|
|
col_info = None
|
|
for col in dataset_context.available_columns:
|
|
if col["name"].lower() == col_ref.name.lower():
|
|
col_info = col
|
|
break
|
|
|
|
if col_info:
|
|
# Check numeric aggregates on non-numeric columns.
|
|
# MIN and MAX are intentionally excluded: they work on dates
|
|
# and text in most SQL engines, so restricting them here would
|
|
# produce false-positive errors. Leave those to the Tier-2
|
|
# compile check.
|
|
numeric_aggs = ["SUM", "AVG", "STDDEV", "VAR", "MEDIAN"]
|
|
if (
|
|
col_ref.aggregate in numeric_aggs
|
|
and not col_info.get("is_numeric", False)
|
|
and col_info.get("type", "").upper()
|
|
not in ["INTEGER", "FLOAT", "DOUBLE", "DECIMAL", "NUMERIC"]
|
|
):
|
|
from superset.mcp_service.utils.error_builder import ( # noqa: E501
|
|
ChartErrorBuilder,
|
|
)
|
|
|
|
errors.append(
|
|
ChartErrorBuilder.build_error(
|
|
error_type="invalid_aggregation",
|
|
template_key="incompatible_configuration",
|
|
template_vars={
|
|
"reason": f"Cannot apply {col_ref.aggregate} to "
|
|
f"non-numeric column "
|
|
f"'{col_ref.name}' (type:"
|
|
f" {col_info.get('type', 'UNKNOWN')})",
|
|
"primary_suggestion": "Use COUNT or COUNT_DISTINCT "
|
|
"for text columns",
|
|
},
|
|
custom_suggestions=[
|
|
"Remove the aggregate function for raw values",
|
|
"Use COUNT to count occurrences",
|
|
"Use COUNT_DISTINCT to count unique values",
|
|
],
|
|
error_code="INVALID_AGGREGATION",
|
|
)
|
|
)
|
|
|
|
return errors
|