File size: 1,438 Bytes
4817bcc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import torch
import numpy as np


class Optimizer():
    def __init__(self, parameters, optimizer, lr, eps, lr_scheduler, 
                **kwargs):

        # Setup torch optimizer
        self.opt_type = optimizer
        self.init_lr = lr
        self.sch_type = lr_scheduler
        opt = getattr(torch.optim, optimizer)
        if lr_scheduler == 'warmup':
            warmup_step = 4000.0
            init_lr = lr
            self.lr_scheduler = lambda step: init_lr * warmup_step ** 0.5 * \
                np.minimum((step+1)*warmup_step**-1.5, (step+1)**-0.5)
            self.opt = opt(parameters, lr=1.0)
        else:
            self.lr_scheduler = None
            self.opt = opt(parameters, lr=lr, eps=eps)  # ToDo: 1e-8 better?

    def get_opt_state_dict(self):
        return self.opt.state_dict()

    def load_opt_state_dict(self, state_dict):
        self.opt.load_state_dict(state_dict)

    def pre_step(self, step):
        if self.lr_scheduler is not None:
            cur_lr = self.lr_scheduler(step)
            for param_group in self.opt.param_groups:
                param_group['lr'] = cur_lr
        else:
            cur_lr = self.init_lr
        self.opt.zero_grad()
        return cur_lr 
 
    def step(self):
        self.opt.step()

    def create_msg(self):
        return ['Optim.Info.| Algo. = {}\t| Lr = {}\t (schedule = {})'
                .format(self.opt_type, self.init_lr, self.sch_type)]