From 43575e761508719a30239ab0e918a834f0ec33e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=BB=BB=E7=A5=89=E6=B6=B5?= <52252114+Renzhihan@users.noreply.github.com> Date: Mon, 19 Jun 2023 10:26:42 +0800 Subject: [PATCH] [Feature] Support iSAID dataset (#10028) Co-authored-by: huanghaian --- configs/_base_/datasets/isaid_instance.py | 59 +++++++++++++ mmdet/datasets/__init__.py | 3 +- mmdet/datasets/isaid.py | 25 ++++++ projects/iSAID/README.md | 85 +++++++++++++++++++ projects/iSAID/README_zh-CN.md | 85 +++++++++++++++++++ .../configs/mask_rcnn_r50_fpn_1x_isaid.py | 6 ++ projects/iSAID/isaid_json.py | 29 +++++++ 7 files changed, 291 insertions(+), 1 deletion(-) create mode 100644 configs/_base_/datasets/isaid_instance.py create mode 100644 mmdet/datasets/isaid.py create mode 100644 projects/iSAID/README.md create mode 100644 projects/iSAID/README_zh-CN.md create mode 100644 projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py create mode 100644 projects/iSAID/isaid_json.py diff --git a/configs/_base_/datasets/isaid_instance.py b/configs/_base_/datasets/isaid_instance.py new file mode 100644 index 00000000000..09ddcab02bd --- /dev/null +++ b/configs/_base_/datasets/isaid_instance.py @@ -0,0 +1,59 @@ +# dataset settings +dataset_type = 'iSAIDDataset' +data_root = 'data/iSAID/' +backend_args = None + +# Please see `projects/iSAID/README.md` for data preparation +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict(type='Resize', scale=(800, 800), keep_ratio=True), + dict(type='RandomFlip', prob=0.5), + dict(type='PackDetInputs') +] +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args=backend_args), + dict(type='Resize', scale=(800, 800), keep_ratio=True), + dict(type='LoadAnnotations', with_bbox=True, with_mask=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] +train_dataloader = dict( + batch_size=2, + num_workers=2, + persistent_workers=True, + sampler=dict(type='DefaultSampler', shuffle=True), + batch_sampler=dict(type='AspectRatioBatchSampler'), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='train/instancesonly_filtered_train.json', + data_prefix=dict(img='train/images/'), + filter_cfg=dict(filter_empty_gt=True, min_size=32), + pipeline=train_pipeline, + backend_args=backend_args)) +val_dataloader = dict( + batch_size=1, + num_workers=2, + persistent_workers=True, + drop_last=False, + sampler=dict(type='DefaultSampler', shuffle=False), + dataset=dict( + type=dataset_type, + data_root=data_root, + ann_file='val/instancesonly_filtered_val.json', + data_prefix=dict(img='val/images/'), + test_mode=True, + pipeline=test_pipeline, + backend_args=backend_args)) +test_dataloader = val_dataloader + +val_evaluator = dict( + type='CocoMetric', + ann_file=data_root + 'val/instancesonly_filtered_val.json', + metric=['bbox', 'segm'], + format_only=False, + backend_args=backend_args) +test_evaluator = val_evaluator diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py index 303ea81a32b..3bc16f9636a 100644 --- a/mmdet/datasets/__init__.py +++ b/mmdet/datasets/__init__.py @@ -13,6 +13,7 @@ from .dataset_wrappers import MultiImageMixDataset from .deepfashion import DeepFashionDataset from .dsdl import DSDLDetDataset +from .isaid import iSAIDDataset from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset from .mot_challenge_dataset import MOTChallengeDataset from .objects365 import Objects365V1Dataset, Objects365V2Dataset @@ -40,5 +41,5 @@ 'ReIDDataset', 'YouTubeVISDataset', 'TrackAspectRatioBatchSampler', 'ADE20KPanopticDataset', 'CocoCaptionDataset', 'RefCocoDataset', 'BaseSegDataset', 'ADE20KSegDataset', 'CocoSegDataset', - 'ADE20KInstanceDataset' + 'ADE20KInstanceDataset', 'iSAIDDataset' ] diff --git a/mmdet/datasets/isaid.py b/mmdet/datasets/isaid.py new file mode 100644 index 00000000000..87067d8459c --- /dev/null +++ b/mmdet/datasets/isaid.py @@ -0,0 +1,25 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from mmdet.registry import DATASETS +from .coco import CocoDataset + + +@DATASETS.register_module() +class iSAIDDataset(CocoDataset): + """Dataset for iSAID instance segmentation. + + iSAID: A Large-scale Dataset for Instance Segmentation + in Aerial Images. + + For more detail, please refer to "projects/iSAID/README.md" + """ + + METAINFO = dict( + classes=('background', 'ship', 'store_tank', 'baseball_diamond', + 'tennis_court', 'basketball_court', 'Ground_Track_Field', + 'Bridge', 'Large_Vehicle', 'Small_Vehicle', 'Helicopter', + 'Swimming_pool', 'Roundabout', 'Soccer_ball_field', 'plane', + 'Harbor'), + palette=[(0, 0, 0), (0, 0, 63), (0, 63, 63), (0, 63, 0), (0, 63, 127), + (0, 63, 191), (0, 63, 255), (0, 127, 63), (0, 127, 127), + (0, 0, 127), (0, 0, 191), (0, 0, 255), (0, 191, 127), + (0, 127, 191), (0, 127, 255), (0, 100, 155)]) diff --git a/projects/iSAID/README.md b/projects/iSAID/README.md new file mode 100644 index 00000000000..80505e46299 --- /dev/null +++ b/projects/iSAID/README.md @@ -0,0 +1,85 @@ +# iSAID Dataset + +> **iSAID**: A Large-scale Dataset for Instance Segmentation in Aerial Images + +## Introduction + +Existing Earth Vision datasets are either suitable for semantic segmentation or object detection. iSAID is the first benchmark dataset for instance segmentation in aerial images. This large-scale and densely annotated dataset contains 655,451 object instances for 15 categories across 2,806 high-resolution images. The distinctive characteristics of iSAID are the following: (a) large number of images with high spatial resolution, (b) fifteen important and commonly occurring categories, (c) large number of instances per category, (d) large count of labelled instances per image, which might help in learning contextual information, (e) huge object scale variation, containing small, medium and large objects, often within the same image, (f) Imbalanced and uneven distribution of objects with varying orientation within images, depicting real-life aerial conditions, (g) several small size objects, with ambiguous appearance, can only be resolved with contextual reasoning, (h) precise instance-level annotations carried out by professional annotators, cross-checked and validated by expert annotators complying with well-defined guidelines. + +For more detail, please refer to our [paper](http://openaccess.thecvf.com/content_CVPRW_2019/papers/DOAI/Zamir_iSAID_A_Large-scale_Dataset_for_Instance_Segmentation_in_Aerial_Images_CVPRW_2019_paper.pdf) . + +## Prepare + +iSAID download link:[Image](https://captain-whu.github.io/DOTA/dataset.html)、[Annotation](https://captain-whu.github.io/iSAID/dataset.html) +Please follow the steps as described in the [official repository](https://github.com/CAPTAIN-WHU/iSAID_Devkit) to preprocess the data (`patch_width`=800,`patch_height`=800,`overlap_area`=200). The final folder format should be as follows. + +``` +iSAID_patches +├── test +│ └── images +│ ├── P0006_0_0_800_800.png +│ └── ... +│ └── P0009_0_0_800_800.png +├── train +│ └── instance_only_filtered_train.json +│ └── images +│ ├── P0002_0_0_800_800_instance_color_RGB.png +│ ├── P0002_0_0_800_800_instance_id_RGB.png +│ ├── P0002_0_800_800.png +│ ├── ... +│ ├── P0010_0_0_800_800_instance_color_RGB.png +│ ├── P0010_0_0_800_800_instance_id_RGB.png +│ └── P0010_0_800_800.png +└── val + └── instance_only_filtered_val.json + └── images + ├── P0003_0_0_800_800_instance_color_RGB.png + ├── P0003_0_0_800_800_instance_id_RGB.png + ├── P0003_0_0_800_800.png + ├── ... + ├── P0004_0_0_800_800_instance_color_RGB.png + ├── P0004_0_0_800_800_instance_id_RGB.png + └── P0004_0_0_800_800.png +``` + +After that, use the following command in the mmdetection directory to convert the json file format. + +``` +python projects/iSAID/isaid_json.py /path/to/iSAID +``` + +## Usage + +### Train + +```python +python tools/train.py projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py +``` + +### Test + +```python +python tools/test.py projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py ${CHECKPOINT_PATH} +``` + +## Citation + +``` +@inproceedings{waqas2019isaid, +title={iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images}, +author={Waqas Zamir, Syed and Arora, Aditya and Gupta, Akshita and Khan, Salman and Sun, Guolei and Shahbaz Khan, Fahad and Zhu, Fan and Shao, Ling and Xia, Gui-Song and Bai, Xiang}, +booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops}, +pages={28--37}, +year={2019} +} +``` + +``` +@InProceedings{Xia_2018_CVPR, +author = {Xia, Gui-Song and Bai, Xiang and Ding, Jian and Zhu, Zhen and Belongie, Serge and Luo, Jiebo and Datcu, Mihai and Pelillo, Marcello and Zhang, Liangpei}, +title = {DOTA: A Large-Scale Dataset for Object Detection in Aerial Images}, +booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, +month = {June}, +year = {2018} +} +``` diff --git a/projects/iSAID/README_zh-CN.md b/projects/iSAID/README_zh-CN.md new file mode 100644 index 00000000000..3481cae3d7b --- /dev/null +++ b/projects/iSAID/README_zh-CN.md @@ -0,0 +1,85 @@ +# iSAID数据集 + +> **iSAID**: A Large-scale Dataset for Instance Segmentation in Aerial Images + +## 数据集介绍 + +Existing Earth Vision datasets are either suitable for semantic segmentation or object detection. iSAID is the first benchmark dataset for instance segmentation in aerial images. This large-scale and densely annotated dataset contains 655,451 object instances for 15 categories across 2,806 high-resolution images. The distinctive characteristics of iSAID are the following: (a) large number of images with high spatial resolution, (b) fifteen important and commonly occurring categories, (c) large number of instances per category, (d) large count of labelled instances per image, which might help in learning contextual information, (e) huge object scale variation, containing small, medium and large objects, often within the same image, (f) Imbalanced and uneven distribution of objects with varying orientation within images, depicting real-life aerial conditions, (g) several small size objects, with ambiguous appearance, can only be resolved with contextual reasoning, (h) precise instance-level annotations carried out by professional annotators, cross-checked and validated by expert annotators complying with well-defined guidelines. + +For more detail, please refer to our [paper](http://openaccess.thecvf.com/content_CVPRW_2019/papers/DOAI/Zamir_iSAID_A_Large-scale_Dataset_for_Instance_Segmentation_in_Aerial_Images_CVPRW_2019_paper.pdf) . + +## 数据集准备 + +iSAID数据集下载链接:[图像数据](https://captain-whu.github.io/DOTA/dataset.html)、[标注数据](https://captain-whu.github.io/iSAID/dataset.html) +请按照[官方仓库](https://github.com/CAPTAIN-WHU/iSAID_Devkit)中所述步骤进行数据预处理(`patch_width`=800,`patch_height`=800,`overlap_area`=200),最终得到的文件夹格式为 + +``` +iSAID_patches +├── test +│ └── images +│ ├── P0006_0_0_800_800.png +│ └── ... +│ └── P0009_0_0_800_800.png +├── train +│ └── instance_only_filtered_train.json +│ └── images +│ ├── P0002_0_0_800_800_instance_color_RGB.png +│ ├── P0002_0_0_800_800_instance_id_RGB.png +│ ├── P0002_0_800_800.png +│ ├── ... +│ ├── P0010_0_0_800_800_instance_color_RGB.png +│ ├── P0010_0_0_800_800_instance_id_RGB.png +│ └── P0010_0_800_800.png +└── val + └── instance_only_filtered_val.json + └── images + ├── P0003_0_0_800_800_instance_color_RGB.png + ├── P0003_0_0_800_800_instance_id_RGB.png + ├── P0003_0_0_800_800.png + ├── ... + ├── P0004_0_0_800_800_instance_color_RGB.png + ├── P0004_0_0_800_800_instance_id_RGB.png + └── P0004_0_0_800_800.png +``` + +之后,在mmdetection目录下使用以下命令转换json文件格式 + +``` +python projects/iSAID/isaid_json.py /path/to/iSAID +``` + +## 使用方法 + +### 训练 + +```python +python tools/train.py projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py +``` + +### 测试 + +```python +python tools/test.py projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py ${CHECKPOINT_PATH} +``` + +## Citation + +``` +@inproceedings{waqas2019isaid, +title={iSAID: A Large-scale Dataset for Instance Segmentation in Aerial Images}, +author={Waqas Zamir, Syed and Arora, Aditya and Gupta, Akshita and Khan, Salman and Sun, Guolei and Shahbaz Khan, Fahad and Zhu, Fan and Shao, Ling and Xia, Gui-Song and Bai, Xiang}, +booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition Workshops}, +pages={28--37}, +year={2019} +} +``` + +``` +@InProceedings{Xia_2018_CVPR, +author = {Xia, Gui-Song and Bai, Xiang and Ding, Jian and Zhu, Zhen and Belongie, Serge and Luo, Jiebo and Datcu, Mihai and Pelillo, Marcello and Zhang, Liangpei}, +title = {DOTA: A Large-Scale Dataset for Object Detection in Aerial Images}, +booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, +month = {June}, +year = {2018} +} +``` diff --git a/projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py b/projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py new file mode 100644 index 00000000000..ee1cb27e4e2 --- /dev/null +++ b/projects/iSAID/configs/mask_rcnn_r50_fpn_1x_isaid.py @@ -0,0 +1,6 @@ +_base_ = [ + '../../../configs/_base_/models/mask-rcnn_r50_fpn.py', + '../../../configs/_base_/datasets/isaid_instance.py', + '../../../configs/_base_/schedules/schedule_1x.py', + '../../../configs/_base_/default_runtime.py' +] diff --git a/projects/iSAID/isaid_json.py b/projects/iSAID/isaid_json.py new file mode 100644 index 00000000000..95b8f089b04 --- /dev/null +++ b/projects/iSAID/isaid_json.py @@ -0,0 +1,29 @@ +import argparse +import json +import os.path as osp + + +def json_convert(path): + with open(path, 'r+') as f: + coco_data = json.load(f) + coco_data['categories'].append({'id': 0, 'name': 'background'}) + coco_data['categories'] = sorted( + coco_data['categories'], key=lambda x: x['id']) + f.seek(0) + json.dump(coco_data, f) + f.truncate() + + +if __name__ == '__main__': + parser = argparse.ArgumentParser( + description='Convert iSAID dataset to mmdetection format') + parser.add_argument('dataset_path', help='iSAID folder path') + + args = parser.parse_args() + dataset_path = args.dataset_path + json_list = ['train', 'val'] + for dataset_mode in ['train', 'val']: + json_file = 'instancesonly_filtered_' + dataset_mode + '.json' + json_file_path = osp.join(dataset_path, dataset_mode, json_file) + assert osp.exists(json_file_path), f'train is not in {dataset_path}' + json_convert(json_file_path)