From d2b238e69bb786e45af40fa867767bf6d133633a Mon Sep 17 00:00:00 2001 From: takuoko Date: Fri, 22 Dec 2023 19:01:22 +0900 Subject: [PATCH] [Feature] Add RTMDet Swin / ConvNeXt (#11259) --- configs/rtmdet/README.md | 19 +-- configs/rtmdet/metafile.yml | 42 +++++++ .../rtmdet_l_convnext_b_4xb32-100e_coco.py | 81 +++++++++++++ .../rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py | 78 ++++++++++++ .../rtmdet_l_swin_b_p6_4xb16-100e_coco.py | 114 ++++++++++++++++++ 5 files changed, 326 insertions(+), 8 deletions(-) create mode 100644 configs/rtmdet/rtmdet_l_convnext_b_4xb32-100e_coco.py create mode 100644 configs/rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py create mode 100644 configs/rtmdet/rtmdet_l_swin_b_p6_4xb16-100e_coco.py diff --git a/configs/rtmdet/README.md b/configs/rtmdet/README.md index 4574dd613c1..1677184af76 100644 --- a/configs/rtmdet/README.md +++ b/configs/rtmdet/README.md @@ -20,14 +20,17 @@ In this paper, we aim to design an efficient real-time object detector that exce ### Object Detection -| Model | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms)
RTX3090 | TRT-FP16-Latency(ms)
T4 | Config | Download | -| :---------: | :--: | :----: | :-------: | :------: | :-----------------------------: | :------------------------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| RTMDet-tiny | 640 | 41.1 | 4.8 | 8.1 | 0.98 | 2.34 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) | -| RTMDet-s | 640 | 44.6 | 8.89 | 14.8 | 1.22 | 2.96 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) | -| RTMDet-m | 640 | 49.4 | 24.71 | 39.27 | 1.62 | 6.41 | [config](./rtmdet_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220.log.json) | -| RTMDet-l | 640 | 51.5 | 52.3 | 80.23 | 2.44 | 10.32 | [config](./rtmdet_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030.log.json) | -| RTMDet-x | 640 | 52.8 | 94.86 | 141.67 | 3.10 | 18.80 | [config](./rtmdet_x_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555.log.json) | -| RTMDet-x-P6 | 1280 | 54.9 | | | | | [config](./rtmdet_x_p6_4xb8-300e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-p6/rtmdet_x_p6_4xb8-300e_coco-bf32be58.pth) | +| Model | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms)
RTX3090 | TRT-FP16-Latency(ms)
T4 | Config | Download | +| :-----------------: | :--: | :----: | :-------: | :------: | :-----------------------------: | :------------------------: | :------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| RTMDet-tiny | 640 | 41.1 | 4.8 | 8.1 | 0.98 | 2.34 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) | +| RTMDet-s | 640 | 44.6 | 8.89 | 14.8 | 1.22 | 2.96 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) | +| RTMDet-m | 640 | 49.4 | 24.71 | 39.27 | 1.62 | 6.41 | [config](./rtmdet_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220.log.json) | +| RTMDet-l | 640 | 51.5 | 52.3 | 80.23 | 2.44 | 10.32 | [config](./rtmdet_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030.log.json) | +| RTMDet-x | 640 | 52.8 | 94.86 | 141.67 | 3.10 | 18.80 | [config](./rtmdet_x_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555.log.json) | +| RTMDet-x-P6 | 1280 | 54.9 | | | | | [config](./rtmdet_x_p6_4xb8-300e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-p6/rtmdet_x_p6_4xb8-300e_coco-bf32be58.pth) | +| RTMDet-l-ConvNeXt-B | 640 | 53.1 | | | | | [config](./rtmdet_l_convnext_b_4xb32-100e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_convnext_b_4xb32-100e_coco-d4731b3d.pth) | +| RTMDet-l-Swin-B | 640 | 52.4 | | | | | [config](./rtmdet_l_swin_b_4xb32-100e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_swin_b_4xb32-100e_coco-0828ce5d.pth) | +| RTMDet-l-Swin-B-P6 | 1280 | 56.4 | | | | | [config](./rtmdet_l_swin_b_p6_4xb16-100e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_swin_b_p6_4xb16-100e_coco-a1486b6f.pth) | **Note**: diff --git a/configs/rtmdet/metafile.yml b/configs/rtmdet/metafile.yml index 7dc72e130be..a62abcb2faa 100644 --- a/configs/rtmdet/metafile.yml +++ b/configs/rtmdet/metafile.yml @@ -104,6 +104,48 @@ Models: box AP: 54.9 Weights: https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-p6/rtmdet_x_p6_4xb8-300e_coco-bf32be58.pth + - Name: rtmdet_l_convnext_b_4xb32-100e_coco + Alias: + - rtmdet-l_convnext_b + In Collection: RTMDet + Config: configs/rtmdet/rtmdet_l_convnext_b_4xb32-100e_coco.py + Metadata: + Epochs: 100 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 53.1 + Weights: https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_convnext_b_4xb32-100e_coco-d4731b3d.pth + + - Name: rtmdet_l_swin_b_4xb32-100e_coco + Alias: + - rtmdet-l_swin_b + In Collection: RTMDet + Config: configs/rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py + Metadata: + Epochs: 100 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 52.4 + Weights: https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_swin_b_4xb32-100e_coco-0828ce5d.pth + + - Name: rtmdet_l_swin_b_p6_4xb16-100e_coco + Alias: + - rtmdet-l_swin_b_p6 + In Collection: RTMDet + Config: configs/rtmdet/rtmdet_l_swin_b_p6_4xb16-100e_coco.py + Metadata: + Epochs: 100 + Results: + - Task: Object Detection + Dataset: COCO + Metrics: + box AP: 56.4 + Weights: https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_swin_b_p6_4xb16-100e_coco-a1486b6f.pth + - Name: rtmdet-ins_tiny_8xb32-300e_coco Alias: - rtmdet-ins-t diff --git a/configs/rtmdet/rtmdet_l_convnext_b_4xb32-100e_coco.py b/configs/rtmdet/rtmdet_l_convnext_b_4xb32-100e_coco.py new file mode 100644 index 00000000000..85af292bcab --- /dev/null +++ b/configs/rtmdet/rtmdet_l_convnext_b_4xb32-100e_coco.py @@ -0,0 +1,81 @@ +_base_ = './rtmdet_l_8xb32-300e_coco.py' + +custom_imports = dict( + imports=['mmpretrain.models'], allow_failed_imports=False) + +norm_cfg = dict(type='GN', num_groups=32) +checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_in21k-pre-3rdparty_in1k-384px_20221219-4570f792.pth' # noqa +model = dict( + type='RTMDet', + data_preprocessor=dict( + _delete_=True, + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + batch_augments=None), + backbone=dict( + _delete_=True, + type='mmpretrain.ConvNeXt', + arch='base', + out_indices=[1, 2, 3], + drop_path_rate=0.7, + layer_scale_init_value=1.0, + gap_before_final_norm=False, + with_cp=True, + init_cfg=dict( + type='Pretrained', checkpoint=checkpoint_file, + prefix='backbone.')), + neck=dict(in_channels=[256, 512, 1024], norm_cfg=norm_cfg), + bbox_head=dict(norm_cfg=norm_cfg)) + +max_epochs = 100 +stage2_num_epochs = 10 +interval = 10 +base_lr = 0.001 + +train_cfg = dict( + max_epochs=max_epochs, + val_interval=interval, + dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)]) + +optim_wrapper = dict( + constructor='LearningRateDecayOptimizerConstructor', + paramwise_cfg={ + 'decay_rate': 0.8, + 'decay_type': 'layer_wise', + 'num_layers': 12 + }, + optimizer=dict(lr=base_lr)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 50 to 100 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline={{_base_.train_pipeline_stage2}}) +] diff --git a/configs/rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py b/configs/rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py new file mode 100644 index 00000000000..84b0e0fa7d1 --- /dev/null +++ b/configs/rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py @@ -0,0 +1,78 @@ +_base_ = './rtmdet_l_8xb32-300e_coco.py' + +norm_cfg = dict(type='GN', num_groups=32) +checkpoint = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth' # noqa +model = dict( + type='RTMDet', + data_preprocessor=dict( + _delete_=True, + type='DetDataPreprocessor', + mean=[123.675, 116.28, 103.53], + std=[58.395, 57.12, 57.375], + bgr_to_rgb=True, + batch_augments=None), + backbone=dict( + _delete_=True, + type='SwinTransformer', + pretrain_img_size=384, + embed_dims=128, + depths=[2, 2, 18, 2], + num_heads=[4, 8, 16, 32], + window_size=12, + mlp_ratio=4, + qkv_bias=True, + qk_scale=None, + drop_rate=0., + attn_drop_rate=0., + drop_path_rate=0.3, + patch_norm=True, + out_indices=(1, 2, 3), + with_cp=True, + convert_weights=True, + init_cfg=dict(type='Pretrained', checkpoint=checkpoint)), + neck=dict(in_channels=[256, 512, 1024], norm_cfg=norm_cfg), + bbox_head=dict(norm_cfg=norm_cfg)) + +max_epochs = 100 +stage2_num_epochs = 10 +interval = 10 +base_lr = 0.001 + +train_cfg = dict( + max_epochs=max_epochs, + val_interval=interval, + dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)]) + +optim_wrapper = dict(optimizer=dict(lr=base_lr)) + +# learning rate +param_scheduler = [ + dict( + type='LinearLR', + start_factor=1.0e-5, + by_epoch=False, + begin=0, + end=1000), + dict( + # use cosine lr from 50 to 100 epoch + type='CosineAnnealingLR', + eta_min=base_lr * 0.05, + begin=max_epochs // 2, + end=max_epochs, + T_max=max_epochs // 2, + by_epoch=True, + convert_to_iter_based=True), +] + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline={{_base_.train_pipeline_stage2}}) +] diff --git a/configs/rtmdet/rtmdet_l_swin_b_p6_4xb16-100e_coco.py b/configs/rtmdet/rtmdet_l_swin_b_p6_4xb16-100e_coco.py new file mode 100644 index 00000000000..37d4215c3f0 --- /dev/null +++ b/configs/rtmdet/rtmdet_l_swin_b_p6_4xb16-100e_coco.py @@ -0,0 +1,114 @@ +_base_ = './rtmdet_l_swin_b_4xb32-100e_coco.py' + +model = dict( + backbone=dict( + depths=[2, 2, 18, 2, 1], + num_heads=[4, 8, 16, 32, 64], + strides=(4, 2, 2, 2, 2), + out_indices=(1, 2, 3, 4)), + neck=dict(in_channels=[256, 512, 1024, 2048]), + bbox_head=dict( + anchor_generator=dict( + type='MlvlPointGenerator', offset=0, strides=[8, 16, 32, 64]))) + +train_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict(type='CachedMosaic', img_scale=(1280, 1280), pad_val=114.0), + dict( + type='RandomResize', + scale=(2560, 2560), + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(1280, 1280)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(1280, 1280), pad_val=dict(img=(114, 114, 114))), + dict( + type='CachedMixUp', + img_scale=(1280, 1280), + ratio_range=(1.0, 1.0), + max_cached_images=20, + pad_val=(114, 114, 114)), + dict(type='PackDetInputs') +] + +train_pipeline_stage2 = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='RandomResize', + scale=(1280, 1280), + ratio_range=(0.1, 2.0), + keep_ratio=True), + dict(type='RandomCrop', crop_size=(1280, 1280)), + dict(type='YOLOXHSVRandomAug'), + dict(type='RandomFlip', prob=0.5), + dict(type='Pad', size=(1280, 1280), pad_val=dict(img=(114, 114, 114))), + dict(type='PackDetInputs') +] + +test_pipeline = [ + dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}), + dict(type='Resize', scale=(1280, 1280), keep_ratio=True), + dict(type='Pad', size=(1280, 1280), pad_val=dict(img=(114, 114, 114))), + dict(type='LoadAnnotations', with_bbox=True), + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor')) +] + +train_dataloader = dict( + batch_size=16, num_workers=20, dataset=dict(pipeline=train_pipeline)) +val_dataloader = dict(num_workers=20, dataset=dict(pipeline=test_pipeline)) +test_dataloader = val_dataloader + +max_epochs = 100 +stage2_num_epochs = 10 + +custom_hooks = [ + dict( + type='EMAHook', + ema_type='ExpMomentumEMA', + momentum=0.0002, + update_buffers=True, + priority=49), + dict( + type='PipelineSwitchHook', + switch_epoch=max_epochs - stage2_num_epochs, + switch_pipeline=train_pipeline_stage2) +] + +img_scales = [(1280, 1280), (640, 640), (1920, 1920)] +tta_pipeline = [ + dict(type='LoadImageFromFile', backend_args=None), + dict( + type='TestTimeAug', + transforms=[ + [ + dict(type='Resize', scale=s, keep_ratio=True) + for s in img_scales + ], + [ + # ``RandomFlip`` must be placed before ``Pad``, otherwise + # bounding box coordinates after flipping cannot be + # recovered correctly. + dict(type='RandomFlip', prob=1.), + dict(type='RandomFlip', prob=0.) + ], + [ + dict( + type='Pad', + size=(1920, 1920), + pad_val=dict(img=(114, 114, 114))), + ], + [dict(type='LoadAnnotations', with_bbox=True)], + [ + dict( + type='PackDetInputs', + meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape', + 'scale_factor', 'flip', 'flip_direction')) + ] + ]) +]