Compare commits

...

5 Commits

Author SHA1 Message Date
Evan
ac207678ae test(ssh_tunnel): assert socket connect in host-key mismatch test
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-03 11:14:07 -07:00
Evan
afeac5abf3 fix(ssh_tunnel): bound host-key connect timeout and preserve verification error message
Build the verification socket via socket.create_connection with an explicit
timeout so the TCP connect phase is bounded (previously only the SSH handshake
was), and propagate SSHTunnelHostKeyVerificationError through the
test-connection/create paths so its specific message is preserved instead of
being flattened into a generic connection failure.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-03 11:14:07 -07:00
Evan
21e40f73cb fix(migrations): avoid eager-loading ssh_tunnel in catalog upgrade
The Database model eager-loads (lazy="joined") the ssh_tunnel backref, so
session.query(Database) SELECTs every ssh_tunnels column. On a fresh DB the
catalog migration (revision 58d051681a3b) runs before later migrations that
add new ssh_tunnels columns (e.g. server_host_key), causing
"no such column: ssh_tunnels.server_host_key" and aborting the upgrade.

Disable eager-loading of the relationship in the catalog upgrade/downgrade
queries, which only need scalar Database columns, keeping them schema-safe
across migration revisions.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-03 11:14:07 -07:00
Evan
24a0ad6d46 fix(ssh_tunnel): use unique migration revision id and update filter SQL snapshot
The new ssh_tunnels migration reused the existing revision id a1b2c3d4e5f6
(already used by add_granular_export_permissions), which created an Alembic
revision cycle (33d7e0e21daa -> a1b2c3d4e5f6 -> ce6bd21901ab -> a1b2c3d4e5f6)
and broke every job that runs db upgrade. Assign a unique revision id and
rename the file accordingly.

Also update test_database_filter's compiled-SQL snapshot to include the new
ssh_tunnels.server_host_key column.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-03 11:14:07 -07:00
Claude Code
9cf5e5387c feat(ssh_tunnel): add opt-in server host key verification
paramiko's transport performs no known-hosts checking by default, so the
SSH server's identity was not verified when opening a tunnel. This adds an
opt-in, defense-in-depth host-key pinning mechanism so operators can verify
the SSH server's host key and reject mismatches (MITM-resistance), without
changing behavior for existing tunnels.

- Add nullable `server_host_key` (Text, plaintext public key) to the
  `SSHTunnel` model + Alembic migration off head 33d7e0e21daa.
- Add `SSH_TUNNEL_STRICT_HOST_KEY_CHECKING` config flag (default False):
  when True a tunnel without an expected key is rejected (fail-closed).
- Add `SSHManager._verify_host_key`, run before `sshtunnel.open_tunnel`,
  which connects via `paramiko.Transport`, reads the presented host key,
  parses the expected authorized-key value via
  `paramiko.PKey.from_type_string`, and compares (whitespace/comment
  tolerant). On mismatch raise `SSHTunnelHostKeyVerificationError`.
- Expose `server_host_key` in the SSH tunnel schema (not sensitive, not
  masked); threads through create/update via the DAO.
- Unit tests and UPDATING.md runbook.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
2026-06-03 11:14:07 -07:00
13 changed files with 416 additions and 12 deletions

View File

