# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from __future__ import annotations import logging from datetime import datetime from typing import Any from superset.daos.base import BaseDAO from superset.extensions import db from superset.reports.filters import ReportScheduleFilter from superset.reports.models import ( ReportExecutionLog, ReportRecipients, ReportSchedule, ReportScheduleType, ReportState, ) from superset.utils import json from superset.utils.core import get_user_id logger = logging.getLogger(__name__) REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER = "Notification sent with error" class ReportScheduleDAO(BaseDAO[ReportSchedule]): base_filter = ReportScheduleFilter @staticmethod def find_by_chart_id(chart_id: int) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.chart_id == chart_id) .all() ) @staticmethod def find_by_chart_ids(chart_ids: list[int]) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.chart_id.in_(chart_ids)) .all() ) @staticmethod def find_by_dashboard_id(dashboard_id: int) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.dashboard_id == dashboard_id) .all() ) @staticmethod def find_by_dashboard_ids(dashboard_ids: list[int]) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.dashboard_id.in_(dashboard_ids)) .all() ) @staticmethod def find_by_database_id(database_id: int) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.database_id == database_id) .all() ) @staticmethod def find_by_database_ids(database_ids: list[int]) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.database_id.in_(database_ids)) .all() ) @staticmethod def find_by_extra_metadata(slug: str) -> list[ReportSchedule]: return ( db.session.query(ReportSchedule) .filter(ReportSchedule.extra_json.contains(slug, autoescape=True)) .all() ) @staticmethod def find_by_native_filter_id(native_filter_id: str) -> list[ReportSchedule]: """ searches extra_json for a filter ID string """ return ( db.session.query(ReportSchedule) .filter( ReportSchedule.extra_json.contains(native_filter_id, autoescape=True) ) .all() ) @staticmethod def validate_unique_creation_method( dashboard_id: int | None = None, chart_id: int | None = None ) -> bool: """ Validate if the user already has a chart or dashboard with a report attached form the self subscribe reports """ query = db.session.query(ReportSchedule).filter_by(created_by_fk=get_user_id()) if dashboard_id is not None: query = query.filter(ReportSchedule.dashboard_id == dashboard_id) if chart_id is not None: query = query.filter(ReportSchedule.chart_id == chart_id) return not db.session.query(query.exists()).scalar() @staticmethod def validate_update_uniqueness( name: str, report_type: ReportScheduleType, expect_id: int | None = None ) -> bool: """ Validate if this name and type is unique. :param name: The report schedule name :param report_type: The report schedule type :param expect_id: The id of the expected report schedule with the name + type combination. Useful for validating existing report schedule. :return: bool """ found_id = ( db.session.query(ReportSchedule.id) .filter(ReportSchedule.name == name, ReportSchedule.type == report_type) .limit(1) .scalar() ) return found_id is None or found_id == expect_id @classmethod def create( cls, item: ReportSchedule | None = None, attributes: dict[str, Any] | None = None, ) -> ReportSchedule: """ Create a report schedule with nested recipients. :param item: The object to create :param attributes: The attributes associated with the object to create """ # TODO(john-bodley): Determine why we need special handling for recipients. if not item: item = ReportSchedule() if attributes: if recipients := attributes.pop("recipients", None): attributes["recipients"] = [ ReportRecipients( type=recipient["type"], recipient_config_json=json.dumps( recipient["recipient_config_json"] ), report_schedule=item, ) for recipient in recipients ] return super().create(item, attributes) @classmethod def update( cls, item: ReportSchedule | None = None, attributes: dict[str, Any] | None = None, ) -> ReportSchedule: """ Update a report schedule with nested recipients. :param item: The object to update :param attributes: The attributes associated with the object to update """ # TODO(john-bodley): Determine why we need special handling for recipients. if not item: item = ReportSchedule() if attributes: if "recipients" in attributes: recipients = attributes.pop("recipients") attributes["recipients"] = [ ReportRecipients( type=recipient["type"], recipient_config_json=json.dumps( recipient["recipient_config_json"] ), report_schedule=item, ) for recipient in recipients ] return super().update(item, attributes) @staticmethod def find_active() -> list[ReportSchedule]: """ Find all active reports. """ return ( db.session.query(ReportSchedule) .filter(ReportSchedule.active.is_(True)) .all() ) @staticmethod def find_last_success_log( report_schedule: ReportSchedule, ) -> ReportExecutionLog | None: """ Finds last success execution log for a given report """ return ( db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.state == ReportState.SUCCESS, ReportExecutionLog.report_schedule == report_schedule, ) .order_by(ReportExecutionLog.end_dttm.desc()) .first() ) @staticmethod def find_last_entered_working_log( report_schedule: ReportSchedule, ) -> ReportExecutionLog | None: """ Finds last success execution log for a given report """ return ( db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.state == ReportState.WORKING, ReportExecutionLog.report_schedule == report_schedule, ReportExecutionLog.error_message.is_(None), ) .order_by(ReportExecutionLog.end_dttm.desc()) .first() ) @staticmethod def find_last_error_notification( report_schedule: ReportSchedule, ) -> ReportExecutionLog | None: """ Finds last error email sent """ last_error_email_log = ( db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.error_message == REPORT_SCHEDULE_ERROR_NOTIFICATION_MARKER, ReportExecutionLog.report_schedule == report_schedule, ) .order_by(ReportExecutionLog.end_dttm.desc()) .first() ) if not last_error_email_log: return None # Checks that only errors have occurred since the last email report_from_last_email = ( db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.state.notin_( [ReportState.ERROR, ReportState.WORKING] ), ReportExecutionLog.report_schedule == report_schedule, ReportExecutionLog.end_dttm < last_error_email_log.end_dttm, ) .order_by(ReportExecutionLog.end_dttm.desc()) .first() ) return last_error_email_log if not report_from_last_email else None @staticmethod def bulk_delete_logs(model: ReportSchedule, from_date: datetime) -> int | None: return ( db.session.query(ReportExecutionLog) .filter( ReportExecutionLog.report_schedule == model, ReportExecutionLog.end_dttm < from_date, ) .delete(synchronize_session="fetch") )