import torch class AlphaFoldLRScheduler(torch.optim.lr_scheduler._LRScheduler): """ Implements the learning rate schedule defined in the AlphaFold 2 supplement. A linear warmup is followed by a plateau at the maximum learning rate and then exponential decay. Note that the initial learning rate of the optimizer in question is ignored; use this class' base_lr parameter to specify the starting point of the warmup. """ def __init__(self, optimizer, last_epoch: int = -1, verbose: bool = False, base_lr: float = 0., max_lr: float = 0.001, warmup_no_steps: int = 1000, start_decay_after_n_steps: int = 50000, decay_every_n_steps: int = 50000, decay_factor: float = 0.95, ): step_counts = { "warmup_no_steps": warmup_no_steps, "start_decay_after_n_steps": start_decay_after_n_steps, } for k,v in step_counts.items(): if(v < 0): raise ValueError(f"{k} must be nonnegative") if(warmup_no_steps > start_decay_after_n_steps): raise ValueError( "warmup_no_steps must not exceed start_decay_after_n_steps" ) self.optimizer = optimizer self.last_epoch = last_epoch self.verbose = verbose self.base_lr = base_lr self.max_lr = max_lr self.warmup_no_steps = warmup_no_steps self.start_decay_after_n_steps = start_decay_after_n_steps self.decay_every_n_steps = decay_every_n_steps self.decay_factor = decay_factor super(AlphaFoldLRScheduler, self).__init__( optimizer, last_epoch=last_epoch, verbose=verbose, ) def state_dict(self): state_dict = { k:v for k,v in self.__dict__.items() if k not in ["optimizer"] } return state_dict def load_state_dict(self, state_dict): self.__dict__.update(state_dict) def get_lr(self): if(not self._get_lr_called_within_step): raise RuntimeError( "To get the last learning rate computed by the scheduler, use " "get_last_lr()" ) step_no = self.last_epoch if(step_no <= self.warmup_no_steps): lr = self.base_lr + (step_no / self.warmup_no_steps) * self.max_lr elif(step_no > self.start_decay_after_n_steps): steps_since_decay = step_no - self.start_decay_after_n_steps exp = (steps_since_decay // self.decay_every_n_steps) + 1 lr = self.max_lr * (self.decay_factor ** exp) else: # plateau lr = self.max_lr return [lr for group in self.optimizer.param_groups]