File size: 931 Bytes
8918ac7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, StepLR
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup

def create_scheduler(args, optimizer):
    if not args.scheduler:
        return None
        
    num_training_steps = args.num_training_steps
    num_warmup_steps = args.warmup_steps or num_training_steps // 10
    
    scheduler_dict = {
        'linear': lambda: get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        ),
        'cosine': lambda: get_cosine_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps
        ),
        'step': lambda: StepLR(optimizer, step_size=30, gamma=0.1)
    }
    
    return scheduler_dict[args.scheduler]()