fix(re-encrypt): handle non-id PKs and make command idempotent (#40079)

This commit is contained in:
Ville Brofeldt
2026-05-12 17:59:52 -07:00
committed by GitHub
parent fa06989ed7
commit af4dc3a9aa
4 changed files with 409 additions and 47 deletions

View File

@@ -115,15 +115,19 @@ def re_encrypt_secrets(previous_secret_key: Optional[str] = None) -> None:
"PREVIOUS_SECRET_KEY"
)
if previous_secret_key is None:
click.secho("A previous secret key must be provided", err=True)
sys.exit(1)
click.secho(
"No previous secret key provided; nothing to re-encrypt.",
fg="yellow",
)
return
secrets_migrator = SecretsMigrator(previous_secret_key=previous_secret_key)
try:
secrets_migrator.run()
except ValueError as exc:
click.secho(
f"An error occurred, "
f"probably an invalid previous secret key was provided. Error:[{exc}]",
err=True,
)
stats = secrets_migrator.run()
except Exception as exc: # pylint: disable=broad-except
click.secho(f"Re-encryption failed: {exc}", err=True)
sys.exit(1)
click.secho(
f"Re-encryption complete: {stats.re_encrypted} re-encrypted, "
f"{stats.skipped} skipped, {stats.null} null, {stats.failed} failed.",
fg="green",
)

View File

