mirror of
https://github.com/apache/superset.git
synced 2026-04-07 10:31:50 +00:00
feat(explore): Postgres datatype conversion (#13294)
* test * unnecessary import * fix lint * changes * fix lint * changes * changes * changes * changes * answering comments & changes * answering comments * answering comments * changes * changes * changes * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests * fix tests
This commit is contained in:
@@ -69,6 +69,7 @@ from superset.result_set import SupersetResultSet
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.typing import Metric, QueryObjectDict
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import GenericDataType
|
||||
|
||||
config = app.config
|
||||
metadata = Model.metadata # pylint: disable=no-member
|
||||
@@ -186,20 +187,20 @@ class TableColumn(Model, BaseColumn):
|
||||
"""
|
||||
Check if the column has a numeric datatype.
|
||||
"""
|
||||
db_engine_spec = self.table.database.db_engine_spec
|
||||
return db_engine_spec.is_db_column_type_match(
|
||||
self.type, utils.GenericDataType.NUMERIC
|
||||
)
|
||||
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
|
||||
if column_spec is None:
|
||||
return False
|
||||
return column_spec.generic_type == GenericDataType.NUMERIC
|
||||
|
||||
@property
|
||||
def is_string(self) -> bool:
|
||||
"""
|
||||
Check if the column has a string datatype.
|
||||
"""
|
||||
db_engine_spec = self.table.database.db_engine_spec
|
||||
return db_engine_spec.is_db_column_type_match(
|
||||
self.type, utils.GenericDataType.STRING
|
||||
)
|
||||
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
|
||||
if column_spec is None:
|
||||
return False
|
||||
return column_spec.generic_type == GenericDataType.STRING
|
||||
|
||||
@property
|
||||
def is_temporal(self) -> bool:
|
||||
@@ -211,10 +212,10 @@ class TableColumn(Model, BaseColumn):
|
||||
"""
|
||||
if self.is_dttm is not None:
|
||||
return self.is_dttm
|
||||
db_engine_spec = self.table.database.db_engine_spec
|
||||
return db_engine_spec.is_db_column_type_match(
|
||||
self.type, utils.GenericDataType.TEMPORAL
|
||||
)
|
||||
column_spec = self.table.database.db_engine_spec.get_column_spec(self.type)
|
||||
if column_spec is None:
|
||||
return False
|
||||
return column_spec.is_dttm
|
||||
|
||||
def get_sqla_col(self, label: Optional[str] = None) -> Column:
|
||||
label = label or self.column_name
|
||||
@@ -222,7 +223,8 @@ class TableColumn(Model, BaseColumn):
|
||||
col = literal_column(self.expression)
|
||||
else:
|
||||
db_engine_spec = self.table.database.db_engine_spec
|
||||
type_ = db_engine_spec.get_sqla_column_type(self.type)
|
||||
column_spec = db_engine_spec.get_column_spec(self.type)
|
||||
type_ = column_spec.sqla_type if column_spec else None
|
||||
col = column(self.column_name, type_=type_)
|
||||
col = self.table.make_sqla_column_compatible(col, label)
|
||||
return col
|
||||
|
||||
@@ -41,7 +41,7 @@ import pandas as pd
|
||||
import sqlparse
|
||||
from flask import g
|
||||
from flask_babel import gettext as __, lazy_gettext as _
|
||||
from sqlalchemy import column, DateTime, select
|
||||
from sqlalchemy import column, DateTime, select, types
|
||||
from sqlalchemy.engine.base import Engine
|
||||
from sqlalchemy.engine.interfaces import Compiled, Dialect
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
@@ -50,13 +50,14 @@ from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql import quoted_name, text
|
||||
from sqlalchemy.sql.expression import ColumnClause, ColumnElement, Select, TextAsFrom
|
||||
from sqlalchemy.types import TypeEngine
|
||||
from sqlalchemy.types import String, TypeEngine, UnicodeText
|
||||
|
||||
from superset import app, security_manager, sql_parse
|
||||
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
|
||||
from superset.models.sql_lab import Query
|
||||
from superset.sql_parse import ParsedQuery, Table
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import ColumnSpec, GenericDataType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# prevent circular imports
|
||||
@@ -145,8 +146,87 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
_date_trunc_functions: Dict[str, str] = {}
|
||||
_time_grain_expressions: Dict[Optional[str], str] = {}
|
||||
column_type_mappings: Tuple[
|
||||
Tuple[Pattern[str], Union[TypeEngine, Callable[[Match[str]], TypeEngine]]], ...,
|
||||
] = ()
|
||||
Tuple[
|
||||
Pattern[str],
|
||||
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
|
||||
GenericDataType,
|
||||
],
|
||||
...,
|
||||
] = (
|
||||
(
|
||||
re.compile(r"^smallint", re.IGNORECASE),
|
||||
types.SmallInteger(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^integer", re.IGNORECASE),
|
||||
types.Integer(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^bigint", re.IGNORECASE),
|
||||
types.BigInteger(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^decimal", re.IGNORECASE),
|
||||
types.Numeric(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^numeric", re.IGNORECASE),
|
||||
types.Numeric(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(re.compile(r"^real", re.IGNORECASE), types.REAL, GenericDataType.NUMERIC,),
|
||||
(
|
||||
re.compile(r"^smallserial", re.IGNORECASE),
|
||||
types.SmallInteger(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^serial", re.IGNORECASE),
|
||||
types.Integer(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^bigserial", re.IGNORECASE),
|
||||
types.BigInteger(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^string", re.IGNORECASE),
|
||||
types.String(),
|
||||
utils.GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE),
|
||||
UnicodeText(),
|
||||
utils.GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE),
|
||||
String(),
|
||||
utils.GenericDataType.STRING,
|
||||
),
|
||||
(re.compile(r"^date", re.IGNORECASE), types.Date(), GenericDataType.TEMPORAL,),
|
||||
(
|
||||
re.compile(r"^timestamp", re.IGNORECASE),
|
||||
types.TIMESTAMP(),
|
||||
GenericDataType.TEMPORAL,
|
||||
),
|
||||
(
|
||||
re.compile(r"^interval", re.IGNORECASE),
|
||||
types.Interval(),
|
||||
GenericDataType.TEMPORAL,
|
||||
),
|
||||
(re.compile(r"^time", re.IGNORECASE), types.Time(), GenericDataType.TEMPORAL,),
|
||||
(
|
||||
re.compile(r"^boolean", re.IGNORECASE),
|
||||
types.Boolean(),
|
||||
GenericDataType.BOOLEAN,
|
||||
),
|
||||
)
|
||||
time_groupby_inline = False
|
||||
limit_method = LimitMethod.FORCE_LIMIT
|
||||
time_secondary_columns = False
|
||||
@@ -160,25 +240,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
try_remove_schema_from_table_name = True # pylint: disable=invalid-name
|
||||
run_multiple_statements_as_one = False
|
||||
|
||||
# default matching patterns to convert database specific column types to
|
||||
# more generic types
|
||||
db_column_types: Dict[utils.GenericDataType, Tuple[Pattern[str], ...]] = {
|
||||
utils.GenericDataType.NUMERIC: (
|
||||
re.compile(r"BIT", re.IGNORECASE),
|
||||
re.compile(
|
||||
r".*(DOUBLE|FLOAT|INT|NUMBER|REAL|NUMERIC|DECIMAL|MONEY).*",
|
||||
re.IGNORECASE,
|
||||
),
|
||||
re.compile(r".*LONG$", re.IGNORECASE),
|
||||
),
|
||||
utils.GenericDataType.STRING: (
|
||||
re.compile(r".*(CHAR|STRING|TEXT).*", re.IGNORECASE),
|
||||
),
|
||||
utils.GenericDataType.TEMPORAL: (
|
||||
re.compile(r".*(DATE|TIME).*", re.IGNORECASE),
|
||||
),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_dbapi_exception_mapping(cls) -> Dict[Type[Exception], Type[Exception]]:
|
||||
"""
|
||||
@@ -208,25 +269,6 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
return exception
|
||||
return new_exception(str(exception))
|
||||
|
||||
@classmethod
|
||||
def is_db_column_type_match(
|
||||
cls, db_column_type: Optional[str], target_column_type: utils.GenericDataType
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a column type satisfies a pattern in a collection of regexes found in
|
||||
`db_column_types`. For example, if `db_column_type == "NVARCHAR"`,
|
||||
it would be a match for "STRING" due to being a match for the regex ".*CHAR.*".
|
||||
|
||||
:param db_column_type: Column type to evaluate
|
||||
:param target_column_type: The target type to evaluate for
|
||||
:return: `True` if a `db_column_type` matches any pattern corresponding to
|
||||
`target_column_type`
|
||||
"""
|
||||
if not db_column_type:
|
||||
return False
|
||||
patterns = cls.db_column_types[target_column_type]
|
||||
return any(pattern.match(db_column_type) for pattern in patterns)
|
||||
|
||||
@classmethod
|
||||
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
|
||||
return False
|
||||
@@ -967,24 +1009,35 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
return label_mutated
|
||||
|
||||
@classmethod
|
||||
def get_sqla_column_type(cls, type_: Optional[str]) -> Optional[TypeEngine]:
|
||||
def get_sqla_column_type(
|
||||
cls,
|
||||
column_type: Optional[str],
|
||||
column_type_mappings: Tuple[
|
||||
Tuple[
|
||||
Pattern[str],
|
||||
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
|
||||
GenericDataType,
|
||||
],
|
||||
...,
|
||||
] = column_type_mappings,
|
||||
) -> Union[Tuple[TypeEngine, GenericDataType], None]:
|
||||
"""
|
||||
Return a sqlalchemy native column type that corresponds to the column type
|
||||
defined in the data source (return None to use default type inferred by
|
||||
SQLAlchemy). Override `column_type_mappings` for specific needs
|
||||
(see MSSQL for example of NCHAR/NVARCHAR handling).
|
||||
|
||||
:param type_: Column type returned by inspector
|
||||
:param column_type: Column type returned by inspector
|
||||
:return: SqlAlchemy column type
|
||||
"""
|
||||
if not type_:
|
||||
if not column_type:
|
||||
return None
|
||||
for regex, sqla_type in cls.column_type_mappings:
|
||||
match = regex.match(type_)
|
||||
for regex, sqla_type, generic_type in column_type_mappings:
|
||||
match = regex.match(column_type)
|
||||
if match:
|
||||
if callable(sqla_type):
|
||||
return sqla_type(match)
|
||||
return sqla_type
|
||||
return sqla_type(match), generic_type
|
||||
return sqla_type, generic_type
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@@ -1101,3 +1154,43 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods
|
||||
or parsed_query.is_explain()
|
||||
or parsed_query.is_show()
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_column_spec(
|
||||
cls,
|
||||
native_type: Optional[str],
|
||||
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
|
||||
column_type_mappings: Tuple[
|
||||
Tuple[
|
||||
Pattern[str],
|
||||
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
|
||||
GenericDataType,
|
||||
],
|
||||
...,
|
||||
] = column_type_mappings,
|
||||
) -> Union[ColumnSpec, None]:
|
||||
"""
|
||||
Converts native database type to sqlalchemy column type.
|
||||
:param native_type: Native database typee
|
||||
:param source: Type coming from the database table or cursor description
|
||||
:return: ColumnSpec object
|
||||
"""
|
||||
column_type = None
|
||||
|
||||
if (
|
||||
cls.get_sqla_column_type(
|
||||
native_type, column_type_mappings=column_type_mappings
|
||||
)
|
||||
is not None
|
||||
):
|
||||
column_type, generic_type = cls.get_sqla_column_type( # type: ignore
|
||||
native_type, column_type_mappings=column_type_mappings
|
||||
)
|
||||
is_dttm = generic_type == GenericDataType.TEMPORAL
|
||||
|
||||
if column_type:
|
||||
return ColumnSpec(
|
||||
sqla_type=column_type, generic_type=generic_type, is_dttm=is_dttm
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -15,18 +15,12 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, List, Optional, Tuple, TYPE_CHECKING
|
||||
|
||||
from sqlalchemy.types import String, UnicodeText
|
||||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec, LimitMethod
|
||||
from superset.utils import core as utils
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -77,11 +71,6 @@ class MssqlEngineSpec(BaseEngineSpec):
|
||||
# Lists of `pyodbc.Row` need to be unpacked further
|
||||
return cls.pyodbc_rows_to_tuples(data)
|
||||
|
||||
column_type_mappings = (
|
||||
(re.compile(r"^N((VAR)?CHAR|TEXT)", re.IGNORECASE), UnicodeText()),
|
||||
(re.compile(r"^((VAR)?CHAR|TEXT|STRING)", re.IGNORECASE), String()),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def extract_error_message(cls, ex: Exception) -> str:
|
||||
if str(ex).startswith("(8155,"):
|
||||
|
||||
@@ -14,14 +14,29 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Match, Optional, Pattern, Tuple, Union
|
||||
from urllib import parse
|
||||
|
||||
from sqlalchemy.dialects.mysql import (
|
||||
BIT,
|
||||
DECIMAL,
|
||||
DOUBLE,
|
||||
FLOAT,
|
||||
INTEGER,
|
||||
LONGTEXT,
|
||||
MEDIUMINT,
|
||||
MEDIUMTEXT,
|
||||
TINYINT,
|
||||
TINYTEXT,
|
||||
)
|
||||
from sqlalchemy.engine.url import URL
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import ColumnSpec, GenericDataType
|
||||
|
||||
|
||||
class MySQLEngineSpec(BaseEngineSpec):
|
||||
@@ -29,6 +44,34 @@ class MySQLEngineSpec(BaseEngineSpec):
|
||||
engine_name = "MySQL"
|
||||
max_column_name_length = 64
|
||||
|
||||
column_type_mappings: Tuple[
|
||||
Tuple[
|
||||
Pattern[str],
|
||||
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
|
||||
GenericDataType,
|
||||
],
|
||||
...,
|
||||
] = (
|
||||
(re.compile(r"^int.*", re.IGNORECASE), INTEGER(), GenericDataType.NUMERIC,),
|
||||
(re.compile(r"^tinyint", re.IGNORECASE), TINYINT(), GenericDataType.NUMERIC,),
|
||||
(
|
||||
re.compile(r"^mediumint", re.IGNORECASE),
|
||||
MEDIUMINT(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(re.compile(r"^decimal", re.IGNORECASE), DECIMAL(), GenericDataType.NUMERIC,),
|
||||
(re.compile(r"^float", re.IGNORECASE), FLOAT(), GenericDataType.NUMERIC,),
|
||||
(re.compile(r"^double", re.IGNORECASE), DOUBLE(), GenericDataType.NUMERIC,),
|
||||
(re.compile(r"^bit", re.IGNORECASE), BIT(), GenericDataType.NUMERIC,),
|
||||
(re.compile(r"^tinytext", re.IGNORECASE), TINYTEXT(), GenericDataType.STRING,),
|
||||
(
|
||||
re.compile(r"^mediumtext", re.IGNORECASE),
|
||||
MEDIUMTEXT(),
|
||||
GenericDataType.STRING,
|
||||
),
|
||||
(re.compile(r"^longtext", re.IGNORECASE), LONGTEXT(), GenericDataType.STRING,),
|
||||
)
|
||||
|
||||
_time_grain_expressions = {
|
||||
None: "{col}",
|
||||
"PT1S": "DATE_ADD(DATE({col}), "
|
||||
@@ -98,3 +141,26 @@ class MySQLEngineSpec(BaseEngineSpec):
|
||||
except (AttributeError, KeyError):
|
||||
pass
|
||||
return message
|
||||
|
||||
@classmethod
|
||||
def get_column_spec( # type: ignore
|
||||
cls,
|
||||
native_type: Optional[str],
|
||||
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
|
||||
column_type_mappings: Tuple[
|
||||
Tuple[
|
||||
Pattern[str],
|
||||
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
|
||||
GenericDataType,
|
||||
],
|
||||
...,
|
||||
] = column_type_mappings,
|
||||
) -> Union[ColumnSpec, None]:
|
||||
|
||||
column_spec = super().get_column_spec(native_type)
|
||||
if column_spec:
|
||||
return column_spec
|
||||
|
||||
return super().get_column_spec(
|
||||
native_type, column_type_mappings=column_type_mappings
|
||||
)
|
||||
|
||||
@@ -18,14 +18,28 @@ import json
|
||||
import logging
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Tuple, TYPE_CHECKING
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
List,
|
||||
Match,
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
|
||||
from pytz import _FixedOffset # type: ignore
|
||||
from sqlalchemy.dialects.postgresql import ARRAY, DOUBLE_PRECISION, ENUM, JSON
|
||||
from sqlalchemy.dialects.postgresql.base import PGInspector
|
||||
from sqlalchemy.types import String, TypeEngine
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.exceptions import SupersetException
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import ColumnSpec, GenericDataType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.models.core import Database # pragma: no cover
|
||||
@@ -77,6 +91,21 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
|
||||
max_column_name_length = 63
|
||||
try_remove_schema_from_table_name = False
|
||||
|
||||
column_type_mappings = (
|
||||
(
|
||||
re.compile(r"^double precision", re.IGNORECASE),
|
||||
DOUBLE_PRECISION(),
|
||||
GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^array.*", re.IGNORECASE),
|
||||
lambda match: ARRAY(int(match[2])) if match[2] else String(),
|
||||
utils.GenericDataType.STRING,
|
||||
),
|
||||
(re.compile(r"^json.*", re.IGNORECASE), JSON(), utils.GenericDataType.STRING,),
|
||||
(re.compile(r"^enum.*", re.IGNORECASE), ENUM(), utils.GenericDataType.STRING,),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_allow_cost_estimate(cls, extra: Dict[str, Any]) -> bool:
|
||||
return True
|
||||
@@ -144,3 +173,26 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
|
||||
engine_params["connect_args"] = connect_args
|
||||
extra["engine_params"] = engine_params
|
||||
return extra
|
||||
|
||||
@classmethod
|
||||
def get_column_spec( # type: ignore
|
||||
cls,
|
||||
native_type: Optional[str],
|
||||
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
|
||||
column_type_mappings: Tuple[
|
||||
Tuple[
|
||||
Pattern[str],
|
||||
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
|
||||
GenericDataType,
|
||||
],
|
||||
...,
|
||||
] = column_type_mappings,
|
||||
) -> Union[ColumnSpec, None]:
|
||||
|
||||
column_spec = super().get_column_spec(native_type)
|
||||
if column_spec:
|
||||
return column_spec
|
||||
|
||||
return super().get_column_spec(
|
||||
native_type, column_type_mappings=column_type_mappings
|
||||
)
|
||||
|
||||
@@ -23,7 +23,19 @@ from collections import defaultdict, deque
|
||||
from contextlib import closing
|
||||
from datetime import datetime
|
||||
from distutils.version import StrictVersion
|
||||
from typing import Any, cast, Dict, List, Optional, Tuple, TYPE_CHECKING, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
cast,
|
||||
Dict,
|
||||
List,
|
||||
Match,
|
||||
Optional,
|
||||
Pattern,
|
||||
Tuple,
|
||||
TYPE_CHECKING,
|
||||
Union,
|
||||
)
|
||||
from urllib import parse
|
||||
|
||||
import pandas as pd
|
||||
@@ -36,6 +48,7 @@ from sqlalchemy.engine.result import RowProxy
|
||||
from sqlalchemy.engine.url import make_url, URL
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.sql.expression import ColumnClause, Select
|
||||
from sqlalchemy.types import TypeEngine
|
||||
|
||||
from superset import app, cache_manager, is_feature_enabled
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
@@ -52,6 +65,7 @@ from superset.models.sql_types.presto_sql_types import (
|
||||
from superset.result_set import destringify
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.utils import core as utils
|
||||
from superset.utils.core import ColumnSpec, GenericDataType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# prevent circular imports
|
||||
@@ -293,7 +307,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
|
||||
field_info = cls._split_data_type(single_field, r"\s")
|
||||
# check if there is a structural data type within
|
||||
# overall structural data type
|
||||
column_type = cls.get_sqla_column_type(field_info[1])
|
||||
column_spec = cls.get_column_spec(field_info[1])
|
||||
column_type = column_spec.sqla_type if column_spec else None
|
||||
if column_type is None:
|
||||
column_type = types.String()
|
||||
logger.info(
|
||||
@@ -356,31 +371,89 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
|
||||
return columns
|
||||
|
||||
column_type_mappings = (
|
||||
(re.compile(r"^boolean.*", re.IGNORECASE), types.Boolean()),
|
||||
(re.compile(r"^tinyint.*", re.IGNORECASE), TinyInteger()),
|
||||
(re.compile(r"^smallint.*", re.IGNORECASE), types.SmallInteger()),
|
||||
(re.compile(r"^integer.*", re.IGNORECASE), types.Integer()),
|
||||
(re.compile(r"^bigint.*", re.IGNORECASE), types.BigInteger()),
|
||||
(re.compile(r"^real.*", re.IGNORECASE), types.Float()),
|
||||
(re.compile(r"^double.*", re.IGNORECASE), types.Float()),
|
||||
(re.compile(r"^decimal.*", re.IGNORECASE), types.DECIMAL()),
|
||||
(
|
||||
re.compile(r"^boolean.*", re.IGNORECASE),
|
||||
types.BOOLEAN,
|
||||
utils.GenericDataType.BOOLEAN,
|
||||
),
|
||||
(
|
||||
re.compile(r"^tinyint.*", re.IGNORECASE),
|
||||
TinyInteger(),
|
||||
utils.GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^smallint.*", re.IGNORECASE),
|
||||
types.SMALLINT(),
|
||||
utils.GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^integer.*", re.IGNORECASE),
|
||||
types.INTEGER(),
|
||||
utils.GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^bigint.*", re.IGNORECASE),
|
||||
types.BIGINT(),
|
||||
utils.GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^real.*", re.IGNORECASE),
|
||||
types.FLOAT(),
|
||||
utils.GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^double.*", re.IGNORECASE),
|
||||
types.FLOAT(),
|
||||
utils.GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^decimal.*", re.IGNORECASE),
|
||||
types.DECIMAL(),
|
||||
utils.GenericDataType.NUMERIC,
|
||||
),
|
||||
(
|
||||
re.compile(r"^varchar(\((\d+)\))*$", re.IGNORECASE),
|
||||
lambda match: types.VARCHAR(int(match[2])) if match[2] else types.String(),
|
||||
utils.GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r"^char(\((\d+)\))*$", re.IGNORECASE),
|
||||
lambda match: types.CHAR(int(match[2])) if match[2] else types.CHAR(),
|
||||
utils.GenericDataType.STRING,
|
||||
),
|
||||
(re.compile(r"^varbinary.*", re.IGNORECASE), types.VARBINARY()),
|
||||
(re.compile(r"^json.*", re.IGNORECASE), types.JSON()),
|
||||
(re.compile(r"^date.*", re.IGNORECASE), types.DATE()),
|
||||
(re.compile(r"^timestamp.*", re.IGNORECASE), types.TIMESTAMP()),
|
||||
(re.compile(r"^time.*", re.IGNORECASE), types.Time()),
|
||||
(re.compile(r"^interval.*", re.IGNORECASE), Interval()),
|
||||
(re.compile(r"^array.*", re.IGNORECASE), Array()),
|
||||
(re.compile(r"^map.*", re.IGNORECASE), Map()),
|
||||
(re.compile(r"^row.*", re.IGNORECASE), Row()),
|
||||
(
|
||||
re.compile(r"^varbinary.*", re.IGNORECASE),
|
||||
types.VARBINARY(),
|
||||
utils.GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r"^json.*", re.IGNORECASE),
|
||||
types.JSON(),
|
||||
utils.GenericDataType.STRING,
|
||||
),
|
||||
(
|
||||
re.compile(r"^date.*", re.IGNORECASE),
|
||||
types.DATE(),
|
||||
utils.GenericDataType.TEMPORAL,
|
||||
),
|
||||
(
|
||||
re.compile(r"^timestamp.*", re.IGNORECASE),
|
||||
types.TIMESTAMP(),
|
||||
utils.GenericDataType.TEMPORAL,
|
||||
),
|
||||
(
|
||||
re.compile(r"^interval.*", re.IGNORECASE),
|
||||
Interval(),
|
||||
utils.GenericDataType.TEMPORAL,
|
||||
),
|
||||
(
|
||||
re.compile(r"^time.*", re.IGNORECASE),
|
||||
types.Time(),
|
||||
utils.GenericDataType.TEMPORAL,
|
||||
),
|
||||
(re.compile(r"^array.*", re.IGNORECASE), Array(), utils.GenericDataType.STRING),
|
||||
(re.compile(r"^map.*", re.IGNORECASE), Map(), utils.GenericDataType.STRING),
|
||||
(re.compile(r"^row.*", re.IGNORECASE), Row(), utils.GenericDataType.STRING),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@@ -412,7 +485,8 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
|
||||
continue
|
||||
|
||||
# otherwise column is a basic data type
|
||||
column_type = cls.get_sqla_column_type(column.Type)
|
||||
column_spec = cls.get_column_spec(column.Type)
|
||||
column_type = column_spec.sqla_type if column_spec else None
|
||||
if column_type is None:
|
||||
column_type = types.String()
|
||||
logger.info(
|
||||
@@ -1111,3 +1185,27 @@ class PrestoEngineSpec(BaseEngineSpec): # pylint: disable=too-many-public-metho
|
||||
def is_readonly_query(cls, parsed_query: ParsedQuery) -> bool:
|
||||
"""Pessimistic readonly, 100% sure statement won't mutate anything"""
|
||||
return super().is_readonly_query(parsed_query) or parsed_query.is_show()
|
||||
|
||||
@classmethod
|
||||
def get_column_spec( # type: ignore
|
||||
cls,
|
||||
native_type: Optional[str],
|
||||
source: utils.ColumnTypeSource = utils.ColumnTypeSource.GET_TABLE,
|
||||
column_type_mappings: Tuple[
|
||||
Tuple[
|
||||
Pattern[str],
|
||||
Union[TypeEngine, Callable[[Match[str]], TypeEngine]],
|
||||
GenericDataType,
|
||||
],
|
||||
...,
|
||||
] = column_type_mappings,
|
||||
) -> Union[ColumnSpec, None]:
|
||||
|
||||
column_spec = super().get_column_spec(
|
||||
native_type, column_type_mappings=column_type_mappings
|
||||
)
|
||||
|
||||
if column_spec:
|
||||
return column_spec
|
||||
|
||||
return super().get_column_spec(native_type)
|
||||
|
||||
@@ -181,9 +181,10 @@ class SupersetResultSet:
|
||||
return next((i for i in items if i), None)
|
||||
|
||||
def is_temporal(self, db_type_str: Optional[str]) -> bool:
|
||||
return self.db_engine_spec.is_db_column_type_match(
|
||||
db_type_str, utils.GenericDataType.TEMPORAL
|
||||
)
|
||||
column_spec = self.db_engine_spec.get_column_spec(db_type_str)
|
||||
if column_spec is None:
|
||||
return False
|
||||
return column_spec.is_dttm
|
||||
|
||||
def data_type(self, col_name: str, pa_dtype: pa.DataType) -> Optional[str]:
|
||||
"""Given a pyarrow data type, Returns a generic database type"""
|
||||
|
||||
@@ -82,7 +82,7 @@ from sqlalchemy.dialects.mysql import MEDIUMTEXT
|
||||
from sqlalchemy.engine import Connection, Engine
|
||||
from sqlalchemy.engine.reflection import Inspector
|
||||
from sqlalchemy.sql.type_api import Variant
|
||||
from sqlalchemy.types import TEXT, TypeDecorator
|
||||
from sqlalchemy.types import TEXT, TypeDecorator, TypeEngine
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
import _thread # pylint: disable=C0411
|
||||
@@ -148,6 +148,10 @@ class GenericDataType(IntEnum):
|
||||
STRING = 1
|
||||
TEMPORAL = 2
|
||||
BOOLEAN = 3
|
||||
# ARRAY = 4 # Mapping all the complex data types to STRING for now
|
||||
# JSON = 5 # and leaving these as a reminder.
|
||||
# MAP = 6
|
||||
# ROW = 7
|
||||
|
||||
|
||||
class ChartDataResultFormat(str, Enum):
|
||||
@@ -306,6 +310,18 @@ class TemporalType(str, Enum):
|
||||
TIMESTAMP = "TIMESTAMP"
|
||||
|
||||
|
||||
class ColumnTypeSource(Enum):
|
||||
GET_TABLE = 1
|
||||
CURSOR_DESCRIPION = 2
|
||||
|
||||
|
||||
class ColumnSpec(NamedTuple):
|
||||
sqla_type: Union[TypeEngine, str]
|
||||
generic_type: GenericDataType
|
||||
is_dttm: bool
|
||||
python_date_format: Optional[str] = None
|
||||
|
||||
|
||||
try:
|
||||
# Having might not have been imported.
|
||||
class DimSelector(Having):
|
||||
|
||||
@@ -24,32 +24,37 @@ from sqlalchemy.types import String, UnicodeText
|
||||
|
||||
from superset.db_engine_specs.base import BaseEngineSpec
|
||||
from superset.db_engine_specs.mssql import MssqlEngineSpec
|
||||
from superset.utils.core import GenericDataType
|
||||
from tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
class TestMssqlEngineSpec(TestDbEngineSpec):
|
||||
def test_mssql_column_types(self):
|
||||
def assert_type(type_string, type_expected):
|
||||
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
|
||||
def assert_type(type_string, type_expected, generic_type_expected):
|
||||
if type_expected is None:
|
||||
type_assigned = MssqlEngineSpec.get_sqla_column_type(type_string)
|
||||
self.assertIsNone(type_assigned)
|
||||
else:
|
||||
self.assertIsInstance(type_assigned, type_expected)
|
||||
column_spec = MssqlEngineSpec.get_column_spec(type_string)
|
||||
if column_spec != None:
|
||||
self.assertIsInstance(column_spec.sqla_type, type_expected)
|
||||
self.assertEquals(column_spec.generic_type, generic_type_expected)
|
||||
|
||||
assert_type("INT", None)
|
||||
assert_type("STRING", String)
|
||||
assert_type("CHAR(10)", String)
|
||||
assert_type("VARCHAR(10)", String)
|
||||
assert_type("TEXT", String)
|
||||
assert_type("NCHAR(10)", UnicodeText)
|
||||
assert_type("NVARCHAR(10)", UnicodeText)
|
||||
assert_type("NTEXT", UnicodeText)
|
||||
assert_type("STRING", String, GenericDataType.STRING)
|
||||
assert_type("CHAR(10)", String, GenericDataType.STRING)
|
||||
assert_type("VARCHAR(10)", String, GenericDataType.STRING)
|
||||
assert_type("TEXT", String, GenericDataType.STRING)
|
||||
assert_type("NCHAR(10)", UnicodeText, GenericDataType.STRING)
|
||||
assert_type("NVARCHAR(10)", UnicodeText, GenericDataType.STRING)
|
||||
assert_type("NTEXT", UnicodeText, GenericDataType.STRING)
|
||||
|
||||
def test_where_clause_n_prefix(self):
|
||||
dialect = mssql.dialect()
|
||||
spec = MssqlEngineSpec
|
||||
str_col = column("col", type_=spec.get_sqla_column_type("VARCHAR(10)"))
|
||||
unicode_col = column("unicode_col", type_=spec.get_sqla_column_type("NTEXT"))
|
||||
type_, _ = spec.get_sqla_column_type("VARCHAR(10)")
|
||||
str_col = column("col", type_=type_)
|
||||
type_, _ = spec.get_sqla_column_type("NTEXT")
|
||||
unicode_col = column("unicode_col", type_=type_)
|
||||
tbl = table("tbl")
|
||||
sel = (
|
||||
select([str_col, unicode_col])
|
||||
|
||||
@@ -89,18 +89,9 @@ class TestMySQLEngineSpecsDbEngineSpec(TestDbEngineSpec):
|
||||
("TIME", GenericDataType.TEMPORAL),
|
||||
)
|
||||
|
||||
for type_expectation in type_expectations:
|
||||
type_str = type_expectation[0]
|
||||
col_type = type_expectation[1]
|
||||
assert MySQLEngineSpec.is_db_column_type_match(
|
||||
type_str, GenericDataType.NUMERIC
|
||||
) is (col_type == GenericDataType.NUMERIC)
|
||||
assert MySQLEngineSpec.is_db_column_type_match(
|
||||
type_str, GenericDataType.STRING
|
||||
) is (col_type == GenericDataType.STRING)
|
||||
assert MySQLEngineSpec.is_db_column_type_match(
|
||||
type_str, GenericDataType.TEMPORAL
|
||||
) is (col_type == GenericDataType.TEMPORAL)
|
||||
for type_str, col_type in type_expectations:
|
||||
column_spec = MySQLEngineSpec.get_column_spec(type_str)
|
||||
assert column_spec.generic_type == col_type
|
||||
|
||||
def test_extract_error_message(self):
|
||||
from MySQLdb._exceptions import OperationalError
|
||||
|
||||
@@ -24,7 +24,7 @@ from sqlalchemy.sql import select
|
||||
|
||||
from superset.db_engine_specs.presto import PrestoEngineSpec
|
||||
from superset.sql_parse import ParsedQuery
|
||||
from superset.utils.core import DatasourceName
|
||||
from superset.utils.core import DatasourceName, GenericDataType
|
||||
from tests.db_engine_specs.base_tests import TestDbEngineSpec
|
||||
|
||||
|
||||
@@ -535,30 +535,37 @@ class TestPrestoDbEngineSpec(TestDbEngineSpec):
|
||||
self.assertEqual(actual_expanded_cols, expected_expanded_cols)
|
||||
|
||||
def test_get_sqla_column_type(self):
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar(255)")
|
||||
assert isinstance(sqla_type, types.VARCHAR)
|
||||
assert sqla_type.length == 255
|
||||
column_spec = PrestoEngineSpec.get_column_spec("varchar(255)")
|
||||
assert isinstance(column_spec.sqla_type, types.VARCHAR)
|
||||
assert column_spec.sqla_type.length == 255
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
|
||||
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type("varchar")
|
||||
assert isinstance(sqla_type, types.String)
|
||||
assert sqla_type.length is None
|
||||
column_spec = PrestoEngineSpec.get_column_spec("varchar")
|
||||
assert isinstance(column_spec.sqla_type, types.String)
|
||||
assert column_spec.sqla_type.length is None
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
|
||||
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type("char(10)")
|
||||
assert isinstance(sqla_type, types.CHAR)
|
||||
assert sqla_type.length == 10
|
||||
column_spec = PrestoEngineSpec.get_column_spec("char(10)")
|
||||
assert isinstance(column_spec.sqla_type, types.CHAR)
|
||||
assert column_spec.sqla_type.length == 10
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
|
||||
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type("char")
|
||||
assert isinstance(sqla_type, types.CHAR)
|
||||
assert sqla_type.length is None
|
||||
column_spec = PrestoEngineSpec.get_column_spec("char")
|
||||
assert isinstance(column_spec.sqla_type, types.CHAR)
|
||||
assert column_spec.sqla_type.length is None
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.STRING)
|
||||
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type("integer")
|
||||
assert isinstance(sqla_type, types.Integer)
|
||||
column_spec = PrestoEngineSpec.get_column_spec("integer")
|
||||
assert isinstance(column_spec.sqla_type, types.Integer)
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.NUMERIC)
|
||||
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type("time")
|
||||
assert isinstance(sqla_type, types.Time)
|
||||
column_spec = PrestoEngineSpec.get_column_spec("time")
|
||||
assert isinstance(column_spec.sqla_type, types.Time)
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
|
||||
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type("timestamp")
|
||||
assert isinstance(sqla_type, types.TIMESTAMP)
|
||||
column_spec = PrestoEngineSpec.get_column_spec("timestamp")
|
||||
assert isinstance(column_spec.sqla_type, types.TIMESTAMP)
|
||||
self.assertEqual(column_spec.generic_type, GenericDataType.TEMPORAL)
|
||||
|
||||
sqla_type = PrestoEngineSpec.get_sqla_column_type(None)
|
||||
assert sqla_type is None
|
||||
|
||||
@@ -84,11 +84,9 @@ class TestDatabaseModel(SupersetTestCase):
|
||||
"TEXT": GenericDataType.STRING,
|
||||
"NTEXT": GenericDataType.STRING,
|
||||
# numeric
|
||||
"INT": GenericDataType.NUMERIC,
|
||||
"INTEGER": GenericDataType.NUMERIC,
|
||||
"BIGINT": GenericDataType.NUMERIC,
|
||||
"FLOAT": GenericDataType.NUMERIC,
|
||||
"DECIMAL": GenericDataType.NUMERIC,
|
||||
"MONEY": GenericDataType.NUMERIC,
|
||||
# temporal
|
||||
"DATE": GenericDataType.TEMPORAL,
|
||||
"DATETIME": GenericDataType.TEMPORAL,
|
||||
|
||||
Reference in New Issue
Block a user