virtex-redcaps / virtex /optim /lr_scheduler.py
kdexd's picture
Black + isort, remove unused virtx files.
8d0e872
raw history blame
No virus
6.61 kB
import bisect
import math
from typing import List
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LambdaLR
class LinearWarmupNoDecayLR(LambdaLR):
r"""
A learning rate scheduler which linearly increases learning rate from 0
LR, and further keeps it constant throughout training.
Parameters
----------
optimizer: torch.optim.Optimizer
Wrapper optimizer.
total_steps: int
Total epochs (or iterations) for training.
warmup_steps: int
Number of first few steps to do linear warmup.
last_epoch: int, optional (default = -1)
The index of last step (epoch or iteration). We named it ``last_epoch``
instead of ``last_step`` to keep the naming consistent with other LR
schedulers in PyTorch.
"""
def __init__(
self,
optimizer: Optimizer,
total_steps: int,
warmup_steps: int,
last_epoch: int = -1,
):
assert (
warmup_steps < total_steps
), "Warmup steps should be less than total steps."
self.tsteps = total_steps
self.wsteps = warmup_steps
super().__init__(optimizer, self._lr_multiplier, last_epoch)
def _lr_multiplier(self, step: int) -> float:
multiplier = step / float(max(1, self.wsteps)) if step < self.wsteps else 1
return max(0, multiplier)
class LinearWarmupMultiStepLR(LambdaLR):
r"""
A learning rate scheduler which linearly increases learning rate from 0
LR, and further decreases it by gamma once the number of steps reaches one
of the milestones.
Parameters
----------
optimizer: torch.optim.Optimizer
Wrapper optimizer.
total_steps: int
Total epochs (or iterations) for training.
warmup_steps: int
Number of first few steps to do linear warmup.
milestones: List[int]
List of step indices (epochs or iterations depending on context). Must
be increasing.
gamma: float, optional (default = 0.1)
Multiplicative factor of learning rate decay.
last_epoch: int, optional (default = -1)
The index of last step (epoch or iteration). We named it ``last_epoch``
instead of ``last_step`` to keep the naming consistent with other LR
schedulers in PyTorch.
"""
def __init__(
self,
optimizer: Optimizer,
total_steps: int,
warmup_steps: int,
milestones: List[int],
gamma: float = 0.1,
last_epoch: int = -1,
):
self.wsteps = warmup_steps
self.milestones = milestones
self.gamma = gamma
# Keep a track of number of milestones encountered.
self.milestones_so_far = 0
# Common sanity checks.
assert milestones == sorted(milestones), "milestones must be increasing"
assert milestones[0] > warmup_steps, "first milestone must be after warmup"
assert (
milestones[-1] < total_steps
), "last milestone must be less than total steps"
super().__init__(optimizer, self._lr_multiplier, last_epoch)
def _lr_multiplier(self, step: int) -> float:
if step < self.wsteps:
# Linear warmup.
multiplier = step / float(max(1, self.wsteps))
else:
# Step decay based on milestones.
multiplier = self.gamma ** bisect.bisect_right(self.milestones, step)
# Avoid negative learning rate.
return max(0, multiplier)
class LinearWarmupLinearDecayLR(LambdaLR):
r"""
A learning rate scheduler which linearly increases learning rate from 0
LR, and further decreases it linearly to zero.
Parameters
----------
optimizer: torch.optim.Optimizer
Wrapper optimizer.
total_steps: int
Total epochs (or iterations) for training.
warmup_steps: int
Number of first few steps to do linear warmup.
last_epoch: int, optional (default = -1)
The index of last step (epoch or iteration). We named it ``last_epoch``
instead of ``last_step`` to keep the naming consistent with other LR
schedulers in PyTorch.
"""
def __init__(
self,
optimizer: Optimizer,
total_steps: int,
warmup_steps: int,
last_epoch: int = -1,
):
assert (
warmup_steps < total_steps
), "Warmup steps should be less than total steps."
self.tsteps = total_steps
self.wsteps = warmup_steps
super().__init__(optimizer, self._lr_multiplier, last_epoch)
def _lr_multiplier(self, step: int) -> float:
if step < self.wsteps:
# Linear warmup.
multiplier = step / float(max(1, self.wsteps))
else:
# Linear decay.
multiplier = (self.tsteps - step) / (self.tsteps - self.wsteps)
# Avoid negative learning rate.
return max(0, multiplier)
class LinearWarmupCosineAnnealingLR(LambdaLR):
r"""
A learning rate scheduler which linearly increases learning rate from 0
LR, and further decreases it to zero by cosine decay. After linear warmup,
the LR decays as:
.. math::
\eta_t = \eta_{max}\cos^2(\frac{T_{cur} - T_{warm}}{T_{max} - T_{warm}}\frac{\pi}{2})
Parameters
----------
optimizer: torch.optim.Optimizer
Wrapper optimizer.
total_steps: int
Total epochs (or iterations) for training.
warmup_steps: int
Number of first few steps to do linear warmup.
last_epoch: int, optional (default = -1)
The index of last step (epoch or iteration). We named it ``last_epoch``
instead of ``last_step`` to keep the naming consistent with other LR
schedulers in PyTorch.
"""
def __init__(
self,
optimizer: Optimizer,
total_steps: int,
warmup_steps: int,
last_epoch: int = -1,
):
assert (
warmup_steps < total_steps
), "Warmup steps should be less than total steps."
self.tsteps = total_steps
self.wsteps = warmup_steps
super().__init__(optimizer, self._lr_multiplier, last_epoch)
def _lr_multiplier(self, step: int) -> float:
if step < self.wsteps:
# Linear warmup.
multiplier = step / float(max(1, self.wsteps))
else:
# Cosine annealing decay.
cos_factor = (step - self.wsteps) / (self.tsteps - self.wsteps)
multiplier = math.cos(cos_factor * (math.pi / 2)) ** 2
# Avoid negative learning rate.
return max(0, multiplier)