|
import torch |
|
from typeguard import check_argument_types |
|
|
|
|
|
class SGD(torch.optim.SGD): |
|
"""Thin inheritance of torch.optim.SGD to bind the required arguments, 'lr' |
|
|
|
Note that |
|
the arguments of the optimizer invoked by AbsTask.main() |
|
must have default value except for 'param'. |
|
|
|
I can't understand why only SGD.lr doesn't have the default value. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
params, |
|
lr: float = 0.1, |
|
momentum: float = 0.0, |
|
dampening: float = 0.0, |
|
weight_decay: float = 0.0, |
|
nesterov: bool = False, |
|
): |
|
assert check_argument_types() |
|
super().__init__( |
|
params, |
|
lr=lr, |
|
momentum=momentum, |
|
dampening=dampening, |
|
weight_decay=weight_decay, |
|
nesterov=nesterov, |
|
) |
|
|