diff --git a/superset/assets/spec/javascripts/explore/components/SelectControl_spec.jsx b/superset/assets/spec/javascripts/explore/components/SelectControl_spec.jsx index bea8e1b4a37..ce20aab4627 100644 --- a/superset/assets/spec/javascripts/explore/components/SelectControl_spec.jsx +++ b/superset/assets/spec/javascripts/explore/components/SelectControl_spec.jsx @@ -30,6 +30,7 @@ const defaultProps = { choices: [['1 year ago', '1 year ago'], ['today', 'today']], name: 'row_limit', label: 'Row Limit', + valueKey: 'value', // shallow isn't passing SelectControl.defaultProps.valueKey through onChange: sinon.spy(), }; @@ -43,6 +44,7 @@ describe('SelectControl', () => { beforeEach(() => { wrapper = shallow(); + wrapper.setProps(defaultProps); }); it('renders an OnPasteSelect', () => { @@ -55,6 +57,23 @@ describe('SelectControl', () => { expect(defaultProps.onChange.calledWith(50)).toBe(true); }); + it('returns all options on select all', () => { + const expectedValues = ['one', 'two']; + const selectAllProps = { + multi: true, + allowAll: true, + choices: expectedValues, + name: 'row_limit', + label: 'Row Limit', + valueKey: 'value', + onChange: sinon.spy(), + }; + wrapper.setProps(selectAllProps); + const select = wrapper.find(OnPasteSelect); + select.simulate('change', [{ meta: true, value: 'Select All' }]); + expect(selectAllProps.onChange.calledWith(expectedValues)).toBe(true); + }); + it('passes VirtualizedSelect as selectWrap', () => { const select = wrapper.find(OnPasteSelect); expect(select.props().selectWrap).toBe(VirtualizedSelect); @@ -82,19 +101,39 @@ describe('SelectControl', () => { describe('getOptions', () => { it('returns the correct options', () => { + wrapper.setProps(defaultProps); expect(wrapper.instance().getOptions(defaultProps)).toEqual(options); }); + it('shows Select-All when enabled', () => { + const selectAllProps = { + choices: ['one', 'two'], + name: 'name', + freeForm: true, + allowAll: true, + multi: true, + valueKey: 'value', + }; + wrapper.setProps(selectAllProps); + expect(wrapper.instance().getOptions(selectAllProps)) + .toContainEqual({ label: 'Select All', meta: true, value: 'Select All' }); + }); + it('returns the correct options when freeform is set to true', () => { - const freeFormProps = Object.assign(defaultProps, { + const freeFormProps = { choices: [], freeForm: true, value: ['one', 'two'], - }); + name: 'row_limit', + label: 'Row Limit', + valueKey: 'value', + onChange: sinon.spy(), + }; const newOptions = [ { value: 'one', label: 'one' }, { value: 'two', label: 'two' }, ]; + wrapper.setProps(freeFormProps); expect(wrapper.instance().getOptions(freeFormProps)).toEqual(newOptions); }); }); diff --git a/superset/assets/src/explore/components/controls/SelectControl.jsx b/superset/assets/src/explore/components/controls/SelectControl.jsx index f38c1d696e5..e15bf7ab03b 100644 --- a/superset/assets/src/explore/components/controls/SelectControl.jsx +++ b/superset/assets/src/explore/components/controls/SelectControl.jsx @@ -35,6 +35,7 @@ const propTypes = { isLoading: PropTypes.bool, label: PropTypes.string, multi: PropTypes.bool, + allowAll: PropTypes.bool, name: PropTypes.string.isRequired, onChange: PropTypes.func, onFocus: PropTypes.func, @@ -48,6 +49,8 @@ const propTypes = { noResultsText: PropTypes.string, refFunc: PropTypes.func, filterOption: PropTypes.func, + promptTextCreator: PropTypes.func, + commaChoosesOption: PropTypes.bool, }; const defaultProps = { @@ -66,6 +69,9 @@ const defaultProps = { valueRenderer: opt => opt.label, valueKey: 'value', noResultsText: t('No results found'), + promptTextCreator: label => `Create Option ${label}`, + commaChoosesOption: true, + allowAll: false, }; export default class SelectControl extends React.PureComponent { @@ -73,7 +79,9 @@ export default class SelectControl extends React.PureComponent { super(props); this.state = { options: this.getOptions(props) }; this.onChange = this.onChange.bind(this); + this.createMetaSelectAllOption = this.createMetaSelectAllOption.bind(this); } + componentWillReceiveProps(nextProps) { if (nextProps.choices !== this.props.choices || nextProps.options !== this.props.options) { @@ -81,40 +89,56 @@ export default class SelectControl extends React.PureComponent { this.setState({ options }); } } + onChange(opt) { - let optionValue = opt ? opt[this.props.valueKey] : null; - // if multi, return options values as an array + let optionValue = null; + if (!opt) { + return; + } if (this.props.multi) { - optionValue = opt ? opt.map(o => o[this.props.valueKey]) : null; + optionValue = []; + for (const o of opt) { + if (o.meta === true) { + optionValue = this.getOptions(this.props) + .filter(x => !x.meta) + .map(x => x[this.props.valueKey]); + break; + } else { + optionValue.push(o[this.props.valueKey]); + } + } + } else if (opt.meta === true) { + return; + } else { + optionValue = opt[this.props.valueKey]; } this.props.onChange(optionValue); } + getOptions(props) { + let options = []; if (props.options) { - return props.options; + options = props.options.map(x => x); + } else { + // Accepts different formats of input + options = props.choices.map((c) => { + let option; + if (Array.isArray(c)) { + const label = c.length > 1 ? c[1] : c[0]; + option = { label }; + option[props.valueKey] = c[0]; + } else if (Object.is(c)) { + option = c; + } else { + option = { label: c }; + option[props.valueKey] = c; + } + return option; + }); } - // Accepts different formats of input - const options = props.choices.map((c) => { - let option; - if (Array.isArray(c)) { - const label = c.length > 1 ? c[1] : c[0]; - option = { - value: c[0], - label, - }; - } else if (Object.is(c)) { - option = c; - } else { - option = { - value: c, - label: c, - }; - } - return option; - }); if (props.freeForm) { // For FreeFormSelect, insert value into options if not exist - const values = options.map(c => c.value); + const values = options.map(c => c[props.valueKey]); if (props.value) { let valuesToAdd = props.value; if (!Array.isArray(valuesToAdd)) { @@ -122,13 +146,33 @@ export default class SelectControl extends React.PureComponent { } valuesToAdd.forEach((v) => { if (values.indexOf(v) < 0) { - options.push({ value: v, label: v }); + const toAdd = { label: v }; + toAdd[props.valueKey] = v; + options.push(toAdd); } }); } } + if (props.allowAll === true && props.multi === true) { + if (options.findIndex(o => this.isMetaSelectAllOption(o)) < 0) { + options.unshift(this.createMetaSelectAllOption()); + } + } else { + options = options.filter(o => !this.isMetaSelectAllOption(o)); + } return options; } + + isMetaSelectAllOption(o) { + return o.meta && o.meta === true && o.label === 'Select All'; + } + + createMetaSelectAllOption() { + const option = { label: 'Select All', meta: true }; + option[this.props.valueKey] = 'Select All'; + return option; + } + render() { // Tab, comma or Enter will trigger a new option created for FreeFormSelect const placeholder = this.props.placeholder || t('%s option(s)', this.state.options.length); @@ -148,11 +192,23 @@ export default class SelectControl extends React.PureComponent { optionRenderer: VirtualizedRendererWrap(this.props.optionRenderer), valueRenderer: this.props.valueRenderer, noResultsText: this.props.noResultsText, - selectComponent: this.props.freeForm ? Creatable : Select, disabled: this.props.disabled, refFunc: this.props.refFunc, filterOption: this.props.filterOption, + promptTextCreator: this.props.promptTextCreator, }; + if (this.props.freeForm) { + selectProps.selectComponent = Creatable; + selectProps.shouldKeyDownEventCreateNewOption = (key) => { + const keyCode = key.keyCode; + if (this.props.commaChoosesOption && keyCode === 188) { + return true; + } + return (keyCode === 9 || keyCode === 13); + }; + } else { + selectProps.selectComponent = Select; + } return (
{this.props.showHeader && diff --git a/superset/assets/src/explore/controls.jsx b/superset/assets/src/explore/controls.jsx index c8d8b7bed96..b5589565eb2 100644 --- a/superset/assets/src/explore/controls.jsx +++ b/superset/assets/src/explore/controls.jsx @@ -120,6 +120,7 @@ const sortAxisChoices = [ const groupByControl = { type: 'SelectControl', multi: true, + freeForm: true, label: t('Group by'), default: [], includeTime: false, @@ -127,10 +128,12 @@ const groupByControl = { optionRenderer: c => , valueRenderer: c => , valueKey: 'column_name', + allowAll: true, filterOption: (opt, text) => ( - (opt.column_name && opt.column_name.toLowerCase().indexOf(text) >= 0) || - (opt.verbose_name && opt.verbose_name.toLowerCase().indexOf(text) >= 0) + (opt.column_name && opt.column_name.toLowerCase().indexOf(text.toLowerCase()) >= 0) || + (opt.verbose_name && opt.verbose_name.toLowerCase().indexOf(text.toLowerCase()) >= 0) ), + promptTextCreator: label => label, mapStateToProps: (state, control) => { const newState = {}; if (state.datasource) { @@ -141,6 +144,7 @@ const groupByControl = { } return newState; }, + commaChoosesOption: false, }; const metrics = { @@ -625,9 +629,12 @@ export const controls = { optionRenderer: c => , valueRenderer: c => , valueKey: 'column_name', + allowAll: true, mapStateToProps: state => ({ options: (state.datasource) ? state.datasource.columns : [], }), + commaChoosesOption: false, + freeForm: true, }, spatial: { diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 41b5dc75a11..64b38cfd508 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. # pylint: disable=C,R,W +from collections import OrderedDict from datetime import datetime import logging @@ -601,26 +602,25 @@ class SqlaTable(Model, BaseDatasource): main_metric_expr = literal_column('COUNT(*)').label(label) select_exprs = [] - groupby_exprs = [] + groupby_exprs_sans_timestamp = OrderedDict() if groupby: select_exprs = [] - inner_select_exprs = [] - inner_groupby_exprs = [] for s in groupby: - col = cols[s] - outer = col.get_sqla_col() - inner = col.get_sqla_col(col.column_name + '__') + if s in cols: + outer = cols[s].get_sqla_col() + else: + outer = literal_column(f'({s})').label(self.get_label(s)) - groupby_exprs.append(outer) + groupby_exprs_sans_timestamp[outer.name] = outer select_exprs.append(outer) - inner_groupby_exprs.append(inner) - inner_select_exprs.append(inner) elif columns: for s in columns: - select_exprs.append(cols[s].get_sqla_col()) + select_exprs.append( + cols[s].get_sqla_col() if s in cols else literal_column(s)) metrics_exprs = [] + groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items()) if granularity: dttm_col = cols[granularity] time_grain = extras.get('time_grain_sqla') @@ -629,7 +629,7 @@ class SqlaTable(Model, BaseDatasource): if is_timeseries: timestamp = dttm_col.get_timestamp_expression(time_grain) select_exprs += [timestamp] - groupby_exprs += [timestamp] + groupby_exprs_with_timestamp[timestamp.name] = timestamp # Use main dttm column to support index with secondary dttm columns if db_engine_spec.time_secondary_columns and \ @@ -645,7 +645,7 @@ class SqlaTable(Model, BaseDatasource): tbl = self.get_from_clause(template_processor) if not columns: - qry = qry.group_by(*groupby_exprs) + qry = qry.group_by(*groupby_exprs_with_timestamp.values()) where_clause_and = [] having_clause_and = [] @@ -725,9 +725,15 @@ class SqlaTable(Model, BaseDatasource): # require a unique inner alias label = self.get_label('mme_inner__') inner_main_metric_expr = main_metric_expr.label(label) + inner_groupby_exprs = [] + inner_select_exprs = [] + for gby_name, gby_obj in groupby_exprs_sans_timestamp.items(): + inner = gby_obj.label(gby_name + '__') + inner_groupby_exprs.append(inner) + inner_select_exprs.append(inner) + inner_select_exprs += [inner_main_metric_expr] - subq = select(inner_select_exprs) - subq = subq.select_from(tbl) + subq = select(inner_select_exprs).select_from(tbl) inner_time_filter = dttm_col.get_time_filter( inner_from_dttm or from_dttm, inner_to_dttm or to_dttm, @@ -751,12 +757,12 @@ class SqlaTable(Model, BaseDatasource): subq = subq.limit(timeseries_limit) on_clause = [] - for i, gb in enumerate(groupby): + for gby_name, gby_obj in groupby_exprs_sans_timestamp.items(): # in this case the column name, not the alias, needs to be # conditionally mutated, as it refers to the column alias in # the inner query - col_name = self.get_label(gb + '__') - on_clause.append(groupby_exprs[i] == column(col_name)) + col_name = self.get_label(gby_name + '__') + on_clause.append(gby_obj == column(col_name)) tbl = tbl.join(subq.alias(), and_(*on_clause)) else: @@ -778,24 +784,23 @@ class SqlaTable(Model, BaseDatasource): 'order_desc': True, } result = self.query(subquery_obj) - cols = {col.column_name: col for col in self.columns} dimensions = [ c for c in result.df.columns - if c not in metrics and c in cols + if c not in metrics and c in groupby_exprs_sans_timestamp ] - top_groups = self._get_top_groups(result.df, dimensions) + top_groups = self._get_top_groups(result.df, + dimensions, + groupby_exprs_sans_timestamp) qry = qry.where(top_groups) return qry.select_from(tbl) - def _get_top_groups(self, df, dimensions): - cols = {col.column_name: col for col in self.columns} + def _get_top_groups(self, df, dimensions, groupby_exprs): groups = [] for unused, row in df.iterrows(): group = [] for dimension in dimensions: - col_obj = cols.get(dimension) - group.append(col_obj.get_sqla_col() == row[dimension]) + group.append(groupby_exprs[dimension] == row[dimension]) groups.append(and_(*group)) return or_(*groups) diff --git a/tests/base_tests.py b/tests/base_tests.py index 1f3b4c5a95a..e50e8e0e0d1 100644 --- a/tests/base_tests.py +++ b/tests/base_tests.py @@ -26,6 +26,7 @@ from superset import app, db, is_feature_enabled, security_manager from superset.connectors.druid.models import DruidCluster, DruidDatasource from superset.connectors.sqla.models import SqlaTable from superset.models import core as models +from superset.models.core import Database from superset.utils.core import get_main_database BASE_DIR = app.config.get('BASE_DIR') @@ -99,6 +100,9 @@ class SupersetTestCase(unittest.TestCase): def get_table_by_name(self, name): return db.session.query(SqlaTable).filter_by(table_name=name).one() + def get_database_by_id(self, db_id): + return db.session.query(Database).filter_by(id=db_id).one() + def get_druid_ds_by_name(self, name): return db.session.query(DruidDatasource).filter_by( datasource_name=name).first() diff --git a/tests/core_tests.py b/tests/core_tests.py index 514ef67f692..9b8983ab7f2 100644 --- a/tests/core_tests.py +++ b/tests/core_tests.py @@ -757,7 +757,6 @@ class CoreTests(SupersetTestCase): {'form_data': json.dumps(form_data)}, ) self.assertEqual(data['status'], utils.QueryStatus.FAILED) - assert 'KeyError' in data['stacktrace'] def test_slice_payload_viz_markdown(self): self.login(username='admin') diff --git a/tests/model_tests.py b/tests/model_tests.py index 9a9fc70ab6e..cc2b30c485e 100644 --- a/tests/model_tests.py +++ b/tests/model_tests.py @@ -16,11 +16,12 @@ # under the License. import textwrap +import pandas from sqlalchemy.engine.url import make_url from superset import app, db from superset.models.core import Database -from superset.utils.core import get_main_database +from superset.utils.core import get_main_database, QueryStatus from .base_tests import SupersetTestCase @@ -150,11 +151,13 @@ class SqlaTableModelTestCase(SupersetTestCase): if tbl.database.backend == 'mysql': self.assertEquals(compiled, 'DATE(ds)') + prev_ds_expr = ds_col.expression ds_col.expression = 'DATE_ADD(ds, 1)' sqla_literal = ds_col.get_timestamp_expression('P1D') compiled = '{}'.format(sqla_literal.compile()) if tbl.database.backend == 'mysql': self.assertEquals(compiled, 'DATE(DATE_ADD(ds, 1))') + ds_col.expression = prev_ds_expr def test_get_timestamp_expression_epoch(self): tbl = self.get_table_by_name('birth_names') @@ -173,11 +176,13 @@ class SqlaTableModelTestCase(SupersetTestCase): if tbl.database.backend == 'mysql': self.assertEquals(compiled, 'DATE(from_unixtime(ds))') + prev_ds_expr = ds_col.expression ds_col.expression = 'DATE_ADD(ds, 1)' sqla_literal = ds_col.get_timestamp_expression('P1D') compiled = '{}'.format(sqla_literal.compile()) if tbl.database.backend == 'mysql': self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))') + ds_col.expression = prev_ds_expr def test_get_timestamp_expression_backward(self): tbl = self.get_table_by_name('birth_names') @@ -197,6 +202,63 @@ class SqlaTableModelTestCase(SupersetTestCase): if tbl.database.backend == 'mysql': self.assertEquals(compiled, 'ds') + def query_with_expr_helper(self, is_timeseries, inner_join=True): + tbl = self.get_table_by_name('birth_names') + ds_col = tbl.get_column('ds') + ds_col.expression = None + ds_col.python_date_format = None + spec = self.get_database_by_id(tbl.database_id).db_engine_spec + if not spec.inner_joins and inner_join: + # if the db does not support inner joins, we cannot force it so + return None + old_inner_join = spec.inner_joins + spec.inner_joins = inner_join + arbitrary_gby = "state || gender || '_test'" + arbitrary_metric = (dict(label='arbitrary', expressionType='SQL', + sqlExpression='COUNT(1)')) + query_obj = dict( + groupby=[arbitrary_gby, 'name'], + metrics=[arbitrary_metric], + filter=[], + is_timeseries=is_timeseries, + prequeries=[], + columns=[], + granularity='ds', + from_dttm=None, + to_dttm=None, + is_prequery=False, + extras=dict(time_grain_sqla='P1Y'), + ) + qr = tbl.query(query_obj) + self.assertEqual(qr.status, QueryStatus.SUCCESS) + sql = qr.query + self.assertIn(arbitrary_gby, sql) + self.assertIn('name', sql) + if inner_join and is_timeseries: + self.assertIn('JOIN', sql.upper()) + else: + self.assertNotIn('JOIN', sql.upper()) + spec.inner_joins = old_inner_join + self.assertIsNotNone(qr.df) + return qr.df + + def test_query_with_expr_groupby_timeseries(self): + def cannonicalize_df(df): + ret = df.sort_values(by=list(df.columns.values), inplace=False) + ret.reset_index(inplace=True, drop=True) + return ret + + df1 = self.query_with_expr_helper(is_timeseries=True, inner_join=True) + df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False) + self.assertIsNotNone(df2) # df1 can be none if the db does not support join + if df1 is not None: + pandas.testing.assert_frame_equal( + cannonicalize_df(df1), + cannonicalize_df(df2)) + + def test_query_with_expr_groupby(self): + self.query_with_expr_helper(is_timeseries=False) + def test_sql_mutator(self): tbl = self.get_table_by_name('birth_names') query_obj = dict(