diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py index 0ec586a8cab..99e28b38f96 100644 --- a/superset/commands/importers/v1/assets.py +++ b/superset/commands/importers/v1/assets.py @@ -49,8 +49,9 @@ from superset.datasets.schemas import ImportV1DatasetSchema from superset.extensions import feature_flag_manager from superset.migrations.shared.native_filters import migrate_dashboard from superset.models.core import Database -from superset.models.dashboard import dashboard_slices +from superset.models.dashboard import Dashboard, dashboard_slices from superset.models.slice import Slice +from superset.models.sql_lab import SavedQuery from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema from superset.utils.decorators import on_error, transaction @@ -89,6 +90,9 @@ class ImportAssetsCommand(BaseCommand): ) self._configs: dict[str, Any] = {} self.sparse = kwargs.get("sparse", False) + # Defaults to ``True`` for backwards compatibility: historically this + # command always overwrote existing assets. + self.overwrite: bool = kwargs.get("overwrite", True) # pylint: disable=too-many-locals @staticmethod @@ -96,6 +100,7 @@ class ImportAssetsCommand(BaseCommand): configs: dict[str, Any], sparse: bool = False, contents: Optional[dict[str, Any]] = None, + overwrite: bool = True, ) -> None: contents = {} if contents is None else contents # import databases first @@ -116,20 +121,20 @@ class ImportAssetsCommand(BaseCommand): for file_name, config in configs.items(): if file_name.startswith("databases/"): - database = import_database(config, overwrite=True) + database = import_database(config, overwrite=overwrite) database_ids[str(database.uuid)] = database.id # import saved queries for file_name, config in configs.items(): if file_name.startswith("queries/"): config["db_id"] = database_ids[config["database_uuid"]] - import_saved_query(config, overwrite=True) + import_saved_query(config, overwrite=overwrite) # import datasets for file_name, config in configs.items(): if file_name.startswith("datasets/"): config["database_id"] = database_ids[config["database_uuid"]] - dataset = import_dataset(config, overwrite=True) + dataset = import_dataset(config, overwrite=overwrite) dataset_info[str(dataset.uuid)] = { "datasource_id": dataset.id, "datasource_type": dataset.datasource_type, @@ -142,7 +147,7 @@ class ImportAssetsCommand(BaseCommand): if file_name.startswith("charts/"): dataset_dict = dataset_info[config["dataset_uuid"]] config = update_chart_config_dataset(config, dataset_dict) - chart = import_chart(config, overwrite=True) + chart = import_chart(config, overwrite=overwrite) charts.append(chart) chart_ids[str(chart.uuid)] = chart.id @@ -157,7 +162,7 @@ class ImportAssetsCommand(BaseCommand): for file_name, config in configs.items(): if file_name.startswith("dashboards/"): config = update_id_refs(config, chart_ids, dataset_info) - dashboard = import_dashboard(config, overwrite=True) + dashboard = import_dashboard(config, overwrite=overwrite) # set ref in the dashboard_slices table dashboard_chart_ids: list[dict[str, int]] = [] @@ -206,7 +211,73 @@ class ImportAssetsCommand(BaseCommand): ) def run(self) -> None: self.validate() - self._import(self._configs, self.sparse, self.contents) + self._import(self._configs, self.sparse, self.contents, self.overwrite) + + # Maps asset file prefixes to the model class used to look up UUIDs for + # the "already exists" validation check when ``overwrite`` is ``False``. + _MODEL_BY_PREFIX: dict[str, Any] = { + "databases/": Database, + "datasets/": SqlaTable, + "charts/": Slice, + "dashboards/": Dashboard, + "queries/": SavedQuery, + } + + def _bundle_entries_by_prefix(self) -> dict[str, list[tuple[str, str]]]: + """Group ``(file_name, uuid)`` pairs from the bundle by asset prefix.""" + bundle_by_prefix: dict[str, list[tuple[str, str]]] = { + prefix: [] for prefix in self._MODEL_BY_PREFIX + } + for file_name, config in self._configs.items(): + uuid = config.get("uuid") + if not uuid: + continue + for prefix in bundle_by_prefix: + if file_name.startswith(prefix): + bundle_by_prefix[prefix].append((file_name, str(uuid))) + break + return bundle_by_prefix + + def _prevent_overwrite_existing_assets( + self, exceptions: list[ValidationError] + ) -> None: + """ + When ``overwrite`` is ``False``, raise a clear validation error for any + asset in the bundle whose UUID already exists in the database. + + Only the UUIDs present in the import bundle are queried (per prefix), + so the cost scales with the bundle size rather than with the total + number of stored assets. + """ + if self.overwrite: + return + + for prefix, entries in self._bundle_entries_by_prefix().items(): + if not entries: + continue + model_cls = self._MODEL_BY_PREFIX[prefix] + incoming_uuids = [uuid for _, uuid in entries] + existing_uuids = { + str(uuid) + for (uuid,) in db.session.query(model_cls.uuid) + .filter(model_cls.uuid.in_(incoming_uuids)) + .all() + } + if not existing_uuids: + continue + model_name = model_cls.__name__ + for file_name, uuid in entries: + if uuid in existing_uuids: + exceptions.append( + ValidationError( + { + file_name: ( + f"{model_name} already exists " + "and `overwrite=true` was not passed" + ), + } + ) + ) def validate(self) -> None: exceptions: list[ValidationError] = [] @@ -229,6 +300,7 @@ class ImportAssetsCommand(BaseCommand): self.ssh_tunnel_priv_key_passwords, self.encrypted_extra_secrets, ) + self._prevent_overwrite_existing_assets(exceptions) if exceptions: raise CommandInvalidError( diff --git a/superset/importexport/api.py b/superset/importexport/api.py index 09da5c0ff70..d65143f5b1d 100644 --- a/superset/importexport/api.py +++ b/superset/importexport/api.py @@ -30,6 +30,7 @@ from superset.commands.importers.v1.assets import ImportAssetsCommand from superset.commands.importers.v1.utils import get_contents_from_bundle from superset.extensions import event_logger from superset.utils import json +from superset.utils.core import parse_boolean_string from superset.views.base_api import BaseSupersetApi, requires_form_data, statsd_metrics @@ -157,6 +158,12 @@ class ImportExportRestApi(BaseSupersetApi): sparse: description: allow sparse update of resources type: boolean + overwrite: + description: >- + overwrite existing assets? Defaults to ``true`` for + backwards compatibility. When ``false``, the import + fails if any of the assets already exist. + type: boolean responses: 200: description: Assets import result @@ -188,6 +195,9 @@ class ImportExportRestApi(BaseSupersetApi): if not contents: raise NoValidFilesFoundError() sparse = request.form.get("sparse") == "true" + # Defaults to True for backwards compatibility: historically this + # endpoint always overwrote existing assets. + overwrite = parse_boolean_string(request.form.get("overwrite", "true")) passwords = ( json.loads(request.form["passwords"]) @@ -218,6 +228,7 @@ class ImportExportRestApi(BaseSupersetApi): command = ImportAssetsCommand( contents, sparse=sparse, + overwrite=overwrite, passwords=passwords, ssh_tunnel_passwords=ssh_tunnel_passwords, ssh_tunnel_private_keys=ssh_tunnel_private_keys, diff --git a/tests/unit_tests/commands/importers/v1/assets_test.py b/tests/unit_tests/commands/importers/v1/assets_test.py index 2b3de431a30..56b372006da 100644 --- a/tests/unit_tests/commands/importers/v1/assets_test.py +++ b/tests/unit_tests/commands/importers/v1/assets_test.py @@ -19,6 +19,7 @@ import copy from typing import Any, cast import yaml +from marshmallow.exceptions import ValidationError from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session from sqlalchemy.sql import select @@ -32,6 +33,18 @@ from tests.unit_tests.fixtures.assets_configs import ( datasets_config, ) +saved_queries_config: dict[str, Any] = { + "queries/examples/my_query.yaml": { + "schema": "main", + "label": "My saved query", + "description": None, + "sql": "SELECT 1", + "uuid": "e3e4f1f0-5c9d-4a4c-a4e4-0000000000aa", + "version": "1.0.0", + "database_uuid": "a2dc77af-e654-49bb-b321-40f6b559a1ee", + }, +} + def test_import_new_assets(mocker: MockerFixture, session: Session) -> None: """ @@ -227,6 +240,309 @@ def test_import_assets_skips_tags_when_feature_disabled( assert db.session.query(TaggedObject).count() == 0 +def test_import_overwrite_defaults_to_true(session: Session) -> None: + """ + ``ImportAssetsCommand.overwrite`` defaults to ``True`` for backwards + compatibility — historically the command always overwrote existing assets. + """ + from superset.commands.importers.v1.assets import ImportAssetsCommand + + command = ImportAssetsCommand({}) + assert command.overwrite is True + + explicit_false = ImportAssetsCommand({}, overwrite=False) + assert explicit_false.overwrite is False + + +def test_import_threads_overwrite_flag(mocker: MockerFixture, session: Session) -> None: + """ + ``overwrite`` must be threaded through to ``import_database``, + ``import_saved_query``, ``import_dataset``, ``import_chart`` and + ``import_dashboard``. Previously these were hard-coded to ``overwrite=True`` + which caused the API flag to be ignored. + """ + from superset import security_manager + from superset.commands.importers.v1 import assets as assets_module + from superset.commands.importers.v1.assets import ImportAssetsCommand + + mocker.patch.object(security_manager, "can_access", return_value=True) + + mocked_db = mocker.patch.object(assets_module, "import_database") + mocked_db.return_value.uuid = "a2dc77af-e654-49bb-b321-40f6b559a1ee" + mocked_db.return_value.id = 1 + mocked_ds = mocker.patch.object(assets_module, "import_dataset") + mocked_ds.return_value.uuid = "53d47c0c-c03d-47f0-b9ac-81225f808283" + mocked_ds.return_value.id = 1 + mocked_ds.return_value.datasource_type = "table" + mocked_ds.return_value.table_name = "video_game_sales" + mocked_chart = mocker.patch.object(assets_module, "import_chart") + mocked_chart.return_value.viz_type = "table" + mocked_dash = mocker.patch.object(assets_module, "import_dashboard") + mocker.patch.object(assets_module, "find_chart_uuids", return_value=[]) + mocker.patch.object(assets_module, "update_id_refs", side_effect=lambda c, *_: c) + mocker.patch.object(assets_module, "migrate_dashboard") + mocker.patch("superset.db.session.execute") + + configs = { + **copy.deepcopy(databases_config), + **copy.deepcopy(datasets_config), + **copy.deepcopy(charts_config_1), + **copy.deepcopy(dashboards_config_1), + } + + ImportAssetsCommand._import(configs, overwrite=False) + + assert mocked_db.called + for call in mocked_db.call_args_list: + assert call.kwargs["overwrite"] is False + for call in mocked_ds.call_args_list: + assert call.kwargs["overwrite"] is False + for call in mocked_chart.call_args_list: + assert call.kwargs["overwrite"] is False + for call in mocked_dash.call_args_list: + assert call.kwargs["overwrite"] is False + + +def test_prevent_overwrite_flags_existing_assets( + mocker: MockerFixture, session: Session +) -> None: + """ + With ``overwrite=False``, ``_prevent_overwrite_existing_assets`` must + surface a clear ``ValidationError`` for each asset whose UUID already + exists in the database. + """ + from superset import db, security_manager + from superset.commands.importers.v1.assets import ImportAssetsCommand + from superset.models.slice import Slice + + mocker.patch.object(security_manager, "can_access", return_value=True) + engine = db.session.get_bind() + Slice.metadata.create_all(engine) # pylint: disable=no-member + + # seed the database with the fixture assets + seed_configs = { + **copy.deepcopy(databases_config), + **copy.deepcopy(datasets_config), + **copy.deepcopy(charts_config_1), + **copy.deepcopy(dashboards_config_1), + } + ImportAssetsCommand._import(seed_configs) + + command = ImportAssetsCommand({}, overwrite=False) + command._configs = { + **copy.deepcopy(databases_config), + **copy.deepcopy(datasets_config), + **copy.deepcopy(charts_config_1), + **copy.deepcopy(dashboards_config_1), + } + + exceptions: list[ValidationError] = [] + command._prevent_overwrite_existing_assets(exceptions) + + # one exception for each of the seeded assets (db + datasets + charts + dashboards) + expected_count = ( + len(databases_config) + + len(datasets_config) + + len(charts_config_1) + + len(dashboards_config_1) + ) + assert len(exceptions) == expected_count + for exc in exceptions: + assert isinstance(exc, ValidationError) + [(_, message)] = exc.messages.items() + assert "already exists" in message + assert "`overwrite=true` was not passed" in message + + +def test_prevent_overwrite_allows_new_assets( + mocker: MockerFixture, session: Session +) -> None: + """ + With ``overwrite=False`` and no conflicting UUIDs in the database, the + validation step must not raise. + """ + from superset import db, security_manager + from superset.commands.importers.v1.assets import ImportAssetsCommand + from superset.models.slice import Slice + + mocker.patch.object(security_manager, "can_access", return_value=True) + engine = db.session.get_bind() + Slice.metadata.create_all(engine) # pylint: disable=no-member + + command = ImportAssetsCommand({}, overwrite=False) + command._configs = { + **copy.deepcopy(databases_config), + **copy.deepcopy(datasets_config), + **copy.deepcopy(charts_config_1), + **copy.deepcopy(dashboards_config_1), + } + + exceptions: list[ValidationError] = [] + command._prevent_overwrite_existing_assets(exceptions) + + assert exceptions == [] + + +def test_prevent_overwrite_noop_when_overwrite_true( + mocker: MockerFixture, session: Session +) -> None: + """ + With ``overwrite=True`` (the default) the "already exists" validation must + be a no-op even when assets exist in the database — this preserves the + historical behavior. + """ + from superset import db, security_manager + from superset.commands.importers.v1.assets import ImportAssetsCommand + from superset.models.slice import Slice + + mocker.patch.object(security_manager, "can_access", return_value=True) + engine = db.session.get_bind() + Slice.metadata.create_all(engine) # pylint: disable=no-member + + seed_configs = { + **copy.deepcopy(databases_config), + **copy.deepcopy(datasets_config), + **copy.deepcopy(charts_config_1), + **copy.deepcopy(dashboards_config_1), + } + ImportAssetsCommand._import(seed_configs) + + command = ImportAssetsCommand({}) # overwrite defaults to True + command._configs = copy.deepcopy(seed_configs) + + exceptions: list[ValidationError] = [] + command._prevent_overwrite_existing_assets(exceptions) + + assert exceptions == [] + + +def test_prevent_overwrite_flags_existing_saved_queries( + mocker: MockerFixture, session: Session +) -> None: + """ + Saved queries (``queries/`` prefix) must also be covered by the + "already exists" validation when ``overwrite=False`` — otherwise + ``import_saved_query`` silently returns existing rows and the endpoint + would appear to succeed despite the conflict. + """ + from superset import db, security_manager + from superset.commands.importers.v1.assets import ImportAssetsCommand + from superset.models.slice import Slice + from superset.models.sql_lab import SavedQuery + + mocker.patch.object(security_manager, "can_access", return_value=True) + engine = db.session.get_bind() + Slice.metadata.create_all(engine) # pylint: disable=no-member + SavedQuery.metadata.create_all(engine) # pylint: disable=no-member + + # seed a saved query with a UUID that matches the fixture below + saved_query_uuid = next(iter(saved_queries_config.values()))["uuid"] + db.session.add(SavedQuery(uuid=saved_query_uuid, label="seeded")) + db.session.flush() + + command = ImportAssetsCommand({}, overwrite=False) + command._configs = copy.deepcopy(saved_queries_config) + + exceptions: list[ValidationError] = [] + command._prevent_overwrite_existing_assets(exceptions) + + assert len(exceptions) == 1 + [(file_name, message)] = exceptions[0].messages.items() + assert file_name.startswith("queries/") + assert "SavedQuery already exists" in message + + +def test_prevent_overwrite_partial_conflict( + mocker: MockerFixture, session: Session +) -> None: + """ + When only some of the incoming assets already exist, validation must flag + exactly the conflicting ones and leave brand-new assets untouched. + """ + from superset import db, security_manager + from superset.commands.importers.v1.assets import ImportAssetsCommand + from superset.models.slice import Slice + + mocker.patch.object(security_manager, "can_access", return_value=True) + engine = db.session.get_bind() + Slice.metadata.create_all(engine) # pylint: disable=no-member + + # seed only databases + datasets; charts and dashboards stay new + ImportAssetsCommand._import( + { + **copy.deepcopy(databases_config), + **copy.deepcopy(datasets_config), + } + ) + + command = ImportAssetsCommand({}, overwrite=False) + command._configs = { + **copy.deepcopy(databases_config), + **copy.deepcopy(datasets_config), + **copy.deepcopy(charts_config_1), + **copy.deepcopy(dashboards_config_1), + } + + exceptions: list[ValidationError] = [] + command._prevent_overwrite_existing_assets(exceptions) + + flagged_files = {next(iter(exc.messages)) for exc in exceptions} + assert flagged_files == set(databases_config) | set(datasets_config) + + +def test_prevent_overwrite_queries_only_bundle_uuids( + mocker: MockerFixture, session: Session +) -> None: + """ + The validation must scope its UUID lookup to the UUIDs present in the + import bundle (one ``WHERE uuid IN (...)`` query per prefix that has + incoming entries) and skip prefixes with no entries entirely. Otherwise + every import with ``overwrite=false`` would scan all asset tables in + full, regardless of bundle size. + """ + from superset import db, security_manager + from superset.commands.importers.v1.assets import ImportAssetsCommand + from superset.connectors.sqla.models import SqlaTable + from superset.models.core import Database + from superset.models.dashboard import Dashboard + from superset.models.slice import Slice + from superset.models.sql_lab import SavedQuery + + mocker.patch.object(security_manager, "can_access", return_value=True) + engine = db.session.get_bind() + Slice.metadata.create_all(engine) # pylint: disable=no-member + SavedQuery.metadata.create_all(engine) # pylint: disable=no-member + + # bundle only contains a database — no datasets/charts/dashboards/queries + bundle = copy.deepcopy(databases_config) + + spy = mocker.spy(db.session, "query") + + command = ImportAssetsCommand({}, overwrite=False) + command._configs = bundle + exceptions: list[ValidationError] = [] + command._prevent_overwrite_existing_assets(exceptions) + + # exactly one UUID query — for the only prefix with bundle entries — and + # it targets the Database UUID column. Empty-bundle prefixes (datasets/ + # charts/dashboards/queries) must not be queried at all, otherwise this + # validation degrades to a full-table scan per asset type. + queried_columns = [ + call.args[0] + for call in spy.call_args_list + if call.args and getattr(call.args[0], "key", None) == "uuid" + ] + assert len(queried_columns) == 1 + assert queried_columns[0].class_ is Database + + queried_models = {col.class_ for col in queried_columns} + for model_cls in (SqlaTable, Slice, Dashboard, SavedQuery): + assert model_cls not in queried_models + + # no row matches in an empty table, so no validation errors are raised + assert exceptions == [] + + def test_import_removes_dashboard_charts( mocker: MockerFixture, session: Session ) -> None: diff --git a/tests/unit_tests/importexport/api_test.py b/tests/unit_tests/importexport/api_test.py index 1ecfd94861c..ff60aa02d02 100644 --- a/tests/unit_tests/importexport/api_test.py +++ b/tests/unit_tests/importexport/api_test.py @@ -48,7 +48,9 @@ def test_export_assets( mocked_export_result = [ ( "metadata.yaml", - lambda: "version: 1.0.0\ntype: assets\ntimestamp: '2022-01-01T00:00:00+00:00'\n", # noqa: E501 + lambda: ( + "version: 1.0.0\ntype: assets\ntimestamp: '2022-01-01T00:00:00+00:00'\n" + ), # noqa: E501 ), ("databases/example.yaml", lambda: ""), ] @@ -109,6 +111,7 @@ def test_import_assets( ImportAssetsCommand.assert_called_with( mocked_contents, sparse=False, + overwrite=True, passwords=passwords, ssh_tunnel_passwords=None, ssh_tunnel_private_keys=None, @@ -160,6 +163,7 @@ def test_import_assets_with_encrypted_extra_secrets( ImportAssetsCommand.assert_called_with( mocked_contents, sparse=False, + overwrite=True, passwords=None, ssh_tunnel_passwords=None, ssh_tunnel_private_keys=None, @@ -168,6 +172,54 @@ def test_import_assets_with_encrypted_extra_secrets( ) +def test_import_assets_overwrite_false( + mocker: MockerFixture, + client: Any, + full_api_access: None, +) -> None: + """ + Passing ``overwrite=false`` on the form must be forwarded to + ``ImportAssetsCommand``. Previously the flag was ignored and assets were + always overwritten. + """ + mocked_contents = { + "metadata.yaml": ( + "version: 1.0.0\ntype: assets\ntimestamp: '2022-01-01T00:00:00+00:00'\n" + ), + "databases/example.yaml": "", + } + + ImportAssetsCommand = mocker.patch("superset.importexport.api.ImportAssetsCommand") # noqa: N806 + + root = Path("assets_export") + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + for path, contents in mocked_contents.items(): + with bundle.open(str(root / path), "w") as fp: + fp.write(contents.encode()) + buf.seek(0) + + form_data = { + "bundle": (buf, "assets_export.zip"), + "overwrite": "false", + } + response = client.post( + "/api/v1/assets/import/", data=form_data, content_type="multipart/form-data" + ) + assert response.status_code == 200 + + ImportAssetsCommand.assert_called_with( + mocked_contents, + sparse=False, + overwrite=False, + passwords=None, + ssh_tunnel_passwords=None, + ssh_tunnel_private_keys=None, + ssh_tunnel_priv_key_passwords=None, + encrypted_extra_secrets=None, + ) + + def test_import_assets_not_zip( mocker: MockerFixture, client: Any,