@@ -34,6 +34,20 @@ The embedded dashboard page now validates the origin of incoming `postMessage` e
Enforcement only applies when the Allowed Domains list is non-empty. If the list is empty (the default), any origin is accepted, so there is no behavior change for embeds that did not configure Allowed Domains.
### Opt-in SSH tunnel server host key verification
SSH tunnels can now optionally pin the expected SSH server host key as a defense-in-depth measure against man-in-the-middle attacks. paramiko's transport performs no known-hosts checking by default, so previously the SSH server's identity was not verified. This feature is opt-in and off by default; existing tunnels are unaffected.
- A new nullable `server_host_key` column on the `ssh_tunnels` table stores the expected host key in authorized-key form (e.g. `ssh-ed25519 AAAA...`). It is a public key and is stored in plaintext. It can be set via the SSH tunnel POST/PUT payloads (`ssh_tunnel.server_host_key`).
- When a tunnel has `server_host_key` set, Superset connects to the SSH server, reads the host key it presents, and rejects the tunnel if it does not match.
- A new config flag `SSH_TUNNEL_STRICT_HOST_KEY_CHECKING` (default `False`) controls fail-closed behavior. When `True`, every tunnel must declare a `server_host_key`; a tunnel without one is rejected.
Runbook to adopt:
1. Capture the SSH server's host key, e.g. `ssh-keyscan -t ed25519 ssh.example.com` (verify it out-of-band).
2. Set that value on the tunnel's `server_host_key` (via the database/SSH tunnel API or UI payload).
3. Optionally set `SSH_TUNNEL_STRICT_HOST_KEY_CHECKING = True` in `superset_config.py` to require host-key verification on all tunnels.
### Dataset import validates catalog against the target connection
Importing a dataset now validates the `catalog` field against the target database connection. When the connection has multi-catalog disabled (`allow_multi_catalog` off) and the dataset's catalog is not the connection's default catalog, the import fails instead of silently persisting the non-default catalog. This matches the validation already enforced on the dataset update path and prevents imported datasets from querying an unintended database.

View File

@@ -33,6 +33,7 @@ from superset.commands.database.exceptions import (
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelCreateFailedError,
SSHTunnelDatabasePortError,
SSHTunnelHostKeyVerificationError,
SSHTunnelingNotEnabledError,
SSHTunnelInvalidError,
)
@@ -75,6 +76,7 @@ class CreateDatabaseCommand(BaseCommand):
SupersetErrorsException,
SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
SSHTunnelHostKeyVerificationError,
) as ex:
event_logger.log_with_context(
action=f"db_creation_failed.{ex.__class__.__name__}",

View File

@@ -75,3 +75,9 @@ class SSHTunnelMissingCredentials(CommandInvalidError, SSHTunnelError): # noqa:
class SSHTunnelInvalidCredentials(CommandInvalidError, SSHTunnelError): # noqa: N818
message = _("Cannot have multiple credentials for the SSH Tunnel")
class SSHTunnelHostKeyVerificationError(CommandInvalidError, SSHTunnelError):
message = _(
"The SSH server host key could not be verified against the expected key."
)

View File

@@ -29,6 +29,7 @@ from superset.commands.database.exceptions import (
)
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelHostKeyVerificationError,
SSHTunnelingNotEnabledError,
)
from superset.commands.database.utils import ping
@@ -221,7 +222,11 @@ class TestConnectionDatabaseCommand(BaseCommand):
engine=engine_name,
)
raise DatabaseSecurityUnsafeError(message=str(ex)) from ex
except (SupersetTimeoutException, SSHTunnelingNotEnabledError) as ex:
except (
SupersetTimeoutException,
SSHTunnelingNotEnabledError,
SSHTunnelHostKeyVerificationError,
) as ex:
event_logger.log_with_context(
action=get_log_connection_action(
"test_connection_error",
@@ -230,7 +235,8 @@ class TestConnectionDatabaseCommand(BaseCommand):
),
engine=engine_name,
)
# bubble up the exception to return proper status code
# bubble up the exception (preserving its specific message and status)
# instead of flattening it into a generic connection failure
raise
except Exception as ex:
if not database:

View File

