RSPrompter / mmpl /engine /hooks /ppyoloe_param_scheduler_hook.py
KyanChen's picture
Upload 159 files
1c3eb47
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
from mmengine.hooks import ParamSchedulerHook
from mmengine.runner import Runner
from mmyolo.registry import HOOKS
@HOOKS.register_module()
class PPYOLOEParamSchedulerHook(ParamSchedulerHook):
"""A hook to update learning rate and momentum in optimizer of PPYOLOE. We
use this hook to implement adaptive computation for `warmup_total_iters`,
which is not possible with the built-in ParamScheduler in mmyolo.
Args:
warmup_min_iter (int): Minimum warmup iters. Defaults to 1000.
start_factor (float): The number we multiply learning rate in the
first epoch. The multiplication factor changes towards end_factor
in the following epochs. Defaults to 0.
warmup_epochs (int): Epochs for warmup. Defaults to 5.
min_lr_ratio (float): Minimum learning rate ratio.
total_epochs (int): In PPYOLOE, `total_epochs` is set to
training_epochs x 1.2. Defaults to 360.
"""
priority = 9
def __init__(self,
warmup_min_iter: int = 1000,
start_factor: float = 0.,
warmup_epochs: int = 5,
min_lr_ratio: float = 0.0,
total_epochs: int = 360):
self.warmup_min_iter = warmup_min_iter
self.start_factor = start_factor
self.warmup_epochs = warmup_epochs
self.min_lr_ratio = min_lr_ratio
self.total_epochs = total_epochs
self._warmup_end = False
self._base_lr = None
def before_train(self, runner: Runner):
"""Operations before train.
Args:
runner (Runner): The runner of the training process.
"""
optimizer = runner.optim_wrapper.optimizer
for group in optimizer.param_groups:
# If the param is never be scheduled, record the current value
# as the initial value.
group.setdefault('initial_lr', group['lr'])
self._base_lr = [
group['initial_lr'] for group in optimizer.param_groups
]
self._min_lr = [i * self.min_lr_ratio for i in self._base_lr]
def before_train_iter(self,
runner: Runner,
batch_idx: int,
data_batch: Optional[dict] = None):
"""Operations before each training iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (dict or tuple or list, optional): Data from dataloader.
"""
cur_iters = runner.iter
optimizer = runner.optim_wrapper.optimizer
dataloader_len = len(runner.train_dataloader)
# The minimum warmup is self.warmup_min_iter
warmup_total_iters = max(
round(self.warmup_epochs * dataloader_len), self.warmup_min_iter)
if cur_iters <= warmup_total_iters:
# warm up
alpha = cur_iters / warmup_total_iters
factor = self.start_factor * (1 - alpha) + alpha
for group_idx, param in enumerate(optimizer.param_groups):
param['lr'] = self._base_lr[group_idx] * factor
else:
for group_idx, param in enumerate(optimizer.param_groups):
total_iters = self.total_epochs * dataloader_len
lr = self._min_lr[group_idx] + (
self._base_lr[group_idx] -
self._min_lr[group_idx]) * 0.5 * (
math.cos((cur_iters - warmup_total_iters) * math.pi /
(total_iters - warmup_total_iters)) + 1.0)
param['lr'] = lr