|
import torch |
|
from torch.optim.lr_scheduler import _LRScheduler |
|
import math |
|
|
|
|
|
class CosineLR(_LRScheduler): |
|
def __init__(self, optimizer, init_lr, total_epochs, last_epoch=-1): |
|
super(CosineLR, self).__init__(optimizer, last_epoch=-1) |
|
self.optimizer = optimizer |
|
self.init_lr = init_lr |
|
self.total_epochs = total_epochs |
|
self.last_epoch = last_epoch |
|
print(f'CosineLR start from epoch(step) {last_epoch} with init_lr {init_lr} ') |
|
|
|
def get_lr(self): |
|
if self.last_epoch == 0: |
|
return [group['lr'] for group in self.optimizer.param_groups] |
|
|
|
return [0.5 * (1 + math.cos(self.last_epoch * math.pi / self.total_epochs)) * self.init_lr for group in |
|
self.optimizer.param_groups] |
|
|