feat(explore): Postgres datatype conversion (#13294)

* test

* unnecessary import

* fix lint

* changes

* fix lint

* changes

* changes

* changes

* changes

* answering comments & changes

* answering comments

* answering comments

* changes

* changes

* changes

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests

* fix tests
This commit is contained in:
Nikola Gigić
2021-03-12 09:36:43 +01:00
committed by GitHub
parent 1a46f93057
commit 609c3594ef
12 changed files with 465 additions and 147 deletions

View File

@@ -69,6 +69,7 @@ from superset.result_set import SupersetResultSet
from superset.sql_parse import ParsedQuery
from superset.typing import Metric, QueryObjectDict
from superset.utils import core as utils
from superset.utils.core import GenericDataType
config = app.config
metadata = Model.metadata # pylint: disable=no-member
@@ -186,20 +187,20 @@ class TableColumn(Model, BaseColumn):
"""
Check if the column has a numeric datatype.
"""
db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match(
self.type, utils.GenericDataType.NUMERIC
)
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.generic_type == GenericDataType.NUMERIC
@property
def is_string(self) -> bool:
"""
Check if the column has a string datatype.
"""
db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match(
self.type, utils.GenericDataType.STRING
)
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.generic_type == GenericDataType.STRING
@property
def is_temporal(self) -> bool:
@@ -211,10 +212,10 @@ class TableColumn(Model, BaseColumn):
"""
if self.is_dttm is not None:
return self.is_dttm
db_engine_spec = self.table.database.db_engine_spec
return db_engine_spec.is_db_column_type_match(
self.type, utils.GenericDataType.TEMPORAL
)
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
if column_spec is None:
return False
return column_spec.is_dttm
def get_sqla_col(self, label: Optional[str] = None) -> Column:
label = label or self.column_name
@@ -222,7 +223,8 @@ class TableColumn(Model, BaseColumn):
col = literal_column(self.expression)
else:
db_engine_spec = self.table.database.db_engine_spec
type_ = db_engine_spec.get_sqla_column_type(self.type)
column_spec = db_engine_spec.get_column_spec(self.type)
type_ = column_spec.sqla_type if column_spec else None
col = column(self.column_name, type_=type_)
col = self.table.make_sqla_column_compatible(col, label)
return col

View File

@@ -41,7 +41,7 @@ import pandas as pd
import sqlparse
from flask import g
from flask_babel import gettext as __, lazy_gettext as _
from sqlalchemy import column, DateTime, select
from sqlalchemy import column, DateTime, select, types
from sqlalchemy.engine.base import Engine
from sqlalchemy.engine.interfaces import Compiled, Dialect
from sqlalchemy.engine.reflection import Inspector
@@ -50,13 +50,14 @@ from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import Session
from sqlalchemy.sql import quoted_name, text
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
from sqlalchemy.types import TypeEngine
from sqlalchemy.types import String, TypeEngine, UnicodeText
from superset import app, security_manager, sql_parse
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.models.sql_lab import Query
from superset.sql_parse import ParsedQuery, Table
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
if TYPE_CHECKING:
# prevent circular imports
@@ -145,8 +146,87 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
_date_trunc_functions: Dict[str, str] = {}
_time_grain_expressions: Dict[Optional[str], str] = {}
column_type_mappings: Tuple[
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
] = ()
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = (
(
re.compile(r"^smallint", re.IGNORECASE),
types.SmallInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^integer", re.IGNORECASE),
types.Integer(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bigint", re.IGNORECASE),
types.BigInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^decimal", re.IGNORECASE),
types.Numeric(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^numeric", re.IGNORECASE),
types.Numeric(),
GenericDataType.NUMERIC,
),
(re.compile(r"^real", re.IGNORECASE), types.REAL, GenericDataType.NUMERIC,),
(
re.compile(r"^smallserial", re.IGNORECASE),
types.SmallInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^serial", re.IGNORECASE),
types.Integer(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^bigserial", re.IGNORECASE),
types.BigInteger(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^string", re.IGNORECASE),
types.String(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
UnicodeText(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
String(),
utils.GenericDataType.STRING,
),
(re.compile(r"^date", re.IGNORECASE), types.Date(), GenericDataType.TEMPORAL,),
(
re.compile(r"^timestamp", re.IGNORECASE),
types.TIMESTAMP(),
GenericDataType.TEMPORAL,
),
(
re.compile(r"^interval", re.IGNORECASE),
types.Interval(),
GenericDataType.TEMPORAL,
),
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
(
re.compile(r"^boolean", re.IGNORECASE),
types.Boolean(),
GenericDataType.BOOLEAN,
),
)
time_groupby_inline = False
limit_method = LimitMethod.FORCE_LIMIT
time_secondary_columns = False
@@ -160,25 +240,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
run_multiple_statements_as_one = False
# default matching patterns to convert database specific column types to
# more generic types
db_column_types: Dict[utils.GenericDataType, Tuple[Pattern[str], ...]] = {
utils.GenericDataType.NUMERIC: (
re.compile(r"BIT", re.IGNORECASE),
re.compile(
r".*(DOUBLE|FLOAT|INT|NUMBER|REAL|NUMERIC|DECIMAL|MONEY).*",
re.IGNORECASE,
),
re.compile(r".*LONG$", re.IGNORECASE),
),
utils.GenericDataType.STRING: (
re.compile(r".*(CHAR|STRING|TEXT).*", re.IGNORECASE),
),
utils.GenericDataType.TEMPORAL: (
re.compile(r".*(DATE|TIME).*", re.IGNORECASE),
),
}
@classmethod
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
"""
@@ -208,25 +269,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return exception
return new_exception(str(exception))
@classmethod
def is_db_column_type_match(
cls, db_column_type: Optional[str], target_column_type: utils.GenericDataType
) -> bool:
"""
Check if a column type satisfies a pattern in a collection of regexes found in
`db_column_types`. For example, if `db_column_type == "NVARCHAR"`,
it would be a match for "STRING" due to being a match for the regex ".*CHAR.*".
:param db_column_type: Column type to evaluate
:param target_column_type: The target type to evaluate for
:return: `True` if a `db_column_type` matches any pattern corresponding to
`target_column_type`
"""
if not db_column_type:
return False
patterns = cls.db_column_types[target_column_type]
return any(pattern.match(db_column_type) for pattern in patterns)
@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return False
@@ -967,24 +1009,35 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
return label_mutated
@classmethod
def get_sqla_column_type(cls, type_: Optional[str]) -> Optional[TypeEngine]:
def get_sqla_column_type(
cls,
column_type: Optional[str],
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[Tuple[TypeEngine, GenericDataType], None]:
"""
Return a sqlalchemy native column type that corresponds to the column type
defined in the data source (return None to use default type inferred by
SQLAlchemy). Override `column_type_mappings` for specific needs
(see MSSQL for example of NCHAR/NVARCHAR handling).
:param type_: Column type returned by inspector
:param column_type: Column type returned by inspector
:return: SqlAlchemy column type
"""
if not type_:
if not column_type:
return None
for regex, sqla_type in cls.column_type_mappings:
match = regex.match(type_)
for regex, sqla_type, generic_type in column_type_mappings:
match = regex.match(column_type)
if match:
if callable(sqla_type):
return sqla_type(match)
return sqla_type
return sqla_type(match), generic_type
return sqla_type, generic_type
return None
@staticmethod
@@ -1101,3 +1154,43 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
or parsed_query.is_explain()
or parsed_query.is_show()
)
@classmethod
def get_column_spec(
cls,
native_type: Optional[str],
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
"""
Converts native database type to sqlalchemy column type.
:param native_type: Native database typee
:param source: Type coming from the database table or cursor description
:return: ColumnSpec object
"""
column_type = None
if (
cls.get_sqla_column_type(
native_type, column_type_mappings=column_type_mappings
)
is not None
):
column_type, generic_type = cls.get_sqla_column_type( # type: ignore
native_type, column_type_mappings=column_type_mappings
)
is_dttm = generic_type == GenericDataType.TEMPORAL
if column_type:
return ColumnSpec(
sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm
)
return None

View File

@@ -15,18 +15,12 @@
# specific language governing permissions and limitations
# under the License.
import logging
import re
from datetime import datetime
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
from sqlalchemy.types import String, UnicodeText
from typing import Any, List, Optional, Tuple
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
from superset.utils import core as utils
if TYPE_CHECKING:
from superset.models.core import Database
logger = logging.getLogger(__name__)
@@ -77,11 +71,6 @@ class MssqlEngineSpec(BaseEngineSpec):
# Lists of `pyodbc.Row` need to be unpacked further
return cls.pyodbc_rows_to_tuples(data)
column_type_mappings = (
(re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
(re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
)
@classmethod
def extract_error_message(cls, ex: Exception) -> str:
if str(ex).startswith("(8155,"):

View File

@@ -14,14 +14,29 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import re
from datetime import datetime
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Match, Optional, Pattern, Tuple, Union
from urllib import parse
from sqlalchemy.dialects.mysql import (
BIT,
DECIMAL,
DOUBLE,
FLOAT,
INTEGER,
LONGTEXT,
MEDIUMINT,
MEDIUMTEXT,
TINYINT,
TINYTEXT,
)
from sqlalchemy.engine.url import URL
from sqlalchemy.types import TypeEngine
from superset.db_engine_specs.base import BaseEngineSpec
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
class MySQLEngineSpec(BaseEngineSpec):
@@ -29,6 +44,34 @@ class MySQLEngineSpec(BaseEngineSpec):
engine_name = "MySQL"
max_column_name_length = 64
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = (
(re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,),
(re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,),
(
re.compile(r"^mediumint", re.IGNORECASE),
MEDIUMINT(),
GenericDataType.NUMERIC,
),
(re.compile(r"^decimal", re.IGNORECASE), DECIMAL(), GenericDataType.NUMERIC,),
(re.compile(r"^float", re.IGNORECASE), FLOAT(), GenericDataType.NUMERIC,),
(re.compile(r"^double", re.IGNORECASE), DOUBLE(), GenericDataType.NUMERIC,),
(re.compile(r"^bit", re.IGNORECASE), BIT(), GenericDataType.NUMERIC,),
(re.compile(r"^tinytext", re.IGNORECASE), TINYTEXT(), GenericDataType.STRING,),
(
re.compile(r"^mediumtext", re.IGNORECASE),
MEDIUMTEXT(),
GenericDataType.STRING,
),
(re.compile(r"^longtext", re.IGNORECASE), LONGTEXT(), GenericDataType.STRING,),
)
_time_grain_expressions = {
None: "{col}",
"PT1S": "DATE_ADD(DATE({col}), "
@@ -98,3 +141,26 @@ class MySQLEngineSpec(BaseEngineSpec):
except (AttributeError, KeyError):
pass
return message
@classmethod
def get_column_spec( # type: ignore
cls,
native_type: Optional[str],
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_spec = super().get_column_spec(native_type)
if column_spec:
return column_spec
return super().get_column_spec(
native_type, column_type_mappings=column_type_mappings
)

View File

@@ -18,14 +18,28 @@ import json
import logging
import re
from datetime import datetime
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
from typing import (
Any,
Callable,
Dict,
List,
Match,
Optional,
Pattern,
Tuple,
TYPE_CHECKING,
Union,
)
from pytz import _FixedOffset # type: ignore
from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
from sqlalchemy.dialects.postgresql.base import PGInspector
from sqlalchemy.types import String, TypeEngine
from superset.db_engine_specs.base import BaseEngineSpec
from superset.exceptions import SupersetException
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
if TYPE_CHECKING:
from superset.models.core import Database # pragma: no cover
@@ -77,6 +91,21 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
max_column_name_length = 63
try_remove_schema_from_table_name = False
column_type_mappings = (
(
re.compile(r"^double precision", re.IGNORECASE),
DOUBLE_PRECISION(),
GenericDataType.NUMERIC,
),
(
re.compile(r"^array.*", re.IGNORECASE),
lambda match: ARRAY(int(match[2])) if match[2] else String(),
utils.GenericDataType.STRING,
),
(re.compile(r"^json.*", re.IGNORECASE), JSON(), utils.GenericDataType.STRING,),
(re.compile(r"^enum.*", re.IGNORECASE), ENUM(), utils.GenericDataType.STRING,),
)
@classmethod
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
return True
@@ -144,3 +173,26 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
engine_params["connect_args"] = connect_args
extra["engine_params"] = engine_params
return extra
@classmethod
def get_column_spec( # type: ignore
cls,
native_type: Optional[str],
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_spec = super().get_column_spec(native_type)
if column_spec:
return column_spec
return super().get_column_spec(
native_type, column_type_mappings=column_type_mappings
)

View File

@@ -23,7 +23,19 @@ from collections import defaultdict, deque
from contextlib import closing
from datetime import datetime
from distutils.version import StrictVersion
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
from typing import (
Any,
Callable,
cast,
Dict,
List,
Match,
Optional,
Pattern,
Tuple,
TYPE_CHECKING,
Union,
)
from urllib import parse
import pandas as pd
@@ -36,6 +48,7 @@ from sqlalchemy.engine.result import RowProxy
from sqlalchemy.engine.url import make_url, URL
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import ColumnClause, Select
from sqlalchemy.types import TypeEngine
from superset import app, cache_manager, is_feature_enabled
from superset.db_engine_specs.base import BaseEngineSpec
@@ -52,6 +65,7 @@ from superset.models.sql_types.presto_sql_types import (
from superset.result_set import destringify
from superset.sql_parse import ParsedQuery
from superset.utils import core as utils
from superset.utils.core import ColumnSpec, GenericDataType
if TYPE_CHECKING:
# prevent circular imports
@@ -293,7 +307,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
field_info = cls._split_data_type(single_field, r"\s")
# check if there is a structural data type within
# overall structural data type
column_type = cls.get_sqla_column_type(field_info[1])
column_spec = cls.get_column_spec(field_info[1])
column_type = column_spec.sqla_type if column_spec else None
if column_type is None:
column_type = types.String()
logger.info(
@@ -356,31 +371,89 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
return columns
column_type_mappings = (
(re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
(re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
(re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
(re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
(re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
(re.compile(r"^real.*", re.IGNORECASE), types.Float()),
(re.compile(r"^double.*", re.IGNORECASE), types.Float()),
(re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
(
re.compile(r"^boolean.*", re.IGNORECASE),
types.BOOLEAN,
utils.GenericDataType.BOOLEAN,
),
(
re.compile(r"^tinyint.*", re.IGNORECASE),
TinyInteger(),
utils.GenericDataType.NUMERIC,
),
(
re.compile(r"^smallint.*", re.IGNORECASE),
types.SMALLINT(),
utils.GenericDataType.NUMERIC,
),
(
re.compile(r"^integer.*", re.IGNORECASE),
types.INTEGER(),
utils.GenericDataType.NUMERIC,
),
(
re.compile(r"^bigint.*", re.IGNORECASE),
types.BIGINT(),
utils.GenericDataType.NUMERIC,
),
(
re.compile(r"^real.*", re.IGNORECASE),
types.FLOAT(),
utils.GenericDataType.NUMERIC,
),
(
re.compile(r"^double.*", re.IGNORECASE),
types.FLOAT(),
utils.GenericDataType.NUMERIC,
),
(
re.compile(r"^decimal.*", re.IGNORECASE),
types.DECIMAL(),
utils.GenericDataType.NUMERIC,
),
(
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
utils.GenericDataType.STRING,
),
(re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
(re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
(re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
(re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
(re.compile(r"^time.*", re.IGNORECASE), types.Time()),
(re.compile(r"^interval.*", re.IGNORECASE), Interval()),
(re.compile(r"^array.*", re.IGNORECASE), Array()),
(re.compile(r"^map.*", re.IGNORECASE), Map()),
(re.compile(r"^row.*", re.IGNORECASE), Row()),
(
re.compile(r"^varbinary.*", re.IGNORECASE),
types.VARBINARY(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^json.*", re.IGNORECASE),
types.JSON(),
utils.GenericDataType.STRING,
),
(
re.compile(r"^date.*", re.IGNORECASE),
types.DATE(),
utils.GenericDataType.TEMPORAL,
),
(
re.compile(r"^timestamp.*", re.IGNORECASE),
types.TIMESTAMP(),
utils.GenericDataType.TEMPORAL,
),
(
re.compile(r"^interval.*", re.IGNORECASE),
Interval(),
utils.GenericDataType.TEMPORAL,
),
(
re.compile(r"^time.*", re.IGNORECASE),
types.Time(),
utils.GenericDataType.TEMPORAL,
),
(re.compile(r"^array.*", re.IGNORECASE), Array(), utils.GenericDataType.STRING),
(re.compile(r"^map.*", re.IGNORECASE), Map(), utils.GenericDataType.STRING),
(re.compile(r"^row.*", re.IGNORECASE), Row(), utils.GenericDataType.STRING),
)
@classmethod
@@ -412,7 +485,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
continue
# otherwise column is a basic data type
column_type = cls.get_sqla_column_type(column.Type)
column_spec = cls.get_column_spec(column.Type)
column_type = column_spec.sqla_type if column_spec else None
if column_type is None:
column_type = types.String()
logger.info(
@@ -1111,3 +1185,27 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
return super().is_readonly_query(parsed_query) or parsed_query.is_show()
@classmethod
def get_column_spec( # type: ignore
cls,
native_type: Optional[str],
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
column_type_mappings: Tuple[
Tuple[
Pattern[str],
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
GenericDataType,
],
...,
] = column_type_mappings,
) -> Union[ColumnSpec, None]:
column_spec = super().get_column_spec(
native_type, column_type_mappings=column_type_mappings
)
if column_spec:
return column_spec
return super().get_column_spec(native_type)

View File

@@ -181,9 +181,10 @@ class SupersetResultSet:
return next((i for i in items if i), None)
def is_temporal(self, db_type_str: Optional[str]) -> bool:
return self.db_engine_spec.is_db_column_type_match(
db_type_str, utils.GenericDataType.TEMPORAL
)
column_spec = self.db_engine_spec.get_column_spec(db_type_str)
if column_spec is None:
return False
return column_spec.is_dttm
def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:
"""Given a pyarrow data type, Returns a generic database type"""

View File

@@ -82,7 +82,7 @@ from sqlalchemy.dialects.mysql import MEDIUMTEXT
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.engine.reflection import Inspector
from sqlalchemy.sql.type_api import Variant
from sqlalchemy.types import TEXT, TypeDecorator
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
from typing_extensions import TypedDict
import _thread # pylint: disable=C0411
@@ -148,6 +148,10 @@ class GenericDataType(IntEnum):
STRING = 1
TEMPORAL = 2
BOOLEAN = 3
# ARRAY = 4 # Mapping all the complex data types to STRING for now
# JSON = 5 # and leaving these as a reminder.
# MAP = 6
# ROW = 7
class ChartDataResultFormat(str, Enum):
@@ -306,6 +310,18 @@ class TemporalType(str, Enum):
TIMESTAMP = "TIMESTAMP"
class ColumnTypeSource(Enum):
GET_TABLE = 1
CURSOR_DESCRIPION = 2
class ColumnSpec(NamedTuple):
sqla_type: Union[TypeEngine, str]
generic_type: GenericDataType
is_dttm: bool
python_date_format: Optional[str] = None
try:
# Having might not have been imported.
class DimSelector(Having):

View File

@@ -24,32 +24,37 @@ from sqlalchemy.types import String, UnicodeText
from superset.db_engine_specs.base import BaseEngineSpec
from superset.db_engine_specs.mssql import MssqlEngineSpec
from superset.utils.core import GenericDataType
from tests.db_engine_specs.base_tests import TestDbEngineSpec
class TestMssqlEngineSpec(TestDbEngineSpec):
def test_mssql_column_types(self):
def assert_type(type_string, type_expected):
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
def assert_type(type_string, type_expected, generic_type_expected):
if type_expected is None:
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
self.assertIsNone(type_assigned)
else:
self.assertIsInstance(type_assigned, type_expected)
column_spec = MssqlEngineSpec.get_column_spec(type_string)
if column_spec != None:
self.assertIsInstance(column_spec.sqla_type, type_expected)
self.assertEquals(column_spec.generic_type, generic_type_expected)
assert_type("INT", None)
assert_type("STRING", String)
assert_type("CHAR(10)", String)
assert_type("VARCHAR(10)", String)
assert_type("TEXT", String)
assert_type("NCHAR(10)", UnicodeText)
assert_type("NVARCHAR(10)", UnicodeText)
assert_type("NTEXT", UnicodeText)
assert_type("STRING", String, GenericDataType.STRING)
assert_type("CHAR(10)", String, GenericDataType.STRING)
assert_type("VARCHAR(10)", String, GenericDataType.STRING)
assert_type("TEXT", String, GenericDataType.STRING)
assert_type("NCHAR(10)", UnicodeText, GenericDataType.STRING)
assert_type("NVARCHAR(10)", UnicodeText, GenericDataType.STRING)
assert_type("NTEXT", UnicodeText, GenericDataType.STRING)
def test_where_clause_n_prefix(self):
dialect = mssql.dialect()
spec = MssqlEngineSpec
str_col = column("col", type_=spec.get_sqla_column_type("VARCHAR(10)"))
unicode_col = column("unicode_col", type_=spec.get_sqla_column_type("NTEXT"))
type_, _ = spec.get_sqla_column_type("VARCHAR(10)")
str_col = column("col", type_=type_)
type_, _ = spec.get_sqla_column_type("NTEXT")
unicode_col = column("unicode_col", type_=type_)
tbl = table("tbl")
sel = (
select([str_col, unicode_col])

View File

@@ -89,18 +89,9 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
("TIME", GenericDataType.TEMPORAL),
)
for type_expectation in type_expectations:
type_str = type_expectation[0]
col_type = type_expectation[1]
assert MySQLEngineSpec.is_db_column_type_match(
type_str, GenericDataType.NUMERIC
) is (col_type == GenericDataType.NUMERIC)
assert MySQLEngineSpec.is_db_column_type_match(
type_str, GenericDataType.STRING
) is (col_type == GenericDataType.STRING)
assert MySQLEngineSpec.is_db_column_type_match(
type_str, GenericDataType.TEMPORAL
) is (col_type == GenericDataType.TEMPORAL)
for type_str, col_type in type_expectations:
column_spec = MySQLEngineSpec.get_column_spec(type_str)
assert column_spec.generic_type == col_type
def test_extract_error_message(self):
from MySQLdb._exceptions import OperationalError

View File

@@ -24,7 +24,7 @@ from sqlalchemy.sql import select
from superset.db_engine_specs.presto import PrestoEngineSpec
from superset.sql_parse import ParsedQuery
from superset.utils.core import DatasourceName
from superset.utils.core import DatasourceName, GenericDataType
from tests.db_engine_specs.base_tests import TestDbEngineSpec
@@ -535,30 +535,37 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
def test_get_sqla_column_type(self):
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
assert isinstance(sqla_type, types.VARCHAR)
assert sqla_type.length == 255
column_spec = PrestoEngineSpec.get_column_spec("varchar(255)")
assert isinstance(column_spec.sqla_type, types.VARCHAR)
assert column_spec.sqla_type.length == 255
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
assert isinstance(sqla_type, types.String)
assert sqla_type.length is None
column_spec = PrestoEngineSpec.get_column_spec("varchar")
assert isinstance(column_spec.sqla_type, types.String)
assert column_spec.sqla_type.length is None
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length == 10
column_spec = PrestoEngineSpec.get_column_spec("char(10)")
assert isinstance(column_spec.sqla_type, types.CHAR)
assert column_spec.sqla_type.length == 10
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
assert isinstance(sqla_type, types.CHAR)
assert sqla_type.length is None
column_spec = PrestoEngineSpec.get_column_spec("char")
assert isinstance(column_spec.sqla_type, types.CHAR)
assert column_spec.sqla_type.length is None
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
assert isinstance(sqla_type, types.Integer)
column_spec = PrestoEngineSpec.get_column_spec("integer")
assert isinstance(column_spec.sqla_type, types.Integer)
self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC)
sqla_type = PrestoEngineSpec.get_sqla_column_type("time")
assert isinstance(sqla_type, types.Time)
column_spec = PrestoEngineSpec.get_column_spec("time")
assert isinstance(column_spec.sqla_type, types.Time)
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
sqla_type = PrestoEngineSpec.get_sqla_column_type("timestamp")
assert isinstance(sqla_type, types.TIMESTAMP)
column_spec = PrestoEngineSpec.get_column_spec("timestamp")
assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
assert sqla_type is None

View File

@@ -84,11 +84,9 @@ class TestDatabaseModel(SupersetTestCase):
"TEXT": GenericDataType.STRING,
"NTEXT": GenericDataType.STRING,
# numeric
"INT": GenericDataType.NUMERIC,
"INTEGER": GenericDataType.NUMERIC,
"BIGINT": GenericDataType.NUMERIC,
"FLOAT": GenericDataType.NUMERIC,
"DECIMAL": GenericDataType.NUMERIC,
"MONEY": GenericDataType.NUMERIC,
# temporal
"DATE": GenericDataType.TEMPORAL,
"DATETIME": GenericDataType.TEMPORAL,