File size: 1,443 Bytes
3b40f46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from torch.optim import lr_scheduler
from models.backbone.skip import skip


def get_scheduler(optimizer, opt):
    if opt.lr_policy == "linear":

        def lambda_rule(epoch):
            lr_l = 1.0 - max(0, epoch + opt.epoch_count - opt.n_epochs) / float(opt.n_epochs_decay + 1)
            return lr_l

        scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
    elif opt.lr_policy == "step":
        scheduler = lr_scheduler.StepLR(optimizer, step_size=opt.lr_decay_iters, gamma=0.1)
    elif opt.lr_policy == "plateau":
        scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.2, threshold=0.01, patience=5)
    elif opt.lr_policy == "cosine":
        scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
    else:
        return NotImplementedError("learning rate policy [%s] is not implemented", opt.lr_policy)
    return scheduler


def define_G(cfg):
    netG = skip(
        3,
        4,
        num_channels_down=[cfg["skip_n33d"]] * cfg["num_scales"]
        if isinstance(cfg["skip_n33d"], int)
        else cfg["skip_n33d"],
        num_channels_up=[cfg["skip_n33u"]] * cfg["num_scales"]
        if isinstance(cfg["skip_n33u"], int)
        else cfg["skip_n33u"],
        num_channels_skip=[cfg["skip_n11"]] * cfg["num_scales"]
        if isinstance(cfg["skip_n11"], int)
        else cfg["skip_n11"],
        need_bias=True,
    )
    return netG