Spaces:
Running
Running
""" | |
This file implements different learning rate schedulers | |
""" | |
import torch | |
def get_lr_scheduler(lr_decay, lr_decay_cfg, optimizer): | |
"""Get the learning rate scheduler according to the config.""" | |
# If no lr_decay is specified => return None | |
if (lr_decay == False) or (lr_decay_cfg is None): | |
schduler = None | |
# Exponential decay | |
elif (lr_decay == True) and (lr_decay_cfg["policy"] == "exp"): | |
schduler = torch.optim.lr_scheduler.ExponentialLR( | |
optimizer, gamma=lr_decay_cfg["gamma"] | |
) | |
# Unknown policy | |
else: | |
raise ValueError("[Error] Unknow learning rate decay policy!") | |
return schduler | |