JustinLin610
update
8437114
raw history blame
No virus
6.23 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.optim
from . import LegacyFairseqOptimizer, register_optimizer
@register_optimizer("adamax")
class FairseqAdamax(LegacyFairseqOptimizer):
def __init__(self, args, params):
super().__init__(args)
self._optimizer = Adamax(params, **self.optimizer_config)
@staticmethod
def add_args(parser):
"""Add optimizer-specific arguments to the parser."""
# fmt: off
parser.add_argument('--adamax-betas', default='(0.9, 0.999)', metavar='B',
help='betas for Adam optimizer')
parser.add_argument('--adamax-eps', type=float, default=1e-8, metavar='D',
help='epsilon for Adam optimizer')
parser.add_argument('--weight-decay', '--wd', default=0.0, type=float, metavar='WD',
help='weight decay')
parser.add_argument('--no-bias-correction', default=False, action='store_true',
help='disable bias correction')
# fmt: on
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
"lr": self.args.lr[0],
"betas": eval(self.args.adamax_betas),
"eps": self.args.adamax_eps,
"weight_decay": self.args.weight_decay,
"bias_correction": not self.args.no_bias_correction,
}
class Adamax(torch.optim.Optimizer):
"""Implements Adamax algorithm (a variant of Adam based on infinity norm).
It has been proposed in `Adam: A Method for Stochastic Optimization`__.
Compared to the version in PyTorch, this version implements a fix for weight decay.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 2e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
bias_correction (bool, optional): enable bias correction (default: True)
__ https://arxiv.org/abs/1412.6980
"""
def __init__(
self,
params,
lr=2e-3,
betas=(0.9, 0.999),
eps=1e-8,
weight_decay=0,
bias_correction=True,
):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
defaults = dict(
lr=lr,
betas=betas,
eps=eps,
weight_decay=weight_decay,
bias_correction=bias_correction,
)
super(Adamax, self).__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group["params"]:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError("Adamax does not support sparse gradients")
p_data_fp32 = p.data
if p.data.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
state = self.state[p]
# State initialization
if len(state) == 0:
state["step"] = 0
state["exp_avg"] = torch.zeros_like(p_data_fp32)
state["exp_inf"] = torch.zeros_like(p_data_fp32)
else:
state["exp_avg"] = state["exp_avg"].to(p_data_fp32)
state["exp_inf"] = state["exp_inf"].to(p_data_fp32)
exp_avg, exp_inf = state["exp_avg"], state["exp_inf"]
beta1, beta2 = group["betas"]
eps = group["eps"]
state["step"] += 1
# Update biased first moment estimate.
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
# Update the exponentially weighted infinity norm.
torch.max(
exp_inf.mul_(beta2),
grad.abs_(),
out=exp_inf,
)
step_size = group["lr"]
if group["bias_correction"]:
bias_correction = 1 - beta1 ** state["step"]
step_size /= bias_correction
if group["weight_decay"] != 0:
p_data_fp32.add_(
p_data_fp32, alpha=-group["weight_decay"] * group["lr"]
)
p_data_fp32.addcdiv_(exp_avg, exp_inf.add(eps), value=-step_size)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
return loss