chore(db_engine_specs): clean up column spec logic and add tests (#22871)

This commit is contained in:
Ville Brofeldt
2023-01-31 15:54:07 +02:00
committed by GitHub
parent 8466eec228
commit cd6fc35f60
73 changed files with 1953 additions and 1463 deletions

View File

@@ -20,6 +20,7 @@ from typing import Any, Dict, Optional, Pattern, Tuple
from urllib import parse
from flask_babel import gettext as __
from sqlalchemy import types
from sqlalchemy.dialects.mysql import (
BIT,
DECIMAL,
@@ -34,15 +35,10 @@ from sqlalchemy.dialects.mysql import (
)
from sqlalchemy.engine.url import URL
from superset.db_engine_specs.base import (
BaseEngineSpec,
BasicParametersMixin,
ColumnTypeMapping,
)
from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin
from superset.errors import SupersetErrorType
from superset.models.sql_lab import Query
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
from superset.utils.core import GenericDataType
# Regular expressions to catch custom errors
CONNECTION_ACCESS_DENIED_REGEX = re.compile(
@@ -182,10 +178,11 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
def convert_dttm(
cls, target_type: str, dttm: datetime, db_extra: Optional[Dict[str, Any]] = None
) -> Optional[str]:
tt = target_type.upper()
if tt == utils.TemporalType.DATE:
sqla_type = cls.get_sqla_column_type(target_type)
if isinstance(sqla_type, types.Date):
return f"STR_TO_DATE('{dttm.date().isoformat()}', '%Y-%m-%d')"
if tt == utils.TemporalType.DATETIME:
if isinstance(sqla_type, types.DateTime):
datetime_formatted = dttm.isoformat(sep=" ", timespec="microseconds")
return f"""STR_TO_DATE('{datetime_formatted}', '%Y-%m-%d %H:%i:%s.%f')"""
return None
@@ -232,23 +229,6 @@ class MySQLEngineSpec(BaseEngineSpec, BasicParametersMixin):
pass
return message
@classmethod
def get_column_spec(
cls,
native_type: Optional[str],
db_extra: Optional[Dict[str, Any]] = None,
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[ColumnTypeMapping, ...] = column_type_mappings,
) -> Optional[ColumnSpec]:
column_spec = super().get_column_spec(native_type)
if column_spec:
return column_spec
return super().get_column_spec(
native_type, column_type_mappings=column_type_mappings
)
@classmethod
def get_cancel_query_id(cls, cursor: Any, query: Query) -> Optional[str]:
"""