from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, StepLR from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup def create_scheduler(args, optimizer): if not args.scheduler: return None num_training_steps = args.num_training_steps num_warmup_steps = args.warmup_steps or num_training_steps // 10 scheduler_dict = { 'linear': lambda: get_linear_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ), 'cosine': lambda: get_cosine_schedule_with_warmup( optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps ), 'step': lambda: StepLR(optimizer, step_size=30, gamma=0.1) } return scheduler_dict[args.scheduler]()