Spaces:
Runtime error
Runtime error
File size: 931 Bytes
8918ac7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
from torch.optim.lr_scheduler import LambdaLR, CosineAnnealingLR, StepLR
from transformers import get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup
def create_scheduler(args, optimizer):
if not args.scheduler:
return None
num_training_steps = args.num_training_steps
num_warmup_steps = args.warmup_steps or num_training_steps // 10
scheduler_dict = {
'linear': lambda: get_linear_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
),
'cosine': lambda: get_cosine_schedule_with_warmup(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
),
'step': lambda: StepLR(optimizer, step_size=30, gamma=0.1)
}
return scheduler_dict[args.scheduler]() |