diff --git a/superset/assets/package.json b/superset/assets/package.json
index c944ad2fa0b..abc978c079d 100644
--- a/superset/assets/package.json
+++ b/superset/assets/package.json
@@ -93,8 +93,8 @@
"react-sortable-hoc": "^0.6.7",
"react-split-pane": "^0.1.66",
"react-syntax-highlighter": "^5.7.0",
- "react-virtualized": "^9.3.0",
- "react-virtualized-select": "^2.4.0",
+ "react-virtualized": "9.3.0",
+ "react-virtualized-select": "2.4.0",
"reactable": "^0.14.1",
"redux": "^3.5.2",
"redux-localstorage": "^0.4.1",
diff --git a/superset/models/core.py b/superset/models/core.py
index 142482bdbd7..cfd6d752037 100644
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -230,7 +230,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
name = escape(self.slice_name)
return Markup('{name}'.format(**locals()))
- def get_viz(self):
+ def get_viz(self, force=False):
"""Creates :py:class:viz.BaseViz object from the url_params_multidict.
:return: object of the 'viz_type' type that is taken from the
@@ -246,6 +246,7 @@ class Slice(Model, AuditMixinNullable, ImportMixin):
return viz_types[slice_params.get('viz_type')](
self.datasource,
form_data=slice_params,
+ force=force,
)
@classmethod
diff --git a/superset/views/core.py b/superset/views/core.py
index 44b6b2a5a42..7fb96a9bc13 100755
--- a/superset/views/core.py
+++ b/superset/views/core.py
@@ -954,7 +954,9 @@ class Superset(BaseSupersetView):
slice_id=None,
form_data=None,
datasource_type=None,
- datasource_id=None):
+ datasource_id=None,
+ force=False,
+ ):
if slice_id:
slc = (
db.session.query(models.Slice)
@@ -969,6 +971,7 @@ class Superset(BaseSupersetView):
viz_obj = viz.viz_types[viz_type](
datasource,
form_data=form_data,
+ force=force,
)
return viz_obj
@@ -1017,7 +1020,9 @@ class Superset(BaseSupersetView):
viz_obj = self.get_viz(
datasource_type=datasource_type,
datasource_id=datasource_id,
- form_data=form_data)
+ form_data=form_data,
+ force=force,
+ )
except Exception as e:
logging.exception(e)
return json_error_response(
@@ -1038,7 +1043,7 @@ class Superset(BaseSupersetView):
return self.get_query_string_response(viz_obj)
try:
- payload = viz_obj.get_payload(force=force)
+ payload = viz_obj.get_payload()
except Exception as e:
logging.exception(e)
return json_error_response(utils.error_msg_from_exception(e))
@@ -1082,9 +1087,10 @@ class Superset(BaseSupersetView):
viz_obj = viz.viz_types['table'](
datasource,
form_data=form_data,
+ force=False,
)
try:
- payload = viz_obj.get_payload(force=False)
+ payload = viz_obj.get_payload()
except Exception as e:
logging.exception(e)
return json_error_response(utils.error_msg_from_exception(e))
@@ -1876,8 +1882,8 @@ class Superset(BaseSupersetView):
for slc in slices:
try:
- obj = slc.get_viz()
- obj.get_json(force=True)
+ obj = slc.get_viz(force=True)
+ obj.get_json()
except Exception as e:
return json_error_response(utils.error_msg_from_exception(e))
return json_success(json.dumps(
diff --git a/superset/viz.py b/superset/viz.py
index ebd0c788b9d..d66884a3b6a 100644
--- a/superset/viz.py
+++ b/superset/viz.py
@@ -18,10 +18,9 @@ import logging
import math
import traceback
import uuid
-import zlib
from dateutil import relativedelta as rdelta
-from flask import request
+from flask import escape, request
from flask_babel import lazy_gettext as _
import geohash
from markdown import markdown
@@ -30,8 +29,8 @@ import pandas as pd
from pandas.tseries.frequencies import to_offset
import polyline
import simplejson as json
-from six import PY3, string_types, text_type
-from six.moves import reduce
+from six import string_types, text_type
+from six.moves import cPickle as pkl, reduce
from superset import app, cache, get_manifest_file, utils
from superset.utils import DTTM_ALIAS, merge_extra_filters
@@ -49,8 +48,9 @@ class BaseViz(object):
credits = ''
is_timeseries = False
default_fillna = 0
+ cache_type = 'df'
- def __init__(self, datasource, form_data):
+ def __init__(self, datasource, form_data, force=False):
if not datasource:
raise Exception(_('Viz is missing a datasource'))
self.datasource = datasource
@@ -67,6 +67,40 @@ class BaseViz(object):
self.status = None
self.error_message = None
+ self.force = force
+
+ # Keeping track of whether some data came from cache
+ # this is useful to trigerr the when
+ # in the cases where visualization have many queries
+ # (FilterBox for instance)
+ self._some_from_cache = False
+ self._any_cache_key = None
+ self._any_cached_dttm = None
+
+ self.run_extra_queries()
+
+ def run_extra_queries(self):
+ """Lyfecycle method to use when more than one query is needed
+
+ In rare-ish cases, a visualization may need to execute multiple
+ queries. That is the case for FilterBox or for time comparison
+ in Line chart for instance.
+
+ In those cases, we need to make sure these queries run before the
+ main `get_payload` method gets called, so that the overall caching
+ metadata can be right. The way it works here is that if any of
+ the previous `get_df_payload` calls hit the cache, the main
+ payload's metadata will reflect that.
+
+ The multi-query support may need more work to become a first class
+ use case in the framework, and for the UI to reflect the subtleties
+ (show that only some of the queries were served from cache for
+ instance). In the meantime, since multi-query is rare, we treat
+ it with a bit of a hack. Note that the hack became necessary
+ when moving from caching the visualization's data itself, to caching
+ the underlying query(ies).
+ """
+ pass
def get_fillna_for_col(self, col):
"""Returns the value for use as filler for a specific Column.type"""
@@ -225,9 +259,9 @@ class BaseViz(object):
return self.datasource.database.cache_timeout
return config.get('CACHE_DEFAULT_TIMEOUT')
- def get_json(self, force=False):
+ def get_json(self):
return json.dumps(
- self.get_payload(force),
+ self.get_payload(),
default=utils.json_int_dttm_ser, ignore_nan=True)
def cache_key(self, query_obj):
@@ -249,64 +283,73 @@ class BaseViz(object):
json_data = self.json_dumps(cache_dict, sort_keys=True)
return hashlib.md5(json_data.encode('utf-8')).hexdigest()
- def get_payload(self, force=False):
- """Handles caching around the json payload retrieval"""
- query_obj = self.query_obj()
+ def get_payload(self, query_obj=None):
+ """Returns a payload of metadata and data"""
+ payload = self.get_df_payload(query_obj)
+ df = payload['df']
+ if df is not None:
+ payload['data'] = self.get_data(df)
+ del payload['df']
+ return payload
+
+ def get_df_payload(self, query_obj=None):
+ """Handles caching around the df payload retrieval"""
+ if not query_obj:
+ query_obj = self.query_obj()
cache_key = self.cache_key(query_obj) if query_obj else None
- cached_dttm = None
- data = None
+ logging.info('Cache key: {}'.format(cache_key))
+ is_loaded = False
stacktrace = None
- rowcount = None
- if cache_key and cache and not force:
+ df = None
+ cached_dttm = datetime.utcnow().isoformat().split('.')[0]
+ if cache_key and cache and not self.force:
cache_value = cache.get(cache_key)
if cache_value:
stats_logger.incr('loaded_from_cache')
- is_cached = True
try:
- cache_value = zlib.decompress(cache_value)
- if PY3:
- cache_value = cache_value.decode('utf-8')
- cache_value = json.loads(cache_value)
- data = cache_value['data']
- cached_dttm = cache_value['dttm']
+ cache_value = pkl.loads(cache_value)
+ df = cache_value['df']
+ is_loaded = True
+ self._any_cache_key = cache_key
+ self._any_cached_dttm = cache_value['dttm']
except Exception as e:
+ logging.exception(e)
logging.error('Error reading cache: ' +
utils.error_msg_from_exception(e))
- data = None
logging.info('Serving from cache')
- if not data:
- stats_logger.incr('loaded_from_source')
- is_cached = False
+ if query_obj and not is_loaded:
try:
df = self.get_df(query_obj)
- if not self.error_message:
- data = self.get_data(df)
- rowcount = len(df.index) if df is not None else 0
+ stats_logger.incr('loaded_from_source')
+ is_loaded = True
except Exception as e:
logging.exception(e)
if not self.error_message:
- self.error_message = str(e)
+ self.error_message = escape('{}'.format(e))
self.status = utils.QueryStatus.FAILED
- data = None
stacktrace = traceback.format_exc()
if (
- data and
+ is_loaded and
cache_key and
cache and
self.status != utils.QueryStatus.FAILED):
- cached_dttm = datetime.utcnow().isoformat().split('.')[0]
try:
- cache_value = self.json_dumps({
- 'data': data,
- 'dttm': cached_dttm,
- })
- if PY3:
- cache_value = bytes(cache_value, 'utf-8')
+ cache_value = dict(
+ dttm=cached_dttm,
+ df=df if df is not None else None,
+ )
+ cache_value = pkl.dumps(
+ cache_value, protocol=pkl.HIGHEST_PROTOCOL)
+
+ logging.info('Caching {} chars at key {}'.format(
+ len(cache_value), cache_key))
+
+ stats_logger.incr('set_cache_key')
cache.set(
cache_key,
- zlib.compress(cache_value),
+ cache_value,
timeout=self.cache_timeout)
except Exception as e:
# cache.set call can fail if the backend is down or if
@@ -316,17 +359,17 @@ class BaseViz(object):
cache.delete(cache_key)
return {
- 'cache_key': cache_key,
- 'cached_dttm': cached_dttm,
+ 'cache_key': self._any_cache_key,
+ 'cached_dttm': self._any_cached_dttm,
'cache_timeout': self.cache_timeout,
- 'data': data,
+ 'df': df,
'error': self.error_message,
'form_data': self.form_data,
- 'is_cached': is_cached,
+ 'is_cached': self._any_cache_key is not None,
'query': self.query,
'status': self.status,
'stacktrace': stacktrace,
- 'rowcount': rowcount,
+ 'rowcount': len(df.index) if df is not None else 0,
}
def json_dumps(self, obj, sort_keys=False):
@@ -415,7 +458,11 @@ class TableViz(BaseViz):
def get_data(self, df):
fd = self.form_data
- if not self.should_be_timeseries() and DTTM_ALIAS in df:
+ if (
+ not self.should_be_timeseries() and
+ df is not None and
+ DTTM_ALIAS in df
+ ):
del df[DTTM_ALIAS]
# Sum up and compute percentages for all percent metrics
@@ -1062,12 +1109,10 @@ class NVD3TimeSeriesViz(NVD3Viz):
df = df[num_period_compare:]
return df
- def get_data(self, df):
+ def run_extra_queries(self):
fd = self.form_data
- df = self.process_data(df)
- chart_data = self.to_series(df)
-
time_compare = fd.get('time_compare')
+ self.extra_chart_data = None
if time_compare:
query_object = self.query_obj()
delta = utils.parse_human_timedelta(time_compare)
@@ -1081,12 +1126,20 @@ class NVD3TimeSeriesViz(NVD3Viz):
query_object['from_dttm'] -= delta
query_object['to_dttm'] -= delta
- df2 = self.get_df(query_object)
+ df2 = self.get_df_payload(query_object).get('df')
df2[DTTM_ALIAS] += delta
df2 = self.process_data(df2)
- chart_data += self.to_series(
+ self.extra_chart_data = self.to_series(
df2, classed='superset', title_suffix='---')
+
+ def get_data(self, df):
+ df = self.process_data(df)
+ chart_data = self.to_series(df)
+
+ if self.extra_chart_data:
+ chart_data += self.extra_chart_data
chart_data = sorted(chart_data, key=lambda x: x['key'])
+
return chart_data
@@ -1564,10 +1617,20 @@ class FilterBoxViz(BaseViz):
verbose_name = _('Filters')
is_timeseries = False
credits = 'a Superset original'
+ cache_type = 'get_data'
def query_obj(self):
return None
+ def run_extra_queries(self):
+ qry = self.filter_query_obj()
+ filters = [g for g in self.form_data['groupby']]
+ self.dataframes = {}
+ for flt in filters:
+ qry['groupby'] = [flt]
+ df = self.get_df_payload(query_obj=qry).get('df')
+ self.dataframes[flt] = df
+
def filter_query_obj(self):
qry = super(FilterBoxViz, self).query_obj()
groupby = self.form_data.get('groupby')
@@ -1578,12 +1641,10 @@ class FilterBoxViz(BaseViz):
return qry
def get_data(self, df):
- qry = self.filter_query_obj()
- filters = [g for g in self.form_data['groupby']]
d = {}
+ filters = [g for g in self.form_data['groupby']]
for flt in filters:
- qry['groupby'] = [flt]
- df = super(FilterBoxViz, self).get_df(qry)
+ df = self.dataframes[flt]
d[flt] = [{
'id': row[0],
'text': row[0],