diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index a249408cd05..425962186c7 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -1295,7 +1295,7 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods # using datetimes if generic_type == GenericDataType.TEMPORAL: column_type = literal_dttm_type_factory( - type(column_type), cls, native_type or "" + column_type, cls, native_type or "" ) is_dttm = generic_type == GenericDataType.TEMPORAL return ColumnSpec( diff --git a/superset/models/sql_types/base.py b/superset/models/sql_types/base.py index 0d9eda43aa1..669181b3a48 100644 --- a/superset/models/sql_types/base.py +++ b/superset/models/sql_types/base.py @@ -26,10 +26,8 @@ if TYPE_CHECKING: def literal_dttm_type_factory( - sqla_type: Type[types.TypeEngine], - db_engine_spec: Type["BaseEngineSpec"], - col_type: str, -) -> Type[types.TypeEngine]: + sqla_type: types.TypeEngine, db_engine_spec: Type["BaseEngineSpec"], col_type: str, +) -> types.TypeEngine: """ Create a custom SQLAlchemy type that supports datetime literal binds. @@ -39,7 +37,7 @@ def literal_dttm_type_factory( :return: SQLAlchemy type that supports using datetima as literal bind """ # pylint: disable=too-few-public-methods - class TemporalWrapperType(sqla_type): # type: ignore + class TemporalWrapperType(type(sqla_type)): # type: ignore # pylint: disable=unused-argument def literal_processor(self, dialect: Dialect) -> Callable[[Any], Any]: def process(value: Any) -> Any: @@ -58,4 +56,4 @@ def literal_dttm_type_factory( return process - return TemporalWrapperType + return TemporalWrapperType() diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 3b48f678867..3b7596b1615 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -561,11 +561,13 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec): self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC) column_spec = PrestoEngineSpec.get_column_spec("time") - assert issubclass(column_spec.sqla_type, types.Time) + assert isinstance(column_spec.sqla_type, types.Time) + assert type(column_spec.sqla_type).__name__ == "TemporalWrapperType" self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL) column_spec = PrestoEngineSpec.get_column_spec("timestamp") - assert issubclass(column_spec.sqla_type, types.TIMESTAMP) + assert isinstance(column_spec.sqla_type, types.TIMESTAMP) + assert type(column_spec.sqla_type).__name__ == "TemporalWrapperType" self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL) sqla_type = PrestoEngineSpec.get_sqla_column_type(None) diff --git a/tests/integration_tests/model_tests.py b/tests/integration_tests/model_tests.py index deb78460b68..567fdfe719b 100644 --- a/tests/integration_tests/model_tests.py +++ b/tests/integration_tests/model_tests.py @@ -24,11 +24,14 @@ from tests.integration_tests.fixtures.birth_names_dashboard import ( import pytest from sqlalchemy.engine.url import make_url +from sqlalchemy.types import DateTime import tests.integration_tests.test_app from superset import app, db as metadata_db +from superset.db_engine_specs.postgres import PostgresEngineSpec from superset.models.core import Database from superset.models.slice import Slice +from superset.models.sql_types.base import literal_dttm_type_factory from superset.utils.core import get_example_database, QueryStatus from .base_tests import SupersetTestCase @@ -516,3 +519,10 @@ class TestSqlaTableModel(SupersetTestCase): assert set(data_for_slices["verbose_map"].keys()) == set( ["__timestamp", "sum__num", "gender",] ) + + +def test_literal_dttm_type_factory(): + orig_type = DateTime() + new_type = literal_dttm_type_factory(orig_type, PostgresEngineSpec, "TIMESTAMP") + assert type(new_type).__name__ == "TemporalWrapperType" + assert str(new_type) == str(orig_type)