mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
Match viz dataframe column case to form_data fields for Snowflake, Oracle and Redshift (#5487)
* Add function to fix dataframe column case * Fix broken handle_nulls method * Add case sensitivity option to dedup * Refactor function definition and call location * Remove added blank line * Move df column rename logit to db_engine_spec * Remove redundant variable * Update comments in db_engine_specs * Tie df adjustment to db_engine_spec class attribute * Fix dedup error * Linting * Check for db_engine_spec attribute prior to adjustment * Rename case sensitivity flag * Linting * Remove function that was moved to db_engine_specs * Get metrics names from utils * Remove double import and rename dedup variable
This commit is contained in:
committed by
Maxime Beauchemin
parent
aa9b30cf55
commit
e1f4db8e24
@@ -27,23 +27,26 @@ INFER_COL_TYPES_THRESHOLD = 95
|
||||
INFER_COL_TYPES_SAMPLE_SIZE = 100
|
||||
|
||||
|
||||
def dedup(l, suffix='__'):
|
||||
def dedup(l, suffix='__', case_sensitive=True):
|
||||
"""De-duplicates a list of string by suffixing a counter
|
||||
|
||||
Always returns the same number of entries as provided, and always returns
|
||||
unique values.
|
||||
unique values. Case sensitive comparison by default.
|
||||
|
||||
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar'])))
|
||||
foo,bar,bar__1,bar__2
|
||||
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'])))
|
||||
foo,bar,bar__1,bar__2,Bar
|
||||
>>> print(','.join(dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False)))
|
||||
foo,bar,bar__1,bar__2,Bar__3
|
||||
"""
|
||||
new_l = []
|
||||
seen = {}
|
||||
for s in l:
|
||||
if s in seen:
|
||||
seen[s] += 1
|
||||
s += suffix + str(seen[s])
|
||||
s_fixed_case = s if case_sensitive else s.lower()
|
||||
if s_fixed_case in seen:
|
||||
seen[s_fixed_case] += 1
|
||||
s += suffix + str(seen[s_fixed_case])
|
||||
else:
|
||||
seen[s] = 0
|
||||
seen[s_fixed_case] = 0
|
||||
new_l.append(s)
|
||||
return new_l
|
||||
|
||||
@@ -70,7 +73,9 @@ class SupersetDataFrame(object):
|
||||
if cursor_description:
|
||||
column_names = [col[0] for col in cursor_description]
|
||||
|
||||
self.column_names = dedup(column_names)
|
||||
case_sensitive = db_engine_spec.consistent_case_sensitivity
|
||||
self.column_names = dedup(column_names,
|
||||
case_sensitive=case_sensitive)
|
||||
|
||||
data = data or []
|
||||
self.df = (
|
||||
|
||||
@@ -101,6 +101,7 @@ class BaseEngineSpec(object):
|
||||
time_secondary_columns = False
|
||||
inner_joins = True
|
||||
allows_subquery = True
|
||||
consistent_case_sensitivity = True # do results have same case as qry for col names?
|
||||
|
||||
@classmethod
|
||||
def get_time_grains(cls):
|
||||
@@ -318,7 +319,6 @@ class BaseEngineSpec(object):
|
||||
|
||||
if show_cols:
|
||||
fields = [sqla.column(c.get('name')) for c in cols]
|
||||
full_table_name = table_name
|
||||
quote = engine.dialect.identifier_preparer.quote
|
||||
if schema:
|
||||
full_table_name = quote(schema) + '.' + quote(table_name)
|
||||
@@ -366,6 +366,57 @@ class BaseEngineSpec(object):
|
||||
def execute(cursor, query, async=False):
|
||||
cursor.execute(query)
|
||||
|
||||
@classmethod
|
||||
def adjust_df_column_names(cls, df, fd):
|
||||
"""Based of fields in form_data, return dataframe with new column names
|
||||
|
||||
Usually sqla engines return column names whose case matches that of the
|
||||
original query. For example:
|
||||
SELECT 1 as col1, 2 as COL2, 3 as Col_3
|
||||
will usually result in the following df.columns:
|
||||
['col1', 'COL2', 'Col_3'].
|
||||
For these engines there is no need to adjust the dataframe column names
|
||||
(default behavior). However, some engines (at least Snowflake, Oracle and
|
||||
Redshift) return column names with different case than in the original query,
|
||||
usually all uppercase. For these the column names need to be adjusted to
|
||||
correspond to the case of the fields specified in the form data for Viz
|
||||
to work properly. This adjustment can be done here.
|
||||
"""
|
||||
if cls.consistent_case_sensitivity:
|
||||
return df
|
||||
else:
|
||||
return cls.align_df_col_names_with_form_data(df, fd)
|
||||
|
||||
@staticmethod
|
||||
def align_df_col_names_with_form_data(df, fd):
|
||||
"""Helper function to rename columns that have changed case during query.
|
||||
|
||||
Returns a dataframe where column names have been adjusted to correspond with
|
||||
column names in form data (case insensitive). Examples:
|
||||
dataframe: 'col1', form_data: 'col1' -> no change
|
||||
dataframe: 'COL1', form_data: 'col1' -> dataframe column renamed: 'col1'
|
||||
dataframe: 'col1', form_data: 'Col1' -> dataframe column renamed: 'Col1'
|
||||
"""
|
||||
|
||||
columns = set()
|
||||
lowercase_mapping = {}
|
||||
|
||||
metrics = utils.get_metric_names(fd.get('metrics', []))
|
||||
groupby = fd.get('groupby', [])
|
||||
other_cols = [utils.DTTM_ALIAS]
|
||||
for col in metrics + groupby + other_cols:
|
||||
columns.add(col)
|
||||
lowercase_mapping[col.lower()] = col
|
||||
|
||||
rename_cols = {}
|
||||
for col in df.columns:
|
||||
if col not in columns:
|
||||
orig_col = lowercase_mapping.get(col.lower())
|
||||
if orig_col:
|
||||
rename_cols[col] = orig_col
|
||||
|
||||
return df.rename(index=str, columns=rename_cols)
|
||||
|
||||
|
||||
class PostgresBaseEngineSpec(BaseEngineSpec):
|
||||
""" Abstract class for Postgres 'like' databases """
|
||||
@@ -414,6 +465,7 @@ class PostgresEngineSpec(PostgresBaseEngineSpec):
|
||||
|
||||
class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'snowflake'
|
||||
consistent_case_sensitivity = False
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
'PT1S': "DATE_TRUNC('SECOND', {col})",
|
||||
@@ -434,6 +486,15 @@ class SnowflakeEngineSpec(PostgresBaseEngineSpec):
|
||||
'P1Y': "DATE_TRUNC('YEAR', {col})",
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def adjust_database_uri(cls, uri, selected_schema=None):
|
||||
database = uri.database
|
||||
if '/' in uri.database:
|
||||
database = uri.database.split('/')[0]
|
||||
if selected_schema:
|
||||
uri.database = database + '/' + selected_schema
|
||||
return uri
|
||||
|
||||
|
||||
class VerticaEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'vertica'
|
||||
@@ -441,11 +502,13 @@ class VerticaEngineSpec(PostgresBaseEngineSpec):
|
||||
|
||||
class RedshiftEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'redshift'
|
||||
consistent_case_sensitivity = False
|
||||
|
||||
|
||||
class OracleEngineSpec(PostgresBaseEngineSpec):
|
||||
engine = 'oracle'
|
||||
limit_method = LimitMethod.WRAP_SQL
|
||||
consistent_case_sensitivity = False
|
||||
|
||||
time_grain_functions = {
|
||||
None: '{col}',
|
||||
|
||||
@@ -153,7 +153,7 @@ class BaseViz(object):
|
||||
|
||||
def handle_nulls(self, df):
|
||||
fillna = self.get_fillna_for_columns(df.columns)
|
||||
df = df.fillna(fillna)
|
||||
return df.fillna(fillna)
|
||||
|
||||
def get_fillna_for_col(self, col):
|
||||
"""Returns the value to use as filler for a specific Column.type"""
|
||||
@@ -217,7 +217,7 @@ class BaseViz(object):
|
||||
self.df_metrics_to_num(df, query_obj.get('metrics') or [])
|
||||
|
||||
df.replace([np.inf, -np.inf], np.nan)
|
||||
self.handle_nulls(df)
|
||||
df = self.handle_nulls(df)
|
||||
return df
|
||||
|
||||
@staticmethod
|
||||
@@ -382,6 +382,9 @@ class BaseViz(object):
|
||||
if query_obj and not is_loaded:
|
||||
try:
|
||||
df = self.get_df(query_obj)
|
||||
if hasattr(self.datasource.database, 'db_engine_spec'):
|
||||
db_engine_spec = self.datasource.database.db_engine_spec
|
||||
df = db_engine_spec.adjust_df_column_names(df, self.form_data)
|
||||
if self.status != utils.QueryStatus.FAILED:
|
||||
stats_logger.incr('loaded_from_source')
|
||||
is_loaded = True
|
||||
|
||||
@@ -16,12 +16,16 @@ class SupersetDataFrameTestCase(SupersetTestCase):
|
||||
['foo', 'bar'],
|
||||
)
|
||||
self.assertEquals(
|
||||
dedup(['foo', 'bar', 'foo', 'bar']),
|
||||
['foo', 'bar', 'foo__1', 'bar__1'],
|
||||
dedup(['foo', 'bar', 'foo', 'bar', 'Foo']),
|
||||
['foo', 'bar', 'foo__1', 'bar__1', 'Foo'],
|
||||
)
|
||||
self.assertEquals(
|
||||
dedup(['foo', 'bar', 'bar', 'bar']),
|
||||
['foo', 'bar', 'bar__1', 'bar__2'],
|
||||
dedup(['foo', 'bar', 'bar', 'bar', 'Bar']),
|
||||
['foo', 'bar', 'bar__1', 'bar__2', 'Bar'],
|
||||
)
|
||||
self.assertEquals(
|
||||
dedup(['foo', 'bar', 'bar', 'bar', 'Bar'], case_sensitive=False),
|
||||
['foo', 'bar', 'bar__1', 'bar__2', 'Bar__3'],
|
||||
)
|
||||
|
||||
def test_get_columns_basic(self):
|
||||
|
||||
Reference in New Issue
Block a user