File size: 1,625 Bytes
10b4a5f
 
 
 
 
 
 
 
 
358ab8f
 
 
10b4a5f
358ab8f
 
 
10b4a5f
 
 
 
 
 
 
 
 
 
 
 
 
 
358ab8f
10b4a5f
 
358ab8f
10b4a5f
358ab8f
 
 
 
 
 
 
 
 
10b4a5f
358ab8f
 
 
10b4a5f
358ab8f
 
10b4a5f
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import torch
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR


def build_optimizer(model, config):
    name = config.TRAINER.OPTIMIZER
    lr = config.TRAINER.TRUE_LR

    if name == "adam":
        return torch.optim.Adam(
            model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY
        )
    elif name == "adamw":
        return torch.optim.AdamW(
            model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY
        )
    else:
        raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")


def build_scheduler(config, optimizer):
    """
    Returns:
        scheduler (dict):{
            'scheduler': lr_scheduler,
            'interval': 'step',  # or 'epoch'
            'monitor': 'val_f1', (optional)
            'frequency': x, (optional)
        }
    """
    scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL}
    name = config.TRAINER.SCHEDULER

    if name == "MultiStepLR":
        scheduler.update(
            {
                "scheduler": MultiStepLR(
                    optimizer,
                    config.TRAINER.MSLR_MILESTONES,
                    gamma=config.TRAINER.MSLR_GAMMA,
                )
            }
        )
    elif name == "CosineAnnealing":
        scheduler.update(
            {"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}
        )
    elif name == "ExponentialLR":
        scheduler.update(
            {"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}
        )
    else:
        raise NotImplementedError()

    return scheduler