Spaces:
Runtime error
Runtime error
# Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# | |
# This work is made available under the Nvidia Source Code License-NC. | |
# To view a copy of this license, check out LICENSE.md | |
import torch | |
from torch.optim.optimizer import Optimizer, required | |
class Madam(Optimizer): | |
r"""MADAM optimizer implementation (https://arxiv.org/abs/2006.14560)""" | |
def __init__(self, params, lr=required, scale=3.0, | |
g_bound=None, momentum=0): | |
self.scale = scale | |
self.g_bound = g_bound | |
defaults = dict(lr=lr, momentum=momentum) | |
super(Madam, self).__init__(params, defaults) | |
def step(self, closure=None): | |
r"""Performs a single optimization step. | |
Args: | |
closure (callable, optional): A closure that reevaluates the model | |
and returns the loss. | |
""" | |
loss = None | |
if closure is not None: | |
loss = closure() | |
for group in self.param_groups: | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
state = self.state[p] | |
if len(state) == 0: | |
state['max'] = self.scale * (p * p).mean().sqrt().item() | |
state['step'] = 0 | |
state['exp_avg_sq'] = torch.zeros_like(p) | |
state['step'] += 1 | |
bias_correction = 1 - 0.999 ** state['step'] | |
state['exp_avg_sq'] = 0.999 * state[ | |
'exp_avg_sq'] + 0.001 * p.grad.data ** 2 | |
g_normed = \ | |
p.grad.data / (state['exp_avg_sq'] / bias_correction).sqrt() | |
g_normed[torch.isnan(g_normed)] = 0 | |
if self.g_bound is not None: | |
g_normed.clamp_(-self.g_bound, self.g_bound) | |
p.data *= torch.exp( | |
-group['lr'] * g_normed * torch.sign(p.data)) | |
p.data.clamp_(-state['max'], state['max']) | |
return loss | |