Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift (#5487)

* Add function to fix dataframe column case

* Fix broken handle_nulls method

* Add case sensitivity option to dedup

* Refactor function definition and call location

* Remove added blank line

* Move df column rename logit to db_engine_spec

* Remove redundant variable

* Update comments in db_engine_specs

* Tie df adjustment to db_engine_spec class attribute

* Fix dedup error

* Linting

* Check for db_engine_spec attribute prior to adjustment

* Rename case sensitivity flag

* Linting

* Remove function that was moved to db_engine_specs

* Get metrics names from utils

* Remove double import and rename dedup variable
This commit is contained in:
Ville Brofeldt
2018-08-03 19:53:56 +03:00
committed by Maxime Beauchemin
parent aa9b30cf55
commit e1f4db8e24
4 changed files with 91 additions and 16 deletions

View File

@@ -27,23 +27,26 @@ INFER_COL_TYPES_THRESHOLD = 95
INFER_COL_TYPES_SAMPLE_SIZE = 100
def dedup(l, suffix='__'):
def dedup(l, suffix='__', case_sensitive=True):
"""De-duplicates a list of string by suffixing a counter
Always returns the same number of entries as provided, and always returns
unique values.
unique values. Case sensitive comparison by default.
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
foo,bar,bar__1,bar__2
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'])))
foo,bar,bar__1,bar__2,Bar
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False)))
foo,bar,bar__1,bar__2,Bar__3
"""
new_l = []
seen = {}
for s in l:
if s in seen:
seen[s] += 1
s += suffix + str(seen[s])
s_fixed_case = s if case_sensitive else s.lower()
if s_fixed_case in seen:
seen[s_fixed_case] += 1
s += suffix + str(seen[s_fixed_case])
else:
seen[s] = 0
seen[s_fixed_case] = 0
new_l.append(s)
return new_l
@@ -70,7 +73,9 @@ class SupersetDataFrame(object):
if cursor_description:
column_names = [col[0] for col in cursor_description]
self.column_names = dedup(column_names)
case_sensitive = db_engine_spec.consistent_case_sensitivity
self.column_names = dedup(column_names,
case_sensitive=case_sensitive)
data = data or []
self.df = (

View File

@@ -101,6 +101,7 @@ class BaseEngineSpec(object):
time_secondary_columns = False
inner_joins = True
allows_subquery = True
consistent_case_sensitivity = True # do results have same case as qry for col names?
@classmethod
def get_time_grains(cls):
@@ -318,7 +319,6 @@ class BaseEngineSpec(object):
if show_cols:
fields = [sqla.column(c.get('name')) for c in cols]
full_table_name = table_name
quote = engine.dialect.identifier_preparer.quote
if schema:
full_table_name = quote(schema) + '.' + quote(table_name)
@@ -366,6 +366,57 @@ class BaseEngineSpec(object):
def execute(cursor, query, async=False):
cursor.execute(query)
@classmethod
def adjust_df_column_names(cls, df, fd):
"""Based of fields in form_data, return dataframe with new column names
Usually sqla engines return column names whose case matches that of the
original query. For example:
SELECT 1 as col1, 2 as COL2, 3 as Col_3
will usually result in the following df.columns:
['col1', 'COL2', 'Col_3'].
For these engines there is no need to adjust the dataframe column names
(default behavior). However, some engines (at least Snowflake, Oracle and
Redshift) return column names with different case than in the original query,
usually all uppercase. For these the column names need to be adjusted to
correspond to the case of the fields specified in the form data for Viz
to work properly. This adjustment can be done here.
"""
if cls.consistent_case_sensitivity:
return df
else:
return cls.align_df_col_names_with_form_data(df, fd)
@staticmethod
def align_df_col_names_with_form_data(df, fd):
"""Helper function to rename columns that have changed case during query.
Returns a dataframe where column names have been adjusted to correspond with
column names in form data (case insensitive). Examples:
dataframe: 'col1', form_data: 'col1' -> no change
dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
"""
columns = set()
lowercase_mapping = {}
metrics = utils.get_metric_names(fd.get('metrics', []))
groupby = fd.get('groupby', [])
other_cols = [utils.DTTM_ALIAS]
for col in metrics + groupby + other_cols:
columns.add(col)
lowercase_mapping[col.lower()] = col
rename_cols = {}
for col in df.columns:
if col not in columns:
orig_col = lowercase_mapping.get(col.lower())
if orig_col:
rename_cols[col] = orig_col
return df.rename(index=str, columns=rename_cols)
class PostgresBaseEngineSpec(BaseEngineSpec):
""" Abstract class for Postgres 'like' databases """
@@ -414,6 +465,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
engine = 'snowflake'
consistent_case_sensitivity = False
time_grain_functions = {
None: '{col}',
'PT1S': "DATE_TRUNC('SECOND', {col})",
@@ -434,6 +486,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
'P1Y': "DATE_TRUNC('YEAR', {col})",
}
@classmethod
def adjust_database_uri(cls, uri, selected_schema=None):
database = uri.database
if '/' in uri.database:
database = uri.database.split('/')[0]
if selected_schema:
uri.database = database + '/' + selected_schema
return uri
class VerticaEngineSpec(PostgresBaseEngineSpec):
engine = 'vertica'
@@ -441,11 +502,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
class RedshiftEngineSpec(PostgresBaseEngineSpec):
engine = 'redshift'
consistent_case_sensitivity = False
class OracleEngineSpec(PostgresBaseEngineSpec):
engine = 'oracle'
limit_method = LimitMethod.WRAP_SQL
consistent_case_sensitivity = False
time_grain_functions = {
None: '{col}',

View File

@@ -153,7 +153,7 @@ class BaseViz(object):
def handle_nulls(self, df):
fillna = self.get_fillna_for_columns(df.columns)
df = df.fillna(fillna)
return df.fillna(fillna)
def get_fillna_for_col(self, col):
"""Returns the value to use as filler for a specific Column.type"""
@@ -217,7 +217,7 @@ class BaseViz(object):
self.df_metrics_to_num(df, query_obj.get('metrics') or [])
df.replace([np.inf, -np.inf], np.nan)
self.handle_nulls(df)
df = self.handle_nulls(df)
return df
@staticmethod
@@ -382,6 +382,9 @@ class BaseViz(object):
if query_obj and not is_loaded:
try:
df = self.get_df(query_obj)
if hasattr(self.datasource.database, 'db_engine_spec'):
db_engine_spec = self.datasource.database.db_engine_spec
df = db_engine_spec.adjust_df_column_names(df, self.form_data)
if self.status != utils.QueryStatus.FAILED:
stats_logger.incr('loaded_from_source')
is_loaded = True

View File

@@ -16,12 +16,16 @@ class SupersetDataFrameTestCase(SupersetTestCase):
['foo', 'bar'],
)
self.assertEquals(
dedup(['foo', 'bar', 'foo', 'bar']),
['foo', 'bar', 'foo__1', 'bar__1'],
dedup(['foo', 'bar', 'foo', 'bar', 'Foo']),
['foo', 'bar', 'foo__1', 'bar__1', 'Foo'],
)
self.assertEquals(
dedup(['foo', 'bar', 'bar', 'bar']),
['foo', 'bar', 'bar__1', 'bar__2'],
dedup(['foo', 'bar', 'bar', 'bar', 'Bar']),
['foo', 'bar', 'bar__1', 'bar__2', 'Bar'],
)
self.assertEquals(
dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False),
['foo', 'bar', 'bar__1', 'bar__2', 'Bar__3'],
)
def test_get_columns_basic(self):