File size: 1,327 Bytes
1d5604f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
#!/usr/bin/env python3
# coding=utf-8

from utility.schedule.linear_lr import LinearLr


def multi_scheduler_wrapper(optimizer, args, steps_per_epoch):
    n_layers = (len(optimizer.param_groups) - 2) // 2

    return MultiScheduler(
        [
            LinearLr(optimizer.param_groups[i], args.encoder_learning_rate * (args.layerwise_lr_decay ** i), args.epochs * steps_per_epoch, False, args.lr_decay_multiplier)
            for i in range(n_layers)
        ]
        +
        [
            LinearLr(optimizer.param_groups[n_layers + i], args.encoder_learning_rate * (args.layerwise_lr_decay ** i), args.epochs * steps_per_epoch, False, args.lr_decay_multiplier)
            for i in range(n_layers)
        ]
        +
        [
            LinearLr(optimizer.param_groups[-2], args.decoder_learning_rate, args.epochs * steps_per_epoch, False, args.lr_decay_multiplier),
            LinearLr(optimizer.param_groups[-1], args.decoder_learning_rate, args.epochs * steps_per_epoch, False, args.lr_decay_multiplier)
        ]
    )


class MultiScheduler:
    def __init__(self, schedulers):
        self.schedulers = schedulers

    def __call__(self, epoch):
        for scheduler in self.schedulers:
            scheduler(epoch)

    def lr(self) -> float:
        return [scheduler.lr() for scheduler in self.schedulers]