diff --git a/configs/rtmdet/README.md b/configs/rtmdet/README.md
index 4574dd613c1..1677184af76 100644
--- a/configs/rtmdet/README.md
+++ b/configs/rtmdet/README.md
@@ -20,14 +20,17 @@ In this paper, we aim to design an efficient real-time object detector that exce
### Object Detection
-| Model | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms)
RTX3090 | TRT-FP16-Latency(ms)
T4 | Config | Download |
-| :---------: | :--: | :----: | :-------: | :------: | :-----------------------------: | :------------------------: | :----------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
-| RTMDet-tiny | 640 | 41.1 | 4.8 | 8.1 | 0.98 | 2.34 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) |
-| RTMDet-s | 640 | 44.6 | 8.89 | 14.8 | 1.22 | 2.96 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) |
-| RTMDet-m | 640 | 49.4 | 24.71 | 39.27 | 1.62 | 6.41 | [config](./rtmdet_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220.log.json) |
-| RTMDet-l | 640 | 51.5 | 52.3 | 80.23 | 2.44 | 10.32 | [config](./rtmdet_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030.log.json) |
-| RTMDet-x | 640 | 52.8 | 94.86 | 141.67 | 3.10 | 18.80 | [config](./rtmdet_x_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555.log.json) |
-| RTMDet-x-P6 | 1280 | 54.9 | | | | | [config](./rtmdet_x_p6_4xb8-300e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-p6/rtmdet_x_p6_4xb8-300e_coco-bf32be58.pth) |
+| Model | size | box AP | Params(M) | FLOPS(G) | TRT-FP16-Latency(ms)
RTX3090 | TRT-FP16-Latency(ms)
T4 | Config | Download |
+| :-----------------: | :--: | :----: | :-------: | :------: | :-----------------------------: | :------------------------: | :------------------------------------------------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
+| RTMDet-tiny | 640 | 41.1 | 4.8 | 8.1 | 0.98 | 2.34 | [config](./rtmdet_tiny_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414-78e30dcc.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_tiny_8xb32-300e_coco/rtmdet_tiny_8xb32-300e_coco_20220902_112414.log.json) |
+| RTMDet-s | 640 | 44.6 | 8.89 | 14.8 | 1.22 | 2.96 | [config](./rtmdet_s_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602-387a891e.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_s_8xb32-300e_coco/rtmdet_s_8xb32-300e_coco_20220905_161602.log.json) |
+| RTMDet-m | 640 | 49.4 | 24.71 | 39.27 | 1.62 | 6.41 | [config](./rtmdet_m_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220-229f527c.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_m_8xb32-300e_coco/rtmdet_m_8xb32-300e_coco_20220719_112220.log.json) |
+| RTMDet-l | 640 | 51.5 | 52.3 | 80.23 | 2.44 | 10.32 | [config](./rtmdet_l_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030-5a0be7c4.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_l_8xb32-300e_coco/rtmdet_l_8xb32-300e_coco_20220719_112030.log.json) |
+| RTMDet-x | 640 | 52.8 | 94.86 | 141.67 | 3.10 | 18.80 | [config](./rtmdet_x_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555-cc79b9ae.pth) \| [log](https://download.openmmlab.com/mmdetection/v3.0/rtmdet/rtmdet_x_8xb32-300e_coco/rtmdet_x_8xb32-300e_coco_20220715_230555.log.json) |
+| RTMDet-x-P6 | 1280 | 54.9 | | | | | [config](./rtmdet_x_p6_4xb8-300e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-p6/rtmdet_x_p6_4xb8-300e_coco-bf32be58.pth) |
+| RTMDet-l-ConvNeXt-B | 640 | 53.1 | | | | | [config](./rtmdet_l_convnext_b_4xb32-100e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_convnext_b_4xb32-100e_coco-d4731b3d.pth) |
+| RTMDet-l-Swin-B | 640 | 52.4 | | | | | [config](./rtmdet_l_swin_b_4xb32-100e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_swin_b_4xb32-100e_coco-0828ce5d.pth) |
+| RTMDet-l-Swin-B-P6 | 1280 | 56.4 | | | | | [config](./rtmdet_l_swin_b_p6_4xb16-100e_coco.py) | [model](https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_swin_b_p6_4xb16-100e_coco-a1486b6f.pth) |
**Note**:
diff --git a/configs/rtmdet/metafile.yml b/configs/rtmdet/metafile.yml
index 7dc72e130be..a62abcb2faa 100644
--- a/configs/rtmdet/metafile.yml
+++ b/configs/rtmdet/metafile.yml
@@ -104,6 +104,48 @@ Models:
box AP: 54.9
Weights: https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-p6/rtmdet_x_p6_4xb8-300e_coco-bf32be58.pth
+ - Name: rtmdet_l_convnext_b_4xb32-100e_coco
+ Alias:
+ - rtmdet-l_convnext_b
+ In Collection: RTMDet
+ Config: configs/rtmdet/rtmdet_l_convnext_b_4xb32-100e_coco.py
+ Metadata:
+ Epochs: 100
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 53.1
+ Weights: https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_convnext_b_4xb32-100e_coco-d4731b3d.pth
+
+ - Name: rtmdet_l_swin_b_4xb32-100e_coco
+ Alias:
+ - rtmdet-l_swin_b
+ In Collection: RTMDet
+ Config: configs/rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py
+ Metadata:
+ Epochs: 100
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 52.4
+ Weights: https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_swin_b_4xb32-100e_coco-0828ce5d.pth
+
+ - Name: rtmdet_l_swin_b_p6_4xb16-100e_coco
+ Alias:
+ - rtmdet-l_swin_b_p6
+ In Collection: RTMDet
+ Config: configs/rtmdet/rtmdet_l_swin_b_p6_4xb16-100e_coco.py
+ Metadata:
+ Epochs: 100
+ Results:
+ - Task: Object Detection
+ Dataset: COCO
+ Metrics:
+ box AP: 56.4
+ Weights: https://github.com/orange0-jp/orange-weights/releases/download/v0.1.0rtmdet-swin-convnext/rtmdet_l_swin_b_p6_4xb16-100e_coco-a1486b6f.pth
+
- Name: rtmdet-ins_tiny_8xb32-300e_coco
Alias:
- rtmdet-ins-t
diff --git a/configs/rtmdet/rtmdet_l_convnext_b_4xb32-100e_coco.py b/configs/rtmdet/rtmdet_l_convnext_b_4xb32-100e_coco.py
new file mode 100644
index 00000000000..85af292bcab
--- /dev/null
+++ b/configs/rtmdet/rtmdet_l_convnext_b_4xb32-100e_coco.py
@@ -0,0 +1,81 @@
+_base_ = './rtmdet_l_8xb32-300e_coco.py'
+
+custom_imports = dict(
+ imports=['mmpretrain.models'], allow_failed_imports=False)
+
+norm_cfg = dict(type='GN', num_groups=32)
+checkpoint_file = 'https://download.openmmlab.com/mmclassification/v0/convnext/convnext-base_in21k-pre-3rdparty_in1k-384px_20221219-4570f792.pth' # noqa
+model = dict(
+ type='RTMDet',
+ data_preprocessor=dict(
+ _delete_=True,
+ type='DetDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ batch_augments=None),
+ backbone=dict(
+ _delete_=True,
+ type='mmpretrain.ConvNeXt',
+ arch='base',
+ out_indices=[1, 2, 3],
+ drop_path_rate=0.7,
+ layer_scale_init_value=1.0,
+ gap_before_final_norm=False,
+ with_cp=True,
+ init_cfg=dict(
+ type='Pretrained', checkpoint=checkpoint_file,
+ prefix='backbone.')),
+ neck=dict(in_channels=[256, 512, 1024], norm_cfg=norm_cfg),
+ bbox_head=dict(norm_cfg=norm_cfg))
+
+max_epochs = 100
+stage2_num_epochs = 10
+interval = 10
+base_lr = 0.001
+
+train_cfg = dict(
+ max_epochs=max_epochs,
+ val_interval=interval,
+ dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)])
+
+optim_wrapper = dict(
+ constructor='LearningRateDecayOptimizerConstructor',
+ paramwise_cfg={
+ 'decay_rate': 0.8,
+ 'decay_type': 'layer_wise',
+ 'num_layers': 12
+ },
+ optimizer=dict(lr=base_lr))
+
+# learning rate
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1.0e-5,
+ by_epoch=False,
+ begin=0,
+ end=1000),
+ dict(
+ # use cosine lr from 50 to 100 epoch
+ type='CosineAnnealingLR',
+ eta_min=base_lr * 0.05,
+ begin=max_epochs // 2,
+ end=max_epochs,
+ T_max=max_epochs // 2,
+ by_epoch=True,
+ convert_to_iter_based=True),
+]
+
+custom_hooks = [
+ dict(
+ type='EMAHook',
+ ema_type='ExpMomentumEMA',
+ momentum=0.0002,
+ update_buffers=True,
+ priority=49),
+ dict(
+ type='PipelineSwitchHook',
+ switch_epoch=max_epochs - stage2_num_epochs,
+ switch_pipeline={{_base_.train_pipeline_stage2}})
+]
diff --git a/configs/rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py b/configs/rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py
new file mode 100644
index 00000000000..84b0e0fa7d1
--- /dev/null
+++ b/configs/rtmdet/rtmdet_l_swin_b_4xb32-100e_coco.py
@@ -0,0 +1,78 @@
+_base_ = './rtmdet_l_8xb32-300e_coco.py'
+
+norm_cfg = dict(type='GN', num_groups=32)
+checkpoint = 'https://github.com/SwinTransformer/storage/releases/download/v1.0.0/swin_base_patch4_window12_384_22k.pth' # noqa
+model = dict(
+ type='RTMDet',
+ data_preprocessor=dict(
+ _delete_=True,
+ type='DetDataPreprocessor',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ batch_augments=None),
+ backbone=dict(
+ _delete_=True,
+ type='SwinTransformer',
+ pretrain_img_size=384,
+ embed_dims=128,
+ depths=[2, 2, 18, 2],
+ num_heads=[4, 8, 16, 32],
+ window_size=12,
+ mlp_ratio=4,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.3,
+ patch_norm=True,
+ out_indices=(1, 2, 3),
+ with_cp=True,
+ convert_weights=True,
+ init_cfg=dict(type='Pretrained', checkpoint=checkpoint)),
+ neck=dict(in_channels=[256, 512, 1024], norm_cfg=norm_cfg),
+ bbox_head=dict(norm_cfg=norm_cfg))
+
+max_epochs = 100
+stage2_num_epochs = 10
+interval = 10
+base_lr = 0.001
+
+train_cfg = dict(
+ max_epochs=max_epochs,
+ val_interval=interval,
+ dynamic_intervals=[(max_epochs - stage2_num_epochs, 1)])
+
+optim_wrapper = dict(optimizer=dict(lr=base_lr))
+
+# learning rate
+param_scheduler = [
+ dict(
+ type='LinearLR',
+ start_factor=1.0e-5,
+ by_epoch=False,
+ begin=0,
+ end=1000),
+ dict(
+ # use cosine lr from 50 to 100 epoch
+ type='CosineAnnealingLR',
+ eta_min=base_lr * 0.05,
+ begin=max_epochs // 2,
+ end=max_epochs,
+ T_max=max_epochs // 2,
+ by_epoch=True,
+ convert_to_iter_based=True),
+]
+
+custom_hooks = [
+ dict(
+ type='EMAHook',
+ ema_type='ExpMomentumEMA',
+ momentum=0.0002,
+ update_buffers=True,
+ priority=49),
+ dict(
+ type='PipelineSwitchHook',
+ switch_epoch=max_epochs - stage2_num_epochs,
+ switch_pipeline={{_base_.train_pipeline_stage2}})
+]
diff --git a/configs/rtmdet/rtmdet_l_swin_b_p6_4xb16-100e_coco.py b/configs/rtmdet/rtmdet_l_swin_b_p6_4xb16-100e_coco.py
new file mode 100644
index 00000000000..37d4215c3f0
--- /dev/null
+++ b/configs/rtmdet/rtmdet_l_swin_b_p6_4xb16-100e_coco.py
@@ -0,0 +1,114 @@
+_base_ = './rtmdet_l_swin_b_4xb32-100e_coco.py'
+
+model = dict(
+ backbone=dict(
+ depths=[2, 2, 18, 2, 1],
+ num_heads=[4, 8, 16, 32, 64],
+ strides=(4, 2, 2, 2, 2),
+ out_indices=(1, 2, 3, 4)),
+ neck=dict(in_channels=[256, 512, 1024, 2048]),
+ bbox_head=dict(
+ anchor_generator=dict(
+ type='MlvlPointGenerator', offset=0, strides=[8, 16, 32, 64])))
+
+train_pipeline = [
+ dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(type='CachedMosaic', img_scale=(1280, 1280), pad_val=114.0),
+ dict(
+ type='RandomResize',
+ scale=(2560, 2560),
+ ratio_range=(0.1, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=(1280, 1280)),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Pad', size=(1280, 1280), pad_val=dict(img=(114, 114, 114))),
+ dict(
+ type='CachedMixUp',
+ img_scale=(1280, 1280),
+ ratio_range=(1.0, 1.0),
+ max_cached_images=20,
+ pad_val=(114, 114, 114)),
+ dict(type='PackDetInputs')
+]
+
+train_pipeline_stage2 = [
+ dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='RandomResize',
+ scale=(1280, 1280),
+ ratio_range=(0.1, 2.0),
+ keep_ratio=True),
+ dict(type='RandomCrop', crop_size=(1280, 1280)),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Pad', size=(1280, 1280), pad_val=dict(img=(114, 114, 114))),
+ dict(type='PackDetInputs')
+]
+
+test_pipeline = [
+ dict(type='LoadImageFromFile', backend_args={{_base_.backend_args}}),
+ dict(type='Resize', scale=(1280, 1280), keep_ratio=True),
+ dict(type='Pad', size=(1280, 1280), pad_val=dict(img=(114, 114, 114))),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='PackDetInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor'))
+]
+
+train_dataloader = dict(
+ batch_size=16, num_workers=20, dataset=dict(pipeline=train_pipeline))
+val_dataloader = dict(num_workers=20, dataset=dict(pipeline=test_pipeline))
+test_dataloader = val_dataloader
+
+max_epochs = 100
+stage2_num_epochs = 10
+
+custom_hooks = [
+ dict(
+ type='EMAHook',
+ ema_type='ExpMomentumEMA',
+ momentum=0.0002,
+ update_buffers=True,
+ priority=49),
+ dict(
+ type='PipelineSwitchHook',
+ switch_epoch=max_epochs - stage2_num_epochs,
+ switch_pipeline=train_pipeline_stage2)
+]
+
+img_scales = [(1280, 1280), (640, 640), (1920, 1920)]
+tta_pipeline = [
+ dict(type='LoadImageFromFile', backend_args=None),
+ dict(
+ type='TestTimeAug',
+ transforms=[
+ [
+ dict(type='Resize', scale=s, keep_ratio=True)
+ for s in img_scales
+ ],
+ [
+ # ``RandomFlip`` must be placed before ``Pad``, otherwise
+ # bounding box coordinates after flipping cannot be
+ # recovered correctly.
+ dict(type='RandomFlip', prob=1.),
+ dict(type='RandomFlip', prob=0.)
+ ],
+ [
+ dict(
+ type='Pad',
+ size=(1920, 1920),
+ pad_val=dict(img=(114, 114, 114))),
+ ],
+ [dict(type='LoadAnnotations', with_bbox=True)],
+ [
+ dict(
+ type='PackDetInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor', 'flip', 'flip_direction'))
+ ]
+ ])
+]