diff --git a/superset/cli/update.py b/superset/cli/update.py index c162bb1e56e..e35e394c325 100755 --- a/superset/cli/update.py +++ b/superset/cli/update.py @@ -115,15 +115,19 @@ def re_encrypt_secrets(previous_secret_key: Optional[str] = None) -> None: "PREVIOUS_SECRET_KEY" ) if previous_secret_key is None: - click.secho("A previous secret key must be provided", err=True) - sys.exit(1) + click.secho( + "No previous secret key provided; nothing to re-encrypt.", + fg="yellow", + ) + return secrets_migrator = SecretsMigrator(previous_secret_key=previous_secret_key) try: - secrets_migrator.run() - except ValueError as exc: - click.secho( - f"An error occurred, " - f"probably an invalid previous secret key was provided. Error:[{exc}]", - err=True, - ) + stats = secrets_migrator.run() + except Exception as exc: # pylint: disable=broad-except + click.secho(f"Re-encryption failed: {exc}", err=True) sys.exit(1) + click.secho( + f"Re-encryption complete: {stats.re_encrypted} re-encrypted, " + f"{stats.skipped} skipped, {stats.null} null, {stats.failed} failed.", + fg="green", + ) diff --git a/superset/utils/encrypt.py b/superset/utils/encrypt.py index bf1c5b159c4..963e72f9858 100644 --- a/superset/utils/encrypt.py +++ b/superset/utils/encrypt.py @@ -16,11 +16,12 @@ # under the License. import logging from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import Any, Optional from flask import Flask from flask_babel import lazy_gettext as _ -from sqlalchemy import text, TypeDecorator +from sqlalchemy import Table, text, TypeDecorator from sqlalchemy.engine import Connection, Dialect, Row from sqlalchemy_utils import EncryptedType as SqlaEncryptedType @@ -33,6 +34,16 @@ ENC_ADAPTER_TAG_ATTR_NAME = "__created_by_enc_field_adapter__" logger = logging.getLogger(__name__) +@dataclass +class ReEncryptStats: + """Per-value outcome counts for a SecretsMigrator.run() invocation.""" + + re_encrypted: int = 0 + skipped: int = 0 + null: int = 0 + failed: int = 0 + + class AbstractEncryptedFieldAdapter(ABC): # pylint: disable=too-few-public-methods @abstractmethod def create( @@ -97,11 +108,13 @@ class SecretsMigrator: self._previous_secret_key = previous_secret_key self._dialect: Dialect = db.engine.url.get_dialect() - def discover_encrypted_fields(self) -> dict[str, dict[str, EncryptedType]]: + def discover_encrypted_fields( + self, + ) -> dict[str, tuple[Table, dict[str, EncryptedType]]]: """ Iterates over ORM-mapped tables, looking for EncryptedType columns along the way. Builds up a dict of - table_name -> dict of col_name: enc type instance + table_name -> (Table, dict of col_name: enc type instance) Superset's ORM models inherit from Flask-AppBuilder's declarative base (`flask_appbuilder.Model`), whose MetaData is distinct from @@ -109,13 +122,18 @@ class SecretsMigrator: regardless of which base a model uses. FAB's metadata takes precedence when a table name appears in both registries. - :return: mapping of table name to a dict of {column name: EncryptedType} + The Table object is returned alongside the encrypted columns so callers + can introspect the schema (notably the primary key) without assuming a + conventional `id` column — some tables (e.g. `semantic_layers`) use a + `uuid` primary key instead. + + :return: mapping of table name to (Table, {column name: EncryptedType}) """ from flask_appbuilder import ( # pylint: disable=import-outside-toplevel Model as FABModel, ) - meta_info: dict[str, Any] = {} + meta_info: dict[str, tuple[Table, dict[str, EncryptedType]]] = {} tables: dict[str, Any] = dict(FABModel.metadata.tables) for table_name, table in self._db.metadata.tables.items(): @@ -124,9 +142,9 @@ class SecretsMigrator: for table_name, table in tables.items(): for col_name, col in table.columns.items(): if isinstance(col.type, EncryptedType): - cols = meta_info.get(table_name, {}) + _, cols = meta_info.get(table_name, (table, {})) cols[col_name] = col.type - meta_info[table_name] = cols + meta_info[table_name] = (table, cols) return meta_info @@ -151,9 +169,13 @@ class SecretsMigrator: @staticmethod def _select_columns_from_table( - conn: Connection, column_names: list[str], table_name: str + conn: Connection, + pk_columns: list[str], + column_names: list[str], + table_name: str, ) -> Row: - return conn.execute(f"SELECT id, {','.join(column_names)} FROM {table_name}") # noqa: S608 + cols = ",".join(pk_columns + column_names) + return conn.execute(f"SELECT {cols} FROM {table_name}") # noqa: S608 def _re_encrypt_row( self, @@ -161,62 +183,143 @@ class SecretsMigrator: row: Row, table_name: str, columns: dict[str, EncryptedType], + pk_columns: list[str], + stats: ReEncryptStats, ) -> None: """ Re encrypts all columns in a Row + + Re-encryption is idempotent per column: we first ask whether the + current key can already decrypt the value, and skip if so. Only if + the current key fails do we fall back to decrypting with the + previous key and re-encrypting. Checking the current key first + keeps ``run()`` idempotent regardless of what ``previous_secret_key`` + the caller supplies — even re-running with the same (unchanged) + ``SECRET_KEY`` will not rewrite rows. + + NULL values are never encrypted, so they are reported separately + (neither re-encrypted nor "skipped because already current"). + + Per-column outcomes are accumulated onto ``stats`` so the caller can + report a summary. Columns whose ciphertext is unreadable under both + keys are counted as failures and logged; the exception is not + propagated, so processing continues. The caller is responsible for + raising once all rows have been scanned. + + If no columns need re-encryption, no UPDATE is issued. + :param row: Current row to reencrypt :param columns: Meta info from columns + :param pk_columns: Primary key column names used to target the row + :param stats: Mutable counters updated per column """ re_encrypted_columns = {} for column_name, encrypted_type in columns.items(): + raw_value = self._read_bytes(column_name, row[column_name]) + + # NULL values aren't encrypted; there is nothing to migrate. + if raw_value is None: + stats.null += 1 + continue + + # Fast path: if the current key can already read the value, + # leave it untouched. A failure here simply means we need to try + # the previous key below — not a condition worth logging. + try: + encrypted_type.process_result_value(raw_value, self._dialect) + except Exception: # noqa: BLE001, S110 # pylint: disable=broad-except + pass + else: + stats.skipped += 1 + continue + + # Current key cannot decrypt — try the previous key. previous_encrypted_type = EncryptedType( type_in=encrypted_type.underlying_type, key=self._previous_secret_key ) try: unencrypted_value = previous_encrypted_type.process_result_value( - self._read_bytes(column_name, row[column_name]), self._dialect + raw_value, self._dialect ) - except ValueError as ex: - # Failed to unencrypt - try: - encrypted_type.process_result_value( - self._read_bytes(column_name, row[column_name]), self._dialect - ) - logger.info( - "Current secret is able to decrypt value on column [%s.%s]," - " nothing to do", - table_name, - column_name, - ) - return - except Exception: - raise Exception from ex # pylint: disable=broad-exception-raised + except Exception as prev_ex: # noqa: BLE001 # pylint: disable=broad-except + logger.error( + "Column [%s.%s] cannot be decrypted under the previous" + " or current secret key (%s: %s)", + table_name, + column_name, + type(prev_ex).__name__, + prev_ex, + ) + stats.failed += 1 + continue re_encrypted_columns[column_name] = encrypted_type.process_bind_param( unencrypted_value, self._dialect, ) + stats.re_encrypted += 1 - set_cols = ",".join( - [f"{name} = :{name}" for name in list(re_encrypted_columns.keys())] - ) - logger.info("Processing table: %s", table_name) + if not re_encrypted_columns: + return + + set_cols = ",".join(f"{name} = :{name}" for name in re_encrypted_columns) + where_clause = " AND ".join(f"{pk} = :_pk_{pk}" for pk in pk_columns) + pk_bind = {f"_pk_{pk}": row[pk] for pk in pk_columns} conn.execute( - text(f"UPDATE {table_name} SET {set_cols} WHERE id = :id"), # noqa: S608 - id=row["id"], + text( + f"UPDATE {table_name} SET {set_cols} WHERE {where_clause}" # noqa: S608 + ), + **pk_bind, **re_encrypted_columns, ) - def run(self) -> None: + def run(self) -> ReEncryptStats: + """ + Re-encrypt every encrypted column in the ORM under the current + ``SECRET_KEY``. + + Returns per-value counts of re-encrypted, skipped (already under the + current key), and failed (undecryptable) outcomes. If any failures + occurred the transaction is rolled back by raising after the + summary is logged, so partial re-encryption never commits. + """ encrypted_meta_info = self.discover_encrypted_fields() + stats = ReEncryptStats() with self._db.engine.begin() as conn: logger.info("Collecting info for re encryption") - for table_name, columns in encrypted_meta_info.items(): + for table_name, (table, columns) in encrypted_meta_info.items(): + pk_columns = [c.name for c in table.primary_key.columns] + if not pk_columns: + logger.warning( + "Skipping %s: no primary key, cannot target rows for update", + table_name, + ) + continue column_names = list(columns.keys()) - rows = self._select_columns_from_table(conn, column_names, table_name) + rows = self._select_columns_from_table( + conn, pk_columns, column_names, table_name + ) for row in rows: - self._re_encrypt_row(conn, row, table_name, columns) + self._re_encrypt_row( + conn, row, table_name, columns, pk_columns, stats + ) + + logger.info( + "Re-encryption summary: %d re-encrypted, %d skipped," + " %d null, %d failed", + stats.re_encrypted, + stats.skipped, + stats.null, + stats.failed, + ) + if stats.failed: + raise Exception( # pylint: disable=broad-exception-raised + f"Re-encryption failed for {stats.failed} value(s); " + "transaction rolled back" + ) + logger.info("All tables processed") + return stats diff --git a/tests/integration_tests/cli_tests.py b/tests/integration_tests/cli_tests.py index 328e41e147f..c41e88decc3 100644 --- a/tests/integration_tests/cli_tests.py +++ b/tests/integration_tests/cli_tests.py @@ -28,6 +28,7 @@ from freezegun import freeze_time import superset.cli.importexport import superset.cli.thumbnails +import superset.cli.update from superset import db from superset.models.dashboard import Dashboard from tests.integration_tests.fixtures.birth_names_dashboard import ( @@ -322,3 +323,44 @@ def test_compute_thumbnails(thumbnail_mock, app_context, fs): thumbnail_mock.assert_called_with(None, dashboard.id, force=False) assert response.exit_code == 0 + + +def test_re_encrypt_secrets_without_previous_key_is_noop(app_context): + """ + When neither --previous_secret_key nor config.PREVIOUS_SECRET_KEY is set, + the command should exit cleanly (0) rather than error out, so that + scheduled re-encryption runs don't start failing after a successful + rotation is complete. + """ + current_app.config.pop("PREVIOUS_SECRET_KEY", None) + runner = current_app.test_cli_runner() + with mock.patch.object(superset.cli.update.SecretsMigrator, "run") as run_mock: + response = runner.invoke(superset.cli.update.re_encrypt_secrets, []) + + assert response.exit_code == 0 + assert "nothing to re-encrypt" in response.output.lower() + run_mock.assert_not_called() + + +def test_re_encrypt_secrets_failure_exits_nonzero(app_context): + """ + When re-encryption fails for any field, SecretsMigrator.run raises to + trigger rollback. The CLI must surface that as a non-zero exit with a + clear error message — not as an uncaught exception. + """ + runner = current_app.test_cli_runner() + with mock.patch.object( + superset.cli.update.SecretsMigrator, + "run", + side_effect=Exception("Re-encryption failed for 2 value(s)"), + ): + response = runner.invoke( + superset.cli.update.re_encrypt_secrets, + ["--previous_secret_key", "old-key"], + ) + + assert response.exit_code == 1 + assert "Re-encryption failed" in response.output + # The failure path must be handled by the CLI, not leaked as an + # uncaught exception. + assert response.exception is None or isinstance(response.exception, SystemExit) diff --git a/tests/integration_tests/utils/encrypt_tests.py b/tests/integration_tests/utils/encrypt_tests.py index e791411ca4e..5c88d43ecef 100644 --- a/tests/integration_tests/utils/encrypt_tests.py +++ b/tests/integration_tests/utils/encrypt_tests.py @@ -24,6 +24,7 @@ from sqlalchemy_utils.types.encrypted.encrypted_type import StringEncryptedType from superset.extensions import encrypted_field_factory from superset.utils.encrypt import ( AbstractEncryptedFieldAdapter, + ReEncryptStats, SecretsMigrator, SQLAlchemyUtilsAdapter, ) @@ -79,7 +80,7 @@ class EncryptedFieldTest(SupersetTestCase): migrator = SecretsMigrator("") encrypted_fields = migrator.discover_encrypted_fields() - for table_name, cols in encrypted_fields.items(): + for table_name, (_table, cols) in encrypted_fields.items(): for col_name, field in cols.items(): if not encrypted_field_factory.created_by_enc_field_factory(field): self.fail( @@ -101,8 +102,33 @@ class EncryptedFieldTest(SupersetTestCase): "dbs table not found in encrypted fields — " "discover_encrypted_fields may be using the wrong MetaData instance" ) - dbs_cols = set(encrypted_fields["dbs"].keys()) - assert {"password", "encrypted_extra", "server_cert"}.issubset(dbs_cols) + _table, dbs_cols = encrypted_fields["dbs"] + assert {"password", "encrypted_extra", "server_cert"}.issubset( + set(dbs_cols.keys()) + ) + + def test_discover_encrypted_fields_returns_table_with_non_id_pk(self): + """ + Ensure discover_encrypted_fields surfaces the Table object alongside + encrypted columns, and that the PK introspection works for tables + whose primary key is not a conventional integer `id` column + (e.g. `semantic_layers` uses `uuid` as its PK). + """ + # Import triggers FAB metadata registration for the semantic_layers table. + from superset.semantic_layers.models import SemanticLayer # noqa: F401 + + migrator = SecretsMigrator("") + encrypted_fields = migrator.discover_encrypted_fields() + assert "semantic_layers" in encrypted_fields, ( + "semantic_layers table not found — it has an encrypted `configuration` " + "column and should be discovered" + ) + table, cols = encrypted_fields["semantic_layers"] + assert "configuration" in cols + pk_columns = [c.name for c in table.primary_key.columns] + assert pk_columns == ["uuid"], ( + f"Expected semantic_layers PK to be ['uuid'], got {pk_columns}" + ) def test_lazy_key_resolution(self): """ @@ -175,3 +201,190 @@ class EncryptedFieldTest(SupersetTestCase): # Restore original key self.app.config["SECRET_KEY"] = key_a + + def test_re_encrypt_row_uses_pk_columns(self): + """ + Verify SecretsMigrator builds UPDATE statements targeting the table's + actual primary key columns rather than a hardcoded `id` column. + Regression guard for tables like `semantic_layers` whose PK is `uuid`. + """ + from unittest.mock import MagicMock + + from sqlalchemy.engine import make_url + + dialect = make_url("sqlite://").get_dialect() + previous_key = "PREVIOUS_KEY_FOR_PK_COLUMN_TEST" + migrator = SecretsMigrator(previous_key) + migrator._dialect = dialect # noqa: SLF001 + + # Encrypt under the previous key so the current-key decrypt fails + # and the re-encrypt path (which issues the UPDATE) is exercised. + previous_field = EncryptedType(type_in=String(1024), key=previous_key) + ciphertext = previous_field.process_bind_param("hunter2", dialect) + + current_field = encrypted_field_factory.create(String(1024)) + conn = MagicMock() + row = {"uuid": b"\x00" * 16, "configuration": ciphertext} + stats = ReEncryptStats() + + migrator._re_encrypt_row( # noqa: SLF001 + conn, + row, + "semantic_layers", + {"configuration": current_field}, + ["uuid"], + stats, + ) + + assert conn.execute.call_count == 1 + stmt = str(conn.execute.call_args.args[0]) + assert "WHERE uuid = :_pk_uuid" in stmt + kwargs = conn.execute.call_args.kwargs + assert kwargs["_pk_uuid"] == row["uuid"] + assert "configuration" in kwargs + assert stats == ReEncryptStats(re_encrypted=1, skipped=0, failed=0) + + def test_re_encrypt_row_is_idempotent(self): + """ + Re-running re-encryption on a row that is already encrypted under the + current key must be a no-op: no UPDATE is issued, no error is raised, + and the outcome is counted as skipped. + """ + from unittest.mock import MagicMock + + from sqlalchemy.engine import make_url + + dialect = make_url("sqlite://").get_dialect() + current_key = self.app.config["SECRET_KEY"] + migrator = SecretsMigrator("WRONG_PREVIOUS_KEY_abcdef") + migrator._dialect = dialect # noqa: SLF001 + + field = encrypted_field_factory.create(String(1024)) + ciphertext = field.process_bind_param("hunter2", dialect) + assert field.process_result_value(ciphertext, dialect) == "hunter2" + + conn = MagicMock() + row = {"uuid": b"\x00" * 16, "configuration": ciphertext} + stats = ReEncryptStats() + + migrator._re_encrypt_row( # noqa: SLF001 + conn, + row, + "semantic_layers", + {"configuration": field}, + ["uuid"], + stats, + ) + + assert conn.execute.call_count == 0, ( + "Row already readable under current key should not trigger UPDATE" + ) + assert stats == ReEncryptStats(re_encrypted=0, skipped=1, failed=0) + # Current key must still decrypt the original ciphertext — nothing + # was mutated. + self.app.config["SECRET_KEY"] = current_key + assert field.process_result_value(ciphertext, dialect) == "hunter2" + + def test_re_encrypt_row_idempotent_when_previous_key_also_decrypts(self): + """ + When the supplied previous_secret_key can also decrypt the value + (e.g. re-running after a successful rotation while still passing + the original secret, or mistakenly passing the current secret as + the previous one), the row must still be skipped. Idempotency is + anchored on whether the current key can already read the data, + not on whether the previous key fails to decrypt. + """ + from unittest.mock import MagicMock + + from sqlalchemy.engine import make_url + + dialect = make_url("sqlite://").get_dialect() + # Previous key == current key — this is the "re-run with no actual + # rotation" scenario. + migrator = SecretsMigrator(self.app.config["SECRET_KEY"]) + migrator._dialect = dialect # noqa: SLF001 + + field = encrypted_field_factory.create(String(1024)) + ciphertext = field.process_bind_param("hunter2", dialect) + + conn = MagicMock() + row = {"uuid": b"\x00" * 16, "configuration": ciphertext} + stats = ReEncryptStats() + + migrator._re_encrypt_row( # noqa: SLF001 + conn, + row, + "semantic_layers", + {"configuration": field}, + ["uuid"], + stats, + ) + + assert conn.execute.call_count == 0, ( + "Idempotency must hold even when previous_secret_key can also " + "decrypt the value" + ) + assert stats == ReEncryptStats(re_encrypted=0, skipped=1, failed=0) + + def test_re_encrypt_row_counts_failures_without_raising(self): + """ + Per-column failures are accumulated onto the stats counter so the + caller can emit a summary covering every row. The row method itself + must not raise — run() decides whether to abort based on the totals. + """ + from unittest.mock import MagicMock + + from sqlalchemy.engine import make_url + + dialect = make_url("sqlite://").get_dialect() + migrator = SecretsMigrator("WRONG_PREVIOUS_KEY_abcdef") + migrator._dialect = dialect # noqa: SLF001 + + field = encrypted_field_factory.create(String(1024)) + conn = MagicMock() + row = {"uuid": b"\x00" * 16, "configuration": b"not-valid-ciphertext"} + stats = ReEncryptStats() + + migrator._re_encrypt_row( # noqa: SLF001 + conn, + row, + "semantic_layers", + {"configuration": field}, + ["uuid"], + stats, + ) + + assert conn.execute.call_count == 0 + assert stats == ReEncryptStats(re_encrypted=0, skipped=0, failed=1) + + def test_re_encrypt_row_counts_nulls_separately(self): + """ + NULL column values are not encrypted and therefore have nothing to + migrate. They must be counted as ``null`` (not ``skipped``) and + must not trigger an UPDATE, regardless of which key is supplied as + the previous secret. + """ + from unittest.mock import MagicMock + + from sqlalchemy.engine import make_url + + dialect = make_url("sqlite://").get_dialect() + migrator = SecretsMigrator("WRONG_PREVIOUS_KEY_abcdef") + migrator._dialect = dialect # noqa: SLF001 + + field = encrypted_field_factory.create(String(1024)) + conn = MagicMock() + row = {"uuid": b"\x00" * 16, "configuration": None} + stats = ReEncryptStats() + + migrator._re_encrypt_row( # noqa: SLF001 + conn, + row, + "semantic_layers", + {"configuration": field}, + ["uuid"], + stats, + ) + + assert conn.execute.call_count == 0 + assert stats == ReEncryptStats(re_encrypted=0, skipped=0, null=1, failed=0)