"""Utility functions used across Superset""" from __future__ import absolute_import from __future__ import division from __future__ import print_function from __future__ import unicode_literals import decimal import functools import json import logging import numpy import os import parsedatetime import pytz import smtplib import sqlalchemy as sa import signal import uuid import sys import zlib from builtins import object from datetime import date, datetime, time import celery from dateutil.parser import parse from email.mime.text import MIMEText from email.mime.multipart import MIMEMultipart from email.mime.application import MIMEApplication from email.utils import formatdate from flask import flash, Markup, render_template, url_for, redirect, request from flask_appbuilder.const import ( LOGMSG_ERR_SEC_ACCESS_DENIED, FLAMSG_ERR_SEC_ACCESS_DENIED, PERMISSION_PREFIX ) from flask_cache import Cache from flask_appbuilder._compat import as_unicode from flask_babel import gettext as __ import markdown as md from past.builtins import basestring from pydruid.utils.having import Having from sqlalchemy import event, exc from sqlalchemy.types import TypeDecorator, TEXT logging.getLogger('MARKDOWN').setLevel(logging.INFO) PY3K = sys.version_info >= (3, 0) EPOCH = datetime(1970, 1, 1) DTTM_ALIAS = '__timestamp' class SupersetException(Exception): pass class SupersetTimeoutException(SupersetException): pass class SupersetSecurityException(SupersetException): pass class MetricPermException(SupersetException): pass class NoDataException(SupersetException): pass class SupersetTemplateException(SupersetException): pass def can_access(sm, permission_name, view_name, user): """Protecting from has_access failing from missing perms/view""" if user.is_anonymous(): return sm.is_item_public(permission_name, view_name) else: return sm._has_view_access(user, permission_name, view_name) def flasher(msg, severity=None): """Flask's flash if available, logging call if not""" try: flash(msg, severity) except RuntimeError: if severity == 'danger': logging.error(msg) else: logging.info(msg) class memoized(object): # noqa """Decorator that caches a function's return value each time it is called If called later with the same arguments, the cached value is returned, and not re-evaluated. """ def __init__(self, func): self.func = func self.cache = {} def __call__(self, *args): try: return self.cache[args] except KeyError: value = self.func(*args) self.cache[args] = value return value except TypeError: # uncachable -- for instance, passing a list as an argument. # Better to not cache than to blow up entirely. return self.func(*args) def __repr__(self): """Return the function's docstring.""" return self.func.__doc__ def __get__(self, obj, objtype): """Support instance methods.""" return functools.partial(self.__call__, obj) def js_string_to_python(item): return None if item in ('null', 'undefined') else item def string_to_num(s): """Converts a string to an int/float Returns ``None`` if it can't be converted >>> string_to_num('5') 5 >>> string_to_num('5.2') 5.2 >>> string_to_num(10) 10 >>> string_to_num(10.1) 10.1 >>> string_to_num('this is not a string') is None True """ if isinstance(s, (int, float)): return s if s.isdigit(): return int(s) try: return float(s) except ValueError: return None class DimSelector(Having): def __init__(self, **args): # Just a hack to prevent any exceptions Having.__init__(self, type='equalTo', aggregation=None, value=None) self.having = {'having': { 'type': 'dimSelector', 'dimension': args['dimension'], 'value': args['value'], }} def list_minus(l, minus): """Returns l without what is in minus >>> list_minus([1, 2, 3], [2]) [1, 3] """ return [o for o in l if o not in minus] def parse_human_datetime(s): """ Returns ``datetime.datetime`` from human readable strings >>> from datetime import date, timedelta >>> from dateutil.relativedelta import relativedelta >>> parse_human_datetime('2015-04-03') datetime.datetime(2015, 4, 3, 0, 0) >>> parse_human_datetime('2/3/1969') datetime.datetime(1969, 2, 3, 0, 0) >>> parse_human_datetime("now") <= datetime.now() True >>> parse_human_datetime("yesterday") <= datetime.now() True >>> date.today() - timedelta(1) == parse_human_datetime('yesterday').date() True >>> year_ago_1 = parse_human_datetime('one year ago').date() >>> year_ago_2 = (datetime.now() - relativedelta(years=1) ).date() >>> year_ago_1 == year_ago_2 True """ try: dttm = parse(s) except Exception: try: cal = parsedatetime.Calendar() dttm = dttm_from_timtuple(cal.parse(s)[0]) except Exception as e: logging.exception(e) raise ValueError("Couldn't parse date string [{}]".format(s)) return dttm def dttm_from_timtuple(d): return datetime( d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec) def parse_human_timedelta(s): """ Returns ``datetime.datetime`` from natural language time deltas >>> parse_human_datetime("now") <= datetime.now() True """ cal = parsedatetime.Calendar() dttm = dttm_from_timtuple(datetime.now().timetuple()) d = cal.parse(s, dttm)[0] d = datetime( d.tm_year, d.tm_mon, d.tm_mday, d.tm_hour, d.tm_min, d.tm_sec) return d - dttm class JSONEncodedDict(TypeDecorator): """Represents an immutable structure as a json-encoded string.""" impl = TEXT def process_bind_param(self, value, dialect): if value is not None: value = json.dumps(value) return value def process_result_value(self, value, dialect): if value is not None: value = json.loads(value) return value def datetime_f(dttm): """Formats datetime to take less room when it is recent""" if dttm: dttm = dttm.isoformat() now_iso = datetime.now().isoformat() if now_iso[:10] == dttm[:10]: dttm = dttm[11:] elif now_iso[:4] == dttm[:4]: dttm = dttm[5:] return "{}".format(dttm) def base_json_conv(obj): if isinstance(obj, numpy.int64): return int(obj) elif isinstance(obj, numpy.bool_): return bool(obj) elif isinstance(obj, set): return list(obj) elif isinstance(obj, decimal.Decimal): return float(obj) elif isinstance(obj, uuid.UUID): return str(obj) def json_iso_dttm_ser(obj): """ json serializer that deals with dates >>> dttm = datetime(1970, 1, 1) >>> json.dumps({'dttm': dttm}, default=json_iso_dttm_ser) '{"dttm": "1970-01-01T00:00:00"}' """ val = base_json_conv(obj) if val is not None: return val if isinstance(obj, datetime): obj = obj.isoformat() elif isinstance(obj, date): obj = obj.isoformat() elif isinstance(obj, time): obj = obj.isoformat() else: raise TypeError( "Unserializable object {} of type {}".format(obj, type(obj)) ) return obj def datetime_to_epoch(dttm): if dttm.tzinfo: epoch_with_tz = pytz.utc.localize(EPOCH) return (dttm - epoch_with_tz).total_seconds() * 1000 return (dttm - EPOCH).total_seconds() * 1000 def now_as_float(): return datetime_to_epoch(datetime.utcnow()) def json_int_dttm_ser(obj): """json serializer that deals with dates""" val = base_json_conv(obj) if val is not None: return val if isinstance(obj, datetime): obj = datetime_to_epoch(obj) elif isinstance(obj, date): obj = (obj - EPOCH.date()).total_seconds() * 1000 else: raise TypeError( "Unserializable object {} of type {}".format(obj, type(obj)) ) return obj def json_dumps_w_dates(payload): return json.dumps(payload, default=json_int_dttm_ser) def error_msg_from_exception(e): """Translate exception into error message Database have different ways to handle exception. This function attempts to make sense of the exception object and construct a human readable sentence. TODO(bkyryliuk): parse the Presto error message from the connection created via create_engine. engine = create_engine('presto://localhost:3506/silver') - gives an e.message as the str(dict) presto.connect("localhost", port=3506, catalog='silver') - as a dict. The latter version is parsed correctly by this function. """ msg = '' if hasattr(e, 'message'): if type(e.message) is dict: msg = e.message.get('message') elif e.message: msg = "{}".format(e.message) return msg or '{}'.format(e) def markdown(s, markup_wrap=False): s = md.markdown(s or '', [ 'markdown.extensions.tables', 'markdown.extensions.fenced_code', 'markdown.extensions.codehilite', ]) if markup_wrap: s = Markup(s) return s def readfile(file_path): with open(file_path) as f: content = f.read() return content def generic_find_constraint_name(table, columns, referenced, db): """Utility to find a constraint name in alembic migrations""" t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) for fk in t.foreign_key_constraints: if ( fk.referred_table.name == referenced and set(fk.column_keys) == columns): return fk.name def get_datasource_full_name(database_name, datasource_name, schema=None): if not schema: return "[{}].[{}]".format(database_name, datasource_name) return "[{}].[{}].[{}]".format(database_name, schema, datasource_name) def get_schema_perm(database, schema): if schema: return "[{}].[{}]".format(database, schema) def validate_json(obj): if obj: try: json.loads(obj) except Exception: raise SupersetException("JSON is not valid") def table_has_constraint(table, name, db): """Utility to find a constraint name in alembic migrations""" t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine) for c in t.constraints: if c.name == name: return True return False class timeout(object): """ To be used in a ``with`` block and timeout its content. """ def __init__(self, seconds=1, error_message='Timeout'): self.seconds = seconds self.error_message = error_message def handle_timeout(self, signum, frame): logging.error("Process timed out") raise SupersetTimeoutException(self.error_message) def __enter__(self): try: signal.signal(signal.SIGALRM, self.handle_timeout) signal.alarm(self.seconds) except ValueError as e: logging.warning("timeout can't be used in the current context") logging.exception(e) def __exit__(self, type, value, traceback): try: signal.alarm(0) except ValueError as e: logging.warning("timeout can't be used in the current context") logging.exception(e) def pessimistic_connection_handling(target): @event.listens_for(target, "checkout") def ping_connection(dbapi_connection, connection_record, connection_proxy): """ Disconnect Handling - Pessimistic, taken from: http://docs.sqlalchemy.org/en/rel_0_9/core/pooling.html """ cursor = dbapi_connection.cursor() try: cursor.execute("SELECT 1") except: raise exc.DisconnectionError() cursor.close() class QueryStatus(object): """Enum-type class for query statuses""" STOPPED = 'stopped' FAILED = 'failed' PENDING = 'pending' RUNNING = 'running' SCHEDULED = 'scheduled' SUCCESS = 'success' TIMED_OUT = 'timed_out' def notify_user_about_perm_udate( granter, user, role, datasource, tpl_name, config): msg = render_template(tpl_name, granter=granter, user=user, role=role, datasource=datasource) logging.info(msg) subject = __('[Superset] Access to the datasource %(name)s was granted', name=datasource.full_name) send_email_smtp(user.email, subject, msg, config, bcc=granter.email, dryrun=not config.get('EMAIL_NOTIFICATIONS')) def send_email_smtp(to, subject, html_content, config, files=None, dryrun=False, cc=None, bcc=None, mime_subtype='mixed'): """ Send an email with html content, eg: send_email_smtp( 'test@example.com', 'foo', 'Foo bar',['/dev/null'], dryrun=True) """ smtp_mail_from = config.get('SMTP_MAIL_FROM') to = get_email_address_list(to) msg = MIMEMultipart(mime_subtype) msg['Subject'] = subject msg['From'] = smtp_mail_from msg['To'] = ", ".join(to) recipients = to if cc: cc = get_email_address_list(cc) msg['CC'] = ", ".join(cc) recipients = recipients + cc if bcc: # don't add bcc in header bcc = get_email_address_list(bcc) recipients = recipients + bcc msg['Date'] = formatdate(localtime=True) mime_text = MIMEText(html_content, 'html') msg.attach(mime_text) for fname in files or []: basename = os.path.basename(fname) with open(fname, "rb") as f: msg.attach(MIMEApplication( f.read(), Content_Disposition='attachment; filename="%s"' % basename, Name=basename )) send_MIME_email(smtp_mail_from, recipients, msg, config, dryrun=dryrun) def send_MIME_email(e_from, e_to, mime_msg, config, dryrun=False): SMTP_HOST = config.get('SMTP_HOST') SMTP_PORT = config.get('SMTP_PORT') SMTP_USER = config.get('SMTP_USER') SMTP_PASSWORD = config.get('SMTP_PASSWORD') SMTP_STARTTLS = config.get('SMTP_STARTTLS') SMTP_SSL = config.get('SMTP_SSL') if not dryrun: s = smtplib.SMTP_SSL(SMTP_HOST, SMTP_PORT) if SMTP_SSL else \ smtplib.SMTP(SMTP_HOST, SMTP_PORT) if SMTP_STARTTLS: s.starttls() if SMTP_USER and SMTP_PASSWORD: s.login(SMTP_USER, SMTP_PASSWORD) logging.info("Sent an alert email to " + str(e_to)) s.sendmail(e_from, e_to, mime_msg.as_string()) s.quit() else: logging.info('Dryrun enabled, email notification content is below:') logging.info(mime_msg.as_string()) def get_email_address_list(address_string): if isinstance(address_string, basestring): if ',' in address_string: address_string = address_string.split(',') elif ';' in address_string: address_string = address_string.split(';') else: address_string = [address_string] return address_string def has_access(f): """ Use this decorator to enable granular security permissions to your methods. Permissions will be associated to a role, and roles are associated to users. By default the permission's name is the methods name. Forked from the flask_appbuilder.security.decorators TODO(bkyryliuk): contribute it back to FAB """ if hasattr(f, '_permission_name'): permission_str = f._permission_name else: permission_str = f.__name__ def wraps(self, *args, **kwargs): permission_str = PERMISSION_PREFIX + f._permission_name if self.appbuilder.sm.has_access( permission_str, self.__class__.__name__): return f(self, *args, **kwargs) else: logging.warning(LOGMSG_ERR_SEC_ACCESS_DENIED.format( permission_str, self.__class__.__name__)) flash(as_unicode(FLAMSG_ERR_SEC_ACCESS_DENIED), "danger") # adds next arg to forward to the original path once user is logged in. return redirect(url_for( self.appbuilder.sm.auth_view.__class__.__name__ + ".login", next=request.path)) f._permission_name = permission_str return functools.update_wrapper(wraps, f) def choicify(values): """Takes an iterable and makes an iterable of tuples with it""" return [(v, v) for v in values] def setup_cache(app, cache_config): """Setup the flask-cache on a flask app""" if cache_config and cache_config.get('CACHE_TYPE') != 'null': return Cache(app, config=cache_config) def zlib_compress(data): """ Compress things in a py2/3 safe fashion >>> json_str = '{"test": 1}' >>> blob = zlib_compress(json_str) """ if PY3K: if isinstance(data, str): return zlib.compress(bytes(data, "utf-8")) return zlib.compress(data) return zlib.compress(data) def zlib_decompress_to_string(blob): """ Decompress things to a string in a py2/3 safe fashion >>> json_str = '{"test": 1}' >>> blob = zlib_compress(json_str) >>> got_str = zlib_decompress_to_string(blob) >>> got_str == json_str True """ if PY3K: if isinstance(blob, bytes): decompressed = zlib.decompress(blob) else: decompressed = zlib.decompress(bytes(blob, "utf-8")) return decompressed.decode("utf-8") return zlib.decompress(blob) _celery_app = None def get_celery_app(config): global _celery_app if _celery_app: return _celery_app _celery_app = celery.Celery(config_source=config.get('CELERY_CONFIG')) return _celery_app