File size: 812 Bytes
41e3185 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
import torch.nn as nn
import torch.optim as optim
def get_optimizer(
optimizer_name: str,
model: nn.Module,
learning_rate: float,
momentum: float = 0.9,
dampening: float = 0.0,
weight_decay: float = 0.0001,
nesterov: bool = True,
) -> optim.Optimizer:
assert optimizer_name in ["SGD", "Adam"]
print(f"{optimizer_name} will be used as an optimizer.")
if optimizer_name == "Adam":
optimizer = optim.Adam(model.parameters(), lr=learning_rate) # only model,no text model
elif optimizer_name == "SGD":
optimizer = optim.SGD(
model.parameters(),
lr=learning_rate,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
return optimizer
|