Alternative PR for: Some bytes/str issues in py3 w/ zlib and json (#2558)

* sql_lab.py: compress via utils

* utils.py: added zlib_compress and zlib_compress_to_string

* core.py: converted to use zlib_decompress_to_string; renamed uncompress to decompress in utils.py

* utils_tests.py: added test for compress/decompress

* fixed broken utils test; removed redundant code and empty lines from utils.py

* utils.py: corrected docstrings, removed unnecessary 'else'

* removed yet another superfluous else
This commit is contained in:
rumbin
2017-04-06 18:42:43 +02:00
committed by Maxime Beauchemin
parent f19d1958c5
commit c581ea8661
4 changed files with 45 additions and 7 deletions

View File

@@ -6,7 +6,6 @@ import logging
import pandas as pd
import sqlalchemy
import uuid
import zlib
from sqlalchemy.pool import NullPool
from sqlalchemy.orm import sessionmaker
@@ -185,7 +184,7 @@ def get_sql_results(self, query_id, return_results=True, store_results=False):
if store_results:
key = '{}'.format(uuid.uuid4())
logging.info("Storing results in results backend, key: {}".format(key))
results_backend.set(key, zlib.compress(payload))
results_backend.set(key, utils.zlib_compress(payload))
query.results_key = key
session.merge(query)

View File

@@ -16,6 +16,8 @@ import smtplib
import sqlalchemy as sa
import signal
import uuid
import sys
import zlib
from builtins import object
from datetime import date, datetime, time
@@ -41,7 +43,7 @@ from sqlalchemy.types import TypeDecorator, TEXT
logging.getLogger('MARKDOWN').setLevel(logging.INFO)
PY3K = sys.version_info >= (3, 0)
EPOCH = datetime(1970, 1, 1)
DTTM_ALIAS = '__timestamp'
@@ -572,3 +574,34 @@ def setup_cache(app, cache_config):
"""Setup the flask-cache on a flask app"""
if cache_config and cache_config.get('CACHE_TYPE') != 'null':
return Cache(app, config=cache_config)
def zlib_compress(data):
"""
Compress things in a py2/3 safe fashion
>>> json_str = '{"test": 1}'
>>> blob = zlib_compress(json_str)
"""
if PY3K:
if isinstance(data, str):
return zlib.compress(bytes(data, "utf-8"))
return zlib.compress(data)
return zlib.compress(data)
def zlib_decompress_to_string(blob):
"""
Decompress things to a string in a py2/3 safe fashion
>>> json_str = '{"test": 1}'
>>> blob = zlib_compress(json_str)
>>> got_str = zlib_decompress_to_string(blob)
>>> got_str == json_str
True
"""
if PY3K:
if isinstance(blob, bytes):
decompressed = zlib.decompress(blob)
else:
decompressed = zlib.decompress(bytes(blob, "utf-8"))
return decompressed.decode("utf-8")
return zlib.decompress(blob)

View File

@@ -11,7 +11,6 @@ import pickle
import re
import time
import traceback
import zlib
import sqlalchemy as sqla
@@ -1878,7 +1877,7 @@ class Superset(BaseSupersetView):
return json_error_response(get_datasource_access_error_msg(
'{}'.format(rejected_tables)))
payload = zlib.decompress(blob)
payload = utils.zlib_decompress_to_string(blob)
display_limit = app.config.get('DISPLAY_SQL_MAX_ROW', None)
if display_limit:
payload_json = json.loads(payload)
@@ -2018,7 +2017,7 @@ class Superset(BaseSupersetView):
if results_backend and query.results_key:
blob = results_backend.get(query.results_key)
if blob:
json_payload = zlib.decompress(blob)
json_payload = utils.zlib_decompress_to_string(blob)
obj = json.loads(json_payload)
columns = [c['name'] for c in obj['columns']]
df = pd.DataFrame.from_records(obj['data'], columns=columns)

View File

@@ -1,7 +1,7 @@
from datetime import datetime, date, timedelta, time
from decimal import Decimal
from superset.utils import (
json_int_dttm_ser, json_iso_dttm_ser, base_json_conv, parse_human_timedelta
json_int_dttm_ser, json_iso_dttm_ser, base_json_conv, parse_human_timedelta, zlib_compress, zlib_decompress_to_string
)
import unittest
import uuid
@@ -45,3 +45,10 @@ class UtilsTestCase(unittest.TestCase):
def test_parse_human_timedelta(self, mock_now):
mock_now.return_value = datetime(2016, 12, 1)
self.assertEquals(parse_human_timedelta('now'), timedelta(0))
def test_zlib_compression(self):
json_str = """{"test": 1}"""
blob = zlib_compress(json_str)
got_str = zlib_decompress_to_string(blob)
self.assertEquals(json_str, got_str)