mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
Improve database type inference (#4724)
* Improve database type inference Python's DBAPI isn't super clear and homogeneous on the cursor.description specification, and this PR attempts to improve inferring the datatypes returned in the cursor. This work started around Presto's TIMESTAMP type being mishandled as string as the database driver (pyhive) returns it as a string. The work here fixes this bug and does a better job at inferring MySQL and Presto types. It also creates a new method in db_engine_specs allowing for other databases engines to implement and become more precise on type-inference as needed. * Fixing tests * Adressing comments * Using infer_objects * Removing faulty line * Addressing PrestoSpec redundant method comment * Fix rebase issue * Fix tests
This commit is contained in:
committed by
GitHub
parent
04fc1d1089
commit
777d876a52
@@ -10,8 +10,6 @@ import uuid
|
||||
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from contextlib2 import contextmanager
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import NullPool
|
||||
@@ -31,27 +29,6 @@ class SqlLabException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def dedup(l, suffix='__'):
|
||||
"""De-duplicates a list of string by suffixing a counter
|
||||
|
||||
Always returns the same number of entries as provided, and always returns
|
||||
unique values.
|
||||
|
||||
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
|
||||
foo,bar,bar__1,bar__2
|
||||
"""
|
||||
new_l = []
|
||||
seen = {}
|
||||
for s in l:
|
||||
if s in seen:
|
||||
seen[s] += 1
|
||||
s += suffix + str(seen[s])
|
||||
else:
|
||||
seen[s] = 0
|
||||
new_l.append(s)
|
||||
return new_l
|
||||
|
||||
|
||||
def get_query(query_id, session, retry_count=5):
|
||||
"""attemps to get the query and retry if it cannot"""
|
||||
query = None
|
||||
@@ -96,24 +73,6 @@ def session_scope(nullpool):
|
||||
session.close()
|
||||
|
||||
|
||||
def convert_results_to_df(column_names, data):
|
||||
"""Convert raw query results to a DataFrame."""
|
||||
column_names = dedup(column_names)
|
||||
|
||||
# check whether the result set has any nested dict columns
|
||||
if data:
|
||||
first_row = data[0]
|
||||
has_dict_col = any([isinstance(c, dict) for c in first_row])
|
||||
df_data = list(data) if has_dict_col else np.array(data, dtype=object)
|
||||
else:
|
||||
df_data = []
|
||||
|
||||
cdf = dataframe.SupersetDataFrame(
|
||||
pd.DataFrame(df_data, columns=column_names))
|
||||
|
||||
return cdf
|
||||
|
||||
|
||||
@celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT)
|
||||
def get_sql_results(
|
||||
ctask, query_id, rendered_query, return_results=True, store_results=False,
|
||||
@@ -233,7 +192,6 @@ def execute_sql(
|
||||
return handle_error(db_engine_spec.extract_error_message(e))
|
||||
|
||||
logging.info('Fetching cursor description')
|
||||
column_names = db_engine_spec.get_normalized_column_names(cursor.description)
|
||||
|
||||
if conn is not None:
|
||||
conn.commit()
|
||||
@@ -242,7 +200,7 @@ def execute_sql(
|
||||
if query.status == utils.QueryStatus.STOPPED:
|
||||
return handle_error('The query has been stopped')
|
||||
|
||||
cdf = convert_results_to_df(column_names, data)
|
||||
cdf = dataframe.SupersetDataFrame(data, cursor.description, db_engine_spec)
|
||||
|
||||
query.rows = cdf.size
|
||||
query.progress = 100
|
||||
|
||||
Reference in New Issue
Block a user