|
import torch |
|
from torch import Tensor |
|
from .optimizer import Optimizer, required, _use_grad_for_differentiable |
|
from typing import List, Optional |
|
|
|
__all__ = ['SGD', 'sgd'] |
|
|
|
class SGD(Optimizer): |
|
r"""Implements stochastic gradient descent (optionally with momentum). |
|
|
|
.. math:: |
|
\begin{aligned} |
|
&\rule{110mm}{0.4pt} \\ |
|
&\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta) |
|
\text{ (objective)}, \: \lambda \text{ (weight decay)}, \\ |
|
&\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)}, |
|
\:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex] |
|
&\rule{110mm}{0.4pt} \\ |
|
&\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\ |
|
&\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\ |
|
&\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\ |
|
&\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\ |
|
&\hspace{5mm}\textbf{if} \: \mu \neq 0 \\ |
|
&\hspace{10mm}\textbf{if} \: t > 1 \\ |
|
&\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\ |
|
&\hspace{10mm}\textbf{else} \\ |
|
&\hspace{15mm} \textbf{b}_t \leftarrow g_t \\ |
|
&\hspace{10mm}\textbf{if} \: \textit{nesterov} \\ |
|
&\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\ |
|
&\hspace{10mm}\textbf{else} \\[-1.ex] |
|
&\hspace{15mm} g_t \leftarrow \textbf{b}_t \\ |
|
&\hspace{5mm}\textbf{if} \: \textit{maximize} \\ |
|
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} + \gamma g_t \\[-1.ex] |
|
&\hspace{5mm}\textbf{else} \\[-1.ex] |
|
&\hspace{10mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex] |
|
&\rule{110mm}{0.4pt} \\[-1.ex] |
|
&\bf{return} \: \theta_t \\[-1.ex] |
|
&\rule{110mm}{0.4pt} \\[-1.ex] |
|
\end{aligned} |
|
|
|
Nesterov momentum is based on the formula from |
|
`On the importance of initialization and momentum in deep learning`__. |
|
|
|
Args: |
|
params (iterable): iterable of parameters to optimize or dicts defining |
|
parameter groups |
|
lr (float): learning rate |
|
momentum (float, optional): momentum factor (default: 0) |
|
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) |
|
dampening (float, optional): dampening for momentum (default: 0) |
|
nesterov (bool, optional): enables Nesterov momentum (default: False) |
|
maximize (bool, optional): maximize the params based on the objective, instead of |
|
minimizing (default: False) |
|
foreach (bool, optional): whether foreach implementation of optimizer |
|
is used (default: None) |
|
|
|
Example: |
|
>>> # xdoctest: +SKIP |
|
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9) |
|
>>> optimizer.zero_grad() |
|
>>> loss_fn(model(input), target).backward() |
|
>>> optimizer.step() |
|
|
|
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf |
|
|
|
.. note:: |
|
The implementation of SGD with Momentum/Nesterov subtly differs from |
|
Sutskever et. al. and implementations in some other frameworks. |
|
|
|
Considering the specific case of Momentum, the update can be written as |
|
|
|
.. math:: |
|
\begin{aligned} |
|
v_{t+1} & = \mu * v_{t} + g_{t+1}, \\ |
|
p_{t+1} & = p_{t} - \text{lr} * v_{t+1}, |
|
\end{aligned} |
|
|
|
where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the |
|
parameters, gradient, velocity, and momentum respectively. |
|
|
|
This is in contrast to Sutskever et. al. and |
|
other frameworks which employ an update of the form |
|
|
|
.. math:: |
|
\begin{aligned} |
|
v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\ |
|
p_{t+1} & = p_{t} - v_{t+1}. |
|
\end{aligned} |
|
|
|
The Nesterov version is analogously modified. |
|
""" |
|
|
|
def __init__(self, params, lr=required, momentum=0, dampening=0, |
|
weight_decay=0, nesterov=False, *, maximize=False, foreach: Optional[bool] = None, |
|
differentiable=False): |
|
if lr is not required and lr < 0.0: |
|
raise ValueError("Invalid learning rate: {}".format(lr)) |
|
if momentum < 0.0: |
|
raise ValueError("Invalid momentum value: {}".format(momentum)) |
|
if weight_decay < 0.0: |
|
raise ValueError("Invalid weight_decay value: {}".format(weight_decay)) |
|
|
|
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, |
|
weight_decay=weight_decay, nesterov=nesterov, |
|
maximize=maximize, foreach=foreach, |
|
differentiable=differentiable) |
|
if nesterov and (momentum <= 0 or dampening != 0): |
|
raise ValueError("Nesterov momentum requires a momentum and zero dampening") |
|
super(SGD, self).__init__(params, defaults) |
|
|
|
def __setstate__(self, state): |
|
super().__setstate__(state) |
|
for group in self.param_groups: |
|
group.setdefault('nesterov', False) |
|
group.setdefault('maximize', False) |
|
group.setdefault('foreach', None) |
|
group.setdefault('differentiable', False) |
|
|
|
@_use_grad_for_differentiable |
|
def step(self, closure=None): |
|
"""Performs a single optimization step. |
|
|
|
Args: |
|
closure (Callable, optional): A closure that reevaluates the model |
|
and returns the loss. |
|
""" |
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
for group in self.param_groups: |
|
params_with_grad = [] |
|
d_p_list = [] |
|
momentum_buffer_list = [] |
|
has_sparse_grad = False |
|
|
|
for p in group['params']: |
|
if p.grad is not None: |
|
params_with_grad.append(p) |
|
d_p_list.append(p.grad) |
|
if p.grad.is_sparse: |
|
has_sparse_grad = True |
|
|
|
state = self.state[p] |
|
if 'momentum_buffer' not in state: |
|
momentum_buffer_list.append(None) |
|
else: |
|
momentum_buffer_list.append(state['momentum_buffer']) |
|
|
|
sgd(params_with_grad, |
|
d_p_list, |
|
momentum_buffer_list, |
|
weight_decay=group['weight_decay'], |
|
momentum=group['momentum'], |
|
lr=group['lr'], |
|
dampening=group['dampening'], |
|
nesterov=group['nesterov'], |
|
maximize=group['maximize'], |
|
has_sparse_grad=has_sparse_grad, |
|
foreach=group['foreach']) |
|
|
|
|
|
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list): |
|
state = self.state[p] |
|
state['momentum_buffer'] = momentum_buffer |
|
|
|
return loss |
|
|
|
|
|
def sgd(params: List[Tensor], |
|
d_p_list: List[Tensor], |
|
momentum_buffer_list: List[Optional[Tensor]], |
|
|
|
|
|
has_sparse_grad: bool = None, |
|
foreach: bool = None, |
|
*, |
|
weight_decay: float, |
|
momentum: float, |
|
lr: float, |
|
dampening: float, |
|
nesterov: bool, |
|
maximize: bool): |
|
r"""Functional API that performs SGD algorithm computation. |
|
|
|
See :class:`~torch.optim.SGD` for details. |
|
""" |
|
|
|
if foreach is None: |
|
|
|
foreach = False |
|
|
|
if foreach and torch.jit.is_scripting(): |
|
raise RuntimeError('torch.jit.script not supported with foreach optimizers') |
|
|
|
if foreach and not torch.jit.is_scripting(): |
|
func = _multi_tensor_sgd |
|
else: |
|
func = _single_tensor_sgd |
|
|
|
func(params, |
|
d_p_list, |
|
momentum_buffer_list, |
|
weight_decay=weight_decay, |
|
momentum=momentum, |
|
lr=lr, |
|
dampening=dampening, |
|
nesterov=nesterov, |
|
has_sparse_grad=has_sparse_grad, |
|
maximize=maximize) |
|
|
|
def _single_tensor_sgd(params: List[Tensor], |
|
d_p_list: List[Tensor], |
|
momentum_buffer_list: List[Optional[Tensor]], |
|
*, |
|
weight_decay: float, |
|
momentum: float, |
|
lr: float, |
|
dampening: float, |
|
nesterov: bool, |
|
maximize: bool, |
|
has_sparse_grad: bool): |
|
|
|
for i, param in enumerate(params): |
|
d_p = d_p_list[i] if not maximize else -d_p_list[i] |
|
|
|
if weight_decay != 0: |
|
d_p = d_p.add(param, alpha=weight_decay) |
|
|
|
if momentum != 0: |
|
buf = momentum_buffer_list[i] |
|
|
|
if buf is None: |
|
buf = torch.clone(d_p).detach() |
|
momentum_buffer_list[i] = buf |
|
else: |
|
buf.mul_(momentum).add_(d_p, alpha=1 - dampening) |
|
|
|
if nesterov: |
|
d_p = d_p.add(buf, alpha=momentum) |
|
else: |
|
d_p = buf |
|
|
|
param.add_(d_p, alpha=-lr) |
|
|
|
|
|
def _multi_tensor_sgd(params: List[Tensor], |
|
grads: List[Tensor], |
|
momentum_buffer_list: List[Optional[Tensor]], |
|
*, |
|
weight_decay: float, |
|
momentum: float, |
|
lr: float, |
|
dampening: float, |
|
nesterov: bool, |
|
maximize: bool, |
|
has_sparse_grad: bool): |
|
|
|
if len(params) == 0: |
|
return |
|
|
|
if has_sparse_grad is None: |
|
has_sparse_grad = any(grad.is_sparse for grad in grads) |
|
|
|
if maximize: |
|
grads = torch._foreach_neg(tuple(grads)) |
|
|
|
if weight_decay != 0: |
|
grads = torch._foreach_add(grads, params, alpha=weight_decay) |
|
|
|
if momentum != 0: |
|
bufs = [] |
|
|
|
all_states_with_momentum_buffer = True |
|
for i in range(len(momentum_buffer_list)): |
|
if momentum_buffer_list[i] is None: |
|
all_states_with_momentum_buffer = False |
|
break |
|
else: |
|
bufs.append(momentum_buffer_list[i]) |
|
|
|
if all_states_with_momentum_buffer: |
|
torch._foreach_mul_(bufs, momentum) |
|
torch._foreach_add_(bufs, grads, alpha=1 - dampening) |
|
else: |
|
bufs = [] |
|
for i in range(len(momentum_buffer_list)): |
|
if momentum_buffer_list[i] is None: |
|
buf = momentum_buffer_list[i] = torch.clone(grads[i]).detach() |
|
else: |
|
buf = momentum_buffer_list[i] |
|
buf.mul_(momentum).add_(grads[i], alpha=1 - dampening) |
|
|
|
bufs.append(buf) |
|
|
|
if nesterov: |
|
torch._foreach_add_(grads, bufs, alpha=momentum) |
|
else: |
|
grads = bufs |
|
|
|
if not has_sparse_grad: |
|
torch._foreach_add_(params, grads, alpha=-lr) |
|
else: |
|
|
|
for i in range(len(params)): |
|
params[i].add_(grads[i], alpha=-lr) |
|
|