diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 63225f3e2d0..87a6b44295b 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -10,6 +10,7 @@ from time import sleep import uuid from celery.exceptions import SoftTimeLimitExceeded +import numpy as np import pandas as pd import sqlalchemy from sqlalchemy.orm import sessionmaker @@ -85,6 +86,26 @@ def get_session(nullpool): return session +def convert_results_to_df(cursor_description, data): + """Convert raw query results to a DataFrame.""" + column_names = ( + [col[0] for col in cursor_description] if cursor_description else []) + column_names = dedup(column_names) + + # check whether the result set has any nested dict columns + if data: + first_row = data[0] + has_dict_col = any([isinstance(c, dict) for c in first_row]) + df_data = list(data) if has_dict_col else np.array(data) + else: + df_data = [] + + cdf = dataframe.SupersetDataFrame( + pd.DataFrame(df_data, columns=column_names)) + + return cdf + + @celery_app.task(bind=True, soft_time_limit=SQLLAB_TIMEOUT) def get_sql_results( ctask, query_id, return_results=True, store_results=False, @@ -224,11 +245,7 @@ def execute_sql( }, default=utils.json_iso_dttm_ser) - column_names = ( - [col[0] for col in cursor_description] if cursor_description else []) - column_names = dedup(column_names) - cdf = dataframe.SupersetDataFrame( - pd.DataFrame(list(data), columns=column_names)) + cdf = convert_results_to_df(cursor_description, data) query.rows = cdf.size query.progress = 100 diff --git a/tests/sqllab_tests.py b/tests/sqllab_tests.py index 2caf4c2ae16..53144eadacd 100644 --- a/tests/sqllab_tests.py +++ b/tests/sqllab_tests.py @@ -12,6 +12,7 @@ from flask_appbuilder.security.sqla import models as ab_models from superset import appbuilder, db, sm, utils from superset.models.sql_lab import Query +from superset.sql_lab import convert_results_to_df from .base_tests import SupersetTestCase @@ -200,6 +201,22 @@ class SqlLabTests(SupersetTestCase): user_name='admin', raise_on_error=True) + def test_df_conversion_no_dict(self): + cols = [['string_col'], ['int_col']] + data = [['a', 4]] + cdf = convert_results_to_df(cols, data) + + self.assertEquals(len(data), cdf.size) + self.assertEquals(len(cols), len(cdf.columns)) + + def test_df_conversion_dict(self): + cols = [['string_col'], ['dict_col'], ['int_col']] + data = [['a', {'c1': 1, 'c2': 2, 'c3': 3}, 4]] + cdf = convert_results_to_df(cols, data) + + self.assertEquals(len(data), cdf.size) + self.assertEquals(len(cols), len(cdf.columns)) + if __name__ == '__main__': unittest.main()