JustinLin610
update
8437114
raw history blame
No virus
3.73 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from collections.abc import Collection
from dataclasses import dataclass, field
from typing import List
import torch
from fairseq.dataclass import FairseqDataclass
from omegaconf import II, DictConfig
from torch.optim.optimizer import Optimizer, required
from . import FairseqOptimizer, register_optimizer
@dataclass
class FairseqNAGConfig(FairseqDataclass):
momentum: float = field(default=0.99, metadata={"help": "momentum factor"})
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"})
# TODO common vars in parent class
lr: List[float] = II("optimization.lr")
@register_optimizer("nag", dataclass=FairseqNAGConfig)
class FairseqNAG(FairseqOptimizer):
def __init__(self, cfg: DictConfig, params):
super().__init__(cfg)
self._optimizer = NAG(params, **self.optimizer_config)
@property
def optimizer_config(self):
"""
Return a kwarg dictionary that will be used to override optimizer
args stored in checkpoints. This allows us to load a checkpoint and
resume training using a different set of optimizer args, e.g., with a
different learning rate.
"""
return {
"lr": self.cfg.lr[0]
if isinstance(self.cfg.lr, Collection)
else self.cfg.lr,
"momentum": self.cfg.momentum,
"weight_decay": self.cfg.weight_decay,
}
class NAG(Optimizer):
def __init__(self, params, lr=required, momentum=0, weight_decay=0):
defaults = dict(lr=lr, lr_old=lr, momentum=momentum, weight_decay=weight_decay)
super(NAG, self).__init__(params, defaults)
@property
def supports_memory_efficient_fp16(self):
return True
@property
def supports_flat_params(self):
return True
def step(self, closure=None):
"""Performs a single optimization step.
Args:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group["weight_decay"]
momentum = group["momentum"]
lr = group["lr"]
lr_old = group.get("lr_old", lr)
lr_correct = lr / lr_old if lr_old > 0 else lr
for p in group["params"]:
if p.grad is None:
continue
p_data_fp32 = p.data
if p_data_fp32.dtype in {torch.float16, torch.bfloat16}:
p_data_fp32 = p_data_fp32.float()
d_p = p.grad.data.float()
param_state = self.state[p]
if "momentum_buffer" not in param_state:
param_state["momentum_buffer"] = torch.zeros_like(d_p)
else:
param_state["momentum_buffer"] = param_state["momentum_buffer"].to(
d_p
)
buf = param_state["momentum_buffer"]
if weight_decay != 0:
p_data_fp32.mul_(1 - lr * weight_decay)
p_data_fp32.add_(buf, alpha=momentum * momentum * lr_correct)
p_data_fp32.add_(d_p, alpha=-(1 + momentum) * lr)
buf.mul_(momentum * lr_correct).add_(d_p, alpha=-lr)
if p.data.dtype in {torch.float16, torch.bfloat16}:
p.data.copy_(p_data_fp32)
group["lr_old"] = lr
return loss