feat: Implement sparse import for ImportAssetsCommand (#32670)

This commit is contained in:
Paul Rhodes
2025-03-17 14:44:15 +00:00
committed by GitHub
parent a49a15f990
commit c9e2c7037e
4 changed files with 44 additions and 5 deletions

View File

@@ -34,17 +34,21 @@ from superset.commands.database.importers.v1.utils import import_database
from superset.commands.dataset.importers.v1.utils import import_dataset
from superset.commands.exceptions import CommandInvalidError, ImportFailedError
from superset.commands.importers.v1.utils import (
get_resource_mappings_batched,
load_configs,
load_metadata,
validate_metadata_type,
)
from superset.commands.query.importers.v1.utils import import_saved_query
from superset.commands.utils import update_chart_config_dataset
from superset.connectors.sqla.models import SqlaTable
from superset.dashboards.schemas import ImportV1DashboardSchema
from superset.databases.schemas import ImportV1DatabaseSchema
from superset.datasets.schemas import ImportV1DatasetSchema
from superset.migrations.shared.native_filters import migrate_dashboard
from superset.models.core import Database
from superset.models.dashboard import dashboard_slices
from superset.models.slice import Slice
from superset.queries.saved_queries.schemas import ImportV1SavedQuerySchema
from superset.utils.decorators import on_error, transaction
@@ -79,12 +83,27 @@ class ImportAssetsCommand(BaseCommand):
kwargs.get("ssh_tunnel_priv_key_passwords") or {}
)
self._configs: dict[str, Any] = {}
self.sparse = kwargs.get("sparse", False)
# pylint: disable=too-many-locals
@staticmethod
def _import(configs: dict[str, Any]) -> None: # noqa: C901
def _import(configs: dict[str, Any], sparse: bool = False) -> None: # noqa: C901
# import databases first
database_ids: dict[str, int] = {}
dataset_info: dict[str, dict[str, Any]] = {}
chart_ids: dict[str, int] = {}
if sparse:
chart_ids = get_resource_mappings_batched(Slice)
database_ids = get_resource_mappings_batched(Database)
dataset_info = get_resource_mappings_batched(
SqlaTable,
value_func=lambda x: {
"datasource_id": x.id,
"datasource_type": x.datasource_type,
"datasource_name": x.datasource_name,
},
)
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(config, overwrite=True)
@@ -97,7 +116,6 @@ class ImportAssetsCommand(BaseCommand):
import_saved_query(config, overwrite=True)
# import datasets
dataset_info: dict[str, dict[str, Any]] = {}
for file_name, config in configs.items():
if file_name.startswith("datasets/"):
config["database_id"] = database_ids[config["database_uuid"]]
@@ -110,7 +128,6 @@ class ImportAssetsCommand(BaseCommand):
# import charts
charts = []
chart_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("charts/"):
dataset_dict = dataset_info[config["dataset_uuid"]]
@@ -161,7 +178,7 @@ class ImportAssetsCommand(BaseCommand):
)
def run(self) -> None:
self.validate()
self._import(self._configs)
self._import(self._configs, self.sparse)
def validate(self) -> None:
exceptions: list[ValidationError] = []

View File

@@ -15,7 +15,7 @@
import logging
from pathlib import Path, PurePosixPath
from typing import Any, Optional
from typing import Any, Callable, Dict, Optional, Type
from zipfile import ZipFile
import yaml
@@ -214,3 +214,19 @@ def get_contents_from_bundle(bundle: ZipFile) -> dict[str, str]:
for file_name in bundle.namelist()
if is_valid_config(file_name)
}
def get_resource_mappings_batched(
model_class: Type[Any],
batch_size: int = 1000,
value_func: Callable[[Any], Any] = lambda x: x.id,
) -> Dict[str, Any]:
offset = 0
mapping = {}
while True:
batch = db.session.query(model_class).limit(batch_size).offset(offset).all()
if not batch:
break
mapping.update({str(x.uuid): value_func(x) for x in batch})
offset += batch_size
return mapping

View File

@@ -147,6 +147,9 @@ class ImportExportRestApi(BaseSupersetApi):
the private_key should be provided in the following format:
`{"databases/MyDatabase.yaml": "my_private_key_password"}`.
type: string
sparse:
description: allow sparse update of resources
type: boolean
responses:
200:
description: Assets import result
@@ -177,6 +180,7 @@ class ImportExportRestApi(BaseSupersetApi):
if not contents:
raise NoValidFilesFoundError()
sparse = request.form.get("sparse") == "true"
passwords = (
json.loads(request.form["passwords"])
@@ -201,6 +205,7 @@ class ImportExportRestApi(BaseSupersetApi):
command = ImportAssetsCommand(
contents,
sparse=sparse,
passwords=passwords,
ssh_tunnel_passwords=ssh_tunnel_passwords,
ssh_tunnel_private_keys=ssh_tunnel_private_keys,

View File

@@ -108,6 +108,7 @@ def test_import_assets(
passwords = {"assets_export/databases/imported_database.yaml": "SECRET"}
ImportAssetsCommand.assert_called_with(
mocked_contents,
sparse=False,
passwords=passwords,
ssh_tunnel_passwords=None,
ssh_tunnel_private_keys=None,