feat(ssh_tunnel): Import/Export Databases with SSHTunnel credentials (#23099)

This commit is contained in:
Antonio Rivero
2023-02-24 14:36:21 -03:00
committed by GitHub
parent 967383853c
commit 3484e8ea7b
30 changed files with 2039 additions and 50 deletions

View File

@@ -1095,6 +1095,30 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
overwrite:
description: overwrite existing databases?
type: boolean
ssh_tunnel_passwords:
description: >-
JSON map of passwords for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the password should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_password"}`.
type: string
ssh_tunnel_private_keys:
description: >-
JSON map of private_keys for each ssh_tunnel associated to a
featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key"}`.
type: string
ssh_tunnel_private_key_passwords:
description: >-
JSON map of private_key_passwords for each ssh_tunnel associated
to a featured database in the ZIP file. If the ZIP includes a
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key_password"}`.
type: string
responses:
200:
description: Database import result
@@ -1131,9 +1155,29 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
else None
)
overwrite = request.form.get("overwrite") == "true"
ssh_tunnel_passwords = (
json.loads(request.form["ssh_tunnel_passwords"])
if "ssh_tunnel_passwords" in request.form
else None
)
ssh_tunnel_private_keys = (
json.loads(request.form["ssh_tunnel_private_keys"])
if "ssh_tunnel_private_keys" in request.form
else None
)
ssh_tunnel_priv_key_passwords = (
json.loads(request.form["ssh_tunnel_private_key_passwords"])
if "ssh_tunnel_private_key_passwords" in request.form
else None
)
command = ImportDatabasesCommand(
contents, passwords=passwords, overwrite=overwrite
contents,
passwords=passwords,
overwrite=overwrite,
ssh_tunnel_passwords=ssh_tunnel_passwords,
ssh_tunnel_private_keys=ssh_tunnel_private_keys,
ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
)
command.run()
return self.response(200, message="OK")

View File

@@ -28,6 +28,7 @@ from superset.commands.export.models import ExportModelsCommand
from superset.models.core import Database
from superset.utils.dict_import_export import EXPORT_VERSION
from superset.utils.file import get_filename
from superset.utils.ssh_tunnel import mask_password_info
logger = logging.getLogger(__name__)
@@ -87,6 +88,15 @@ class ExportDatabasesCommand(ExportModelsCommand):
"schemas_allowed_for_file_upload"
)
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(model.id):
ssh_tunnel_payload = ssh_tunnel.export_to_dict(
recursive=False,
include_parent_ref=False,
include_defaults=True,
export_uuids=False,
)
payload["ssh_tunnel"] = mask_password_info(ssh_tunnel_payload)
payload["version"] = EXPORT_VERSION
file_content = yaml.safe_dump(payload, sort_keys=False)

View File

@@ -20,6 +20,7 @@ from typing import Any, Dict
from sqlalchemy.orm import Session
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.models.core import Database
@@ -42,8 +43,15 @@ def import_database(
# TODO (betodealmeida): move this logic to import_from_dict
config["extra"] = json.dumps(config["extra"])
# Before it gets removed in import_from_dict
ssh_tunnel = config.pop("ssh_tunnel", None)
database = Database.import_from_dict(session, config, recursive=False)
if database.id is None:
session.flush()
if ssh_tunnel:
ssh_tunnel["database_id"] = database.id
SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
return database

View File

@@ -19,7 +19,7 @@
import inspect
import json
from typing import Any, Dict
from typing import Any, Dict, List
from flask import current_app
from flask_babel import lazy_gettext as _
@@ -28,9 +28,14 @@ from marshmallow.validate import Length, ValidationError
from marshmallow_enum import EnumField
from sqlalchemy import MetaData
from superset import db
from superset import db, is_feature_enabled
from superset.constants import PASSWORD_MASK
from superset.databases.commands.exceptions import DatabaseInvalidError
from superset.databases.ssh_tunnel.commands.exceptions import (
SSHTunnelingNotEnabledError,
SSHTunnelInvalidCredentials,
SSHTunnelMissingCredentials,
)
from superset.databases.utils import make_url_safe
from superset.db_engine_specs import get_engine_spec
from superset.exceptions import CertificateException, SupersetSecurityException
@@ -706,6 +711,7 @@ class ImportV1DatabaseSchema(Schema):
version = fields.String(required=True)
is_managed_externally = fields.Boolean(allow_none=True, default=False)
external_url = fields.String(allow_none=True)
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
@validates_schema
def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
@@ -720,6 +726,68 @@ class ImportV1DatabaseSchema(Schema):
if password == PASSWORD_MASK and data.get("password") is None:
raise ValidationError("Must provide a password for the database")
@validates_schema
def validate_ssh_tunnel_credentials(
self, data: Dict[str, Any], **kwargs: Any
) -> None:
"""If ssh_tunnel has a masked credentials, credentials are required"""
uuid = data["uuid"]
existing = db.session.query(Database).filter_by(uuid=uuid).first()
if existing:
return
# Our DB has a ssh_tunnel in it
if ssh_tunnel := data.get("ssh_tunnel"):
# Login methods are (only one from these options):
# 1. password
# 2. private_key + private_key_password
# Based on the data passed we determine what info is required.
# You cannot mix the credentials from both methods.
if not is_feature_enabled("SSH_TUNNELING"):
# You are trying to import a Database with SSH Tunnel
# But the Feature Flag is not enabled.
raise SSHTunnelingNotEnabledError()
password = ssh_tunnel.get("password")
private_key = ssh_tunnel.get("private_key")
private_key_password = ssh_tunnel.get("private_key_password")
if password is not None:
# Login method #1 (Password)
if private_key is not None or private_key_password is not None:
# You cannot have a mix of login methods
raise SSHTunnelInvalidCredentials()
if password == PASSWORD_MASK:
raise ValidationError("Must provide a password for the ssh tunnel")
if password is None:
# If the SSH Tunnel we're importing has no password then it must
# have a private_key + private_key_password combination
if private_key is None and private_key_password is None:
# We have found nothing related to other credentials
raise SSHTunnelMissingCredentials()
# We need to ask for the missing properties of our method # 2
# Some times the property is just missing
# or there're times where it's masked.
# If both are masked, we need to return a list of errors
# so the UI ask for both fields at the same time if needed
exception_messages: List[str] = []
if private_key is None or private_key == PASSWORD_MASK:
# If we get here we need to ask for the private key
exception_messages.append(
"Must provide a private key for the ssh tunnel"
)
if (
private_key_password is None
or private_key_password == PASSWORD_MASK
):
# If we get here we need to ask for the private key password
exception_messages.append(
"Must provide a private key password for the ssh tunnel"
)
if exception_messages:
# We can ask for just one field or both if masked, if both
# are empty, SSHTunnelMissingCredentials was already raised
raise ValidationError(exception_messages)
return
class EncryptedField: # pylint: disable=too-few-public-methods
"""

View File

@@ -57,3 +57,11 @@ class SSHTunnelRequiredFieldValidationError(ValidationError):
[_("Field is required")],
field_name=field_name,
)
class SSHTunnelMissingCredentials(CommandInvalidError):
message = _("Must provide credentials for the SSH Tunnel")
class SSHTunnelInvalidCredentials(CommandInvalidError):
message = _("Cannot have multiple credentials for the SSH Tunnel")

View File

@@ -68,6 +68,19 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
)
export_fields = [
"server_address",
"server_port",
"username",
"password",
"private_key",
"private_key_password",
]
extra_import_fields = [
"database_id",
]
@property
def data(self) -> Dict[str, Any]:
output = {