""" Scheduler Factory Hacked together by / Copyright 2020 Ross Wightman """ from .timm.cosine_lr import CosineLRScheduler from .timm.tanh_lr import TanhLRScheduler from .timm.step_lr import StepLRScheduler from .timm.plateau_lr import PlateauLRScheduler import torch def create_scheduler(args, optimizer, **kwargs): num_epochs = args.epochs if getattr(args, 'lr_noise', None) is not None: lr_noise = getattr(args, 'lr_noise') if isinstance(lr_noise, (list, tuple)): noise_range = [n * num_epochs for n in lr_noise] if len(noise_range) == 1: noise_range = noise_range[0] else: noise_range = lr_noise * num_epochs else: noise_range = None lr_scheduler = None if args.lr_policy == 'cosine': lr_scheduler = CosineLRScheduler( optimizer, t_initial=num_epochs, t_mul=getattr(args, 'lr_cycle_mul', 1.), lr_min=args.lr_min, decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, cycle_limit=getattr(args, 'lr_cycle_limit', 1), t_in_epochs=True, noise_range_t=noise_range, noise_pct=getattr(args, 'lr_noise_pct', 0.67), noise_std=getattr(args, 'lr_noise_std', 1.), noise_seed=getattr(args, 'seed', 42), ) num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS elif args.lr_policy == 'tanh': lr_scheduler = TanhLRScheduler( optimizer, t_initial=num_epochs, t_mul=getattr(args, 'lr_cycle_mul', 1.), lr_min=args.min_lr, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, cycle_limit=getattr(args, 'lr_cycle_limit', 1), t_in_epochs=True, noise_range_t=noise_range, noise_pct=getattr(args, 'lr_noise_pct', 0.67), noise_std=getattr(args, 'lr_noise_std', 1.), noise_seed=getattr(args, 'seed', 42), ) num_epochs = lr_scheduler.get_cycle_length() + args.COOLDOWN_EPOCHS elif args.lr_policy == 'step': lr_scheduler = StepLRScheduler( optimizer, decay_t=args.decay_epochs - getattr(kwargs, 'init_epoch', 0), # for D decay_rate=args.decay_rate, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, noise_range_t=noise_range, noise_pct=getattr(args, 'lr_noise_pct', 0.67), noise_std=getattr(args, 'lr_noise_std', 1.), noise_seed=getattr(args, 'seed', 42), ) elif args.lr_policy == 'plateau': mode = 'min' if 'loss' in getattr(args, 'eval_metric', '') else 'max' lr_scheduler = PlateauLRScheduler( optimizer, decay_rate=args.decay_rate, patience_t=args.patience_epochs, lr_min=args.min_lr, mode=mode, warmup_lr_init=args.warmup_lr, warmup_t=args.warmup_epochs, cooldown_t=0, noise_range_t=noise_range, noise_pct=getattr(args, 'lr_noise_pct', 0.67), noise_std=getattr(args, 'lr_noise_std', 1.), noise_seed=getattr(args, 'seed', 42), ) elif args.lr_policy == "onecyclelr": lr_scheduler = torch.optim.lr_scheduler.OneCycleLR( optimizer, max_lr=args.LR, total_steps=kwargs["total_steps"], pct_start=args.PCT_START, div_factor=args.DIV_FACTOR_ONECOS, final_div_factor=args.FIN_DACTOR_ONCCOS, ) elif args.lr_policy == "cosinerestart": lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts( optimizer, T_0 = kwargs["total_steps"], T_mult=2, eta_min = 1e-6, last_epoch=-1, ) return lr_scheduler