fix: refactored SQL-based alerts to not pass sqlalchemy objects as args (#10506)

* refractored alerting to not pass sqlalchemy obj as args

* updated to pass only alert id as arg

* used object id instead of argument

* updated alerts_tests.py to reflect change

Co-authored-by: Jason Davis <@dropbox.com>
This commit is contained in:
Jason Davis
2020-08-04 09:52:32 -07:00
committed by GitHub
parent 5bb8b9790f
commit 9c5b0e1c86
2 changed files with 14 additions and 10 deletions

View File

@@ -551,10 +551,10 @@ def schedule_alert_query( # pylint: disable=unused-argument
if report_type == ScheduleType.alert:
if is_test_alert and recipients:
deliver_alert(schedule, recipients)
deliver_alert(schedule.id, recipients)
return
if run_alert_query(schedule):
if run_alert_query(schedule.id):
# deliver_dashboard OR deliver_slice
return
else:
@@ -567,7 +567,9 @@ class AlertState:
PASS = "pass"
def deliver_alert(alert: Alert, recipients: Optional[str] = None) -> None:
def deliver_alert(alert_id: int, recipients: Optional[str] = None) -> None:
alert = db.session.query(Alert).get(alert_id)
logging.info("Triggering alert: %s", alert)
img_data = None
images = {}
@@ -612,10 +614,12 @@ def deliver_alert(alert: Alert, recipients: Optional[str] = None) -> None:
_deliver_email(recipients, deliver_as_group, subject, body, data, images)
def run_alert_query(alert: Alert) -> Optional[bool]:
def run_alert_query(alert_id: int) -> Optional[bool]:
"""
Execute alert.sql and return value if any rows are returned
"""
alert = db.session.query(Alert).get(alert_id)
logger.info("Processing alert ID: %i", alert.id)
database = alert.database
if not database:
@@ -650,7 +654,7 @@ def run_alert_query(alert: Alert) -> Optional[bool]:
for row in df.to_records():
if any(row):
state = AlertState.TRIGGER
deliver_alert(alert)
deliver_alert(alert.id)
break
if not state:
state = AlertState.PASS

View File

@@ -86,21 +86,21 @@ def teardown_module():
@patch("superset.tasks.schedules.logging.Logger.error")
def test_run_alert_query(mock_error, mock_deliver_alert):
with app.app_context():
run_alert_query(db.session.query(Alert).filter_by(id=1).one())
run_alert_query(db.session.query(Alert).filter_by(id=1).one().id)
alert1 = db.session.query(Alert).filter_by(id=1).one()
assert mock_deliver_alert.call_count == 0
assert len(alert1.logs) == 1
assert alert1.logs[0].alert_id == 1
assert alert1.logs[0].state == "pass"
run_alert_query(db.session.query(Alert).filter_by(id=2).one())
run_alert_query(db.session.query(Alert).filter_by(id=2).one().id)
alert2 = db.session.query(Alert).filter_by(id=2).one()
assert mock_deliver_alert.call_count == 1
assert len(alert2.logs) == 1
assert alert2.logs[0].alert_id == 2
assert alert2.logs[0].state == "trigger"
run_alert_query(db.session.query(Alert).filter_by(id=3).one())
run_alert_query(db.session.query(Alert).filter_by(id=3).one().id)
alert3 = db.session.query(Alert).filter_by(id=3).one()
assert mock_deliver_alert.call_count == 1
assert mock_error.call_count == 2
@@ -108,11 +108,11 @@ def test_run_alert_query(mock_error, mock_deliver_alert):
assert alert3.logs[0].alert_id == 3
assert alert3.logs[0].state == "error"
run_alert_query(db.session.query(Alert).filter_by(id=4).one())
run_alert_query(db.session.query(Alert).filter_by(id=4).one().id)
assert mock_deliver_alert.call_count == 1
assert mock_error.call_count == 3
run_alert_query(db.session.query(Alert).filter_by(id=5).one())
run_alert_query(db.session.query(Alert).filter_by(id=5).one().id)
assert mock_deliver_alert.call_count == 1
assert mock_error.call_count == 4