zhangzhi's picture
init commit
a476bbf verified
raw
history blame
6.18 kB
import math
from torch.optim.lr_scheduler import _LRScheduler, CosineAnnealingLR
class ConstantLRScheduler(_LRScheduler):
def __init__(
self,
optimizer,
last_epoch: int = -1,
verbose: bool = False,
init_lr: float = 0.0,
):
"""
This is an implementation of constant learning rate scheduler.
Args:
optimizer: Optimizer
last_epoch: The index of last epoch. Default: -1
verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
init_lr: Initial learning rate
"""
self.init_lr = init_lr
super().__init__(optimizer, last_epoch, verbose)
def state_dict(self):
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
return state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
if not self._get_lr_called_within_step:
raise RuntimeError(
"To get the last learning rate computed by the scheduler, use "
"get_last_lr()"
)
return [self.init_lr for group in self.optimizer.param_groups]
class CosineAnnealingLRScheduler(_LRScheduler):
def __init__(
self,
optimizer,
last_epoch: int = -1,
verbose: bool = False,
init_lr: float = 4.0e-5,
max_lr: float = 4e-4,
final_lr: float = 4e-5,
warmup_steps: int = 2000,
cosine_steps: int = 10000,
):
"""
This is an implementation of cosine annealing learning rate scheduler.
Args:
optimizer: Optimizer
last_epoch: The index of last epoch. Default: -1
verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
init_lr: Initial learning rate
max_lr: Maximum learning rate after warmup
final_lr: Final learning rate after decay
warmup_steps: Number of steps for warmup
cosine_steps: Number of steps for cosine annealing
"""
self.init_lr = init_lr
self.max_lr = max_lr
self.final_lr = final_lr
self.warmup_steps = warmup_steps
self.cosine_steps = cosine_steps
super(CosineAnnealingLRScheduler, self).__init__(optimizer, last_epoch, verbose)
def state_dict(self):
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
return state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
if not self._get_lr_called_within_step:
raise RuntimeError(
"To get the last learning rate computed by the scheduler, use "
"get_last_lr()"
)
step_no = self.last_epoch
if step_no <= self.warmup_steps:
lr = self.init_lr + step_no / self.warmup_steps * (
self.max_lr - self.init_lr
)
else:
lr = self.final_lr + 0.5 * (self.max_lr - self.final_lr) * (
1
+ math.cos(math.pi * (step_no - self.warmup_steps) / self.cosine_steps)
)
return [lr for group in self.optimizer.param_groups]
class Esm2LRScheduler(_LRScheduler):
def __init__(
self,
optimizer,
last_epoch: int = -1,
verbose: bool = False,
init_lr: float = 4e-5,
max_lr: float = 4e-4,
final_lr: float = 4e-5,
warmup_steps: int = 2000,
start_decay_after_n_steps: int = 500000,
end_decay_after_n_steps: int = 5000000,
on_use: bool = True,
):
"""
An implementation of ESM2's learning rate scheduler.
Args:
optimizer: Optimizer
last_epoch: The index of last epoch. Default: -1
verbose: If ``True``, prints a message to stdout for each update. Default: ``False``
init_lr: Initial learning rate
max_lr: Maximum learning rate after warmup
final_lr: Final learning rate after decay
warmup_steps: Number of steps for warmup
start_decay_after_n_steps: Start decay after this number of steps
end_decay_after_n_steps: End decay after this number of steps
on_use: Whether to use this scheduler. If ``False``, the scheduler will not change the learning rate
and will only use the ``init_lr``. Default: ``True``
"""
self.init_lr = init_lr
self.max_lr = max_lr
self.final_lr = final_lr
self.warmup_steps = warmup_steps
self.start_decay_after_n_steps = start_decay_after_n_steps
self.end_decay_after_n_steps = end_decay_after_n_steps
self.on_use = on_use
super(Esm2LRScheduler, self).__init__(optimizer, last_epoch, verbose)
def state_dict(self):
state_dict = {k: v for k, v in self.__dict__.items() if k not in ["optimizer"]}
return state_dict
def load_state_dict(self, state_dict):
self.__dict__.update(state_dict)
def get_lr(self):
if not self._get_lr_called_within_step:
raise RuntimeError(
"To get the last learning rate computed by the scheduler, use "
"get_last_lr()"
)
step_no = self.last_epoch
if not self.on_use:
return [base_lr for base_lr in self.base_lrs]
if step_no <= self.warmup_steps:
lr = self.init_lr + step_no / self.warmup_steps * (
self.max_lr - self.init_lr
)
elif step_no <= self.start_decay_after_n_steps:
lr = self.max_lr
elif step_no <= self.end_decay_after_n_steps:
portion = (step_no - self.start_decay_after_n_steps) / (
self.end_decay_after_n_steps - self.start_decay_after_n_steps
)
lr = self.max_lr - portion * (self.max_lr - self.final_lr)
else:
lr = self.final_lr
return [lr for group in self.optimizer.param_groups]