EMAGE / optimizers /scheduler_factory.py
H-Liu1997's picture
Upload folder using huggingface_hub
2d47d90 verified
raw
history blame
3.97 kB
""" Scheduler Factory
Hacked together by / Copyright 2020 Ross Wightman
"""
from .timm.cosine_lr import CosineLRScheduler
from .timm.tanh_lr import TanhLRScheduler
from .timm.step_lr import StepLRScheduler
from .timm.plateau_lr import PlateauLRScheduler
import torch
def create_scheduler(args, optimizer, **kwargs):
num_epochs = args.epochs
if getattr(args, 'lr_noise', None) is not None:
lr_noise = getattr(args, 'lr_noise')
if isinstance(lr_noise, (list, tuple)):
noise_range = [n * num_epochs for n in lr_noise]
if len(noise_range) == 1:
noise_range = noise_range[0]
else:
noise_range = lr_noise * num_epochs
else:
noise_range = None
lr_scheduler = None
if args.lr_policy == 'cosine':
lr_scheduler = CosineLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=getattr(args, 'lr_cycle_mul', 1.),
lr_min=args.lr_min,
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
t_in_epochs=True,
noise_range_t=noise_range,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS
elif args.lr_policy == 'tanh':
lr_scheduler = TanhLRScheduler(
optimizer,
t_initial=num_epochs,
t_mul=getattr(args, 'lr_cycle_mul', 1.),
lr_min=args.min_lr,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cycle_limit=getattr(args, 'lr_cycle_limit', 1),
t_in_epochs=True,
noise_range_t=noise_range,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS
elif args.lr_policy == 'step':
lr_scheduler = StepLRScheduler(
optimizer,
decay_t=args.decay_epochs - getattr(kwargs, 'init_epoch', 0), # for D
decay_rate=args.decay_rate,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
noise_range_t=noise_range,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
elif args.lr_policy == 'plateau':
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max'
lr_scheduler = PlateauLRScheduler(
optimizer,
decay_rate=args.decay_rate,
patience_t=args.patience_epochs,
lr_min=args.min_lr,
mode=mode,
warmup_lr_init=args.warmup_lr,
warmup_t=args.warmup_epochs,
cooldown_t=0,
noise_range_t=noise_range,
noise_pct=getattr(args, 'lr_noise_pct', 0.67),
noise_std=getattr(args, 'lr_noise_std', 1.),
noise_seed=getattr(args, 'seed', 42),
)
elif args.lr_policy == "onecyclelr":
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer,
max_lr=args.LR,
total_steps=kwargs["total_steps"],
pct_start=args.PCT_START,
div_factor=args.DIV_FACTOR_ONECOS,
final_div_factor=args.FIN_DACTOR_ONCCOS,
)
elif args.lr_policy == "cosinerestart":
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
optimizer,
T_0 = kwargs["total_steps"],
T_mult=2,
eta_min = 1e-6,
last_epoch=-1,
)
return lr_scheduler