mirror of
https://github.com/apache/superset.git
synced 2026-06-11 18:49:15 +00:00
Compare commits
5 Commits
url-param-
...
feat/ssh-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ac207678ae | ||
|
|
afeac5abf3 | ||
|
|
21e40f73cb | ||
|
|
24a0ad6d46 | ||
|
|
9cf5e5387c |
14
UPDATING.md
14
UPDATING.md
@@ -34,6 +34,20 @@ The embedded dashboard page now validates the origin of incoming `postMessage` e
|
||||
|
||||
Enforcement only applies when the Allowed Domains list is non-empty. If the list is empty (the default), any origin is accepted, so there is no behavior change for embeds that did not configure Allowed Domains.
|
||||
|
||||
### Opt-in SSH tunnel server host key verification
|
||||
|
||||
SSH tunnels can now optionally pin the expected SSH server host key as a defense-in-depth measure against man-in-the-middle attacks. paramiko's transport performs no known-hosts checking by default, so previously the SSH server's identity was not verified. This feature is opt-in and off by default; existing tunnels are unaffected.
|
||||
|
||||
- A new nullable `server_host_key` column on the `ssh_tunnels` table stores the expected host key in authorized-key form (e.g. `ssh-ed25519 AAAA...`). It is a public key and is stored in plaintext. It can be set via the SSH tunnel POST/PUT payloads (`ssh_tunnel.server_host_key`).
|
||||
- When a tunnel has `server_host_key` set, Superset connects to the SSH server, reads the host key it presents, and rejects the tunnel if it does not match.
|
||||
- A new config flag `SSH_TUNNEL_STRICT_HOST_KEY_CHECKING` (default `False`) controls fail-closed behavior. When `True`, every tunnel must declare a `server_host_key`; a tunnel without one is rejected.
|
||||
|
||||
Runbook to adopt:
|
||||
|
||||
1. Capture the SSH server's host key, e.g. `ssh-keyscan -t ed25519 ssh.example.com` (verify it out-of-band).
|
||||
2. Set that value on the tunnel's `server_host_key` (via the database/SSH tunnel API or UI payload).
|
||||
3. Optionally set `SSH_TUNNEL_STRICT_HOST_KEY_CHECKING = True` in `superset_config.py` to require host-key verification on all tunnels.
|
||||
|
||||
### Dataset import validates catalog against the target connection
|
||||
|
||||
Importing a dataset now validates the `catalog` field against the target database connection. When the connection has multi-catalog disabled (`allow_multi_catalog` off) and the dataset's catalog is not the connection's default catalog, the import fails instead of silently persisting the non-default catalog. This matches the validation already enforced on the dataset update path and prevents imported datasets from querying an unintended database.
|
||||
|
||||
@@ -33,6 +33,7 @@ from superset.commands.database.exceptions import (
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelCreateFailedError,
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelInvalidError,
|
||||
)
|
||||
@@ -75,6 +76,7 @@ class CreateDatabaseCommand(BaseCommand):
|
||||
SupersetErrorsException,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
) as ex:
|
||||
event_logger.log_with_context(
|
||||
action=f"db_creation_failed.{ex.__class__.__name__}",
|
||||
|
||||
@@ -75,3 +75,9 @@ class SSHTunnelMissingCredentials(CommandInvalidError, SSHTunnelError): # noqa:
|
||||
|
||||
class SSHTunnelInvalidCredentials(CommandInvalidError, SSHTunnelError): # noqa: N818
|
||||
message = _("Cannot have multiple credentials for the SSH Tunnel")
|
||||
|
||||
|
||||
class SSHTunnelHostKeyVerificationError(CommandInvalidError, SSHTunnelError):
|
||||
message = _(
|
||||
"The SSH server host key could not be verified against the expected key."
|
||||
)
|
||||
|
||||
@@ -29,6 +29,7 @@ from superset.commands.database.exceptions import (
|
||||
)
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
)
|
||||
from superset.commands.database.utils import ping
|
||||
@@ -221,7 +222,11 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||
engine=engine_name,
|
||||
)
|
||||
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
|
||||
except (SupersetTimeoutException, SSHTunnelingNotEnabledError) as ex:
|
||||
except (
|
||||
SupersetTimeoutException,
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
) as ex:
|
||||
event_logger.log_with_context(
|
||||
action=get_log_connection_action(
|
||||
"test_connection_error",
|
||||
@@ -230,7 +235,8 @@ class TestConnectionDatabaseCommand(BaseCommand):
|
||||
),
|
||||
engine=engine_name,
|
||||
)
|
||||
# bubble up the exception to return proper status code
|
||||
# bubble up the exception (preserving its specific message and status)
|
||||
# instead of flattening it into a generic connection failure
|
||||
raise
|
||||
except Exception as ex:
|
||||
if not database:
|
||||
|
||||
@@ -838,6 +838,15 @@ SSH_TUNNEL_TIMEOUT_SEC = 10.0
|
||||
#: Timeout (seconds) for transport socket (``socket.settimeout``)
|
||||
SSH_TUNNEL_PACKET_TIMEOUT_SEC = 1.0
|
||||
|
||||
#: Opt-in defense-in-depth: when enabled, every SSH tunnel must declare an expected
|
||||
#: server host key (``server_host_key`` on the tunnel) and the SSH server's presented
|
||||
#: host key is verified against it before the tunnel is opened. A mismatch, or a
|
||||
#: missing expected key while this flag is enabled, fails closed and the tunnel is
|
||||
#: rejected. When disabled (the default), tunnels without a ``server_host_key`` open
|
||||
#: without host-key verification, preserving existing behavior; tunnels that do set a
|
||||
#: ``server_host_key`` are still verified regardless of this flag.
|
||||
SSH_TUNNEL_STRICT_HOST_KEY_CHECKING: bool = False
|
||||
|
||||
|
||||
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
|
||||
DEFAULT_FEATURE_FLAGS.update(
|
||||
|
||||
@@ -56,6 +56,7 @@ from superset.commands.database.importers.dispatcher import ImportDatabasesComma
|
||||
from superset.commands.database.oauth2 import OAuth2StoreTokenCommand
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
SSHTunnelingNotEnabledError,
|
||||
)
|
||||
from superset.commands.database.sync_permissions import SyncPermissionsCommand
|
||||
@@ -483,7 +484,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
|
||||
except (
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
) as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
except SupersetException as ex:
|
||||
return self.response(ex.status, message=ex.message)
|
||||
@@ -568,7 +573,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
exc_info=True,
|
||||
)
|
||||
return self.response_422(message=str(ex))
|
||||
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
|
||||
except (
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
) as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>", methods=("DELETE",))
|
||||
@@ -1280,7 +1289,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
|
||||
try:
|
||||
TestConnectionDatabaseCommand(item).run()
|
||||
return self.response(200, message="OK")
|
||||
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
|
||||
except (
|
||||
SSHTunnelingNotEnabledError,
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
) as ex:
|
||||
return self.response_400(message=str(ex))
|
||||
|
||||
@expose("/<int:pk>/related_objects/", methods=("GET",))
|
||||
|
||||
@@ -474,6 +474,22 @@ class DatabaseSSHTunnel(Schema):
|
||||
private_key = fields.String(required=False)
|
||||
private_key_password = fields.String(required=False)
|
||||
|
||||
# Optional expected SSH server host key in authorized-key form
|
||||
# (e.g. "ssh-rsa AAAA...", "ssh-ed25519 AAAA..."). When set, the SSH server's
|
||||
# presented host key is verified against it before the tunnel is opened. This is
|
||||
# a public key, so it is not sensitive and is not masked.
|
||||
server_host_key = fields.String(
|
||||
required=False,
|
||||
allow_none=True,
|
||||
metadata={
|
||||
"description": (
|
||||
"Expected SSH server host key in authorized-key form "
|
||||
"(e.g. 'ssh-ed25519 AAAA...'). When set, the server's host key is "
|
||||
"verified against it before the tunnel is opened."
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
@validates_schema
|
||||
def validate_authentication(self, data: dict[str, Any], **kwargs: Any) -> None:
|
||||
errors: dict[str, str] = {}
|
||||
|
||||
@@ -72,6 +72,12 @@ class SSHTunnel(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
|
||||
encrypted_field_factory.create(Text), nullable=True
|
||||
)
|
||||
|
||||
# Optional expected SSH server host key, in authorized-key form
|
||||
# (e.g. "ssh-rsa AAAA...", "ssh-ed25519 AAAA..."). When set, the SSH server's
|
||||
# presented host key is verified against this value before the tunnel is opened.
|
||||
# This is a public key, so it is stored in plaintext (not encrypted).
|
||||
server_host_key = sa.Column(sa.Text, nullable=True)
|
||||
|
||||
export_fields = [
|
||||
"server_address",
|
||||
"server_port",
|
||||
@@ -79,6 +85,7 @@ class SSHTunnel(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
|
||||
"password",
|
||||
"private_key",
|
||||
"private_key_password",
|
||||
"server_host_key",
|
||||
]
|
||||
|
||||
extra_import_fields = [
|
||||
@@ -93,6 +100,9 @@ class SSHTunnel(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
|
||||
"server_port": self.server_port,
|
||||
"username": self.username,
|
||||
}
|
||||
if self.server_host_key is not None:
|
||||
# public key, not sensitive: returned in cleartext
|
||||
output["server_host_key"] = self.server_host_key
|
||||
if self.password is not None:
|
||||
output["password"] = PASSWORD_MASK
|
||||
if self.private_key is not None:
|
||||
|
||||
@@ -15,26 +15,57 @@
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import logging
|
||||
import socket
|
||||
from io import StringIO
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import paramiko
|
||||
import sshtunnel
|
||||
from flask import Flask
|
||||
from paramiko import RSAKey
|
||||
|
||||
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelDatabasePortError,
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
)
|
||||
from superset.databases.utils import make_url_safe
|
||||
from superset.utils.class_utils import load_class_from_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from superset.databases.ssh_tunnel.models import SSHTunnel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _parse_authorized_key(authorized_key: str) -> paramiko.PKey:
|
||||
"""
|
||||
Parse a host key in authorized-key form (``"<type> <base64>[ comment]"``) into a
|
||||
:class:`paramiko.PKey`. The optional trailing comment field and surrounding
|
||||
whitespace are ignored.
|
||||
|
||||
:raises ValueError: if the value is empty or cannot be parsed as a host key.
|
||||
"""
|
||||
fields = authorized_key.strip().split()
|
||||
if len(fields) < 2:
|
||||
raise ValueError("Host key must be in 'ssh-<type> <base64>' form")
|
||||
key_type, key_b64 = fields[0], fields[1]
|
||||
try:
|
||||
key_bytes = base64.b64decode(key_b64)
|
||||
except (binascii.Error, ValueError) as ex:
|
||||
raise ValueError("Host key base64 payload could not be decoded") from ex
|
||||
return paramiko.PKey.from_type_string(key_type, key_bytes)
|
||||
|
||||
|
||||
class SSHManager:
|
||||
def __init__(self, app: Flask) -> None:
|
||||
super().__init__()
|
||||
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
|
||||
self.strict_host_key_checking = app.config.get(
|
||||
"SSH_TUNNEL_STRICT_HOST_KEY_CHECKING", False
|
||||
)
|
||||
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]
|
||||
sshtunnel.SSH_TIMEOUT = app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]
|
||||
|
||||
@@ -48,6 +79,81 @@ class SSHManager:
|
||||
port=server.local_bind_port,
|
||||
)
|
||||
|
||||
def _verify_host_key(self, ssh_tunnel: "SSHTunnel") -> None:
|
||||
"""
|
||||
Opt-in defense-in-depth: verify the SSH server's host key before opening the
|
||||
tunnel, to resist man-in-the-middle attacks (paramiko's ``Transport`` does no
|
||||
known-hosts checking by default).
|
||||
|
||||
Behavior:
|
||||
|
||||
- If the tunnel declares an expected ``server_host_key``, connect to the SSH
|
||||
server, read the host key it presents, and compare. On mismatch (or if the
|
||||
expected key cannot be parsed) raise
|
||||
:class:`SSHTunnelHostKeyVerificationError`.
|
||||
- If no expected key is set and ``SSH_TUNNEL_STRICT_HOST_KEY_CHECKING`` is
|
||||
enabled, fail closed and raise.
|
||||
- If no expected key is set and strict checking is disabled, do nothing,
|
||||
preserving existing (unverified) behavior.
|
||||
"""
|
||||
expected_raw = ssh_tunnel.server_host_key
|
||||
|
||||
if not expected_raw or not expected_raw.strip():
|
||||
if self.strict_host_key_checking:
|
||||
raise SSHTunnelHostKeyVerificationError(
|
||||
message=(
|
||||
"SSH_TUNNEL_STRICT_HOST_KEY_CHECKING is enabled but no "
|
||||
"expected server host key is configured for this tunnel."
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
expected_key = _parse_authorized_key(expected_raw)
|
||||
except ValueError as ex:
|
||||
raise SSHTunnelHostKeyVerificationError(
|
||||
message=f"The configured expected server host key is invalid: {ex}"
|
||||
) from ex
|
||||
|
||||
# Build the socket ourselves with an explicit timeout so the TCP connect
|
||||
# phase is bounded too. ``paramiko.Transport((host, port))`` would connect
|
||||
# synchronously with no timeout, leaving ``start_client(timeout=...)`` to
|
||||
# govern only the SSH handshake; an unreachable host could then block for the
|
||||
# full OS-level TCP timeout.
|
||||
try:
|
||||
sock = socket.create_connection(
|
||||
(ssh_tunnel.server_address, ssh_tunnel.server_port),
|
||||
timeout=sshtunnel.SSH_TIMEOUT,
|
||||
)
|
||||
except OSError as ex:
|
||||
raise SSHTunnelHostKeyVerificationError(
|
||||
message=f"Could not connect to the SSH server: {ex}"
|
||||
) from ex
|
||||
|
||||
transport = paramiko.Transport(sock)
|
||||
try:
|
||||
transport.start_client(timeout=sshtunnel.SSH_TIMEOUT)
|
||||
remote_key = transport.get_remote_server_key()
|
||||
except Exception as ex: # noqa: BLE001
|
||||
raise SSHTunnelHostKeyVerificationError(
|
||||
message=f"Could not retrieve the SSH server host key: {ex}"
|
||||
) from ex
|
||||
finally:
|
||||
transport.close()
|
||||
|
||||
if remote_key != expected_key:
|
||||
logger.warning(
|
||||
"SSH host key mismatch for %s:%s",
|
||||
ssh_tunnel.server_address,
|
||||
ssh_tunnel.server_port,
|
||||
)
|
||||
raise SSHTunnelHostKeyVerificationError(
|
||||
message=(
|
||||
"The SSH server presented a host key that does not match the "
|
||||
"expected server host key configured for this tunnel."
|
||||
)
|
||||
)
|
||||
|
||||
def create_tunnel(
|
||||
self,
|
||||
ssh_tunnel: "SSHTunnel",
|
||||
@@ -60,6 +166,10 @@ class SSHManager:
|
||||
port = url.port or get_default_port(backend)
|
||||
if not port:
|
||||
raise SSHTunnelDatabasePortError()
|
||||
|
||||
# Opt-in host-key verification runs before the tunnel is opened.
|
||||
self._verify_host_key(ssh_tunnel)
|
||||
|
||||
params = {
|
||||
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
|
||||
"ssh_username": ssh_tunnel.username,
|
||||
|
||||
@@ -25,7 +25,7 @@ import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from flask import current_app
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.orm import lazyload, Session
|
||||
|
||||
from superset import db, security_manager
|
||||
from superset.db_engine_specs.base import GenericDBException
|
||||
@@ -379,7 +379,15 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
||||
bind = op.get_bind()
|
||||
session = db.Session(bind=bind)
|
||||
|
||||
for database in session.query(Database).all():
|
||||
# The Database model has an eager-loaded (``lazy="joined"``) ``ssh_tunnel``
|
||||
# backref. Eager-loading it here would SELECT every column on ``ssh_tunnels``,
|
||||
# including columns added by later migrations that do not yet exist at the
|
||||
# revision this helper runs in (e.g. on a fresh DB upgraded in one pass). The
|
||||
# catalog upgrade only needs scalar ``Database`` columns, so disable the eager
|
||||
# join to keep the query schema-safe across migration revisions.
|
||||
for database in (
|
||||
session.query(Database).options(lazyload(Database.ssh_tunnel)).all()
|
||||
):
|
||||
db_engine_spec = database.db_engine_spec
|
||||
if (
|
||||
engines and db_engine_spec.engine not in engines
|
||||
@@ -576,7 +584,11 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
|
||||
bind = op.get_bind()
|
||||
session = db.Session(bind=bind)
|
||||
|
||||
for database in session.query(Database).all():
|
||||
# See upgrade_catalog_perms: avoid eager-loading the ``ssh_tunnel`` backref so the
|
||||
# query stays schema-safe across migration revisions.
|
||||
for database in (
|
||||
session.query(Database).options(lazyload(Database.ssh_tunnel)).all()
|
||||
):
|
||||
db_engine_spec = database.db_engine_spec
|
||||
if (
|
||||
engines and db_engine_spec.engine not in engines
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
# Licensed to the Apache Software Foundation (ASF) under one
|
||||
# or more contributor license agreements. See the NOTICE file
|
||||
# distributed with this work for additional information
|
||||
# regarding copyright ownership. The ASF licenses this file
|
||||
# to you under the Apache License, Version 2.0 (the
|
||||
# "License"); you may not use this file except in compliance
|
||||
# with the License. You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing,
|
||||
# software distributed under the License is distributed on an
|
||||
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""add server_host_key to ssh_tunnels
|
||||
|
||||
Adds a nullable ``server_host_key`` column to the ``ssh_tunnels`` table. It stores the
|
||||
expected SSH server host key in authorized-key form (e.g. "ssh-ed25519 AAAA...") so
|
||||
operators can opt in to verifying the SSH server's host key before a tunnel is opened.
|
||||
This is a public key and is stored in plaintext (not encrypted). The column is
|
||||
nullable, so existing tunnels are unaffected.
|
||||
|
||||
Revision ID: 78a40c08b4be
|
||||
Revises: 33d7e0e21daa
|
||||
Create Date: 2026-06-01 10:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from superset.migrations.shared.utils import add_columns, drop_columns
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "78a40c08b4be"
|
||||
down_revision = "33d7e0e21daa"
|
||||
|
||||
|
||||
def upgrade():
|
||||
add_columns(
|
||||
"ssh_tunnels",
|
||||
sa.Column("server_host_key", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
drop_columns("ssh_tunnels", "server_host_key")
|
||||
@@ -119,5 +119,5 @@ def test_database_filter(mocker: MockerFixture) -> None:
|
||||
)
|
||||
assert (
|
||||
str(compiled_query)
|
||||
== "SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk, ssh_tunnels_1.uuid AS uuid_1, ssh_tunnels_1.created_on AS created_on_1, ssh_tunnels_1.changed_on AS changed_on_1, ssh_tunnels_1.extra_json, ssh_tunnels_1.id AS id_1, ssh_tunnels_1.database_id, ssh_tunnels_1.server_address, ssh_tunnels_1.server_port, ssh_tunnels_1.username, ssh_tunnels_1.password AS password_1, ssh_tunnels_1.private_key, ssh_tunnels_1.private_key_password, ssh_tunnels_1.created_by_fk AS created_by_fk_1, ssh_tunnels_1.changed_by_fk AS changed_by_fk_1 \nFROM dbs LEFT OUTER JOIN ssh_tunnels AS ssh_tunnels_1 ON dbs.id = ssh_tunnels_1.database_id \nWHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')" # noqa: E501
|
||||
== "SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk, ssh_tunnels_1.uuid AS uuid_1, ssh_tunnels_1.created_on AS created_on_1, ssh_tunnels_1.changed_on AS changed_on_1, ssh_tunnels_1.extra_json, ssh_tunnels_1.id AS id_1, ssh_tunnels_1.database_id, ssh_tunnels_1.server_address, ssh_tunnels_1.server_port, ssh_tunnels_1.username, ssh_tunnels_1.password AS password_1, ssh_tunnels_1.private_key, ssh_tunnels_1.private_key_password, ssh_tunnels_1.server_host_key, ssh_tunnels_1.created_by_fk AS created_by_fk_1, ssh_tunnels_1.changed_by_fk AS changed_by_fk_1 \nFROM dbs LEFT OUTER JOIN ssh_tunnels AS ssh_tunnels_1 ON dbs.id = ssh_tunnels_1.database_id \nWHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')" # noqa: E501
|
||||
)
|
||||
|
||||
@@ -14,11 +14,41 @@
|
||||
# KIND, either express or implied. See the License for the
|
||||
# specific language governing permissions and limitations
|
||||
# under the License.
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import paramiko
|
||||
import pytest
|
||||
import sshtunnel
|
||||
|
||||
from superset.extensions.ssh import SSHManagerFactory
|
||||
from superset.commands.database.ssh_tunnel.exceptions import (
|
||||
SSHTunnelHostKeyVerificationError,
|
||||
)
|
||||
from superset.extensions.ssh import SSHManager, SSHManagerFactory
|
||||
|
||||
|
||||
def _make_manager(strict: bool = False) -> SSHManager:
|
||||
app = Mock()
|
||||
app.config = {
|
||||
"SSH_TUNNEL_MAX_RETRIES": 2,
|
||||
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "127.0.0.1",
|
||||
"SSH_TUNNEL_TIMEOUT_SEC": 123.0,
|
||||
"SSH_TUNNEL_PACKET_TIMEOUT_SEC": 321.0,
|
||||
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
|
||||
"SSH_TUNNEL_STRICT_HOST_KEY_CHECKING": strict,
|
||||
}
|
||||
return SSHManager(app)
|
||||
|
||||
|
||||
def _authorized_key(key: paramiko.PKey) -> str:
|
||||
return f"{key.get_name()} {key.get_base64()}"
|
||||
|
||||
|
||||
def _ssh_tunnel(server_host_key: str | None) -> Mock:
|
||||
tunnel = Mock()
|
||||
tunnel.server_address = "ssh.example.com"
|
||||
tunnel.server_port = 22
|
||||
tunnel.server_host_key = server_host_key
|
||||
return tunnel
|
||||
|
||||
|
||||
def test_ssh_tunnel_timeout_setting() -> None:
|
||||
@@ -34,3 +64,131 @@ def test_ssh_tunnel_timeout_setting() -> None:
|
||||
factory.init_app(app)
|
||||
assert sshtunnel.TUNNEL_TIMEOUT == 123.0
|
||||
assert sshtunnel.SSH_TIMEOUT == 321.0
|
||||
|
||||
|
||||
@patch("superset.extensions.ssh.socket.create_connection")
|
||||
@patch("superset.extensions.ssh.paramiko.Transport")
|
||||
def test_verify_host_key_match(
|
||||
mock_transport_cls: Mock, mock_create_connection: Mock
|
||||
) -> None:
|
||||
# The server presents the same key we expect: verification passes.
|
||||
server_key = paramiko.RSAKey.generate(2048)
|
||||
manager = _make_manager(strict=False)
|
||||
tunnel = _ssh_tunnel(_authorized_key(server_key))
|
||||
|
||||
transport = mock_transport_cls.return_value
|
||||
transport.get_remote_server_key.return_value = server_key
|
||||
|
||||
manager._verify_host_key(tunnel) # should not raise
|
||||
|
||||
# The TCP connect is bounded by an explicit timeout, and the resulting
|
||||
# socket is handed to Transport.
|
||||
mock_create_connection.assert_called_once_with(
|
||||
("ssh.example.com", 22), timeout=321.0
|
||||
)
|
||||
mock_transport_cls.assert_called_once_with(mock_create_connection.return_value)
|
||||
transport.start_client.assert_called_once()
|
||||
transport.close.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.extensions.ssh.socket.create_connection")
|
||||
@patch("superset.extensions.ssh.paramiko.Transport")
|
||||
def test_verify_host_key_mismatch_raises(
|
||||
mock_transport_cls: Mock, mock_create_connection: Mock
|
||||
) -> None:
|
||||
# The server presents a different key than expected: verification fails.
|
||||
expected_key = paramiko.RSAKey.generate(2048)
|
||||
presented_key = paramiko.RSAKey.generate(2048)
|
||||
manager = _make_manager(strict=False)
|
||||
tunnel = _ssh_tunnel(_authorized_key(expected_key))
|
||||
|
||||
transport = mock_transport_cls.return_value
|
||||
transport.get_remote_server_key.return_value = presented_key
|
||||
|
||||
with pytest.raises(SSHTunnelHostKeyVerificationError):
|
||||
manager._verify_host_key(tunnel)
|
||||
|
||||
mock_create_connection.assert_called_once()
|
||||
transport.close.assert_called_once()
|
||||
|
||||
|
||||
@patch("superset.extensions.ssh.socket.create_connection")
|
||||
def test_verify_host_key_connect_failure_raises(
|
||||
mock_create_connection: Mock,
|
||||
) -> None:
|
||||
# A bounded TCP connect failure surfaces as a host-key verification error.
|
||||
manager = _make_manager(strict=False)
|
||||
server_key = paramiko.RSAKey.generate(2048)
|
||||
tunnel = _ssh_tunnel(_authorized_key(server_key))
|
||||
|
||||
mock_create_connection.side_effect = OSError("connection refused")
|
||||
|
||||
with pytest.raises(SSHTunnelHostKeyVerificationError):
|
||||
manager._verify_host_key(tunnel)
|
||||
|
||||
|
||||
@patch("superset.extensions.ssh.paramiko.Transport")
|
||||
def test_verify_host_key_unset_non_strict_skips(mock_transport_cls: Mock) -> None:
|
||||
# Back-compat: no expected key + strict checking off => no verification at all.
|
||||
manager = _make_manager(strict=False)
|
||||
tunnel = _ssh_tunnel(None)
|
||||
|
||||
manager._verify_host_key(tunnel) # should not raise
|
||||
|
||||
mock_transport_cls.assert_not_called()
|
||||
|
||||
|
||||
@patch("superset.extensions.ssh.paramiko.Transport")
|
||||
def test_verify_host_key_unset_strict_raises(mock_transport_cls: Mock) -> None:
|
||||
# Fail-closed: no expected key + strict checking on => reject.
|
||||
manager = _make_manager(strict=True)
|
||||
tunnel = _ssh_tunnel(None)
|
||||
|
||||
with pytest.raises(SSHTunnelHostKeyVerificationError):
|
||||
manager._verify_host_key(tunnel)
|
||||
|
||||
mock_transport_cls.assert_not_called()
|
||||
|
||||
|
||||
@patch("superset.extensions.ssh.socket.create_connection")
|
||||
@patch("superset.extensions.ssh.paramiko.Transport")
|
||||
def test_verify_host_key_match_ignores_comment_and_whitespace(
|
||||
mock_transport_cls: Mock,
|
||||
mock_create_connection: Mock,
|
||||
) -> None:
|
||||
# The stored key may carry a trailing comment and extra whitespace.
|
||||
server_key = paramiko.RSAKey.generate(2048)
|
||||
manager = _make_manager(strict=False)
|
||||
stored = f" {_authorized_key(server_key)} user@host "
|
||||
tunnel = _ssh_tunnel(stored)
|
||||
|
||||
transport = mock_transport_cls.return_value
|
||||
transport.get_remote_server_key.return_value = server_key
|
||||
|
||||
manager._verify_host_key(tunnel) # should not raise
|
||||
|
||||
|
||||
def test_verify_host_key_invalid_expected_raises() -> None:
|
||||
# A malformed expected key is rejected before any network connection.
|
||||
manager = _make_manager(strict=False)
|
||||
tunnel = _ssh_tunnel("not-a-valid-key")
|
||||
|
||||
with pytest.raises(SSHTunnelHostKeyVerificationError):
|
||||
manager._verify_host_key(tunnel)
|
||||
|
||||
|
||||
def test_ssh_tunnel_schema_round_trips_server_host_key() -> None:
|
||||
# The schema accepts and preserves the public host key field.
|
||||
from superset.databases.schemas import DatabaseSSHTunnel
|
||||
|
||||
server_key = paramiko.RSAKey.generate(2048)
|
||||
authorized = _authorized_key(server_key)
|
||||
payload = {
|
||||
"server_address": "ssh.example.com",
|
||||
"server_port": 22,
|
||||
"username": "user",
|
||||
"password": "secret",
|
||||
"server_host_key": authorized,
|
||||
}
|
||||
loaded = DatabaseSSHTunnel().load(payload)
|
||||
assert loaded["server_host_key"] == authorized
|
||||
|
||||
Reference in New Issue
Block a user