import math # LrStepTracker class LrStepTracker: """ ---------- Author: Ryan Marshall Modified: Damon Gwinn ---------- Class for custom learn rate scheduler (to be used by torch.optim.lr_scheduler.LambdaLR). Learn rate for each step (batch) given the warmup steps is: lr = [ 1/sqrt(d_model) ] * min[ 1/sqrt(step) , step * (warmup_steps)^-1.5 ] This is from Attention is All you Need (https://arxiv.org/abs/1706.03762) ---------- """ def __init__(self, model_dim=512, warmup_steps=4000, init_steps=0): # Store Values self.warmup_steps = warmup_steps self.model_dim = model_dim self.init_steps = init_steps # Begin Calculations self.invsqrt_dim = (1 / math.sqrt(model_dim)) self.invsqrt_warmup = (1 / (warmup_steps * math.sqrt(warmup_steps))) # step def step(self, step): """ ---------- Author: Ryan Marshall Modified: Damon Gwinn ---------- Method to pass to LambdaLR. Increments the step and computes the new learn rate. ---------- """ step += self.init_steps if(step <= self.warmup_steps): return self.invsqrt_dim * self.invsqrt_warmup * step else: invsqrt_step = (1 / math.sqrt(step)) return self.invsqrt_dim * invsqrt_step # get_lr def get_lr(optimizer): """ ---------- Author: Damon Gwinn ---------- Hack to get the current learn rate of the model ---------- """ for param_group in optimizer.param_groups: return param_group['lr']