|
""" Scheduler Factory |
|
Hacked together by / Copyright 2020 Ross Wightman |
|
""" |
|
from .cosine_lr import CosineLRScheduler |
|
from .tanh_lr import TanhLRScheduler |
|
from .step_lr import StepLRScheduler |
|
from .plateau_lr import PlateauLRScheduler |
|
|
|
|
|
def create_scheduler(args, optimizer): |
|
num_epochs = args.epochs |
|
|
|
if getattr(args, 'lr_noise', None) is not None: |
|
lr_noise = getattr(args, 'lr_noise') |
|
if isinstance(lr_noise, (list, tuple)): |
|
noise_range = [n * num_epochs for n in lr_noise] |
|
if len(noise_range) == 1: |
|
noise_range = noise_range[0] |
|
else: |
|
noise_range = lr_noise * num_epochs |
|
else: |
|
noise_range = None |
|
|
|
lr_scheduler = None |
|
|
|
scheduler_groups = getattr(args, 'scheduler_groups', None) |
|
|
|
if args.sched == 'cosine': |
|
lr_scheduler = CosineLRScheduler( |
|
optimizer, |
|
t_initial=num_epochs, |
|
t_mul=getattr(args, 'lr_cycle_mul', 1.), |
|
lr_min=args.min_lr, |
|
decay_rate=args.decay_rate, |
|
warmup_lr_init=args.warmup_lr, |
|
warmup_t=args.warmup_epochs, |
|
cycle_limit=getattr(args, 'lr_cycle_limit', 1), |
|
t_in_epochs=True, |
|
noise_range_t=noise_range, |
|
noise_pct=getattr(args, 'lr_noise_pct', 0.67), |
|
noise_std=getattr(args, 'lr_noise_std', 1.), |
|
noise_seed=getattr(args, 'seed', 42), |
|
scheduler_groups=scheduler_groups, |
|
) |
|
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs |
|
elif args.sched == 'tanh': |
|
lr_scheduler = TanhLRScheduler( |
|
optimizer, |
|
t_initial=num_epochs, |
|
t_mul=getattr(args, 'lr_cycle_mul', 1.), |
|
lr_min=args.min_lr, |
|
warmup_lr_init=args.warmup_lr, |
|
warmup_t=args.warmup_epochs, |
|
cycle_limit=getattr(args, 'lr_cycle_limit', 1), |
|
t_in_epochs=True, |
|
noise_range_t=noise_range, |
|
noise_pct=getattr(args, 'lr_noise_pct', 0.67), |
|
noise_std=getattr(args, 'lr_noise_std', 1.), |
|
noise_seed=getattr(args, 'seed', 42), |
|
) |
|
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs |
|
elif args.sched == 'step': |
|
lr_scheduler = StepLRScheduler( |
|
optimizer, |
|
decay_t=args.decay_epochs, |
|
decay_rate=args.decay_rate, |
|
warmup_lr_init=args.warmup_lr, |
|
warmup_t=args.warmup_epochs, |
|
noise_range_t=noise_range, |
|
noise_pct=getattr(args, 'lr_noise_pct', 0.67), |
|
noise_std=getattr(args, 'lr_noise_std', 1.), |
|
noise_seed=getattr(args, 'seed', 42), |
|
) |
|
elif args.sched == 'plateau': |
|
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' |
|
lr_scheduler = PlateauLRScheduler( |
|
optimizer, |
|
decay_rate=args.decay_rate, |
|
patience_t=args.patience_epochs, |
|
lr_min=args.min_lr, |
|
mode=mode, |
|
warmup_lr_init=args.warmup_lr, |
|
warmup_t=args.warmup_epochs, |
|
cooldown_t=0, |
|
noise_range_t=noise_range, |
|
noise_pct=getattr(args, 'lr_noise_pct', 0.67), |
|
noise_std=getattr(args, 'lr_noise_std', 1.), |
|
noise_seed=getattr(args, 'seed', 42), |
|
) |
|
|
|
return lr_scheduler, num_epochs |
|
|