|
|
|
|
|
|
|
|
|
|
|
import torch |
|
from torch.optim import Optimizer |
|
from typing import List, Optional, Tuple, Union |
|
|
|
|
|
def calc_lr(step, dim_embed, warmup_steps): |
|
return dim_embed ** (-0.5) * min(step ** (-0.5), step * warmup_steps ** (-1.5)) |
|
|
|
|
|
|
|
|
|
class NoamScheduler(torch.optim.lr_scheduler._LRScheduler): |
|
def __init__( |
|
self, |
|
base_lr: float, |
|
optimizer: torch.optim.Optimizer, |
|
dim_embed: int, |
|
warmup_steps: int, |
|
last_epoch: int = -1, |
|
verbose: bool = False, |
|
) -> None: |
|
self.dim_embed = dim_embed |
|
self.base_lr = base_lr |
|
self.warmup_steps = warmup_steps |
|
self.num_param_groups = len(optimizer.param_groups) |
|
|
|
super().__init__(optimizer, last_epoch, verbose) |
|
|
|
def get_lr(self) -> float: |
|
lr = self.base_lr * calc_lr(self._step_count, self.dim_embed, self.warmup_steps) |
|
return [lr] * self.num_param_groups |
|
|
|
def set_step(self, step: int): |
|
self._step_count = step |
|
|
|
|
|
class LRScheduler(object): |
|
""" |
|
Base-class for learning rate schedulers where the learning-rate depends on both the |
|
batch and the epoch. |
|
""" |
|
|
|
def __init__(self, optimizer: Optimizer, verbose: bool = False): |
|
|
|
if not isinstance(optimizer, Optimizer): |
|
raise TypeError("{} is not an Optimizer".format(type(optimizer).__name__)) |
|
self.optimizer = optimizer |
|
self.verbose = verbose |
|
|
|
for group in optimizer.param_groups: |
|
group.setdefault("base_lr", group["lr"]) |
|
|
|
self.base_lrs = [group["base_lr"] for group in optimizer.param_groups] |
|
|
|
self.epoch = 0 |
|
self.batch = 0 |
|
|
|
def state_dict(self): |
|
"""Returns the state of the scheduler as a :class:`dict`. |
|
|
|
It contains an entry for every variable in self.__dict__ which |
|
is not the optimizer. |
|
""" |
|
return { |
|
"base_lrs": self.base_lrs, |
|
"epoch": self.epoch, |
|
"batch": self.batch, |
|
} |
|
|
|
def load_state_dict(self, state_dict): |
|
"""Loads the schedulers state. |
|
|
|
Args: |
|
state_dict (dict): scheduler state. Should be an object returned |
|
from a call to :meth:`state_dict`. |
|
""" |
|
self.__dict__.update(state_dict) |
|
|
|
def get_last_lr(self) -> List[float]: |
|
"""Return last computed learning rate by current scheduler. Will be a list of float.""" |
|
return self._last_lr |
|
|
|
def get_lr(self): |
|
|
|
|
|
|
|
raise NotImplementedError |
|
|
|
def step_batch(self, batch: Optional[int] = None) -> None: |
|
|
|
|
|
|
|
|
|
|
|
if batch is not None: |
|
self.batch = batch |
|
else: |
|
self.batch = self.batch + 1 |
|
self._set_lrs() |
|
|
|
def step_epoch(self, epoch: Optional[int] = None): |
|
|
|
|
|
|
|
if epoch is not None: |
|
self.epoch = epoch |
|
else: |
|
self.epoch = self.epoch + 1 |
|
self._set_lrs() |
|
|
|
def _set_lrs(self): |
|
values = self.get_lr() |
|
assert len(values) == len(self.optimizer.param_groups) |
|
|
|
for i, data in enumerate(zip(self.optimizer.param_groups, values)): |
|
param_group, lr = data |
|
param_group["lr"] = lr |
|
self._last_lr = [group["lr"] for group in self.optimizer.param_groups] |
|
|
|
|
|
class Eden(LRScheduler): |
|
""" |
|
Eden scheduler. |
|
The basic formula (before warmup) is: |
|
lr = base_lr * (((batch**2 + lr_batches**2) / lr_batches**2) ** -0.25 * |
|
(((epoch**2 + lr_epochs**2) / lr_epochs**2) ** -0.25)) * warmup |
|
where `warmup` increases from linearly 0.5 to 1 over `warmup_batches` batches |
|
and then stays constant at 1. |
|
|
|
|
|
E.g. suggest base_lr = 0.04 (passed to optimizer) if used with ScaledAdam |
|
|
|
Args: |
|
optimizer: the optimizer to change the learning rates on |
|
lr_batches: the number of batches after which we start significantly |
|
decreasing the learning rate, suggest 5000. |
|
lr_epochs: the number of epochs after which we start significantly |
|
decreasing the learning rate, suggest 6 if you plan to do e.g. |
|
20 to 40 epochs, but may need smaller number if dataset is huge |
|
and you will do few epochs. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
optimizer: Optimizer, |
|
lr_batches: Union[int, float], |
|
lr_epochs: Union[int, float], |
|
warmup_batches: Union[int, float] = 500.0, |
|
verbose: bool = False, |
|
): |
|
super(Eden, self).__init__(optimizer, verbose) |
|
self.lr_batches = lr_batches |
|
self.lr_epochs = lr_epochs |
|
self.warmup_batches = warmup_batches |
|
|
|
def get_lr(self): |
|
factor = ( |
|
(self.batch**2 + self.lr_batches**2) / self.lr_batches**2 |
|
) ** -0.25 * ( |
|
((self.epoch**2 + self.lr_epochs**2) / self.lr_epochs**2) ** -0.25 |
|
) |
|
warmup_factor = ( |
|
1.0 |
|
if self.batch >= self.warmup_batches |
|
else 0.5 + 0.5 * (self.batch / self.warmup_batches) |
|
) |
|
|
|
return [x * factor * warmup_factor for x in self.base_lrs] |
|
|