steerers / training /optim.py
ryanjg's picture
init upload
33b542e verified
import math
from typing import Optional
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
def get_scheduler(
scheduler_name: Optional[str], optimizer: optim.Optimizer, **kwargs
):
def get_warmup_lambda(warm_up_steps, training_steps):
def lr_lambda(steps):
if steps < warm_up_steps:
return (steps + 1) / warm_up_steps
else:
return (training_steps - steps) / (
training_steps - warm_up_steps
)
return lr_lambda
# heavily derived from hugging face although copilot helped.
def get_warmup_cosine_lambda(warm_up_steps, training_steps, lr_end):
def lr_lambda(steps):
if steps < warm_up_steps:
return (steps + 1) / warm_up_steps
else:
progress = (steps - warm_up_steps) / (
training_steps - warm_up_steps
)
return lr_end + 0.5 * (1 - lr_end) * (
1 + math.cos(math.pi * progress)
)
return lr_lambda
if scheduler_name is None or scheduler_name.lower() == "constant":
return lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda steps: 1.0)
elif scheduler_name.lower() == "constantwithwarmup":
warm_up_steps = kwargs.get("warm_up_steps", 0)
return lr_scheduler.LambdaLR(
optimizer,
lr_lambda=lambda steps: min(1.0, (steps + 1) / warm_up_steps),
)
else:
raise ValueError(f"Unsupported scheduler: {scheduler_name}")