Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Iterable | |
import torch | |
from torch.optim.optimizer import Optimizer | |
from mmpretrain.registry import OPTIMIZERS | |
class LARS(Optimizer): | |
"""Implements layer-wise adaptive rate scaling for SGD. | |
Based on Algorithm 1 of the following paper by You, Gitman, and Ginsburg. | |
`Large Batch Training of Convolutional Networks: | |
<https://arxiv.org/abs/1708.03888>`_. | |
Args: | |
params (Iterable): Iterable of parameters to optimize or dicts defining | |
parameter groups. | |
lr (float): Base learning rate. | |
momentum (float): Momentum factor. Defaults to 0. | |
weight_decay (float): Weight decay (L2 penalty). Defaults to 0. | |
dampening (float): Dampening for momentum. Defaults to 0. | |
eta (float): LARS coefficient. Defaults to 0.001. | |
nesterov (bool): Enables Nesterov momentum. Defaults to False. | |
eps (float): A small number to avoid dviding zero. Defaults to 1e-8. | |
Example: | |
>>> optimizer = LARS(model.parameters(), lr=0.1, momentum=0.9, | |
>>> weight_decay=1e-4, eta=1e-3) | |
>>> optimizer.zero_grad() | |
>>> loss_fn(model(input), target).backward() | |
>>> optimizer.step() | |
""" | |
def __init__(self, | |
params: Iterable, | |
lr: float, | |
momentum: float = 0, | |
weight_decay: float = 0, | |
dampening: float = 0, | |
eta: float = 0.001, | |
nesterov: bool = False, | |
eps: float = 1e-8) -> None: | |
if not isinstance(lr, float) and lr < 0.0: | |
raise ValueError(f'Invalid learning rate: {lr}') | |
if momentum < 0.0: | |
raise ValueError(f'Invalid momentum value: {momentum}') | |
if weight_decay < 0.0: | |
raise ValueError(f'Invalid weight_decay value: {weight_decay}') | |
if eta < 0.0: | |
raise ValueError(f'Invalid LARS coefficient value: {eta}') | |
defaults = dict( | |
lr=lr, | |
momentum=momentum, | |
dampening=dampening, | |
weight_decay=weight_decay, | |
nesterov=nesterov, | |
eta=eta) | |
if nesterov and (momentum <= 0 or dampening != 0): | |
raise ValueError( | |
'Nesterov momentum requires a momentum and zero dampening') | |
self.eps = eps | |
super().__init__(params, defaults) | |
def __setstate__(self, state) -> None: | |
super().__setstate__(state) | |
for group in self.param_groups: | |
group.setdefault('nesterov', False) | |
def step(self, closure=None) -> torch.Tensor: | |
"""Performs a single optimization step. | |
Args: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
for group in self.param_groups: | |
weight_decay = group['weight_decay'] | |
momentum = group['momentum'] | |
dampening = group['dampening'] | |
eta = group['eta'] | |
nesterov = group['nesterov'] | |
lr = group['lr'] | |
lars_exclude = group.get('lars_exclude', False) | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
d_p = p.grad | |
if lars_exclude: | |
local_lr = 1. | |
else: | |
weight_norm = torch.norm(p).item() | |
grad_norm = torch.norm(d_p).item() | |
if weight_norm != 0 and grad_norm != 0: | |
# Compute local learning rate for this layer | |
local_lr = eta * weight_norm / \ | |
(grad_norm + weight_decay * weight_norm + self.eps) | |
else: | |
local_lr = 1. | |
actual_lr = local_lr * lr | |
d_p = d_p.add(p, alpha=weight_decay).mul(actual_lr) | |
if momentum != 0: | |
param_state = self.state[p] | |
if 'momentum_buffer' not in param_state: | |
buf = param_state['momentum_buffer'] = \ | |
torch.clone(d_p).detach() | |
else: | |
buf = param_state['momentum_buffer'] | |
buf.mul_(momentum).add_(d_p, alpha=1 - dampening) | |
if nesterov: | |
d_p = d_p.add(buf, alpha=momentum) | |
else: | |
d_p = buf | |
p.add_(-d_p) | |
return loss | |