@@ -838,6 +838,15 @@ SSH_TUNNEL_TIMEOUT_SEC = 10.0
#: Timeout (seconds) for transport socket (``socket.settimeout``)
SSH_TUNNEL_PACKET_TIMEOUT_SEC = 1.0
#: Opt-in defense-in-depth: when enabled, every SSH tunnel must declare an expected
#: server host key (``server_host_key`` on the tunnel) and the SSH server's presented
#: host key is verified against it before the tunnel is opened. A mismatch, or a
#: missing expected key while this flag is enabled, fails closed and the tunnel is
#: rejected. When disabled (the default), tunnels without a ``server_host_key`` open
#: without host-key verification, preserving existing behavior; tunnels that do set a
#: ``server_host_key`` are still verified regardless of this flag.
SSH_TUNNEL_STRICT_HOST_KEY_CHECKING: bool = False
# Feature flags may also be set via 'SUPERSET_FEATURE_' prefixed environment vars.
DEFAULT_FEATURE_FLAGS.update(

View File

@@ -56,6 +56,7 @@ from superset.commands.database.importers.dispatcher import ImportDatabasesComma
from superset.commands.database.oauth2 import OAuth2StoreTokenCommand
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelHostKeyVerificationError,
SSHTunnelingNotEnabledError,
)
from superset.commands.database.sync_permissions import SyncPermissionsCommand
@@ -483,7 +484,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
exc_info=True,
)
return self.response_422(message=str(ex))
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
except (
SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
SSHTunnelHostKeyVerificationError,
) as ex:
return self.response_400(message=str(ex))
except SupersetException as ex:
return self.response(ex.status, message=ex.message)
@@ -568,7 +573,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
exc_info=True,
)
return self.response_422(message=str(ex))
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
except (
SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
SSHTunnelHostKeyVerificationError,
) as ex:
return self.response_400(message=str(ex))
@expose("/<int:pk>", methods=("DELETE",))
@@ -1280,7 +1289,11 @@ class DatabaseRestApi(BaseSupersetModelRestApi):
try:
TestConnectionDatabaseCommand(item).run()
return self.response(200, message="OK")
except (SSHTunnelingNotEnabledError, SSHTunnelDatabasePortError) as ex:
except (
SSHTunnelingNotEnabledError,
SSHTunnelDatabasePortError,
SSHTunnelHostKeyVerificationError,
) as ex:
return self.response_400(message=str(ex))
@expose("/<int:pk>/related_objects/", methods=("GET",))

View File

@@ -474,6 +474,22 @@ class DatabaseSSHTunnel(Schema):
private_key = fields.String(required=False)
private_key_password = fields.String(required=False)
# Optional expected SSH server host key in authorized-key form
# (e.g. "ssh-rsa AAAA...", "ssh-ed25519 AAAA..."). When set, the SSH server's
# presented host key is verified against it before the tunnel is opened. This is
# a public key, so it is not sensitive and is not masked.
server_host_key = fields.String(
required=False,
allow_none=True,
metadata={
"description": (
"Expected SSH server host key in authorized-key form "
"(e.g. 'ssh-ed25519 AAAA...'). When set, the server's host key is "
"verified against it before the tunnel is opened."
)
},
)
@validates_schema
def validate_authentication(self, data: dict[str, Any], **kwargs: Any) -> None:
errors: dict[str, str] = {}

View File

