Improve database type inference (#4724)

* Improve database type inference

Python's DBAPI isn't super clear and homogeneous on the
cursor.description specification, and this PR attempts to improve
inferring the datatypes returned in the cursor.

This work started around Presto's TIMESTAMP type being mishandled as
string as the database driver (pyhive) returns it as a string. The work
here fixes this bug and does a better job at inferring MySQL and Presto types.
It also creates a new method in db_engine_specs allowing for other
databases engines to implement and become more precise on type-inference
as needed.

* Fixing tests

* Adressing comments

* Using infer_objects

* Removing faulty line

* Addressing PrestoSpec redundant method comment

* Fix rebase issue

* Fix tests
This commit is contained in:
Maxime Beauchemin
2018-06-27 21:35:12 -07:00
committed by GitHub
parent 04fc1d1089
commit 777d876a52
8 changed files with 224 additions and 117 deletions

View File

@@ -10,8 +10,6 @@ import uuid
from celery.exceptions import SoftTimeLimitExceeded
from contextlib2 import contextmanager
import numpy as np
import pandas as pd
import sqlalchemy
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import NullPool
@@ -31,27 +29,6 @@ class SqlLabException(Exception):
pass
def dedup(l, suffix='__'):
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
unique values.
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
foo,bar,bar__1,bar__2
"""
new_l = []
seen = {}
for s in l:
if s in seen:
seen[s] += 1
s += suffix + str(seen[s])
else:
seen[s] = 0
new_l.append(s)
return new_l
def get_query(query_id, session, retry_count=5):
"""attemps to get the query and retry if it cannot"""
query = None
@@ -96,24 +73,6 @@ def session_scope(nullpool):
session.close()
def convert_results_to_df(column_names, data):
"""Convert raw query results to a DataFrame."""
column_names = dedup(column_names)
# check whether the result set has any nested dict columns
if data:
first_row = data[0]
has_dict_col = any([isinstance(c, dict) for c in first_row])
df_data = list(data) if has_dict_col else np.array(data, dtype=object)
else:
df_data = []
cdf = dataframe.SupersetDataFrame(
pd.DataFrame(df_data, columns=column_names))
return cdf
@celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT)
def get_sql_results(
ctask, query_id, rendered_query, return_results=True, store_results=False,
@@ -233,7 +192,6 @@ def execute_sql(
return handle_error(db_engine_spec.extract_error_message(e))
logging.info('Fetching cursor description')
column_names = db_engine_spec.get_normalized_column_names(cursor.description)
if conn is not None:
conn.commit()
@@ -242,7 +200,7 @@ def execute_sql(
if query.status == utils.QueryStatus.STOPPED:
return handle_error('The query has been stopped')
cdf = convert_results_to_df(column_names, data)
cdf = dataframe.SupersetDataFrame(data, cursor.description, db_engine_spec)
query.rows = cdf.size
query.progress = 100