chore(command): Organize Commands according to SIP-92 (#25850)

This commit is contained in:
John Bodley
2023-11-22 11:55:54 -08:00
committed by GitHub
parent 984c278c4c
commit 07bcfa9b5f
265 changed files with 786 additions and 808 deletions

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from typing import Any, Optional
from flask import g
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset import security_manager
from superset.commands.base import BaseCommand, CreateMixin
from superset.commands.chart.exceptions import (
ChartCreateFailedError,
ChartInvalidError,
DashboardsForbiddenError,
DashboardsNotFoundValidationError,
)
from superset.commands.utils import get_datasource_by_id
from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAOCreateFailedError
logger = logging.getLogger(__name__)
class CreateChartCommand(CreateMixin, BaseCommand):
def __init__(self, data: dict[str, Any]):
self._properties = data.copy()
def run(self) -> Model:
self.validate()
try:
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
return ChartDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ChartCreateFailedError() from ex
def validate(self) -> None:
exceptions = []
datasource_type = self._properties["datasource_type"]
datasource_id = self._properties["datasource_id"]
dashboard_ids = self._properties.get("dashboards", [])
owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate/Populate datasource
try:
datasource = get_datasource_by_id(datasource_id, datasource_type)
self._properties["datasource_name"] = datasource.name
except ValidationError as ex:
exceptions.append(ex)
# Validate/Populate dashboards
dashboards = DashboardDAO.find_by_ids(dashboard_ids)
if len(dashboards) != len(dashboard_ids):
exceptions.append(DashboardsNotFoundValidationError())
for dash in dashboards:
if not security_manager.is_owner(dash):
raise DashboardsForbiddenError()
self._properties["dashboards"] = dashboards
try:
owners = self.populate_owners(owner_ids)
self._properties["owners"] = owners
except ValidationError as ex:
exceptions.append(ex)
if exceptions:
raise ChartInvalidError(exceptions=exceptions)

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,38 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any, Optional
from flask import Request
from superset.extensions import async_query_manager
logger = logging.getLogger(__name__)
class CreateAsyncChartDataJobCommand:
_async_channel_id: str
def validate(self, request: Request) -> None:
self._async_channel_id = async_query_manager.parse_channel_id_from_request(
request
)
def run(self, form_data: dict[str, Any], user_id: Optional[int]) -> dict[str, Any]:
return async_query_manager.submit_chart_data_job(
self._async_channel_id, form_data, user_id
)

View File

@@ -0,0 +1,68 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
from flask_babel import gettext as _
from superset.commands.base import BaseCommand
from superset.commands.chart.exceptions import (
ChartDataCacheLoadError,
ChartDataQueryFailedError,
)
from superset.common.query_context import QueryContext
from superset.exceptions import CacheLoadError
logger = logging.getLogger(__name__)
class ChartDataCommand(BaseCommand):
_query_context: QueryContext
def __init__(self, query_context: QueryContext):
self._query_context = query_context
def run(self, **kwargs: Any) -> dict[str, Any]:
# caching is handled in query_context.get_df_payload
# (also evals `force` property)
cache_query_context = kwargs.get("cache", False)
force_cached = kwargs.get("force_cached", False)
try:
payload = self._query_context.get_payload(
cache_query_context=cache_query_context, force_cached=force_cached
)
except CacheLoadError as ex:
raise ChartDataCacheLoadError(ex.message) from ex
# TODO: QueryContext should support SIP-40 style errors
for query in payload["queries"]:
if query.get("error"):
raise ChartDataQueryFailedError(
_("Error: %(error)s", error=query["error"])
)
return_value = {
"query_context": self._query_context,
"queries": payload["queries"],
}
if cache_query_context:
return_value.update(cache_key=payload["cache_key"])
return return_value
def validate(self) -> None:
self._query_context.raise_for_access()

View File

@@ -0,0 +1,74 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Optional
from flask_babel import lazy_gettext as _
from superset import security_manager
from superset.commands.base import BaseCommand
from superset.commands.chart.exceptions import (
ChartDeleteFailedError,
ChartDeleteFailedReportsExistError,
ChartForbiddenError,
ChartNotFoundError,
)
from superset.daos.chart import ChartDAO
from superset.daos.exceptions import DAODeleteFailedError
from superset.daos.report import ReportScheduleDAO
from superset.exceptions import SupersetSecurityException
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice
logger = logging.getLogger(__name__)
class DeleteChartCommand(BaseCommand):
def __init__(self, model_ids: list[int]):
self._model_ids = model_ids
self._models: Optional[list[Slice]] = None
def run(self) -> None:
self.validate()
assert self._models
for model_id in self._model_ids:
Dashboard.clear_cache_for_slice(slice_id=model_id)
try:
ChartDAO.delete(self._models)
except DAODeleteFailedError as ex:
logger.exception(ex.exception)
raise ChartDeleteFailedError() from ex
def validate(self) -> None:
# Validate/populate model exists
self._models = ChartDAO.find_by_ids(self._model_ids)
if not self._models or len(self._models) != len(self._model_ids):
raise ChartNotFoundError()
# Check there are no associated ReportSchedules
if reports := ReportScheduleDAO.find_by_chart_ids(self._model_ids):
report_names = [report.name for report in reports]
raise ChartDeleteFailedReportsExistError(
_(f"There are associated alerts or reports: {','.join(report_names)}")
)
# Check ownership
for model in self._models:
try:
security_manager.raise_for_ownership(model)
except SupersetSecurityException as ex:
raise ChartForbiddenError() from ex

View File

@@ -0,0 +1,156 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from flask_babel import _
from marshmallow.validate import ValidationError
from superset.commands.exceptions import (
CommandException,
CommandInvalidError,
CreateFailedError,
DeleteFailedError,
ForbiddenError,
ImportFailedError,
UpdateFailedError,
)
class TimeRangeAmbiguousError(ValidationError):
"""
Time range is ambiguous error.
"""
def __init__(self, human_readable: str) -> None:
super().__init__(
_(
"Time string is ambiguous."
" Please specify [%(human_readable)s ago]"
" or [%(human_readable)s later].",
human_readable=human_readable,
),
field_name="time_range",
)
class TimeRangeParseFailError(ValidationError):
def __init__(self, human_readable: str) -> None:
super().__init__(
_(
"Cannot parse time string [%(human_readable)s]",
human_readable=human_readable,
),
field_name="time_range",
)
class TimeDeltaAmbiguousError(ValidationError):
"""
Time delta is ambiguous error.
"""
def __init__(self, human_readable: str) -> None:
super().__init__(
_(
"Time delta is ambiguous."
" Please specify [%(human_readable)s ago]"
" or [%(human_readable)s later].",
human_readable=human_readable,
),
field_name="time_range",
)
class DatabaseNotFoundValidationError(ValidationError):
"""
Marshmallow validation error for database does not exist
"""
def __init__(self) -> None:
super().__init__(_("Database does not exist"), field_name="database")
class DashboardsNotFoundValidationError(ValidationError):
"""
Marshmallow validation error for dashboards don't exist
"""
def __init__(self) -> None:
super().__init__(_("Dashboards do not exist"), field_name="dashboards")
class DatasourceTypeUpdateRequiredValidationError(ValidationError):
"""
Marshmallow validation error for dashboards don't exist
"""
def __init__(self) -> None:
super().__init__(
_("Datasource type is required when datasource_id is given"),
field_names=["datasource_type"],
)
class ChartNotFoundError(CommandException):
message = "Chart not found."
class ChartInvalidError(CommandInvalidError):
message = _("Chart parameters are invalid.")
class ChartCreateFailedError(CreateFailedError):
message = _("Chart could not be created.")
class ChartUpdateFailedError(UpdateFailedError):
message = _("Chart could not be updated.")
class ChartDeleteFailedError(DeleteFailedError):
message = _("Charts could not be deleted.")
class ChartDeleteFailedReportsExistError(ChartDeleteFailedError):
message = _("There are associated alerts or reports")
class ChartAccessDeniedError(ForbiddenError):
message = _("You don't have access to this chart.")
class ChartForbiddenError(ForbiddenError):
message = _("Changing this chart is forbidden")
class ChartDataQueryFailedError(CommandException):
pass
class ChartDataCacheLoadError(CommandException):
pass
class ChartImportError(ImportFailedError):
message = _("Import chart failed for an unknown reason")
class DashboardsForbiddenError(ForbiddenError):
message = _("Changing one or more of these dashboards is forbidden")
class WarmUpCacheChartNotFoundError(CommandException):
status = 404
message = _("Chart not found")

View File

@@ -0,0 +1,75 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# isort:skip_file
import json
import logging
from collections.abc import Iterator
import yaml
from superset.commands.chart.exceptions import ChartNotFoundError
from superset.daos.chart import ChartDAO
from superset.commands.dataset.export import ExportDatasetsCommand
from superset.commands.export.models import ExportModelsCommand
from superset.models.slice import Slice
from superset.utils.dict_import_export import EXPORT_VERSION
from superset.utils.file import get_filename
logger = logging.getLogger(__name__)
# keys present in the standard export that are not needed
REMOVE_KEYS = ["datasource_type", "datasource_name", "url_params"]
class ExportChartsCommand(ExportModelsCommand):
dao = ChartDAO
not_found = ChartNotFoundError
@staticmethod
def _export(model: Slice, export_related: bool = True) -> Iterator[tuple[str, str]]:
file_name = get_filename(model.slice_name, model.id)
file_path = f"charts/{file_name}.yaml"
payload = model.export_to_dict(
recursive=False,
include_parent_ref=False,
include_defaults=True,
export_uuids=True,
)
# TODO (betodealmeida): move this logic to export_to_dict once this
# becomes the default export endpoint
payload = {
key: value for key, value in payload.items() if key not in REMOVE_KEYS
}
if payload.get("params"):
try:
payload["params"] = json.loads(payload["params"])
except json.decoder.JSONDecodeError:
logger.info("Unable to decode `params` field: %s", payload["params"])
payload["version"] = EXPORT_VERSION
if model.table:
payload["dataset_uuid"] = str(model.table.uuid)
file_content = yaml.safe_dump(payload, sort_keys=False)
yield file_path, file_content
if model.table and export_related:
yield from ExportDatasetsCommand([model.table.id]).run()

View File

@@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@@ -0,0 +1,70 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from typing import Any
from marshmallow.exceptions import ValidationError
from superset.commands.base import BaseCommand
from superset.commands.chart.importers import v1
from superset.commands.exceptions import CommandInvalidError
from superset.commands.importers.exceptions import IncorrectVersionError
logger = logging.getLogger(__name__)
command_versions = [
v1.ImportChartsCommand,
]
class ImportChartsCommand(BaseCommand):
"""
Import charts.
This command dispatches the import to different versions of the command
until it finds one that matches.
"""
def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any):
self.contents = contents
self.args = args
self.kwargs = kwargs
def run(self) -> None:
# iterate over all commands until we find a version that can
# handle the contents
for version in command_versions:
command = version(self.contents, *self.args, **self.kwargs)
try:
command.run()
return
except IncorrectVersionError:
logger.debug("File not handled by command, skipping")
except (CommandInvalidError, ValidationError) as exc:
# found right version, but file is invalid
logger.info("Command failed validation")
raise exc
except Exception as exc:
# validation succeeded but something went wrong
logger.exception("Error running import command")
raise exc
raise CommandInvalidError("Could not find a valid command to import file")
def validate(self) -> None:
pass

