hma / magvit2 /modules /scheduler /lr_scheduler.py
LeroyWaa's picture
draft
246c106
raw
history blame
880 Bytes
import math
import torch
from functools import partial
# step scheduler
def fn_LinearWarmup(warmup_steps, step):
if step < warmup_steps: # linear warmup
return float(step) / float(max(1, warmup_steps))
else:
return 1.0
def Scheduler_LinearWarmup(warmup_steps):
return partial(fn_LinearWarmup, warmup_steps)
def fn_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min, step):
if step < warmup_steps: # linear warmup
return float(step) / float(max(1, warmup_steps))
else: # cosine learning rate schedule
multipler = 0.5 * (math.cos((step - warmup_steps) / (max_steps - warmup_steps) * math.pi) + 1)
return max(multipler, multipler_min)
def Scheduler_LinearWarmup_CosineDecay(warmup_steps, max_steps, multipler_min):
return partial(fn_LinearWarmup_CosineDecay, warmup_steps, max_steps, multipler_min)