|
|
|
|
|
|
|
|
|
|
|
import importlib |
|
from collections.abc import Collection |
|
from dataclasses import dataclass, field |
|
from typing import List |
|
|
|
import torch |
|
from fairseq.dataclass import FairseqDataclass |
|
from fairseq.optim import FairseqOptimizer, register_optimizer |
|
from omegaconf import II, DictConfig |
|
|
|
|
|
try: |
|
import deepspeed |
|
|
|
has_deepspeed = True |
|
except ImportError as e: |
|
has_deepspeed = False |
|
|
|
|
|
def _get_cpu_adam(): |
|
try: |
|
from deepspeed.ops.op_builder import CPUAdamBuilder |
|
|
|
return CPUAdamBuilder().load() |
|
except ImportError: |
|
|
|
from deepspeed.ops.adam import DeepSpeedCPUAdam as ds_opt_adam |
|
|
|
return ds_opt_adam |
|
|
|
|
|
@dataclass |
|
class FairseqCPUAdamConfig(FairseqDataclass): |
|
adam_betas: str = field( |
|
default="(0.9, 0.999)", metadata={"help": "betas for Adam optimizer"} |
|
) |
|
adam_eps: float = field( |
|
default=1e-8, metadata={"help": "epsilon for Adam optimizer"} |
|
) |
|
weight_decay: float = field(default=0.0, metadata={"help": "weight decay"}) |
|
fp16_adam_stats: bool = field( |
|
default=False, metadata={"help": "use FP16 stats (with automatic scaling)"} |
|
) |
|
|
|
lr: List[float] = II("optimization.lr") |
|
|
|
|
|
@register_optimizer("cpu_adam", dataclass=FairseqCPUAdamConfig) |
|
class FairseqCPUAdam(FairseqOptimizer): |
|
"""Adam optimizer for fairseq, optimized for CPU tensors. |
|
|
|
Important note: this optimizer corresponds to the "AdamW" variant of |
|
Adam in its weight decay behavior. As such, it is most closely |
|
analogous to torch.optim.AdamW from PyTorch. |
|
""" |
|
|
|
def __init__(self, cfg: DictConfig, params): |
|
super().__init__(cfg) |
|
self._optimizer = CPUAdam(params, **self.optimizer_config) |
|
|
|
@property |
|
def optimizer_config(self): |
|
""" |
|
Return a kwarg dictionary that will be used to override optimizer |
|
args stored in checkpoints. This allows us to load a checkpoint and |
|
resume training using a different set of optimizer args, e.g., with a |
|
different learning rate. |
|
""" |
|
return { |
|
"lr": self.cfg.lr[0] |
|
if isinstance(self.cfg.lr, Collection) |
|
else self.cfg.lr, |
|
"betas": eval(self.cfg.adam_betas), |
|
"eps": self.cfg.adam_eps, |
|
"weight_decay": self.cfg.weight_decay, |
|
"use_fp16_stats": self.cfg.fp16_adam_stats, |
|
} |
|
|
|
|
|
class CPUAdam(torch.optim.Optimizer): |
|
|
|
optimizer_id = 0 |
|
|
|
def __init__( |
|
self, |
|
params, |
|
lr=1e-3, |
|
bias_correction=True, |
|
betas=(0.9, 0.999), |
|
eps=1e-8, |
|
weight_decay=0, |
|
use_fp16_stats=False, |
|
): |
|
defaults = { |
|
"lr": lr, |
|
"bias_correction": bias_correction, |
|
"betas": betas, |
|
"eps": eps, |
|
"weight_decay": weight_decay, |
|
} |
|
super().__init__(params, defaults) |
|
|
|
self.use_fp16_stats = use_fp16_stats |
|
self.FLOAT16_MAX = 65504.0 |
|
|
|
if not has_deepspeed: |
|
raise ImportError("Please install DeepSpeed: pip install deepspeed") |
|
|
|
self.opt_id = CPUAdam.optimizer_id |
|
CPUAdam.optimizer_id = CPUAdam.optimizer_id + 1 |
|
|
|
self.ds_opt_adam = _get_cpu_adam() |
|
adamw_mode = True |
|
self.ds_opt_adam.create_adam( |
|
self.opt_id, lr, betas[0], betas[1], eps, weight_decay, adamw_mode |
|
) |
|
|
|
@property |
|
def supports_memory_efficient_fp16(self): |
|
return True |
|
|
|
@property |
|
def supports_flat_params(self): |
|
return True |
|
|
|
@torch.no_grad() |
|
def step(self, closure=None): |
|
loss = None |
|
if closure is not None: |
|
with torch.enable_grad(): |
|
loss = closure() |
|
|
|
torch.cuda.synchronize() |
|
|
|
for group_id, group in enumerate(self.param_groups): |
|
for param_id, p in enumerate(group["params"]): |
|
if p.grad is None: |
|
continue |
|
|
|
state = self.state[p] |
|
if len(state) == 0: |
|
state["step"] = 0 |
|
dtype = torch.float16 if self.use_fp16_stats else p.data.dtype |
|
|
|
state["exp_avg"] = torch.zeros_like( |
|
p.data, dtype=dtype, device="cpu" |
|
) |
|
|
|
state["exp_avg_sq"] = torch.zeros_like( |
|
p.data, dtype=dtype, device="cpu" |
|
) |
|
if self.use_fp16_stats: |
|
assert torch.is_floating_point(p.data) |
|
state["exp_avg_scale"] = 1.0 |
|
state["exp_avg_sq_scale"] = 1.0 |
|
|
|
exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] |
|
|
|
p_data_bak = p.data |
|
|
|
p.data = p.data.to(dtype=torch.float32, device="cpu") |
|
p.grad.data = p.grad.data.to(dtype=torch.float32, device="cpu") |
|
|
|
if self.use_fp16_stats: |
|
exp_avg = exp_avg.float() * state["exp_avg_scale"] |
|
exp_avg_sq = exp_avg_sq.float() * state["exp_avg_sq_scale"] |
|
|
|
state["step"] += 1 |
|
beta1, beta2 = group["betas"] |
|
|
|
self.ds_opt_adam.adam_update( |
|
self.opt_id, |
|
state["step"], |
|
group["lr"], |
|
beta1, |
|
beta2, |
|
group["eps"], |
|
group["weight_decay"], |
|
group["bias_correction"], |
|
p.data, |
|
p.grad.data, |
|
exp_avg, |
|
exp_avg_sq, |
|
) |
|
|
|
if p_data_bak.data_ptr() != p.data.data_ptr(): |
|
p_data_bak.copy_(p.data) |
|
p.data = p_data_bak |
|
|
|
if self.use_fp16_stats: |
|
|
|
def inf_norm(t): |
|
return torch.norm(t, float("inf")) |
|
|
|
|
|
state["exp_avg_scale"], state["exp_avg_sq_scale"] = ( |
|
1e-8 + inf_norm(exp_avg) / self.FLOAT16_MAX, |
|
1e-8 + inf_norm(exp_avg_sq) / self.FLOAT16_MAX, |
|
) |
|
state["exp_avg"], state["exp_avg_sq"] = ( |
|
(exp_avg / state["exp_avg_scale"]).half(), |
|
(exp_avg_sq / state["exp_avg_sq_scale"]).half(), |
|
) |
|
|
|
return loss |
|
|