File size: 8,166 Bytes
23bd7af |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 |
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""General utilities."""
import sys
import torch
from torch.nn.parallel import DistributedDataParallel as torchDDP
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
from megatron import get_args
from megatron import get_adlr_autoresume
from megatron import mpu
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
def unwrap_model(model, module_instances=(torchDDP)):
return_list = True
if not isinstance(model, list):
model = [model]
return_list = False
unwrapped_model = []
for model_module in model:
while isinstance(model_module, module_instances):
model_module = model_module.module
unwrapped_model.append(model_module)
if not return_list:
return unwrapped_model[0]
return unwrapped_model
def calc_params_l2_norm(model):
"""Calculate l2 norm of parameters """
args = get_args()
if not isinstance(model, list):
model = [model]
# Remove duplicate params.
params_data = []
for model_ in model:
for param in model_.parameters():
is_not_shared = param_is_not_shared(param)
is_not_tp_duplicate = param_is_not_tensor_parallel_duplicate(param)
if is_not_shared and is_not_tp_duplicate:
if args.bf16:
params_data.append(param.data.float())
else:
params_data.append(param.data)
# Calculate norm
dummy_overflow_buf = torch.cuda.IntTensor([0])
norm, _ = multi_tensor_applier(
amp_C.multi_tensor_l2norm,
dummy_overflow_buf,
[params_data],
False # no per-parameter norm
)
norm_2 = norm * norm
# Sum across all model-parallel GPUs.
torch.distributed.all_reduce(norm_2,
op=torch.distributed.ReduceOp.SUM,
group=mpu.get_model_parallel_group())
return norm_2.item() ** 0.5
def average_losses_across_data_parallel_group(losses):
"""Reduce a tensor of losses across all GPUs."""
averaged_losses = torch.cat(
[loss.clone().detach().view(1) for loss in losses])
torch.distributed.all_reduce(averaged_losses,
group=mpu.get_data_parallel_group())
averaged_losses = averaged_losses / \
torch.distributed.get_world_size(group=mpu.get_data_parallel_group())
return averaged_losses
def report_memory(name):
"""Simple GPU memory report."""
mega_bytes = 1024.0 * 1024.0
string = name + ' memory (MB)'
string += ' | allocated: {}'.format(
torch.cuda.memory_allocated() / mega_bytes)
string += ' | max allocated: {}'.format(
torch.cuda.max_memory_allocated() / mega_bytes)
string += ' | reserved: {}'.format(
torch.cuda.memory_reserved() / mega_bytes)
string += ' | max reserved: {}'.format(
torch.cuda.max_memory_reserved() / mega_bytes)
if mpu.get_data_parallel_rank() == 0:
print("[Rank {}] {}".format(torch.distributed.get_rank(), string),
flush=True)
def print_params_min_max_norm(optimizer, iteration):
"""Print min, max, and norm of all parameters."""
index = 0
rank = torch.distributed.get_rank()
string = 'iteration, rank, index, tensor-model-parallel, min, max, norm\n'
optimizer_ = optimizer.optimizer
for param_group in optimizer_.param_groups:
for param in param_group['params']:
index += 1
min_ = param.data.min()
max_ = param.data.max()
norm = torch.linalg.norm(param.data)
string += '{:7d}, {:4d}, {:4d}, {:2d}, '.format(
iteration, rank, index, int(param.tensor_model_parallel))
string += '{:.6E}, {:.6E}, {:.6E}\n'.format(min_, max_, norm)
print(string, flush=True)
def check_adlr_autoresume_termination(iteration, model,
optimizer, opt_param_scheduler):
"""Check for autoresume signal and exit if it is received."""
from megatron.checkpointing import save_checkpoint
args = get_args()
autoresume = get_adlr_autoresume()
# Add barrier to ensure consistnecy.
torch.distributed.barrier()
if autoresume.termination_requested():
if args.save:
save_checkpoint(iteration, model, optimizer, opt_param_scheduler)
print_rank_0(">>> autoresume termination request found!")
if torch.distributed.get_rank() == 0:
autoresume.request_resume()
print_rank_0(">>> training terminated. Returning")
sys.exit(0)
def get_ltor_masks_and_position_ids(data,
eod_token,
reset_position_ids,
reset_attention_mask,
eod_mask_loss):
"""Build masks and position id for left to right model."""
# Extract batch size and sequence length.
micro_batch_size, seq_length = data.size()
# Attention mask (lower triangular).
if reset_attention_mask:
att_mask_batch = micro_batch_size
else:
att_mask_batch = 1
attention_mask = torch.tril(torch.ones(
(att_mask_batch, seq_length, seq_length), device=data.device)).view(
att_mask_batch, 1, seq_length, seq_length)
# Loss mask.
loss_mask = torch.ones(data.size(), dtype=torch.float, device=data.device)
if eod_mask_loss:
loss_mask[data == eod_token] = 0.0
# Position ids.
position_ids = torch.arange(seq_length, dtype=torch.long,
device=data.device)
position_ids = position_ids.unsqueeze(0).expand_as(data)
# We need to clone as the ids will be modifed based on batch index.
if reset_position_ids:
position_ids = position_ids.clone()
if reset_position_ids or reset_attention_mask:
# Loop through the batches:
for b in range(micro_batch_size):
# Find indecies where EOD token is.
eod_index = position_ids[b, data[b] == eod_token]
# Detach indecies from positions if going to modify positions.
if reset_position_ids:
eod_index = eod_index.clone()
# Loop through EOD indecies:
prev_index = 0
for j in range(eod_index.size()[0]):
i = eod_index[j]
# Mask attention loss.
if reset_attention_mask:
attention_mask[b, 0, (i + 1):, :(i + 1)] = 0
# Reset positions.
if reset_position_ids:
position_ids[b, (i + 1):] -= (i + 1 - prev_index)
prev_index = i + 1
# Convert attention mask to binary:
attention_mask = (attention_mask < 0.5)
return attention_mask, loss_mask, position_ids
def print_rank_0(message):
"""If distributed is initialized, print only on rank 0."""
if torch.distributed.is_initialized():
if torch.distributed.get_rank() == 0:
print(message, flush=True)
else:
print(message, flush=True)
def is_last_rank():
return torch.distributed.get_rank() == (
torch.distributed.get_world_size() - 1)
def print_rank_last(message):
"""If distributed is initialized, print only on last rank."""
if torch.distributed.is_initialized():
if is_last_rank():
print(message, flush=True)
else:
print(message, flush=True)
|