Spaces:
Sleeping
Sleeping
""" Scheduler Factory | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
from .timm.cosine_lr import CosineLRScheduler | |
from .timm.tanh_lr import TanhLRScheduler | |
from .timm.step_lr import StepLRScheduler | |
from .timm.plateau_lr import PlateauLRScheduler | |
import torch | |
def create_scheduler(args, optimizer, **kwargs): | |
num_epochs = args.epochs | |
if getattr(args, 'lr_noise', None) is not None: | |
lr_noise = getattr(args, 'lr_noise') | |
if isinstance(lr_noise, (list, tuple)): | |
noise_range = [n * num_epochs for n in lr_noise] | |
if len(noise_range) == 1: | |
noise_range = noise_range[0] | |
else: | |
noise_range = lr_noise * num_epochs | |
else: | |
noise_range = None | |
lr_scheduler = None | |
if args.lr_policy == 'cosine': | |
lr_scheduler = CosineLRScheduler( | |
optimizer, | |
t_initial=num_epochs, | |
t_mul=getattr(args, 'lr_cycle_mul', 1.), | |
lr_min=args.lr_min, | |
decay_rate=args.decay_rate, | |
warmup_lr_init=args.warmup_lr, | |
warmup_t=args.warmup_epochs, | |
cycle_limit=getattr(args, 'lr_cycle_limit', 1), | |
t_in_epochs=True, | |
noise_range_t=noise_range, | |
noise_pct=getattr(args, 'lr_noise_pct', 0.67), | |
noise_std=getattr(args, 'lr_noise_std', 1.), | |
noise_seed=getattr(args, 'seed', 42), | |
) | |
num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS | |
elif args.lr_policy == 'tanh': | |
lr_scheduler = TanhLRScheduler( | |
optimizer, | |
t_initial=num_epochs, | |
t_mul=getattr(args, 'lr_cycle_mul', 1.), | |
lr_min=args.min_lr, | |
warmup_lr_init=args.warmup_lr, | |
warmup_t=args.warmup_epochs, | |
cycle_limit=getattr(args, 'lr_cycle_limit', 1), | |
t_in_epochs=True, | |
noise_range_t=noise_range, | |
noise_pct=getattr(args, 'lr_noise_pct', 0.67), | |
noise_std=getattr(args, 'lr_noise_std', 1.), | |
noise_seed=getattr(args, 'seed', 42), | |
) | |
num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS | |
elif args.lr_policy == 'step': | |
lr_scheduler = StepLRScheduler( | |
optimizer, | |
decay_t=args.decay_epochs - getattr(kwargs, 'init_epoch', 0), # for D | |
decay_rate=args.decay_rate, | |
warmup_lr_init=args.warmup_lr, | |
warmup_t=args.warmup_epochs, | |
noise_range_t=noise_range, | |
noise_pct=getattr(args, 'lr_noise_pct', 0.67), | |
noise_std=getattr(args, 'lr_noise_std', 1.), | |
noise_seed=getattr(args, 'seed', 42), | |
) | |
elif args.lr_policy == 'plateau': | |
mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' | |
lr_scheduler = PlateauLRScheduler( | |
optimizer, | |
decay_rate=args.decay_rate, | |
patience_t=args.patience_epochs, | |
lr_min=args.min_lr, | |
mode=mode, | |
warmup_lr_init=args.warmup_lr, | |
warmup_t=args.warmup_epochs, | |
cooldown_t=0, | |
noise_range_t=noise_range, | |
noise_pct=getattr(args, 'lr_noise_pct', 0.67), | |
noise_std=getattr(args, 'lr_noise_std', 1.), | |
noise_seed=getattr(args, 'seed', 42), | |
) | |
elif args.lr_policy == "onecyclelr": | |
lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( | |
optimizer, | |
max_lr=args.LR, | |
total_steps=kwargs["total_steps"], | |
pct_start=args.PCT_START, | |
div_factor=args.DIV_FACTOR_ONECOS, | |
final_div_factor=args.FIN_DACTOR_ONCCOS, | |
) | |
elif args.lr_policy == "cosinerestart": | |
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( | |
optimizer, | |
T_0 = kwargs["total_steps"], | |
T_mult=2, | |
eta_min = 1e-6, | |
last_epoch=-1, | |
) | |
return lr_scheduler |