diff --git a/superset/charts/api.py b/superset/charts/api.py index 2327ed8aae2..8e8789da6a2 100644 --- a/superset/charts/api.py +++ b/superset/charts/api.py @@ -44,6 +44,7 @@ from superset.charts.commands.exceptions import ( ChartUpdateFailedError, ) from superset.charts.commands.export import ExportChartsCommand +from superset.charts.commands.importers.dispatcher import ImportChartsCommand from superset.charts.commands.update import UpdateChartCommand from superset.charts.dao import ChartDAO from superset.charts.filters import ChartAllTextFilter, ChartFavoriteFilter, ChartFilter @@ -59,6 +60,7 @@ from superset.charts.schemas import ( screenshot_query_schema, thumbnail_query_schema, ) +from superset.commands.exceptions import CommandInvalidError from superset.constants import RouteMethod from superset.exceptions import SupersetSecurityException from superset.extensions import event_logger @@ -86,6 +88,7 @@ class ChartRestApi(BaseSupersetModelRestApi): include_route_methods = RouteMethod.REST_MODEL_VIEW_CRUD_SET | { RouteMethod.EXPORT, + RouteMethod.IMPORT, RouteMethod.RELATED, "bulk_delete", # not using RouteMethod since locally defined "data", @@ -823,3 +826,56 @@ class ChartRestApi(BaseSupersetModelRestApi): for request_id in requested_ids ] return self.response(200, result=res) + + @expose("/import/", methods=["POST"]) + @protect() + @safe + @statsd_metrics + def import_(self) -> Response: + """Import chart(s) with associated datasets and databases + --- + post: + requestBody: + content: + application/zip: + schema: + type: string + format: binary + responses: + 200: + description: Chart import result + content: + application/json: + schema: + type: object + properties: + message: + type: string + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + upload = request.files.get("file") + if not upload: + return self.response_400() + with ZipFile(upload) as bundle: + contents = { + file_name: bundle.read(file_name).decode() + for file_name in bundle.namelist() + } + + command = ImportChartsCommand(contents) + try: + command.run() + return self.response(200, message="OK") + except CommandInvalidError as exc: + logger.warning("Import chart failed") + return self.response_422(message=exc.normalized_messages()) + except Exception as exc: # pylint: disable=broad-except + logger.exception("Import chart failed") + return self.response_500(message=str(exc)) diff --git a/superset/charts/commands/importers/dispatcher.py b/superset/charts/commands/importers/dispatcher.py new file mode 100644 index 00000000000..e7f01491d76 --- /dev/null +++ b/superset/charts/commands/importers/dispatcher.py @@ -0,0 +1,70 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import Any, Dict + +from marshmallow.exceptions import ValidationError + +from superset.charts.commands.importers import v1 +from superset.commands.base import BaseCommand +from superset.commands.exceptions import CommandInvalidError +from superset.commands.importers.exceptions import IncorrectVersionError + +logger = logging.getLogger(__name__) + +command_versions = [ + v1.ImportChartsCommand, +] + + +class ImportChartsCommand(BaseCommand): + """ + Import charts. + + This command dispatches the import to different versions of the command + until it finds one that matches. + """ + + # pylint: disable=unused-argument + def __init__(self, contents: Dict[str, str], *args: Any, **kwargs: Any): + self.contents = contents + + def run(self) -> None: + # iterate over all commands until we find a version that can + # handle the contents + for version in command_versions: + command = version(self.contents) + try: + command.run() + return + except IncorrectVersionError: + # file is not handled by command, skip + pass + except (CommandInvalidError, ValidationError) as exc: + # found right version, but file is invalid + logger.info("Command failed validation") + raise exc + except Exception as exc: + # validation succeeded but something went wrong + logger.exception("Error running import command") + raise exc + + raise CommandInvalidError("Could not find a valid command to import file") + + def validate(self) -> None: + pass diff --git a/tests/charts/api_tests.py b/tests/charts/api_tests.py index 00decdc23d3..3acfdb9ca0f 100644 --- a/tests/charts/api_tests.py +++ b/tests/charts/api_tests.py @@ -21,11 +21,12 @@ from typing import List, Optional from datetime import datetime from io import BytesIO from unittest import mock -from zipfile import is_zipfile +from zipfile import is_zipfile, ZipFile import humanize import prison import pytest +import yaml from sqlalchemy import and_ from sqlalchemy.sql import func @@ -35,12 +36,19 @@ from tests.fixtures.unicode_dashboard import load_unicode_dashboard_with_slice from tests.test_app import app from superset.connectors.connector_registry import ConnectorRegistry from superset.extensions import db, security_manager -from superset.models.core import FavStar, FavStarClassName +from superset.models.core import Database, FavStar, FavStarClassName from superset.models.dashboard import Dashboard from superset.models.slice import Slice from superset.utils import core as utils from tests.base_api_tests import ApiOwnersTestCaseMixin from tests.base_tests import SupersetTestCase +from tests.fixtures.importexport import ( + chart_config, + chart_metadata_config, + database_config, + dataset_config, + dataset_metadata_config, +) from tests.fixtures.query_context import get_query_context CHART_DATA_URI = "api/v1/chart/data" @@ -1131,7 +1139,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): def test_export_chart(self): """ - Chart API: Test export dataset + Chart API: Test export chart """ example_chart = db.session.query(Slice).all()[0] argument = [example_chart.id] @@ -1147,7 +1155,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): def test_export_chart_not_found(self): """ - Dataset API: Test export dataset not found + Chart API: Test export chart not found """ # Just one does not exist and we get 404 argument = [-1, 1] @@ -1159,7 +1167,7 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): def test_export_chart_gamma(self): """ - Dataset API: Test export dataset has gamma + Chart API: Test export chart has gamma """ example_chart = db.session.query(Slice).all()[0] argument = [example_chart.id] @@ -1169,3 +1177,79 @@ class TestChartApi(SupersetTestCase, ApiOwnersTestCaseMixin): rv = self.client.get(uri) assert rv.status_code == 404 + + def test_import_chart(self): + """ + Chart API: Test import chart + """ + self.login(username="admin") + uri = "api/v1/chart/import/" + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(chart_metadata_config).encode()) + with bundle.open("databases/imported_database.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open("datasets/imported_dataset.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + with bundle.open("charts/imported_chart.yaml", "w") as fp: + fp.write(yaml.safe_dump(chart_config).encode()) + buf.seek(0) + + form_data = { + "file": (buf, "chart_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 200 + assert response == {"message": "OK"} + + database = ( + db.session.query(Database).filter_by(uuid=database_config["uuid"]).one() + ) + assert database.database_name == "imported_database" + + assert len(database.tables) == 1 + dataset = database.tables[0] + assert dataset.table_name == "imported_dataset" + assert str(dataset.uuid) == dataset_config["uuid"] + + chart = db.session.query(Slice).filter_by(uuid=chart_config["uuid"]).one() + assert chart.table == dataset + + db.session.delete(chart) + db.session.delete(dataset) + db.session.delete(database) + db.session.commit() + + def test_import_chart_invalid(self): + """ + Chart API: Test import invalid chart + """ + self.login(username="admin") + uri = "api/v1/chart/import/" + + buf = BytesIO() + with ZipFile(buf, "w") as bundle: + with bundle.open("metadata.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_metadata_config).encode()) + with bundle.open("databases/imported_database.yaml", "w") as fp: + fp.write(yaml.safe_dump(database_config).encode()) + with bundle.open("datasets/imported_dataset.yaml", "w") as fp: + fp.write(yaml.safe_dump(dataset_config).encode()) + with bundle.open("charts/imported_chart.yaml", "w") as fp: + fp.write(yaml.safe_dump(chart_config).encode()) + buf.seek(0) + + form_data = { + "file": (buf, "chart_export.zip"), + } + rv = self.client.post(uri, data=form_data, content_type="multipart/form-data") + response = json.loads(rv.data.decode("utf-8")) + + assert rv.status_code == 422 + assert response == { + "message": {"metadata.yaml": {"type": ["Must be equal to Slice."]}} + } diff --git a/tests/databases/api_tests.py b/tests/databases/api_tests.py index e38b64a7bb6..88213b6ccdb 100644 --- a/tests/databases/api_tests.py +++ b/tests/databases/api_tests.py @@ -840,7 +840,7 @@ class TestDatabaseApi(SupersetTestCase): fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open("databases/imported_database.yaml", "w") as fp: fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open("datasets/import_dataset.yaml", "w") as fp: + with bundle.open("datasets/imported_dataset.yaml", "w") as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) @@ -880,7 +880,7 @@ class TestDatabaseApi(SupersetTestCase): fp.write(yaml.safe_dump(dataset_metadata_config).encode()) with bundle.open("databases/imported_database.yaml", "w") as fp: fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open("datasets/import_dataset.yaml", "w") as fp: + with bundle.open("datasets/imported_dataset.yaml", "w") as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) diff --git a/tests/datasets/api_tests.py b/tests/datasets/api_tests.py index 59854c496eb..ea16c61cfeb 100644 --- a/tests/datasets/api_tests.py +++ b/tests/datasets/api_tests.py @@ -1176,7 +1176,7 @@ class TestDatasetApi(SupersetTestCase): for table_name in self.fixture_tables_names: assert table_name in [ds["table_name"] for ds in data["result"]] - def test_import_dataset(self): + def test_imported_dataset(self): """ Dataset API: Test import dataset """ @@ -1189,7 +1189,7 @@ class TestDatasetApi(SupersetTestCase): fp.write(yaml.safe_dump(dataset_metadata_config).encode()) with bundle.open("databases/imported_database.yaml", "w") as fp: fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open("datasets/import_dataset.yaml", "w") as fp: + with bundle.open("datasets/imported_dataset.yaml", "w") as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) @@ -1216,7 +1216,7 @@ class TestDatasetApi(SupersetTestCase): db.session.delete(database) db.session.commit() - def test_import_dataset_invalid(self): + def test_imported_dataset_invalid(self): """ Dataset API: Test import invalid dataset """ @@ -1229,7 +1229,7 @@ class TestDatasetApi(SupersetTestCase): fp.write(yaml.safe_dump(database_metadata_config).encode()) with bundle.open("databases/imported_database.yaml", "w") as fp: fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open("datasets/import_dataset.yaml", "w") as fp: + with bundle.open("datasets/imported_dataset.yaml", "w") as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0) @@ -1244,7 +1244,7 @@ class TestDatasetApi(SupersetTestCase): "message": {"metadata.yaml": {"type": ["Must be equal to SqlaTable."]}} } - def test_import_dataset_invalid_v0_validation(self): + def test_imported_dataset_invalid_v0_validation(self): """ Dataset API: Test import invalid dataset """ @@ -1255,7 +1255,7 @@ class TestDatasetApi(SupersetTestCase): with ZipFile(buf, "w") as bundle: with bundle.open("databases/imported_database.yaml", "w") as fp: fp.write(yaml.safe_dump(database_config).encode()) - with bundle.open("datasets/import_dataset.yaml", "w") as fp: + with bundle.open("datasets/imported_dataset.yaml", "w") as fp: fp.write(yaml.safe_dump(dataset_config).encode()) buf.seek(0)