diff --git a/superset/cli/viz_migrations.py b/superset/cli/viz_migrations.py index 34550ac2da4..df62a598584 100644 --- a/superset/cli/viz_migrations.py +++ b/superset/cli/viz_migrations.py @@ -16,6 +16,7 @@ # under the License. from __future__ import annotations +import logging from enum import Enum from typing import Type @@ -107,9 +108,10 @@ def migrate_viz() -> None: ) def upgrade(viz_type: str, ids: tuple[int, ...] | None = None) -> None: """Upgrade a viz to the latest version.""" - if ids is None: + setup_logger() + if viz_type: migrate_by_viz_type(VizType(viz_type)) - else: + elif ids: migrate_by_id(ids) @@ -133,9 +135,10 @@ def upgrade(viz_type: str, ids: tuple[int, ...] | None = None) -> None: ) def downgrade(viz_type: str, ids: tuple[int, ...] | None = None) -> None: """Downgrade a viz to the previous version.""" - if ids is None: + setup_logger() + if viz_type: migrate_by_viz_type(VizType(viz_type), is_downgrade=True) - else: + elif ids: migrate_by_id(ids, is_downgrade=True) @@ -163,7 +166,7 @@ def migrate_by_id(ids: tuple[int, ...], is_downgrade: bool = False) -> None: slices = db.session.query(Slice).filter(Slice.id.in_(ids)) for slc in paginated_update( slices, - lambda current, total: print( + lambda current, total: click.echo( f"{('Downgraded' if is_downgrade else 'Upgraded')} {current}/{total} charts" ), ): @@ -171,3 +174,12 @@ def migrate_by_id(ids: tuple[int, ...], is_downgrade: bool = False) -> None: PREVIOUS_VERSION[slc.viz_type].downgrade_slice(slc) elif slc.viz_type in MIGRATIONS: MIGRATIONS[slc.viz_type].upgrade_slice(slc) + + +def setup_logger() -> None: + """ + Configure the logger for the CLI commands. + """ + console_handler = logging.StreamHandler() + logger = logging.getLogger("alembic") + logger.addHandler(console_handler) diff --git a/superset/migrations/shared/migrate_viz/base.py b/superset/migrations/shared/migrate_viz/base.py index ee5372e3a8f..ba10de1ec22 100644 --- a/superset/migrations/shared/migrate_viz/base.py +++ b/superset/migrations/shared/migrate_viz/base.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. import copy +import logging from typing import Any from sqlalchemy import and_, Column, Integer, String, Text @@ -26,6 +27,8 @@ from superset.constants import TimeGrain from superset.migrations.shared.utils import paginated_update, try_load_json from superset.utils import json +logger = logging.getLogger("alembic") + Base = declarative_base() @@ -121,41 +124,51 @@ class MigrateViz: @classmethod def upgrade_slice(cls, slc: Slice) -> None: - clz = cls(slc.params) - form_data_bak = copy.deepcopy(clz.data) + try: + clz = cls(slc.params) + form_data_bak = copy.deepcopy(clz.data) - clz._pre_action() - clz._migrate() - clz._post_action() + clz._pre_action() + clz._migrate() + clz._post_action() - # viz_type depends on the migration and should be set after its execution - # because a source viz can be mapped to different target viz types - slc.viz_type = clz.target_viz_type + # viz_type depends on the migration and should be set after its execution + # because a source viz can be mapped to different target viz types + slc.viz_type = clz.target_viz_type - # only backup params - slc.params = json.dumps({**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak}) + # only backup params + slc.params = json.dumps( + {**clz.data, FORM_DATA_BAK_FIELD_NAME: form_data_bak} + ) - if "form_data" in (query_context := try_load_json(slc.query_context)): - query_context["form_data"] = clz.data - slc.query_context = json.dumps(query_context) + if "form_data" in (query_context := try_load_json(slc.query_context)): + query_context["form_data"] = clz.data + slc.query_context = json.dumps(query_context) + except Exception as e: + logger.warning(f"Failed to migrate slice {slc.id}: {e}") @classmethod def downgrade_slice(cls, slc: Slice) -> None: - form_data = try_load_json(slc.params) - if "viz_type" in (form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {})): - slc.params = json.dumps(form_data_bak) - slc.viz_type = form_data_bak.get("viz_type") - query_context = try_load_json(slc.query_context) - if "form_data" in query_context: - query_context["form_data"] = form_data_bak - slc.query_context = json.dumps(query_context) + try: + form_data = try_load_json(slc.params) + if "viz_type" in ( + form_data_bak := form_data.get(FORM_DATA_BAK_FIELD_NAME, {}) + ): + slc.params = json.dumps(form_data_bak) + slc.viz_type = form_data_bak.get("viz_type") + query_context = try_load_json(slc.query_context) + if "form_data" in query_context: + query_context["form_data"] = form_data_bak + slc.query_context = json.dumps(query_context) + except Exception as e: + logger.warning(f"Failed to downgrade slice {slc.id}: {e}") @classmethod def upgrade(cls, session: Session) -> None: slices = session.query(Slice).filter(Slice.viz_type == cls.source_viz_type) for slc in paginated_update( slices, - lambda current, total: print(f"Upgraded {current}/{total} charts"), + lambda current, total: logger.info(f"Upgraded {current}/{total} charts"), ): cls.upgrade_slice(slc) @@ -169,6 +182,6 @@ class MigrateViz: ) for slc in paginated_update( slices, - lambda current, total: print(f"Downgraded {current}/{total} charts"), + lambda current, total: logger.info(f"Downgraded {current}/{total} charts"), ): cls.downgrade_slice(slc) diff --git a/superset/migrations/shared/utils.py b/superset/migrations/shared/utils.py index 6ee4137af60..db97383a564 100644 --- a/superset/migrations/shared/utils.py +++ b/superset/migrations/shared/utils.py @@ -172,11 +172,7 @@ def paginated_update( def try_load_json(data: Optional[str]) -> dict[str, Any]: - try: - return data and json.loads(data) or {} - except json.JSONDecodeError: - print(f"Failed to parse: {data}") - return {} + return data and json.loads(data) or {} def has_table(table_name: str) -> bool: