Skip to content

Commit

Permalink
[Feature] Support iSAID dataset (#10028)
Browse files Browse the repository at this point in the history
Co-authored-by: huanghaian <huanghaian@sensetime.com>
  • Loading branch information
Renzhihan and hhaAndroid authored Jun 19, 2023
1 parent 04d0b5e commit 43575e7
Show file tree
Hide file tree
Showing 7 changed files with 291 additions and 1 deletion.
59 changes: 59 additions & 0 deletions configs/_base_/datasets/isaid_instance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# dataset settings
dataset_type = 'iSAIDDataset'
data_root = 'data/iSAID/'
backend_args = None

# Please see `projects/iSAID/README.md` for data preparation
train_pipeline = [
dict(type='LoadImageFromFile', backend_args=backend_args),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(type='Resize', scale=(800, 800), keep_ratio=True),
dict(type='RandomFlip', prob=0.5),
dict(type='PackDetInputs')
]
test_pipeline = [
dict(type='LoadImageFromFile', backend_args=backend_args),
dict(type='Resize', scale=(800, 800), keep_ratio=True),
dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
dict(
type='PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor'))
]
train_dataloader = dict(
batch_size=2,
num_workers=2,
persistent_workers=True,
sampler=dict(type='DefaultSampler', shuffle=True),
batch_sampler=dict(type='AspectRatioBatchSampler'),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='train/instancesonly_filtered_train.json',
data_prefix=dict(img='train/images/'),
filter_cfg=dict(filter_empty_gt=True, min_size=32),
pipeline=train_pipeline,
backend_args=backend_args))
val_dataloader = dict(
batch_size=1,
num_workers=2,
persistent_workers=True,
drop_last=False,
sampler=dict(type='DefaultSampler', shuffle=False),
dataset=dict(
type=dataset_type,
data_root=data_root,
ann_file='val/instancesonly_filtered_val.json',
data_prefix=dict(img='val/images/'),
test_mode=True,
pipeline=test_pipeline,
backend_args=backend_args))
test_dataloader = val_dataloader

val_evaluator = dict(
type='CocoMetric',
ann_file=data_root + 'val/instancesonly_filtered_val.json',
metric=['bbox', 'segm'],
format_only=False,
backend_args=backend_args)
test_evaluator = val_evaluator
3 changes: 2 additions & 1 deletion mmdet/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .dataset_wrappers import MultiImageMixDataset
from .deepfashion import DeepFashionDataset
from .dsdl import DSDLDetDataset
from .isaid import iSAIDDataset
from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
from .mot_challenge_dataset import MOTChallengeDataset
from .objects365 import Objects365V1Dataset, Objects365V2Dataset
Expand Down Expand Up @@ -40,5 +41,5 @@
'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler',
'ADE20KPanopticDataset', 'CocoCaptionDataset', 'RefCocoDataset',
'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDataset',
'ADE20KInstanceDataset'
'ADE20KInstanceDataset', 'iSAIDDataset'
]
25 changes: 25 additions & 0 deletions mmdet/datasets/isaid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) OpenMMLab. All rights reserved.
from mmdet.registry import DATASETS
from .coco import CocoDataset


@DATASETS.register_module()
class iSAIDDataset(CocoDataset):
"""Dataset for iSAID instance segmentation.
iSAID: A Large-scale Dataset for Instance Segmentation
in Aerial Images.
For more detail, please refer to "projects/iSAID/README.md"
"""

METAINFO = dict(
classes=('background', 'ship', 'store_tank', 'baseball_diamond',
'tennis_court', 'basketball_court', 'Ground_Track_Field',
'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter',
'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane',
'Harbor'),
palette=[(0, 0, 0), (0, 0, 63), (0, 63, 63), (0, 63, 0), (0, 63, 127),
(0, 63, 191), (0, 63, 255), (0, 127, 63), (0, 127, 127),
(0, 0, 127), (0, 0, 191), (0, 0, 255), (0, 191, 127),
(0, 127, 191), (0, 127, 255), (0, 100, 155)])
85 changes: 85 additions & 0 deletions projects/iSAID/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# iSAID Dataset

> **iSAID**: A Large-scale Dataset for Instance Segmentation in Aerial Images
## Introduction

