mirror of
https://github.com/apache/superset.git
synced 2026-06-06 16:19:18 +00:00
feat(ssh-tunnelling): Setup SSH Tunneling Commands for Database Connections (#21912)
Co-authored-by: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
This commit is contained in:
@@ -35,6 +35,8 @@ from sqlalchemy.sql import func
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.db_engine_specs.redshift import RedshiftEngineSpec
|
||||
@@ -280,6 +282,314 @@ class TestDatabaseApi(SupersetTestCase):
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_create_database_with_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_update_database_with_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test update with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
}
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_update_ssh_tunnel_via_database_api(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test update with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
initial_ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
updated_ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "Test",
|
||||
}
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": initial_ssh_tunnel_properties,
|
||||
}
|
||||
database_data_with_ssh_tunnel_update = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": updated_ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
self.assertEqual(model_ssh_tunnel.username, "foo")
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
self.assertEqual(model_ssh_tunnel.username, "Test")
|
||||
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
|
||||
self.assertEqual(model_ssh_tunnel.server_port, 8080)
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_cascade_delete_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one_or_none()
|
||||
)
|
||||
assert model_ssh_tunnel is None
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-failure-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
fail_message = {"message": "SSH Tunnel parameters are invalid."}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one_or_none()
|
||||
)
|
||||
assert model_ssh_tunnel is None
|
||||
self.assertEqual(response, fail_message)
|
||||
# Cleanup
|
||||
model = (
|
||||
db.session.query(Database)
|
||||
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
|
||||
.one_or_none()
|
||||
)
|
||||
# the DB should not be created
|
||||
assert model is None
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_get_database_returns_related_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test GET Database returns its related SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
response_ssh_tunnel = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "XXXXXXXXXX",
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
def test_create_database_invalid_configuration_method(self):
|
||||
"""
|
||||
Database API: Test create with an invalid configuration method.
|
||||
|
||||
16
tests/integration_tests/databases/ssh_tunnel/__init__.py
Normal file
16
tests/integration_tests/databases/ssh_tunnel/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,76 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock, skip
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from superset import security_manager
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class TestCreateSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_create_invalid_database_id(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
command = CreateSSHTunnelCommand(
|
||||
None,
|
||||
{
|
||||
"server_address": "127.0.0.1",
|
||||
"server_port": 5432,
|
||||
"username": "test_user",
|
||||
},
|
||||
)
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
class TestUpdateSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_update_ssh_tunnel_not_found(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
# We have not created a SSH Tunnel yet so id = 1 is invalid
|
||||
command = UpdateSSHTunnelCommand(
|
||||
1,
|
||||
{
|
||||
"server_address": "127.0.0.1",
|
||||
"server_port": 5432,
|
||||
"username": "test_user",
|
||||
},
|
||||
)
|
||||
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel not found.")
|
||||
|
||||
|
||||
class TestDeleteSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_delete_ssh_tunnel_not_found(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
# We have not created a SSH Tunnel yet so id = 1 is invalid
|
||||
command = DeleteSSHTunnelCommand(1)
|
||||
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel not found.")
|
||||
Reference in New Issue
Block a user