mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
feat(ssh-tunnelling): Setup SSH Tunneling Commands for Database Connections (#21912)
Co-authored-by: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com> Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from __future__ import annotations
|
||||
from typing import Callable, TYPE_CHECKING
|
||||
from unittest.mock import MagicMock, Mock, PropertyMock
|
||||
|
||||
from flask import Flask
|
||||
from flask import current_app, Flask
|
||||
from flask.ctx import AppContext
|
||||
from pytest import fixture
|
||||
|
||||
@@ -40,6 +40,7 @@ from tests.example_data.data_loading.pandas.pands_data_loading_conf import (
|
||||
from tests.example_data.data_loading.pandas.table_df_convertor import (
|
||||
TableToDfConvertorImpl,
|
||||
)
|
||||
from tests.integration_tests.test_app import app
|
||||
|
||||
SUPPORT_DATETIME_TYPE = "support_datetime_type"
|
||||
|
||||
@@ -70,8 +71,9 @@ def example_db_provider() -> Callable[[], Database]:
|
||||
|
||||
@fixture(scope="session")
|
||||
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
|
||||
with example_db_provider().get_sqla_engine_with_context() as engine:
|
||||
return engine
|
||||
with app.app_context():
|
||||
with example_db_provider().get_sqla_engine_with_context() as engine:
|
||||
return engine
|
||||
|
||||
|
||||
@fixture(scope="session")
|
||||
|
||||
@@ -35,6 +35,8 @@ from sqlalchemy.sql import func
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs.mysql import MySQLEngineSpec
|
||||
from superset.db_engine_specs.postgres import PostgresEngineSpec
|
||||
from superset.db_engine_specs.redshift import RedshiftEngineSpec
|
||||
@@ -280,6 +282,314 @@ class TestDatabaseApi(SupersetTestCase):
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_create_database_with_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_update_database_with_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test update with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
}
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_update_ssh_tunnel_via_database_api(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test update with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
initial_ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
updated_ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "Test",
|
||||
}
|
||||
database_data_with_ssh_tunnel = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": initial_ssh_tunnel_properties,
|
||||
}
|
||||
database_data_with_ssh_tunnel_update = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": updated_ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
self.assertEqual(model_ssh_tunnel.username, "foo")
|
||||
uri = "api/v1/database/{}".format(response.get("id"))
|
||||
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
|
||||
response_update = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 200)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response_update.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
|
||||
self.assertEqual(model_ssh_tunnel.username, "Test")
|
||||
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
|
||||
self.assertEqual(model_ssh_tunnel.server_port, 8080)
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_cascade_delete_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one_or_none()
|
||||
)
|
||||
assert model_ssh_tunnel is None
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test create with SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-failure-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
fail_message = {"message": "SSH Tunnel parameters are invalid."}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 422)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one_or_none()
|
||||
)
|
||||
assert model_ssh_tunnel is None
|
||||
self.assertEqual(response, fail_message)
|
||||
# Cleanup
|
||||
model = (
|
||||
db.session.query(Database)
|
||||
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
|
||||
.one_or_none()
|
||||
)
|
||||
# the DB should not be created
|
||||
assert model is None
|
||||
|
||||
@mock.patch(
|
||||
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
|
||||
)
|
||||
@mock.patch(
|
||||
"superset.models.core.Database.get_all_schema_names",
|
||||
)
|
||||
def test_get_database_returns_related_ssh_tunnel(
|
||||
self, mock_test_connection_database_command_run, mock_get_all_schema_names
|
||||
):
|
||||
"""
|
||||
Database API: Test GET Database returns its related SSH Tunnel
|
||||
"""
|
||||
self.login(username="admin")
|
||||
example_db = get_example_database()
|
||||
if example_db.backend == "sqlite":
|
||||
return
|
||||
ssh_tunnel_properties = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
database_data = {
|
||||
"database_name": "test-db-with-ssh-tunnel",
|
||||
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
|
||||
"ssh_tunnel": ssh_tunnel_properties,
|
||||
}
|
||||
response_ssh_tunnel = {
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": 8080,
|
||||
"username": "foo",
|
||||
"password": "XXXXXXXXXX",
|
||||
}
|
||||
|
||||
uri = "api/v1/database/"
|
||||
rv = self.client.post(uri, json=database_data)
|
||||
response = json.loads(rv.data.decode("utf-8"))
|
||||
self.assertEqual(rv.status_code, 201)
|
||||
model_ssh_tunnel = (
|
||||
db.session.query(SSHTunnel)
|
||||
.filter(SSHTunnel.database_id == response.get("id"))
|
||||
.one()
|
||||
)
|
||||
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
|
||||
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
|
||||
# Cleanup
|
||||
model = db.session.query(Database).get(response.get("id"))
|
||||
db.session.delete(model)
|
||||
db.session.commit()
|
||||
|
||||
def test_create_database_invalid_configuration_method(self):
|
||||
"""
|
||||
Database API: Test create with an invalid configuration method.
|
||||
|
||||
16
tests/integration_tests/databases/ssh_tunnel/__init__.py
Normal file
16
tests/integration_tests/databases/ssh_tunnel/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,76 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest import mock, skip
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from superset import security_manager
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelInvalidError,
|
||||
SSHTunnelNotFoundError,
|
||||
)
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from tests.integration_tests.base_tests import SupersetTestCase
|
||||
|
||||
|
||||
class TestCreateSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_create_invalid_database_id(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
command = CreateSSHTunnelCommand(
|
||||
None,
|
||||
{
|
||||
"server_address": "127.0.0.1",
|
||||
"server_port": 5432,
|
||||
"username": "test_user",
|
||||
},
|
||||
)
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
class TestUpdateSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_update_ssh_tunnel_not_found(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
# We have not created a SSH Tunnel yet so id = 1 is invalid
|
||||
command = UpdateSSHTunnelCommand(
|
||||
1,
|
||||
{
|
||||
"server_address": "127.0.0.1",
|
||||
"server_port": 5432,
|
||||
"username": "test_user",
|
||||
},
|
||||
)
|
||||
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel not found.")
|
||||
|
||||
|
||||
class TestDeleteSSHTunnelCommand(SupersetTestCase):
|
||||
@mock.patch("superset.utils.core.g")
|
||||
def test_delete_ssh_tunnel_not_found(self, mock_g):
|
||||
mock_g.user = security_manager.find_user("admin")
|
||||
# We have not created a SSH Tunnel yet so id = 1 is invalid
|
||||
command = DeleteSSHTunnelCommand(1)
|
||||
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel not found.")
|
||||
@@ -191,3 +191,147 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None:
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel(
|
||||
mocker: MockFixture,
|
||||
app: Any,
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test that we can delete SSH Tunnel
|
||||
"""
|
||||
with app.app_context():
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
DatabaseRestApi.datamodel.session = session
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
|
||||
# Create our Database
|
||||
database = Database(
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="gsheets://",
|
||||
encrypted_extra=json.dumps(
|
||||
{
|
||||
"service_account_info": {
|
||||
"type": "service_account",
|
||||
"project_id": "black-sanctum-314419",
|
||||
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
||||
"private_key": "SECRET",
|
||||
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
|
||||
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||
|
||||
# Create our SSHTunnel
|
||||
tunnel = SSHTunnel(
|
||||
database_id=1,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(tunnel)
|
||||
session.commit()
|
||||
|
||||
# Get our recently created SSHTunnel
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
assert response_tunnel
|
||||
assert isinstance(response_tunnel, SSHTunnel)
|
||||
assert 1 == response_tunnel.database_id
|
||||
|
||||
# Delete the recently created SSHTunnel
|
||||
response_delete_tunnel = client.delete("/api/v1/database/1/ssh_tunnel/")
|
||||
assert response_delete_tunnel.json["message"] == "OK"
|
||||
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
assert response_tunnel is None
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel_not_found(
|
||||
mocker: MockFixture,
|
||||
app: Any,
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test that we cannot delete a tunnel that does not exist
|
||||
"""
|
||||
with app.app_context():
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
DatabaseRestApi.datamodel.session = session
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
|
||||
# Create our Database
|
||||
database = Database(
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="gsheets://",
|
||||
encrypted_extra=json.dumps(
|
||||
{
|
||||
"service_account_info": {
|
||||
"type": "service_account",
|
||||
"project_id": "black-sanctum-314419",
|
||||
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
||||
"private_key": "SECRET",
|
||||
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
|
||||
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
session.add(database)
|
||||
session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||
|
||||
# Create our SSHTunnel
|
||||
tunnel = SSHTunnel(
|
||||
database_id=1,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(tunnel)
|
||||
session.commit()
|
||||
|
||||
# Delete the recently created SSHTunnel
|
||||
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
|
||||
assert response_delete_tunnel.json["message"] == "Not found"
|
||||
|
||||
# Get our recently created SSHTunnel
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
assert response_tunnel
|
||||
assert isinstance(response_tunnel, SSHTunnel)
|
||||
assert 1 == response_tunnel.database_id
|
||||
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
|
||||
assert response_tunnel is None
|
||||
|
||||
16
tests/unit_tests/databases/dao/__init__.py
Normal file
16
tests/unit_tests/databases/dao/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
69
tests/unit_tests/databases/dao/dao_tests.py
Normal file
69
tests/unit_tests/databases/dao/dao_tests.py
Normal file
@@ -0,0 +1,69 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_data(session: Session) -> Iterator[Session]:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=db.id,
|
||||
database=db,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
|
||||
|
||||
def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(2)
|
||||
|
||||
assert result is None
|
||||
16
tests/unit_tests/databases/ssh_tunnel/__init__.py
Normal file
16
tests/unit_tests/databases/ssh_tunnel/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
16
tests/unit_tests/databases/ssh_tunnel/commands/__init__.py
Normal file
16
tests/unit_tests/databases/ssh_tunnel/commands/__init__.py
Normal file
@@ -0,0 +1,16 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -0,0 +1,68 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command() -> None:
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
properties = {
|
||||
"database_id": db.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
result = CreateSSHTunnelCommand(db.id, properties).run()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SSHTunnel)
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command_invalid_params() -> None:
|
||||
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
# If we are trying to create a tunnel with a private_key_password
|
||||
# then a private_key is mandatory
|
||||
properties = {
|
||||
"database_id": db.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"private_key_password": "bar",
|
||||
}
|
||||
|
||||
command = CreateSSHTunnelCommand(db.id, properties)
|
||||
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||
@@ -0,0 +1,68 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_data(session: Session) -> Iterator[Session]:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=db.id,
|
||||
database=db,
|
||||
)
|
||||
|
||||
session.add(db)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
|
||||
DeleteSSHTunnelCommand(1).run()
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result is None
|
||||
@@ -0,0 +1,93 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_data(session: Session) -> Iterator[Session]:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=db,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
|
||||
|
||||
session.add(db)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_update_shh_tunnel_command(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
assert "Test" == result.server_address
|
||||
|
||||
update_payload = {"server_address": "Test2"}
|
||||
UpdateSSHTunnelCommand(1, update_payload).run()
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert "Test2" == result.server_address
|
||||
|
||||
|
||||
def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
assert "Test" == result.server_address
|
||||
|
||||
# If we are trying to update a tunnel with a private_key_password
|
||||
# then a private_key is mandatory
|
||||
update_payload = {"private_key_password": "pass"}
|
||||
command = UpdateSSHTunnelCommand(1, update_payload)
|
||||
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||
43
tests/unit_tests/databases/ssh_tunnel/dao_tests.py
Normal file
43
tests/unit_tests/databases/ssh_tunnel/dao_tests.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from typing import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
def test_create_ssh_tunnel():
|
||||
from superset.databases.dao import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
properties = {
|
||||
"database_id": db.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
result = SSHTunnelDAO.create(properties)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SSHTunnel)
|
||||
Reference in New Issue
Block a user