File size: 1,624 Bytes
4e46a55
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import math

# LrStepTracker
class LrStepTracker:
    """
    ----------
    Author: Ryan Marshall
    Modified: Damon Gwinn
    ----------
    Class for custom learn rate scheduler (to be used by torch.optim.lr_scheduler.LambdaLR).

    Learn rate for each step (batch) given the warmup steps is:
        lr = [ 1/sqrt(d_model) ] * min[ 1/sqrt(step) , step * (warmup_steps)^-1.5 ]

    This is from Attention is All you Need (https://arxiv.org/abs/1706.03762)
    ----------
    """

    def __init__(self, model_dim=512, warmup_steps=4000, init_steps=0):
        # Store Values
        self.warmup_steps = warmup_steps
        self.model_dim = model_dim
        self.init_steps = init_steps

        # Begin Calculations
        self.invsqrt_dim = (1 / math.sqrt(model_dim))
        self.invsqrt_warmup = (1 / (warmup_steps * math.sqrt(warmup_steps)))

    # step
    def step(self, step):
        """
        ----------
        Author: Ryan Marshall
        Modified: Damon Gwinn
        ----------
        Method to pass to LambdaLR. Increments the step and computes the new learn rate.
        ----------
        """

        step += self.init_steps
        if(step <= self.warmup_steps):
            return self.invsqrt_dim * self.invsqrt_warmup * step
        else:
            invsqrt_step = (1 / math.sqrt(step))
            return self.invsqrt_dim * invsqrt_step

# get_lr
def get_lr(optimizer):
    """
    ----------
    Author: Damon Gwinn
    ----------
    Hack to get the current learn rate of the model
    ----------
    """

    for param_group in optimizer.param_groups:
        return param_group['lr']