| from torch.optim import Optimizer |
| import math |
|
|
| class _LRScheduler(object): |
| def __init__(self, optimizer, last_epoch=-1): |
| if not isinstance(optimizer, Optimizer): |
| raise TypeError('{} is not an Optimizer'.format( |
| type(optimizer).__name__)) |
| self.optimizer = optimizer |
| if last_epoch == -1: |
| for group in optimizer.param_groups: |
| group.setdefault('initial_lr', group['lr']) |
| else: |
| for i, group in enumerate(optimizer.param_groups): |
| if 'initial_lr' not in group: |
| raise KeyError("param 'initial_lr' is not specified " |
| "in param_groups[{}] when resuming an optimizer".format(i)) |
| self.base_lrs = list(map(lambda group: group['initial_lr'], optimizer.param_groups)) |
| self.step(epoch = last_epoch + 1) |
| self.last_epoch = last_epoch |
|
|
| def state_dict(self): |
| """Returns the state of the scheduler as a :class:`dict`. |
| It contains an entry for every variable in self.__dict__ which |
| is not the optimizer. |
| """ |
| return {key: value for key, value in self.__dict__.items() if key != 'optimizer'} |
|
|
| def load_state_dict(self, state_dict): |
| """Loads the schedulers state. |
| Arguments: |
| state_dict (dict): scheduler state. Should be an object returned |
| from a call to :meth:`state_dict`. |
| """ |
| self.__dict__.update(state_dict) |
|
|
| def get_lr(self): |
| raise NotImplementedError |
|
|
| def step(self, epoch=None): |
| if epoch is None: |
| epoch = self.last_epoch + 1 |
| self.last_epoch = epoch |
| for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()): |
| param_group['lr'] = lr |
|
|
| class CosineSchedule(_LRScheduler): |
|
|
| def __init__(self, optimizer, K): |
| self.K = K |
| super().__init__(optimizer, -1) |
|
|
| def cosine(self, base_lr): |
| if self.K == 1: |
| return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (2-1))) |
| return base_lr * math.cos((99 * math.pi * (self.last_epoch)) / (200 * (self.K-1))) |
|
|
| def get_lr(self): |
| return [self.cosine(base_lr) for base_lr in self.base_lrs] |
| |
| def get_last_lr(self): |
| return self.get_lr() |
|
|
| class CosineAnnealingWarmUp(_LRScheduler): |
|
|
| def __init__(self, optimizer, warmup_length, T_max = 0, last_epoch = -1): |
| self.warmup_length = warmup_length |
| self.T_max = T_max |
| self.last_epoch = last_epoch |
|
|
| super().__init__(optimizer, last_epoch) |
|
|
| def cosine_lr(self, base_lr): |
|
|
| return base_lr * 0.5 * (1 + math.cos(math.pi * self.last_epoch / self.T_max)) |
|
|
| def warmup_lr(self, base_lr): |
|
|
| return base_lr * (self.last_epoch + 1) / self.warmup_length |
|
|
| def get_lr(self): |
| if self.last_epoch < self.warmup_length: |
| return [self.warmup_lr(base_lr) for base_lr in self.base_lrs] |
| else: |
| return [self.cosine_lr(base_lr) for base_lr in self.base_lrs] |
| |
| def get_last_lr(self): |
| assert self.T_max > 0, 'CosineAnnealingWarmUp is called with T_max <= 0, Check your code' |
| return self.get_lr() |
|
|
| class PatienceSchedule(_LRScheduler): |
|
|
| def __init__(self, optimizer, patience, factor): |
| self.factor = factor |
| self.patience = patience |
| self.best_loss = float('inf') |
| self.counter = 0 |
|
|
| super().__init__(optimizer, -1) |
|
|
| def step(self, current_loss = None, **kwargs): |
| |
| |
|
|
| if current_loss is None: |
| return 0 |
| |
| |
| if current_loss < self.best_loss: |
| self.best_loss = current_loss |
| self.counter = 0 |
| else: |
| |
| self.counter += 1 |
| |
| |
| if self.counter >= self.patience: |
| for param_group in self.optimizer.param_groups: |
| param_group['lr'] /= self.factor |
| print(f"Reducing learning rate to {self.optimizer.param_groups[0]['lr']:.5f}") |
| self.counter = 0 |
|
|
| def get_last_lr(self): |
| return self.optimizer.param_groups[0]['lr'] |