Spaces:
Runtime error
Runtime error
from torch.optim.lr_scheduler import _LRScheduler | |
class PolyScheduler(_LRScheduler): | |
def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1): | |
self.base_lr = base_lr | |
self.warmup_lr_init = 0.0001 | |
self.max_steps: int = max_steps | |
self.warmup_steps: int = warmup_steps | |
self.power = 2 | |
super(PolyScheduler, self).__init__(optimizer, -1, False) | |
self.last_epoch = last_epoch | |
def get_warmup_lr(self): | |
alpha = float(self.last_epoch) / float(self.warmup_steps) | |
return [self.base_lr * alpha for _ in self.optimizer.param_groups] | |
def get_lr(self): | |
if self.last_epoch == -1: | |
return [self.warmup_lr_init for _ in self.optimizer.param_groups] | |
if self.last_epoch < self.warmup_steps: | |
return self.get_warmup_lr() | |
else: | |
alpha = pow( | |
1 - float(self.last_epoch - self.warmup_steps) / float(self.max_steps - self.warmup_steps), | |
self.power, | |
) | |
return [self.base_lr * alpha for _ in self.optimizer.param_groups] | |