tobiasc's picture
Initial commit
ad16788
raw
history blame contribute delete
828 Bytes
import torch
from typeguard import check_argument_types
class SGD(torch.optim.SGD):
"""Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr'
Note that
the arguments of the optimizer invoked by AbsTask.main()
must have default value except for 'param'.
I can't understand why only SGD.lr doesn't have the default value.
"""
def __init__(
self,
params,
lr: float = 0.1,
momentum: float = 0.0,
dampening: float = 0.0,
weight_decay: float = 0.0,
nesterov: bool = False,
):
assert check_argument_types()
super().__init__(
params,
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)