workshop / LaSA /libs /optimizer.py
qiushuocheng's picture
Upload 173 files
41e3185
import torch.nn as nn
import torch.optim as optim
def get_optimizer(
optimizer_name: str,
model: nn.Module,
learning_rate: float,
momentum: float = 0.9,
dampening: float = 0.0,
weight_decay: float = 0.0001,
nesterov: bool = True,
) -> optim.Optimizer:
assert optimizer_name in ["SGD", "Adam"]
print(f"{optimizer_name} will be used as an optimizer.")
if optimizer_name == "Adam":
optimizer = optim.Adam(model.parameters(), lr=learning_rate) # only model,no text model
elif optimizer_name == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=learning_rate,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
return optimizer