import torch from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR def build_optimizer(model, config): name = config.TRAINER.OPTIMIZER lr = config.TRAINER.TRUE_LR if name == "adam": return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY) elif name == "adamw": return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY) else: raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") def build_scheduler(config, optimizer): """ Returns: scheduler (dict):{ 'scheduler': lr_scheduler, 'interval': 'step', # or 'epoch' 'monitor': 'val_f1', (optional) 'frequency': x, (optional) } """ scheduler = {'interval': config.TRAINER.SCHEDULER_INTERVAL} name = config.TRAINER.SCHEDULER if name == 'MultiStepLR': scheduler.update( {'scheduler': MultiStepLR(optimizer, config.TRAINER.MSLR_MILESTONES, gamma=config.TRAINER.MSLR_GAMMA)}) elif name == 'CosineAnnealing': scheduler.update( {'scheduler': CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)}) elif name == 'ExponentialLR': scheduler.update( {'scheduler': ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)}) else: raise NotImplementedError() return scheduler