fix: SSH Tunnel configuration settings (#27186)

This commit is contained in:
Geido
2024-03-11 16:56:54 +01:00
committed by GitHub
parent fde93dcf08
commit 89e89de341
24 changed files with 871 additions and 271 deletions

View File

@@ -35,6 +35,7 @@ from sqlalchemy.exc import DBAPIError
from sqlalchemy.sql import func
from superset import db, security_manager
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
@@ -336,6 +337,58 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_create_database_with_missing_port_raises_error(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
)
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@@ -397,6 +450,154 @@ class TestDatabaseApi(SupersetTestCase):
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_update_database_with_missing_port_raises_error(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test that missing port raises SSHTunnelDatabaseError
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
modified_sqlalchemy_uri = "postgresql://foo:bar@localhost/test-db"
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": modified_sqlalchemy_uri,
"ssh_tunnel": ssh_tunnel_properties,
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response_create = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response_create.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 400)
self.assertEqual(
response.get("message"),
"A database port is required when connecting via SSH Tunnel.",
)
# Cleanup
model = db.session.query(Database).get(response_create.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)
@mock.patch("superset.commands.database.create.is_feature_enabled")
@mock.patch("superset.commands.database.update.is_feature_enabled")
@mock.patch("superset.commands.database.ssh_tunnel.delete.is_feature_enabled")
@mock.patch(
"superset.models.core.Database.get_all_schema_names",
)
def test_delete_ssh_tunnel(
self,
mock_test_connection_database_command_run,
mock_create_is_feature_enabled,
mock_update_is_feature_enabled,
mock_delete_is_feature_enabled,
mock_get_all_schema_names,
):
"""
Database API: Test deleting a SSH tunnel via Database update
"""
mock_create_is_feature_enabled.return_value = True
mock_update_is_feature_enabled.return_value = True
mock_delete_is_feature_enabled.return_value = True
self.login(username="admin")
example_db = get_example_database()
if example_db.backend == "sqlite":
return
ssh_tunnel_properties = {
"server_address": "123.132.123.1",
"server_port": 8080,
"username": "foo",
"password": "bar",
}
database_data = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
}
database_data_with_ssh_tunnel = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": ssh_tunnel_properties,
}
uri = "api/v1/database/"
rv = self.client.post(uri, json=database_data)
response = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 201)
uri = "api/v1/database/{}".format(response.get("id"))
rv = self.client.put(uri, json=database_data_with_ssh_tunnel)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one()
)
self.assertEqual(model_ssh_tunnel.database_id, response_update.get("id"))
database_data_with_ssh_tunnel_null = {
"database_name": "test-db-with-ssh-tunnel",
"sqlalchemy_uri": example_db.sqlalchemy_uri_decrypted,
"ssh_tunnel": None,
}
rv = self.client.put(uri, json=database_data_with_ssh_tunnel_null)
response_update = json.loads(rv.data.decode("utf-8"))
self.assertEqual(rv.status_code, 200)
model_ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == response_update.get("id"))
.one_or_none()
)
assert model_ssh_tunnel is None
# Cleanup
model = db.session.query(Database).get(response.get("id"))
db.session.delete(model)
db.session.commit()
@mock.patch(
"superset.commands.database.test_connection.TestConnectionDatabaseCommand.run",
)

View File

@@ -19,7 +19,10 @@
import pytest
from sqlalchemy.orm.session import Session
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)
def test_create_ssh_tunnel_command() -> None:
@@ -27,7 +30,11 @@ def test_create_ssh_tunnel_command() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
properties = {
"database_id": database.id,
@@ -48,7 +55,11 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
@@ -65,3 +76,31 @@ def test_create_ssh_tunnel_command_invalid_params() -> None:
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
def test_create_ssh_tunnel_command_no_port() -> None:
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost/db",
)
properties = {
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
command = CreateSSHTunnelCommand(database, properties)
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
command.run()
assert str(excinfo.value) == (
"A database port is required when connecting via SSH Tunnel."
)

View File

@@ -20,11 +20,14 @@ from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelInvalidError
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
def session_with_data(request, session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
@@ -32,7 +35,8 @@ def session_with_data(session: Session) -> Iterator[Session]:
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqlalchemy_uri = getattr(request, "param", "postgresql://u:p@localhost:5432/db")
database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri)
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
@@ -93,3 +97,28 @@ def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
@pytest.mark.parametrize(
"session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
)
def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
update_payload = {"server_address": "Test update"}
command = UpdateSSHTunnelCommand(1, update_payload)
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
command.run()
assert str(excinfo.value) == (
"A database port is required when connecting via SSH Tunnel."
)