mirror of
https://github.com/apache/superset.git
synced 2026-04-19 08:04:53 +00:00
feat(ssh_tunnel): Import/Export Databases with SSHTunnel credentials (#23099)
This commit is contained in:
@@ -1095,6 +1095,30 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
overwrite:
|
||||
description: overwrite existing databases?
|
||||
type: boolean
|
||||
ssh_tunnel_passwords:
|
||||
description: >-
|
||||
JSON map of passwords for each ssh_tunnel associated to a
|
||||
featured database in the ZIP file. If the ZIP includes a
|
||||
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
|
||||
the password should be provided in the following format:
|
||||
`{"databases/MyDatabase.yaml": "my_password"}`.
|
||||
type: string
|
||||
ssh_tunnel_private_keys:
|
||||
description: >-
|
||||
JSON map of private_keys for each ssh_tunnel associated to a
|
||||
featured database in the ZIP file. If the ZIP includes a
|
||||
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
|
||||
the private_key should be provided in the following format:
|
||||
`{"databases/MyDatabase.yaml": "my_private_key"}`.
|
||||
type: string
|
||||
ssh_tunnel_private_key_passwords:
|
||||
description: >-
|
||||
JSON map of private_key_passwords for each ssh_tunnel associated
|
||||
to a featured database in the ZIP file. If the ZIP includes a
|
||||
ssh_tunnel config in the path `databases/MyDatabase.yaml`,
|
||||
the private_key should be provided in the following format:
|
||||
`{"databases/MyDatabase.yaml": "my_private_key_password"}`.
|
||||
type: string
|
||||
responses:
|
||||
200:
|
||||
description: Database import result
|
||||
@@ -1131,9 +1155,29 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
else None
|
||||
)
|
||||
overwrite = request.form.get("overwrite") == "true"
|
||||
ssh_tunnel_passwords = (
|
||||
json.loads(request.form["ssh_tunnel_passwords"])
|
||||
if "ssh_tunnel_passwords" in request.form
|
||||
else None
|
||||
)
|
||||
ssh_tunnel_private_keys = (
|
||||
json.loads(request.form["ssh_tunnel_private_keys"])
|
||||
if "ssh_tunnel_private_keys" in request.form
|
||||
else None
|
||||
)
|
||||
ssh_tunnel_priv_key_passwords = (
|
||||
json.loads(request.form["ssh_tunnel_private_key_passwords"])
|
||||
if "ssh_tunnel_private_key_passwords" in request.form
|
||||
else None
|
||||
)
|
||||
|
||||
command = ImportDatabasesCommand(
|
||||
contents, passwords=passwords, overwrite=overwrite
|
||||
contents,
|
||||
passwords=passwords,
|
||||
overwrite=overwrite,
|
||||
ssh_tunnel_passwords=ssh_tunnel_passwords,
|
||||
ssh_tunnel_private_keys=ssh_tunnel_private_keys,
|
||||
ssh_tunnel_priv_key_passwords=ssh_tunnel_priv_key_passwords,
|
||||
)
|
||||
command.run()
|
||||
return self.response(200, message="OK")
|
||||
|
||||
@@ -28,6 +28,7 @@ from superset.commands.export.models import ExportModelsCommand
|
||||
from superset.models.core import Database
|
||||
from superset.utils.dict_import_export import EXPORT_VERSION
|
||||
from superset.utils.file import get_filename
|
||||
from superset.utils.ssh_tunnel import mask_password_info
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -87,6 +88,15 @@ class ExportDatabasesCommand(ExportModelsCommand):
|
||||
"schemas_allowed_for_file_upload"
|
||||
)
|
||||
|
||||
if ssh_tunnel := DatabaseDAO.get_ssh_tunnel(model.id):
|
||||
ssh_tunnel_payload = ssh_tunnel.export_to_dict(
|
||||
recursive=False,
|
||||
include_parent_ref=False,
|
||||
include_defaults=True,
|
||||
export_uuids=False,
|
||||
)
|
||||
payload["ssh_tunnel"] = mask_password_info(ssh_tunnel_payload)
|
||||
|
||||
payload["version"] = EXPORT_VERSION
|
||||
|
||||
file_content = yaml.safe_dump(payload, sort_keys=False)
|
||||
|
||||
@@ -20,6 +20,7 @@ from typing import Any, Dict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
from superset.models.core import Database
|
||||
|
||||
|
||||
@@ -42,8 +43,15 @@ def import_database(
|
||||
# TODO (betodealmeida): move this logic to import_from_dict
|
||||
config["extra"] = json.dumps(config["extra"])
|
||||
|
||||
# Before it gets removed in import_from_dict
|
||||
ssh_tunnel = config.pop("ssh_tunnel", None)
|
||||
|
||||
database = Database.import_from_dict(session, config, recursive=False)
|
||||
if database.id is None:
|
||||
session.flush()
|
||||
|
||||
if ssh_tunnel:
|
||||
ssh_tunnel["database_id"] = database.id
|
||||
SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
|
||||
|
||||
return database
|
||||
|
||||
@@ -19,7 +19,7 @@
|
||||
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Dict
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from flask import current_app
|
||||
from flask_babel import lazy_gettext as _
|
||||
@@ -28,9 +28,14 @@ from marshmallow.validate import Length, ValidationError
|
||||
from marshmallow_enum import EnumField
|
||||
from sqlalchemy import MetaData
|
||||
|
||||
from superset import db
|
||||
from superset import db, is_feature_enabled
|
||||
from superset.constants import PASSWORD_MASK
|
||||
from superset.databases.commands.exceptions import DatabaseInvalidError
|
||||
from superset.databases.ssh_tunnel.commands.exceptions import (
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelInvalidCredentials,
|
||||
SSHTunnelMissingCredentials,
|
||||
)
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.db_engine_specs import get_engine_spec
|
||||
from superset.exceptions import CertificateException, SupersetSecurityException
|
||||
@@ -706,6 +711,7 @@ class ImportV1DatabaseSchema(Schema):
|
||||
version = fields.String(required=True)
|
||||
is_managed_externally = fields.Boolean(allow_none=True, default=False)
|
||||
external_url = fields.String(allow_none=True)
|
||||
ssh_tunnel = fields.Nested(DatabaseSSHTunnel, allow_none=True)
|
||||
|
||||
@validates_schema
|
||||
def validate_password(self, data: Dict[str, Any], **kwargs: Any) -> None:
|
||||
@@ -720,6 +726,68 @@ class ImportV1DatabaseSchema(Schema):
|
||||
if password == PASSWORD_MASK and data.get("password") is None:
|
||||
raise ValidationError("Must provide a password for the database")
|
||||
|
||||
@validates_schema
|
||||
def validate_ssh_tunnel_credentials(
|
||||
self, data: Dict[str, Any], **kwargs: Any
|
||||
) -> None:
|
||||
"""If ssh_tunnel has a masked credentials, credentials are required"""
|
||||
uuid = data["uuid"]
|
||||
existing = db.session.query(Database).filter_by(uuid=uuid).first()
|
||||
if existing:
|
||||
return
|
||||
|
||||
# Our DB has a ssh_tunnel in it
|
||||
if ssh_tunnel := data.get("ssh_tunnel"):
|
||||
# Login methods are (only one from these options):
|
||||
# 1. password
|
||||
# 2. private_key + private_key_password
|
||||
# Based on the data passed we determine what info is required.
|
||||
# You cannot mix the credentials from both methods.
|
||||
if not is_feature_enabled("SSH_TUNNELING"):
|
||||
# You are trying to import a Database with SSH Tunnel
|
||||
# But the Feature Flag is not enabled.
|
||||
raise SSHTunnelingNotEnabledError()
|
||||
password = ssh_tunnel.get("password")
|
||||
private_key = ssh_tunnel.get("private_key")
|
||||
private_key_password = ssh_tunnel.get("private_key_password")
|
||||
if password is not None:
|
||||
# Login method #1 (Password)
|
||||
if private_key is not None or private_key_password is not None:
|
||||
# You cannot have a mix of login methods
|
||||
raise SSHTunnelInvalidCredentials()
|
||||
if password == PASSWORD_MASK:
|
||||
raise ValidationError("Must provide a password for the ssh tunnel")
|
||||
if password is None:
|
||||
# If the SSH Tunnel we're importing has no password then it must
|
||||
# have a private_key + private_key_password combination
|
||||
if private_key is None and private_key_password is None:
|
||||
# We have found nothing related to other credentials
|
||||
raise SSHTunnelMissingCredentials()
|
||||
# We need to ask for the missing properties of our method # 2
|
||||
# Some times the property is just missing
|
||||
# or there're times where it's masked.
|
||||
# If both are masked, we need to return a list of errors
|
||||
# so the UI ask for both fields at the same time if needed
|
||||
exception_messages: List[str] = []
|
||||
if private_key is None or private_key == PASSWORD_MASK:
|
||||
# If we get here we need to ask for the private key
|
||||
exception_messages.append(
|
||||
"Must provide a private key for the ssh tunnel"
|
||||
)
|
||||
if (
|
||||
private_key_password is None
|
||||
or private_key_password == PASSWORD_MASK
|
||||
):
|
||||
# If we get here we need to ask for the private key password
|
||||
exception_messages.append(
|
||||
"Must provide a private key password for the ssh tunnel"
|
||||
)
|
||||
if exception_messages:
|
||||
# We can ask for just one field or both if masked, if both
|
||||
# are empty, SSHTunnelMissingCredentials was already raised
|
||||
raise ValidationError(exception_messages)
|
||||
return
|
||||
|
||||
|
||||
class EncryptedField: # pylint: disable=too-few-public-methods
|
||||
"""
|
||||
|
||||
@@ -57,3 +57,11 @@ class SSHTunnelRequiredFieldValidationError(ValidationError):
|
||||
[_("Field is required")],
|
||||
field_name=field_name,
|
||||
)
|
||||
|
||||
|
||||
class SSHTunnelMissingCredentials(CommandInvalidError):
|
||||
message = _("Must provide credentials for the SSH Tunnel")
|
||||
|
||||
|
||||
class SSHTunnelInvalidCredentials(CommandInvalidError):
|
||||
message = _("Cannot have multiple credentials for the SSH Tunnel")
|
||||
|
||||
@@ -68,6 +68,19 @@ class SSHTunnel(Model, AuditMixinNullable, ExtraJSONMixin, ImportExportMixin):
|
||||
EncryptedType(sa.String, app_config["SECRET_KEY"]), nullable=True
|
||||
)
|
||||
|
||||
export_fields = [
|
||||
"server_address",
|
||||
"server_port",
|
||||
"username",
|
||||
"password",
|
||||
"private_key",
|
||||
"private_key_password",
|
||||
]
|
||||
|
||||
extra_import_fields = [
|
||||
"database_id",
|
||||
]
|
||||
|
||||
@property
|
||||
def data(self) -> Dict[str, Any]:
|
||||
output = {
|
||||
|
||||
Reference in New Issue
Block a user