Spaces:
Runtime error
Runtime error
"""PyTorch Lamb optimizer w/ behaviour similar to NVIDIA FusedLamb. | |
This optimizer code was adapted from the following (starting with latest) | |
* https://github.com/HabanaAI/Model-References/blob/ | |
2b435114fe8e31f159b1d3063b8280ae37af7423/PyTorch/nlp/bert/pretraining/lamb.py | |
* https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/ | |
LanguageModeling/Transformer-XL/pytorch/lamb.py | |
* https://github.com/cybertronai/pytorch-lamb | |
Use FusedLamb if you can (GPU). The reason for including this variant of Lamb | |
is to have a version that is | |
similar in behaviour to APEX FusedLamb if you aren't using NVIDIA GPUs or | |
cannot install/use APEX. | |
In addition to some cleanup, this Lamb impl has been modified to support | |
PyTorch XLA and has been tested on TPU. | |
Original copyrights for above sources are below. | |
Modifications Copyright 2021 Ross Wightman | |
""" | |
# Copyright (c) 2021, Habana Labs Ltd. All rights reserved. | |
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# MIT License | |
# | |
# Copyright (c) 2019 cybertronai | |
# | |
# Permission is hereby granted, free of charge, to any person obtaining a copy | |
# of this software and associated documentation files (the "Software"), to deal | |
# in the Software without restriction, including without limitation the rights | |
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | |
# copies of the Software, and to permit persons to whom the Software is | |
# furnished to do so, subject to the following conditions: | |
# | |
# The above copyright notice and this permission notice shall be included in | |
# all copies or substantial portions of the Software. | |
# | |
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | |
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | |
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | |
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | |
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE | |
# SOFTWARE. | |
import math | |
import torch | |
from torch.optim import Optimizer | |
from mmpretrain.registry import OPTIMIZERS | |
class Lamb(Optimizer): | |
"""A pure pytorch variant of FuseLAMB (NvLamb variant) optimizer. | |
This class is copied from `timm`_. The LAMB was proposed in `Large Batch | |
Optimization for Deep Learning - Training BERT in 76 minutes`_. | |
.. _timm: | |
https://github.com/rwightman/pytorch-image-models/blob/master/timm/optim/lamb.py | |
.. _Large Batch Optimization for Deep Learning - Training BERT in 76 minutes: | |
https://arxiv.org/abs/1904.00962 | |
Arguments: | |
params (iterable): iterable of parameters to optimize or dicts defining | |
parameter groups. | |
lr (float, optional): learning rate. (default: 1e-3) | |
betas (Tuple[float, float], optional): coefficients used for computing | |
running averages of gradient and its norm. (default: (0.9, 0.999)) | |
eps (float, optional): term added to the denominator to improve | |
numerical stability. (default: 1e-8) | |
weight_decay (float, optional): weight decay (L2 penalty) (default: 0) | |
grad_averaging (bool, optional): whether apply (1-beta2) to grad when | |
calculating running averages of gradient. (default: True) | |
max_grad_norm (float, optional): value used to clip global grad norm | |
(default: 1.0) | |
trust_clip (bool): enable LAMBC trust ratio clipping (default: False) | |
always_adapt (boolean, optional): Apply adaptive learning rate to 0.0 | |
weight decay parameter (default: False) | |
""" # noqa: E501 | |
def __init__(self, | |
params, | |
lr=1e-3, | |
bias_correction=True, | |
betas=(0.9, 0.999), | |
eps=1e-6, | |
weight_decay=0.01, | |
grad_averaging=True, | |
max_grad_norm=1.0, | |
trust_clip=False, | |
always_adapt=False): | |
defaults = dict( | |
lr=lr, | |
bias_correction=bias_correction, | |
betas=betas, | |
eps=eps, | |
weight_decay=weight_decay, | |
grad_averaging=grad_averaging, | |
max_grad_norm=max_grad_norm, | |
trust_clip=trust_clip, | |
always_adapt=always_adapt) | |
super().__init__(params, defaults) | |
def step(self, closure=None): | |
"""Performs a single optimization step. | |
Arguments: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
with torch.enable_grad(): | |
loss = closure() | |
device = self.param_groups[0]['params'][0].device | |
one_tensor = torch.tensor( | |
1.0, device=device | |
) # because torch.where doesn't handle scalars correctly | |
global_grad_norm = torch.zeros(1, device=device) | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad | |
if grad.is_sparse: | |
raise RuntimeError( | |
'Lamb does not support sparse gradients, consider ' | |
'SparseAdam instead.') | |
global_grad_norm.add_(grad.pow(2).sum()) | |
global_grad_norm = torch.sqrt(global_grad_norm) | |
# FIXME it'd be nice to remove explicit tensor conversion of scalars | |
# when torch.where promotes | |
# scalar types properly https://github.com/pytorch/pytorch/issues/9190 | |
max_grad_norm = torch.tensor( | |
self.defaults['max_grad_norm'], device=device) | |
clip_global_grad_norm = torch.where(global_grad_norm > max_grad_norm, | |
global_grad_norm / max_grad_norm, | |
one_tensor) | |
for group in self.param_groups: | |
bias_correction = 1 if group['bias_correction'] else 0 | |
beta1, beta2 = group['betas'] | |
grad_averaging = 1 if group['grad_averaging'] else 0 | |
beta3 = 1 - beta1 if grad_averaging else 1.0 | |
# assume same step across group now to simplify things | |
# per parameter step can be easily support by making it tensor, or | |
# pass list into kernel | |
if 'step' in group: | |
group['step'] += 1 | |
else: | |
group['step'] = 1 | |
if bias_correction: | |
bias_correction1 = 1 - beta1**group['step'] | |
bias_correction2 = 1 - beta2**group['step'] | |
else: | |
bias_correction1, bias_correction2 = 1.0, 1.0 | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad.div_(clip_global_grad_norm) | |
state = self.state[p] | |
# State initialization | |
if len(state) == 0: | |
# Exponential moving average of gradient valuesa | |
state['exp_avg'] = torch.zeros_like(p) | |
# Exponential moving average of squared gradient values | |
state['exp_avg_sq'] = torch.zeros_like(p) | |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |
# Decay the first and second moment running average coefficient | |
exp_avg.mul_(beta1).add_(grad, alpha=beta3) # m_t | |
exp_avg_sq.mul_(beta2).addcmul_( | |
grad, grad, value=1 - beta2) # v_t | |
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_( | |
group['eps']) | |
update = (exp_avg / bias_correction1).div_(denom) | |
weight_decay = group['weight_decay'] | |
if weight_decay != 0: | |
update.add_(p, alpha=weight_decay) | |
if weight_decay != 0 or group['always_adapt']: | |
# Layer-wise LR adaptation. By default, skip adaptation on | |
# parameters that are | |
# excluded from weight decay, unless always_adapt == True, | |
# then always enabled. | |
w_norm = p.norm(2.0) | |
g_norm = update.norm(2.0) | |
# FIXME nested where required since logical and/or not | |
# working in PT XLA | |
trust_ratio = torch.where( | |
w_norm > 0, | |
torch.where(g_norm > 0, w_norm / g_norm, one_tensor), | |
one_tensor, | |
) | |
if group['trust_clip']: | |
# LAMBC trust clipping, upper bound fixed at one | |
trust_ratio = torch.minimum(trust_ratio, one_tensor) | |
update.mul_(trust_ratio) | |
p.add_(update, alpha=-group['lr']) | |
return loss | |