# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from typing import Any, Dict, Iterable, List, Mapping, Optional, Sequence, Set, Union from flask import current_app, g from flask_appbuilder import Model from marshmallow import post_load, pre_load, Schema, ValidationError from sqlalchemy.orm.exc import NoResultFound def validate_owner(value: int) -> None: try: ( current_app.appbuilder.get_session.query( current_app.appbuilder.sm.user_model.id ) .filter_by(id=value) .one() ) except NoResultFound: raise ValidationError(f"User {value} does not exist") class BaseSupersetSchema(Schema): """ Extends Marshmallow schema so that we can pass a Model to load (following marshamallow-sqlalchemy pattern). This is useful to perform partial model merges on HTTP PUT """ __class_model__: Model = None def __init__(self, **kwargs: Any) -> None: self.instance: Optional[Model] = None super().__init__(**kwargs) def load( # pylint: disable=arguments-differ self, data: Union[Mapping[str, Any], Iterable[Mapping[str, Any]]], many: Optional[bool] = None, partial: Union[bool, Sequence[str], Set[str], None] = None, instance: Optional[Model] = None, **kwargs: Any, ) -> Any: self.instance = instance if many is None: many = False if partial is None: partial = False return super().load(data, many=many, partial=partial, **kwargs) @post_load def make_object( self, data: Dict[Any, Any], discard: Optional[List[str]] = None ) -> Model: """ Creates a Model object from POST or PUT requests. PUT will use self.instance previously fetched from the endpoint handler :param data: Schema data payload :param discard: List of fields to not set on the model """ discard = discard or [] if not self.instance: self.instance = self.__class_model__() # pylint: disable=not-callable for field in data: if field not in discard: setattr(self.instance, field, data.get(field)) return self.instance class BaseOwnedSchema(BaseSupersetSchema): """ Implements owners validation,pre load and post_load (to populate the owners field) on Marshmallow schemas """ owners_field_name = "owners" @post_load def make_object( self, data: Dict[str, Any], discard: Optional[List[str]] = None ) -> Model: discard = discard or [] discard.append(self.owners_field_name) instance = super().make_object(data, discard) if "owners" not in data and g.user not in instance.owners: instance.owners.append(g.user) if self.owners_field_name in data: self.set_owners(instance, data[self.owners_field_name]) return instance @pre_load def pre_load(self, data: Dict[Any, Any]) -> None: # if PUT request don't set owners to empty list if not self.instance: data[self.owners_field_name] = data.get(self.owners_field_name, []) @staticmethod def set_owners(instance: Model, owners: List[int]) -> None: owner_objs = list() if g.user.id not in owners: owners.append(g.user.id) for owner_id in owners: user = current_app.appbuilder.get_session.query( current_app.appbuilder.sm.user_model ).get(owner_id) owner_objs.append(user) instance.owners = owner_objs