|
""" |
|
Modified from https://github.com/KaiyangZhou/deep-person-reid |
|
""" |
|
import torch |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
|
AVAI_SCHEDS = ["single_step", "multi_step", "cosine"] |
|
|
|
|
|
class _BaseWarmupScheduler(_LRScheduler): |
|
|
|
def __init__( |
|
self, |
|
optimizer, |
|
successor, |
|
warmup_epoch, |
|
last_epoch=-1, |
|
verbose=False |
|
): |
|
self.successor = successor |
|
self.warmup_epoch = warmup_epoch |
|
super().__init__(optimizer, last_epoch, verbose) |
|
|
|
def get_lr(self): |
|
raise NotImplementedError |
|
|
|
def step(self, epoch=None): |
|
if self.last_epoch >= self.warmup_epoch: |
|
self.successor.step(epoch) |
|
self._last_lr = self.successor.get_last_lr() |
|
else: |
|
super().step(epoch) |
|
|
|
|
|
class ConstantWarmupScheduler(_BaseWarmupScheduler): |
|
|
|
def __init__( |
|
self, |
|
optimizer, |
|
successor, |
|
warmup_epoch, |
|
cons_lr, |
|
last_epoch=-1, |
|
verbose=False |
|
): |
|
self.cons_lr = cons_lr |
|
super().__init__( |
|
optimizer, successor, warmup_epoch, last_epoch, verbose |
|
) |
|
|
|
def get_lr(self): |
|
if self.last_epoch >= self.warmup_epoch: |
|
return self.successor.get_last_lr() |
|
return [self.cons_lr for _ in self.base_lrs] |
|
|
|
|
|
class LinearWarmupScheduler(_BaseWarmupScheduler): |
|
|
|
def __init__( |
|
self, |
|
optimizer, |
|
successor, |
|
warmup_epoch, |
|
min_lr, |
|
last_epoch=-1, |
|
verbose=False |
|
): |
|
self.min_lr = min_lr |
|
super().__init__( |
|
optimizer, successor, warmup_epoch, last_epoch, verbose |
|
) |
|
|
|
def get_lr(self): |
|
if self.last_epoch >= self.warmup_epoch: |
|
return self.successor.get_last_lr() |
|
if self.last_epoch == 0: |
|
return [self.min_lr for _ in self.base_lrs] |
|
return [ |
|
lr * self.last_epoch / self.warmup_epoch for lr in self.base_lrs |
|
] |
|
|
|
|
|
def build_lr_scheduler(optimizer, optim_cfg): |
|
"""A function wrapper for building a learning rate scheduler. |
|
|
|
Args: |
|
optimizer (Optimizer): an Optimizer. |
|
optim_cfg (CfgNode): optimization config. |
|
""" |
|
lr_scheduler = optim_cfg.LR_SCHEDULER |
|
stepsize = optim_cfg.STEPSIZE |
|
gamma = optim_cfg.GAMMA |
|
max_epoch = optim_cfg.MAX_EPOCH |
|
|
|
if lr_scheduler not in AVAI_SCHEDS: |
|
raise ValueError( |
|
f"scheduler must be one of {AVAI_SCHEDS}, but got {lr_scheduler}" |
|
) |
|
|
|
if lr_scheduler == "single_step": |
|
if isinstance(stepsize, (list, tuple)): |
|
stepsize = stepsize[-1] |
|
|
|
if not isinstance(stepsize, int): |
|
raise TypeError( |
|
"For single_step lr_scheduler, stepsize must " |
|
f"be an integer, but got {type(stepsize)}" |
|
) |
|
|
|
if stepsize <= 0: |
|
stepsize = max_epoch |
|
|
|
scheduler = torch.optim.lr_scheduler.StepLR( |
|
optimizer, step_size=stepsize, gamma=gamma |
|
) |
|
|
|
elif lr_scheduler == "multi_step": |
|
if not isinstance(stepsize, (list, tuple)): |
|
raise TypeError( |
|
"For multi_step lr_scheduler, stepsize must " |
|
f"be a list, but got {type(stepsize)}" |
|
) |
|
|
|
scheduler = torch.optim.lr_scheduler.MultiStepLR( |
|
optimizer, milestones=stepsize, gamma=gamma |
|
) |
|
|
|
elif lr_scheduler == "cosine": |
|
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
|
optimizer, float(max_epoch) |
|
) |
|
|
|
if optim_cfg.WARMUP_EPOCH > 0: |
|
if not optim_cfg.WARMUP_RECOUNT: |
|
scheduler.last_epoch = optim_cfg.WARMUP_EPOCH |
|
|
|
if optim_cfg.WARMUP_TYPE == "constant": |
|
scheduler = ConstantWarmupScheduler( |
|
optimizer, scheduler, optim_cfg.WARMUP_EPOCH, |
|
optim_cfg.WARMUP_CONS_LR |
|
) |
|
|
|
elif optim_cfg.WARMUP_TYPE == "linear": |
|
scheduler = LinearWarmupScheduler( |
|
optimizer, scheduler, optim_cfg.WARMUP_EPOCH, |
|
optim_cfg.WARMUP_MIN_LR |
|
) |
|
|
|
else: |
|
raise ValueError |
|
|
|
return scheduler |
|
|