hyliu's picture
Upload folder using huggingface_hub
8ec10cf verified
import math
from bisect import bisect_right
from torch.optim.lr_scheduler import _LRScheduler
# MultiStep learning rate scheduler with warm restart
class WarmMultiStepLR(_LRScheduler):
def __init__(self, optimizer, milestones, gamma=0.1, last_epoch=-1, scale=1):
if not list(milestones) == sorted(milestones):
raise ValueError(
'Milestones should be a list of increasing integers. Got {}',
milestones
)
self.milestones = milestones
self.gamma = gamma
self.scale = scale
self.warmup_epochs = 5
self.gradual = (self.scale - 1) / self.warmup_epochs
super(WarmMultiStepLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
if self.last_epoch < self.warmup_epochs:
return [
base_lr * (1 + self.last_epoch * self.gradual) / self.scale
for base_lr in self.base_lrs
]
else:
return [
base_lr * self.gamma ** bisect_right(self.milestones, self.last_epoch)
for base_lr in self.base_lrs
]