diff --git a/superset/mcp_service/chart/compile.py b/superset/mcp_service/chart/compile.py index 6809e7534a5..62abf3e5f15 100644 --- a/superset/mcp_service/chart/compile.py +++ b/superset/mcp_service/chart/compile.py @@ -36,7 +36,10 @@ import logging from dataclasses import dataclass, field from typing import Any, Dict, List, Literal +from sqlalchemy.exc import SQLAlchemyError + from superset.commands.exceptions import CommandException +from superset.errors import SupersetErrorType from superset.mcp_service.chart.validation.dataset_validator import DatasetValidator from superset.mcp_service.common.error_schemas import ( ChartGenerationError, @@ -46,6 +49,31 @@ from superset.mcp_service.common.error_schemas import ( logger = logging.getLogger(__name__) +# Error types from db_engine_spec.extract_errors() that indicate a database +# connectivity or authentication issue rather than a query/config problem. +# +# GENERIC_DB_ENGINE_ERROR is included because many engines (BigQuery, +# Snowflake, Athena, Databricks, Trino) lack specific CONNECTION_* regex +# patterns in their engine specs — all their connection failures fall back +# to this generic type. This is safe here because _compile_chart only runs +# after Tier 1 schema validation has already verified columns, metrics, and +# filters against the dataset. At that point the SQL is auto-generated by +# Superset's query builder, so genuine SQL/config errors are very unlikely. +_CONNECTION_ERROR_TYPES = { + SupersetErrorType.CONNECTION_INVALID_USERNAME_ERROR, + SupersetErrorType.CONNECTION_INVALID_PASSWORD_ERROR, + SupersetErrorType.CONNECTION_INVALID_HOSTNAME_ERROR, + SupersetErrorType.CONNECTION_PORT_CLOSED_ERROR, + SupersetErrorType.CONNECTION_INVALID_PORT_ERROR, + SupersetErrorType.CONNECTION_HOST_DOWN_ERROR, + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + SupersetErrorType.CONNECTION_UNKNOWN_DATABASE_ERROR, + SupersetErrorType.CONNECTION_DATABASE_PERMISSIONS_ERROR, + SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + SupersetErrorType.CONNECTION_DATABASE_TIMEOUT, + SupersetErrorType.GENERIC_DB_ENGINE_ERROR, +} + @dataclass class CompileResult: @@ -186,6 +214,19 @@ def _compile_chart( return CompileResult(success=True, warnings=warnings, row_count=row_count) except (ChartDataQueryFailedError, ChartDataCacheLoadError) as exc: + if _classify_as_database_error(exc, dataset_id): + logger.warning( + "Database connection error during chart compile check: %s: %s", + type(exc).__name__, + str(exc), + ) + return CompileResult( + success=False, + error=f"Database connection error: {exc}", + error_code="CHART_COMPILE_FAILED", + tier="compile", + error_obj=_build_database_error(str(exc)), + ) return CompileResult( success=False, error=str(exc), @@ -201,6 +242,19 @@ def _compile_chart( tier="compile", error_obj=_build_compile_error(str(exc)), ) + except SQLAlchemyError as exc: + logger.warning( + "Database connection error during chart compile check: %s: %s", + type(exc).__name__, + str(exc), + ) + return CompileResult( + success=False, + error=f"Database connection error: {exc}", + error_code="CHART_COMPILE_FAILED", + tier="compile", + error_obj=_build_database_error(str(exc)), + ) def _adhoc_filter_column_valid( @@ -278,6 +332,53 @@ def _validate_adhoc_filter_columns( ) +def _classify_as_database_error(exc: BaseException, dataset_id: int) -> bool: + """Use the dataset's DB engine spec to classify the error. + + Walks the ``__cause__`` chain for direct ``SQLAlchemyError`` instances, + then falls back to the engine spec's ``extract_errors`` regex patterns — + the same classification the Superset UI uses. + """ + # Direct SQLAlchemy errors (unwrapped or in cause chain) + current: BaseException | None = exc + while current is not None: + if isinstance(current, SQLAlchemyError): + return True + current = current.__cause__ + + # Use the dataset's engine spec to classify (same as the UI) + try: + from superset.daos.dataset import DatasetDAO + + dataset = DatasetDAO.find_by_id(dataset_id) + if dataset and dataset.database and isinstance(exc, Exception): + errors = dataset.database.db_engine_spec.extract_errors(exc) + return any(e.error_type in _CONNECTION_ERROR_TYPES for e in errors) + except Exception: # pylint: disable=broad-except + logger.debug( + "Failed to classify error via engine spec for dataset %s: %s", + dataset_id, + exc, + ) + + return False + + +def _build_database_error(message: str) -> ChartGenerationError: + """Wrap a database connection failure in the structured response envelope.""" + return ChartGenerationError( + error_type="database_connection_error", + message="Unable to connect to the database.", + details=message or "", + suggestions=[ + "Check that the database is online and reachable", + "Verify database credentials and connection settings", + "Contact your administrator if the issue persists", + ], + error_code="DATABASE_CONNECTION_ERROR", + ) + + def _build_compile_error(message: str) -> ChartGenerationError: """Wrap a raw compile-failure string in the structured response envelope.""" return ChartGenerationError( diff --git a/superset/mcp_service/chart/tool/generate_chart.py b/superset/mcp_service/chart/tool/generate_chart.py index 45f5d30a17a..caa6db7e41c 100644 --- a/superset/mcp_service/chart/tool/generate_chart.py +++ b/superset/mcp_service/chart/tool/generate_chart.py @@ -416,7 +416,7 @@ async def generate_chart( # noqa: C901 ) execution_time = int((time.time() - start_time) * 1000) - error = ChartGenerationError( + error = compile_result.error_obj or ChartGenerationError( error_type="compile_error", message=( "Chart query failed to execute. The chart was not saved." @@ -546,7 +546,7 @@ async def generate_chart( # noqa: C901 ) execution_time = int((time.time() - start_time) * 1000) - error = ChartGenerationError( + error = compile_result.error_obj or ChartGenerationError( error_type="compile_error", message=( "Chart query failed to execute. " diff --git a/tests/unit_tests/mcp_service/chart/test_compile.py b/tests/unit_tests/mcp_service/chart/test_compile.py index 4c52063e5de..e65e6b801d4 100644 --- a/tests/unit_tests/mcp_service/chart/test_compile.py +++ b/tests/unit_tests/mcp_service/chart/test_compile.py @@ -423,6 +423,93 @@ class TestValidateAndCompileTier2: assert result.error_code == "DATASET_NOT_FOUND" +@patch("superset.daos.dataset.DatasetDAO") +@patch("superset.commands.chart.data.get_data_command.ChartDataCommand") +@patch("superset.common.query_context_factory.QueryContextFactory") +def test_compile_chart_returns_database_error_when_wrapped_in_query_failed( + mock_factory, mock_cmd_cls, mock_dataset_dao +): + """ChartDataCommand converts OperationalError to a string inside + ChartDataQueryFailedError (no __cause__ set). _classify_as_database_error + should use db_engine_spec.extract_errors() to detect the DB error.""" + from superset.commands.chart.exceptions import ChartDataQueryFailedError + from superset.errors import ErrorLevel, SupersetError, SupersetErrorType + from superset.mcp_service.chart.compile import _compile_chart + + mock_factory.return_value.create.return_value = Mock() + mock_cmd_cls.return_value.validate.return_value = None + + # Real scenario: __cause__ is NOT set, error is just a string + wrapped = ChartDataQueryFailedError( + "Error: (psycopg2.OperationalError) connection to server at '10.0.0.1'," + " port 5432 failed: FATAL: tenant not found" + ) + mock_cmd_cls.return_value.run.side_effect = wrapped + + # Mock the dataset's db_engine_spec to return GENERIC_DB_ENGINE_ERROR + mock_db = Mock() + mock_db.db_engine_spec.extract_errors.return_value = [ + SupersetError( + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + message="connection to server failed", + level=ErrorLevel.ERROR, + extra={"engine_name": "PostgreSQL"}, + ) + ] + mock_dataset = Mock() + mock_dataset.database = mock_db + mock_dataset_dao.find_by_id.return_value = mock_dataset + + result = _compile_chart( + form_data={ + "metrics": [{"label": "count", "expressionType": "SIMPLE"}], + "adhoc_filters": [], + }, + dataset_id=1, + ) + + assert not result.success + assert "Database connection error" in result.error + assert result.error_obj is not None + assert result.error_obj.error_type == "database_connection_error" + assert result.error_obj.error_code == "DATABASE_CONNECTION_ERROR" + mock_db.db_engine_spec.extract_errors.assert_called_once() + + +@patch("superset.commands.chart.data.get_data_command.ChartDataCommand") +@patch("superset.common.query_context_factory.QueryContextFactory") +def test_compile_chart_returns_database_error_on_raw_sqlalchemy_error( + mock_factory, mock_cmd_cls +): + """When SQLAlchemyError escapes unwrapped, _compile_chart should + catch it and return a database_connection_error.""" + from sqlalchemy.exc import OperationalError + + from superset.mcp_service.chart.compile import _compile_chart + + mock_factory.return_value.create.return_value = Mock() + mock_cmd_cls.return_value.validate.return_value = None + mock_cmd_cls.return_value.run.side_effect = OperationalError( + "connection to server at '10.0.0.1', port 5432 failed: Connection timed out", + None, + None, + ) + + result = _compile_chart( + form_data={ + "metrics": [{"label": "count", "expressionType": "SIMPLE"}], + "adhoc_filters": [], + }, + dataset_id=1, + ) + + assert not result.success + assert "Database connection error" in result.error + assert result.error_obj is not None + assert result.error_obj.error_type == "database_connection_error" + assert result.error_obj.error_code == "DATABASE_CONNECTION_ERROR" + + @pytest.mark.parametrize( "config_factory", [