|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""General utilities.""" |
|
|
|
import sys |
|
|
|
import torch |
|
from torch.nn.parallel import DistributedDataParallel as torchDDP |
|
|
|
from apex.multi_tensor_apply import multi_tensor_applier |
|
import amp_C |
|
|
|
from megatron import get_args |
|
from megatron import get_adlr_autoresume |
|
from megatron import mpu |
|
from megatron.model.module import param_is_not_shared |
|
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate |
|
|
|
|
|
def unwrap_model(model, module_instances=(torchDDP)): |
|
return_list = True |
|
if not isinstance(model, list): |
|
model = [model] |
|
return_list = False |
|
unwrapped_model = [] |
|
for model_module in model: |
|
while isinstance(model_module, module_instances): |
|
model_module = model_module.module |
|
unwrapped_model.append(model_module) |
|
if not return_list: |
|
return unwrapped_model[0] |
|
return unwrapped_model |
|
|
|
|
|
def calc_params_l2_norm(model): |
|
"""Calculate l2 norm of parameters """ |
|
args = get_args() |
|
if not isinstance(model, list): |
|
model = [model] |
|
|
|
params_data = [] |
|
for model_ in model: |
|
for param in model_.parameters(): |
|
is_not_shared = param_is_not_shared(param) |
|
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param) |
|
if is_not_shared and is_not_tp_duplicate: |
|
if args.bf16: |
|
params_data.append(param.data.float()) |
|
else: |
|
params_data.append(param.data) |
|
|
|
dummy_overflow_buf = torch.cuda.IntTensor([0]) |
|
norm, _ = multi_tensor_applier( |
|
amp_C.multi_tensor_l2norm, |
|
dummy_overflow_buf, |
|
[params_data], |
|
False |
|
) |
|
norm_2 = norm * norm |
|
|
|
torch.distributed.all_reduce(norm_2, |
|
op=torch.distributed.ReduceOp.SUM, |
|
group=mpu.get_model_parallel_group()) |
|
return norm_2.item() ** 0.5 |
|
|
|
|
|
def average_losses_across_data_parallel_group(losses): |
|
"""Reduce a tensor of losses across all GPUs.""" |
|
averaged_losses = torch.cat( |
|
[loss.clone().detach().view(1) for loss in losses]) |
|
torch.distributed.all_reduce(averaged_losses, |
|
group=mpu.get_data_parallel_group()) |
|
averaged_losses = averaged_losses / \ |
|
torch.distributed.get_world_size(group=mpu.get_data_parallel_group()) |
|
|
|
return averaged_losses |
|
|
|
|
|
def report_memory(name): |
|
"""Simple GPU memory report.""" |
|
mega_bytes = 1024.0 * 1024.0 |
|
string = name + ' memory (MB)' |
|
string += ' | allocated: {}'.format( |
|
torch.cuda.memory_allocated() / mega_bytes) |
|
string += ' | max allocated: {}'.format( |
|
torch.cuda.max_memory_allocated() / mega_bytes) |
|
string += ' | reserved: {}'.format( |
|
torch.cuda.memory_reserved() / mega_bytes) |
|
string += ' | max reserved: {}'.format( |
|
torch.cuda.max_memory_reserved() / mega_bytes) |
|
if mpu.get_data_parallel_rank() == 0: |
|
print("[Rank {}] {}".format(torch.distributed.get_rank(), string), |
|
flush=True) |
|
|
|
|
|
def print_params_min_max_norm(optimizer, iteration): |
|
"""Print min, max, and norm of all parameters.""" |
|
index = 0 |
|
rank = torch.distributed.get_rank() |
|
string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n' |
|
optimizer_ = optimizer.optimizer |
|
for param_group in optimizer_.param_groups: |
|
for param in param_group['params']: |
|
index += 1 |
|
min_ = param.data.min() |
|
max_ = param.data.max() |
|
norm = torch.linalg.norm(param.data) |
|
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format( |
|
iteration, rank, index, int(param.tensor_model_parallel)) |
|
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm) |
|
print(string, flush=True) |
|
|
|
|
|
def check_adlr_autoresume_termination(iteration, model, |
|
optimizer, opt_param_scheduler): |
|
"""Check for autoresume signal and exit if it is received.""" |
|
from megatron.checkpointing import save_checkpoint |
|
|
|
args = get_args() |
|
autoresume = get_adlr_autoresume() |
|
|
|
torch.distributed.barrier() |
|
if autoresume.termination_requested(): |
|
if args.save: |
|
save_checkpoint(iteration, model, optimizer, opt_param_scheduler) |
|
print_rank_0(">>> autoresume termination request found!") |
|
if torch.distributed.get_rank() == 0: |
|
autoresume.request_resume() |
|
print_rank_0(">>> training terminated. Returning") |
|
sys.exit(0) |
|
|
|
|
|
def get_ltor_masks_and_position_ids(data, |
|
eod_token, |
|
reset_position_ids, |
|
reset_attention_mask, |
|
eod_mask_loss): |
|
"""Build masks and position id for left to right model.""" |
|
|
|
|
|
micro_batch_size, seq_length = data.size() |
|
|
|
|
|
if reset_attention_mask: |
|
att_mask_batch = micro_batch_size |
|
else: |
|
att_mask_batch = 1 |
|
attention_mask = torch.tril(torch.ones( |
|
(att_mask_batch, seq_length, seq_length), device=data.device)).view( |
|
att_mask_batch, 1, seq_length, seq_length) |
|
|
|
|
|
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device) |
|
if eod_mask_loss: |
|
loss_mask[data == eod_token] = 0.0 |
|
|
|
|
|
position_ids = torch.arange(seq_length, dtype=torch.long, |
|
device=data.device) |
|
position_ids = position_ids.unsqueeze(0).expand_as(data) |
|
|
|
if reset_position_ids: |
|
position_ids = position_ids.clone() |
|
|
|
if reset_position_ids or reset_attention_mask: |
|
|
|
for b in range(micro_batch_size): |
|
|
|
|
|
eod_index = position_ids[b, data[b] == eod_token] |
|
|
|
if reset_position_ids: |
|
eod_index = eod_index.clone() |
|
|
|
|
|
prev_index = 0 |
|
for j in range(eod_index.size()[0]): |
|
i = eod_index[j] |
|
|
|
if reset_attention_mask: |
|
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0 |
|
|
|
if reset_position_ids: |
|
position_ids[b, (i + 1):] -= (i + 1 - prev_index) |
|
prev_index = i + 1 |
|
|
|
|
|
attention_mask = (attention_mask < 0.5) |
|
|
|
return attention_mask, loss_mask, position_ids |
|
|
|
|
|
def print_rank_0(message): |
|
"""If distributed is initialized, print only on rank 0.""" |
|
if torch.distributed.is_initialized(): |
|
if torch.distributed.get_rank() == 0: |
|
print(message, flush=True) |
|
else: |
|
print(message, flush=True) |
|
|
|
def is_last_rank(): |
|
return torch.distributed.get_rank() == ( |
|
torch.distributed.get_world_size() - 1) |
|
|
|
def print_rank_last(message): |
|
"""If distributed is initialized, print only on last rank.""" |
|
if torch.distributed.is_initialized(): |
|
if is_last_rank(): |
|
print(message, flush=True) |
|
else: |
|
print(message, flush=True) |
|
|