|
import math |
|
import torch |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
|
|
class CosineAnnealingWarmupRestarts(_LRScheduler): |
|
""" |
|
optimizer (Optimizer): Wrapped optimizer. |
|
first_cycle_steps (int): First cycle step size. |
|
cycle_mult(float): Cycle steps magnification. Default: -1. |
|
max_lr(float): First cycle's max learning rate. Default: 0.1. |
|
min_lr(float): Min learning rate. Default: 0.001. |
|
warmup_steps(int): Linear warmup step size. Default: 0. |
|
gamma(float): Decrease rate of max learning rate by cycle. Default: 1. |
|
last_epoch (int): The index of last epoch. Default: -1. |
|
""" |
|
|
|
def __init__(self, |
|
optimizer : torch.optim.Optimizer, |
|
first_cycle_steps : int, |
|
cycle_mult : float = 1., |
|
max_lr : float = 0.1, |
|
min_lr : float = 0.001, |
|
warmup_steps : int = 0, |
|
gamma : float = 1., |
|
last_epoch : int = -1 |
|
): |
|
assert warmup_steps < first_cycle_steps |
|
|
|
self.first_cycle_steps = first_cycle_steps |
|
self.cycle_mult = cycle_mult |
|
self.base_max_lr = max_lr |
|
self.max_lr = max_lr |
|
self.min_lr = min_lr |
|
self.warmup_steps = warmup_steps |
|
self.gamma = gamma |
|
|
|
self.cur_cycle_steps = first_cycle_steps |
|
self.cycle = 0 |
|
self.step_in_cycle = last_epoch |
|
|
|
super(CosineAnnealingWarmupRestarts, self).__init__(optimizer, last_epoch) |
|
|
|
|
|
self.init_lr() |
|
|
|
def init_lr(self): |
|
self.base_lrs = [] |
|
for param_group in self.optimizer.param_groups: |
|
param_group['lr'] = self.min_lr |
|
self.base_lrs.append(self.min_lr) |
|
|
|
def get_lr(self): |
|
if self.step_in_cycle == -1: |
|
return self.base_lrs |
|
elif self.step_in_cycle < self.warmup_steps: |
|
return [(self.max_lr - base_lr)*self.step_in_cycle / self.warmup_steps + base_lr for base_lr in self.base_lrs] |
|
else: |
|
|
|
return [base_lr + (self.max_lr - base_lr) \ |
|
* (1 + math.cos(math.pi * (self.step_in_cycle-self.warmup_steps) \ |
|
/ (self.cur_cycle_steps - self.warmup_steps))) / 2 |
|
for base_lr in self.base_lrs] |
|
|
|
def step(self, epoch=None): |
|
if epoch is None: |
|
epoch = self.last_epoch + 1 |
|
self.step_in_cycle = self.step_in_cycle + 1 |
|
if self.step_in_cycle >= self.cur_cycle_steps: |
|
self.cycle += 1 |
|
self.step_in_cycle = self.step_in_cycle - self.cur_cycle_steps |
|
self.cur_cycle_steps = int((self.cur_cycle_steps - self.warmup_steps) * self.cycle_mult) + self.warmup_steps |
|
else: |
|
if epoch >= self.first_cycle_steps: |
|
if self.cycle_mult == 1.: |
|
self.step_in_cycle = epoch % self.first_cycle_steps |
|
self.cycle = epoch // self.first_cycle_steps |
|
else: |
|
n = int(math.log((epoch / self.first_cycle_steps * (self.cycle_mult - 1) + 1), self.cycle_mult)) |
|
self.cycle = n |
|
self.step_in_cycle = epoch - int(self.first_cycle_steps * (self.cycle_mult ** n - 1) / (self.cycle_mult - 1)) |
|
self.cur_cycle_steps = self.first_cycle_steps * self.cycle_mult ** (n) |
|
else: |
|
self.cur_cycle_steps = self.first_cycle_steps |
|
self.step_in_cycle = epoch |
|
|
|
self.max_lr = self.base_max_lr * (self.gamma**self.cycle) |
|
self.last_epoch = math.floor(epoch) |
|
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): |
|
param_group['lr'] = lr |
|
|