File size: 667 Bytes
caa56d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from torch.optim import SGD
from torch.optim.lr_scheduler import _LRScheduler
class LinearDecayLR(_LRScheduler):
def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1):
self.start_decay=start_decay
self.n_epoch=n_epoch
super(LinearDecayLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
last_epoch = self.last_epoch
n_epoch=self.n_epoch
b_lr=self.base_lrs[0]
start_decay=self.start_decay
if last_epoch>start_decay:
lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay)
else:
lr=b_lr
return [lr] |