diff --git a/superset/commands/report/exceptions.py b/superset/commands/report/exceptions.py index 27966e6f09e..b51c2d1cd1f 100644 --- a/superset/commands/report/exceptions.py +++ b/superset/commands/report/exceptions.py @@ -39,6 +39,18 @@ class DatabaseNotFoundValidationError(ValidationError): super().__init__(_("Database does not exist"), field_name="database") +class ReportScheduleDatabaseNotAllowedValidationError(ValidationError): + """ + Marshmallow validation error for database reference on a Report type schedule + """ + + def __init__(self) -> None: + super().__init__( + _("Database reference is not allowed on a report"), + field_name="database", + ) + + class DashboardNotFoundValidationError(ValidationError): """ Marshmallow validation error for dashboard does not exist diff --git a/superset/commands/report/update.py b/superset/commands/report/update.py index abae62cadd2..1ee09ea7a71 100644 --- a/superset/commands/report/update.py +++ b/superset/commands/report/update.py @@ -26,6 +26,8 @@ from superset.commands.base import UpdateMixin from superset.commands.report.base import BaseReportScheduleCommand from superset.commands.report.exceptions import ( DatabaseNotFoundValidationError, + ReportScheduleAlertRequiredDatabaseValidationError, + ReportScheduleDatabaseNotAllowedValidationError, ReportScheduleForbiddenError, ReportScheduleInvalidError, ReportScheduleNameUniquenessValidationError, @@ -98,8 +100,22 @@ class UpdateReportScheduleCommand(UpdateMixin, BaseReportScheduleCommand): ) ) + # Determine effective database state (payload overrides model) + if "database" in self._properties: + has_database = self._properties["database"] is not None + else: + has_database = self._model.database_id is not None + + # Validate database is not allowed on Report type + if report_type == ReportScheduleType.REPORT and has_database: + exceptions.append(ReportScheduleDatabaseNotAllowedValidationError()) + + # Validate Alert has a database + if report_type == ReportScheduleType.ALERT and not has_database: + exceptions.append(ReportScheduleAlertRequiredDatabaseValidationError()) + # Validate if DB exists (for alerts) - if report_type == ReportScheduleType.ALERT and database_id: + if report_type == ReportScheduleType.ALERT and database_id is not None: if not (database := DatabaseDAO.find_by_id(database_id)): exceptions.append(DatabaseNotFoundValidationError()) self._properties["database"] = database diff --git a/superset/reports/schemas.py b/superset/reports/schemas.py index 76108f976ff..055ebc75d24 100644 --- a/superset/reports/schemas.py +++ b/superset/reports/schemas.py @@ -334,7 +334,7 @@ class ReportSchedulePutSchema(Schema): metadata={"description": creation_method_description}, ) dashboard = fields.Integer(required=False, allow_none=True) - database = fields.Integer(required=False) + database = fields.Integer(required=False, allow_none=True) owners = fields.List( fields.Integer(metadata={"description": owners_description}), required=False ) diff --git a/tests/integration_tests/reports/api_tests.py b/tests/integration_tests/reports/api_tests.py index 41edf717020..87fb8ccf046 100644 --- a/tests/integration_tests/reports/api_tests.py +++ b/tests/integration_tests/reports/api_tests.py @@ -18,6 +18,7 @@ """Unit tests for Superset""" from datetime import datetime, timedelta +from typing import Any from unittest.mock import patch import pytz @@ -1392,7 +1393,7 @@ class TestReportSchedulesApi(SupersetTestCase): ) assert report_schedule.type == ReportScheduleType.ALERT previous_cron = report_schedule.crontab - update_payload = { + update_payload: dict[str, Any] = { "crontab": "5,10 * * * *", } with patch.dict( @@ -1410,6 +1411,7 @@ class TestReportSchedulesApi(SupersetTestCase): # Test report minimum interval update_payload["crontab"] = "5,8 * * * *" update_payload["type"] = ReportScheduleType.REPORT + update_payload["database"] = None uri = f"api/v1/report/{report_schedule.id}" rv = self.put_assert_metric(uri, update_payload, "put") assert rv.status_code == 200 @@ -1424,6 +1426,7 @@ class TestReportSchedulesApi(SupersetTestCase): # Undo changes update_payload["crontab"] = previous_cron update_payload["type"] = ReportScheduleType.ALERT + update_payload["database"] = get_example_database().id uri = f"api/v1/report/{report_schedule.id}" rv = self.put_assert_metric(uri, update_payload, "put") assert rv.status_code == 200 @@ -1441,7 +1444,7 @@ class TestReportSchedulesApi(SupersetTestCase): .one_or_none() ) assert report_schedule.type == ReportScheduleType.ALERT - update_payload = { + update_payload: dict[str, Any] = { "crontab": "5,10 * * * *", } with patch.dict( @@ -1468,6 +1471,7 @@ class TestReportSchedulesApi(SupersetTestCase): # Exceed report minimum interval update_payload["crontab"] = "5,8 * * * *" update_payload["type"] = ReportScheduleType.REPORT + update_payload["database"] = None uri = f"api/v1/report/{report_schedule.id}" rv = self.put_assert_metric(uri, update_payload, "put") assert rv.status_code == 422 @@ -1607,6 +1611,292 @@ class TestReportSchedulesApi(SupersetTestCase): data = json.loads(rv.data.decode("utf-8")) assert data == {"message": {"chart": "Choose a chart or dashboard not both"}} + @pytest.mark.usefixtures("create_report_schedules") + def test_update_report_schedule_database_not_allowed_on_report(self): + """ + ReportSchedule API: Test update report schedule rejects database on Report type + """ + self.login(ADMIN_USERNAME) + example_db = get_example_database() + + # Create a Report-type schedule (name1 is an Alert, so create one) + report_schedule = ( + db.session.query(ReportSchedule) + .filter(ReportSchedule.name == "name1") + .one_or_none() + ) + # Change to Report type first (clearing database) + uri = f"api/v1/report/{report_schedule.id}" + rv = self.put_assert_metric( + uri, + {"type": ReportScheduleType.REPORT, "database": None}, + "put", + ) + assert rv.status_code == 200 + + # Test 1: Report + database (no type in payload) → 422 + rv = self.put_assert_metric(uri, {"database": example_db.id}, "put") + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == { + "message": {"database": "Database reference is not allowed on a report"} + } + + # Test 2: Report + database + explicit type=Report → 422 + rv = self.put_assert_metric( + uri, + {"type": ReportScheduleType.REPORT, "database": example_db.id}, + "put", + ) + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == { + "message": {"database": "Database reference is not allowed on a report"} + } + + @pytest.mark.usefixtures("create_report_schedules") + def test_update_report_schedule_nonexistent_database_returns_not_allowed(self): + """ + ReportSchedule API: Test Report + nonexistent DB returns 'not allowed', + not 'does not exist' — type invariant takes precedence. + """ + self.login(ADMIN_USERNAME) + + report_schedule = ( + db.session.query(ReportSchedule) + .filter(ReportSchedule.name == "name1") + .one_or_none() + ) + uri = f"api/v1/report/{report_schedule.id}" + + # Transition to Report type first + rv = self.put_assert_metric( + uri, + {"type": ReportScheduleType.REPORT, "database": None}, + "put", + ) + assert rv.status_code == 200 + + # Report + nonexistent DB → 422 "not allowed" (not "does not exist") + database_max_id = db.session.query(func.max(Database.id)).scalar() + rv = self.put_assert_metric(uri, {"database": database_max_id + 1}, "put") + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == { + "message": {"database": "Database reference is not allowed on a report"} + } + + @pytest.mark.usefixtures("create_report_schedules") + def test_update_alert_schedule_database_allowed(self): + """ + ReportSchedule API: Test update alert schedule accepts database + """ + self.login(ADMIN_USERNAME) + example_db = get_example_database() + + report_schedule = ( + db.session.query(ReportSchedule) + .filter(ReportSchedule.name == "name2") + .one_or_none() + ) + assert report_schedule.type == ReportScheduleType.ALERT + + # Test 3: Alert + database (no type in payload) → 200 + uri = f"api/v1/report/{report_schedule.id}" + rv = self.put_assert_metric(uri, {"database": example_db.id}, "put") + assert rv.status_code == 200 + + @pytest.mark.usefixtures("create_report_schedules") + def test_update_report_schedule_type_transitions(self): + """ + ReportSchedule API: Test type transitions with database validation + """ + self.login(ADMIN_USERNAME) + example_db = get_example_database() + + report_schedule = ( + db.session.query(ReportSchedule) + .filter(ReportSchedule.name == "name3") + .one_or_none() + ) + assert report_schedule.type == ReportScheduleType.ALERT + assert report_schedule.database_id is not None + uri = f"api/v1/report/{report_schedule.id}" + + # Test 4: Alert + database update (same type) → 200 + rv = self.put_assert_metric( + uri, + {"database": example_db.id}, + "put", + ) + assert rv.status_code == 200 + + # Test 5: Alert → Report + database → 422 + rv = self.put_assert_metric( + uri, + { + "type": ReportScheduleType.REPORT, + "database": example_db.id, + }, + "put", + ) + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == { + "message": {"database": "Database reference is not allowed on a report"} + } + + # Test 6: Alert → Report without clearing database → 422 + rv = self.put_assert_metric(uri, {"type": ReportScheduleType.REPORT}, "put") + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == { + "message": {"database": "Database reference is not allowed on a report"} + } + + # Test 7: Alert → Report with database: null (explicit clear) → 200 + rv = self.put_assert_metric( + uri, + {"type": ReportScheduleType.REPORT, "database": None}, + "put", + ) + assert rv.status_code == 200 + + # Now schedule is a Report with no database. + # Test 8: Report → Alert without providing database → 422 + rv = self.put_assert_metric( + uri, + {"type": ReportScheduleType.ALERT}, + "put", + ) + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == {"message": {"database": "Database is required for alerts"}} + + # Test 9: Report → Alert with database → 200 (valid transition) + rv = self.put_assert_metric( + uri, + {"type": ReportScheduleType.ALERT, "database": example_db.id}, + "put", + ) + assert rv.status_code == 200 + + @pytest.mark.usefixtures("create_report_schedules") + def test_update_alert_schedule_database_null_rejected(self): + """ + ReportSchedule API: Test alert schedule rejects null database + """ + self.login(ADMIN_USERNAME) + + report_schedule = ( + db.session.query(ReportSchedule) + .filter(ReportSchedule.name == "name2") + .one_or_none() + ) + assert report_schedule.type == ReportScheduleType.ALERT + uri = f"api/v1/report/{report_schedule.id}" + + # Test 8: Alert + database: null → 422 + rv = self.put_assert_metric(uri, {"database": None}, "put") + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == {"message": {"database": "Database is required for alerts"}} + + @pytest.mark.usefixtures("create_report_schedules") + def test_update_report_schedule_422_does_not_mutate(self): + """ + ReportSchedule API: Test that a rejected PUT does not mutate the model + """ + self.login(ADMIN_USERNAME) + + report_schedule = ( + db.session.query(ReportSchedule) + .filter(ReportSchedule.name == "name2") + .one_or_none() + ) + assert report_schedule.type == ReportScheduleType.ALERT + original_type = report_schedule.type + original_database_id = report_schedule.database_id + assert original_database_id is not None + uri = f"api/v1/report/{report_schedule.id}" + + # Alert→Report without clearing database → 422 + rv = self.put_assert_metric(uri, {"type": ReportScheduleType.REPORT}, "put") + assert rv.status_code == 422 + + # Re-query and verify no mutation + db.session.expire(report_schedule) + report_schedule = ( + db.session.query(ReportSchedule) + .filter(ReportSchedule.id == report_schedule.id) + .one_or_none() + ) + assert report_schedule.type == original_type + assert report_schedule.database_id == original_database_id + + @pytest.mark.usefixtures( + "load_birth_names_dashboard_with_slices", "create_report_schedules" + ) + def test_create_report_schedule_database_not_allowed(self): + """ + ReportSchedule API: Test POST rejects database on Report type at schema level + """ + self.login(ADMIN_USERNAME) + + chart = db.session.query(Slice).first() + example_db = get_example_database() + report_schedule_data = { + "type": ReportScheduleType.REPORT, + "name": "report_with_db", + "description": "should fail", + "crontab": "0 9 * * *", + "creation_method": ReportCreationMethod.ALERTS_REPORTS, + "chart": chart.id, + "database": example_db.id, + } + uri = "api/v1/report/" + rv = self.post_assert_metric(uri, report_schedule_data, "post") + assert rv.status_code == 400 + data = json.loads(rv.data.decode("utf-8")) + assert "database" in data.get("message", {}) + + @pytest.mark.usefixtures("create_report_schedules") + def test_update_report_to_alert_nonexistent_database(self): + """ + ReportSchedule API: Test Report→Alert with nonexistent database returns 422 + """ + self.login(ADMIN_USERNAME) + + report_schedule = ( + db.session.query(ReportSchedule) + .filter(ReportSchedule.name == "name4") + .one_or_none() + ) + assert report_schedule.type == ReportScheduleType.ALERT + uri = f"api/v1/report/{report_schedule.id}" + + # First transition to Report (clearing database) + rv = self.put_assert_metric( + uri, + {"type": ReportScheduleType.REPORT, "database": None}, + "put", + ) + assert rv.status_code == 200 + + # Now transition back to Alert with nonexistent database + database_max_id = db.session.query(func.max(Database.id)).scalar() + rv = self.put_assert_metric( + uri, + { + "type": ReportScheduleType.ALERT, + "database": database_max_id + 1, + }, + "put", + ) + assert rv.status_code == 422 + data = json.loads(rv.data.decode("utf-8")) + assert data == {"message": {"database": "Database does not exist"}} + @pytest.mark.usefixtures( "load_birth_names_dashboard_with_slices", "create_report_schedules" ) diff --git a/tests/unit_tests/commands/report/__init__.py b/tests/unit_tests/commands/report/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/tests/unit_tests/commands/report/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/tests/unit_tests/commands/report/update_test.py b/tests/unit_tests/commands/report/update_test.py new file mode 100644 index 00000000000..6b515781b4d --- /dev/null +++ b/tests/unit_tests/commands/report/update_test.py @@ -0,0 +1,254 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Unit tests for UpdateReportScheduleCommand.validate() database invariants.""" + +from __future__ import annotations + +from unittest.mock import Mock + +import pytest +from pytest_mock import MockerFixture + +from superset.commands.report.exceptions import ( + ReportScheduleInvalidError, +) +from superset.commands.report.update import UpdateReportScheduleCommand +from superset.reports.models import ReportScheduleType + + +def _make_model( + mocker: MockerFixture, + *, + model_type: ReportScheduleType | str, + database_id: int | None, +) -> Mock: + model = mocker.Mock() + model.type = model_type + model.database_id = database_id + model.name = "test_schedule" + model.crontab = "0 9 * * *" + model.last_state = "noop" + model.owners = [] + return model + + +def _setup_mocks(mocker: MockerFixture, model: Mock) -> None: + mocker.patch( + "superset.commands.report.update.ReportScheduleDAO.find_by_id", + return_value=model, + ) + mocker.patch( + "superset.commands.report.update.ReportScheduleDAO.validate_update_uniqueness", + return_value=True, + ) + mocker.patch( + "superset.commands.report.update.security_manager.raise_for_ownership", + ) + mocker.patch( + "superset.commands.report.update.DatabaseDAO.find_by_id", + return_value=mocker.Mock(), + ) + mocker.patch.object( + UpdateReportScheduleCommand, + "validate_chart_dashboard", + ) + mocker.patch.object( + UpdateReportScheduleCommand, + "validate_report_frequency", + ) + mocker.patch.object( + UpdateReportScheduleCommand, + "compute_owners", + return_value=[], + ) + + +def _get_validation_messages( + exc_info: pytest.ExceptionInfo[ReportScheduleInvalidError], +) -> dict[str, str]: + """Extract field→first message string from ReportScheduleInvalidError.""" + raw = exc_info.value.normalized_messages() + result = {} + for field, msgs in raw.items(): + if isinstance(msgs, list): + result[field] = str(msgs[0]) + else: + result[field] = str(msgs) + return result + + +# --- Report type: database must NOT be set --- + + +def test_report_with_database_in_payload_rejected(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.REPORT, database_id=None) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand(model_id=1, data={"database": 5}) + with pytest.raises(ReportScheduleInvalidError) as exc_info: + cmd.validate() + messages = _get_validation_messages(exc_info) + assert "database" in messages + assert "not allowed" in messages["database"].lower() + + +def test_report_with_database_none_in_payload_accepted(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.REPORT, database_id=None) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand(model_id=1, data={"database": None}) + cmd.validate() # should not raise + + +def test_report_no_database_in_payload_model_has_db_rejected( + mocker: MockerFixture, +) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.REPORT, database_id=5) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand(model_id=1, data={}) + with pytest.raises(ReportScheduleInvalidError) as exc_info: + cmd.validate() + messages = _get_validation_messages(exc_info) + assert "database" in messages + assert "not allowed" in messages["database"].lower() + + +def test_report_no_database_anywhere_accepted(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.REPORT, database_id=None) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand(model_id=1, data={}) + cmd.validate() # should not raise + + +# --- Alert type: database MUST be set --- + + +def test_alert_with_database_in_payload_accepted(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.ALERT, database_id=None) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand(model_id=1, data={"database": 5}) + cmd.validate() # should not raise + + +def test_alert_with_database_none_in_payload_rejected(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.ALERT, database_id=5) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand(model_id=1, data={"database": None}) + with pytest.raises(ReportScheduleInvalidError) as exc_info: + cmd.validate() + messages = _get_validation_messages(exc_info) + assert "database" in messages + assert "required" in messages["database"].lower() + + +def test_alert_no_database_in_payload_model_has_db_accepted( + mocker: MockerFixture, +) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.ALERT, database_id=5) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand(model_id=1, data={}) + cmd.validate() # should not raise + + +def test_alert_no_database_anywhere_rejected(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.ALERT, database_id=None) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand(model_id=1, data={}) + with pytest.raises(ReportScheduleInvalidError) as exc_info: + cmd.validate() + messages = _get_validation_messages(exc_info) + assert "database" in messages + assert "required" in messages["database"].lower() + + +# --- Type transitions --- + + +def test_alert_to_report_without_clearing_db_rejected(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.ALERT, database_id=5) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand( + model_id=1, data={"type": ReportScheduleType.REPORT} + ) + with pytest.raises(ReportScheduleInvalidError) as exc_info: + cmd.validate() + messages = _get_validation_messages(exc_info) + assert "database" in messages + assert "not allowed" in messages["database"].lower() + + +def test_alert_to_report_with_db_cleared_accepted(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.ALERT, database_id=5) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand( + model_id=1, + data={"type": ReportScheduleType.REPORT, "database": None}, + ) + cmd.validate() # should not raise + + +def test_report_to_alert_without_db_rejected(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.REPORT, database_id=None) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand( + model_id=1, data={"type": ReportScheduleType.ALERT} + ) + with pytest.raises(ReportScheduleInvalidError) as exc_info: + cmd.validate() + messages = _get_validation_messages(exc_info) + assert "database" in messages + assert "required" in messages["database"].lower() + + +def test_report_with_nonexistent_database_returns_not_allowed( + mocker: MockerFixture, +) -> None: + """Report + nonexistent DB must return 'not allowed', not 'does not exist'.""" + model = _make_model(mocker, model_type=ReportScheduleType.REPORT, database_id=None) + _setup_mocks(mocker, model) + mocker.patch( + "superset.commands.report.update.DatabaseDAO.find_by_id", + return_value=None, + ) + + cmd = UpdateReportScheduleCommand(model_id=1, data={"database": 99999}) + with pytest.raises(ReportScheduleInvalidError) as exc_info: + cmd.validate() + messages = _get_validation_messages(exc_info) + assert "database" in messages + assert "not allowed" in messages["database"].lower() + assert "does not exist" not in messages["database"].lower() + + +def test_report_to_alert_with_db_accepted(mocker: MockerFixture) -> None: + model = _make_model(mocker, model_type=ReportScheduleType.REPORT, database_id=None) + _setup_mocks(mocker, model) + + cmd = UpdateReportScheduleCommand( + model_id=1, + data={"type": ReportScheduleType.ALERT, "database": 5}, + ) + cmd.validate() # should not raise