Spaces:
Runtime error
Runtime error
"""Modified from https://github.com/mlfoundations/open_flamingo""" | |
import time | |
from contextlib import suppress | |
import torch | |
from tqdm import tqdm | |
def get_cast_dtype(precision: str): | |
cast_dtype = None | |
if precision == "bf16": | |
cast_dtype = torch.bfloat16 | |
elif precision == "fp16": | |
cast_dtype = torch.float16 | |
return cast_dtype | |
def get_autocast(precision): | |
if precision == "amp": | |
return torch.cuda.amp.autocast | |
elif precision == "amp_bfloat16" or precision == "amp_bf16": | |
# amp_bfloat16 is more stable than amp float16 for clip training | |
return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) | |
else: | |
return suppress | |
def train_one_epoch( | |
args, | |
model, | |
epoch, | |
laion_loader, | |
mmc4_loader, | |
tokenizer, | |
optimizer, | |
lr_scheduler, | |
device_id, | |
wandb, | |
): | |
num_batches_per_epoch_laion = laion_loader.num_batches | |
num_batches_per_epoch_mmc4 = mmc4_loader.num_batches | |
assert ( | |
num_batches_per_epoch_laion == num_batches_per_epoch_mmc4 | |
), "Number of batches in laion and mmc4 datasets must be the same" | |
num_batches_per_epoch = num_batches_per_epoch_mmc4 | |
total_training_steps = num_batches_per_epoch * args.num_epochs | |
autocast = get_autocast(args.precision) | |
cast_dtype = get_cast_dtype(args.precision) | |
media_token_id = tokenizer("<image>", add_special_tokens=False)["input_ids"][-1] | |
endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] | |
model.train() | |
# setup logging | |
step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum) | |
data_time_m = ( | |
AverageMeter() | |
) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) | |
end = time.time() | |
# loop through dataloader | |
for num_steps, (batch_laion, batch_mmc4) in tqdm( | |
enumerate(zip(laion_loader, mmc4_loader)), | |
disable=args.rank != 0, | |
total=total_training_steps, | |
initial=(epoch * num_batches_per_epoch), | |
): | |
data_time_m.update(time.time() - end) | |
global_step = num_steps + epoch * num_batches_per_epoch | |
#### LAION FORWARD PASS #### | |
images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1) | |
input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True) | |
attention_mask = batch_laion[1][1].to(device_id, dtype=cast_dtype, non_blocking=True) | |
labels = input_ids.clone() | |
labels[labels == tokenizer.pad_token_id] = -100 | |
labels[:, 0] = -100 | |
labels[labels == media_token_id] = -100 | |
labels.to(device_id) | |
with autocast(): | |
loss_laion = model( | |
vision_x=images, | |
lang_x=input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
)[0] | |
divided_loss_laion = loss_laion / args.gradient_accumulation_steps | |
#### C4 FORWARD PASS #### | |
images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(2) | |
input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1) | |
attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1) | |
# NOTE: irena: expected shape of clip_text_input_ids / attention_mask is (N, I, max_seq_len) | |
labels = input_ids.clone() | |
labels[labels == tokenizer.pad_token_id] = -100 | |
labels[:, 0] = -100 | |
for i in range(labels.shape[0]): | |
# remove loss for any token before the first <image> token | |
label_idx = 0 | |
while label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id: | |
labels[i][label_idx] = -100 | |
label_idx += 1 | |
# get index of all endofchunk tokens in the sequence | |
endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0] | |
for endofchunk_idx in endofchunk_idxs: | |
token_idx = endofchunk_idx + 1 | |
while token_idx < labels.shape[1] and labels[i][token_idx] != media_token_id: | |
labels[i][token_idx] = -100 | |
token_idx += 1 | |
labels[labels == media_token_id] = -100 | |
labels.to(device_id) | |
with autocast(): | |
loss_mmc4 = model( | |
vision_x=images, | |
lang_x=input_ids, | |
attention_mask=attention_mask, | |
labels=labels, | |
)[0] | |
# if loss is nan, skip this batch | |
if torch.isnan(loss_mmc4): | |
print("loss is nan, skipping this batch") | |
print("input_ids: ", tokenizer.batch_decode(input_ids)) | |
print("labels: ", labels) | |
print("images: ", images) | |
optimizer.zero_grad() | |
continue | |
divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps | |
#### BACKWARD PASS #### | |
loss = divided_loss_laion * args.loss_multiplier_laion + divided_loss_mmc4 * args.loss_multiplier_mmc4 | |
loss.backward() | |
#### MASK GRADIENTS FOR EMBEDDINGS #### | |
# Note (anas): Do not apply weight decay to embeddings as it will break this function. | |
def mask_embedding(m): | |
if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad: | |
zero_mask = torch.zeros_like(m.weight.grad) | |
zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) | |
zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id]) | |
m.weight.grad = m.weight.grad * zero_mask | |
model.apply(mask_embedding) | |
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
# step optimizer and log | |
if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1): | |
optimizer.step() | |
lr_scheduler.step() | |
optimizer.zero_grad() | |
# step time and reset end outside of rank 0 | |
step_time_m.update(time.time() - end) | |
end = time.time() | |
if args.rank == 0 and args.report_to_wandb: | |
# compute within rank 0 | |
laion_samples_per_second = ( | |
args.gradient_accumulation_steps * args.batch_size_laion * args.world_size / step_time_m.val | |
) | |
laion_samples_per_second_per_gpu = ( | |
args.gradient_accumulation_steps * args.batch_size_laion / step_time_m.val | |
) | |
c4_samples_per_second = ( | |
args.gradient_accumulation_steps * args.batch_size_mmc4 * args.world_size / step_time_m.val | |
) | |
c4_samples_per_second_per_gpu = ( | |
args.gradient_accumulation_steps * args.batch_size_mmc4 / step_time_m.val | |
) | |
wandb.log( | |
{ | |
"data_time": data_time_m.avg, | |
"step_time": step_time_m.avg, | |
"laion_samples_per_second": laion_samples_per_second, | |
"laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu, | |
"c4_samples_per_second": c4_samples_per_second, | |
"c4_samples_per_second_per_gpu": c4_samples_per_second_per_gpu, | |
"lr": optimizer.param_groups[0]["lr"], | |
}, | |
commit=False, | |
) | |
step_time_m.reset() | |
data_time_m.reset() | |
wandb.log( | |
{ | |
"loss_laion": divided_loss_laion.item(), | |
"global_step": global_step, | |
}, | |
commit=False, | |
) | |
wandb.log( | |
{"loss_mmc4": divided_loss_mmc4.item(), "global_step": global_step}, | |
commit=True, | |
) | |
# Log loss to console | |
if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: | |
print( | |
f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f} // Loss MMC4: {loss_mmc4.item():.3f}" | |
) | |
def get_checkpoint(model: torch.nn.Module): | |
state_dict = model.state_dict() | |
parameters = {k: v for k, v in model.named_parameters()} | |
# remove duplicate parameters | |
duplicate_keys = set(state_dict.keys()) - set(parameters.keys()) | |
for k in duplicate_keys: | |
del state_dict[k] | |
# remove non-grad parameters | |
for name, p in parameters.items(): | |
if not p.requires_grad: | |
del state_dict[name] | |
return state_dict | |
class AverageMeter(object): | |
"""Computes and stores the average and current value""" | |
def __init__(self): | |
self.reset() | |
def reset(self): | |
self.val = 0 | |
self.avg = 0 | |
self.sum = 0 | |
self.count = 0 | |
def update(self, val, n=1): | |
self.val = val | |
self.sum += val * n | |
self.count += n | |
self.avg = self.sum / self.count | |