mirror of
https://github.com/apache/superset.git
synced 2026-04-19 16:14:52 +00:00
chore: cleanup ssh tunnel (#34388)
This commit is contained in:
@@ -43,7 +43,6 @@ def test_sync_permissions_command_sync_mode(
|
||||
"""
|
||||
Test ``SyncPermissionsCommand`` in sync mode.
|
||||
"""
|
||||
mock_ssh = mocker.MagicMock()
|
||||
user_mock = mocker.patch(
|
||||
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
|
||||
)
|
||||
@@ -55,7 +54,9 @@ def test_sync_permissions_command_sync_mode(
|
||||
add_pvm_mock = mocker.patch("superset.commands.database.sync_permissions.add_pvm")
|
||||
|
||||
cmmd = SyncPermissionsCommand(
|
||||
1, "admin", db_connection=database_with_catalog, ssh_tunnel=mock_ssh
|
||||
1,
|
||||
"admin",
|
||||
db_connection=database_with_catalog,
|
||||
)
|
||||
mock_refresh_schemas = mocker.patch.object(cmmd, "_refresh_schemas")
|
||||
mock_rename_db_perm = mocker.patch.object(cmmd, "_rename_database_in_permissions")
|
||||
@@ -64,7 +65,6 @@ def test_sync_permissions_command_sync_mode(
|
||||
|
||||
assert cmmd.db_connection == database_with_catalog
|
||||
assert cmmd.old_db_connection_name == "my_db"
|
||||
assert cmmd.db_connection_ssh_tunnel == mock_ssh
|
||||
user_mock.assert_called_once_with("admin")
|
||||
add_pvm_mock.assert_has_calls(
|
||||
[
|
||||
@@ -120,7 +120,6 @@ def test_sync_permissions_command_passing_all_values(
|
||||
"""
|
||||
Test ``SyncPermissionsCommand`` when providing all arguments to the constructor.
|
||||
"""
|
||||
mock_ssh = mocker.MagicMock()
|
||||
mock_database_dao = mocker.patch(
|
||||
"superset.commands.database.sync_permissions.DatabaseDAO"
|
||||
)
|
||||
@@ -134,16 +133,13 @@ def test_sync_permissions_command_passing_all_values(
|
||||
"admin",
|
||||
old_db_connection_name="old name",
|
||||
db_connection=database_with_catalog,
|
||||
ssh_tunnel=mock_ssh,
|
||||
)
|
||||
mocker.patch.object(cmmd, "sync_database_permissions")
|
||||
cmmd.run()
|
||||
|
||||
assert cmmd.db_connection == database_with_catalog
|
||||
assert cmmd.old_db_connection_name == "old name"
|
||||
assert cmmd.db_connection_ssh_tunnel == mock_ssh
|
||||
mock_database_dao.find_by_id.assert_not_called()
|
||||
mock_database_dao.get_ssh_tunnel.assert_not_called()
|
||||
|
||||
|
||||
@with_config({"SYNC_DB_PERMISSIONS_IN_ASYNC_MODE": False})
|
||||
@@ -159,7 +155,6 @@ def test_sync_permissions_command_raise(
|
||||
"superset.commands.database.sync_permissions.DatabaseDAO"
|
||||
)
|
||||
mock_database_dao.find_by_id.return_value = database_without_catalog
|
||||
mock_database_dao.get_ssh_tunnel.return_value = mocker.MagicMock()
|
||||
mock_user = mocker.patch(
|
||||
"superset.commands.database.sync_permissions.security_manager.get_user_by_username"
|
||||
)
|
||||
|
||||
@@ -308,6 +308,7 @@ def test_database_connection(
|
||||
},
|
||||
"server_cert": None,
|
||||
"sqlalchemy_uri": "gsheets://",
|
||||
"ssh_tunnel": None,
|
||||
"uuid": "02feae18-2dd6-4bb4-a9c0-49e9d4f29d58",
|
||||
},
|
||||
}
|
||||
@@ -486,160 +487,6 @@ def test_non_zip_import(client: Any, full_api_access: None) -> None:
|
||||
}
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel(
|
||||
mocker: MockerFixture,
|
||||
app: Any,
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test that we can delete SSH Tunnel
|
||||
"""
|
||||
with app.app_context():
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
DatabaseRestApi.datamodel._session = session
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
|
||||
# Create our Database
|
||||
database = Database(
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="gsheets://",
|
||||
encrypted_extra=json.dumps(
|
||||
{
|
||||
"service_account_info": {
|
||||
"type": "service_account",
|
||||
"project_id": "black-sanctum-314419",
|
||||
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
||||
"private_key": "SECRET",
|
||||
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", # noqa: E501
|
||||
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||
mocker.patch(
|
||||
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
# Create our SSHTunnel
|
||||
tunnel = SSHTunnel(
|
||||
database_id=1,
|
||||
database=database,
|
||||
)
|
||||
|
||||
db.session.add(tunnel)
|
||||
db.session.commit()
|
||||
|
||||
# Get our recently created SSHTunnel
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
assert response_tunnel
|
||||
assert isinstance(response_tunnel, SSHTunnel)
|
||||
assert 1 == response_tunnel.database_id
|
||||
|
||||
# Delete the recently created SSHTunnel
|
||||
response_delete_tunnel = client.delete(
|
||||
f"/api/v1/database/{database.id}/ssh_tunnel/"
|
||||
)
|
||||
assert response_delete_tunnel.json["message"] == "OK"
|
||||
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
assert response_tunnel is None
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel_not_found(
|
||||
mocker: MockerFixture,
|
||||
app: Any,
|
||||
session: Session,
|
||||
client: Any,
|
||||
full_api_access: None,
|
||||
) -> None:
|
||||
"""
|
||||
Test that we cannot delete a tunnel that does not exist
|
||||
"""
|
||||
with app.app_context():
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.api import DatabaseRestApi
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
DatabaseRestApi.datamodel._session = session
|
||||
|
||||
# create table for databases
|
||||
Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member
|
||||
|
||||
# Create our Database
|
||||
database = Database(
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="gsheets://",
|
||||
encrypted_extra=json.dumps(
|
||||
{
|
||||
"service_account_info": {
|
||||
"type": "service_account",
|
||||
"project_id": "black-sanctum-314419",
|
||||
"private_key_id": "259b0d419a8f840056158763ff54d8b08f7b8173",
|
||||
"private_key": "SECRET",
|
||||
"client_email": "google-spreadsheets-demo-servi@black-sanctum-314419.iam.gserviceaccount.com", # noqa: E501
|
||||
"client_id": "SSH_TUNNEL_CREDENTIALS_CLIENT",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/google-spreadsheets-demo-servi%40black-sanctum-314419.iam.gserviceaccount.com",
|
||||
},
|
||||
}
|
||||
),
|
||||
)
|
||||
db.session.add(database)
|
||||
db.session.commit()
|
||||
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||
mocker.patch(
|
||||
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
|
||||
# Create our SSHTunnel
|
||||
tunnel = SSHTunnel(
|
||||
database_id=1,
|
||||
database=database,
|
||||
)
|
||||
|
||||
db.session.add(tunnel)
|
||||
db.session.commit()
|
||||
|
||||
# Delete the recently created SSHTunnel
|
||||
response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/")
|
||||
assert response_delete_tunnel.json["message"] == "Not found"
|
||||
|
||||
# Get our recently created SSHTunnel
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(1)
|
||||
assert response_tunnel
|
||||
assert isinstance(response_tunnel, SSHTunnel)
|
||||
assert 1 == response_tunnel.database_id
|
||||
|
||||
response_tunnel = DatabaseDAO.get_ssh_tunnel(2)
|
||||
assert response_tunnel is None
|
||||
|
||||
|
||||
def test_apply_dynamic_database_filter(
|
||||
mocker: MockerFixture,
|
||||
app: Any,
|
||||
@@ -698,10 +545,6 @@ def test_apply_dynamic_database_filter(
|
||||
# mock the lookup so that we don't need to include the driver
|
||||
mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets")
|
||||
mocker.patch("superset.utils.log.DBEventLogger.log")
|
||||
mocker.patch(
|
||||
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
|
||||
return_value=False,
|
||||
)
|
||||
|
||||
def _base_filter(query):
|
||||
from superset.models.core import Database
|
||||
|
||||
@@ -29,13 +29,12 @@ def test_add_permissions(mocker: MockerFixture) -> None:
|
||||
database.db_engine_spec.supports_catalog = True
|
||||
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
|
||||
database.get_all_schema_names.side_effect = [["schema1"], ["schema2"]]
|
||||
ssh_tunnel = mocker.MagicMock()
|
||||
add_permission_view_menu = mocker.patch(
|
||||
"superset.commands.database.importers.v1.utils.security_manager."
|
||||
"add_permission_view_menu"
|
||||
)
|
||||
|
||||
add_permissions(database, ssh_tunnel)
|
||||
add_permissions(database)
|
||||
|
||||
add_permission_view_menu.assert_has_calls(
|
||||
[
|
||||
@@ -60,13 +59,12 @@ def test_add_permissions_get_default_catalog(mocker: MockerFixture):
|
||||
database.get_all_catalog_names.return_value = ["catalog1", "catalog2"]
|
||||
database.get_default_catalog.return_value = "catalog1"
|
||||
database.get_all_schema_names.side_effect = [["schema1"], ["schema2"]]
|
||||
ssh_tunnel = mocker.MagicMock()
|
||||
add_permission_view_menu = mocker.patch(
|
||||
"superset.commands.database.importers.v1.utils.security_manager."
|
||||
"add_permission_view_menu"
|
||||
)
|
||||
|
||||
add_permissions(database, ssh_tunnel)
|
||||
add_permissions(database)
|
||||
|
||||
add_permission_view_menu.assert_has_calls(
|
||||
[
|
||||
@@ -88,13 +86,12 @@ def test_add_permissions_handle_failures(mocker: MockerFixture) -> None:
|
||||
database.db_engine_spec.supports_catalog = True
|
||||
database.get_all_catalog_names.return_value = ["catalog1", "catalog2", "catalog3"]
|
||||
database.get_all_schema_names.side_effect = [["schema1"], Exception, ["schema3"]]
|
||||
ssh_tunnel = mocker.MagicMock()
|
||||
add_permission_view_menu = mocker.patch(
|
||||
"superset.commands.database.importers.v1.utils.security_manager."
|
||||
"add_permission_view_menu"
|
||||
)
|
||||
|
||||
add_permissions(database, ssh_tunnel)
|
||||
add_permissions(database)
|
||||
|
||||
add_permission_view_menu.assert_has_calls(
|
||||
[
|
||||
|
||||
@@ -54,7 +54,9 @@ def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
database = DatabaseDAO.find_by_id(1, skip_base_filter=True)
|
||||
assert database is not None
|
||||
result = database.ssh_tunnel
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
@@ -64,6 +66,7 @@ def test_database_get_ssh_tunnel(session_with_data: Session) -> None:
|
||||
def test_database_get_ssh_tunnel_not_found(session_with_data: Session) -> None:
|
||||
from superset.daos.database import DatabaseDAO
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(2)
|
||||
database = DatabaseDAO.find_by_id(2, skip_base_filter=True)
|
||||
result = database.ssh_tunnel if database else None
|
||||
|
||||
assert result is None
|
||||
|
||||
@@ -117,10 +117,7 @@ def test_database_filter(mocker: MockerFixture) -> None:
|
||||
engine,
|
||||
compile_kwargs={"literal_binds": True},
|
||||
)
|
||||
space = " " # pre-commit removes trailing spaces...
|
||||
assert (
|
||||
str(compiled_query)
|
||||
== f"""SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk{space}
|
||||
FROM dbs{space}
|
||||
WHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')""" # noqa: S608, E501
|
||||
== "SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk, ssh_tunnels_1.uuid AS uuid_1, ssh_tunnels_1.created_on AS created_on_1, ssh_tunnels_1.changed_on AS changed_on_1, ssh_tunnels_1.extra_json, ssh_tunnels_1.id AS id_1, ssh_tunnels_1.database_id, ssh_tunnels_1.server_address, ssh_tunnels_1.server_port, ssh_tunnels_1.username, ssh_tunnels_1.password AS password_1, ssh_tunnels_1.private_key, ssh_tunnels_1.private_key_password, ssh_tunnels_1.created_by_fk AS created_by_fk_1, ssh_tunnels_1.changed_by_fk AS changed_by_fk_1 \nFROM dbs LEFT OUTER JOIN ssh_tunnels AS ssh_tunnels_1 ON dbs.id = ssh_tunnels_1.database_id \nWHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')" # noqa: E501
|
||||
)
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
@@ -1,148 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command(session: Session) -> None:
|
||||
from superset import db
|
||||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
|
||||
)
|
||||
|
||||
properties = {
|
||||
"database": database,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
result = CreateSSHTunnelCommand(database, properties).run()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SSHTunnel)
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command_invalid_params(session: Session) -> None:
|
||||
from superset import db
|
||||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="postgresql://u:p@localhost:5432/db",
|
||||
)
|
||||
|
||||
# If we are trying to create a tunnel with a private_key_password
|
||||
# then a private_key is mandatory
|
||||
properties = {
|
||||
"database": database,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"private_key_password": "bar",
|
||||
}
|
||||
|
||||
command = CreateSSHTunnelCommand(database, properties)
|
||||
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command_no_port(session: Session) -> None:
|
||||
"""
|
||||
Test that SSH Tunnel can be created without explicit port but with a default one.
|
||||
"""
|
||||
from superset import db
|
||||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="postgresql://u:p@localhost/db",
|
||||
)
|
||||
|
||||
properties = {
|
||||
"database": database,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
result = CreateSSHTunnelCommand(database, properties).run()
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SSHTunnel)
|
||||
|
||||
|
||||
def test_create_ssh_tunnel_command_no_port_no_default(session: Session) -> None:
|
||||
"""
|
||||
Test that error is raised when creating SSH Tunnel without explicit/default ports.
|
||||
"""
|
||||
from superset import db
|
||||
from superset.commands.database.ssh_tunnel.create import CreateSSHTunnelCommand
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = db.session.get_bind()
|
||||
Database.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(
|
||||
id=1,
|
||||
database_name="my_database",
|
||||
sqlalchemy_uri="weird+db://u:p@localhost/db",
|
||||
)
|
||||
|
||||
properties = {
|
||||
"database": database,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
}
|
||||
|
||||
command = CreateSSHTunnelCommand(database, properties)
|
||||
|
||||
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == (
|
||||
"A database port is required when connecting via SSH Tunnel."
|
||||
)
|
||||
@@ -1,73 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_data(session: Session) -> Iterator[Session]:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
database = Database(database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=database,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=database.id,
|
||||
database=database,
|
||||
)
|
||||
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_delete_ssh_tunnel_command(
|
||||
mocker: MockerFixture, session_with_data: Session
|
||||
) -> None:
|
||||
from superset.commands.database.ssh_tunnel.delete import DeleteSSHTunnelCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
mocker.patch(
|
||||
"superset.commands.database.ssh_tunnel.delete.is_feature_enabled",
|
||||
return_value=True,
|
||||
)
|
||||
DeleteSSHTunnelCommand(1).run()
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result is None
|
||||
@@ -1,155 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from collections.abc import Iterator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm.session import Session
|
||||
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session_with_data(request, session: Session) -> Iterator[Session]:
|
||||
from superset.connectors.sqla.models import SqlaTable
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
engine = session.get_bind()
|
||||
SqlaTable.metadata.create_all(engine) # pylint: disable=no-member
|
||||
|
||||
sqlalchemy_uri = getattr(request, "param", "postgresql://u:p@localhost:5432/db")
|
||||
database = Database(database_name="my_database", sqlalchemy_uri=sqlalchemy_uri)
|
||||
sqla_table = SqlaTable(
|
||||
table_name="my_sqla_table",
|
||||
columns=[],
|
||||
metrics=[],
|
||||
database=database,
|
||||
)
|
||||
ssh_tunnel = SSHTunnel(
|
||||
database_id=database.id, database=database, server_address="Test"
|
||||
)
|
||||
|
||||
session.add(database)
|
||||
session.add(sqla_table)
|
||||
session.add(ssh_tunnel)
|
||||
session.flush()
|
||||
yield session
|
||||
session.rollback()
|
||||
|
||||
|
||||
def test_update_shh_tunnel_command(session_with_data: Session) -> None:
|
||||
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
assert "Test" == result.server_address
|
||||
|
||||
update_payload = {"server_address": "Test2"}
|
||||
UpdateSSHTunnelCommand(1, update_payload).run()
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert "Test2" == result.server_address
|
||||
|
||||
|
||||
def test_update_shh_tunnel_invalid_params(session_with_data: Session) -> None:
|
||||
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
assert "Test" == result.server_address
|
||||
|
||||
# If we are trying to update a tunnel with a private_key_password
|
||||
# then a private_key is mandatory
|
||||
update_payload = {"private_key_password": "pass"}
|
||||
command = UpdateSSHTunnelCommand(1, update_payload)
|
||||
|
||||
with pytest.raises(SSHTunnelInvalidError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == ("SSH Tunnel parameters are invalid.")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"session_with_data", ["postgresql://u:p@localhost/testdb"], indirect=True
|
||||
)
|
||||
def test_update_shh_tunnel_no_port(session_with_data: Session) -> None:
|
||||
"""
|
||||
Test that SSH Tunnel can be updated without explicit port but with a default one.
|
||||
"""
|
||||
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
assert "Test" == result.server_address
|
||||
|
||||
update_payload = {"server_address": "Test2"}
|
||||
UpdateSSHTunnelCommand(1, update_payload).run()
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert "Test2" == result.server_address
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"session_with_data", ["weird+db://u:p@localhost/testdb"], indirect=True
|
||||
)
|
||||
def test_update_shh_tunnel_no_port_no_default(session_with_data: Session) -> None:
|
||||
"""
|
||||
Test that error is raised when updating SSH Tunnel without explicit/default ports.
|
||||
"""
|
||||
from superset.commands.database.ssh_tunnel.update import UpdateSSHTunnelCommand
|
||||
from superset.daos.database import DatabaseDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
result = DatabaseDAO.get_ssh_tunnel(1)
|
||||
|
||||
assert result
|
||||
assert isinstance(result, SSHTunnel)
|
||||
assert 1 == result.database_id
|
||||
assert "Test" == result.server_address
|
||||
|
||||
update_payload = {"server_address": "Test update"}
|
||||
command = UpdateSSHTunnelCommand(1, update_payload)
|
||||
|
||||
with pytest.raises(SSHTunnelDatabasePortError) as excinfo:
|
||||
command.run()
|
||||
assert str(excinfo.value) == (
|
||||
"A database port is required when connecting via SSH Tunnel."
|
||||
)
|
||||
@@ -1,37 +0,0 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
|
||||
def test_create_ssh_tunnel():
|
||||
from superset.daos.database import SSHTunnelDAO
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://")
|
||||
|
||||
result = SSHTunnelDAO.create(
|
||||
attributes={
|
||||
"database_id": database.id,
|
||||
"server_address": "123.132.123.1",
|
||||
"server_port": "3005",
|
||||
"username": "foo",
|
||||
"password": "bar",
|
||||
},
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SSHTunnel)
|
||||
@@ -333,7 +333,7 @@ def test_get_all_catalog_names(mocker: MockerFixture) -> None:
|
||||
inspector.bind.execute.return_value = [("examples",), ("other",)]
|
||||
|
||||
assert database.get_all_catalog_names(force=True) == {"examples", "other"}
|
||||
get_inspector.assert_called_with(ssh_tunnel=None)
|
||||
get_inspector.assert_called_with()
|
||||
|
||||
|
||||
def test_get_all_schema_names_needs_oauth2(mocker: MockerFixture) -> None:
|
||||
|
||||
Reference in New Issue
Block a user