@@ -16,11 +16,12 @@
# under the License.
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Optional
from flask import Flask
from flask_babel import lazy_gettext as _
from sqlalchemy import text, TypeDecorator
from sqlalchemy import Table, text, TypeDecorator
from sqlalchemy.engine import Connection, Dialect, Row
from sqlalchemy_utils import EncryptedType as SqlaEncryptedType
@@ -33,6 +34,16 @@ ENC_ADAPTER_TAG_ATTR_NAME = "__created_by_enc_field_adapter__"
logger = logging.getLogger(__name__)
@dataclass
class ReEncryptStats:
"""Per-value outcome counts for a SecretsMigrator.run() invocation."""
re_encrypted: int = 0
skipped: int = 0
null: int = 0
failed: int = 0
class AbstractEncryptedFieldAdapter(ABC): # pylint: disable=too-few-public-methods
@abstractmethod
def create(
@@ -97,11 +108,13 @@ class SecretsMigrator:
self._previous_secret_key = previous_secret_key
self._dialect: Dialect = db.engine.url.get_dialect()
def discover_encrypted_fields(self) -> dict[str, dict[str, EncryptedType]]:
def discover_encrypted_fields(
self,
) -> dict[str, tuple[Table, dict[str, EncryptedType]]]:
"""
Iterates over ORM-mapped tables, looking for EncryptedType columns
along the way. Builds up a dict of
table_name -> dict of col_name: enc type instance
table_name -> (Table, dict of col_name: enc type instance)
Superset's ORM models inherit from Flask-AppBuilder's declarative base
(`flask_appbuilder.Model`), whose MetaData is distinct from
@@ -109,13 +122,18 @@ class SecretsMigrator:
regardless of which base a model uses. FAB's metadata takes precedence
when a table name appears in both registries.
:return: mapping of table name to a dict of {column name: EncryptedType}
The Table object is returned alongside the encrypted columns so callers
can introspect the schema (notably the primary key) without assuming a
conventional `id` column — some tables (e.g. `semantic_layers`) use a
`uuid` primary key instead.
:return: mapping of table name to (Table, {column name: EncryptedType})
"""
from flask_appbuilder import ( # pylint: disable=import-outside-toplevel
Model as FABModel,
)
meta_info: dict[str, Any] = {}
meta_info: dict[str, tuple[Table, dict[str, EncryptedType]]] = {}
tables: dict[str, Any] = dict(FABModel.metadata.tables)
for table_name, table in self._db.metadata.tables.items():
@@ -124,9 +142,9 @@ class SecretsMigrator:
for table_name, table in tables.items():
for col_name, col in table.columns.items():
if isinstance(col.type, EncryptedType):
cols = meta_info.get(table_name, {})
_, cols = meta_info.get(table_name, (table, {}))
cols[col_name] = col.type
meta_info[table_name] = cols
meta_info[table_name] = (table, cols)
return meta_info
@@ -151,9 +169,13 @@ class SecretsMigrator:
@staticmethod
def _select_columns_from_table(
conn: Connection, column_names: list[str], table_name: str
conn: Connection,
pk_columns: list[str],
column_names: list[str],
table_name: str,
) -> Row:
return conn.execute(f"SELECT id, {','.join(column_names)} FROM {table_name}") # noqa: S608
cols = ",".join(pk_columns + column_names)
return conn.execute(f"SELECT {cols} FROM {table_name}") # noqa: S608
def _re_encrypt_row(
self,
@@ -161,62 +183,143 @@ class SecretsMigrator:
row: Row,
table_name: str,
columns: dict[str, EncryptedType],
pk_columns: list[str],
stats: ReEncryptStats,
) -> None:
"""
Re encrypts all columns in a Row
Re-encryption is idempotent per column: we first ask whether the
current key can already decrypt the value, and skip if so. Only if
the current key fails do we fall back to decrypting with the
previous key and re-encrypting. Checking the current key first
keeps ``run()`` idempotent regardless of what ``previous_secret_key``
the caller supplies — even re-running with the same (unchanged)
``SECRET_KEY`` will not rewrite rows.
NULL values are never encrypted, so they are reported separately
(neither re-encrypted nor "skipped because already current").
Per-column outcomes are accumulated onto ``stats`` so the caller can
report a summary. Columns whose ciphertext is unreadable under both
keys are counted as failures and logged; the exception is not
propagated, so processing continues. The caller is responsible for
raising once all rows have been scanned.
If no columns need re-encryption, no UPDATE is issued.
:param row: Current row to reencrypt
:param columns: Meta info from columns
:param pk_columns: Primary key column names used to target the row
:param stats: Mutable counters updated per column
"""
re_encrypted_columns = {}
for column_name, encrypted_type in columns.items():
raw_value = self._read_bytes(column_name, row[column_name])
# NULL values aren't encrypted; there is nothing to migrate.
if raw_value is None:
stats.null += 1
continue
# Fast path: if the current key can already read the value,
# leave it untouched. A failure here simply means we need to try
# the previous key below — not a condition worth logging.
try:
encrypted_type.process_result_value(raw_value, self._dialect)
except Exception: # noqa: BLE001, S110 # pylint: disable=broad-except
pass
else:
stats.skipped += 1
continue
# Current key cannot decrypt — try the previous key.
previous_encrypted_type = EncryptedType(
type_in=encrypted_type.underlying_type, key=self._previous_secret_key
)
try:
unencrypted_value = previous_encrypted_type.process_result_value(
self._read_bytes(column_name, row[column_name]), self._dialect
raw_value, self._dialect
)
except ValueError as ex:
# Failed to unencrypt
try:
encrypted_type.process_result_value(
self._read_bytes(column_name, row[column_name]), self._dialect
)
logger.info(
"Current secret is able to decrypt value on column [%s.%s],"
" nothing to do",
table_name,
column_name,
)
return
except Exception:
raise Exception from ex # pylint: disable=broad-exception-raised
except Exception as prev_ex: # noqa: BLE001 # pylint: disable=broad-except
logger.error(
"Column [%s.%s] cannot be decrypted under the previous"
" or current secret key (%s: %s)",
table_name,
column_name,
type(prev_ex).__name__,
prev_ex,
)
stats.failed += 1
continue
re_encrypted_columns[column_name] = encrypted_type.process_bind_param(
unencrypted_value,
self._dialect,
)
stats.re_encrypted += 1
set_cols = ",".join(
[f"{name} = :{name}" for name in list(re_encrypted_columns.keys())]
)
logger.info("Processing table: %s", table_name)
if not re_encrypted_columns:
return
set_cols = ",".join(f"{name} = :{name}" for name in re_encrypted_columns)
where_clause = " AND ".join(f"{pk} = :_pk_{pk}" for pk in pk_columns)
pk_bind = {f"_pk_{pk}": row[pk] for pk in pk_columns}
conn.execute(
text(f"UPDATE {table_name} SET {set_cols} WHERE id = :id"), # noqa: S608
id=row["id"],
text(
f"UPDATE {table_name} SET {set_cols} WHERE {where_clause}" # noqa: S608
),
**pk_bind,
**re_encrypted_columns,
)
def run(self) -> None:
def run(self) -> ReEncryptStats:
"""
Re-encrypt every encrypted column in the ORM under the current
``SECRET_KEY``.
Returns per-value counts of re-encrypted, skipped (already under the
current key), and failed (undecryptable) outcomes. If any failures
occurred the transaction is rolled back by raising after the
summary is logged, so partial re-encryption never commits.
"""
encrypted_meta_info = self.discover_encrypted_fields()
stats = ReEncryptStats()
with self._db.engine.begin() as conn:
logger.info("Collecting info for re encryption")
for table_name, columns in encrypted_meta_info.items():
for table_name, (table, columns) in encrypted_meta_info.items():
pk_columns = [c.name for c in table.primary_key.columns]
if not pk_columns:
logger.warning(
"Skipping %s: no primary key, cannot target rows for update",
table_name,
)
continue
column_names = list(columns.keys())
rows = self._select_columns_from_table(conn, column_names, table_name)
rows = self._select_columns_from_table(
conn, pk_columns, column_names, table_name
)
for row in rows:
self._re_encrypt_row(conn, row, table_name, columns)
self._re_encrypt_row(
conn, row, table_name, columns, pk_columns, stats
)
logger.info(
"Re-encryption summary: %d re-encrypted, %d skipped,"
" %d null, %d failed",
stats.re_encrypted,
stats.skipped,
stats.null,
stats.failed,
)
if stats.failed:
raise Exception( # pylint: disable=broad-exception-raised
f"Re-encryption failed for {stats.failed} value(s); "
"transaction rolled back"
)
logger.info("All tables processed")
return stats

View File

@@ -28,6 +28,7 @@ from freezegun import freeze_time
import superset.cli.importexport
import superset.cli.thumbnails
import superset.cli.update
from superset import db
from superset.models.dashboard import Dashboard
from tests.integration_tests.fixtures.birth_names_dashboard import (
@@ -322,3 +323,44 @@ def test_compute_thumbnails(thumbnail_mock, app_context, fs):
thumbnail_mock.assert_called_with(None, dashboard.id, force=False)
assert response.exit_code == 0
def test_re_encrypt_secrets_without_previous_key_is_noop(app_context):
"""
When neither --previous_secret_key nor config.PREVIOUS_SECRET_KEY is set,
the command should exit cleanly (0) rather than error out, so that
scheduled re-encryption runs don't start failing after a successful
rotation is complete.
"""
current_app.config.pop("PREVIOUS_SECRET_KEY", None)
runner = current_app.test_cli_runner()
with mock.patch.object(superset.cli.update.SecretsMigrator, "run") as run_mock:
response = runner.invoke(superset.cli.update.re_encrypt_secrets, [])
assert response.exit_code == 0
assert "nothing to re-encrypt" in response.output.lower()
run_mock.assert_not_called()
def test_re_encrypt_secrets_failure_exits_nonzero(app_context):
"""
When re-encryption fails for any field, SecretsMigrator.run raises to
trigger rollback. The CLI must surface that as a non-zero exit with a
clear error message — not as an uncaught exception.
"""
runner = current_app.test_cli_runner()
with mock.patch.object(
superset.cli.update.SecretsMigrator,
"run",
side_effect=Exception("Re-encryption failed for 2 value(s)"),
):
response = runner.invoke(
superset.cli.update.re_encrypt_secrets,
["--previous_secret_key", "old-key"],
)
assert response.exit_code == 1
assert "Re-encryption failed" in response.output
# The failure path must be handled by the CLI, not leaked as an
# uncaught exception.
assert response.exception is None or isinstance(response.exception, SystemExit)

View File

@@ -24,6 +24,7 @@ from sqlalchemy_utils.types.encrypted.encrypted_type import StringEncryptedType
from superset.extensions import encrypted_field_factory
from superset.utils.encrypt import (
AbstractEncryptedFieldAdapter,
ReEncryptStats,
SecretsMigrator,
SQLAlchemyUtilsAdapter,
)
@@ -79,7 +80,7 @@ class EncryptedFieldTest(SupersetTestCase):
migrator = SecretsMigrator("")
encrypted_fields = migrator.discover_encrypted_fields()
for table_name, cols in encrypted_fields.items():
for table_name, (_table, cols) in encrypted_fields.items():
for col_name, field in cols.items():
if not encrypted_field_factory.created_by_enc_field_factory(field):
self.fail(
@@ -101,8 +102,33 @@ class EncryptedFieldTest(SupersetTestCase):
"dbs table not found in encrypted fields — "
"discover_encrypted_fields may be using the wrong MetaData instance"
)
dbs_cols = set(encrypted_fields["dbs"].keys())
assert {"password", "encrypted_extra", "server_cert"}.issubset(dbs_cols)
_table, dbs_cols = encrypted_fields["dbs"]
assert {"password", "encrypted_extra", "server_cert"}.issubset(
set(dbs_cols.keys())
)
def test_discover_encrypted_fields_returns_table_with_non_id_pk(self):
"""
Ensure discover_encrypted_fields surfaces the Table object alongside
encrypted columns, and that the PK introspection works for tables
whose primary key is not a conventional integer `id` column
(e.g. `semantic_layers` uses `uuid` as its PK).
"""
# Import triggers FAB metadata registration for the semantic_layers table.
from superset.semantic_layers.models import SemanticLayer # noqa: F401
migrator = SecretsMigrator("")
encrypted_fields = migrator.discover_encrypted_fields()
assert "semantic_layers" in encrypted_fields, (
"semantic_layers table not found — it has an encrypted `configuration` "
"column and should be discovered"
)
table, cols = encrypted_fields["semantic_layers"]
assert "configuration" in cols
pk_columns = [c.name for c in table.primary_key.columns]
assert pk_columns == ["uuid"], (
f"Expected semantic_layers PK to be ['uuid'], got {pk_columns}"
)
def test_lazy_key_resolution(self):
"""
@@ -175,3 +201,190 @@ class EncryptedFieldTest(SupersetTestCase):
# Restore original key
self.app.config["SECRET_KEY"] = key_a
def test_re_encrypt_row_uses_pk_columns(self):
"""
Verify SecretsMigrator builds UPDATE statements targeting the table's
actual primary key columns rather than a hardcoded `id` column.
Regression guard for tables like `semantic_layers` whose PK is `uuid`.
"""
from unittest.mock import MagicMock
from sqlalchemy.engine import make_url
dialect = make_url("sqlite://").get_dialect()
previous_key = "PREVIOUS_KEY_FOR_PK_COLUMN_TEST"
migrator = SecretsMigrator(previous_key)
migrator._dialect = dialect # noqa: SLF001
# Encrypt under the previous key so the current-key decrypt fails
# and the re-encrypt path (which issues the UPDATE) is exercised.
previous_field = EncryptedType(type_in=String(1024), key=previous_key)
ciphertext = previous_field.process_bind_param("hunter2", dialect)
current_field = encrypted_field_factory.create(String(1024))
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": ciphertext}
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
conn,
row,
"semantic_layers",
{"configuration": current_field},
["uuid"],
stats,
)
assert conn.execute.call_count == 1
stmt = str(conn.execute.call_args.args[0])
assert "WHERE uuid = :_pk_uuid" in stmt
kwargs = conn.execute.call_args.kwargs
assert kwargs["_pk_uuid"] == row["uuid"]
assert "configuration" in kwargs
assert stats == ReEncryptStats(re_encrypted=1, skipped=0, failed=0)
def test_re_encrypt_row_is_idempotent(self):
"""
Re-running re-encryption on a row that is already encrypted under the
current key must be a no-op: no UPDATE is issued, no error is raised,
and the outcome is counted as skipped.
"""
from unittest.mock import MagicMock
from sqlalchemy.engine import make_url
dialect = make_url("sqlite://").get_dialect()
current_key = self.app.config["SECRET_KEY"]
migrator = SecretsMigrator("WRONG_PREVIOUS_KEY_abcdef")
migrator._dialect = dialect # noqa: SLF001
field = encrypted_field_factory.create(String(1024))
ciphertext = field.process_bind_param("hunter2", dialect)
assert field.process_result_value(ciphertext, dialect) == "hunter2"
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": ciphertext}
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
conn,
row,
"semantic_layers",
{"configuration": field},
["uuid"],
stats,
)
assert conn.execute.call_count == 0, (
"Row already readable under current key should not trigger UPDATE"
)
assert stats == ReEncryptStats(re_encrypted=0, skipped=1, failed=0)
# Current key must still decrypt the original ciphertext — nothing
# was mutated.
self.app.config["SECRET_KEY"] = current_key
assert field.process_result_value(ciphertext, dialect) == "hunter2"
def test_re_encrypt_row_idempotent_when_previous_key_also_decrypts(self):
"""
When the supplied previous_secret_key can also decrypt the value
(e.g. re-running after a successful rotation while still passing
the original secret, or mistakenly passing the current secret as
the previous one), the row must still be skipped. Idempotency is
anchored on whether the current key can already read the data,
not on whether the previous key fails to decrypt.
"""
from unittest.mock import MagicMock
from sqlalchemy.engine import make_url
dialect = make_url("sqlite://").get_dialect()
# Previous key == current key — this is the "re-run with no actual
# rotation" scenario.
migrator = SecretsMigrator(self.app.config["SECRET_KEY"])
migrator._dialect = dialect # noqa: SLF001
field = encrypted_field_factory.create(String(1024))
ciphertext = field.process_bind_param("hunter2", dialect)
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": ciphertext}
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
conn,
row,
"semantic_layers",
{"configuration": field},
["uuid"],
stats,
)
assert conn.execute.call_count == 0, (
"Idempotency must hold even when previous_secret_key can also "
"decrypt the value"
)
assert stats == ReEncryptStats(re_encrypted=0, skipped=1, failed=0)
def test_re_encrypt_row_counts_failures_without_raising(self):
"""
Per-column failures are accumulated onto the stats counter so the
caller can emit a summary covering every row. The row method itself
must not raise — run() decides whether to abort based on the totals.
"""
from unittest.mock import MagicMock
from sqlalchemy.engine import make_url
dialect = make_url("sqlite://").get_dialect()
migrator = SecretsMigrator("WRONG_PREVIOUS_KEY_abcdef")
migrator._dialect = dialect # noqa: SLF001
field = encrypted_field_factory.create(String(1024))
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": b"not-valid-ciphertext"}
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
conn,
row,
"semantic_layers",
{"configuration": field},
["uuid"],
stats,
)
assert conn.execute.call_count == 0
assert stats == ReEncryptStats(re_encrypted=0, skipped=0, failed=1)
def test_re_encrypt_row_counts_nulls_separately(self):
"""
NULL column values are not encrypted and therefore have nothing to
migrate. They must be counted as ``null`` (not ``skipped``) and
must not trigger an UPDATE, regardless of which key is supplied as
the previous secret.
"""
from unittest.mock import MagicMock
from sqlalchemy.engine import make_url
dialect = make_url("sqlite://").get_dialect()
migrator = SecretsMigrator("WRONG_PREVIOUS_KEY_abcdef")
migrator._dialect = dialect # noqa: SLF001
field = encrypted_field_factory.create(String(1024))
conn = MagicMock()
row = {"uuid": b"\x00" * 16, "configuration": None}
stats = ReEncryptStats()
migrator._re_encrypt_row( # noqa: SLF001
conn,
row,
"semantic_layers",
{"configuration": field},
["uuid"],
stats,
)
assert conn.execute.call_count == 0
assert stats == ReEncryptStats(re_encrypted=0, skipped=0, null=1, failed=0)