from torch.optim.lr_scheduler import _LRScheduler class PolyScheduler(_LRScheduler): def __init__(self, optimizer, base_lr, max_steps, warmup_steps, last_epoch=-1): self.base_lr = base_lr self.warmup_lr_init = 0.0001 self.max_steps: int = max_steps self.warmup_steps: int = warmup_steps self.power = 2 super(PolyScheduler, self).__init__(optimizer, -1, False) self.last_epoch = last_epoch def get_warmup_lr(self): alpha = float(self.last_epoch) / float(self.warmup_steps) return [self.base_lr * alpha for _ in self.optimizer.param_groups] def get_lr(self): if self.last_epoch == -1: return [self.warmup_lr_init for _ in self.optimizer.param_groups] if self.last_epoch < self.warmup_steps: return self.get_warmup_lr() else: alpha = pow( 1 - float(self.last_epoch - self.warmup_steps) / float(self.max_steps - self.warmup_steps), self.power, ) return [self.base_lr * alpha for _ in self.optimizer.param_groups]