Skip to content

Commit

Permalink
[Feature] Support CopyPaste when mask is not available (#10509)
Browse files Browse the repository at this point in the history
  • Loading branch information
nijkah authored Jun 19, 2023
1 parent 43575e7 commit f5228ff
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 4 deletions.
33 changes: 30 additions & 3 deletions mmdet/datasets/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import copy
import inspect
import math
import warnings
from typing import List, Optional, Sequence, Tuple, Union

import cv2
Expand Down Expand Up @@ -3008,6 +3009,9 @@ class CopyPaste(BaseTransform):
all objects of the source image will be pasted to the
destination image.
Defaults to True.
paste_by_box (bool): Whether use boxes as masks when masks are not
available.
Defaults to False.
"""

def __init__(
Expand All @@ -3016,11 +3020,13 @@ def __init__(
bbox_occluded_thr: int = 10,
mask_occluded_thr: int = 300,
selected: bool = True,
paste_by_box: bool = False,
) -> None:
self.max_num_pasted = max_num_pasted
self.bbox_occluded_thr = bbox_occluded_thr
self.mask_occluded_thr = mask_occluded_thr
self.selected = selected
self.paste_by_box = paste_by_box

@cache_randomness
def get_indexes(self, dataset: BaseDataset) -> int:
Expand Down Expand Up @@ -3059,11 +3065,31 @@ def _get_selected_inds(self, num_bboxes: int) -> np.ndarray:
num_pasted = np.random.randint(0, max_num_pasted)
return np.random.choice(num_bboxes, size=num_pasted, replace=False)

def get_gt_masks(self, results: dict) -> BitmapMasks:
"""Get gt_masks originally or generated based on bboxes.
If gt_masks is not contained in results,
it will be generated based on gt_bboxes.
Args:
results (dict): Result dict.
Returns:
BitmapMasks: gt_masks, originally or generated based on bboxes.
"""
if results.get('gt_masks', None) is not None:
if self.paste_by_box:
warnings.warn('gt_masks is already contained in results, '
'so paste_by_box is disabled.')
return results['gt_masks']
else:
if not self.paste_by_box:
raise RuntimeError('results does not contain masks.')
return results['gt_bboxes'].create_masks(results['img'].shape[:2])

def _select_object(self, results: dict) -> dict:
"""Select some objects from the source results."""
bboxes = results['gt_bboxes']
labels = results['gt_bboxes_labels']
masks = results['gt_masks']
masks = self.get_gt_masks(results)
ignore_flags = results['gt_ignore_flags']

selected_inds = self._get_selected_inds(bboxes.shape[0])
Expand Down Expand Up @@ -3091,7 +3117,7 @@ def _copy_paste(self, dst_results: dict, src_results: dict) -> dict:
dst_img = dst_results['img']
dst_bboxes = dst_results['gt_bboxes']
dst_labels = dst_results['gt_bboxes_labels']
dst_masks = dst_results['gt_masks']
dst_masks = self.get_gt_masks(dst_results)
dst_ignore_flags = dst_results['gt_ignore_flags']

src_img = src_results['img']
Expand Down Expand Up @@ -3149,7 +3175,8 @@ def __repr__(self):
repr_str += f'(max_num_pasted={self.max_num_pasted}, '
repr_str += f'bbox_occluded_thr={self.bbox_occluded_thr}, '
repr_str += f'mask_occluded_thr={self.mask_occluded_thr}, '
repr_str += f'selected={self.selected})'
repr_str += f'selected={self.selected}), '
repr_str += f'paste_by_box={self.paste_by_box})'
return repr_str


Expand Down
20 changes: 20 additions & 0 deletions mmdet/structures/bbox/horizontal_boxes.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,26 @@ def find_inside_points(self,
return (points[..., 0] >= x_min) & (points[..., 0] <= x_max) & \
(points[..., 1] >= y_min) & (points[..., 1] <= y_max)

def create_masks(self, img_shape: Tuple[int, int]) -> BitmapMasks:
"""
Args:
img_shape (Tuple[int, int]): A tuple of image height and width.
Returns:
:obj:`BitmapMasks`: Converted masks
"""
img_h, img_w = img_shape
boxes = self.tensor

xmin, ymin = boxes[:, 0:1], boxes[:, 1:2]
xmax, ymax = boxes[:, 2:3], boxes[:, 3:4]
gt_masks = np.zeros((len(boxes), img_h, img_w), dtype=np.uint8)
for i in range(len(boxes)):
gt_masks[i,
int(ymin[i]):int(ymax[i]),
int(xmin[i]):int(xmax[i])] = 1
return BitmapMasks(gt_masks, img_h, img_w)

@staticmethod
def overlaps(boxes1: BaseBoxes,
boxes2: BaseBoxes,
Expand Down
23 changes: 22 additions & 1 deletion tests/test_datasets/test_transforms/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -1444,6 +1444,26 @@ def test_transform(self):
}]
results = transform(results)

# test copypaste with an empty mask results
transform = CopyPaste()
results = copy.deepcopy(self.dst_results)
results = {k: v for k, v in results.items() if 'mask' not in k}
results['mix_results'] = [copy.deepcopy(self.src_results)]
with self.assertRaises(RuntimeError):
results = transform(results)

# test copypaste with boxes as masks
transform = CopyPaste(paste_by_box=True)
results = copy.deepcopy(self.dst_results)
results = {k: v for k, v in results.items() if 'mask' not in k}
src_results = copy.deepcopy(self.src_results)
src_results = {k: v for k, v in src_results.items() if 'mask' not in k}
results['mix_results'] = [src_results]
results = transform(results)

self.assertEqual(results['img'].shape[:2],
self.dst_results['img'].shape[:2])

def test_transform_use_box_type(self):
src_results = copy.deepcopy(self.src_results)
src_results['gt_bboxes'] = HorizontalBoxes(src_results['gt_bboxes'])
Expand Down Expand Up @@ -1515,7 +1535,8 @@ def test_repr(self):
repr(transform), ('CopyPaste(max_num_pasted=100, '
'bbox_occluded_thr=10, '
'mask_occluded_thr=300, '
'selected=True)'))
'selected=True), '
'paste_by_box=False)'))


class TestAlbu(unittest.TestCase):
Expand Down

0 comments on commit f5228ff

Please sign in to comment.