diff --git a/mmdet/datasets/transforms/transforms.py b/mmdet/datasets/transforms/transforms.py index a03b90be135..9d1c1ed71ab 100644 --- a/mmdet/datasets/transforms/transforms.py +++ b/mmdet/datasets/transforms/transforms.py @@ -2,6 +2,7 @@ import copy import inspect import math +import warnings from typing import List, Optional, Sequence, Tuple, Union import cv2 @@ -3008,6 +3009,9 @@ class CopyPaste(BaseTransform): all objects of the source image will be pasted to the destination image. Defaults to True. + paste_by_box (bool): Whether use boxes as masks when masks are not + available. + Defaults to False. """ def __init__( @@ -3016,11 +3020,13 @@ def __init__( bbox_occluded_thr: int = 10, mask_occluded_thr: int = 300, selected: bool = True, + paste_by_box: bool = False, ) -> None: self.max_num_pasted = max_num_pasted self.bbox_occluded_thr = bbox_occluded_thr self.mask_occluded_thr = mask_occluded_thr self.selected = selected + self.paste_by_box = paste_by_box @cache_randomness def get_indexes(self, dataset: BaseDataset) -> int: @@ -3059,11 +3065,31 @@ def _get_selected_inds(self, num_bboxes: int) -> np.ndarray: num_pasted = np.random.randint(0, max_num_pasted) return np.random.choice(num_bboxes, size=num_pasted, replace=False) + def get_gt_masks(self, results: dict) -> BitmapMasks: + """Get gt_masks originally or generated based on bboxes. + + If gt_masks is not contained in results, + it will be generated based on gt_bboxes. + Args: + results (dict): Result dict. + Returns: + BitmapMasks: gt_masks, originally or generated based on bboxes. + """ + if results.get('gt_masks', None) is not None: + if self.paste_by_box: + warnings.warn('gt_masks is already contained in results, ' + 'so paste_by_box is disabled.') + return results['gt_masks'] + else: + if not self.paste_by_box: + raise RuntimeError('results does not contain masks.') + return results['gt_bboxes'].create_masks(results['img'].shape[:2]) + def _select_object(self, results: dict) -> dict: """Select some objects from the source results.""" bboxes = results['gt_bboxes'] labels = results['gt_bboxes_labels'] - masks = results['gt_masks'] + masks = self.get_gt_masks(results) ignore_flags = results['gt_ignore_flags'] selected_inds = self._get_selected_inds(bboxes.shape[0]) @@ -3091,7 +3117,7 @@ def _copy_paste(self, dst_results: dict, src_results: dict) -> dict: dst_img = dst_results['img'] dst_bboxes = dst_results['gt_bboxes'] dst_labels = dst_results['gt_bboxes_labels'] - dst_masks = dst_results['gt_masks'] + dst_masks = self.get_gt_masks(dst_results) dst_ignore_flags = dst_results['gt_ignore_flags'] src_img = src_results['img'] @@ -3149,7 +3175,8 @@ def __repr__(self): repr_str += f'(max_num_pasted={self.max_num_pasted}, ' repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, ' repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, ' - repr_str += f'selected={self.selected})' + repr_str += f'selected={self.selected}), ' + repr_str += f'paste_by_box={self.paste_by_box})' return repr_str diff --git a/mmdet/structures/bbox/horizontal_boxes.py b/mmdet/structures/bbox/horizontal_boxes.py index 360c8a24e0b..b3a78518105 100644 --- a/mmdet/structures/bbox/horizontal_boxes.py +++ b/mmdet/structures/bbox/horizontal_boxes.py @@ -335,6 +335,26 @@ def find_inside_points(self, return (points[..., 0] >= x_min) & (points[..., 0] <= x_max) & \ (points[..., 1] >= y_min) & (points[..., 1] <= y_max) + def create_masks(self, img_shape: Tuple[int, int]) -> BitmapMasks: + """ + Args: + img_shape (Tuple[int, int]): A tuple of image height and width. + + Returns: + :obj:`BitmapMasks`: Converted masks + """ + img_h, img_w = img_shape + boxes = self.tensor + + xmin, ymin = boxes[:, 0:1], boxes[:, 1:2] + xmax, ymax = boxes[:, 2:3], boxes[:, 3:4] + gt_masks = np.zeros((len(boxes), img_h, img_w), dtype=np.uint8) + for i in range(len(boxes)): + gt_masks[i, + int(ymin[i]):int(ymax[i]), + int(xmin[i]):int(xmax[i])] = 1 + return BitmapMasks(gt_masks, img_h, img_w) + @staticmethod def overlaps(boxes1: BaseBoxes, boxes2: BaseBoxes, diff --git a/tests/test_datasets/test_transforms/test_transforms.py b/tests/test_datasets/test_transforms/test_transforms.py index e36f518aa8b..134e5de8a7c 100644 --- a/tests/test_datasets/test_transforms/test_transforms.py +++ b/tests/test_datasets/test_transforms/test_transforms.py @@ -1444,6 +1444,26 @@ def test_transform(self): }] results = transform(results) + # test copypaste with an empty mask results + transform = CopyPaste() + results = copy.deepcopy(self.dst_results) + results = {k: v for k, v in results.items() if 'mask' not in k} + results['mix_results'] = [copy.deepcopy(self.src_results)] + with self.assertRaises(RuntimeError): + results = transform(results) + + # test copypaste with boxes as masks + transform = CopyPaste(paste_by_box=True) + results = copy.deepcopy(self.dst_results) + results = {k: v for k, v in results.items() if 'mask' not in k} + src_results = copy.deepcopy(self.src_results) + src_results = {k: v for k, v in src_results.items() if 'mask' not in k} + results['mix_results'] = [src_results] + results = transform(results) + + self.assertEqual(results['img'].shape[:2], + self.dst_results['img'].shape[:2]) + def test_transform_use_box_type(self): src_results = copy.deepcopy(self.src_results) src_results['gt_bboxes'] = HorizontalBoxes(src_results['gt_bboxes']) @@ -1515,7 +1535,8 @@ def test_repr(self): repr(transform), ('CopyPaste(max_num_pasted=100, ' 'bbox_occluded_thr=10, ' 'mask_occluded_thr=300, ' - 'selected=True)')) + 'selected=True), ' + 'paste_by_box=False)')) class TestAlbu(unittest.TestCase):