import torch from torch.optim import SGD from torch.optim.lr_scheduler import _LRScheduler class LinearDecayLR(_LRScheduler): def __init__(self, optimizer, n_epoch, start_decay, last_epoch=-1): self.start_decay=start_decay self.n_epoch=n_epoch super(LinearDecayLR, self).__init__(optimizer, last_epoch) def get_lr(self): last_epoch = self.last_epoch n_epoch=self.n_epoch b_lr=self.base_lrs[0] start_decay=self.start_decay if last_epoch>start_decay: lr=b_lr-b_lr/(n_epoch-start_decay)*(last_epoch-start_decay) else: lr=b_lr return [lr]