Spaces:
Runtime error
Runtime error
import bisect | |
import math | |
from typing import List | |
from torch.optim import Optimizer | |
from torch.optim.lr_scheduler import LambdaLR | |
class LinearWarmupNoDecayLR(LambdaLR): | |
r""" | |
A learning rate scheduler which linearly increases learning rate from 0 | |
LR, and further keeps it constant throughout training. | |
Parameters | |
---------- | |
optimizer: torch.optim.Optimizer | |
Wrapper optimizer. | |
total_steps: int | |
Total epochs (or iterations) for training. | |
warmup_steps: int | |
Number of first few steps to do linear warmup. | |
last_epoch: int, optional (default = -1) | |
The index of last step (epoch or iteration). We named it ``last_epoch`` | |
instead of ``last_step`` to keep the naming consistent with other LR | |
schedulers in PyTorch. | |
""" | |
def __init__( | |
self, | |
optimizer: Optimizer, | |
total_steps: int, | |
warmup_steps: int, | |
last_epoch: int = -1, | |
): | |
assert ( | |
warmup_steps < total_steps | |
), "Warmup steps should be less than total steps." | |
self.tsteps = total_steps | |
self.wsteps = warmup_steps | |
super().__init__(optimizer, self._lr_multiplier, last_epoch) | |
def _lr_multiplier(self, step: int) -> float: | |
multiplier = step / float(max(1, self.wsteps)) if step < self.wsteps else 1 | |
return max(0, multiplier) | |
class LinearWarmupMultiStepLR(LambdaLR): | |
r""" | |
A learning rate scheduler which linearly increases learning rate from 0 | |
LR, and further decreases it by gamma once the number of steps reaches one | |
of the milestones. | |
Parameters | |
---------- | |
optimizer: torch.optim.Optimizer | |
Wrapper optimizer. | |
total_steps: int | |
Total epochs (or iterations) for training. | |
warmup_steps: int | |
Number of first few steps to do linear warmup. | |
milestones: List[int] | |
List of step indices (epochs or iterations depending on context). Must | |
be increasing. | |
gamma: float, optional (default = 0.1) | |
Multiplicative factor of learning rate decay. | |
last_epoch: int, optional (default = -1) | |
The index of last step (epoch or iteration). We named it ``last_epoch`` | |
instead of ``last_step`` to keep the naming consistent with other LR | |
schedulers in PyTorch. | |
""" | |
def __init__( | |
self, | |
optimizer: Optimizer, | |
total_steps: int, | |
warmup_steps: int, | |
milestones: List[int], | |
gamma: float = 0.1, | |
last_epoch: int = -1, | |
): | |
self.wsteps = warmup_steps | |
self.milestones = milestones | |
self.gamma = gamma | |
# Keep a track of number of milestones encountered. | |
self.milestones_so_far = 0 | |
# Common sanity checks. | |
assert milestones == sorted(milestones), "milestones must be increasing" | |
assert milestones[0] > warmup_steps, "first milestone must be after warmup" | |
assert ( | |
milestones[-1] < total_steps | |
), "last milestone must be less than total steps" | |
super().__init__(optimizer, self._lr_multiplier, last_epoch) | |
def _lr_multiplier(self, step: int) -> float: | |
if step < self.wsteps: | |
# Linear warmup. | |
multiplier = step / float(max(1, self.wsteps)) | |
else: | |
# Step decay based on milestones. | |
multiplier = self.gamma ** bisect.bisect_right(self.milestones, step) | |
# Avoid negative learning rate. | |
return max(0, multiplier) | |
class LinearWarmupLinearDecayLR(LambdaLR): | |
r""" | |
A learning rate scheduler which linearly increases learning rate from 0 | |
LR, and further decreases it linearly to zero. | |
Parameters | |
---------- | |
optimizer: torch.optim.Optimizer | |
Wrapper optimizer. | |
total_steps: int | |
Total epochs (or iterations) for training. | |
warmup_steps: int | |
Number of first few steps to do linear warmup. | |
last_epoch: int, optional (default = -1) | |
The index of last step (epoch or iteration). We named it ``last_epoch`` | |
instead of ``last_step`` to keep the naming consistent with other LR | |
schedulers in PyTorch. | |
""" | |
def __init__( | |
self, | |
optimizer: Optimizer, | |
total_steps: int, | |
warmup_steps: int, | |
last_epoch: int = -1, | |
): | |
assert ( | |
warmup_steps < total_steps | |
), "Warmup steps should be less than total steps." | |
self.tsteps = total_steps | |
self.wsteps = warmup_steps | |
super().__init__(optimizer, self._lr_multiplier, last_epoch) | |
def _lr_multiplier(self, step: int) -> float: | |
if step < self.wsteps: | |
# Linear warmup. | |
multiplier = step / float(max(1, self.wsteps)) | |
else: | |
# Linear decay. | |
multiplier = (self.tsteps - step) / (self.tsteps - self.wsteps) | |
# Avoid negative learning rate. | |
return max(0, multiplier) | |
class LinearWarmupCosineAnnealingLR(LambdaLR): | |
r""" | |
A learning rate scheduler which linearly increases learning rate from 0 | |
LR, and further decreases it to zero by cosine decay. After linear warmup, | |
the LR decays as: | |
.. math:: | |
\eta_t = \eta_{max}\cos^2(\frac{T_{cur} - T_{warm}}{T_{max} - T_{warm}}\frac{\pi}{2}) | |
Parameters | |
---------- | |
optimizer: torch.optim.Optimizer | |
Wrapper optimizer. | |
total_steps: int | |
Total epochs (or iterations) for training. | |
warmup_steps: int | |
Number of first few steps to do linear warmup. | |
last_epoch: int, optional (default = -1) | |
The index of last step (epoch or iteration). We named it ``last_epoch`` | |
instead of ``last_step`` to keep the naming consistent with other LR | |
schedulers in PyTorch. | |
""" | |
def __init__( | |
self, | |
optimizer: Optimizer, | |
total_steps: int, | |
warmup_steps: int, | |
last_epoch: int = -1, | |
): | |
assert ( | |
warmup_steps < total_steps | |
), "Warmup steps should be less than total steps." | |
self.tsteps = total_steps | |
self.wsteps = warmup_steps | |
super().__init__(optimizer, self._lr_multiplier, last_epoch) | |
def _lr_multiplier(self, step: int) -> float: | |
if step < self.wsteps: | |
# Linear warmup. | |
multiplier = step / float(max(1, self.wsteps)) | |
else: | |
# Cosine annealing decay. | |
cos_factor = (step - self.wsteps) / (self.tsteps - self.wsteps) | |
multiplier = math.cos(cos_factor * (math.pi / 2)) ** 2 | |
# Avoid negative learning rate. | |
return max(0, multiplier) | |