Spaces:
Running
Running
File size: 1,445 Bytes
a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 |
import torch
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR
def build_optimizer(model, config):
name = config.TRAINER.OPTIMIZER
lr = config.TRAINER.TRUE_LR
if name == "adam":
return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY)
elif name == "adamw":
return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY)
else:
raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!")
def build_scheduler(config, optimizer):
"""
Returns:
scheduler (dict):{
'scheduler': lr_scheduler,
'interval': 'step', # or 'epoch'
'monitor': 'val_f1', (optional)
'frequency': x, (optional)
}
"""
scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL}
name = config.TRAINER.SCHEDULER
if name == 'MultiStepLR':
scheduler.update(
{'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)})
elif name == 'CosineAnnealing':
scheduler.update(
{'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)})
elif name == 'ExponentialLR':
scheduler.update(
{'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)})
else:
raise NotImplementedError()
return scheduler
|