Existing Earth Vision datasets are either suitable for semantic segmentation or object detection. iSAID is the first benchmark dataset for instance segmentation in aerial images. This large-scale and densely annotated dataset contains 655,451 object instances for 15 categories across 2,806 high-resolution images. The distinctive characteristics of iSAID are the following: (a) large number of images with high spatial resolution, (b) fifteen important and commonly occurring categories, (c) large number of instances per category, (d) large count of labelled instances per image, which might help in learning contextual information, (e) huge object scale variation, containing small, medium and large objects, often within the same image, (f) Imbalanced and uneven distribution of objects with varying orientation within images, depicting real-life aerial conditions, (g) several small size objects, with ambiguous appearance, can only be resolved with contextual reasoning, (h) precise instance-level annotations carried out by professional annotators, cross-checked and validated by expert annotators complying with well-defined guidelines.

For more detail, please refer to our [paper](http://openaccess.thecvf.com/content_CVPRW_2019/papers/DOAI/Zamir_iSAID_A_Large-scale_Dataset_for_Instance_Segmentation_in_Aerial_Images_CVPRW_2019_paper.pdf) .

## Prepare

iSAID download link:[Image](https://captain-whu.github.io/DOTA/dataset.html)[Annotation](https://captain-whu.github.io/iSAID/dataset.html)
Please follow the steps as described in the [official repository](https://github.com/CAPTAIN-WHU/iSAID_Devkit) to preprocess the data (`patch_width`=800,`patch_height`=800,`overlap_area`=200). The final folder format should be as follows.

```
iSAID_patches
├── test
│ └── images
│ ├── P0006_0_0_800_800.png
│ └── ...
│ └── P0009_0_0_800_800.png
├── train
│ └── instance_only_filtered_train.json
│ └── images
│ ├── P0002_0_0_800_800_instance_color_RGB.png
│ ├── P0002_0_0_800_800_instance_id_RGB.png
│ ├── P0002_0_800_800.png
│ ├── ...
│ ├── P0010_0_0_800_800_instance_color_RGB.png
│ ├── P0010_0_0_800_800_instance_id_RGB.png
│ └── P0010_0_800_800.png
└── val
└── instance_only_filtered_val.json
└── images
├── P0003_0_0_800_800_instance_color_RGB.png
├── P0003_0_0_800_800_instance_id_RGB.png
├── P0003_0_0_800_800.png
├── ...
├── P0004_0_0_800_800_instance_color_RGB.png
├── P0004_0_0_800_800_instance_id_RGB.png
└── P0004_0_0_800_800.png
```

After that, use the following command in the mmdetection directory to convert the json file format.

```
python projects/iSAID/isaid_json.py /path/to/iSAID
```

## Usage

### Train

```python
python tools/train.py projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py
```

### Test

```python
python tools/test.py projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py ${CHECKPOINT_PATH}
```

## Citation

```
@inproceedings{waqas2019isaid,
title={iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images},
author={Waqas Zamir, Syed and Arora, Aditya and Gupta, Akshita and Khan, Salman and Sun, Guolei and Shahbaz Khan, Fahad and Zhu, Fan and Shao, Ling and Xia, Gui-Song and Bai, Xiang},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops},
pages={28--37},
year={2019}
}
```

```
@InProceedings{Xia_2018_CVPR,
author = {Xia, Gui-Song and Bai, Xiang and Ding, Jian and Zhu, Zhen and Belongie, Serge and Luo, Jiebo and Datcu, Mihai and Pelillo, Marcello and Zhang, Liangpei},
title = {DOTA: A Large-Scale Dataset for Object Detection in Aerial Images},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2018}
}
```
85 changes: 85 additions & 0 deletions projects/iSAID/README_zh-CN.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# iSAID数据集

> **iSAID**: A Large-scale Dataset for Instance Segmentation in Aerial Images
## 数据集介绍

Existing Earth Vision datasets are either suitable for semantic segmentation or object detection. iSAID is the first benchmark dataset for instance segmentation in aerial images. This large-scale and densely annotated dataset contains 655,451 object instances for 15 categories across 2,806 high-resolution images. The distinctive characteristics of iSAID are the following: (a) large number of images with high spatial resolution, (b) fifteen important and commonly occurring categories, (c) large number of instances per category, (d) large count of labelled instances per image, which might help in learning contextual information, (e) huge object scale variation, containing small, medium and large objects, often within the same image, (f) Imbalanced and uneven distribution of objects with varying orientation within images, depicting real-life aerial conditions, (g) several small size objects, with ambiguous appearance, can only be resolved with contextual reasoning, (h) precise instance-level annotations carried out by professional annotators, cross-checked and validated by expert annotators complying with well-defined guidelines.

For more detail, please refer to our [paper](http://openaccess.thecvf.com/content_CVPRW_2019/papers/DOAI/Zamir_iSAID_A_Large-scale_Dataset_for_Instance_Segmentation_in_Aerial_Images_CVPRW_2019_paper.pdf) .

## 数据集准备

iSAID数据集下载链接:[图像数据](https://captain-whu.github.io/DOTA/dataset.html)[标注数据](https://captain-whu.github.io/iSAID/dataset.html)
请按照[官方仓库](https://github.com/CAPTAIN-WHU/iSAID_Devkit)中所述步骤进行数据预处理(`patch_width`=800,`patch_height`=800,`overlap_area`=200),最终得到的文件夹格式为

```
iSAID_patches
├── test
│ └── images
│ ├── P0006_0_0_800_800.png
│ └── ...
│ └── P0009_0_0_800_800.png
├── train
│ └── instance_only_filtered_train.json
│ └── images
│ ├── P0002_0_0_800_800_instance_color_RGB.png
│ ├── P0002_0_0_800_800_instance_id_RGB.png
│ ├── P0002_0_800_800.png
│ ├── ...
│ ├── P0010_0_0_800_800_instance_color_RGB.png
│ ├── P0010_0_0_800_800_instance_id_RGB.png
│ └── P0010_0_800_800.png
└── val
└── instance_only_filtered_val.json
└── images
├── P0003_0_0_800_800_instance_color_RGB.png
├── P0003_0_0_800_800_instance_id_RGB.png
├── P0003_0_0_800_800.png
├── ...
├── P0004_0_0_800_800_instance_color_RGB.png
├── P0004_0_0_800_800_instance_id_RGB.png
└── P0004_0_0_800_800.png
```

之后,在mmdetection目录下使用以下命令转换json文件格式

```
python projects/iSAID/isaid_json.py /path/to/iSAID
```

## 使用方法

### 训练

```python
python tools/train.py projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py
```

### 测试

```python
python tools/test.py projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py ${CHECKPOINT_PATH}
```

## Citation

```
@inproceedings{waqas2019isaid,
title={iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images},
author={Waqas Zamir, Syed and Arora, Aditya and Gupta, Akshita and Khan, Salman and Sun, Guolei and Shahbaz Khan, Fahad and Zhu, Fan and Shao, Ling and Xia, Gui-Song and Bai, Xiang},
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops},
pages={28--37},
year={2019}
}
```

```
@InProceedings{Xia_2018_CVPR,
author = {Xia, Gui-Song and Bai, Xiang and Ding, Jian and Zhu, Zhen and Belongie, Serge and Luo, Jiebo and Datcu, Mihai and Pelillo, Marcello and Zhang, Liangpei},
title = {DOTA: A Large-Scale Dataset for Object Detection in Aerial Images},
booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2018}
}
```
6 changes: 6 additions & 0 deletions projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
_base_ = [
'../../../configs/_base_/models/mask-rcnn_r50_fpn.py',
'../../../configs/_base_/datasets/isaid_instance.py',
'../../../configs/_base_/schedules/schedule_1x.py',
'../../../configs/_base_/default_runtime.py'
]
29 changes: 29 additions & 0 deletions projects/iSAID/isaid_json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import argparse
import json
import os.path as osp


def json_convert(path):
with open(path, 'r+') as f:
coco_data = json.load(f)
coco_data['categories'].append({'id': 0, 'name': 'background'})
coco_data['categories'] = sorted(
coco_data['categories'], key=lambda x: x['id'])
f.seek(0)
json.dump(coco_data, f)
f.truncate()


if __name__ == '__main__':
parser = argparse.ArgumentParser(
description='Convert iSAID dataset to mmdetection format')
parser.add_argument('dataset_path', help='iSAID folder path')

args = parser.parse_args()
dataset_path = args.dataset_path
json_list = ['train', 'val']
for dataset_mode in ['train', 'val']:
json_file = 'instancesonly_filtered_' + dataset_mode + '.json'
json_file_path = osp.join(dataset_path, dataset_mode, json_file)
assert osp.exists(json_file_path), f'train is not in {dataset_path}'
json_convert(json_file_path)

0 comments on commit 43575e7

Please sign in to comment.