fix(importexport): honor overwrite flag on /api/v1/assets/import (#39502)

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Maxime Beauchemin
2026-05-11 10:24:42 -07:00
committed by GitHub
parent 6ee4d694bc
commit d90d3a2dea
4 changed files with 459 additions and 8 deletions

View File

@@ -49,8 +49,9 @@ from superset.datasets.schemas import ImportV1DatasetSchema
from superset.extensions import feature_flag_manager
from superset.migrations.shared.native_filters import migrate_dashboard
from superset.models.core import Database
from superset.models.dashboard import dashboard_slices
from superset.models.dashboard import Dashboard, dashboard_slices
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
from superset.utils.decorators import on_error, transaction
@@ -89,6 +90,9 @@ class ImportAssetsCommand(BaseCommand):
)
self._configs: dict[str, Any] = {}
self.sparse = kwargs.get("sparse", False)
# Defaults to ``True`` for backwards compatibility: historically this
# command always overwrote existing assets.
self.overwrite: bool = kwargs.get("overwrite", True)
# pylint: disable=too-many-locals
@staticmethod
@@ -96,6 +100,7 @@ class ImportAssetsCommand(BaseCommand):
configs: dict[str, Any],
sparse: bool = False,
contents: Optional[dict[str, Any]] = None,
overwrite: bool = True,
) -> None:
contents = {} if contents is None else contents
# import databases first
@@ -116,20 +121,20 @@ class ImportAssetsCommand(BaseCommand):
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(config, overwrite=True)
database = import_database(config, overwrite=overwrite)
database_ids[str(database.uuid)] = database.id
# import saved queries
for file_name, config in configs.items():
if file_name.startswith("queries/"):
config["db_id"] = database_ids[config["database_uuid"]]
import_saved_query(config, overwrite=True)
import_saved_query(config, overwrite=overwrite)
# import datasets
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
config["database_id"] = database_ids[config["database_uuid"]]
dataset = import_dataset(config, overwrite=True)
dataset = import_dataset(config, overwrite=overwrite)
dataset_info[str(dataset.uuid)] = {
"datasource_id": dataset.id,
"datasource_type": dataset.datasource_type,
@@ -142,7 +147,7 @@ class ImportAssetsCommand(BaseCommand):
if file_name.startswith("charts/"):
dataset_dict = dataset_info[config["dataset_uuid"]]
config = update_chart_config_dataset(config, dataset_dict)
chart = import_chart(config, overwrite=True)
chart = import_chart(config, overwrite=overwrite)
charts.append(chart)
chart_ids[str(chart.uuid)] = chart.id
@@ -157,7 +162,7 @@ class ImportAssetsCommand(BaseCommand):
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
config = update_id_refs(config, chart_ids, dataset_info)
dashboard = import_dashboard(config, overwrite=True)
dashboard = import_dashboard(config, overwrite=overwrite)
# set ref in the dashboard_slices table
dashboard_chart_ids: list[dict[str, int]] = []
@@ -206,7 +211,73 @@ class ImportAssetsCommand(BaseCommand):
)
def run(self) -> None:
self.validate()
self._import(self._configs, self.sparse, self.contents)
self._import(self._configs, self.sparse, self.contents, self.overwrite)
# Maps asset file prefixes to the model class used to look up UUIDs for
# the "already exists" validation check when ``overwrite`` is ``False``.
_MODEL_BY_PREFIX: dict[str, Any] = {
"databases/": Database,
"datasets/": SqlaTable,
"charts/": Slice,
"dashboards/": Dashboard,
"queries/": SavedQuery,
}
def _bundle_entries_by_prefix(self) -> dict[str, list[tuple[str, str]]]:
"""Group ``(file_name, uuid)`` pairs from the bundle by asset prefix."""
bundle_by_prefix: dict[str, list[tuple[str, str]]] = {
prefix: [] for prefix in self._MODEL_BY_PREFIX
}
for file_name, config in self._configs.items():
uuid = config.get("uuid")
if not uuid:
continue
for prefix in bundle_by_prefix:
if file_name.startswith(prefix):
bundle_by_prefix[prefix].append((file_name, str(uuid)))
break
return bundle_by_prefix
def _prevent_overwrite_existing_assets(
self, exceptions: list[ValidationError]
) -> None:
"""
When ``overwrite`` is ``False``, raise a clear validation error for any
asset in the bundle whose UUID already exists in the database.
Only the UUIDs present in the import bundle are queried (per prefix),
so the cost scales with the bundle size rather than with the total
number of stored assets.
"""
if self.overwrite:
return
for prefix, entries in self._bundle_entries_by_prefix().items():
if not entries:
continue
model_cls = self._MODEL_BY_PREFIX[prefix]
incoming_uuids = [uuid for _, uuid in entries]
existing_uuids = {
str(uuid)
for (uuid,) in db.session.query(model_cls.uuid)
.filter(model_cls.uuid.in_(incoming_uuids))
.all()
}
if not existing_uuids:
continue
model_name = model_cls.__name__
for file_name, uuid in entries:
if uuid in existing_uuids:
exceptions.append(
ValidationError(
{
file_name: (
f"{model_name} already exists "
"and `overwrite=true` was not passed"
),
}
)
)
def validate(self) -> None:
exceptions: list[ValidationError] = []
@@ -229,6 +300,7 @@ class ImportAssetsCommand(BaseCommand):
self.ssh_tunnel_priv_key_passwords,
self.encrypted_extra_secrets,
)
self._prevent_overwrite_existing_assets(exceptions)
if exceptions:
raise CommandInvalidError(

View File

@@ -30,6 +30,7 @@ from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.commands.importers.v1.utils import get_contents_from_bundle
from superset.extensions import event_logger
from superset.utils import json
from superset.utils.core import parse_boolean_string
from superset.views.base_api import BaseSupersetApi, requires_form_data, statsd_metrics
@@ -157,6 +158,12 @@ class ImportExportRestApi(BaseSupersetApi):
sparse:
description: allow sparse update of resources
type: boolean
overwrite:
description: >-
overwrite existing assets? Defaults to ``true`` for
backwards compatibility. When ``false``, the import
fails if any of the assets already exist.
type: boolean
responses:
200:
description: Assets import result
@@ -188,6 +195,9 @@ class ImportExportRestApi(BaseSupersetApi):
if not contents:
raise NoValidFilesFoundError()
sparse = request.form.get("sparse") == "true"
# Defaults to True for backwards compatibility: historically this
# endpoint always overwrote existing assets.
overwrite = parse_boolean_string(request.form.get("overwrite", "true"))
passwords = (
json.loads(request.form["passwords"])
@@ -218,6 +228,7 @@ class ImportExportRestApi(BaseSupersetApi):
command = ImportAssetsCommand(
contents,
sparse=sparse,
overwrite=overwrite,
passwords=passwords,
ssh_tunnel_passwords=ssh_tunnel_passwords,
ssh_tunnel_private_keys=ssh_tunnel_private_keys,

View File

@@ -19,6 +19,7 @@ import copy
from typing import Any, cast
import yaml
from marshmallow.exceptions import ValidationError
from pytest_mock import MockerFixture
from sqlalchemy.orm.session import Session
from sqlalchemy.sql import select
@@ -32,6 +33,18 @@ from tests.unit_tests.fixtures.assets_configs import (
datasets_config,
)
saved_queries_config: dict[str, Any] = {
"queries/examples/my_query.yaml": {
"schema": "main",
"label": "My saved query",
"description": None,
"sql": "SELECT 1",
"uuid": "e3e4f1f0-5c9d-4a4c-a4e4-0000000000aa",
"version": "1.0.0",
"database_uuid": "a2dc77af-e654-49bb-b321-40f6b559a1ee",
},
}
def test_import_new_assets(mocker: MockerFixture, session: Session) -> None:
"""
@@ -227,6 +240,309 @@ def test_import_assets_skips_tags_when_feature_disabled(
assert db.session.query(TaggedObject).count() == 0
def test_import_overwrite_defaults_to_true(session: Session) -> None:
"""
``ImportAssetsCommand.overwrite`` defaults to ``True`` for backwards
compatibility — historically the command always overwrote existing assets.
"""
from superset.commands.importers.v1.assets import ImportAssetsCommand
command = ImportAssetsCommand({})
assert command.overwrite is True
explicit_false = ImportAssetsCommand({}, overwrite=False)
assert explicit_false.overwrite is False
def test_import_threads_overwrite_flag(mocker: MockerFixture, session: Session) -> None:
"""
``overwrite`` must be threaded through to ``import_database``,
``import_saved_query``, ``import_dataset``, ``import_chart`` and
``import_dashboard``. Previously these were hard-coded to ``overwrite=True``
which caused the API flag to be ignored.
"""
from superset import security_manager
from superset.commands.importers.v1 import assets as assets_module
from superset.commands.importers.v1.assets import ImportAssetsCommand
mocker.patch.object(security_manager, "can_access", return_value=True)
mocked_db = mocker.patch.object(assets_module, "import_database")
mocked_db.return_value.uuid = "a2dc77af-e654-49bb-b321-40f6b559a1ee"
mocked_db.return_value.id = 1
mocked_ds = mocker.patch.object(assets_module, "import_dataset")
mocked_ds.return_value.uuid = "53d47c0c-c03d-47f0-b9ac-81225f808283"
mocked_ds.return_value.id = 1
mocked_ds.return_value.datasource_type = "table"
mocked_ds.return_value.table_name = "video_game_sales"
mocked_chart = mocker.patch.object(assets_module, "import_chart")
mocked_chart.return_value.viz_type = "table"
mocked_dash = mocker.patch.object(assets_module, "import_dashboard")
mocker.patch.object(assets_module, "find_chart_uuids", return_value=[])
mocker.patch.object(assets_module, "update_id_refs", side_effect=lambda c, *_: c)
mocker.patch.object(assets_module, "migrate_dashboard")
mocker.patch("superset.db.session.execute")
configs = {
**copy.deepcopy(databases_config),
**copy.deepcopy(datasets_config),
**copy.deepcopy(charts_config_1),
**copy.deepcopy(dashboards_config_1),
}
ImportAssetsCommand._import(configs, overwrite=False)
assert mocked_db.called
for call in mocked_db.call_args_list:
assert call.kwargs["overwrite"] is False
for call in mocked_ds.call_args_list:
assert call.kwargs["overwrite"] is False
for call in mocked_chart.call_args_list:
assert call.kwargs["overwrite"] is False
for call in mocked_dash.call_args_list:
assert call.kwargs["overwrite"] is False
def test_prevent_overwrite_flags_existing_assets(
mocker: MockerFixture, session: Session
) -> None:
"""
With ``overwrite=False``, ``_prevent_overwrite_existing_assets`` must
surface a clear ``ValidationError`` for each asset whose UUID already
exists in the database.
"""
from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
# seed the database with the fixture assets
seed_configs = {
**copy.deepcopy(databases_config),
**copy.deepcopy(datasets_config),
**copy.deepcopy(charts_config_1),
**copy.deepcopy(dashboards_config_1),
}
ImportAssetsCommand._import(seed_configs)
command = ImportAssetsCommand({}, overwrite=False)
command._configs = {
**copy.deepcopy(databases_config),
**copy.deepcopy(datasets_config),
**copy.deepcopy(charts_config_1),
**copy.deepcopy(dashboards_config_1),
}
exceptions: list[ValidationError] = []
command._prevent_overwrite_existing_assets(exceptions)
# one exception for each of the seeded assets (db + datasets + charts + dashboards)
expected_count = (
len(databases_config)
+ len(datasets_config)
+ len(charts_config_1)
+ len(dashboards_config_1)
)
assert len(exceptions) == expected_count
for exc in exceptions:
assert isinstance(exc, ValidationError)
[(_, message)] = exc.messages.items()
assert "already exists" in message
assert "`overwrite=true` was not passed" in message
def test_prevent_overwrite_allows_new_assets(
mocker: MockerFixture, session: Session
) -> None:
"""
With ``overwrite=False`` and no conflicting UUIDs in the database, the
validation step must not raise.
"""
from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
command = ImportAssetsCommand({}, overwrite=False)
command._configs = {
**copy.deepcopy(databases_config),
**copy.deepcopy(datasets_config),
**copy.deepcopy(charts_config_1),
**copy.deepcopy(dashboards_config_1),
}
exceptions: list[ValidationError] = []
command._prevent_overwrite_existing_assets(exceptions)
assert exceptions == []
def test_prevent_overwrite_noop_when_overwrite_true(
mocker: MockerFixture, session: Session
) -> None:
"""
With ``overwrite=True`` (the default) the "already exists" validation must
be a no-op even when assets exist in the database — this preserves the
historical behavior.
"""
from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
seed_configs = {
**copy.deepcopy(databases_config),
**copy.deepcopy(datasets_config),
**copy.deepcopy(charts_config_1),
**copy.deepcopy(dashboards_config_1),
}
ImportAssetsCommand._import(seed_configs)
command = ImportAssetsCommand({}) # overwrite defaults to True
command._configs = copy.deepcopy(seed_configs)
exceptions: list[ValidationError] = []
command._prevent_overwrite_existing_assets(exceptions)
assert exceptions == []
def test_prevent_overwrite_flags_existing_saved_queries(
mocker: MockerFixture, session: Session
) -> None:
"""
Saved queries (``queries/`` prefix) must also be covered by the
"already exists" validation when ``overwrite=False`` — otherwise
``import_saved_query`` silently returns existing rows and the endpoint
would appear to succeed despite the conflict.
"""
from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
SavedQuery.metadata.create_all(engine) # pylint: disable=no-member
# seed a saved query with a UUID that matches the fixture below
saved_query_uuid = next(iter(saved_queries_config.values()))["uuid"]
db.session.add(SavedQuery(uuid=saved_query_uuid, label="seeded"))
db.session.flush()
command = ImportAssetsCommand({}, overwrite=False)
command._configs = copy.deepcopy(saved_queries_config)
exceptions: list[ValidationError] = []
command._prevent_overwrite_existing_assets(exceptions)
assert len(exceptions) == 1
[(file_name, message)] = exceptions[0].messages.items()
assert file_name.startswith("queries/")
assert "SavedQuery already exists" in message
def test_prevent_overwrite_partial_conflict(
mocker: MockerFixture, session: Session
) -> None:
"""
When only some of the incoming assets already exist, validation must flag
exactly the conflicting ones and leave brand-new assets untouched.
"""
from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.models.slice import Slice
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
# seed only databases + datasets; charts and dashboards stay new
ImportAssetsCommand._import(
{
**copy.deepcopy(databases_config),
**copy.deepcopy(datasets_config),
}
)
command = ImportAssetsCommand({}, overwrite=False)
command._configs = {
**copy.deepcopy(databases_config),
**copy.deepcopy(datasets_config),
**copy.deepcopy(charts_config_1),
**copy.deepcopy(dashboards_config_1),
}
exceptions: list[ValidationError] = []
command._prevent_overwrite_existing_assets(exceptions)
flagged_files = {next(iter(exc.messages)) for exc in exceptions}
assert flagged_files == set(databases_config) | set(datasets_config)
def test_prevent_overwrite_queries_only_bundle_uuids(
mocker: MockerFixture, session: Session
) -> None:
"""
The validation must scope its UUID lookup to the UUIDs present in the
import bundle (one ``WHERE uuid IN (...)`` query per prefix that has
incoming entries) and skip prefixes with no entries entirely. Otherwise
every import with ``overwrite=false`` would scan all asset tables in
full, regardless of bundle size.
"""
from superset import db, security_manager
from superset.commands.importers.v1.assets import ImportAssetsCommand
from superset.connectors.sqla.models import SqlaTable
from superset.models.core import Database
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
from superset.models.sql_lab import SavedQuery
mocker.patch.object(security_manager, "can_access", return_value=True)
engine = db.session.get_bind()
Slice.metadata.create_all(engine) # pylint: disable=no-member
SavedQuery.metadata.create_all(engine) # pylint: disable=no-member
# bundle only contains a database — no datasets/charts/dashboards/queries
bundle = copy.deepcopy(databases_config)
spy = mocker.spy(db.session, "query")
command = ImportAssetsCommand({}, overwrite=False)
command._configs = bundle
exceptions: list[ValidationError] = []
command._prevent_overwrite_existing_assets(exceptions)
# exactly one UUID query — for the only prefix with bundle entries — and
# it targets the Database UUID column. Empty-bundle prefixes (datasets/
# charts/dashboards/queries) must not be queried at all, otherwise this
# validation degrades to a full-table scan per asset type.
queried_columns = [
call.args[0]
for call in spy.call_args_list
if call.args and getattr(call.args[0], "key", None) == "uuid"
]
assert len(queried_columns) == 1
assert queried_columns[0].class_ is Database
queried_models = {col.class_ for col in queried_columns}
for model_cls in (SqlaTable, Slice, Dashboard, SavedQuery):
assert model_cls not in queried_models
# no row matches in an empty table, so no validation errors are raised
assert exceptions == []
def test_import_removes_dashboard_charts(
mocker: MockerFixture, session: Session
) -> None:

View File

@@ -48,7 +48,9 @@ def test_export_assets(
mocked_export_result = [
(
"metadata.yaml",
lambda: "version: 1.0.0\ntype: assets\ntimestamp: '2022-01-01T00:00:00+00:00'\n", # noqa: E501
lambda: (
"version: 1.0.0\ntype: assets\ntimestamp: '2022-01-01T00:00:00+00:00'\n"
), # noqa: E501
),
("databases/example.yaml", lambda: "<DATABASE CONTENTS>"),
]
@@ -109,6 +111,7 @@ def test_import_assets(
ImportAssetsCommand.assert_called_with(
mocked_contents,
sparse=False,
overwrite=True,
passwords=passwords,
ssh_tunnel_passwords=None,
ssh_tunnel_private_keys=None,
@@ -160,6 +163,7 @@ def test_import_assets_with_encrypted_extra_secrets(
ImportAssetsCommand.assert_called_with(
mocked_contents,
sparse=False,
overwrite=True,
passwords=None,
ssh_tunnel_passwords=None,
ssh_tunnel_private_keys=None,
@@ -168,6 +172,54 @@ def test_import_assets_with_encrypted_extra_secrets(
)
def test_import_assets_overwrite_false(
mocker: MockerFixture,
client: Any,
full_api_access: None,
) -> None:
"""
Passing ``overwrite=false`` on the form must be forwarded to
``ImportAssetsCommand``. Previously the flag was ignored and assets were
always overwritten.
"""
mocked_contents = {
"metadata.yaml": (
"version: 1.0.0\ntype: assets\ntimestamp: '2022-01-01T00:00:00+00:00'\n"
),
"databases/example.yaml": "<DATABASE CONTENTS>",
}
ImportAssetsCommand = mocker.patch("superset.importexport.api.ImportAssetsCommand") # noqa: N806
root = Path("assets_export")
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
for path, contents in mocked_contents.items():
with bundle.open(str(root / path), "w") as fp:
fp.write(contents.encode())
buf.seek(0)
form_data = {
"bundle": (buf, "assets_export.zip"),
"overwrite": "false",
}
response = client.post(
"/api/v1/assets/import/", data=form_data, content_type="multipart/form-data"
)
assert response.status_code == 200
ImportAssetsCommand.assert_called_with(
mocked_contents,
sparse=False,
overwrite=False,
passwords=None,
ssh_tunnel_passwords=None,
ssh_tunnel_private_keys=None,
ssh_tunnel_priv_key_passwords=None,
encrypted_extra_secrets=None,
)
def test_import_assets_not_zip(
mocker: MockerFixture,
client: Any,