chore: Remove database ID dependency for SSH Tunnel creation (#26989)

This commit is contained in:
Geido
2024-02-07 18:03:19 +02:00
committed by GitHub
parent 43e1dc49c9
commit d8e26cfff1
6 changed files with 51 additions and 73 deletions

View File

@@ -28,39 +28,32 @@ from superset.commands.database.ssh_tunnel.exceptions import (
)
from superset.daos.database import SSHTunnelDAO
from superset.daos.exceptions import DAOCreateFailedError
from superset.extensions import db, event_logger
from superset.extensions import event_logger
from superset.models.core import Database
logger = logging.getLogger(__name__)
class CreateSSHTunnelCommand(BaseCommand):
def __init__(self, database_id: int, data: dict[str, Any]):
def __init__(self, database: Database, data: dict[str, Any]):
self._properties = data.copy()
self._properties["database_id"] = database_id
self._properties["database"] = database
def run(self) -> Model:
try:
# Start nested transaction since we are always creating the tunnel
# through a DB command (Create or Update). Without this, we cannot
# safely rollback changes to databases if any, i.e, things like
# test_do_not_create_database_if_ssh_tunnel_creation_fails test will fail
db.session.begin_nested()
self.validate()
return SSHTunnelDAO.create(attributes=self._properties, commit=False)
ssh_tunnel = SSHTunnelDAO.create(attributes=self._properties, commit=False)
return ssh_tunnel
except DAOCreateFailedError as ex:
# Rollback nested transaction
db.session.rollback()
raise SSHTunnelCreateFailedError() from ex
except SSHTunnelInvalidError as ex:
# Rollback nested transaction
db.session.rollback()
raise ex
def validate(self) -> None:
# TODO(hughhh): check to make sure the server port is not localhost
# using the config.SSH_TUNNEL_MANAGER
exceptions: list[ValidationError] = []
database_id: Optional[int] = self._properties.get("database_id")
server_address: Optional[str] = self._properties.get("server_address")
server_port: Optional[int] = self._properties.get("server_port")
username: Optional[str] = self._properties.get("username")
@@ -68,8 +61,6 @@ class CreateSSHTunnelCommand(BaseCommand):
private_key_password: Optional[str] = self._properties.get(
"private_key_password"
)
if not database_id:
exceptions.append(SSHTunnelRequiredFieldValidationError("database_id"))
if not server_address:
exceptions.append(SSHTunnelRequiredFieldValidationError("server_address"))
if not server_port: