Distill RTM Detectors Based on MMRazor
Description
To further improve the model accuracy while not introducing much additional computation cost, we apply the feature-based distillation to the training phase of these RTM detectors. In summary, our distillation strategy are threefold:
(1) Inspired by PKD, we first normalize the intermediate feature maps to have zero mean and unit variances before calculating the distillation loss.
(2) Inspired by CWD, we adopt the channel-wise distillation paradigm, which can pay more attention to the most salient regions of each channel.
(3) Inspired by DAMO-YOLO, the distillation process is split into two stages. 1) The teacher distills the student at the first stage (280 epochs) on strong mosaic domain. 2) The student finetunes itself on no masaic domain at the second stage (20 epochs).
Results and Models
Location | Dataset | Teacher | Student | mAP | mAP(T) | mAP(S) | Config | Download |
---|---|---|---|---|---|---|---|---|
FPN | COCO | RTMDet-s | RTMDet-tiny | 41.8 (+0.8) | 44.6 | 41.0 | config | teacher |model | log |
FPN | COCO | RTMDet-m | RTMDet-s | 45.7 (+1.1) | 49.3 | 44.6 | config | teacher |model | log |
FPN | COCO | RTMDet-l | RTMDet-m | 50.2 (+0.9) | 51.4 | 49.3 | config | teacher |model | log |
FPN | COCO | RTMDet-x | RTMDet-l | 52.3 (+0.9) | 52.8 | 51.4 | config | teacher |model | log |
Usage
Prerequisites
Install MMRazor from source
git clone -b dev-1.x https://github.com/open-mmlab/mmrazor.git
cd mmrazor
# Install MMRazor
mim install -v -e .
Training commands
In MMYOLO's root directory, run the following command to train the RTMDet-tiny with 8 GPUs, using RTMDet-s as the teacher:
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh configs/rtmdet/distillation/kd_tiny_rtmdet_s_neck_300e_coco.py
Testing commands
In MMYOLO's root directory, run the following command to test the model:
CUDA_VISIBLE_DEVICES=0 PORT=29500 ./tools/dist_test.sh configs/rtmdet/distillation/kd_tiny_rtmdet_s_neck_300e_coco.py ${CHECKPOINT_PATH}
Getting student-only checkpoint
After training, the checkpoint contains parameters for both student and teacher models. Run the following command to convert it to student-only checkpoint:
python ./tools/model_converters/convert_kd_ckpt_to_student.py ${CHECKPOINT_PATH} --out-path ${OUTPUT_CHECKPOINT_PATH}
Configs
Here we provide detection configs and models for MMRazor in MMYOLO. For clarify,
we take ./kd_tiny_rtmdet_s_neck_300e_coco.py
as an example to show how to
distill a RTM detector based on MMRazor.
Here is the main part of ./kd_tiny_rtmdet_s_neck_300e_coco.py
.
norm_cfg = dict(type='BN', affine=False, track_running_stats=False)
distiller=dict(
type='ConfigurableDistiller',
student_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv'),
),
teacher_recorders=dict(
fpn0=dict(type='ModuleOutputs', source='neck.out_layers.0.conv'),
fpn1=dict(type='ModuleOutputs', source='neck.out_layers.1.conv'),
fpn2=dict(type='ModuleOutputs', source='neck.out_layers.2.conv')),
connectors=dict(
fpn0_s=dict(type='ConvModuleConnector', in_channel=96,
out_channel=128, bias=False, norm_cfg=norm_cfg,
act_cfg=None),
fpn0_t=dict(
type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
fpn1_s=dict(
type='ConvModuleConnector', in_channel=96,
out_channel=128, bias=False, norm_cfg=norm_cfg,
act_cfg=None),
fpn1_t=dict(
type='NormConnector', in_channels=128, norm_cfg=norm_cfg),
fpn2_s=dict(
type='ConvModuleConnector', in_channel=96,
out_channel=128, bias=False, norm_cfg=norm_cfg,
act_cfg=None),
fpn2_t=dict(
type='NormConnector', in_channels=128, norm_cfg=norm_cfg)),
distill_losses=dict(
loss_fpn0=dict(type='ChannelWiseDivergence', loss_weight=1),
loss_fpn1=dict(type='ChannelWiseDivergence', loss_weight=1),
loss_fpn2=dict(type='ChannelWiseDivergence', loss_weight=1)),
loss_forward_mappings=dict(
loss_fpn0=dict(
preds_S=dict(from_student=True, recorder='fpn0', connector='fpn0_s'),
preds_T=dict(from_student=False, recorder='fpn0', connector='fpn0_t')),
loss_fpn1=dict(
preds_S=dict(from_student=True, recorder='fpn1', connector='fpn1_s'),
preds_T=dict(from_student=False, recorder='fpn1', connector='fpn1_t')),
loss_fpn2=dict(
preds_S=dict(from_student=True, recorder='fpn2', connector='fpn2_s'),
preds_T=dict(from_student=False, recorder='fpn2', connector='fpn2_t'))))
recorders
are used to record various intermediate results during the model forward.
In this example, they can help record the output of 3 nn.Module
of the teacher
and the student. Details are list in Recorder and MMRazor Distillation (if users can read Chinese).
connectors
are adaptive layers which usually map teacher's and students features
to the same dimension.
distill_losses
are configs for multiple distill losses.
loss_forward_mappings
are mappings between distill loss forward arguments and records.
In addition, the student finetunes itself on no masaic domain at the last 20 epochs,
so we add a new hook named StopDistillHook
to stop distillation on time.
We need to add this hook to the custom_hooks
list like this:
custom_hooks = [..., dict(type='mmrazor.StopDistillHook', detach_epoch=280)]