diff --git a/superset/mcp_service/CLAUDE.md b/superset/mcp_service/CLAUDE.md index 2f300fbd725..954d353cb6e 100644 --- a/superset/mcp_service/CLAUDE.md +++ b/superset/mcp_service/CLAUDE.md @@ -393,17 +393,26 @@ Used by: `get_chart_info`, `get_chart_preview`, `get_chart_data`, `generate_char ### 11. Compile Check for Chart Creation -When creating or saving charts, run a compile check to verify the query executes: +When creating, saving, or previewing charts, run schema validation (Tier 1) +and optionally a compile check (Tier 2) before persisting or caching. +``validate_and_compile`` glues both together; tools with tight SLAs +(``generate_explore_link``, ``update_chart_preview``) opt out of Tier 2. ```python -from superset.mcp_service.chart.tool.generate_chart import _compile_chart +from superset.mcp_service.chart.compile import validate_and_compile -compile_result = _compile_chart(form_data, dataset.id) -if not compile_result.success: - # Delete broken chart, return error +result = validate_and_compile( + config, form_data, dataset, run_compile_check=True +) +if not result.success: + # ``result.error_obj`` is a ``ChartGenerationError`` with fuzzy-match + # suggestions ("did you mean sum_boys?") so the LLM can self-correct. ... ``` +The lower-level ``_compile_chart(form_data, dataset_id)`` is still exported +for callers that have already done their own schema validation. + ### 12. Flexible Input Parsing `ModelListCore` handles JSON string vs. native object parsing automatically via utilities in `superset.mcp_service.utils.schema_utils`: diff --git a/superset/mcp_service/chart/compile.py b/superset/mcp_service/chart/compile.py new file mode 100644 index 00000000000..6809e7534a5 --- /dev/null +++ b/superset/mcp_service/chart/compile.py @@ -0,0 +1,362 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Shared compile/validation helpers for MCP chart-generating tools. + +Two tiers are exposed: + +* **Tier 1 — schema validation** (``DatasetValidator.validate_against_dataset``): + cheap, no SQL execution, catches references to columns or metrics that do + not exist in the dataset and returns fuzzy-match suggestions. +* **Tier 2 — compile check** (``_compile_chart``): runs a small (``row_limit=2``) + ``ChartDataCommand`` against the underlying database to surface anything Tier + 1 cannot catch (incompatible aggregates, virtual-dataset SQL bugs, etc.). + +``validate_and_compile`` glues both together so each MCP tool can opt into the +tier(s) appropriate for its SLA. +""" + +from __future__ import annotations + +import logging +from dataclasses import dataclass, field +from typing import Any, Dict, List, Literal + +from superset.commands.exceptions import CommandException +from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator +from superset.mcp_service.common.error_schemas import ( + ChartGenerationError, + ColumnSuggestion, + DatasetContext, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class CompileResult: + """Result of a chart validate-and-compile check. + + ``error_obj`` carries the structured ``ChartGenerationError`` (with + suggestions, dataset context, etc.) that callers should embed in their + response envelope so LLM clients can self-correct. ``error`` retains the + plain-string form for backwards compatibility with existing call sites. + """ + + success: bool + error: str | None = None + error_code: str | None = None + tier: Literal["validation", "compile"] | None = None + error_obj: ChartGenerationError | None = None + warnings: List[str] = field(default_factory=list) + row_count: int | None = None + + +def build_dataset_context_from_orm(dataset: Any) -> DatasetContext | None: + """Construct a ``DatasetContext`` from an already-fetched ORM dataset. + + Mirrors :py:meth:`DatasetValidator._get_dataset_context` but skips the + ``DatasetDAO.find_by_id`` round trip. Callers that have already loaded + the dataset (for permission checks, etc.) should use this instead. + """ + if dataset is None: + return None + + columns: List[Dict[str, Any]] = [] + for col in getattr(dataset, "columns", []) or []: + columns.append( + { + "name": col.column_name, + "type": str(col.type) if col.type else "UNKNOWN", + "is_temporal": getattr(col, "is_temporal", False), + "is_numeric": getattr(col, "is_numeric", False), + } + ) + + metrics: List[Dict[str, Any]] = [] + for metric in getattr(dataset, "metrics", []) or []: + metrics.append( + { + "name": metric.metric_name, + "expression": metric.expression, + "description": metric.description, + } + ) + + database = getattr(dataset, "database", None) + # ``DatasetContext.database_name`` is typed as required ``str``; default to + # an empty string when the relationship isn't loaded so we don't blow up + # Pydantic validation. The field is purely informational in error messages. + database_name = getattr(database, "database_name", None) or "" + return DatasetContext( + id=dataset.id, + table_name=dataset.table_name, + schema=dataset.schema, + database_name=database_name, + available_columns=columns, + available_metrics=metrics, + ) + + +def _compile_chart( + form_data: Dict[str, Any], + dataset_id: int, +) -> CompileResult: + """Execute the chart's query to verify it renders without errors. + + Builds a ``QueryContext`` from *form_data* and runs it through + ``ChartDataCommand``. A small ``row_limit`` is used so the check is + fast — we only need to know the query compiles and returns data, not + fetch the full result set. + + Returns a :class:`CompileResult` with ``success=True`` when the + query executes cleanly. + """ + from superset.commands.chart.data.get_data_command import ChartDataCommand + from superset.commands.chart.exceptions import ( + ChartDataCacheLoadError, + ChartDataQueryFailedError, + ) + from superset.common.query_context_factory import QueryContextFactory + from superset.mcp_service.chart.chart_utils import adhoc_filters_to_query_filters + from superset.mcp_service.chart.preview_utils import _build_query_columns + + try: + columns = _build_query_columns(form_data) + query_filters = adhoc_filters_to_query_filters( + form_data.get("adhoc_filters", []) + ) + + # Big Number charts use singular "metric" instead of "metrics" + metrics = form_data.get("metrics", []) + if not metrics and form_data.get("metric"): + metrics = [form_data["metric"]] + + # Big Number with trendline uses granularity_sqla as the time column + if not columns and form_data.get("granularity_sqla"): + columns = [form_data["granularity_sqla"]] + + factory = QueryContextFactory() + query_context = factory.create( + datasource={"id": dataset_id, "type": "table"}, + queries=[ + { + "columns": columns, + "metrics": metrics, + "orderby": form_data.get("orderby", []), + "row_limit": 2, + "filters": query_filters, + "time_range": form_data.get("time_range", "No filter"), + } + ], + form_data=form_data, + ) + + command = ChartDataCommand(query_context) + command.validate() + result = command.run() + + warnings: List[str] = [] + row_count = 0 + for query in result.get("queries", []): + if query.get("error"): + error_str = str(query["error"]) + return CompileResult( + success=False, + error=error_str, + error_code="CHART_COMPILE_FAILED", + tier="compile", + error_obj=_build_compile_error(error_str), + ) + row_count += len(query.get("data", [])) + + return CompileResult(success=True, warnings=warnings, row_count=row_count) + except (ChartDataQueryFailedError, ChartDataCacheLoadError) as exc: + return CompileResult( + success=False, + error=str(exc), + error_code="CHART_COMPILE_FAILED", + tier="compile", + error_obj=_build_compile_error(str(exc)), + ) + except (CommandException, ValueError, KeyError) as exc: + return CompileResult( + success=False, + error=str(exc), + error_code="CHART_COMPILE_FAILED", + tier="compile", + error_obj=_build_compile_error(str(exc)), + ) + + +def _adhoc_filter_column_valid( + column: str, clause: str, dataset_context: DatasetContext +) -> bool: + """Return True if *column* is a valid reference for this filter clause. + + WHERE filters must reference a physical column; HAVING filters may also + reference a saved metric because Superset resolves metric names there. + """ + if clause == "HAVING": + return DatasetValidator._column_exists(column, dataset_context) + return any( + col["name"].lower() == column.lower() + for col in dataset_context.available_columns + ) + + +def _validate_adhoc_filter_columns( + form_data: Dict[str, Any], dataset_context: DatasetContext +) -> ChartGenerationError | None: + """Tier-1 check for adhoc-filter column references stored in ``form_data``. + + ``DatasetValidator._extract_column_references`` walks the typed + ``ChartConfig`` and only sees ``config.filters``. Tools like + ``update_chart_preview`` and ``update_chart`` (preview path) also merge + *previously cached* ``adhoc_filters`` into ``form_data`` that aren't + represented on the new config — those would otherwise bypass validation + and surface only when Explore tries to run the query. + """ + adhoc_filters = form_data.get("adhoc_filters") or [] + invalid: List[str] = [] + for f in adhoc_filters: + if not isinstance(f, dict): + continue + # SIMPLE filters expose the column via "subject"; SQL-expression + # filters carry a free-form ``sqlExpression`` we can't safely parse, + # so skip those. + if f.get("expressionType") and f.get("expressionType") != "SIMPLE": + continue + column = f.get("subject") or f.get("col") + if not column or not isinstance(column, str): + continue + clause = f.get("clause", "WHERE").upper() + if not _adhoc_filter_column_valid(column, clause, dataset_context): + invalid.append(column) + + if not invalid: + return None + + suggestions: List[str] = [] + for column in invalid: + for suggestion in DatasetValidator._get_column_suggestions( + column, dataset_context + ): + name = ( + suggestion.name + if isinstance(suggestion, ColumnSuggestion) + else str(suggestion) + ) + if name and name not in suggestions: + suggestions.append(name) + + bad = ", ".join(sorted(set(invalid))) + return ChartGenerationError( + error_type="invalid_column", + message=(f"Filter references column(s) not in dataset: {bad}"), + details=( + "Adhoc filter columns must exist on the dataset. " + "If these filters were preserved from a previous chart preview, " + "remove them or pass an explicit ``filters`` list on the new config." + ), + suggestions=suggestions, + error_code="CHART_VALIDATION_FAILED", + ) + + +def _build_compile_error(message: str) -> ChartGenerationError: + """Wrap a raw compile-failure string in the structured response envelope.""" + return ChartGenerationError( + error_type="compile_error", + message="Chart query failed to execute. The chart was not saved.", + details=message or "", + suggestions=[ + "Check that all columns exist in the dataset", + "Verify aggregate functions are compatible with column types", + "Ensure filters reference valid columns", + "Try simplifying the chart configuration", + ], + error_code="CHART_COMPILE_FAILED", + ) + + +def validate_and_compile( + config: Any, + form_data: Dict[str, Any], + dataset: Any, + *, + run_compile_check: bool = True, +) -> CompileResult: + """Run schema validation (Tier 1) and optionally a compile check (Tier 2). + + ``dataset`` must be an already-fetched ORM dataset; this avoids a second + ``DatasetDAO.find_by_id`` round trip inside the validator. + + ``run_compile_check`` lets fast-path tools (``generate_explore_link``, + ``update_chart_preview``) skip the live DB query while still rejecting + obviously bad column references with fuzzy-match suggestions. + + Returns a :class:`CompileResult`. On failure, ``error_obj`` carries the + structured :class:`ChartGenerationError` (with ``suggestions``) that the + caller should embed in its response envelope so LLM clients can + self-correct. + """ + if dataset is None: + return CompileResult( + success=False, + error="Dataset not provided to validate_and_compile", + error_code="DATASET_NOT_FOUND", + tier="validation", + ) + + dataset_context = build_dataset_context_from_orm(dataset) + + is_valid, error = DatasetValidator.validate_against_dataset( + config, dataset.id, dataset_context=dataset_context + ) + if not is_valid: + details = "" + if error is not None: + details = error.details or error.message + if error.error_code is None: + error.error_code = "CHART_VALIDATION_FAILED" + return CompileResult( + success=False, + error=details, + error_code="CHART_VALIDATION_FAILED", + tier="validation", + error_obj=error, + ) + + # Validate adhoc-filter columns living only in form_data (e.g. filters + # preserved from a previously cached preview). The typed config-level + # validator above doesn't see those. + if dataset_context is not None: + filter_error = _validate_adhoc_filter_columns(form_data, dataset_context) + if filter_error is not None: + return CompileResult( + success=False, + error=filter_error.details or filter_error.message, + error_code="CHART_VALIDATION_FAILED", + tier="validation", + error_obj=filter_error, + ) + + if not run_compile_check: + return CompileResult(success=True) + + return _compile_chart(form_data, dataset.id) diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 646ac4d4c2a..8ad907c1abe 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -20,8 +20,7 @@ MCP tool: generate_chart (simplified schema) import logging import time -from dataclasses import dataclass, field -from typing import Any, Dict, List +from typing import Any from fastmcp import Context from sqlalchemy.exc import SQLAlchemyError @@ -39,6 +38,11 @@ from superset.mcp_service.chart.chart_utils import ( map_config_to_form_data, validate_chart_dataset, ) +from superset.mcp_service.chart.compile import ( + _compile_chart, + CompileResult, + validate_and_compile, +) from superset.mcp_service.chart.schemas import ( AccessibilityMetadata, CHART_FORM_DATA_EXCLUDED_FIELD_NAMES, @@ -74,86 +78,7 @@ def _sanitize_generate_chart_form_data_for_llm_context( ) -@dataclass -class CompileResult: - """Result of a chart compile check (test query execution).""" - - success: bool - error: str | None = None - warnings: List[str] = field(default_factory=list) - row_count: int | None = None - - -def _compile_chart( - form_data: Dict[str, Any], - dataset_id: int, -) -> CompileResult: - """Execute the chart's query to verify it renders without errors. - - Builds a ``QueryContext`` from *form_data* and runs it through - ``ChartDataCommand``. A small ``row_limit`` is used so the check is - fast — we only need to know the query compiles and returns data, not - fetch the full result set. - - Returns a :class:`CompileResult` with ``success=True`` when the - query executes cleanly. - """ - from superset.commands.chart.data.get_data_command import ChartDataCommand - from superset.commands.chart.exceptions import ( - ChartDataCacheLoadError, - ChartDataQueryFailedError, - ) - from superset.common.query_context_factory import QueryContextFactory - from superset.mcp_service.chart.chart_utils import adhoc_filters_to_query_filters - from superset.mcp_service.chart.preview_utils import _build_query_columns - - try: - columns = _build_query_columns(form_data) - query_filters = adhoc_filters_to_query_filters( - form_data.get("adhoc_filters", []) - ) - - # Big Number charts use singular "metric" instead of "metrics" - metrics = form_data.get("metrics", []) - if not metrics and form_data.get("metric"): - metrics = [form_data["metric"]] - - # Big Number with trendline uses granularity_sqla as the time column - if not columns and form_data.get("granularity_sqla"): - columns = [form_data["granularity_sqla"]] - - factory = QueryContextFactory() - query_context = factory.create( - datasource={"id": dataset_id, "type": "table"}, - queries=[ - { - "columns": columns, - "metrics": metrics, - "orderby": form_data.get("orderby", []), - "row_limit": 2, - "filters": query_filters, - "time_range": form_data.get("time_range", "No filter"), - } - ], - form_data=form_data, - ) - - command = ChartDataCommand(query_context) - command.validate() - result = command.run() - - warnings: List[str] = [] - row_count = 0 - for query in result.get("queries", []): - if query.get("error"): - return CompileResult(success=False, error=str(query["error"])) - row_count += len(query.get("data", [])) - - return CompileResult(success=True, warnings=warnings, row_count=row_count) - except (ChartDataQueryFailedError, ChartDataCacheLoadError) as exc: - return CompileResult(success=False, error=str(exc)) - except (CommandException, ValueError, KeyError) as exc: - return CompileResult(success=False, error=str(exc)) +__all__ = ["CompileResult", "_compile_chart", "validate_and_compile", "generate_chart"] @tool( diff --git a/superset/mcp_service/chart/tool/update_chart.py b/superset/mcp_service/chart/tool/update_chart.py index abce23ddddb..3e3057bdd2e 100644 --- a/superset/mcp_service/chart/tool/update_chart.py +++ b/superset/mcp_service/chart/tool/update_chart.py @@ -40,6 +40,7 @@ from superset.mcp_service.chart.chart_utils import ( generate_chart_name, map_config_to_form_data, ) +from superset.mcp_service.chart.compile import validate_and_compile from superset.mcp_service.chart.schemas import ( AccessibilityMetadata, GenerateChartResponse, @@ -162,6 +163,70 @@ def _build_preview_form_data( return merged +def _validate_update_against_dataset( + parsed_config: Any, + form_data: dict[str, Any], + chart: Any, +) -> GenerateChartResponse | None: + """Run Tier 1 (schema) + Tier 2 (compile) validation against the chart's + dataset. Returns ``None`` on success, or a :class:`GenerateChartResponse` + error envelope on failure that callers should return as-is. + """ + from superset.daos.dataset import DatasetDAO + + dataset = getattr(chart, "datasource", None) + if dataset is None and getattr(chart, "datasource_id", None) is not None: + dataset = DatasetDAO.find_by_id(chart.datasource_id) + if dataset is None: + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": { + "error_type": "DatasetNotAccessible", + "message": "Chart's dataset is not accessible", + "details": ( + f"Dataset {getattr(chart, 'datasource_id', None)} " + "is missing or inaccessible." + ), + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) + + compile_result = validate_and_compile( + parsed_config, form_data, dataset, run_compile_check=True + ) + if compile_result.success: + return None + + logger.warning( + "update_chart validation failed for chart %s: %s", + getattr(chart, "id", None), + compile_result.error, + ) + if compile_result.error_obj is not None: + error_payload = compile_result.error_obj.model_dump() + else: + error_payload = { + "error_type": "validation_error", + "message": "Chart update validation failed", + "details": compile_result.error or "", + "error_code": compile_result.error_code, + "suggestions": [], + } + return GenerateChartResponse.model_validate( + { + "chart": None, + "error": error_payload, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + ) + + def _create_preview_url( chart: Any, form_data: dict[str, Any] ) -> tuple[str, str | None, list[str]]: @@ -334,6 +399,18 @@ async def update_chart( # noqa: C901 if "params" in payload_or_error: new_form_data = json.loads(payload_or_error["params"]) + # Validate before persisting — catches bad column refs and runtime + # SQL errors so we don't commit a chart that can't be queried. + # Renames (no parsed_config) skip validation since form_data is + # untouched. + if parsed_config is not None and new_form_data is not None: + with event_logger.log_context(action="mcp.update_chart.validation"): + validation_error = _validate_update_against_dataset( + parsed_config, new_form_data, chart + ) + if validation_error is not None: + return validation_error + with event_logger.log_context(action="mcp.update_chart.db_write"): command = UpdateChartCommand(chart.id, payload_or_error) updated_chart = command.run() @@ -346,6 +423,15 @@ async def update_chart( # noqa: C901 if isinstance(preview_or_error, GenerateChartResponse): return preview_or_error + # Validate before caching the form_data — same rationale as above. + if parsed_config is not None: + with event_logger.log_context(action="mcp.update_chart.validation"): + validation_error = _validate_update_against_dataset( + parsed_config, preview_or_error, chart + ) + if validation_error is not None: + return validation_error + with event_logger.log_context(action="mcp.update_chart.preview_link"): explore_url, form_data_key, warnings = _create_preview_url( chart, preview_or_error diff --git a/superset/mcp_service/chart/tool/update_chart_preview.py b/superset/mcp_service/chart/tool/update_chart_preview.py index 9da8f16fc9b..82ace0df0ad 100644 --- a/superset/mcp_service/chart/tool/update_chart_preview.py +++ b/superset/mcp_service/chart/tool/update_chart_preview.py @@ -30,6 +30,7 @@ from superset_core.mcp.decorators import tool, ToolAnnotations from superset.commands.exceptions import CommandException from superset.exceptions import OAuth2Error, OAuth2RedirectError, SupersetException from superset.extensions import event_logger +from superset.mcp_service.auth import has_dataset_access from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url from superset.mcp_service.chart.chart_utils import ( analyze_chart_capabilities, @@ -38,6 +39,7 @@ from superset.mcp_service.chart.chart_utils import ( generate_explore_link, map_config_to_form_data, ) +from superset.mcp_service.chart.compile import validate_and_compile from superset.mcp_service.chart.schemas import ( AccessibilityMetadata, PerformanceMetadata, @@ -88,7 +90,7 @@ def _get_previous_form_data(form_data_key: str) -> dict[str, Any] | None: destructiveHint=True, ), ) -def update_chart_preview( +def update_chart_preview( # noqa: C901 request: UpdateChartPreviewRequest, ctx: Context ) -> Dict[str, Any]: """Update cached chart preview without saving. @@ -133,6 +135,61 @@ def update_chart_preview( if old_adhoc_filters: new_form_data["adhoc_filters"] = old_adhoc_filters + # Tier-1 schema validation against the dataset (no DB roundtrip). + # Runs AFTER the filter merge so filter columns are also validated. + from superset.daos.dataset import DatasetDAO + + if isinstance(request.dataset_id, int) or ( + isinstance(request.dataset_id, str) and request.dataset_id.isdigit() + ): + dataset = DatasetDAO.find_by_id(int(request.dataset_id)) + else: + dataset = DatasetDAO.find_by_id(request.dataset_id, id_column="uuid") + + if dataset is None or not has_dataset_access(dataset): + return { + "chart": None, + "error": { + "error_type": "DatasetNotAccessible", + "message": ( + f"Dataset not found: {request.dataset_id}. " + "Use list_datasets to find valid dataset IDs." + ), + "details": ( + f"Dataset {request.dataset_id} is missing or inaccessible." + ), + }, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + + compile_result = validate_and_compile( + config, new_form_data, dataset, run_compile_check=False + ) + if not compile_result.success: + logger.warning( + "update_chart_preview validation failed: %s", + compile_result.error, + ) + if compile_result.error_obj is not None: + error_payload = compile_result.error_obj.model_dump() + else: + error_payload = { + "error_type": "validation_error", + "message": "Chart preview validation failed", + "details": compile_result.error or "", + "error_code": compile_result.error_code, + "suggestions": [], + } + return { + "chart": None, + "error": error_payload, + "success": False, + "schema_version": "2.0", + "api_version": "v1", + } + # Generate new explore link with updated form_data explore_url = generate_explore_link(request.dataset_id, new_form_data) diff --git a/superset/mcp_service/chart/validation/dataset_validator.py b/superset/mcp_service/chart/validation/dataset_validator.py index c7f497a426b..5602b7af116 100644 --- a/superset/mcp_service/chart/validation/dataset_validator.py +++ b/superset/mcp_service/chart/validation/dataset_validator.py @@ -25,7 +25,12 @@ import logging from typing import Any, Dict, List, Tuple from superset.mcp_service.chart.schemas import ( + BigNumberChartConfig, ColumnRef, + HandlebarsChartConfig, + MixedTimeseriesChartConfig, + PieChartConfig, + PivotTableChartConfig, TableChartConfig, XYChartConfig, ) @@ -53,7 +58,7 @@ class DatasetValidator: @staticmethod def validate_against_dataset( - config: TableChartConfig | XYChartConfig, + config: Any, dataset_id: int | str, dataset_context: DatasetContext | None = None, ) -> Tuple[bool, ChartGenerationError | None]: @@ -96,13 +101,16 @@ class DatasetValidator: if column_error: return False, column_error - # Validate aggregation compatibility - if isinstance(config, (TableChartConfig, XYChartConfig)): - aggregation_errors = DatasetValidator._validate_aggregations( - column_refs, dataset_context - ) - if aggregation_errors: - return False, aggregation_errors[0] + # Validate aggregation compatibility for every config that produced + # column refs. ``_validate_aggregations`` is config-agnostic — gating + # it to Table/XY would let pie / pivot table / mixed timeseries / + # handlebars / big number slip through ``SUM(non_numeric)`` patterns + # for the fast-path tools that skip Tier 2. + aggregation_errors = DatasetValidator._validate_aggregations( + column_refs, dataset_context + ) + if aggregation_errors: + return False, aggregation_errors[0] return True, None @@ -110,14 +118,41 @@ class DatasetValidator: def _validate_columns_exist( column_refs: List[ColumnRef], dataset_context: DatasetContext ) -> ChartGenerationError | None: - """Validate that non-saved-metric column refs exist in the dataset.""" - invalid_columns = [] + """Validate that non-saved-metric column refs exist in the dataset. + + A ``ColumnRef`` with ``saved_metric=False`` must match an entry in + ``available_columns``. Saved-metric *names* don't satisfy this check — + otherwise ``{name: "sum_boys", aggregate: "SUM"}`` (no + ``saved_metric=true``) would slip through and downstream code would + emit ``SUM(sum_boys)`` as an ad-hoc SIMPLE metric, producing the + broken-SQL pattern this validator is meant to prevent. + """ + column_names_lower = { + col["name"].lower() for col in dataset_context.available_columns + } + metric_names_lower = { + metric["name"].lower() for metric in dataset_context.available_metrics + } + + invalid_columns: List[ColumnRef] = [] + saved_metric_typo: List[ColumnRef] = [] for col_ref in column_refs: if col_ref.saved_metric: continue - if not DatasetValidator._column_exists(col_ref.name, dataset_context): + name_lower = col_ref.name.lower() + if name_lower in column_names_lower: + continue + if name_lower in metric_names_lower: + # Name matches a saved metric but the ref didn't opt into + # saved-metric resolution. Surface a tailored hint so the + # caller (typically an LLM) can flip ``saved_metric=true``. + saved_metric_typo.append(col_ref) + else: invalid_columns.append(col_ref) + if saved_metric_typo: + return DatasetValidator._build_saved_metric_hint_error(saved_metric_typo) + if not invalid_columns: return None @@ -132,6 +167,36 @@ class DatasetValidator: invalid_columns, suggestions_map, dataset_context ) + @staticmethod + def _build_saved_metric_hint_error( + refs: List[ColumnRef], + ) -> ChartGenerationError: + """Error response when a non-saved-metric ref names a saved metric.""" + names = [r.name for r in refs] + names_str = ", ".join(f"'{n}'" for n in names) + first = names[0] + return ChartGenerationError( + error_type="saved_metric_not_marked", + message=( + f"{names_str} matches a saved metric but the ref doesn't " + f"have saved_metric=true" + ), + details=( + f"The dataset has a saved metric named {names_str}. To use " + f"it, set 'saved_metric': true on the column ref instead of " + f"providing an 'aggregate'. With the current shape, the " + f"chart would emit ad-hoc SQL like SUM({first}) — which is " + f"invalid because {first} is a metric expression, not a " + f"column." + ), + suggestions=[ + f'Did you mean: {{"name": "{first}", "saved_metric": true}}?', + "Use saved_metric=true to reference a saved dataset metric", + "Or pick a real column name and apply an aggregate to it", + ], + error_code="SAVED_METRIC_NOT_MARKED", + ) + @staticmethod def _get_dataset_context(dataset_id: int | str) -> DatasetContext | None: """Get dataset context with column information.""" @@ -195,11 +260,16 @@ class DatasetValidator: return None @staticmethod - def _extract_column_references( - config: TableChartConfig | XYChartConfig, - ) -> List[ColumnRef]: - """Extract all column references from configuration.""" - refs = [] + def _extract_column_references(config: Any) -> List[ColumnRef]: # noqa: C901 + """Extract all column references from a chart configuration. + + Covers every supported ``ChartConfig`` variant so fast-path tools + (``generate_explore_link``, ``update_chart_preview``) that only run + Tier-1 validation still catch bad column refs in pie / pivot table / + mixed timeseries / handlebars / big number charts — not just XY and + table. + """ + refs: List[ColumnRef] = [] if isinstance(config, TableChartConfig): refs.extend(config.columns) @@ -209,10 +279,37 @@ class DatasetValidator: refs.extend(config.y) if config.group_by: refs.extend(config.group_by) + elif isinstance(config, PieChartConfig): + refs.append(config.dimension) + refs.append(config.metric) + elif isinstance(config, PivotTableChartConfig): + refs.extend(config.rows) + if config.columns: + refs.extend(config.columns) + refs.extend(config.metrics) + elif isinstance(config, MixedTimeseriesChartConfig): + refs.append(config.x) + refs.extend(config.y) + if config.group_by: + refs.extend(config.group_by) + refs.extend(config.y_secondary) + if config.group_by_secondary: + refs.extend(config.group_by_secondary) + elif isinstance(config, HandlebarsChartConfig): + if config.columns: + refs.extend(config.columns) + if config.groupby: + refs.extend(config.groupby) + if config.metrics: + refs.extend(config.metrics) + elif isinstance(config, BigNumberChartConfig): + refs.append(config.metric) + if config.temporal_column: + refs.append(ColumnRef(name=config.temporal_column)) - # Add filter columns - if hasattr(config, "filters") and config.filters: - for filter_config in config.filters: + # Filter columns (shared by every config type that defines ``filters``). + if filters := getattr(config, "filters", None): + for filter_config in filters: refs.append(ColumnRef(name=filter_config.column)) return refs @@ -379,20 +476,28 @@ class DatasetValidator: # Find close matches column_lower = column_name.lower() + candidate_lookup = [name[0].lower() for name in all_names] close_matches = difflib.get_close_matches( column_lower, - [name[0].lower() for name in all_names], + candidate_lookup, n=max_suggestions, cutoff=0.6, ) - # Build suggestions with proper case and type info + # Build suggestions with proper case and type info. ``ColumnSuggestion`` + # requires ``similarity_score`` and does not have a ``data_type`` field; + # we score via difflib ratio and store the candidate kind in ``type``. suggestions = [] for match in close_matches: - for name, col_type, data_type in all_names: + for name, col_type, _data_type in all_names: if name.lower() == match: + score = difflib.SequenceMatcher(None, column_lower, match).ratio() suggestions.append( - ColumnSuggestion(name=name, type=col_type, data_type=data_type) + ColumnSuggestion( + name=name, + type=col_type, + similarity_score=round(score, 3), + ) ) break @@ -503,8 +608,12 @@ class DatasetValidator: break if col_info: - # Check numeric aggregates on non-numeric columns - numeric_aggs = ["SUM", "AVG", "MIN", "MAX", "STDDEV", "VAR", "MEDIAN"] + # Check numeric aggregates on non-numeric columns. + # MIN and MAX are intentionally excluded: they work on dates + # and text in most SQL engines, so restricting them here would + # produce false-positive errors. Leave those to the Tier-2 + # compile check. + numeric_aggs = ["SUM", "AVG", "STDDEV", "VAR", "MEDIAN"] if ( col_ref.aggregate in numeric_aggs and not col_info.get("is_numeric", False) diff --git a/superset/mcp_service/explore/tool/generate_explore_link.py b/superset/mcp_service/explore/tool/generate_explore_link.py index 1ff0ea0d6c1..c6af97ecca7 100644 --- a/superset/mcp_service/explore/tool/generate_explore_link.py +++ b/superset/mcp_service/explore/tool/generate_explore_link.py @@ -22,21 +22,26 @@ This tool generates a URL to the Superset explore interface with the specified chart configuration. """ +import logging from typing import Any, Dict from fastmcp import Context from superset_core.mcp.decorators import tool, ToolAnnotations from superset.extensions import event_logger +from superset.mcp_service.auth import has_dataset_access from superset.mcp_service.chart.chart_helpers import extract_form_data_key_from_url from superset.mcp_service.chart.chart_utils import ( generate_explore_link as generate_url, map_config_to_form_data, ) +from superset.mcp_service.chart.compile import validate_and_compile from superset.mcp_service.chart.schemas import ( GenerateExploreLinkRequest, ) +logger = logging.getLogger(__name__) + @tool( tags=["explore"], @@ -131,6 +136,24 @@ async def generate_explore_link( ), } + if not has_dataset_access(dataset): + logger.warning( + "User attempted to access dataset %s without permission", + request.dataset_id, + ) + await ctx.warning( + "Dataset access denied: dataset_id=%s" % (request.dataset_id,) + ) + return { + "url": "", + "form_data": {}, + "form_data_key": None, + "error": ( + f"Dataset not found: {request.dataset_id}. " + "Use list_datasets to find valid dataset IDs." + ), + } + await ctx.report_progress(2, 4, "Converting configuration to form data") with event_logger.log_context(action="mcp.generate_explore_link.form_data"): # Normalize column names to match canonical dataset column names @@ -165,6 +188,38 @@ async def generate_explore_link( ) ) + # Tier-1 schema validation against the dataset (no DB roundtrip). + # Catches references to non-existent columns/metrics with fuzzy + # suggestions so the LLM can self-correct ("did you mean sum_boys?"). + with event_logger.log_context(action="mcp.generate_explore_link.validation"): + compile_result = validate_and_compile( + normalized_config, + form_data, + dataset, + run_compile_check=False, + ) + if not compile_result.success: + await ctx.warning( + "Explore link validation failed: error=%s" % (compile_result.error,) + ) + error_payload: Dict[str, Any] + if compile_result.error_obj is not None: + error_payload = compile_result.error_obj.model_dump() + else: + error_payload = { + "error_type": "validation_error", + "message": "Explore link validation failed", + "details": compile_result.error or "", + "error_code": compile_result.error_code, + "suggestions": [], + } + return { + "url": "", + "form_data": form_data, + "form_data_key": None, + "error": error_payload, + } + await ctx.report_progress(3, 4, "Generating explore URL") with event_logger.log_context( action="mcp.generate_explore_link.url_generation" diff --git a/tests/unit_tests/mcp_service/chart/test_compile.py b/tests/unit_tests/mcp_service/chart/test_compile.py new file mode 100644 index 00000000000..4c52063e5de --- /dev/null +++ b/tests/unit_tests/mcp_service/chart/test_compile.py @@ -0,0 +1,445 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Integration-style tests for ``validate_and_compile``. + +These tests exercise the real ``DatasetValidator.validate_against_dataset`` +path so fast-path tools (``generate_explore_link``, ``update_chart_preview``) +that only use Tier-1 validation are exercised end-to-end. +""" + +from unittest.mock import Mock, patch + +import pytest + +from superset.mcp_service.chart.compile import ( + build_dataset_context_from_orm, + CompileResult, + validate_and_compile, +) +from superset.mcp_service.chart.schemas import ( + BigNumberChartConfig, + ColumnRef, + FilterConfig, + PieChartConfig, + PivotTableChartConfig, + TableChartConfig, + XYChartConfig, +) + + +def _orm_dataset( + *, + column_names: list[str] | None = None, + metric_names: list[str] | None = None, + has_database: bool = True, +) -> Mock: + """Build a Mock dataset that satisfies build_dataset_context_from_orm.""" + columns = [] + for name in column_names or ["ds", "gender", "name", "num"]: + col = Mock() + col.column_name = name + col.type = "TEXT" + col.is_temporal = name == "ds" + col.is_numeric = name == "num" + columns.append(col) + + metrics = [] + for name in metric_names or ["sum_boys", "sum_girls"]: + m = Mock() + m.metric_name = name + m.expression = f"SUM({name})" + m.description = None + metrics.append(m) + + dataset = Mock() + dataset.id = 3 + dataset.table_name = "birth_names" + dataset.schema = None + dataset.columns = columns + dataset.metrics = metrics + if has_database: + db = Mock() + db.database_name = "examples" + dataset.database = db + else: + dataset.database = None + return dataset + + +class TestBuildDatasetContextFromOrm: + """Cover the helper that converts ORM dataset → DatasetContext.""" + + def test_handles_missing_database_relationship(self): + """``database_name`` defaults to '' when ``dataset.database`` is None + so Pydantic validation doesn't blow up.""" + ds = _orm_dataset(has_database=False) + ctx = build_dataset_context_from_orm(ds) + assert ctx is not None + assert ctx.database_name == "" + assert ctx.id == 3 + assert {c["name"] for c in ctx.available_columns} == { + "ds", + "gender", + "name", + "num", + } + assert {m["name"] for m in ctx.available_metrics} == { + "sum_boys", + "sum_girls", + } + + def test_returns_none_for_none_input(self): + assert build_dataset_context_from_orm(None) is None + + +class TestValidateAndCompileChartTypeCoverage: + """Tier-1 validation must catch bad column refs in every supported + chart-config variant — not just XY and table.""" + + def test_xy_bad_metric_column_rejected(self): + ds = _orm_dataset() + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="num_boys", aggregate="SUM")], + kind="line", + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert not result.success + assert result.tier == "validation" + assert result.error_obj is not None + assert any("sum_boys" in s for s in (result.error_obj.suggestions or [])) + + def test_pie_bad_metric_column_rejected(self): + ds = _orm_dataset() + config = PieChartConfig( + dimension=ColumnRef(name="gender"), + metric=ColumnRef(name="num_boys", aggregate="SUM"), + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert not result.success, "Pie chart with bad metric column should fail" + assert result.tier == "validation" + assert result.error_obj is not None + assert any("sum_boys" in s for s in (result.error_obj.suggestions or [])) + + def test_pie_valid_dimension_and_saved_metric_passes(self): + ds = _orm_dataset() + config = PieChartConfig( + dimension=ColumnRef(name="gender"), + metric=ColumnRef(name="sum_boys", saved_metric=True), + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert result.success, result.error + + def test_pivot_table_bad_row_rejected(self): + ds = _orm_dataset() + config = PivotTableChartConfig( + rows=[ColumnRef(name="bogus_dim")], + metrics=[ColumnRef(name="sum_boys", saved_metric=True)], + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert not result.success + assert result.error_obj is not None + + def test_big_number_bad_temporal_column_rejected(self): + ds = _orm_dataset() + config = BigNumberChartConfig( + chart_type="big_number", + metric=ColumnRef(name="sum_boys", saved_metric=True), + temporal_column="not_a_real_temporal", + show_trendline=True, + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert not result.success, "BigNumber temporal_column must be validated" + assert result.error_obj is not None + assert "not_a_real_temporal" in (result.error_obj.message or "") + + def test_pie_with_sum_on_non_numeric_column_rejected(self): + """Tier-1 aggregation compatibility now runs for non-Table/XY too — + a pie ``metric={"name": "gender", "aggregate": "SUM"}`` would emit + ``SUM(gender)`` which the DB rejects, so the validator must catch it + before we hand back an explore URL.""" + ds = _orm_dataset() + config = PieChartConfig( + dimension=ColumnRef(name="name"), + metric=ColumnRef(name="gender", aggregate="SUM"), + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert not result.success, "SUM on a TEXT column must reject" + assert result.error_obj is not None + assert result.error_obj.error_code == "INVALID_AGGREGATION" + + def test_pivot_table_sum_on_non_numeric_column_rejected(self): + ds = _orm_dataset() + config = PivotTableChartConfig( + rows=[ColumnRef(name="gender")], + metrics=[ColumnRef(name="name", aggregate="SUM")], + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert not result.success + assert result.error_obj is not None + assert result.error_obj.error_code == "INVALID_AGGREGATION" + + def test_pivot_table_min_on_non_numeric_column_passes(self): + """MIN and MAX are not numeric-only (valid on dates/text in SQL). + + They are left to the Tier-2 compile check rather than being rejected + by Tier-1 schema validation. + """ + ds = _orm_dataset() + config = PivotTableChartConfig( + rows=[ColumnRef(name="gender")], + metrics=[ColumnRef(name="name", aggregate="MIN")], + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert result.success, ( + "MIN on a text column should not be rejected by Tier-1 validation" + ) + + def test_table_with_invalid_filter_column_rejected(self): + ds = _orm_dataset() + config = TableChartConfig( + chart_type="table", + columns=[ColumnRef(name="gender")], + filters=[FilterConfig(column="bogus", op="=", value="x")], + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert not result.success + assert result.error_obj is not None + + +class TestSavedMetricNotMarked: + """A non-saved-metric ColumnRef whose name matches a saved metric is a + common LLM mistake (forgetting to set ``saved_metric=true``). The + validator should surface a tailored hint instead of letting the bad SQL + through.""" + + def test_table_metric_name_without_saved_metric_flag_rejected(self): + ds = _orm_dataset() + config = TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="gender"), + # ``sum_boys`` is a saved metric on the dataset, but + # saved_metric=False (default) would render as + # ``SUM(sum_boys)`` ad-hoc SQL — broken. + ColumnRef(name="sum_boys", aggregate="SUM"), + ], + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert not result.success, ( + "ref.name matches a saved metric but saved_metric=False -> reject" + ) + assert result.error_obj is not None + assert result.error_obj.error_code == "SAVED_METRIC_NOT_MARKED" + # Suggestion should point the LLM at the right correction. + suggestions_text = " ".join(result.error_obj.suggestions or []) + assert "saved_metric" in suggestions_text + assert "sum_boys" in suggestions_text + + def test_pie_metric_name_without_saved_metric_flag_rejected(self): + ds = _orm_dataset() + config = PieChartConfig( + dimension=ColumnRef(name="gender"), + metric=ColumnRef(name="sum_boys", aggregate="SUM"), + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert not result.success + assert result.error_obj is not None + assert result.error_obj.error_code == "SAVED_METRIC_NOT_MARKED" + + def test_explicit_saved_metric_passes(self): + ds = _orm_dataset() + config = PieChartConfig( + dimension=ColumnRef(name="gender"), + metric=ColumnRef(name="sum_boys", saved_metric=True), + ) + result = validate_and_compile(config, {}, ds, run_compile_check=False) + assert result.success, result.error + + +class TestAdhocFiltersFromFormData: + """Filters merged into form_data (not present on the typed config) must + also be validated. Without this hook, ``update_chart_preview`` could + smuggle bad column refs through preserved adhoc filters.""" + + def test_unknown_adhoc_filter_subject_rejected(self): + ds = _orm_dataset() + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="gender")] + ) + form_data = { + "adhoc_filters": [ + { + "expressionType": "SIMPLE", + "subject": "removed_column", + "operator": "==", + "comparator": "x", + } + ] + } + result = validate_and_compile(config, form_data, ds, run_compile_check=False) + assert not result.success + assert result.error_obj is not None + assert "removed_column" in (result.error_obj.message or "") + + def test_known_adhoc_filter_subject_passes(self): + ds = _orm_dataset() + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="gender")] + ) + form_data = { + "adhoc_filters": [ + { + "expressionType": "SIMPLE", + "subject": "gender", + "operator": "==", + "comparator": "boy", + } + ] + } + result = validate_and_compile(config, form_data, ds, run_compile_check=False) + assert result.success, result.error + + def test_sql_expression_filter_skipped(self): + """SQL-expression filters carry a free-form ``sqlExpression`` we can't + safely parse, so they should pass Tier-1 untouched.""" + ds = _orm_dataset() + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="gender")] + ) + form_data = { + "adhoc_filters": [ + { + "expressionType": "SQL", + "clause": "WHERE", + "sqlExpression": "1 = 1", + } + ] + } + result = validate_and_compile(config, form_data, ds, run_compile_check=False) + assert result.success + + def test_where_filter_with_metric_name_rejected(self): + """A saved-metric name used as a WHERE filter subject must be rejected. + + WHERE filters need a physical column; metric names are only valid in + HAVING clauses where Superset can resolve them. + """ + ds = _orm_dataset() + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="gender")] + ) + form_data = { + "adhoc_filters": [ + { + "expressionType": "SIMPLE", + "clause": "WHERE", + "subject": "sum_boys", # saved metric, not a physical column + "operator": ">", + "comparator": "0", + } + ] + } + result = validate_and_compile(config, form_data, ds, run_compile_check=False) + assert not result.success, ( + "A saved-metric name used in a WHERE filter must not pass Tier-1" + ) + assert result.error_obj is not None + assert "sum_boys" in (result.error_obj.message or "") + + def test_having_filter_with_metric_name_passes(self): + """A saved-metric name used in a HAVING filter must be accepted. + + HAVING filters are aggregate-level conditions; Superset resolves metric + names there so they are valid references. + """ + ds = _orm_dataset() + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="gender")] + ) + form_data = { + "adhoc_filters": [ + { + "expressionType": "SIMPLE", + "clause": "HAVING", + "subject": "sum_boys", # saved metric — valid in HAVING + "operator": ">", + "comparator": "0", + } + ] + } + result = validate_and_compile(config, form_data, ds, run_compile_check=False) + assert result.success, ( + "A saved-metric name in a HAVING filter should pass Tier-1 validation" + ) + + +class TestValidateAndCompileTier2: + """When ``run_compile_check=True`` and Tier-1 passes, the helper must + invoke ``_compile_chart`` and surface its outcome.""" + + @patch("superset.mcp_service.chart.compile._compile_chart") + def test_tier2_runs_when_tier1_passes(self, mock_compile): + mock_compile.return_value = CompileResult(success=True) + ds = _orm_dataset() + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="gender")] + ) + result = validate_and_compile( + config, {"adhoc_filters": []}, ds, run_compile_check=True + ) + assert result.success + mock_compile.assert_called_once() + + @patch("superset.mcp_service.chart.compile._compile_chart") + def test_tier2_skipped_on_tier1_failure(self, mock_compile): + ds = _orm_dataset() + config = TableChartConfig(chart_type="table", columns=[ColumnRef(name="bogus")]) + result = validate_and_compile(config, {}, ds, run_compile_check=True) + assert not result.success + assert result.tier == "validation" + mock_compile.assert_not_called() + + def test_dataset_none_returns_dataset_not_found(self): + result = validate_and_compile(None, {}, None, run_compile_check=True) + assert not result.success + assert result.error_code == "DATASET_NOT_FOUND" + + +@pytest.mark.parametrize( + "config_factory", + [ + lambda: PieChartConfig( + dimension=ColumnRef(name="gender"), + metric=ColumnRef(name="sum_boys", saved_metric=True), + ), + lambda: TableChartConfig( + chart_type="table", + columns=[ + ColumnRef(name="gender"), + ColumnRef(name="sum_boys", saved_metric=True), + ], + ), + ], +) +def test_valid_configs_pass_tier1(config_factory): + ds = _orm_dataset() + result = validate_and_compile(config_factory(), {}, ds, run_compile_check=False) + assert result.success, result.error diff --git a/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py index f7ded95c9b1..c504d8bca59 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_update_chart.py @@ -745,6 +745,9 @@ class TestUpdateChartNameOnly: class TestUpdateChartPreviewFirst: """Integration-style tests for the preview-first default flow.""" + @patch.object( + update_chart_module, "_validate_update_against_dataset", return_value=None + ) @patch.object(update_chart_module, "_create_preview_url", new_callable=Mock) @patch( "superset.commands.chart.update.UpdateChartCommand", @@ -764,6 +767,7 @@ class TestUpdateChartPreviewFirst: mock_check_access, mock_update_cmd_cls, mock_create_preview, + unused_validate_mock, mcp_server, ): """Default update flow returns a preview URL and does NOT save.""" @@ -934,6 +938,9 @@ class TestBuildPreviewFormData: class TestUpdateChartSaveWithConfig: """Save-path integration tests for update_chart with a full config payload.""" + @patch.object( + update_chart_module, "_validate_update_against_dataset", return_value=None + ) @patch( "superset.commands.chart.update.UpdateChartCommand", new_callable=Mock, @@ -951,6 +958,7 @@ class TestUpdateChartSaveWithConfig: mock_find_by_id, mock_check_access, mock_update_cmd_cls, + unused_validate_mock, mcp_server, ): """generate_preview=False with config persists and returns saved chart.""" @@ -1086,6 +1094,9 @@ class TestUpdateChartErrorPaths: assert error["error_type"] == "CommandException" assert "boom" in error["details"] + @patch.object( + update_chart_module, "_validate_update_against_dataset", return_value=None + ) @patch.object(update_chart_module, "_create_preview_url", new_callable=Mock) @patch( "superset.mcp_service.auth.check_chart_data_access", @@ -1100,6 +1111,7 @@ class TestUpdateChartErrorPaths: mock_find_by_id, mock_check_access, mock_create_preview, + unused_validate_mock, mcp_server, ): """If _create_preview_url returns (url, None), form_data_key comes from url.""" @@ -1137,3 +1149,140 @@ class TestUpdateChartErrorPaths: assert result.structured_content["success"] is True assert result.structured_content["form_data_key"] == "url_embedded_key" + + +class TestUpdateChartValidationGate: + """Tier-1+2 validation prevents bad config from reaching DB or cache.""" + + @staticmethod + def _mock_chart_with_dataset(chart_id: int = 1) -> Mock: + chart = Mock() + chart.id = chart_id + chart.datasource_id = 10 + chart.slice_name = "Existing" + chart.viz_type = "table" + chart.uuid = "abc-123" + chart.params = '{"viz_type": "table", "datasource": "10__table"}' + # validate_and_compile is mocked, so dataset shape doesn't matter. + chart.datasource = Mock() + return chart + + @patch.object(update_chart_module, "validate_and_compile") + @patch.object(update_chart_module, "_create_preview_url", new_callable=Mock) + @patch( + "superset.mcp_service.auth.check_chart_data_access", + new_callable=Mock, + ) + @patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock) + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_preview_path_validation_failure_skips_cache( + self, + mock_db_session, + mock_find_by_id, + mock_check_access, + mock_create_preview, + mock_validate, + mcp_server, + ): + """Preview path: bad column → structured error, _create_preview_url + must NOT be called.""" + from superset.mcp_service.chart.compile import CompileResult + from superset.mcp_service.common.error_schemas import ChartGenerationError + + mock_find_by_id.return_value = self._mock_chart_with_dataset() + mock_check_access.return_value = DatasetValidationResult( + is_valid=True, dataset_id=10, dataset_name="ds", warnings=[] + ) + mock_validate.return_value = CompileResult( + success=False, + error="Column 'num_boys' does not exist", + error_code="CHART_VALIDATION_FAILED", + tier="validation", + error_obj=ChartGenerationError( + error_type="invalid_column", + message="Column 'num_boys' does not exist", + details="Available: ds, gender, sum_boys", + suggestions=["sum_boys"], + error_code="CHART_VALIDATION_FAILED", + ), + ) + + request = { + "identifier": 1, + "config": { + "chart_type": "xy", + "x": {"name": "ds"}, + "y": [{"name": "num_boys", "aggregate": "SUM"}], + "kind": "line", + }, + } + + async with Client(mcp) as client: + result = await client.call_tool("update_chart", {"request": request}) + + assert result.structured_content["success"] is False + error = result.structured_content["error"] + assert error["error_code"] == "CHART_VALIDATION_FAILED" + assert "sum_boys" in error["suggestions"] + mock_create_preview.assert_not_called() + + @patch.object(update_chart_module, "validate_and_compile") + @patch( + "superset.commands.chart.update.UpdateChartCommand", + new_callable=Mock, + ) + @patch( + "superset.mcp_service.auth.check_chart_data_access", + new_callable=Mock, + ) + @patch("superset.daos.chart.ChartDAO.find_by_id", new_callable=Mock) + @patch("superset.db.session") + @pytest.mark.asyncio + async def test_persist_path_validation_failure_skips_db_write( + self, + mock_db_session, + mock_find_by_id, + mock_check_access, + mock_update_cmd_cls, + mock_validate, + mcp_server, + ): + """Persist path: validation failure → UpdateChartCommand NOT called.""" + from superset.mcp_service.chart.compile import CompileResult + from superset.mcp_service.common.error_schemas import ChartGenerationError + + mock_find_by_id.return_value = self._mock_chart_with_dataset(chart_id=42) + mock_check_access.return_value = DatasetValidationResult( + is_valid=True, dataset_id=10, dataset_name="ds", warnings=[] + ) + mock_validate.return_value = CompileResult( + success=False, + error="Column 'bad_col' does not exist", + error_code="CHART_VALIDATION_FAILED", + tier="validation", + error_obj=ChartGenerationError( + error_type="invalid_column", + message="Column 'bad_col' does not exist", + details="Available: a, b, c", + suggestions=["a"], + error_code="CHART_VALIDATION_FAILED", + ), + ) + + request = { + "identifier": 42, + "generate_preview": False, + "config": { + "chart_type": "table", + "columns": [{"name": "bad_col"}], + }, + } + + async with Client(mcp) as client: + result = await client.call_tool("update_chart", {"request": request}) + + assert result.structured_content["success"] is False + error = result.structured_content["error"] + assert error["error_code"] == "CHART_VALIDATION_FAILED" + mock_update_cmd_cls.assert_not_called() diff --git a/tests/unit_tests/mcp_service/chart/tool/test_update_chart_preview.py b/tests/unit_tests/mcp_service/chart/tool/test_update_chart_preview.py index 79f3c90c5e5..2aa63efdf1c 100644 --- a/tests/unit_tests/mcp_service/chart/tool/test_update_chart_preview.py +++ b/tests/unit_tests/mcp_service/chart/tool/test_update_chart_preview.py @@ -23,7 +23,9 @@ import importlib from unittest.mock import Mock, patch import pytest +from fastmcp import Client +from superset.mcp_service.app import mcp from superset.mcp_service.chart.schemas import ( AxisConfig, ColumnRef, @@ -34,11 +36,48 @@ from superset.mcp_service.chart.schemas import ( XYChartConfig, ) +# The package ``__init__.py`` re-exports the ``update_chart_preview`` tool +# function under the same dotted path as the module, so mock.patch's string +# lookup of ``...update_chart_preview.`` can resolve to the function on +# some Python versions. Hold a direct module reference for ``patch.object``. update_chart_preview_module = importlib.import_module( "superset.mcp_service.chart.tool.update_chart_preview" ) +@pytest.fixture +def mcp_server(): + return mcp + + +@pytest.fixture +def mock_auth(): + """Mock authentication for tool-invocation tests.""" + with patch("superset.mcp_service.auth.get_user_from_request") as mock_get_user: + user = Mock() + user.id = 1 + user.username = "admin" + mock_get_user.return_value = user + yield mock_get_user + + +def _mock_dataset(id: int = 1) -> Mock: + """Mock SqlaTable with the attributes the tool reads.""" + column = Mock() + column.column_name = "ds" + column.type = "TIMESTAMP" + database = Mock() + database.database_name = "main" + dataset = Mock() + dataset.id = id + dataset.table_name = "birth_names" + dataset.schema = None + dataset.columns = [column] + dataset.metrics = [] + dataset.database = database + return dataset + + class TestUpdateChartPreview: """Tests for update_chart_preview MCP tool.""" @@ -528,6 +567,9 @@ class TestUpdateChartPreview: assert result is None + @patch.object(update_chart_preview_module, "validate_and_compile") + @patch.object(update_chart_preview_module, "has_dataset_access", return_value=True) + @patch("superset.daos.dataset.DatasetDAO.find_by_id") @patch.object(update_chart_preview_module, "analyze_chart_semantics") @patch.object(update_chart_preview_module, "analyze_chart_capabilities") @patch.object(update_chart_preview_module, "generate_explore_link") @@ -541,11 +583,16 @@ class TestUpdateChartPreview: mock_generate_explore_link, mock_analyze_chart_capabilities, mock_analyze_chart_semantics, + mock_find_by_id, + unused_access_mock, + mock_validate_and_compile, ) -> None: """Invalid previous form_data_key is warning-only for preview updates.""" mock_user = Mock() mock_user.id = 1 mock_get_user_from_request.return_value = mock_user + mock_find_by_id.return_value = _mock_dataset(id=3) + mock_validate_and_compile.return_value = Mock(success=True) mock_get_previous_form_data.return_value = None mock_generate_explore_link.return_value = ( "http://localhost:8088/explore/?form_data_key=new_preview_key" @@ -581,6 +628,9 @@ class TestUpdateChartPreview: ] mock_get_previous_form_data.assert_called_once_with("nonexistent_key_12345") + @patch.object(update_chart_preview_module, "validate_and_compile") + @patch.object(update_chart_preview_module, "has_dataset_access", return_value=True) + @patch("superset.daos.dataset.DatasetDAO.find_by_id") @patch.object(update_chart_preview_module, "analyze_chart_semantics") @patch.object(update_chart_preview_module, "analyze_chart_capabilities") @patch.object(update_chart_preview_module, "generate_explore_link") @@ -594,11 +644,16 @@ class TestUpdateChartPreview: mock_generate_explore_link, mock_analyze_chart_capabilities, mock_analyze_chart_semantics, + mock_find_by_id, + unused_access_mock, + mock_validate_and_compile, ) -> None: """Valid previous form_data preserves filters without a cache warning.""" mock_user = Mock() mock_user.id = 1 mock_get_user_from_request.return_value = mock_user + mock_find_by_id.return_value = _mock_dataset(id=3) + mock_validate_and_compile.return_value = Mock(success=True) cached_adhoc_filters = [ { "clause": "WHERE", @@ -642,3 +697,101 @@ class TestUpdateChartPreview: assert result["error"] is None assert result["warnings"] == [] mock_get_previous_form_data.assert_called_once_with("valid_key_12345") + + +class TestUpdateChartPreviewValidation: + """Tier-1 validation gate and dataset access checks.""" + + @patch.object(update_chart_preview_module, "has_dataset_access", return_value=True) + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch.object(update_chart_preview_module, "validate_and_compile") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_validation_failure_skips_cache_write( + self, + mock_create_form_data, + mock_validate, + mock_find_dataset, + unused_access_mock, + mcp_server, + mock_auth, + ): + """Bad column ref → structured error with suggestions, no cache write.""" + from superset.mcp_service.chart.compile import CompileResult + from superset.mcp_service.common.error_schemas import ChartGenerationError + + mock_find_dataset.return_value = _mock_dataset(id=3) + mock_validate.return_value = CompileResult( + success=False, + error="Column 'num_boys' does not exist in dataset", + error_code="CHART_VALIDATION_FAILED", + tier="validation", + error_obj=ChartGenerationError( + error_type="invalid_column", + message="Column 'num_boys' does not exist in dataset", + details="Available columns: ds, gender, name, num, sum_boys", + suggestions=["sum_boys"], + error_code="CHART_VALIDATION_FAILED", + ), + ) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="num_boys", aggregate="SUM")], + kind="line", + ) + request = UpdateChartPreviewRequest( + form_data_key="prev_key", dataset_id="3", config=config + ) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "update_chart_preview", {"request": request.model_dump()} + ) + + assert result.data["success"] is False + assert result.data["chart"] is None + error = result.data["error"] + assert isinstance(error, dict) + assert error["error_code"] == "CHART_VALIDATION_FAILED" + assert "sum_boys" in error["suggestions"] + mock_create_form_data.assert_not_called() + + @patch.object(update_chart_preview_module, "has_dataset_access", return_value=False) + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_dataset_access_denied_short_circuits( + self, + mock_create_form_data, + mock_find_dataset, + unused_access_mock, + mcp_server, + mock_auth, + ): + """has_dataset_access=False → DatasetNotAccessible, no cache write.""" + mock_find_dataset.return_value = _mock_dataset(id=3) + + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="region")] + ) + request = UpdateChartPreviewRequest( + form_data_key="prev_key", dataset_id="3", config=config + ) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "update_chart_preview", {"request": request.model_dump()} + ) + + assert result.data["success"] is False + assert result.data["chart"] is None + error = result.data["error"] + assert isinstance(error, dict) + assert error["error_type"] == "DatasetNotAccessible" + mock_create_form_data.assert_not_called() diff --git a/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py index fb8aee539b0..b9e7c3e4b07 100644 --- a/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py +++ b/tests/unit_tests/mcp_service/explore/tool/test_generate_explore_link.py @@ -19,6 +19,7 @@ Comprehensive unit tests for MCP generate_explore_link tool """ +import importlib import logging from unittest.mock import Mock, patch @@ -37,6 +38,14 @@ from superset.mcp_service.chart.schemas import ( ) from superset.mcp_service.common.error_schemas import DatasetContext +# The package ``__init__.py`` re-exports the ``generate_explore_link`` tool +# function under the same dotted path as the module, so mock.patch's string +# lookup of ``...generate_explore_link.`` can resolve to the function +# on some Python versions. Hold a direct module reference for ``patch.object``. +generate_explore_link_module = importlib.import_module( + "superset.mcp_service.explore.tool.generate_explore_link" +) + logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) @@ -57,6 +66,29 @@ def mock_auth(): yield mock_get_user +@pytest.fixture(autouse=True) +def mock_dataset_access_granted(): + """Grant dataset access by default; tests that need a denial override this.""" + with patch.object( + generate_explore_link_module, "has_dataset_access", return_value=True + ): + yield + + +@pytest.fixture(autouse=True) +def mock_validation_passes(): + """Skip Tier-1 dataset validation by default so Mock datasets don't trip the + real validator. Individual tests that exercise validation override this.""" + from superset.mcp_service.chart.compile import CompileResult + + with patch.object( + generate_explore_link_module, + "validate_and_compile", + return_value=CompileResult(success=True), + ): + yield + + @pytest.fixture(autouse=True) def mock_webdriver_baseurl(app_context): """Mock WEBDRIVER_BASEURL_USER_FRIENDLY for consistent test URLs.""" @@ -922,3 +954,102 @@ class TestGenerateExploreLinkColumnNormalization: assert result.data["error"] is None # original names should pass through unchanged assert result.data["form_data"]["x_axis"] == "orderdate" + + +class TestGenerateExploreLinkValidation: + """Tier-1 validation gate (DatasetValidator) and dataset access checks.""" + + @pytest.fixture(autouse=True) + def mock_validation_passes(self): + """Override the module-level autouse patch so each test in this class + can stub ``validate_and_compile`` itself. The fixture name MUST match + the module-level fixture for pytest's override-by-name to take effect. + """ + return + + @patch.object(generate_explore_link_module, "validate_and_compile") + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_validation_failure_returns_structured_error( + self, + mock_create_form_data, + mock_find_dataset, + mock_validate, + mcp_server, + ): + """Non-existent column → structured ChartGenerationError with suggestions, + and MCPCreateFormDataCommand must NOT be called (no cache write).""" + from superset.mcp_service.chart.compile import CompileResult + from superset.mcp_service.common.error_schemas import ChartGenerationError + + mock_find_dataset.return_value = _mock_dataset(id=3) + mock_validate.return_value = CompileResult( + success=False, + error="Column 'num_boys' does not exist in dataset", + error_code="CHART_VALIDATION_FAILED", + tier="validation", + error_obj=ChartGenerationError( + error_type="invalid_column", + message="Column 'num_boys' does not exist in dataset", + details="Available columns: ds, gender, name, num, sum_boys", + suggestions=["sum_boys"], + error_code="CHART_VALIDATION_FAILED", + ), + ) + + config = XYChartConfig( + chart_type="xy", + x=ColumnRef(name="ds"), + y=[ColumnRef(name="num_boys", aggregate="SUM")], + kind="line", + ) + request = GenerateExploreLinkRequest(dataset_id="3", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["url"] == "" + assert result.data["form_data_key"] is None + error = result.data["error"] + assert isinstance(error, dict) + assert error["error_code"] == "CHART_VALIDATION_FAILED" + assert "sum_boys" in error["suggestions"] + mock_create_form_data.assert_not_called() + + @patch.object( + generate_explore_link_module, "has_dataset_access", return_value=False + ) + @patch("superset.daos.dataset.DatasetDAO.find_by_id") + @patch( + "superset.mcp_service.commands.create_form_data.MCPCreateFormDataCommand.run" + ) + @pytest.mark.asyncio + async def test_dataset_access_denied_short_circuits( + self, + mock_create_form_data, + mock_find_dataset, + unused_access_mock, + mcp_server, + ): + """has_dataset_access=False blocks the tool before any cache write.""" + mock_find_dataset.return_value = _mock_dataset(id=3) + + config = TableChartConfig( + chart_type="table", columns=[ColumnRef(name="region")] + ) + request = GenerateExploreLinkRequest(dataset_id="3", config=config) + + async with Client(mcp_server) as client: + result = await client.call_tool( + "generate_explore_link", {"request": request.model_dump()} + ) + + assert result.data["url"] == "" + # Surface as "not found" rather than leaking that the dataset exists. + assert "Dataset not found" in result.data["error"] + mock_create_form_data.assert_not_called()