feat(ssh-tunnelling): Setup SSH Tunneling Commands for Database Connections (#21912)

Co-authored-by: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com>
Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
This commit is contained in:
Hugh A. Miles II
2023-01-03 17:22:42 -05:00
committed by GitHub
parent a7a4561550
commit ebaad10d6c
40 changed files with 1905 additions and 47 deletions

View File

@@ -35,6 +35,8 @@ from sqlalchemy.sql import func
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.db_engine_specs.redshift import RedshiftEngineSpec
@@ -280,6 +282,314 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_create_database_with_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_database_with_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test update with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_ssh_tunnel_via_database_api(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test update with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
initial_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
updated_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "Test",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": initial_ssh_tunnel_properties,
}
database_data_with_ssh_tunnel_update = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": updated_ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(model_ssh_tunnel.username, "foo")
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
self.assertEqual(model_ssh_tunnel.username, "Test")
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
self.assertEqual(model_ssh_tunnel.server_port, 8080)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_cascade_delete_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
}
database_data = {
"database_name": "test-db-failure-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
fail_message = {"message": "SSH Tunnel parameters are invalid."}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
self.assertEqual(response, fail_message)
# Cleanup
model = (
db.session.query(Database)
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
.one_or_none()
)
# the DB should not be created
assert model is None
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_get_database_returns_related_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test GET Database returns its related SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
response_ssh_tunnel = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "XXXXXXXXXX",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_database_invalid_configuration_method(self):
"""
Database API: Test create with an invalid configuration method.