diff --git a/superset/assets/spec/javascripts/explore/components/SelectControl_spec.jsx b/superset/assets/spec/javascripts/explore/components/SelectControl_spec.jsx
index bea8e1b4a37..ce20aab4627 100644
--- a/superset/assets/spec/javascripts/explore/components/SelectControl_spec.jsx
+++ b/superset/assets/spec/javascripts/explore/components/SelectControl_spec.jsx
@@ -30,6 +30,7 @@ const defaultProps = {
choices: [['1 year ago', '1 year ago'], ['today', 'today']],
name: 'row_limit',
label: 'Row Limit',
+ valueKey: 'value', // shallow isn't passing SelectControl.defaultProps.valueKey through
onChange: sinon.spy(),
};
@@ -43,6 +44,7 @@ describe('SelectControl', () => {
beforeEach(() => {
wrapper = shallow();
+ wrapper.setProps(defaultProps);
});
it('renders an OnPasteSelect', () => {
@@ -55,6 +57,23 @@ describe('SelectControl', () => {
expect(defaultProps.onChange.calledWith(50)).toBe(true);
});
+ it('returns all options on select all', () => {
+ const expectedValues = ['one', 'two'];
+ const selectAllProps = {
+ multi: true,
+ allowAll: true,
+ choices: expectedValues,
+ name: 'row_limit',
+ label: 'Row Limit',
+ valueKey: 'value',
+ onChange: sinon.spy(),
+ };
+ wrapper.setProps(selectAllProps);
+ const select = wrapper.find(OnPasteSelect);
+ select.simulate('change', [{ meta: true, value: 'Select All' }]);
+ expect(selectAllProps.onChange.calledWith(expectedValues)).toBe(true);
+ });
+
it('passes VirtualizedSelect as selectWrap', () => {
const select = wrapper.find(OnPasteSelect);
expect(select.props().selectWrap).toBe(VirtualizedSelect);
@@ -82,19 +101,39 @@ describe('SelectControl', () => {
describe('getOptions', () => {
it('returns the correct options', () => {
+ wrapper.setProps(defaultProps);
expect(wrapper.instance().getOptions(defaultProps)).toEqual(options);
});
+ it('shows Select-All when enabled', () => {
+ const selectAllProps = {
+ choices: ['one', 'two'],
+ name: 'name',
+ freeForm: true,
+ allowAll: true,
+ multi: true,
+ valueKey: 'value',
+ };
+ wrapper.setProps(selectAllProps);
+ expect(wrapper.instance().getOptions(selectAllProps))
+ .toContainEqual({ label: 'Select All', meta: true, value: 'Select All' });
+ });
+
it('returns the correct options when freeform is set to true', () => {
- const freeFormProps = Object.assign(defaultProps, {
+ const freeFormProps = {
choices: [],
freeForm: true,
value: ['one', 'two'],
- });
+ name: 'row_limit',
+ label: 'Row Limit',
+ valueKey: 'value',
+ onChange: sinon.spy(),
+ };
const newOptions = [
{ value: 'one', label: 'one' },
{ value: 'two', label: 'two' },
];
+ wrapper.setProps(freeFormProps);
expect(wrapper.instance().getOptions(freeFormProps)).toEqual(newOptions);
});
});
diff --git a/superset/assets/src/explore/components/controls/SelectControl.jsx b/superset/assets/src/explore/components/controls/SelectControl.jsx
index f38c1d696e5..e15bf7ab03b 100644
--- a/superset/assets/src/explore/components/controls/SelectControl.jsx
+++ b/superset/assets/src/explore/components/controls/SelectControl.jsx
@@ -35,6 +35,7 @@ const propTypes = {
isLoading: PropTypes.bool,
label: PropTypes.string,
multi: PropTypes.bool,
+ allowAll: PropTypes.bool,
name: PropTypes.string.isRequired,
onChange: PropTypes.func,
onFocus: PropTypes.func,
@@ -48,6 +49,8 @@ const propTypes = {
noResultsText: PropTypes.string,
refFunc: PropTypes.func,
filterOption: PropTypes.func,
+ promptTextCreator: PropTypes.func,
+ commaChoosesOption: PropTypes.bool,
};
const defaultProps = {
@@ -66,6 +69,9 @@ const defaultProps = {
valueRenderer: opt => opt.label,
valueKey: 'value',
noResultsText: t('No results found'),
+ promptTextCreator: label => `Create Option ${label}`,
+ commaChoosesOption: true,
+ allowAll: false,
};
export default class SelectControl extends React.PureComponent {
@@ -73,7 +79,9 @@ export default class SelectControl extends React.PureComponent {
super(props);
this.state = { options: this.getOptions(props) };
this.onChange = this.onChange.bind(this);
+ this.createMetaSelectAllOption = this.createMetaSelectAllOption.bind(this);
}
+
componentWillReceiveProps(nextProps) {
if (nextProps.choices !== this.props.choices ||
nextProps.options !== this.props.options) {
@@ -81,40 +89,56 @@ export default class SelectControl extends React.PureComponent {
this.setState({ options });
}
}
+
onChange(opt) {
- let optionValue = opt ? opt[this.props.valueKey] : null;
- // if multi, return options values as an array
+ let optionValue = null;
+ if (!opt) {
+ return;
+ }
if (this.props.multi) {
- optionValue = opt ? opt.map(o => o[this.props.valueKey]) : null;
+ optionValue = [];
+ for (const o of opt) {
+ if (o.meta === true) {
+ optionValue = this.getOptions(this.props)
+ .filter(x => !x.meta)
+ .map(x => x[this.props.valueKey]);
+ break;
+ } else {
+ optionValue.push(o[this.props.valueKey]);
+ }
+ }
+ } else if (opt.meta === true) {
+ return;
+ } else {
+ optionValue = opt[this.props.valueKey];
}
this.props.onChange(optionValue);
}
+
getOptions(props) {
+ let options = [];
if (props.options) {
- return props.options;
+ options = props.options.map(x => x);
+ } else {
+ // Accepts different formats of input
+ options = props.choices.map((c) => {
+ let option;
+ if (Array.isArray(c)) {
+ const label = c.length > 1 ? c[1] : c[0];
+ option = { label };
+ option[props.valueKey] = c[0];
+ } else if (Object.is(c)) {
+ option = c;
+ } else {
+ option = { label: c };
+ option[props.valueKey] = c;
+ }
+ return option;
+ });
}
- // Accepts different formats of input
- const options = props.choices.map((c) => {
- let option;
- if (Array.isArray(c)) {
- const label = c.length > 1 ? c[1] : c[0];
- option = {
- value: c[0],
- label,
- };
- } else if (Object.is(c)) {
- option = c;
- } else {
- option = {
- value: c,
- label: c,
- };
- }
- return option;
- });
if (props.freeForm) {
// For FreeFormSelect, insert value into options if not exist
- const values = options.map(c => c.value);
+ const values = options.map(c => c[props.valueKey]);
if (props.value) {
let valuesToAdd = props.value;
if (!Array.isArray(valuesToAdd)) {
@@ -122,13 +146,33 @@ export default class SelectControl extends React.PureComponent {
}
valuesToAdd.forEach((v) => {
if (values.indexOf(v) < 0) {
- options.push({ value: v, label: v });
+ const toAdd = { label: v };
+ toAdd[props.valueKey] = v;
+ options.push(toAdd);
}
});
}
}
+ if (props.allowAll === true && props.multi === true) {
+ if (options.findIndex(o => this.isMetaSelectAllOption(o)) < 0) {
+ options.unshift(this.createMetaSelectAllOption());
+ }
+ } else {
+ options = options.filter(o => !this.isMetaSelectAllOption(o));
+ }
return options;
}
+
+ isMetaSelectAllOption(o) {
+ return o.meta && o.meta === true && o.label === 'Select All';
+ }
+
+ createMetaSelectAllOption() {
+ const option = { label: 'Select All', meta: true };
+ option[this.props.valueKey] = 'Select All';
+ return option;
+ }
+
render() {
// Tab, comma or Enter will trigger a new option created for FreeFormSelect
const placeholder = this.props.placeholder || t('%s option(s)', this.state.options.length);
@@ -148,11 +192,23 @@ export default class SelectControl extends React.PureComponent {
optionRenderer: VirtualizedRendererWrap(this.props.optionRenderer),
valueRenderer: this.props.valueRenderer,
noResultsText: this.props.noResultsText,
- selectComponent: this.props.freeForm ? Creatable : Select,
disabled: this.props.disabled,
refFunc: this.props.refFunc,
filterOption: this.props.filterOption,
+ promptTextCreator: this.props.promptTextCreator,
};
+ if (this.props.freeForm) {
+ selectProps.selectComponent = Creatable;
+ selectProps.shouldKeyDownEventCreateNewOption = (key) => {
+ const keyCode = key.keyCode;
+ if (this.props.commaChoosesOption && keyCode === 188) {
+ return true;
+ }
+ return (keyCode === 9 || keyCode === 13);
+ };
+ } else {
+ selectProps.selectComponent = Select;
+ }
return (
{this.props.showHeader &&
diff --git a/superset/assets/src/explore/controls.jsx b/superset/assets/src/explore/controls.jsx
index c8d8b7bed96..b5589565eb2 100644
--- a/superset/assets/src/explore/controls.jsx
+++ b/superset/assets/src/explore/controls.jsx
@@ -120,6 +120,7 @@ const sortAxisChoices = [
const groupByControl = {
type: 'SelectControl',
multi: true,
+ freeForm: true,
label: t('Group by'),
default: [],
includeTime: false,
@@ -127,10 +128,12 @@ const groupByControl = {
optionRenderer: c => ,
valueRenderer: c => ,
valueKey: 'column_name',
+ allowAll: true,
filterOption: (opt, text) => (
- (opt.column_name && opt.column_name.toLowerCase().indexOf(text) >= 0) ||
- (opt.verbose_name && opt.verbose_name.toLowerCase().indexOf(text) >= 0)
+ (opt.column_name && opt.column_name.toLowerCase().indexOf(text.toLowerCase()) >= 0) ||
+ (opt.verbose_name && opt.verbose_name.toLowerCase().indexOf(text.toLowerCase()) >= 0)
),
+ promptTextCreator: label => label,
mapStateToProps: (state, control) => {
const newState = {};
if (state.datasource) {
@@ -141,6 +144,7 @@ const groupByControl = {
}
return newState;
},
+ commaChoosesOption: false,
};
const metrics = {
@@ -625,9 +629,12 @@ export const controls = {
optionRenderer: c => ,
valueRenderer: c => ,
valueKey: 'column_name',
+ allowAll: true,
mapStateToProps: state => ({
options: (state.datasource) ? state.datasource.columns : [],
}),
+ commaChoosesOption: false,
+ freeForm: true,
},
spatial: {
diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py
index 41b5dc75a11..64b38cfd508 100644
--- a/superset/connectors/sqla/models.py
+++ b/superset/connectors/sqla/models.py
@@ -15,6 +15,7 @@
# specific language governing permissions and limitations
# under the License.
# pylint: disable=C,R,W
+from collections import OrderedDict
from datetime import datetime
import logging
@@ -601,26 +602,25 @@ class SqlaTable(Model, BaseDatasource):
main_metric_expr = literal_column('COUNT(*)').label(label)
select_exprs = []
- groupby_exprs = []
+ groupby_exprs_sans_timestamp = OrderedDict()
if groupby:
select_exprs = []
- inner_select_exprs = []
- inner_groupby_exprs = []
for s in groupby:
- col = cols[s]
- outer = col.get_sqla_col()
- inner = col.get_sqla_col(col.column_name + '__')
+ if s in cols:
+ outer = cols[s].get_sqla_col()
+ else:
+ outer = literal_column(f'({s})').label(self.get_label(s))
- groupby_exprs.append(outer)
+ groupby_exprs_sans_timestamp[outer.name] = outer
select_exprs.append(outer)
- inner_groupby_exprs.append(inner)
- inner_select_exprs.append(inner)
elif columns:
for s in columns:
- select_exprs.append(cols[s].get_sqla_col())
+ select_exprs.append(
+ cols[s].get_sqla_col() if s in cols else literal_column(s))
metrics_exprs = []
+ groupby_exprs_with_timestamp = OrderedDict(groupby_exprs_sans_timestamp.items())
if granularity:
dttm_col = cols[granularity]
time_grain = extras.get('time_grain_sqla')
@@ -629,7 +629,7 @@ class SqlaTable(Model, BaseDatasource):
if is_timeseries:
timestamp = dttm_col.get_timestamp_expression(time_grain)
select_exprs += [timestamp]
- groupby_exprs += [timestamp]
+ groupby_exprs_with_timestamp[timestamp.name] = timestamp
# Use main dttm column to support index with secondary dttm columns
if db_engine_spec.time_secondary_columns and \
@@ -645,7 +645,7 @@ class SqlaTable(Model, BaseDatasource):
tbl = self.get_from_clause(template_processor)
if not columns:
- qry = qry.group_by(*groupby_exprs)
+ qry = qry.group_by(*groupby_exprs_with_timestamp.values())
where_clause_and = []
having_clause_and = []
@@ -725,9 +725,15 @@ class SqlaTable(Model, BaseDatasource):
# require a unique inner alias
label = self.get_label('mme_inner__')
inner_main_metric_expr = main_metric_expr.label(label)
+ inner_groupby_exprs = []
+ inner_select_exprs = []
+ for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
+ inner = gby_obj.label(gby_name + '__')
+ inner_groupby_exprs.append(inner)
+ inner_select_exprs.append(inner)
+
inner_select_exprs += [inner_main_metric_expr]
- subq = select(inner_select_exprs)
- subq = subq.select_from(tbl)
+ subq = select(inner_select_exprs).select_from(tbl)
inner_time_filter = dttm_col.get_time_filter(
inner_from_dttm or from_dttm,
inner_to_dttm or to_dttm,
@@ -751,12 +757,12 @@ class SqlaTable(Model, BaseDatasource):
subq = subq.limit(timeseries_limit)
on_clause = []
- for i, gb in enumerate(groupby):
+ for gby_name, gby_obj in groupby_exprs_sans_timestamp.items():
# in this case the column name, not the alias, needs to be
# conditionally mutated, as it refers to the column alias in
# the inner query
- col_name = self.get_label(gb + '__')
- on_clause.append(groupby_exprs[i] == column(col_name))
+ col_name = self.get_label(gby_name + '__')
+ on_clause.append(gby_obj == column(col_name))
tbl = tbl.join(subq.alias(), and_(*on_clause))
else:
@@ -778,24 +784,23 @@ class SqlaTable(Model, BaseDatasource):
'order_desc': True,
}
result = self.query(subquery_obj)
- cols = {col.column_name: col for col in self.columns}
dimensions = [
c for c in result.df.columns
- if c not in metrics and c in cols
+ if c not in metrics and c in groupby_exprs_sans_timestamp
]
- top_groups = self._get_top_groups(result.df, dimensions)
+ top_groups = self._get_top_groups(result.df,
+ dimensions,
+ groupby_exprs_sans_timestamp)
qry = qry.where(top_groups)
return qry.select_from(tbl)
- def _get_top_groups(self, df, dimensions):
- cols = {col.column_name: col for col in self.columns}
+ def _get_top_groups(self, df, dimensions, groupby_exprs):
groups = []
for unused, row in df.iterrows():
group = []
for dimension in dimensions:
- col_obj = cols.get(dimension)
- group.append(col_obj.get_sqla_col() == row[dimension])
+ group.append(groupby_exprs[dimension] == row[dimension])
groups.append(and_(*group))
return or_(*groups)
diff --git a/tests/base_tests.py b/tests/base_tests.py
index 1f3b4c5a95a..e50e8e0e0d1 100644
--- a/tests/base_tests.py
+++ b/tests/base_tests.py
@@ -26,6 +26,7 @@ from superset import app, db, is_feature_enabled, security_manager
from superset.connectors.druid.models import DruidCluster, DruidDatasource
from superset.connectors.sqla.models import SqlaTable
from superset.models import core as models
+from superset.models.core import Database
from superset.utils.core import get_main_database
BASE_DIR = app.config.get('BASE_DIR')
@@ -99,6 +100,9 @@ class SupersetTestCase(unittest.TestCase):
def get_table_by_name(self, name):
return db.session.query(SqlaTable).filter_by(table_name=name).one()
+ def get_database_by_id(self, db_id):
+ return db.session.query(Database).filter_by(id=db_id).one()
+
def get_druid_ds_by_name(self, name):
return db.session.query(DruidDatasource).filter_by(
datasource_name=name).first()
diff --git a/tests/core_tests.py b/tests/core_tests.py
index 514ef67f692..9b8983ab7f2 100644
--- a/tests/core_tests.py
+++ b/tests/core_tests.py
@@ -757,7 +757,6 @@ class CoreTests(SupersetTestCase):
{'form_data': json.dumps(form_data)},
)
self.assertEqual(data['status'], utils.QueryStatus.FAILED)
- assert 'KeyError' in data['stacktrace']
def test_slice_payload_viz_markdown(self):
self.login(username='admin')
diff --git a/tests/model_tests.py b/tests/model_tests.py
index 9a9fc70ab6e..cc2b30c485e 100644
--- a/tests/model_tests.py
+++ b/tests/model_tests.py
@@ -16,11 +16,12 @@
# under the License.
import textwrap
+import pandas
from sqlalchemy.engine.url import make_url
from superset import app, db
from superset.models.core import Database
-from superset.utils.core import get_main_database
+from superset.utils.core import get_main_database, QueryStatus
from .base_tests import SupersetTestCase
@@ -150,11 +151,13 @@ class SqlaTableModelTestCase(SupersetTestCase):
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(ds)')
+ prev_ds_expr = ds_col.expression
ds_col.expression = 'DATE_ADD(ds, 1)'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(DATE_ADD(ds, 1))')
+ ds_col.expression = prev_ds_expr
def test_get_timestamp_expression_epoch(self):
tbl = self.get_table_by_name('birth_names')
@@ -173,11 +176,13 @@ class SqlaTableModelTestCase(SupersetTestCase):
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(from_unixtime(ds))')
+ prev_ds_expr = ds_col.expression
ds_col.expression = 'DATE_ADD(ds, 1)'
sqla_literal = ds_col.get_timestamp_expression('P1D')
compiled = '{}'.format(sqla_literal.compile())
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'DATE(from_unixtime(DATE_ADD(ds, 1)))')
+ ds_col.expression = prev_ds_expr
def test_get_timestamp_expression_backward(self):
tbl = self.get_table_by_name('birth_names')
@@ -197,6 +202,63 @@ class SqlaTableModelTestCase(SupersetTestCase):
if tbl.database.backend == 'mysql':
self.assertEquals(compiled, 'ds')
+ def query_with_expr_helper(self, is_timeseries, inner_join=True):
+ tbl = self.get_table_by_name('birth_names')
+ ds_col = tbl.get_column('ds')
+ ds_col.expression = None
+ ds_col.python_date_format = None
+ spec = self.get_database_by_id(tbl.database_id).db_engine_spec
+ if not spec.inner_joins and inner_join:
+ # if the db does not support inner joins, we cannot force it so
+ return None
+ old_inner_join = spec.inner_joins
+ spec.inner_joins = inner_join
+ arbitrary_gby = "state || gender || '_test'"
+ arbitrary_metric = (dict(label='arbitrary', expressionType='SQL',
+ sqlExpression='COUNT(1)'))
+ query_obj = dict(
+ groupby=[arbitrary_gby, 'name'],
+ metrics=[arbitrary_metric],
+ filter=[],
+ is_timeseries=is_timeseries,
+ prequeries=[],
+ columns=[],
+ granularity='ds',
+ from_dttm=None,
+ to_dttm=None,
+ is_prequery=False,
+ extras=dict(time_grain_sqla='P1Y'),
+ )
+ qr = tbl.query(query_obj)
+ self.assertEqual(qr.status, QueryStatus.SUCCESS)
+ sql = qr.query
+ self.assertIn(arbitrary_gby, sql)
+ self.assertIn('name', sql)
+ if inner_join and is_timeseries:
+ self.assertIn('JOIN', sql.upper())
+ else:
+ self.assertNotIn('JOIN', sql.upper())
+ spec.inner_joins = old_inner_join
+ self.assertIsNotNone(qr.df)
+ return qr.df
+
+ def test_query_with_expr_groupby_timeseries(self):
+ def cannonicalize_df(df):
+ ret = df.sort_values(by=list(df.columns.values), inplace=False)
+ ret.reset_index(inplace=True, drop=True)
+ return ret
+
+ df1 = self.query_with_expr_helper(is_timeseries=True, inner_join=True)
+ df2 = self.query_with_expr_helper(is_timeseries=True, inner_join=False)
+ self.assertIsNotNone(df2) # df1 can be none if the db does not support join
+ if df1 is not None:
+ pandas.testing.assert_frame_equal(
+ cannonicalize_df(df1),
+ cannonicalize_df(df2))
+
+ def test_query_with_expr_groupby(self):
+ self.query_with_expr_helper(is_timeseries=False)
+
def test_sql_mutator(self):
tbl = self.get_table_by_name('birth_names')
query_obj = dict(