"""Modified from https://github.com/mlfoundations/open_flamingo""" import time from contextlib import suppress import torch from tqdm import tqdm def get_cast_dtype(precision: str): cast_dtype = None if precision == "bf16": cast_dtype = torch.bfloat16 elif precision == "fp16": cast_dtype = torch.float16 return cast_dtype def get_autocast(precision): if precision == "amp": return torch.cuda.amp.autocast elif precision == "amp_bfloat16" or precision == "amp_bf16": # amp_bfloat16 is more stable than amp float16 for clip training return lambda: torch.cuda.amp.autocast(dtype=torch.bfloat16) else: return suppress def train_one_epoch( args, model, epoch, laion_loader, mmc4_loader, tokenizer, optimizer, lr_scheduler, device_id, wandb, ): num_batches_per_epoch_laion = laion_loader.num_batches num_batches_per_epoch_mmc4 = mmc4_loader.num_batches assert ( num_batches_per_epoch_laion == num_batches_per_epoch_mmc4 ), "Number of batches in laion and mmc4 datasets must be the same" num_batches_per_epoch = num_batches_per_epoch_mmc4 total_training_steps = num_batches_per_epoch * args.num_epochs autocast = get_autocast(args.precision) cast_dtype = get_cast_dtype(args.precision) media_token_id = tokenizer("", add_special_tokens=False)["input_ids"][-1] endofchunk_token_id = tokenizer("<|endofchunk|>", add_special_tokens=False)["input_ids"][-1] model.train() # setup logging step_time_m = AverageMeter() # time for one optimizer step (> 1 batch if using gradient accum) data_time_m = ( AverageMeter() ) # avg time to load one batch of both C4 AND laion (= 1 batch regardless of gradient accum) end = time.time() # loop through dataloader for num_steps, (batch_laion, batch_mmc4) in tqdm( enumerate(zip(laion_loader, mmc4_loader)), disable=args.rank != 0, total=total_training_steps, initial=(epoch * num_batches_per_epoch), ): data_time_m.update(time.time() - end) global_step = num_steps + epoch * num_batches_per_epoch #### LAION FORWARD PASS #### images = batch_laion[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(1).unsqueeze(1) input_ids = batch_laion[1][0].to(device_id, dtype=cast_dtype, non_blocking=True) attention_mask = batch_laion[1][1].to(device_id, dtype=cast_dtype, non_blocking=True) labels = input_ids.clone() labels[labels == tokenizer.pad_token_id] = -100 labels[:, 0] = -100 labels[labels == media_token_id] = -100 labels.to(device_id) with autocast(): loss_laion = model( vision_x=images, lang_x=input_ids, attention_mask=attention_mask, labels=labels, )[0] divided_loss_laion = loss_laion / args.gradient_accumulation_steps #### C4 FORWARD PASS #### images = batch_mmc4[0].to(device_id, dtype=cast_dtype, non_blocking=True).unsqueeze(2) input_ids = torch.stack([x[0] for x in batch_mmc4[1]]).squeeze(1) attention_mask = torch.stack([x[1] for x in batch_mmc4[1]]).squeeze(1) # NOTE: irena: expected shape of clip_text_input_ids / attention_mask is (N, I, max_seq_len) labels = input_ids.clone() labels[labels == tokenizer.pad_token_id] = -100 labels[:, 0] = -100 for i in range(labels.shape[0]): # remove loss for any token before the first token label_idx = 0 while label_idx < labels.shape[1] and labels[i][label_idx] != media_token_id: labels[i][label_idx] = -100 label_idx += 1 # get index of all endofchunk tokens in the sequence endofchunk_idxs = torch.where(labels[i] == endofchunk_token_id)[0] for endofchunk_idx in endofchunk_idxs: token_idx = endofchunk_idx + 1 while token_idx < labels.shape[1] and labels[i][token_idx] != media_token_id: labels[i][token_idx] = -100 token_idx += 1 labels[labels == media_token_id] = -100 labels.to(device_id) with autocast(): loss_mmc4 = model( vision_x=images, lang_x=input_ids, attention_mask=attention_mask, labels=labels, )[0] # if loss is nan, skip this batch if torch.isnan(loss_mmc4): print("loss is nan, skipping this batch") print("input_ids: ", tokenizer.batch_decode(input_ids)) print("labels: ", labels) print("images: ", images) optimizer.zero_grad() continue divided_loss_mmc4 = loss_mmc4 / args.gradient_accumulation_steps #### BACKWARD PASS #### loss = divided_loss_laion * args.loss_multiplier_laion + divided_loss_mmc4 * args.loss_multiplier_mmc4 loss.backward() #### MASK GRADIENTS FOR EMBEDDINGS #### # Note (anas): Do not apply weight decay to embeddings as it will break this function. def mask_embedding(m): if isinstance(m, torch.nn.Embedding) and m.weight.requires_grad: zero_mask = torch.zeros_like(m.weight.grad) zero_mask[media_token_id] = torch.ones_like(zero_mask[media_token_id]) zero_mask[endofchunk_token_id] = torch.ones_like(zero_mask[endofchunk_token_id]) m.weight.grad = m.weight.grad * zero_mask model.apply(mask_embedding) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # step optimizer and log if (((num_steps + 1) % args.gradient_accumulation_steps) == 0) or (num_steps == num_batches_per_epoch - 1): optimizer.step() lr_scheduler.step() optimizer.zero_grad() # step time and reset end outside of rank 0 step_time_m.update(time.time() - end) end = time.time() if args.rank == 0 and args.report_to_wandb: # compute within rank 0 laion_samples_per_second = ( args.gradient_accumulation_steps * args.batch_size_laion * args.world_size / step_time_m.val ) laion_samples_per_second_per_gpu = ( args.gradient_accumulation_steps * args.batch_size_laion / step_time_m.val ) c4_samples_per_second = ( args.gradient_accumulation_steps * args.batch_size_mmc4 * args.world_size / step_time_m.val ) c4_samples_per_second_per_gpu = ( args.gradient_accumulation_steps * args.batch_size_mmc4 / step_time_m.val ) wandb.log( { "data_time": data_time_m.avg, "step_time": step_time_m.avg, "laion_samples_per_second": laion_samples_per_second, "laion_samples_per_second_per_gpu": laion_samples_per_second_per_gpu, "c4_samples_per_second": c4_samples_per_second, "c4_samples_per_second_per_gpu": c4_samples_per_second_per_gpu, "lr": optimizer.param_groups[0]["lr"], }, commit=False, ) step_time_m.reset() data_time_m.reset() wandb.log( { "loss_laion": divided_loss_laion.item(), "global_step": global_step, }, commit=False, ) wandb.log( {"loss_mmc4": divided_loss_mmc4.item(), "global_step": global_step}, commit=True, ) # Log loss to console if ((num_steps + 1) % args.logging_steps == 0) and args.rank == 0: print( f"Step {num_steps+1}/{num_batches_per_epoch} of epoch {epoch+1}/{args.num_epochs} complete. Loss LAION: {loss_laion.item():.3f} // Loss MMC4: {loss_mmc4.item():.3f}" ) def get_checkpoint(model: torch.nn.Module): state_dict = model.state_dict() parameters = {k: v for k, v in model.named_parameters()} # remove duplicate parameters duplicate_keys = set(state_dict.keys()) - set(parameters.keys()) for k in duplicate_keys: del state_dict[k] # remove non-grad parameters for name, p in parameters.items(): if not p.requires_grad: del state_dict[name] return state_dict class AverageMeter(object): """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count