diff --git a/anyway/views/news_flash/api.py b/anyway/views/news_flash/api.py index 525c0cef..5ab14e87 100644 --- a/anyway/views/news_flash/api.py +++ b/anyway/views/news_flash/api.py @@ -31,6 +31,7 @@ from anyway.parsers.resolution_fields import ResolutionFields as RF from anyway.models import AccidentMarkerView, InvolvedView from anyway.widgets.widget_utils import get_accidents_stats +from anyway.parsers.location_extraction import get_road_segment_name_and_number from anyway.telegram_accident_notifications import trigger_generate_infographics_and_send_to_telegram from io import BytesIO @@ -321,12 +322,25 @@ def update_news_flash_qualifying(id): old_location, old_location_qualifiction = extracted_location_and_qualification(news_flash_obj) if manual_update: if use_road_segment: + road_number, road_segment_name = get_road_segment_name_and_number(road_segment_id) + if road_number != road1: + logging.error("road number from road_segment_id does not match road1 input.") + return return_json_error(Es.BR_BAD_FIELD) news_flash_obj.road_segment_id = road_segment_id - news_flash_obj.road1 = road1 + news_flash_obj.road_segment_name = road_segment_name + news_flash_obj.road1 = road_number + news_flash_obj.road2 = None + news_flash_obj.yishuv_name = None + news_flash_obj.street1_hebrew = None + news_flash_obj.street2_hebrew = None news_flash_obj.resolution = BE_CONST.ResolutionCategories.SUBURBAN_ROAD.value else: news_flash_obj.yishuv_name = yishuv_name news_flash_obj.street1_hebrew = street1_hebrew + news_flash_obj.road_segment_id = None + news_flash_obj.road_segment_name = None + news_flash_obj.road1 = None + news_flash_obj.road2 = None news_flash_obj.resolution = BE_CONST.ResolutionCategories.STREET.value else: if ((news_flash_obj.road_segment_id is None) or (news_flash_obj.road1 is None)) and ( diff --git a/tests/test_news_flash_api.py b/tests/test_news_flash_api.py index bf6b16a6..c8838885 100644 --- a/tests/test_news_flash_api.py +++ b/tests/test_news_flash_api.py @@ -13,7 +13,7 @@ OFFSET, ) from anyway.backend_constants import BE_CONST -from anyway.models import LocationVerificationHistory, NewsFlash, Users +from anyway.models import LocationVerificationHistory, NewsFlash, Users, RoadSegments import tests.test_flask as tests_flask # global application scope. create Session class, engine @@ -82,27 +82,36 @@ def test_add_location_qualifiction_history(self, can, current_user, get_current_ db_mock.session = self.session user_id = self.session.query(Users).all()[0].id tests_flask.set_current_user_mock(current_user, user_id=user_id) + road_segment = RoadSegments( + segment_id=100, + road=1, + from_name='from_name', + to_name='to_name', + ) + db_mock.session.add(road_segment) + db_mock.session.commit() with patch("anyway.views.news_flash.api.db", db_mock): - mock_request = unittest.mock.MagicMock() - values = {"newsflash_location_qualification": "manual", "road_segment_id": 100, "road1": "1"} - mock_request.values.get = lambda key: values.get(key) - with patch("anyway.views.news_flash.api.request", mock_request): - id = self.session.query(NewsFlash).all()[0].id - return_value = update_news_flash_qualifying(id) - self.assertEqual(return_value.status_code, HTTPStatus.OK.value) - location_verifiction_history = ( - self.session.query(LocationVerificationHistory).all()[0].serialize() - ) - self.assertEqual(location_verifiction_history["user_id"], user_id) - saved_location = json.loads(location_verifiction_history["location_after_change"]) - saved_road_segment_id = saved_location["road_segment_id"] - saved_road_num = saved_location["road1"] - self.assertEqual(saved_road_segment_id, values["road_segment_id"]) - self.assertEqual(saved_road_num, float(values["road1"])) - self.assertEqual( - values["newsflash_location_qualification"], - location_verifiction_history["location_verification_after_change"], - ) + with patch("anyway.app_and_db.db", db_mock): + mock_request = unittest.mock.MagicMock() + values = {"newsflash_location_qualification": "manual", "road_segment_id": 100, "road1": 1} + mock_request.values.get = lambda key: values.get(key) + with patch("anyway.views.news_flash.api.request", mock_request): + id = self.session.query(NewsFlash).all()[0].id + return_value = update_news_flash_qualifying(id) + self.assertEqual(return_value.status_code, HTTPStatus.OK.value) + location_verifiction_history = ( + self.session.query(LocationVerificationHistory).all()[0].serialize() + ) + self.assertEqual(location_verifiction_history["user_id"], user_id) + saved_location = json.loads(location_verifiction_history["location_after_change"]) + saved_road_segment_id = saved_location["road_segment_id"] + saved_road_num = saved_location["road1"] + self.assertEqual(saved_road_segment_id, values["road_segment_id"]) + self.assertEqual(saved_road_num, float(values["road1"])) + self.assertEqual( + values["newsflash_location_qualification"], + location_verifiction_history["location_verification_after_change"], + ) @patch("flask_principal.Permission.can", return_value=True) @patch("flask_login.utils._get_user") @@ -123,13 +132,24 @@ def _test_update_news_flash_qualifying_manual_with_location(self): """ the test tries to change manually the road_segment_name of a news flash. """ + db_mock = unittest.mock.MagicMock() + db_mock.session = self.session mock_request = unittest.mock.MagicMock() - values = {"newsflash_location_qualification": "manual", "road_segment_id": 100, "road1": "1"} + values = {"newsflash_location_qualification": "manual", "road_segment_id": 100, "road1": 1} mock_request.values.get = lambda key: values.get(key) - with patch("anyway.views.news_flash.api.request", mock_request): - id = self.session.query(NewsFlash).all()[0].id - return_value = update_news_flash_qualifying(id) - self.assertEqual(return_value.status_code, HTTPStatus.OK.value, "1") + road_segment = RoadSegments( + segment_id=100, + road=1, + from_name='from_name', + to_name='to_name', + ) + db_mock.session.add(road_segment) + db_mock.session.commit() + with patch("anyway.app_and_db.db", db_mock): + with patch("anyway.views.news_flash.api.request", mock_request): + id = self.session.query(NewsFlash).all()[0].id + return_value = update_news_flash_qualifying(id) + self.assertEqual(return_value.status_code, HTTPStatus.OK.value, "1") def _test_update_news_flash_qualifying_manual_without_location(self): """ @@ -150,7 +170,7 @@ def _test_update_news_flash_qualifying_not_manual_with_location(self): also a new location """ mock_request = unittest.mock.MagicMock() - values = {"newsflash_location_qualification": "rejected", "road_segment_name": "road", "road1": "1"} + values = {"newsflash_location_qualification": "rejected", "road_segment_name": "road", "road1": 1} mock_request.values.get = lambda key: values.get(key) with patch("anyway.views.news_flash.api.request", mock_request): id = self.session.query(NewsFlash).all()[0].id