chore: cleanup ssh tunnel (#34388)

This commit is contained in:
Beto Dealmeida
2025-12-03 14:26:35 -05:00
committed by GitHub
parent 70aec7fa76
commit c458f99dd4
35 changed files with 304 additions and 1287 deletions

View File

@@ -43,7 +43,6 @@ def test_sync_permissions_command_sync_mode(
"""
Test ``SyncPermissionsCommand`` in sync mode.
"""
mock_ssh = mocker.MagicMock()
user_mock = mocker.patch(
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
)
@@ -55,7 +54,9 @@ def test_sync_permissions_command_sync_mode(
add_pvm_mock = mocker.patch("superset.commands.database.sync_permissions.add_pvm")
cmmd = SyncPermissionsCommand(
1, "admin", db_connection=database_with_catalog, ssh_tunnel=mock_ssh
1,
"admin",
db_connection=database_with_catalog,
)
mock_refresh_schemas = mocker.patch.object(cmmd, "_refresh_schemas")
mock_rename_db_perm = mocker.patch.object(cmmd, "_rename_database_in_permissions")
@@ -64,7 +65,6 @@ def test_sync_permissions_command_sync_mode(
assert cmmd.db_connection == database_with_catalog
assert cmmd.old_db_connection_name == "my_db"
assert cmmd.db_connection_ssh_tunnel == mock_ssh
user_mock.assert_called_once_with("admin")
add_pvm_mock.assert_has_calls(
[
@@ -120,7 +120,6 @@ def test_sync_permissions_command_passing_all_values(
"""
Test ``SyncPermissionsCommand`` when providing all arguments to the constructor.
"""
mock_ssh = mocker.MagicMock()
mock_database_dao = mocker.patch(
"superset.commands.database.sync_permissions.DatabaseDAO"
)
@@ -134,16 +133,13 @@ def test_sync_permissions_command_passing_all_values(
"admin",
old_db_connection_name="old name",
db_connection=database_with_catalog,
ssh_tunnel=mock_ssh,
)
mocker.patch.object(cmmd, "sync_database_permissions")
cmmd.run()
assert cmmd.db_connection == database_with_catalog
assert cmmd.old_db_connection_name == "old name"
assert cmmd.db_connection_ssh_tunnel == mock_ssh
mock_database_dao.find_by_id.assert_not_called()
mock_database_dao.get_ssh_tunnel.assert_not_called()
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False})
@@ -159,7 +155,6 @@ def test_sync_permissions_command_raise(
"superset.commands.database.sync_permissions.DatabaseDAO"
)
mock_database_dao.find_by_id.return_value = database_without_catalog
mock_database_dao.get_ssh_tunnel.return_value = mocker.MagicMock()
mock_user = mocker.patch(
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
)

View File

@@ -308,6 +308,7 @@ def test_database_connection(
},
"server_cert": None,
"sqlalchemy_uri": "gsheets://",
"ssh_tunnel": None,
"uuid": "02feae18-2dd6-4bb4-a9c0-49e9d4f29d58",
},
}
@@ -486,160 +487,6 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None:
}
def test_delete_ssh_tunnel(
mocker: MockerFixture,
app: Any,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test that we can delete SSH Tunnel
"""
with app.app_context():
from superset.daos.database import DatabaseDAO
from superset.databases.api import DatabaseRestApi
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
DatabaseRestApi.datamodel._session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", # noqa: E501
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=True,
)
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
db.session.add(tunnel)
db.session.commit()
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete(
f"/api/v1/database/{database.id}/ssh_tunnel/"
)
assert response_delete_tunnel.json["message"] == "OK"
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel is None
def test_delete_ssh_tunnel_not_found(
mocker: MockerFixture,
app: Any,
session: Session,
client: Any,
full_api_access: None,
) -> None:
"""
Test that we cannot delete a tunnel that does not exist
"""
with app.app_context():
from superset.daos.database import DatabaseDAO
from superset.databases.api import DatabaseRestApi
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
DatabaseRestApi.datamodel._session = session
# create table for databases
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
# Create our Database
database = Database(
database_name="my_database",
sqlalchemy_uri="gsheets://",
encrypted_extra=json.dumps(
{
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
"private_key": "SECRET",
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", # noqa: E501
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
},
}
),
)
db.session.add(database)
db.session.commit()
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=True,
)
# Create our SSHTunnel
tunnel = SSHTunnel(
database_id=1,
database=database,
)
db.session.add(tunnel)
db.session.commit()
# Delete the recently created SSHTunnel
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
assert response_delete_tunnel.json["message"] == "Not found"
# Get our recently created SSHTunnel
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
assert response_tunnel
assert isinstance(response_tunnel, SSHTunnel)
assert 1 == response_tunnel.database_id
response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
assert response_tunnel is None
def test_apply_dynamic_database_filter(
mocker: MockerFixture,
app: Any,
@@ -698,10 +545,6 @@ def test_apply_dynamic_database_filter(
# mock the lookup so that we don't need to include the driver
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
mocker.patch("superset.utils.log.DBEventLogger.log")
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=False,
)
def _base_filter(query):
from superset.models.core import Database

View File

@@ -29,13 +29,12 @@ def test_add_permissions(mocker: MockerFixture) -> None:
database.db_engine_spec.supports_catalog = True
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
database.get_all_schema_names.side_effect = [["schema1"], ["schema2"]]
ssh_tunnel = mocker.MagicMock()
add_permission_view_menu = mocker.patch(
"superset.commands.database.importers.v1.utils.security_manager."
"add_permission_view_menu"
)
add_permissions(database, ssh_tunnel)
add_permissions(database)
add_permission_view_menu.assert_has_calls(
[
@@ -60,13 +59,12 @@ def test_add_permissions_get_default_catalog(mocker: MockerFixture):
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
database.get_default_catalog.return_value = "catalog1"
database.get_all_schema_names.side_effect = [["schema1"], ["schema2"]]
ssh_tunnel = mocker.MagicMock()
add_permission_view_menu = mocker.patch(
"superset.commands.database.importers.v1.utils.security_manager."
"add_permission_view_menu"
)
add_permissions(database, ssh_tunnel)
add_permissions(database)
add_permission_view_menu.assert_has_calls(
[
@@ -88,13 +86,12 @@ def test_add_permissions_handle_failures(mocker: MockerFixture) -> None:
database.db_engine_spec.supports_catalog = True
database.get_all_catalog_names.return_value = ["catalog1", "catalog2", "catalog3"]
database.get_all_schema_names.side_effect = [["schema1"], Exception, ["schema3"]]
ssh_tunnel = mocker.MagicMock()
add_permission_view_menu = mocker.patch(
"superset.commands.database.importers.v1.utils.security_manager."
"add_permission_view_menu"
)
add_permissions(database, ssh_tunnel)
add_permissions(database)
add_permission_view_menu.assert_has_calls(
[

View File

@@ -54,7 +54,9 @@ def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
database = DatabaseDAO.find_by_id(1, skip_base_filter=True)
assert database is not None
result = database.ssh_tunnel
assert result
assert isinstance(result, SSHTunnel)
@@ -64,6 +66,7 @@ def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None:
from superset.daos.database import DatabaseDAO
result = DatabaseDAO.get_ssh_tunnel(2)
database = DatabaseDAO.find_by_id(2, skip_base_filter=True)
result = database.ssh_tunnel if database else None
assert result is None

View File

@@ -117,10 +117,7 @@ def test_database_filter(mocker: MockerFixture) -> None:
engine,
compile_kwargs={"literal_binds": True},
)
space = " " # pre-commit removes trailing spaces...
assert (
str(compiled_query)
== f"""SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk{space}
FROM dbs{space}
WHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')""" # noqa: S608, E501
== "SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk, ssh_tunnels_1.uuid AS uuid_1, ssh_tunnels_1.created_on AS created_on_1, ssh_tunnels_1.changed_on AS changed_on_1, ssh_tunnels_1.extra_json, ssh_tunnels_1.id AS id_1, ssh_tunnels_1.database_id, ssh_tunnels_1.server_address, ssh_tunnels_1.server_port, ssh_tunnels_1.username, ssh_tunnels_1.password AS password_1, ssh_tunnels_1.private_key, ssh_tunnels_1.private_key_password, ssh_tunnels_1.created_by_fk AS created_by_fk_1, ssh_tunnels_1.changed_by_fk AS changed_by_fk_1 \nFROM dbs LEFT OUTER JOIN ssh_tunnels AS ssh_tunnels_1 ON dbs.id = ssh_tunnels_1.database_id \nWHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')" # noqa: E501
)

View File

@@ -1,16 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -1,148 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
from sqlalchemy.orm.session import Session
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)
def test_create_ssh_tunnel_command(session: Session) -> None:
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
database = Database(
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
properties = {
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
result = CreateSSHTunnelCommand(database, properties).run()
assert result is not None
assert isinstance(result, SSHTunnel)
def test_create_ssh_tunnel_command_invalid_params(session: Session) -> None:
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.models.core import Database
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
database = Database(
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
)
# If we are trying to create a tunnel with a private_key_password
# then a private_key is mandatory
properties = {
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"private_key_password": "bar",
}
command = CreateSSHTunnelCommand(database, properties)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
def test_create_ssh_tunnel_command_no_port(session: Session) -> None:
"""
Test that SSH Tunnel can be created without explicit port but with a default one.
"""
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
database = Database(
database_name="my_database",
sqlalchemy_uri="postgresql://u:p@localhost/db",
)
properties = {
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
result = CreateSSHTunnelCommand(database, properties).run()
assert result is not None
assert isinstance(result, SSHTunnel)
def test_create_ssh_tunnel_command_no_port_no_default(session: Session) -> None:
"""
Test that error is raised when creating SSH Tunnel without explicit/default ports.
"""
from superset import db
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
from superset.models.core import Database
engine = db.session.get_bind()
Database.metadata.create_all(engine) # pylint: disable=no-member
database = Database(
id=1,
database_name="my_database",
sqlalchemy_uri="weird+db://u:p@localhost/db",
)
properties = {
"database": database,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
}
command = CreateSSHTunnelCommand(database, properties)
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
command.run()
assert str(excinfo.value) == (
"A database port is required when connecting via SSH Tunnel."
)

View File

@@ -1,73 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Iterator
import pytest
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
@pytest.fixture
def session_with_data(session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=database,
)
ssh_tunnel = SSHTunnel(
database_id=database.id,
database=database,
)
session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
yield session
session.rollback()
def test_delete_ssh_tunnel_command(
mocker: MockerFixture, session_with_data: Session
) -> None:
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
mocker.patch(
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
return_value=True,
)
DeleteSSHTunnelCommand(1).run()
result = DatabaseDAO.get_ssh_tunnel(1)
assert result is None

View File

@@ -1,155 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from collections.abc import Iterator
import pytest
from sqlalchemy.orm.session import Session
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelInvalidError,
)
@pytest.fixture
def session_with_data(request, session: Session) -> Iterator[Session]:
from superset.connectors.sqla.models import SqlaTable
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
engine = session.get_bind()
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
sqlalchemy_uri = getattr(request, "param", "postgresql://u:p@localhost:5432/db")
database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri)
sqla_table = SqlaTable(
table_name="my_sqla_table",
columns=[],
metrics=[],
database=database,
)
ssh_tunnel = SSHTunnel(
database_id=database.id, database=database, server_address="Test"
)
session.add(database)
session.add(sqla_table)
session.add(ssh_tunnel)
session.flush()
yield session
session.rollback()
def test_update_shh_tunnel_command(session_with_data: Session) -> None:
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
update_payload = {"server_address": "Test2"}
UpdateSSHTunnelCommand(1, update_payload).run()
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert "Test2" == result.server_address
def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
# If we are trying to update a tunnel with a private_key_password
# then a private_key is mandatory
update_payload = {"private_key_password": "pass"}
command = UpdateSSHTunnelCommand(1, update_payload)
with pytest.raises(SSHTunnelInvalidError) as excinfo:
command.run()
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
@pytest.mark.parametrize(
"session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
)
def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
"""
Test that SSH Tunnel can be updated without explicit port but with a default one.
"""
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
update_payload = {"server_address": "Test2"}
UpdateSSHTunnelCommand(1, update_payload).run()
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert "Test2" == result.server_address
@pytest.mark.parametrize(
"session_with_data", ["weird+db://u:p@localhost/testdb"], indirect=True
)
def test_update_shh_tunnel_no_port_no_default(session_with_data: Session) -> None:
"""
Test that error is raised when updating SSH Tunnel without explicit/default ports.
"""
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
from superset.daos.database import DatabaseDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
result = DatabaseDAO.get_ssh_tunnel(1)
assert result
assert isinstance(result, SSHTunnel)
assert 1 == result.database_id
assert "Test" == result.server_address
update_payload = {"server_address": "Test update"}
command = UpdateSSHTunnelCommand(1, update_payload)
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
command.run()
assert str(excinfo.value) == (
"A database port is required when connecting via SSH Tunnel."
)

View File

@@ -1,37 +0,0 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
def test_create_ssh_tunnel():
from superset.daos.database import SSHTunnelDAO
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
result = SSHTunnelDAO.create(
attributes={
"database_id": database.id,
"server_address": "123.132.123.1",
"server_port": "3005",
"username": "foo",
"password": "bar",
},
)
assert result is not None
assert isinstance(result, SSHTunnel)

View File

@@ -333,7 +333,7 @@ def test_get_all_catalog_names(mocker: MockerFixture) -> None:
inspector.bind.execute.return_value = [("examples",), ("other",)]
assert database.get_all_catalog_names(force=True) == {"examples", "other"}
get_inspector.assert_called_with(ssh_tunnel=None)
get_inspector.assert_called_with()
def test_get_all_schema_names_needs_oauth2(mocker: MockerFixture) -> None: