diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 136ca55712c..331961c7c5e 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -156,11 +156,12 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # default matching patterns for identifying column types db_column_types: Dict[utils.DbColumnType, Tuple[Pattern[Any], ...]] = { utils.DbColumnType.NUMERIC: ( + re.compile(r"BIT", re.IGNORECASE), re.compile(r".*DOUBLE.*", re.IGNORECASE), re.compile(r".*FLOAT.*", re.IGNORECASE), re.compile(r".*INT.*", re.IGNORECASE), re.compile(r".*NUMBER.*", re.IGNORECASE), - re.compile(r".*LONG.*", re.IGNORECASE), + re.compile(r".*LONG$", re.IGNORECASE), re.compile(r".*REAL.*", re.IGNORECASE), re.compile(r".*NUMERIC.*", re.IGNORECASE), re.compile(r".*DECIMAL.*", re.IGNORECASE), diff --git a/tests/db_engine_specs/mysql_tests.py b/tests/db_engine_specs/mysql_tests.py index f284f8801ad..b3ebdfc6ca3 100644 --- a/tests/db_engine_specs/mysql_tests.py +++ b/tests/db_engine_specs/mysql_tests.py @@ -20,6 +20,7 @@ from sqlalchemy.dialects import mysql from sqlalchemy.dialects.mysql import DATE, NVARCHAR, TEXT, VARCHAR from superset.db_engine_specs.mysql import MySQLEngineSpec +from superset.utils.core import DbColumnType from tests.db_engine_specs.base_tests import TestDbEngineSpec @@ -62,3 +63,41 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec): original, mysql.dialect() ) self.assertEqual(actual, expected) + + def test_is_db_column_type_match(self): + type_expectations = ( + # Numeric + ("TINYINT", DbColumnType.NUMERIC), + ("SMALLINT", DbColumnType.NUMERIC), + ("MEDIUMINT", DbColumnType.NUMERIC), + ("INT", DbColumnType.NUMERIC), + ("BIGINT", DbColumnType.NUMERIC), + ("DECIMAL", DbColumnType.NUMERIC), + ("FLOAT", DbColumnType.NUMERIC), + ("DOUBLE", DbColumnType.NUMERIC), + ("BIT", DbColumnType.NUMERIC), + # String + ("CHAR", DbColumnType.STRING), + ("VARCHAR", DbColumnType.STRING), + ("TINYTEXT", DbColumnType.STRING), + ("MEDIUMTEXT", DbColumnType.STRING), + ("LONGTEXT", DbColumnType.STRING), + # Temporal + ("DATE", DbColumnType.TEMPORAL), + ("DATETIME", DbColumnType.TEMPORAL), + ("TIMESTAMP", DbColumnType.TEMPORAL), + ("TIME", DbColumnType.TEMPORAL), + ) + + for type_expectation in type_expectations: + type_str = type_expectation[0] + col_type = type_expectation[1] + assert MySQLEngineSpec.is_db_column_type_match( + type_str, DbColumnType.NUMERIC + ) is (col_type == DbColumnType.NUMERIC) + assert MySQLEngineSpec.is_db_column_type_match( + type_str, DbColumnType.STRING + ) is (col_type == DbColumnType.STRING) + assert MySQLEngineSpec.is_db_column_type_match( + type_str, DbColumnType.TEMPORAL + ) is (col_type == DbColumnType.TEMPORAL)