mirror of
https://github.com/apache/superset.git
synced 2026-04-20 00:24:38 +00:00
[format] Using Black (#7769)
This commit is contained in:
@@ -30,11 +30,10 @@ from superset.models import core as models
|
||||
from superset.models.core import Database
|
||||
from superset.utils.core import get_main_database
|
||||
|
||||
BASE_DIR = app.config.get('BASE_DIR')
|
||||
BASE_DIR = app.config.get("BASE_DIR")
|
||||
|
||||
|
||||
class SupersetTestCase(unittest.TestCase):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SupersetTestCase, self).__init__(*args, **kwargs)
|
||||
self.client = app.test_client()
|
||||
@@ -45,34 +44,25 @@ class SupersetTestCase(unittest.TestCase):
|
||||
# create druid cluster and druid datasources
|
||||
session = db.session
|
||||
cluster = (
|
||||
session.query(DruidCluster)
|
||||
.filter_by(cluster_name='druid_test')
|
||||
.first()
|
||||
session.query(DruidCluster).filter_by(cluster_name="druid_test").first()
|
||||
)
|
||||
if not cluster:
|
||||
cluster = DruidCluster(cluster_name='druid_test')
|
||||
cluster = DruidCluster(cluster_name="druid_test")
|
||||
session.add(cluster)
|
||||
session.commit()
|
||||
|
||||
druid_datasource1 = DruidDatasource(
|
||||
datasource_name='druid_ds_1',
|
||||
cluster_name='druid_test',
|
||||
datasource_name="druid_ds_1", cluster_name="druid_test"
|
||||
)
|
||||
session.add(druid_datasource1)
|
||||
druid_datasource2 = DruidDatasource(
|
||||
datasource_name='druid_ds_2',
|
||||
cluster_name='druid_test',
|
||||
datasource_name="druid_ds_2", cluster_name="druid_test"
|
||||
)
|
||||
session.add(druid_datasource2)
|
||||
session.commit()
|
||||
|
||||
def get_table(self, table_id):
|
||||
return (
|
||||
db.session
|
||||
.query(SqlaTable)
|
||||
.filter_by(id=table_id)
|
||||
.one()
|
||||
)
|
||||
return db.session.query(SqlaTable).filter_by(id=table_id).one()
|
||||
|
||||
@staticmethod
|
||||
def is_module_installed(module_name):
|
||||
@@ -91,18 +81,12 @@ class SupersetTestCase(unittest.TestCase):
|
||||
session.commit()
|
||||
return obj
|
||||
|
||||
def login(self, username='admin', password='general'):
|
||||
resp = self.get_resp(
|
||||
'/login/',
|
||||
data=dict(username=username, password=password))
|
||||
self.assertNotIn('User confirmation needed', resp)
|
||||
def login(self, username="admin", password="general"):
|
||||
resp = self.get_resp("/login/", data=dict(username=username, password=password))
|
||||
self.assertNotIn("User confirmation needed", resp)
|
||||
|
||||
def get_slice(self, slice_name, session):
|
||||
slc = (
|
||||
session.query(models.Slice)
|
||||
.filter_by(slice_name=slice_name)
|
||||
.one()
|
||||
)
|
||||
slc = session.query(models.Slice).filter_by(slice_name=slice_name).one()
|
||||
session.expunge_all()
|
||||
return slc
|
||||
|
||||
@@ -113,8 +97,7 @@ class SupersetTestCase(unittest.TestCase):
|
||||
return db.session.query(Database).filter_by(id=db_id).one()
|
||||
|
||||
def get_druid_ds_by_name(self, name):
|
||||
return db.session.query(DruidDatasource).filter_by(
|
||||
datasource_name=name).first()
|
||||
return db.session.query(DruidDatasource).filter_by(datasource_name=name).first()
|
||||
|
||||
def get_datasource_mock(self):
|
||||
datasource = Mock()
|
||||
@@ -123,7 +106,7 @@ class SupersetTestCase(unittest.TestCase):
|
||||
results.status = Mock()
|
||||
results.error_message = None
|
||||
results.df = pd.DataFrame()
|
||||
datasource.type = 'table'
|
||||
datasource.type = "table"
|
||||
datasource.query = Mock(return_value=results)
|
||||
mock_dttm_col = Mock()
|
||||
datasource.get_col = Mock(return_value=mock_dttm_col)
|
||||
@@ -133,21 +116,17 @@ class SupersetTestCase(unittest.TestCase):
|
||||
datasource.database.db_engine_spec.mutate_expression_label = lambda x: x
|
||||
return datasource
|
||||
|
||||
def get_resp(
|
||||
self, url, data=None, follow_redirects=True, raise_on_error=True):
|
||||
def get_resp(self, url, data=None, follow_redirects=True, raise_on_error=True):
|
||||
"""Shortcut to get the parsed results while following redirects"""
|
||||
if data:
|
||||
resp = self.client.post(
|
||||
url, data=data, follow_redirects=follow_redirects)
|
||||
resp = self.client.post(url, data=data, follow_redirects=follow_redirects)
|
||||
else:
|
||||
resp = self.client.get(url, follow_redirects=follow_redirects)
|
||||
if raise_on_error and resp.status_code > 400:
|
||||
raise Exception(
|
||||
'http request failed with code {}'.format(resp.status_code))
|
||||
return resp.data.decode('utf-8')
|
||||
raise Exception("http request failed with code {}".format(resp.status_code))
|
||||
return resp.data.decode("utf-8")
|
||||
|
||||
def get_json_resp(
|
||||
self, url, data=None, follow_redirects=True, raise_on_error=True):
|
||||
def get_json_resp(self, url, data=None, follow_redirects=True, raise_on_error=True):
|
||||
"""Shortcut to get the parsed results while following redirects"""
|
||||
resp = self.get_resp(url, data, follow_redirects, raise_on_error)
|
||||
return json.loads(resp)
|
||||
@@ -165,66 +144,82 @@ class SupersetTestCase(unittest.TestCase):
|
||||
)
|
||||
|
||||
def logout(self):
|
||||
self.client.get('/logout/', follow_redirects=True)
|
||||
self.client.get("/logout/", follow_redirects=True)
|
||||
|
||||
def grant_public_access_to_table(self, table):
|
||||
public_role = security_manager.find_role('Public')
|
||||
public_role = security_manager.find_role("Public")
|
||||
perms = db.session.query(ab_models.PermissionView).all()
|
||||
for perm in perms:
|
||||
if (perm.permission.name == 'datasource_access' and
|
||||
perm.view_menu and table.perm in perm.view_menu.name):
|
||||
if (
|
||||
perm.permission.name == "datasource_access"
|
||||
and perm.view_menu
|
||||
and table.perm in perm.view_menu.name
|
||||
):
|
||||
security_manager.add_permission_role(public_role, perm)
|
||||
|
||||
def revoke_public_access_to_table(self, table):
|
||||
public_role = security_manager.find_role('Public')
|
||||
public_role = security_manager.find_role("Public")
|
||||
perms = db.session.query(ab_models.PermissionView).all()
|
||||
for perm in perms:
|
||||
if (perm.permission.name == 'datasource_access' and
|
||||
perm.view_menu and table.perm in perm.view_menu.name):
|
||||
if (
|
||||
perm.permission.name == "datasource_access"
|
||||
and perm.view_menu
|
||||
and table.perm in perm.view_menu.name
|
||||
):
|
||||
security_manager.del_permission_role(public_role, perm)
|
||||
|
||||
def get_main_database(self):
|
||||
return get_main_database(db.session)
|
||||
|
||||
def run_sql(self, sql, client_id=None, user_name=None, raise_on_error=False,
|
||||
query_limit=None):
|
||||
def run_sql(
|
||||
self,
|
||||
sql,
|
||||
client_id=None,
|
||||
user_name=None,
|
||||
raise_on_error=False,
|
||||
query_limit=None,
|
||||
):
|
||||
if user_name:
|
||||
self.logout()
|
||||
self.login(username=(user_name if user_name else 'admin'))
|
||||
self.login(username=(user_name if user_name else "admin"))
|
||||
dbid = self.get_main_database().id
|
||||
resp = self.get_json_resp(
|
||||
'/superset/sql_json/',
|
||||
"/superset/sql_json/",
|
||||
raise_on_error=False,
|
||||
data=dict(database_id=dbid, sql=sql, select_as_create_as=False,
|
||||
client_id=client_id, queryLimit=query_limit),
|
||||
data=dict(
|
||||
database_id=dbid,
|
||||
sql=sql,
|
||||
select_as_create_as=False,
|
||||
client_id=client_id,
|
||||
queryLimit=query_limit,
|
||||
),
|
||||
)
|
||||
if raise_on_error and 'error' in resp:
|
||||
raise Exception('run_sql failed')
|
||||
if raise_on_error and "error" in resp:
|
||||
raise Exception("run_sql failed")
|
||||
return resp
|
||||
|
||||
def validate_sql(self, sql, client_id=None, user_name=None,
|
||||
raise_on_error=False):
|
||||
def validate_sql(self, sql, client_id=None, user_name=None, raise_on_error=False):
|
||||
if user_name:
|
||||
self.logout()
|
||||
self.login(username=(user_name if user_name else 'admin'))
|
||||
self.login(username=(user_name if user_name else "admin"))
|
||||
dbid = self.get_main_database().id
|
||||
resp = self.get_json_resp(
|
||||
'/superset/validate_sql_json/',
|
||||
"/superset/validate_sql_json/",
|
||||
raise_on_error=False,
|
||||
data=dict(database_id=dbid, sql=sql, client_id=client_id),
|
||||
)
|
||||
if raise_on_error and 'error' in resp:
|
||||
raise Exception('validate_sql failed')
|
||||
if raise_on_error and "error" in resp:
|
||||
raise Exception("validate_sql failed")
|
||||
return resp
|
||||
|
||||
@patch.dict('superset._feature_flags', {'FOO': True}, clear=True)
|
||||
@patch.dict("superset._feature_flags", {"FOO": True}, clear=True)
|
||||
def test_existing_feature_flags(self):
|
||||
self.assertTrue(is_feature_enabled('FOO'))
|
||||
self.assertTrue(is_feature_enabled("FOO"))
|
||||
|
||||
@patch.dict('superset._feature_flags', {}, clear=True)
|
||||
@patch.dict("superset._feature_flags", {}, clear=True)
|
||||
def test_nonexistent_feature_flags(self):
|
||||
self.assertFalse(is_feature_enabled('FOO'))
|
||||
self.assertFalse(is_feature_enabled("FOO"))
|
||||
|
||||
def test_feature_flags(self):
|
||||
self.assertEquals(is_feature_enabled('foo'), 'bar')
|
||||
self.assertEquals(is_feature_enabled('super'), 'set')
|
||||
self.assertEquals(is_feature_enabled("foo"), "bar")
|
||||
self.assertEquals(is_feature_enabled("super"), "set")
|
||||
|
||||
Reference in New Issue
Block a user