Spaces:
Runtime error
Runtime error
File size: 9,502 Bytes
03561be |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 |
"""Modified from https://github.com/mlfoundations/open_flamingo"""
import time
from contextlib import suppress
import torch
from tqdm import tqdm
def get_cast_dtype(precision: str):
cast_dtype = None
if precision == "bf16":
cast_dtype = torch.bfloat16
elif precision == "fp16":
cast_dtype = torch.float16
return cast_dtype
def get_autocast(precision):
if precision == "amp":
return torch.cuda.amp.autocast
elif precision == "amp_bfloat16" or precision == "amp_bf16":
# amp_bfloat16 is more stable than amp float16 for clip training
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16)
else:
return suppress
def train_one_epoch(
args,
model,
epoch,
laion_loader,
mmc4_loader,
tokenizer,
optimizer,
lr_scheduler,
device_id,
wandb,
):
num_batches_per_epoch_laion = laion_loader.num_batches
num_batches_per_epoch_mmc4 = mmc4_loader.num_batches
assert (
num_batches_per_epoch_laion == num_batches_per_epoch_mmc4
), "Number of batches in laion and mmc4 datasets must be the same"
num_batches_per_epoch = num_batches_per_epoch_mmc4
total_training_steps = num_batches_per_epoch * args.num_epochs
autocast = get_autocast(args.precision)
cast_dtype = get_cast_dtype(args.precision)
media_token_id = tokenizer("<image>", add_special_tokens=False)["input_ids"][-1]
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1]
model.train()
# setup logging
step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum)
data_time_m = (
AverageMeter()
) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum)
end = time.time()
# loop through dataloader
for num_steps, (batch_laion, batch_mmc4) in tqdm(
enumerate(zip(laion_loader, mmc4_loader)),
disable=args.rank != 0,
total=total_training_steps,
initial=(epoch * num_batches_per_epoch),
):
data_time_m.update(time.time() - end)
global_step = num_steps + epoch * num_batches_per_epoch
#### LAION FORWARD PASS ####
images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1)
input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True)
attention_mask = batch_laion[1][1].to(device_id, dtype=cast_dtype, non_blocking=True)
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[:, 0] = -100
labels[labels == media_token_id] = -100
labels.to(device_id)
with autocast():
loss_laion = model(
vision_x=images,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels,
)[0]
divided_loss_laion = loss_laion / args.gradient_accumulation_steps
#### C4 FORWARD PASS ####
images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(2)
input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1)
attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1)
# NOTE: irena: expected shape of clip_text_input_ids / attention_mask is (N, I, max_seq_len)
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
labels[:, 0] = -100
for i in range(labels.shape[0]):
# remove loss for any token before the first <image> token
label_idx = 0
while label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id:
labels[i][label_idx] = -100
label_idx += 1
# get index of all endofchunk tokens in the sequence
endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0]
for endofchunk_idx in endofchunk_idxs:
token_idx = endofchunk_idx + 1
while token_idx < labels.shape[1] and labels[i][token_idx] != media_token_id:
labels[i][token_idx] = -100
token_idx += 1
labels[labels == media_token_id] = -100
labels.to(device_id)
with autocast():
loss_mmc4 = model(
vision_x=images,
lang_x=input_ids,
attention_mask=attention_mask,
labels=labels,
)[0]
# if loss is nan, skip this batch
if torch.isnan(loss_mmc4):
print("loss is nan, skipping this batch")
print("input_ids: ", tokenizer.batch_decode(input_ids))
print("labels: ", labels)
print("images: ", images)
optimizer.zero_grad()
continue
divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps
#### BACKWARD PASS ####
loss = divided_loss_laion * args.loss_multiplier_laion + divided_loss_mmc4 * args.loss_multiplier_mmc4
loss.backward()
#### MASK GRADIENTS FOR EMBEDDINGS ####
# Note (anas): Do not apply weight decay to embeddings as it will break this function.
def mask_embedding(m):
if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad:
zero_mask = torch.zeros_like(m.weight.grad)
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id])
zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id])
m.weight.grad = m.weight.grad * zero_mask
model.apply(mask_embedding)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# step optimizer and log
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1):
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
# step time and reset end outside of rank 0
step_time_m.update(time.time() - end)
end = time.time()
if args.rank == 0 and args.report_to_wandb:
# compute within rank 0
laion_samples_per_second = (
args.gradient_accumulation_steps * args.batch_size_laion * args.world_size / step_time_m.val
)
laion_samples_per_second_per_gpu = (
args.gradient_accumulation_steps * args.batch_size_laion / step_time_m.val
)
c4_samples_per_second = (
args.gradient_accumulation_steps * args.batch_size_mmc4 * args.world_size / step_time_m.val
)
c4_samples_per_second_per_gpu = (
args.gradient_accumulation_steps * args.batch_size_mmc4 / step_time_m.val
)
wandb.log(
{
"data_time": data_time_m.avg,
"step_time": step_time_m.avg,
"laion_samples_per_second": laion_samples_per_second,
"laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu,
"c4_samples_per_second": c4_samples_per_second,
"c4_samples_per_second_per_gpu": c4_samples_per_second_per_gpu,
"lr": optimizer.param_groups[0]["lr"],
},
commit=False,
)
step_time_m.reset()
data_time_m.reset()
wandb.log(
{
"loss_laion": divided_loss_laion.item(),
"global_step": global_step,
},
commit=False,
)
wandb.log(
{"loss_mmc4": divided_loss_mmc4.item(), "global_step": global_step},
commit=True,
)
# Log loss to console
if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0:
print(
f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f} // Loss MMC4: {loss_mmc4.item():.3f}"
)
def get_checkpoint(model: torch.nn.Module):
state_dict = model.state_dict()
parameters = {k: v for k, v in model.named_parameters()}
# remove duplicate parameters
duplicate_keys = set(state_dict.keys()) - set(parameters.keys())
for k in duplicate_keys:
del state_dict[k]
# remove non-grad parameters
for name, p in parameters.items():
if not p.requires_grad:
del state_dict[name]
return state_dict
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
|