diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 99bf203f9a4..d2c57203416 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -402,7 +402,10 @@ class MySQLEngineSpec(BasicParametersMixin, BaseEngineSpec): if not cls.type_code_map: # only import and store if needed at least once # pylint: disable=import-outside-toplevel - import MySQLdb + try: + import MySQLdb + except ImportError: + import pymysql as MySQLdb # type: ignore[import-untyped] # noqa: N812 ft = MySQLdb.constants.FIELD_TYPE cls.type_code_map = { diff --git a/tests/unit_tests/db_engine_specs/test_mysql.py b/tests/unit_tests/db_engine_specs/test_mysql.py index 649af5c7c15..8c117b1cca1 100644 --- a/tests/unit_tests/db_engine_specs/test_mysql.py +++ b/tests/unit_tests/db_engine_specs/test_mysql.py @@ -15,8 +15,10 @@ # specific language governing permissions and limitations # under the License. +import builtins from datetime import datetime from decimal import Decimal +from types import ModuleType from typing import Any, Optional from unittest.mock import Mock, patch @@ -262,3 +264,42 @@ def test_column_type_mutator( mock_cursor.description = description assert spec.fetch_data(mock_cursor) == expected_result + + +def test_get_datatype_pymysql_fallback(): + """get_datatype() falls back to pymysql when MySQLdb is not installed.""" + from superset.db_engine_specs.mysql import MySQLEngineSpec + + # Reset cached type_code_map so the import path is exercised + original_type_code_map = MySQLEngineSpec.type_code_map + MySQLEngineSpec.type_code_map = {} + + try: + # Build a fake pymysql module with constants.FIELD_TYPE + fake_field_type = ModuleType("pymysql.constants.FIELD_TYPE") + fake_field_type.TINY = 1 + fake_field_type.VARCHAR = 15 + + fake_constants = ModuleType("pymysql.constants") + fake_constants.FIELD_TYPE = fake_field_type + + fake_pymysql = ModuleType("pymysql") + fake_pymysql.constants = fake_constants + + original_import = builtins.__import__ + + def mock_import(name: str, *args: Any, **kwargs: Any) -> Any: + if name == "MySQLdb": + raise ImportError("No module named 'MySQLdb'") + if name == "pymysql": + return fake_pymysql + return original_import(name, *args, **kwargs) + + with patch("builtins.__import__", side_effect=mock_import): + assert MySQLEngineSpec.get_datatype(1) == "TINY" + assert MySQLEngineSpec.get_datatype(15) == "VARCHAR" + assert MySQLEngineSpec.get_datatype("BIGINT") == "BIGINT" + assert MySQLEngineSpec.get_datatype(999) is None + finally: + # Restore original state + MySQLEngineSpec.type_code_map = original_type_code_map