@@ -72,6 +72,12 @@ class SSHTunnel(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
encrypted_field_factory.create(Text), nullable=True
)
# Optional expected SSH server host key, in authorized-key form
# (e.g. "ssh-rsa AAAA...", "ssh-ed25519 AAAA..."). When set, the SSH server's
# presented host key is verified against this value before the tunnel is opened.
# This is a public key, so it is stored in plaintext (not encrypted).
server_host_key = sa.Column(sa.Text, nullable=True)
export_fields = [
"server_address",
"server_port",
@@ -79,6 +85,7 @@ class SSHTunnel(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
"password",
"private_key",
"private_key_password",
"server_host_key",
]
extra_import_fields = [
@@ -93,6 +100,9 @@ class SSHTunnel(AuditMixinNullable, ExtraJSONMixin, ImportExportMixin, Model):
"server_port": self.server_port,
"username": self.username,
}
if self.server_host_key is not None:
# public key, not sensitive: returned in cleartext
output["server_host_key"] = self.server_host_key
if self.password is not None:
output["password"] = PASSWORD_MASK
if self.private_key is not None:

View File

@@ -15,26 +15,57 @@
# specific language governing permissions and limitations
# under the License.
import base64
import binascii
import logging
import socket
from io import StringIO
from typing import TYPE_CHECKING
import paramiko
import sshtunnel
from flask import Flask
from paramiko import RSAKey
from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelDatabasePortError,
SSHTunnelHostKeyVerificationError,
)
from superset.databases.utils import make_url_safe
from superset.utils.class_utils import load_class_from_name
if TYPE_CHECKING:
from superset.databases.ssh_tunnel.models import SSHTunnel
logger = logging.getLogger(__name__)
def _parse_authorized_key(authorized_key: str) -> paramiko.PKey:
"""
Parse a host key in authorized-key form (``"<type> <base64>[ comment]"``) into a
:class:`paramiko.PKey`. The optional trailing comment field and surrounding
whitespace are ignored.
:raises ValueError: if the value is empty or cannot be parsed as a host key.
"""
fields = authorized_key.strip().split()
if len(fields) < 2:
raise ValueError("Host key must be in 'ssh-<type> <base64>' form")
key_type, key_b64 = fields[0], fields[1]
try:
key_bytes = base64.b64decode(key_b64)
except (binascii.Error, ValueError) as ex:
raise ValueError("Host key base64 payload could not be decoded") from ex
return paramiko.PKey.from_type_string(key_type, key_bytes)
class SSHManager:
def __init__(self, app: Flask) -> None:
super().__init__()
self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"]
self.strict_host_key_checking = app.config.get(
"SSH_TUNNEL_STRICT_HOST_KEY_CHECKING", False
)
sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"]
sshtunnel.SSH_TIMEOUT = app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]
@@ -48,6 +79,81 @@ class SSHManager:
port=server.local_bind_port,
)
def _verify_host_key(self, ssh_tunnel: "SSHTunnel") -> None:
"""
Opt-in defense-in-depth: verify the SSH server's host key before opening the
tunnel, to resist man-in-the-middle attacks (paramiko's ``Transport`` does no
known-hosts checking by default).
Behavior:
- If the tunnel declares an expected ``server_host_key``, connect to the SSH
server, read the host key it presents, and compare. On mismatch (or if the
expected key cannot be parsed) raise
:class:`SSHTunnelHostKeyVerificationError`.
- If no expected key is set and ``SSH_TUNNEL_STRICT_HOST_KEY_CHECKING`` is
enabled, fail closed and raise.
- If no expected key is set and strict checking is disabled, do nothing,
preserving existing (unverified) behavior.
"""
expected_raw = ssh_tunnel.server_host_key
if not expected_raw or not expected_raw.strip():
if self.strict_host_key_checking:
raise SSHTunnelHostKeyVerificationError(
message=(
"SSH_TUNNEL_STRICT_HOST_KEY_CHECKING is enabled but no "
"expected server host key is configured for this tunnel."
)
)
return
try:
expected_key = _parse_authorized_key(expected_raw)
except ValueError as ex:
raise SSHTunnelHostKeyVerificationError(
message=f"The configured expected server host key is invalid: {ex}"
) from ex
# Build the socket ourselves with an explicit timeout so the TCP connect
# phase is bounded too. ``paramiko.Transport((host, port))`` would connect
# synchronously with no timeout, leaving ``start_client(timeout=...)`` to
# govern only the SSH handshake; an unreachable host could then block for the
# full OS-level TCP timeout.
try:
sock = socket.create_connection(
(ssh_tunnel.server_address, ssh_tunnel.server_port),
timeout=sshtunnel.SSH_TIMEOUT,
)
except OSError as ex:
raise SSHTunnelHostKeyVerificationError(
message=f"Could not connect to the SSH server: {ex}"
) from ex
transport = paramiko.Transport(sock)
try:
transport.start_client(timeout=sshtunnel.SSH_TIMEOUT)
remote_key = transport.get_remote_server_key()
except Exception as ex: # noqa: BLE001
raise SSHTunnelHostKeyVerificationError(
message=f"Could not retrieve the SSH server host key: {ex}"
) from ex
finally:
transport.close()
if remote_key != expected_key:
logger.warning(
"SSH host key mismatch for %s:%s",
ssh_tunnel.server_address,
ssh_tunnel.server_port,
)
raise SSHTunnelHostKeyVerificationError(
message=(
"The SSH server presented a host key that does not match the "
"expected server host key configured for this tunnel."
)
)
def create_tunnel(
self,
ssh_tunnel: "SSHTunnel",
@@ -60,6 +166,10 @@ class SSHManager:
port = url.port or get_default_port(backend)
if not port:
raise SSHTunnelDatabasePortError()
# Opt-in host-key verification runs before the tunnel is opened.
self._verify_host_key(ssh_tunnel)
params = {
"ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port),
"ssh_username": ssh_tunnel.username,

View File

@@ -25,7 +25,7 @@ import sqlalchemy as sa
from alembic import op
from flask import current_app
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session
from sqlalchemy.orm import lazyload, Session
from superset import db, security_manager
from superset.db_engine_specs.base import GenericDBException
@@ -379,7 +379,15 @@ def upgrade_catalog_perms(engines: set[str] | None = None) -> None:
bind = op.get_bind()
session = db.Session(bind=bind)
for database in session.query(Database).all():
# The Database model has an eager-loaded (``lazy="joined"``) ``ssh_tunnel``
# backref. Eager-loading it here would SELECT every column on ``ssh_tunnels``,
# including columns added by later migrations that do not yet exist at the
# revision this helper runs in (e.g. on a fresh DB upgraded in one pass). The
# catalog upgrade only needs scalar ``Database`` columns, so disable the eager
# join to keep the query schema-safe across migration revisions.
for database in (
session.query(Database).options(lazyload(Database.ssh_tunnel)).all()
):
db_engine_spec = database.db_engine_spec
if (
engines and db_engine_spec.engine not in engines
@@ -576,7 +584,11 @@ def downgrade_catalog_perms(engines: set[str] | None = None) -> None:
bind = op.get_bind()
session = db.Session(bind=bind)
for database in session.query(Database).all():
# See upgrade_catalog_perms: avoid eager-loading the ``ssh_tunnel`` backref so the
# query stays schema-safe across migration revisions.
for database in (
session.query(Database).options(lazyload(Database.ssh_tunnel)).all()
):
db_engine_spec = database.db_engine_spec
if (
engines and db_engine_spec.engine not in engines

View File

@@ -0,0 +1,48 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""add server_host_key to ssh_tunnels
Adds a nullable ``server_host_key`` column to the ``ssh_tunnels`` table. It stores the
expected SSH server host key in authorized-key form (e.g. "ssh-ed25519 AAAA...") so
operators can opt in to verifying the SSH server's host key before a tunnel is opened.
This is a public key and is stored in plaintext (not encrypted). The column is
nullable, so existing tunnels are unaffected.
Revision ID: 78a40c08b4be
Revises: 33d7e0e21daa
Create Date: 2026-06-01 10:00:00.000000
"""
import sqlalchemy as sa
from superset.migrations.shared.utils import add_columns, drop_columns
# revision identifiers, used by Alembic.
revision = "78a40c08b4be"
down_revision = "33d7e0e21daa"
def upgrade():
add_columns(
"ssh_tunnels",
sa.Column("server_host_key", sa.Text(), nullable=True),
)
def downgrade():
drop_columns("ssh_tunnels", "server_host_key")

View File

@@ -119,5 +119,5 @@ def test_database_filter(mocker: MockerFixture) -> None:
)
assert (
str(compiled_query)
== "SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk, ssh_tunnels_1.uuid AS uuid_1, ssh_tunnels_1.created_on AS created_on_1, ssh_tunnels_1.changed_on AS changed_on_1, ssh_tunnels_1.extra_json, ssh_tunnels_1.id AS id_1, ssh_tunnels_1.database_id, ssh_tunnels_1.server_address, ssh_tunnels_1.server_port, ssh_tunnels_1.username, ssh_tunnels_1.password AS password_1, ssh_tunnels_1.private_key, ssh_tunnels_1.private_key_password, ssh_tunnels_1.created_by_fk AS created_by_fk_1, ssh_tunnels_1.changed_by_fk AS changed_by_fk_1 \nFROM dbs LEFT OUTER JOIN ssh_tunnels AS ssh_tunnels_1 ON dbs.id = ssh_tunnels_1.database_id \nWHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')" # noqa: E501
== "SELECT dbs.uuid, dbs.created_on, dbs.changed_on, dbs.id, dbs.verbose_name, dbs.database_name, dbs.sqlalchemy_uri, dbs.password, dbs.cache_timeout, dbs.select_as_create_table_as, dbs.expose_in_sqllab, dbs.configuration_method, dbs.allow_run_async, dbs.allow_file_upload, dbs.allow_ctas, dbs.allow_cvas, dbs.allow_dml, dbs.force_ctas_schema, dbs.extra, dbs.encrypted_extra, dbs.impersonate_user, dbs.server_cert, dbs.is_managed_externally, dbs.external_url, dbs.created_by_fk, dbs.changed_by_fk, ssh_tunnels_1.uuid AS uuid_1, ssh_tunnels_1.created_on AS created_on_1, ssh_tunnels_1.changed_on AS changed_on_1, ssh_tunnels_1.extra_json, ssh_tunnels_1.id AS id_1, ssh_tunnels_1.database_id, ssh_tunnels_1.server_address, ssh_tunnels_1.server_port, ssh_tunnels_1.username, ssh_tunnels_1.password AS password_1, ssh_tunnels_1.private_key, ssh_tunnels_1.private_key_password, ssh_tunnels_1.server_host_key, ssh_tunnels_1.created_by_fk AS created_by_fk_1, ssh_tunnels_1.changed_by_fk AS changed_by_fk_1 \nFROM dbs LEFT OUTER JOIN ssh_tunnels AS ssh_tunnels_1 ON dbs.id = ssh_tunnels_1.database_id \nWHERE '[' || dbs.database_name || '].(id:' || CAST(dbs.id AS VARCHAR) || ')' IN ('[my_db].(id:42)', '[my_other_db].(id:43)') OR dbs.database_name IN ('my_db', 'my_other_db', 'third_db')" # noqa: E501
)

View File

@@ -14,11 +14,41 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from unittest.mock import Mock
from unittest.mock import Mock, patch
import paramiko
import pytest
import sshtunnel
from superset.extensions.ssh import SSHManagerFactory
from superset.commands.database.ssh_tunnel.exceptions import (
SSHTunnelHostKeyVerificationError,
)
from superset.extensions.ssh import SSHManager, SSHManagerFactory
def _make_manager(strict: bool = False) -> SSHManager:
app = Mock()
app.config = {
"SSH_TUNNEL_MAX_RETRIES": 2,
"SSH_TUNNEL_LOCAL_BIND_ADDRESS": "127.0.0.1",
"SSH_TUNNEL_TIMEOUT_SEC": 123.0,
"SSH_TUNNEL_PACKET_TIMEOUT_SEC": 321.0,
"SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager",
"SSH_TUNNEL_STRICT_HOST_KEY_CHECKING": strict,
}
return SSHManager(app)
def _authorized_key(key: paramiko.PKey) -> str:
return f"{key.get_name()} {key.get_base64()}"
def _ssh_tunnel(server_host_key: str | None) -> Mock:
tunnel = Mock()
tunnel.server_address = "ssh.example.com"
tunnel.server_port = 22
tunnel.server_host_key = server_host_key
return tunnel
def test_ssh_tunnel_timeout_setting() -> None:
@@ -34,3 +64,131 @@ def test_ssh_tunnel_timeout_setting() -> None:
factory.init_app(app)
assert sshtunnel.TUNNEL_TIMEOUT == 123.0
assert sshtunnel.SSH_TIMEOUT == 321.0
@patch("superset.extensions.ssh.socket.create_connection")
@patch("superset.extensions.ssh.paramiko.Transport")
def test_verify_host_key_match(
mock_transport_cls: Mock, mock_create_connection: Mock
) -> None:
# The server presents the same key we expect: verification passes.
server_key = paramiko.RSAKey.generate(2048)
manager = _make_manager(strict=False)
tunnel = _ssh_tunnel(_authorized_key(server_key))
transport = mock_transport_cls.return_value
transport.get_remote_server_key.return_value = server_key
manager._verify_host_key(tunnel) # should not raise
# The TCP connect is bounded by an explicit timeout, and the resulting
# socket is handed to Transport.
mock_create_connection.assert_called_once_with(
("ssh.example.com", 22), timeout=321.0
)
mock_transport_cls.assert_called_once_with(mock_create_connection.return_value)
transport.start_client.assert_called_once()
transport.close.assert_called_once()
@patch("superset.extensions.ssh.socket.create_connection")
@patch("superset.extensions.ssh.paramiko.Transport")
def test_verify_host_key_mismatch_raises(
mock_transport_cls: Mock, mock_create_connection: Mock
) -> None:
# The server presents a different key than expected: verification fails.
expected_key = paramiko.RSAKey.generate(2048)
presented_key = paramiko.RSAKey.generate(2048)
manager = _make_manager(strict=False)
tunnel = _ssh_tunnel(_authorized_key(expected_key))
transport = mock_transport_cls.return_value
transport.get_remote_server_key.return_value = presented_key
with pytest.raises(SSHTunnelHostKeyVerificationError):
manager._verify_host_key(tunnel)
mock_create_connection.assert_called_once()
transport.close.assert_called_once()
@patch("superset.extensions.ssh.socket.create_connection")
def test_verify_host_key_connect_failure_raises(
mock_create_connection: Mock,
) -> None:
# A bounded TCP connect failure surfaces as a host-key verification error.
manager = _make_manager(strict=False)
server_key = paramiko.RSAKey.generate(2048)
tunnel = _ssh_tunnel(_authorized_key(server_key))
mock_create_connection.side_effect = OSError("connection refused")
with pytest.raises(SSHTunnelHostKeyVerificationError):
manager._verify_host_key(tunnel)
@patch("superset.extensions.ssh.paramiko.Transport")
def test_verify_host_key_unset_non_strict_skips(mock_transport_cls: Mock) -> None:
# Back-compat: no expected key + strict checking off => no verification at all.
manager = _make_manager(strict=False)
tunnel = _ssh_tunnel(None)
manager._verify_host_key(tunnel) # should not raise
mock_transport_cls.assert_not_called()
@patch("superset.extensions.ssh.paramiko.Transport")
def test_verify_host_key_unset_strict_raises(mock_transport_cls: Mock) -> None:
# Fail-closed: no expected key + strict checking on => reject.
manager = _make_manager(strict=True)
tunnel = _ssh_tunnel(None)
with pytest.raises(SSHTunnelHostKeyVerificationError):
manager._verify_host_key(tunnel)
mock_transport_cls.assert_not_called()
@patch("superset.extensions.ssh.socket.create_connection")
@patch("superset.extensions.ssh.paramiko.Transport")
def test_verify_host_key_match_ignores_comment_and_whitespace(
mock_transport_cls: Mock,
mock_create_connection: Mock,
) -> None:
# The stored key may carry a trailing comment and extra whitespace.
server_key = paramiko.RSAKey.generate(2048)
manager = _make_manager(strict=False)
stored = f" {_authorized_key(server_key)} user@host "
tunnel = _ssh_tunnel(stored)
transport = mock_transport_cls.return_value
transport.get_remote_server_key.return_value = server_key
manager._verify_host_key(tunnel) # should not raise
def test_verify_host_key_invalid_expected_raises() -> None:
# A malformed expected key is rejected before any network connection.
manager = _make_manager(strict=False)
tunnel = _ssh_tunnel("not-a-valid-key")
with pytest.raises(SSHTunnelHostKeyVerificationError):
manager._verify_host_key(tunnel)
def test_ssh_tunnel_schema_round_trips_server_host_key() -> None:
# The schema accepts and preserves the public host key field.
from superset.databases.schemas import DatabaseSSHTunnel
server_key = paramiko.RSAKey.generate(2048)
authorized = _authorized_key(server_key)
payload = {
"server_address": "ssh.example.com",
"server_port": 22,
"username": "user",
"password": "secret",
"server_host_key": authorized,
}
loaded = DatabaseSSHTunnel().load(payload)
assert loaded["server_host_key"] == authorized