|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from contextlib import contextmanager |
|
|
|
import torch |
|
|
|
from nemo.utils import logging |
|
|
|
try: |
|
import amp_C |
|
from apex.multi_tensor_apply import multi_tensor_applier |
|
from apex.transformer.parallel_state import get_data_parallel_group, get_data_parallel_world_size |
|
from apex.transformer.tensor_parallel import copy_tensor_model_parallel_attributes |
|
|
|
HAVE_APEX = True |
|
|
|
except (ImportError, ModuleNotFoundError): |
|
|
|
HAVE_APEX = False |
|
|
|
|
|
def _zero_grad_group_helper(group, set_to_none): |
|
"""Zero out the gradient for a group of parameters. |
|
Note: copied from torch.optim.optimizer.""" |
|
for param in group: |
|
if param.grad is not None: |
|
if set_to_none: |
|
param.grad = None |
|
else: |
|
if param.grad.grad_fn is not None: |
|
param.grad.detach_() |
|
else: |
|
param.grad.requires_grad_(False) |
|
param.grad.zero_() |
|
|
|
|
|
def _multi_tensor_copy_this_to_that(this, that, overflow_buf): |
|
"""Use multi-tensor-applier to copy values from one list to another. |
|
We don't have a blfoat16 implementation so for now if the overflow_buf |
|
is not provided, we default back to simple loop copy to be compatible |
|
with bfloat16.""" |
|
if overflow_buf: |
|
|
|
multi_tensor_applier(amp_C.multi_tensor_scale, overflow_buf, [this, that], 1.0) |
|
else: |
|
|
|
for this_, that_ in zip(this, that): |
|
that_.copy_(this_) |
|
|
|
|
|
class GradBucket(object): |
|
""" |
|
Persistent buffer for main gradients that remains allocated between training iterations |
|
""" |
|
|
|
def __init__(self, numel, chunk_size_mb): |
|
if not HAVE_APEX: |
|
raise ImportError( |
|
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." |
|
) |
|
|
|
self.numel = numel |
|
self.data = torch.zeros(self.numel, dtype=torch.float, device=torch.cuda.current_device(), requires_grad=False) |
|
|
|
self.chunk_size_mb = chunk_size_mb |
|
if self.chunk_size_mb > 0: |
|
chunk_size_bytes = chunk_size_mb * 1024 * 1024 |
|
self.chunk_size_numel = chunk_size_bytes // 4 |
|
self.num_chunks = self.numel // self.chunk_size_numel |
|
self.numel_per_chunk = [self.chunk_size_numel] * self.num_chunks |
|
if self.numel % self.chunk_size_numel != 0: |
|
self.num_chunks += 1 |
|
self.numel_per_chunk.append(self.numel % self.chunk_size_numel) |
|
|
|
self.start_index_per_chunk = torch.cumsum(torch.tensor([0] + self.numel_per_chunk[:-1]), dim=0) |
|
self.current_chunk = 0 |
|
self.computed_numel_per_chunk = [0] * self.num_chunks |
|
|
|
def zero(self): |
|
"""Reset the buffer to zero.""" |
|
self.data.zero_() |
|
|
|
def allreduce_buffer(self): |
|
"""Synchronous buffer data allreduce """ |
|
self.data.div_(get_data_parallel_world_size()) |
|
torch.distributed.all_reduce(self.data, group=get_data_parallel_group()) |
|
|
|
def get(self, shape, start_index): |
|
"""Return a tensor with the input `shape` as a view into the |
|
1-D data starting at `start_index`.""" |
|
end_index = start_index + shape.numel() |
|
assert end_index <= self.numel, 'requested tensor is out of the buffer range.' |
|
buffer_tensor = self.data[start_index:end_index] |
|
buffer_tensor = buffer_tensor.view(shape) |
|
|
|
grad_chunk_info = None |
|
if self.chunk_size_mb > 0: |
|
grad_chunk_info = {} |
|
chunk = start_index // self.chunk_size_numel |
|
chunk_start_index = self.start_index_per_chunk[chunk] |
|
chunk_end_index = chunk_start_index + self.numel_per_chunk[chunk] |
|
grad_chunk_info[chunk] = min(chunk_end_index, end_index) - start_index |
|
while chunk_end_index < end_index: |
|
chunk += 1 |
|
chunk_start_index = self.start_index_per_chunk[chunk] |
|
chunk_end_index = chunk_start_index + self.numel_per_chunk[chunk] |
|
grad_chunk_info[chunk] = min(chunk_end_index, end_index) - chunk_start_index |
|
|
|
return buffer_tensor, grad_chunk_info |
|
|
|
def update_chunk_info(self, grad_chunk_info): |
|
for chunk in grad_chunk_info.keys(): |
|
self.computed_numel_per_chunk[chunk] += grad_chunk_info[chunk] |
|
|
|
def get_allreduce_tensor(self): |
|
if self.computed_numel_per_chunk[self.current_chunk] == self.numel_per_chunk[self.current_chunk]: |
|
chunk_start_index = self.start_index_per_chunk[self.current_chunk] |
|
chunk_end_index = chunk_start_index + self.numel_per_chunk[self.current_chunk] |
|
allreduce_tensor = self.data[chunk_start_index:chunk_end_index] |
|
|
|
self.computed_numel_per_chunk[self.current_chunk] = 0 |
|
self.current_chunk += 1 |
|
if self.current_chunk == self.num_chunks: |
|
self.current_chunk = 0 |
|
|
|
return allreduce_tensor |
|
|
|
return None |
|
|
|
|
|
class MainParamsOptimizerWrapper(torch.optim.Optimizer): |
|
""" |
|
Float16 optimizer wrapper for half precision (fp16 and bf16) data types. |
|
This optimizer wrapper holds main parameters and gradients in fp32 to support |
|
stable convergence. |
|
|
|
Arguments: |
|
optimizer: base optimizer such as Adam or SGD. |
|
fp32_grad_accum: to enable the use of fp32 in gradient accumulation and allreduce. |
|
contiguous_grad_bucket: to enable allocating the master gradients in the |
|
contiguous memory space to reduce memory fragmentation. |
|
async_grad_allreduce: enable asynchronous gradient allreduce that is executed |
|
along with the training step backprop. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
optimizer, |
|
fp32_grad_accum=False, |
|
contiguous_grad_bucket=False, |
|
async_grad_allreduce=False, |
|
grad_div_ar_fusion=True, |
|
grad_allreduce_chunk_size_mb=0, |
|
): |
|
if not HAVE_APEX: |
|
raise ImportError( |
|
"Apex was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt." |
|
) |
|
|
|
self.optimizer = optimizer |
|
assert self.optimizer, 'no optimizer is provided.' |
|
if contiguous_grad_bucket: |
|
assert fp32_grad_accum, 'contiguous gradient buffer assumes using fp32 grad.' |
|
if async_grad_allreduce: |
|
assert fp32_grad_accum, ( |
|
'async allreduce applies to master gradients only, ' |
|
'which is supposed to be accumulated after grad op.' |
|
) |
|
assert contiguous_grad_bucket, ( |
|
'currently async_grad_allreduce is supported only ' 'with contiguous_grad_bucket.' |
|
) |
|
|
|
self._fp32_grad_accum = fp32_grad_accum |
|
self._contiguous_grad_bucket = contiguous_grad_bucket |
|
|
|
|
|
|
|
self._async_grad_allreduce = async_grad_allreduce and get_data_parallel_world_size() > 1 |
|
self._grad_divisor = 1 / get_data_parallel_world_size() |
|
|
|
if self._async_grad_allreduce: |
|
|
|
self._require_backward_grad_sync = True |
|
self._grad_div_ar_fusion = grad_div_ar_fusion |
|
self._grad_allreduce_chunk_size_mb = grad_allreduce_chunk_size_mb |
|
else: |
|
self._require_backward_grad_sync = False |
|
self._grad_div_ar_fusion = False |
|
self._grad_allreduce_chunk_size_mb = 0 |
|
|
|
|
|
self._dummy_overflow_buf = None |
|
|
|
|
|
|
|
|
|
self._main_grad_buffers = None |
|
if self._contiguous_grad_bucket: |
|
self._main_grad_buffers = {} |
|
|
|
num_elements = {} |
|
for i, param_group in enumerate(self.optimizer.param_groups): |
|
for param in param_group['params']: |
|
if param.requires_grad: |
|
num_elements[i] = num_elements.get(i, 0) + param.data.nelement() |
|
|
|
|
|
if any(param.requires_grad for param in param_group['params']): |
|
self._main_grad_buffers[i] = GradBucket(num_elements[i], self._grad_allreduce_chunk_size_mb) |
|
|
|
|
|
self.float16_groups = [] |
|
self.fp32_from_float16_groups = [] |
|
self.fp32_from_fp32_groups = [] |
|
|
|
|
|
if self._fp32_grad_accum: |
|
self.grad_accs = [] |
|
|
|
|
|
for i, param_group in enumerate(self.optimizer.param_groups): |
|
float16_params_this_group = [] |
|
fp32_params_this_group = [] |
|
fp32_from_float16_params_this_group = [] |
|
|
|
for j, param in enumerate(param_group['params']): |
|
if param.requires_grad: |
|
|
|
if param.type() in ['torch.cuda.HalfTensor', 'torch.cuda.BFloat16Tensor']: |
|
float16_params_this_group.append(param) |
|
|
|
|
|
main_param = param.detach().clone().float() |
|
|
|
|
|
copy_tensor_model_parallel_attributes(main_param, param) |
|
if hasattr(param, 'shared'): |
|
main_param.shared = param.shared |
|
|
|
|
|
if self._contiguous_grad_bucket: |
|
num_elements[i] -= param.data.nelement() |
|
main_param.grad, grad_chunk_info = self._main_grad_buffers[i].get( |
|
param.data.shape, num_elements[i] |
|
) |
|
|
|
param.main_grad = main_param.grad |
|
|
|
|
|
param_group['params'][j] = main_param |
|
fp32_from_float16_params_this_group.append(main_param) |
|
|
|
if param in self.optimizer.state: |
|
self.optimizer.state[main_param] = self.optimizer.state.pop(param) |
|
|
|
elif param.type() == 'torch.cuda.FloatTensor': |
|
fp32_params_this_group.append(param) |
|
param_group['params'][j] = param |
|
|
|
else: |
|
raise TypeError( |
|
'Wrapped parameters must be one of ' |
|
'torch.cuda.FloatTensor, ' |
|
'torch.cuda.HalfTensor, or ' |
|
'torch.cuda.BFloat16Tensor. ' |
|
'Received {}'.format(param.type()) |
|
) |
|
|
|
|
|
if self._fp32_grad_accum and param.requires_grad: |
|
|
|
param_tmp = param.expand_as(param) |
|
|
|
grad_acc = param_tmp.grad_fn.next_functions[0][0] |
|
grad_acc.register_hook(self._make_param_hook(param, main_param, i, grad_chunk_info)) |
|
self.grad_accs.append(grad_acc) |
|
|
|
self.float16_groups.append(float16_params_this_group) |
|
self.fp32_from_float16_groups.append(fp32_from_float16_params_this_group) |
|
self.fp32_from_fp32_groups.append(fp32_params_this_group) |
|
|
|
|
|
|
|
self.optimizer.load_state_dict(self.optimizer.state_dict()) |
|
|
|
def _make_param_hook(self, param, main_param, i, grad_chunk_info): |
|
"""Create the grad accumulation and all-reduce hook for backprop.""" |
|
|
|
def param_hook(*unused): |
|
|
|
if param.grad is not None: |
|
if main_param.grad is None: |
|
main_param.grad = param.grad.float() |
|
else: |
|
main_param.grad.add_(param.grad.data) |
|
|
|
param.grad = None |
|
|
|
|
|
if self._require_backward_grad_sync: |
|
if self._grad_allreduce_chunk_size_mb > 0: |
|
self._main_grad_buffers[i].update_chunk_info(grad_chunk_info) |
|
while True: |
|
allreduce_tensor = self._main_grad_buffers[i].get_allreduce_tensor() |
|
if allreduce_tensor is None: |
|
break |
|
if self._grad_div_ar_fusion: |
|
torch.distributed.all_reduce( |
|
allreduce_tensor, |
|
group=get_data_parallel_group(), |
|
async_op=True, |
|
op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), |
|
) |
|
else: |
|
allreduce_tensor.div_(get_data_parallel_world_size()) |
|
torch.distributed.all_reduce( |
|
allreduce_tensor, group=get_data_parallel_group(), async_op=True, |
|
) |
|
else: |
|
if self._grad_div_ar_fusion: |
|
torch.distributed.all_reduce( |
|
main_param.grad, |
|
group=get_data_parallel_group(), |
|
async_op=True, |
|
op=torch.distributed._make_nccl_premul_sum(self._grad_divisor), |
|
) |
|
else: |
|
main_param.grad.div_(get_data_parallel_world_size()) |
|
torch.distributed.all_reduce( |
|
main_param.grad, group=get_data_parallel_group(), async_op=True, |
|
) |
|
|
|
return param_hook |
|
|
|
def zero_grad(self, set_to_none=True): |
|
"""We only need to zero the model related parameters, i.e., |
|
float16_groups & fp32_from_fp32_groups. We additionally zero |
|
fp32_from_float16_groups as a memory optimization to reduce |
|
fragmentation; in the case of set_to_none==True, the space |
|
used by this field can be safely deallocated at this point.""" |
|
for group in self.float16_groups: |
|
_zero_grad_group_helper(group, set_to_none) |
|
if self._contiguous_grad_bucket: |
|
for i in self._main_grad_buffers: |
|
self._main_grad_buffers[i].zero() |
|
else: |
|
for group in self.fp32_from_float16_groups: |
|
_zero_grad_group_helper(group, set_to_none) |
|
for group in self.fp32_from_fp32_groups: |
|
_zero_grad_group_helper(group, set_to_none) |
|
|
|
def copy_model_grads_to_main_grads(self): |
|
|
|
for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): |
|
for model_param, main_param in zip(model_group, main_group): |
|
if model_param.grad is not None: |
|
main_param.grad = model_param.grad.float() |
|
|
|
|
|
|
|
|
|
model_param.grad = None |
|
|
|
def _get_model_and_main_params_data_float16(self): |
|
model_data = [] |
|
main_data = [] |
|
half_dtype = None |
|
for model_group, main_group in zip(self.float16_groups, self.fp32_from_float16_groups): |
|
for model_param, main_param in zip(model_group, main_group): |
|
if half_dtype is None: |
|
half_dtype = model_param.data.dtype |
|
model_data.append(model_param.data) |
|
main_data.append(main_param.data) |
|
return model_data, main_data, half_dtype |
|
|
|
def _set_overflow_buffer(self, half_dtype): |
|
if half_dtype == torch.float16: |
|
if self._dummy_overflow_buf is None: |
|
self._dummy_overflow_buf = torch.cuda.IntTensor([0]) |
|
else: |
|
self._dummy_overflow_buf.fill_(0) |
|
|
|
def _copy_main_params_to_model_params(self): |
|
|
|
model_data, main_data, half_dtype = self._get_model_and_main_params_data_float16() |
|
self._set_overflow_buffer(half_dtype) |
|
_multi_tensor_copy_this_to_that(this=main_data, that=model_data, overflow_buf=self._dummy_overflow_buf) |
|
|
|
def _copy_model_params_to_main_params(self): |
|
|
|
model_data, main_data, half_dtype = self._get_model_and_main_params_data_float16() |
|
self._set_overflow_buffer(half_dtype) |
|
_multi_tensor_copy_this_to_that(this=model_data, that=main_data, overflow_buf=self._dummy_overflow_buf) |
|
|
|
def reload_model_params(self): |
|
self._copy_model_params_to_main_params() |
|
|
|
@torch.no_grad() |
|
def step(self, **kwargs): |
|
|
|
|
|
|
|
if self._async_grad_allreduce: |
|
torch.cuda.synchronize() |
|
|
|
|
|
self.optimizer.step(closure=None, **kwargs) |
|
|
|
|
|
with torch.no_grad(): |
|
self._copy_main_params_to_model_params() |
|
|
|
|
|
return True |
|
|
|
def state_dict(self): |
|
state_dict = {} |
|
state_dict['optimizer'] = self.optimizer.state_dict() |
|
state_dict['fp32_from_fp16_params'] = self.fp32_from_float16_groups |
|
return state_dict |
|
|
|
def load_state_dict(self, state_dict): |
|
|
|
optimizer_key = 'optimizer' |
|
if optimizer_key not in state_dict: |
|
optimizer_key = 'optimizer_state_dict' |
|
logging.info('***WARNING*** loading optimizer from ' 'an old checkpoint ...') |
|
self.optimizer.load_state_dict(state_dict[optimizer_key]) |
|
|
|
|
|
fp32_from_float16_params_key = 'fp32_from_fp16_params' |
|
if fp32_from_float16_params_key not in state_dict: |
|
fp32_from_float16_params_key = 'fp32_from_fp16' |
|
for current_group, saved_group in zip(self.fp32_from_float16_groups, state_dict[fp32_from_float16_params_key]): |
|
for current_param, saved_param in zip(current_group, saved_group): |
|
current_param.data.copy_(saved_param.data) |
|
|
|
def allreduce_main_grads(self): |
|
for i in self._main_grad_buffers: |
|
self._main_grad_buffers[i].allreduce_buffer() |
|
|
|
@contextmanager |
|
def no_sync(self): |
|
""" A context manager to disable gradient synchronizations across |
|
data-parallel ranks.""" |
|
old_require_backward_grad_sync = self._require_backward_grad_sync |
|
self._require_backward_grad_sync = False |
|
try: |
|
yield |
|
finally: |
|
self._require_backward_grad_sync = old_require_backward_grad_sync |
|
|
|
@property |
|
def async_master_grads_allreudce(self): |
|
return self._async_grad_allreduce |
|
|
|
@property |
|
def fp32_grad_accumulation(self): |
|
return self._fp32_grad_accum |
|
|
|
def get_parameters(self): |
|
params = [] |
|
for param_group in self.optimizer.param_groups: |
|
for param in param_group['params']: |
|
params.append(param) |
|
return params |
|
|
|
|
|
|
|
def _get_state(self): |
|
if hasattr(self, 'optimizer'): |
|
return self.optimizer.state |
|
else: |
|
return [] |
|
|
|
def _set_state(self, value): |
|
self.optimizer.state = value |
|
|
|
state = property(_get_state, _set_state) |
|
|
|
|
|
|
|
|
|
def _get_param_groups(self): |
|
if hasattr(self, 'optimizer'): |
|
return self.optimizer.param_groups |
|
else: |
|
return [] |
|
|
|
def _set_param_groups(self, value): |
|
self.optimizer.param_groups = value |
|
|
|
param_groups = property(_get_param_groups, _set_param_groups) |
|
|
|
|
|
|
|
def _get_defaults(self): |
|
if hasattr(self, 'optimizer'): |
|
return self.optimizer.defaults |
|
else: |
|
return [] |
|
|
|
def _set_defaults(self, value): |
|
self.optimizer.defaults = value |
|
|
|
defaults = property(_get_defaults, _set_defaults) |
|
|