View File

@@ -0,0 +1,100 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
from marshmallow import Schema
from sqlalchemy.orm import Session
from superset.charts.schemas import ImportV1ChartSchema
from superset.commands.chart.exceptions import ChartImportError
from superset.commands.chart.importers.v1.utils import import_chart
from superset.commands.database.importers.v1.utils import import_database
from superset.commands.dataset.importers.v1.utils import import_dataset
from superset.commands.importers.v1 import ImportModelsCommand
from superset.connectors.sqla.models import SqlaTable
from superset.daos.chart import ChartDAO
from superset.databases.schemas import ImportV1DatabaseSchema
from superset.datasets.schemas import ImportV1DatasetSchema
class ImportChartsCommand(ImportModelsCommand):
"""Import charts"""
dao = ChartDAO
model_name = "chart"
prefix = "charts/"
schemas: dict[str, Schema] = {
"charts/": ImportV1ChartSchema(),
"datasets/": ImportV1DatasetSchema(),
"databases/": ImportV1DatabaseSchema(),
}
import_error = ChartImportError
@staticmethod
def _import(
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
# discover datasets associated with charts
dataset_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("charts/"):
dataset_uuids.add(config["dataset_uuid"])
# discover databases associated with datasets
database_uuids: set[str] = set()
for file_name, config in configs.items():
if file_name.startswith("datasets/") and config["uuid"] in dataset_uuids:
database_uuids.add(config["database_uuid"])
# import related databases
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database_ids[str(database.uuid)] = database.id
# import datasets with the correct parent ref
datasets: dict[str, SqlaTable] = {}
for file_name, config in configs.items():
if (
file_name.startswith("datasets/")
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
dataset = import_dataset(session, config, overwrite=False)
datasets[str(dataset.uuid)] = dataset
# import charts with the correct parent ref
for file_name, config in configs.items():
if file_name.startswith("charts/") and config["dataset_uuid"] in datasets:
# update datasource id, type, and name
dataset = datasets[config["dataset_uuid"]]
config.update(
{
"datasource_id": dataset.id,
"datasource_type": "table",
"datasource_name": dataset.table_name,
}
)
config["params"].update({"datasource": dataset.uid})
if "query_context" in config:
config["query_context"] = None
import_chart(session, config, overwrite=overwrite)

View File

@@ -0,0 +1,108 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import copy
import json
from inspect import isclass
from typing import Any
from flask import g
from sqlalchemy.orm import Session
from superset import security_manager
from superset.commands.exceptions import ImportFailedError
from superset.migrations.shared.migrate_viz import processors
from superset.migrations.shared.migrate_viz.base import MigrateViz
from superset.models.slice import Slice
def import_chart(
session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Slice:
can_write = ignore_permissions or security_manager.can_access("can_write", "Chart")
existing = session.query(Slice).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite or not can_write:
return existing
config["id"] = existing.id
elif not can_write:
raise ImportFailedError(
"Chart doesn't exist and user doesn't have permission to create charts"
)
# TODO (betodealmeida): move this logic to import_from_dict
config["params"] = json.dumps(config["params"])
# migrate old viz types to new ones
config = migrate_chart(config)
chart = Slice.import_from_dict(
session, config, recursive=False, allow_reparenting=True
)
if chart.id is None:
session.flush()
if hasattr(g, "user") and g.user:
chart.owners.append(g.user)
return chart
def migrate_chart(config: dict[str, Any]) -> dict[str, Any]:
"""
Used to migrate old viz types to new ones.
"""
migrators = {
class_.source_viz_type: class_
for class_ in processors.__dict__.values()
if isclass(class_)
and issubclass(class_, MigrateViz)
and hasattr(class_, "source_viz_type")
}
output = copy.deepcopy(config)
if config["viz_type"] not in migrators:
return output
migrator = migrators[config["viz_type"]](output["params"])
# pylint: disable=protected-access
migrator._pre_action()
migrator._migrate()
migrator._post_action()
params = migrator.data
params["viz_type"] = migrator.target_viz_type
output.update(
{
"params": json.dumps(params),
"viz_type": migrator.target_viz_type,
}
)
# also update `query_context`
try:
query_context = json.loads(output.get("query_context") or "{}")
except (json.decoder.JSONDecodeError, TypeError):
query_context = {}
if "form_data" in query_context:
query_context["form_data"] = output["params"]
output["query_context"] = json.dumps(query_context)
return output

View File

@@ -0,0 +1,119 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
from datetime import datetime
from typing import Any, Optional
from flask import g
from flask_appbuilder.models.sqla import Model
from marshmallow import ValidationError
from superset import security_manager
from superset.commands.base import BaseCommand, UpdateMixin
from superset.commands.chart.exceptions import (
ChartForbiddenError,
ChartInvalidError,
ChartNotFoundError,
ChartUpdateFailedError,
DashboardsNotFoundValidationError,
DatasourceTypeUpdateRequiredValidationError,
)
from superset.commands.utils import get_datasource_by_id
from superset.daos.chart import ChartDAO
from superset.daos.dashboard import DashboardDAO
from superset.daos.exceptions import DAOUpdateFailedError
from superset.exceptions import SupersetSecurityException
from superset.models.slice import Slice
logger = logging.getLogger(__name__)
def is_query_context_update(properties: dict[str, Any]) -> bool:
return set(properties) == {"query_context", "query_context_generation"} and bool(
properties.get("query_context_generation")
)
class UpdateChartCommand(UpdateMixin, BaseCommand):
def __init__(self, model_id: int, data: dict[str, Any]):
self._model_id = model_id
self._properties = data.copy()
self._model: Optional[Slice] = None
def run(self) -> Model:
self.validate()
assert self._model
try:
if self._properties.get("query_context_generation") is None:
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
chart = ChartDAO.update(self._model, self._properties)
except DAOUpdateFailedError as ex:
logger.exception(ex.exception)
raise ChartUpdateFailedError() from ex
return chart
def validate(self) -> None:
exceptions: list[ValidationError] = []
dashboard_ids = self._properties.get("dashboards")
owner_ids: Optional[list[int]] = self._properties.get("owners")
# Validate if datasource_id is provided datasource_type is required
datasource_id = self._properties.get("datasource_id")
if datasource_id is not None:
datasource_type = self._properties.get("datasource_type", "")
if not datasource_type:
exceptions.append(DatasourceTypeUpdateRequiredValidationError())
# Validate/populate model exists
self._model = ChartDAO.find_by_id(self._model_id)
if not self._model:
raise ChartNotFoundError()
# Check and update ownership; when only updating query context we ignore
# ownership so the update can be performed by report workers
if not is_query_context_update(self._properties):
try:
security_manager.raise_for_ownership(self._model)
owners = self.populate_owners(owner_ids)
self._properties["owners"] = owners
except SupersetSecurityException as ex:
raise ChartForbiddenError() from ex
except ValidationError as ex:
exceptions.append(ex)
# Validate/Populate datasource
if datasource_id is not None:
try:
datasource = get_datasource_by_id(datasource_id, datasource_type)
self._properties["datasource_name"] = datasource.name
except ValidationError as ex:
exceptions.append(ex)
# Validate/Populate dashboards only if it's a list
if dashboard_ids is not None:
dashboards = DashboardDAO.find_by_ids(
dashboard_ids,
skip_base_filter=True,
)
if len(dashboards) != len(dashboard_ids):
exceptions.append(DashboardsNotFoundValidationError())
self._properties["dashboards"] = dashboards
if exceptions:
raise ChartInvalidError(exceptions=exceptions)

View File

@@ -0,0 +1,108 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any, Optional, Union
import simplejson as json
from flask import g
from superset.commands.base import BaseCommand
from superset.commands.chart.data.get_data_command import ChartDataCommand
from superset.commands.chart.exceptions import (
ChartInvalidError,
WarmUpCacheChartNotFoundError,
)
from superset.extensions import db
from superset.models.slice import Slice
from superset.utils.core import error_msg_from_exception
from superset.views.utils import get_dashboard_extra_filters, get_form_data, get_viz
from superset.viz import viz_types
class ChartWarmUpCacheCommand(BaseCommand):
def __init__(
self,
chart_or_id: Union[int, Slice],
dashboard_id: Optional[int],
extra_filters: Optional[str],
):
self._chart_or_id = chart_or_id
self._dashboard_id = dashboard_id
self._extra_filters = extra_filters
def run(self) -> dict[str, Any]:
self.validate()
chart: Slice = self._chart_or_id # type: ignore
try:
form_data = get_form_data(chart.id, use_slice_data=True)[0]
if form_data.get("viz_type") in viz_types:
# Legacy visualizations.
if not chart.datasource:
raise ChartInvalidError("Chart's datasource does not exist")
if self._dashboard_id:
form_data["extra_filters"] = (
json.loads(self._extra_filters)
if self._extra_filters
else get_dashboard_extra_filters(chart.id, self._dashboard_id)
)
g.form_data = form_data
payload = get_viz(
datasource_type=chart.datasource.type,
datasource_id=chart.datasource.id,
form_data=form_data,
force=True,
).get_payload()
delattr(g, "form_data")
error = payload["errors"] or None
status = payload["status"]
else:
# Non-legacy visualizations.
query_context = chart.get_query_context()
if not query_context:
raise ChartInvalidError("Chart's query context does not exist")
query_context.force = True
command = ChartDataCommand(query_context)
command.validate()
payload = command.run()
# Report the first error.
for query in payload["queries"]:
error = query["error"]
status = query["status"]
if error is not None:
break
except Exception as ex: # pylint: disable=broad-except
error = error_msg_from_exception(ex)
status = None
return {"chart_id": chart.id, "viz_error": error, "viz_status": status}
def validate(self) -> None:
if isinstance(self._chart_or_id, Slice):
return
chart = db.session.query(Slice).filter_by(id=self._chart_or_id).scalar()
if not chart:
raise WarmUpCacheChartNotFoundError()
self._chart_or_id = chart