Spaces:
Sleeping
Sleeping
File size: 1,624 Bytes
4e46a55 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import math
# LrStepTracker
class LrStepTracker:
"""
----------
Author: Ryan Marshall
Modified: Damon Gwinn
----------
Class for custom learn rate scheduler (to be used by torch.optim.lr_scheduler.LambdaLR).
Learn rate for each step (batch) given the warmup steps is:
lr = [ 1/sqrt(d_model) ] * min[ 1/sqrt(step) , step * (warmup_steps)^-1.5 ]
This is from Attention is All you Need (https://arxiv.org/abs/1706.03762)
----------
"""
def __init__(self, model_dim=512, warmup_steps=4000, init_steps=0):
# Store Values
self.warmup_steps = warmup_steps
self.model_dim = model_dim
self.init_steps = init_steps
# Begin Calculations
self.invsqrt_dim = (1 / math.sqrt(model_dim))
self.invsqrt_warmup = (1 / (warmup_steps * math.sqrt(warmup_steps)))
# step
def step(self, step):
"""
----------
Author: Ryan Marshall
Modified: Damon Gwinn
----------
Method to pass to LambdaLR. Increments the step and computes the new learn rate.
----------
"""
step += self.init_steps
if(step <= self.warmup_steps):
return self.invsqrt_dim * self.invsqrt_warmup * step
else:
invsqrt_step = (1 / math.sqrt(step))
return self.invsqrt_dim * invsqrt_step
# get_lr
def get_lr(optimizer):
"""
----------
Author: Damon Gwinn
----------
Hack to get the current learn rate of the model
----------
"""
for param_group in optimizer.param_groups:
return param_group['lr']
|