venite's picture
initial
f670afc
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is made available under the Nvidia Source Code License-NC.
# To view a copy of this license, check out LICENSE.md
# import torch
import math
from torch.optim.optimizer import Optimizer, required
class Fromage(Optimizer):
r"""Fromage optimizer implementation (https://arxiv.org/abs/2002.03432)"""
def __init__(self, params, lr=required, momentum=0):
if lr is not required and lr < 0.0:
raise ValueError("Invalid learning rate: {}".format(lr))
defaults = dict(lr=lr, momentum=momentum)
super(Fromage, self).__init__(params, defaults)
def step(self, closure=None):
r"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
d_p_norm = p.grad.norm()
p_norm = p.norm()
if p_norm > 0.0 and d_p_norm > 0.0:
p.data.add_(-group['lr'], d_p * (p_norm / d_p_norm))
else:
p.data.add_(-group['lr'], d_p)
p.data /= math.sqrt(1 + group['lr'] ** 2)
return loss