feat(ssh-tunnelling): Setup SSH Tunneling Commands for Database Connections (#21912)

Co-authored-by: Antonio Rivero Martinez <38889534+Antonio-RiveroMartnez@users.noreply.github.com>
Co-authored-by: Elizabeth Thompson <eschutho@gmail.com>
This commit is contained in:
Hugh A. Miles II
2023-01-03 17:22:42 -05:00
committed by GitHub
parent a7a4561550
commit ebaad10d6c
40 changed files with 1905 additions and 47 deletions

View File

@@ -28,7 +28,7 @@ from __future__ import annotations
from typing import Callable, TYPE_CHECKING
from unittest.mock import MagicMock, Mock, PropertyMock
from flask import Flask
from flask import current_app, Flask
from flask.ctx import AppContext
from pytest import fixture
@@ -40,6 +40,7 @@ from tests.example_data.data_loading.pandas.pands_data_loading_conf import (
from tests.example_data.data_loading.pandas.table_df_convertor import (
TableToDfConvertorImpl,
)
from tests.integration_tests.test_app import app
SUPPORT_DATETIME_TYPE = "support_datetime_type"
@@ -70,8 +71,9 @@ def example_db_provider() -> Callable[[], Database]:
@fixture(scope="session")
def example_db_engine(example_db_provider: Callable[[], Database]) -> Engine:
with example_db_provider().get_sqla_engine_with_context() as engine:
return engine
with app.app_context():
with example_db_provider().get_sqla_engine_with_context() as engine:
return engine
@fixture(scope="session")

View File

@@ -35,6 +35,8 @@ from sqlalchemy.sql import func
from superset import db, security_manager
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
from superset.db_engine_specs.mysql import MySQLEngineSpec
from superset.db_engine_specs.postgres import PostgresEngineSpec
from superset.db_engine_specs.redshift import RedshiftEngineSpec
@@ -280,6 +282,314 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_create_database_with_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_database_with_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test update with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_ssh_tunnel_via_database_api(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test update with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
initial_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
updated_ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "Test",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": initial_ssh_tunnel_properties,
}
database_data_with_ssh_tunnel_update = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": updated_ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(model_ssh_tunnel.username, "foo")
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_update)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
self.assertEqual(model_ssh_tunnel.username, "Test")
self.assertEqual(model_ssh_tunnel.server_address, "123.132.123.1")
self.assertEqual(model_ssh_tunnel.server_port, 8080)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_cascade_delete_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_do_not_create_database_if_ssh_tunnel_creation_fails(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test create with SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
}
database_data = {
"database_name": "test-db-failure-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
fail_message = {"message": "SSH Tunnel parameters are invalid."}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 422)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
self.assertEqual(response, fail_message)
# Cleanup
model = (
db.session.query(Database)
.filter(Database.database_name == "test-db-failure-ssh-tunnel")
.one_or_none()
)
# the DB should not be created
assert model is None
@mock.patch(
"superset.databases.commands.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_get_database_returns_related_ssh_tunnel(
self, mock_test_connection_database_command_run, mock_get_all_schema_names
):
"""
Database API: Test GET Database returns its related SSH Tunnel
"""
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
response_ssh_tunnel = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "XXXXXXXXXX",
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response.get("id"))
self.assertEqual(response.get("result")["ssh_tunnel"], response_ssh_tunnel)
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
def test_create_database_invalid_configuration_method(self):
"""
Database API: Test create with an invalid configuration method.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,76 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest import mock, skip
from unittest.mock import patch
import pytest
from superset import security_manager
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelInvalidError,
SSHTunnelNotFoundError,
)
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from tests.integration_tests.base_tests import SupersetTestCase
class TestCreateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_create_invalid_database_id(self, mock_g):
mock_g.user = security_manager.find_user("admin")
command = CreateSSHTunnelCommand(
None,
{
"server_address": "127.0.0.1",
"server_port": 5432,
"username": "test_user",
},
)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
class TestUpdateSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_update_ssh_tunnel_not_found(self, mock_g):
mock_g.user = security_manager.find_user("admin")
# We have not created a SSH Tunnel yet so id = 1 is invalid
command = UpdateSSHTunnelCommand(
1,
{
"server_address": "127.0.0.1",
"server_port": 5432,
"username": "test_user",
},
)
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel not found.")
class TestDeleteSSHTunnelCommand(SupersetTestCase):
@mock.patch("superset.utils.core.g")
def test_delete_ssh_tunnel_not_found(self, mock_g):
mock_g.user = security_manager.find_user("admin")
# We have not created a SSH Tunnel yet so id = 1 is invalid
command = DeleteSSHTunnelCommand(1)
with pytest.raises(SSHTunnelNotFoundError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel not found.")

View File

@@ -191,3 +191,147 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None:
}
]
}
def test_delete_ssh_tunnel(
mocker: MockFixture,
app: Any,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test that we can delete SSH Tunnel
"""
with app.app_context():
from superset.databases.api import DatabaseRestApi
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
session.add(database)
session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
session.add(tunnel)
session.commit()
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/1/ssh_tunnel/")
assert response_delete_tunnel.json["message"] == "OK"
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel is None
def test_delete_ssh_tunnel_not_found(
mocker: MockFixture,
app: Any,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test that we cannot delete a tunnel that does not exist
"""
with app.app_context():
from superset.databases.api import DatabaseRestApi
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
DatabaseRestApi.datamodel.session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com",
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
session.add(database)
session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
session.add(tunnel)
session.commit()
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
assert response_delete_tunnel.json["message"] == "Not found"
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
assert response_tunnel is None

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,69 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
)
ssh_tunnel = SSHTunnel(
database_id=db.id,
database=db,
)
session.add(db)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
yield session
session.rollback()
def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
result = DatabaseDAO.get_ssh_tunnel(2)
assert result is None

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
def test_create_ssh_tunnel_command() -> None:
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
properties = {
"database_id": db.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
result = CreateSSHTunnelCommand(db.id, properties).run()
assert result is not None
assert isinstance(result, SSHTunnel)
def test_create_ssh_tunnel_command_invalid_params() -> None:
from superset.databases.ssh_tunnel.commands.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
"database_id": db.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}
command = CreateSSHTunnelCommand(db.id, properties)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")

View File

@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
)
ssh_tunnel = SSHTunnel(
database_id=db.id,
database=db,
)
session.add(db)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
yield session
session.rollback()
def test_delete_ssh_tunnel_command(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.delete import DeleteSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
DeleteSSHTunnelCommand(1).run()
result = DatabaseDAO.get_ssh_tunnel(1)
assert result is None

View File

@@ -0,0 +1,93 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
from superset.databases.ssh_tunnel.commands.exceptions import SSHTunnelInvalidError
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
db = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=db,
)
ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test")
session.add(db)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
yield session
session.rollback()
def test_update_shh_tunnel_command(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
update_payload = {"server_address": "Test2"}
UpdateSSHTunnelCommand(1, update_payload).run()
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert "Test2" == result.server_address
def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.commands.update import UpdateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
# If we are trying to update a tunnel with a private_key_password
# then a private_key is mandatory
update_payload = {"private_key_password": "pass"}
command = UpdateSSHTunnelCommand(1, update_payload)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")

View File

@@ -0,0 +1,43 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Iterator
import pytest
from sqlalchemy.orm.session import Session
def test_create_ssh_tunnel():
from superset.databases.dao import DatabaseDAO
from superset.databases.ssh_tunnel.dao import SSHTunnelDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
properties = {
"database_id": db.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
result = SSHTunnelDAO.create(properties)
assert result is not None
assert isinstance(result, SSHTunnel)