|
import torch |
|
from torch.optim.lr_scheduler import MultiStepLR, CosineAnnealingLR, ExponentialLR |
|
|
|
|
|
def build_optimizer(model, config): |
|
name = config.TRAINER.OPTIMIZER |
|
lr = config.TRAINER.TRUE_LR |
|
|
|
if name == "adam": |
|
return torch.optim.Adam( |
|
model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAM_DECAY |
|
) |
|
elif name == "adamw": |
|
return torch.optim.AdamW( |
|
model.parameters(), lr=lr, weight_decay=config.TRAINER.ADAMW_DECAY |
|
) |
|
else: |
|
raise ValueError(f"TRAINER.OPTIMIZER = {name} is not a valid optimizer!") |
|
|
|
|
|
def build_scheduler(config, optimizer): |
|
""" |
|
Returns: |
|
scheduler (dict):{ |
|
'scheduler': lr_scheduler, |
|
'interval': 'step', # or 'epoch' |
|
'monitor': 'val_f1', (optional) |
|
'frequency': x, (optional) |
|
} |
|
""" |
|
scheduler = {"interval": config.TRAINER.SCHEDULER_INTERVAL} |
|
name = config.TRAINER.SCHEDULER |
|
|
|
if name == "MultiStepLR": |
|
scheduler.update( |
|
{ |
|
"scheduler": MultiStepLR( |
|
optimizer, |
|
config.TRAINER.MSLR_MILESTONES, |
|
gamma=config.TRAINER.MSLR_GAMMA, |
|
) |
|
} |
|
) |
|
elif name == "CosineAnnealing": |
|
scheduler.update( |
|
{"scheduler": CosineAnnealingLR(optimizer, config.TRAINER.COSA_TMAX)} |
|
) |
|
elif name == "ExponentialLR": |
|
scheduler.update( |
|
{"scheduler": ExponentialLR(optimizer, config.TRAINER.ELR_GAMMA)} |
|
) |
|
else: |
|
raise NotImplementedError() |
|
|
|
return scheduler |
|
|