chore: enable lint PT009 'use regular assert over self.assert.*' (#30521)

This commit is contained in:
Maxime Beauchemin
2024-10-07 13:17:27 -07:00
committed by GitHub
parent 1f013055d2
commit a849c29288
62 changed files with 2218 additions and 2422 deletions

View File

@@ -187,7 +187,7 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = "api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
expected_columns = [
"allow_ctas",
@@ -216,8 +216,8 @@ class TestDatabaseApi(SupersetTestCase):
"uuid",
]
self.assertGreater(response["count"], 0)
self.assertEqual(list(response["result"][0].keys()), expected_columns)
assert response["count"] > 0
assert list(response["result"][0].keys()) == expected_columns
def test_get_items_filter(self):
"""
@@ -241,8 +241,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/?q={prison.dumps(arguments)}"
rv = self.client.get(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response["count"], len(dbs))
assert rv.status_code == 200
assert response["count"] == len(dbs)
# Cleanup
db.session.delete(test_database)
@@ -255,9 +255,9 @@ class TestDatabaseApi(SupersetTestCase):
self.login(GAMMA_USERNAME)
uri = "api/v1/database/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["count"], 0)
assert response["count"] == 0
def test_create_database(self):
"""
@@ -284,7 +284,7 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
# Cleanup
model = db.session.query(Database).get(response.get("id"))
assert model.configuration_method == ConfigurationMethod.SQLALCHEMY_FORM
@@ -326,14 +326,14 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(response.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX")
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
assert response.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX"
assert model_ssh_tunnel.database_id == response.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@@ -385,10 +385,10 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
assert rv.status_code == 400
assert (
response.get("message")
== "A database port is required when connecting via SSH Tunnel."
)
@mock.patch(
@@ -434,19 +434,19 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
assert model_ssh_tunnel.database_id == response_update.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@@ -500,15 +500,15 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response_create = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response_create.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
assert rv.status_code == 400
assert (
response.get("message")
== "A database port is required when connecting via SSH Tunnel."
)
# Cleanup
@@ -563,19 +563,19 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
assert model_ssh_tunnel.database_id == response_update.get("id")
database_data_with_ssh_tunnel_null = {
"database_name": "test-db-with-ssh-tunnel",
@@ -585,7 +585,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
@@ -651,30 +651,28 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(model_ssh_tunnel.username, "foo")
assert model_ssh_tunnel.database_id == response.get("id")
assert model_ssh_tunnel.username == "foo"
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
self.assertEqual(
response_update.get("result")["ssh_tunnel"]["password"], "XXXXXXXXXX"
)
self.assertEqual(model_ssh_tunnel.username, "Test")
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
self.assertEqual(model_ssh_tunnel.server_port, 8080)
assert model_ssh_tunnel.database_id == response_update.get("id")
assert response_update.get("result")["ssh_tunnel"]["password"] == "XXXXXXXXXX"
assert model_ssh_tunnel.username == "Test"
assert model_ssh_tunnel.server_address == "123.132.123.1"
assert model_ssh_tunnel.server_port == 8080
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@@ -715,13 +713,13 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
assert model_ssh_tunnel.database_id == response.get("id")
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@@ -769,7 +767,7 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
model_ssh_tunnel = (
db.session.query(SSHTunnel)
@@ -777,7 +775,7 @@ class TestDatabaseApi(SupersetTestCase):
.one_or_none()
)
assert model_ssh_tunnel is None
self.assertEqual(response, fail_message)
assert response == fail_message
# Check that rollback was called
mock_rollback.assert_called()
@@ -824,14 +822,14 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
assert model_ssh_tunnel.database_id == response.get("id")
assert response.get("result")["ssh_tunnel"] == response_ssh_tunnel
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
@@ -866,8 +864,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, {"message": "SSH Tunneling is not enabled"})
assert rv.status_code == 400
assert response == {"message": "SSH Tunneling is not enabled"}
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
@@ -897,7 +895,7 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{database.id}/table/{table_name}/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
def test_create_database_invalid_configuration_method(self):
"""
@@ -959,7 +957,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
assert rv.status_code == 201
self.assertIn("sqlalchemy_form", response["result"]["configuration_method"])
assert "sqlalchemy_form" in response["result"]["configuration_method"]
def test_create_database_server_cert_validate(self):
"""
@@ -981,8 +979,8 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
expected_response = {"message": {"server_cert": ["Invalid certificate"]}}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
assert rv.status_code == 400
assert response == expected_response
def test_create_database_json_validate(self):
"""
@@ -1016,8 +1014,8 @@ class TestDatabaseApi(SupersetTestCase):
],
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
assert rv.status_code == 400
assert response == expected_response
def test_create_database_extra_metadata_validate(self):
"""
@@ -1052,8 +1050,8 @@ class TestDatabaseApi(SupersetTestCase):
]
}
}
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, expected_response)
assert rv.status_code == 400
assert response == expected_response
def test_create_database_unique_validate(self):
"""
@@ -1078,8 +1076,8 @@ class TestDatabaseApi(SupersetTestCase):
"database_name": "A database with the same name already exists."
}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
assert rv.status_code == 422
assert response == expected_response
def test_create_database_uri_validate(self):
"""
@@ -1095,11 +1093,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertIn(
"Invalid connection string",
response["message"]["sqlalchemy_uri"][0],
)
assert rv.status_code == 400
assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0]
@mock.patch(
"superset.views.core.app.config",
@@ -1127,8 +1122,8 @@ class TestDatabaseApi(SupersetTestCase):
]
}
}
self.assertEqual(response_data, expected_response)
self.assertEqual(response.status_code, 400)
assert response_data == expected_response
assert response.status_code == 400
def test_create_database_conn_fail(self):
"""
@@ -1192,11 +1187,11 @@ class TestDatabaseApi(SupersetTestCase):
expected_response_postgres = {
"errors": [dataclasses.asdict(superset_error_postgres)]
}
self.assertEqual(response.status_code, 500)
assert response.status_code == 500
if example_db.backend == "mysql":
self.assertEqual(response_data, expected_response_mysql)
assert response_data == expected_response_mysql
else:
self.assertEqual(response_data, expected_response_postgres)
assert response_data == expected_response_postgres
def test_update_database(self):
"""
@@ -1213,7 +1208,7 @@ class TestDatabaseApi(SupersetTestCase):
}
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
@@ -1242,8 +1237,8 @@ class TestDatabaseApi(SupersetTestCase):
expected_response = {
"message": "Connection failed, please check your connection settings"
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
assert rv.status_code == 422
assert response == expected_response
# Cleanup
model = db.session.query(Database).get(test_database.id)
db.session.delete(model)
@@ -1271,8 +1266,8 @@ class TestDatabaseApi(SupersetTestCase):
"database_name": "A database with the same name already exists."
}
}
self.assertEqual(rv.status_code, 422)
self.assertEqual(response, expected_response)
assert rv.status_code == 422
assert response == expected_response
# Cleanup
db.session.delete(test_database1)
db.session.delete(test_database2)
@@ -1286,7 +1281,7 @@ class TestDatabaseApi(SupersetTestCase):
database_data = {"database_name": "test-database-updated"}
uri = "api/v1/database/invalid"
rv = self.client.put(uri, json=database_data)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_update_database_uri_validate(self):
"""
@@ -1305,11 +1300,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{test_database.id}"
rv = self.client.put(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertIn(
"Invalid connection string",
response["message"]["sqlalchemy_uri"][0],
)
assert rv.status_code == 400
assert "Invalid connection string" in response["message"]["sqlalchemy_uri"][0]
db.session.delete(test_database)
db.session.commit()
@@ -1369,9 +1361,9 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database_id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
model = db.session.query(Database).get(database_id)
self.assertEqual(model, None)
assert model is None
def test_delete_database_not_found(self):
"""
@@ -1381,7 +1373,7 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{max_id + 1}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("create_database_with_dataset")
def test_delete_database_with_datasets(self):
@@ -1391,7 +1383,7 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{self._database.id}"
rv = self.delete_assert_metric(uri, "delete")
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
@pytest.mark.usefixtures("create_database_with_report")
def test_delete_database_with_report(self):
@@ -1407,11 +1399,11 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{database.id}"
rv = self.client.delete(uri)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
expected_response = {
"message": "There are associated alerts or reports: report_with_database"
}
self.assertEqual(response, expected_response)
assert response == expected_response
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_table_metadata(self):
@@ -1422,12 +1414,12 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{example_db.id}/table/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["name"], "birth_names")
self.assertIsNone(response["comment"])
self.assertTrue(len(response["columns"]) > 5)
self.assertTrue(response.get("selectStar").startswith("SELECT"))
assert response["name"] == "birth_names"
assert response["comment"] is None
assert len(response["columns"]) > 5
assert response.get("selectStar").startswith("SELECT")
def test_info_security_database(self):
"""
@@ -1456,11 +1448,11 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database_id}/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
uri = "api/v1/database/some_database/table/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_get_invalid_table_table_metadata(self):
"""
@@ -1472,25 +1464,22 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
if example_db.backend == "sqlite":
self.assertEqual(rv.status_code, 200)
self.assertEqual(
data,
{
"columns": [],
"comment": None,
"foreignKeys": [],
"indexes": [],
"name": "wrong_table",
"primaryKey": {"constrained_columns": None, "name": None},
"selectStar": "SELECT\n *\nFROM wrong_table\nLIMIT 100\nOFFSET 0",
},
)
assert rv.status_code == 200
assert data == {
"columns": [],
"comment": None,
"foreignKeys": [],
"indexes": [],
"name": "wrong_table",
"primaryKey": {"constrained_columns": None, "name": None},
"selectStar": "SELECT\n *\nFROM wrong_table\nLIMIT 100\nOFFSET 0",
}
elif example_db.backend == "mysql":
self.assertEqual(rv.status_code, 422)
self.assertEqual(data, {"message": "`wrong_table`"})
assert rv.status_code == 422
assert data == {"message": "`wrong_table`"}
else:
self.assertEqual(rv.status_code, 422)
self.assertEqual(data, {"message": "wrong_table"})
assert rv.status_code == 422
assert data == {"message": "wrong_table"}
def test_get_table_metadata_no_db_permission(self):
"""
@@ -1500,7 +1489,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_table_extra_metadata_deprecated(self):
@@ -1511,9 +1500,9 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{example_db.id}/table_extra/birth_names/null/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response, {})
assert response == {}
def test_get_invalid_database_table_extra_metadata_deprecated(self):
"""
@@ -1523,11 +1512,11 @@ class TestDatabaseApi(SupersetTestCase):
self.login(ADMIN_USERNAME)
uri = f"api/v1/database/{database_id}/table_extra/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
uri = "api/v1/database/some_database/table_extra/some_table/some_schema/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_get_invalid_table_table_extra_metadata_deprecated(self):
"""
@@ -1539,8 +1528,8 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(data, {})
assert rv.status_code == 200
assert data == {}
@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_get_select_star(self):
@@ -1551,7 +1540,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
def test_get_select_star_not_allowed(self):
"""
@@ -1561,7 +1550,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_get_select_star_not_found_database(self):
"""
@@ -1571,7 +1560,7 @@ class TestDatabaseApi(SupersetTestCase):
max_id = db.session.query(func.max(Database.id)).scalar()
uri = f"api/v1/database/{max_id + 1}/select_star/birth_names/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_get_select_star_not_found_table(self):
"""
@@ -1585,7 +1574,7 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{example_db.id}/select_star/table_does_not_exist/"
rv = self.client.get(uri)
# TODO(bkyryliuk): investigate why presto returns 500
self.assertEqual(rv.status_code, 404 if example_db.backend != "presto" else 500)
assert rv.status_code == (404 if example_db.backend != "presto" else 500)
def test_get_allow_file_upload_filter(self):
"""
@@ -1952,13 +1941,13 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(f"api/v1/database/{database.id}/schemas/")
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, set(response["result"]))
assert schemas == set(response["result"])
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': True})}"
)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(schemas, set(response["result"]))
assert schemas == set(response["result"])
def test_database_schemas_not_found(self):
"""
@@ -1968,7 +1957,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/schemas/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_database_schemas_invalid_query(self):
"""
@@ -1979,7 +1968,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(
f"api/v1/database/{database.id}/schemas/?q={prison.dumps({'force': 'nop'})}"
)
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
def test_database_tables(self):
"""
@@ -1993,17 +1982,17 @@ class TestDatabaseApi(SupersetTestCase):
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': schema_name})}"
)
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
if database.backend == "postgresql":
response = json.loads(rv.data.decode("utf-8"))
schemas = [
s[0] for s in database.get_all_table_names_in_schema(None, schema_name)
]
self.assertEqual(response["count"], len(schemas))
assert response["count"] == len(schemas)
for option in response["result"]:
self.assertEqual(option["extra"], None)
self.assertEqual(option["type"], "table")
self.assertTrue(option["value"] in schemas)
assert option["extra"] is None
assert option["type"] == "table"
assert option["value"] in schemas
@patch("superset.utils.log.logger")
def test_database_tables_not_found(self, logger_mock):
@@ -2014,7 +2003,7 @@ class TestDatabaseApi(SupersetTestCase):
example_db = get_example_database()
uri = f"api/v1/database/{example_db.id}/tables/?q={prison.dumps({'schema_name': 'non_existent'})}"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
logger_mock.warning.assert_called_once_with(
"Database not found.", exc_info=True
)
@@ -2028,7 +2017,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'force': 'nop'})}"
)
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
@patch("superset.utils.log.logger")
@mock.patch("superset.security.manager.SupersetSecurityManager.can_access_database")
@@ -2046,7 +2035,7 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(
f"api/v1/database/{database.id}/tables/?q={prison.dumps({'schema_name': 'main'})}"
)
self.assertEqual(rv.status_code, 422)
assert rv.status_code == 422
logger_mock.warning.assert_called_once_with("Test Error", exc_info=True)
def test_test_connection(self):
@@ -2074,8 +2063,8 @@ class TestDatabaseApi(SupersetTestCase):
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
assert rv.status_code == 200
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
# validate that the endpoint works with the decrypted sqlalchemy uri
data = {
@@ -2086,8 +2075,8 @@ class TestDatabaseApi(SupersetTestCase):
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 200)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
assert rv.status_code == 200
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
def test_test_connection_failed(self):
"""
@@ -2103,8 +2092,8 @@ class TestDatabaseApi(SupersetTestCase):
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
assert rv.status_code == 422
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"errors": [
@@ -2123,7 +2112,7 @@ class TestDatabaseApi(SupersetTestCase):
}
]
}
self.assertEqual(response, expected_response)
assert response == expected_response
data = {
"sqlalchemy_uri": "mssql+pymssql://url",
@@ -2132,8 +2121,8 @@ class TestDatabaseApi(SupersetTestCase):
"server_cert": None,
}
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 422)
self.assertEqual(rv.headers["Content-Type"], "application/json; charset=utf-8")
assert rv.status_code == 422
assert rv.headers["Content-Type"] == "application/json; charset=utf-8"
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"errors": [
@@ -2152,7 +2141,7 @@ class TestDatabaseApi(SupersetTestCase):
}
]
}
self.assertEqual(response, expected_response)
assert response == expected_response
def test_test_connection_unsafe_uri(self):
"""
@@ -2169,7 +2158,7 @@ class TestDatabaseApi(SupersetTestCase):
}
url = "api/v1/database/test_connection/"
rv = self.post_assert_metric(url, data, "test_connection")
self.assertEqual(rv.status_code, 400)
assert rv.status_code == 400
response = json.loads(rv.data.decode("utf-8"))
expected_response = {
"message": {
@@ -2178,7 +2167,7 @@ class TestDatabaseApi(SupersetTestCase):
]
}
}
self.assertEqual(response, expected_response)
assert response == expected_response
app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = False
@@ -2250,10 +2239,10 @@ class TestDatabaseApi(SupersetTestCase):
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 200)
assert rv.status_code == 200
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(response["charts"]["count"], 33)
self.assertEqual(response["dashboards"]["count"], 3)
assert response["charts"]["count"] == 33
assert response["dashboards"]["count"] == 3
def test_get_database_related_objects_not_found(self):
"""
@@ -2265,13 +2254,13 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{invalid_id}/related_objects/"
self.login(ADMIN_USERNAME)
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
self.logout()
self.login(GAMMA_USERNAME)
database = get_example_database()
uri = f"api/v1/database/{database.id}/related_objects/"
rv = self.get_assert_metric(uri, "related_objects")
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
def test_export_database(self):
"""
@@ -2679,7 +2668,7 @@ class TestDatabaseApi(SupersetTestCase):
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.password, "TEST")
assert model_ssh_tunnel.password == "TEST"
db.session.delete(database)
db.session.commit()
@@ -2797,8 +2786,8 @@ class TestDatabaseApi(SupersetTestCase):
.filter(SSHTunnel.database_id == database.id)
.one()
)
self.assertEqual(model_ssh_tunnel.private_key, "TestPrivateKey")
self.assertEqual(model_ssh_tunnel.private_key_password, "TEST")
assert model_ssh_tunnel.private_key == "TestPrivateKey"
assert model_ssh_tunnel.private_key_password == "TEST"
db.session.delete(database)
db.session.commit()
@@ -3852,8 +3841,8 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(response["result"], [])
assert rv.status_code == 200
assert response["result"] == []
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
@@ -3878,18 +3867,15 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
self.assertEqual(
response["result"],
[
{
"end_column": None,
"line_number": 1,
"message": 'ERROR: syntax error at or near "table1"',
"start_column": None,
}
],
)
assert rv.status_code == 200
assert response["result"] == [
{
"end_column": None,
"line_number": 1,
"message": 'ERROR: syntax error at or near "table1"',
"start_column": None,
}
]
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
@@ -3910,7 +3896,7 @@ class TestDatabaseApi(SupersetTestCase):
f"api/v1/database/{self.get_nonexistent_numeric_id(Database)}/validate_sql/"
)
rv = self.client.post(uri, json=request_payload)
self.assertEqual(rv.status_code, 404)
assert rv.status_code == 404
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
@@ -3932,8 +3918,8 @@ class TestDatabaseApi(SupersetTestCase):
)
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(response, {"message": {"sql": ["Field may not be null."]}})
assert rv.status_code == 400
assert response == {"message": {"sql": ["Field may not be null."]}}
@mock.patch.dict(
"superset.config.SQL_VALIDATORS_BY_ENGINE",
@@ -3956,29 +3942,26 @@ class TestDatabaseApi(SupersetTestCase):
uri = f"api/v1/database/{example_db.id}/validate_sql/"
rv = self.client.post(uri, json=request_payload)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
self.assertEqual(
response,
{
"errors": [
{
"message": f"no SQL validator is configured for "
f"{example_db.backend}",
"error_type": "GENERIC_DB_ENGINE_ERROR",
"level": "error",
"extra": {
"issue_codes": [
{
"code": 1002,
"message": "Issue 1002 - The database returned an "
"unexpected error.",
}
]
},
}
]
},
)
assert rv.status_code == 422
assert response == {
"errors": [
{
"message": f"no SQL validator is configured for "
f"{example_db.backend}",
"error_type": "GENERIC_DB_ENGINE_ERROR",
"level": "error",
"extra": {
"issue_codes": [
{
"code": 1002,
"message": "Issue 1002 - The database returned an "
"unexpected error.",
}
]
},
}
]
}
@mock.patch("superset.commands.database.validate_sql.get_validator_by_name")
@mock.patch.dict(
@@ -4013,8 +3996,8 @@ class TestDatabaseApi(SupersetTestCase):
# TODO(bkyryliuk): properly handle hive error
if get_example_database().backend == "hive":
return
self.assertEqual(rv.status_code, 422)
self.assertIn("Kaboom!", response["errors"][0]["message"])
assert rv.status_code == 422
assert "Kaboom!" in response["errors"][0]["message"]
def test_get_databases_with_extra_filters(self):
"""
@@ -4048,14 +4031,14 @@ class TestDatabaseApi(SupersetTestCase):
uri, json={**database_data, "database_name": "dyntest-create-database-1"}
)
first_response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
uri = "api/v1/database/"
rv = self.client.post(
uri, json={**database_data, "database_name": "create-database-2"}
)
second_response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
assert rv.status_code == 201
# The filter function
def _base_filter(query):
@@ -4074,11 +4057,11 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
# All databases must be returned if no filter is present
self.assertEqual(data["count"], len(dbs))
assert data["count"] == len(dbs)
database_names = [item["database_name"] for item in data["result"]]
database_names.sort()
# All Databases because we are an admin
self.assertEqual(database_names, expected_names)
assert database_names == expected_names
assert rv.status_code == 200
# Our filter function wasn't get called
base_filter_mock.assert_not_called()
@@ -4092,10 +4075,10 @@ class TestDatabaseApi(SupersetTestCase):
rv = self.client.get(uri)
data = json.loads(rv.data.decode("utf-8"))
# Only one database start with dyntest
self.assertEqual(data["count"], 1)
assert data["count"] == 1
database_names = [item["database_name"] for item in data["result"]]
# Only the database that starts with tests, even if we are an admin
self.assertEqual(database_names, ["dyntest-create-database-1"])
assert database_names == ["dyntest-create-database-1"]
assert rv.status_code == 200
# The filter function is called now that it's defined in our config
base_filter_mock.assert_called()