import collections import json import os from PIL import Image import numpy as np import time import tqdm import torch import torch.distributed as dist from torch.utils.tensorboard import SummaryWriter from torchmetrics import BLEUScore import torchvision from fromage import losses as losses_utils from fromage import utils def validate(val_loader, model, tokenizer, criterion, epoch, args): ngpus_per_node = torch.cuda.device_count() writer = SummaryWriter(args.log_dir) bleu_scorers = [BLEUScore(n_gram=i) for i in [1, 2, 3, 4]] actual_step = (epoch + 1) * args.steps_per_epoch model_modes = ['captioning', 'retrieval'] num_words = 32 # Number of tokens to generate. feature_extractor = utils.get_feature_extractor_for_model(args.visual_model, image_size=args.image_size, train=False) def get_pixel_values_from_path(path: str): img = Image.open(path) img = img.resize((args.image_size, args.image_size)) pixel_values = utils.get_pixel_values_for_model(feature_extractor, img)[None, ...] if args.precision == 'fp16': pixel_values = pixel_values.half() elif args.precision == 'bf16': pixel_values = pixel_values.bfloat16() if torch.cuda.is_available(): pixel_values = pixel_values.cuda() return pixel_values def run_validate(loader, base_progress=0): with torch.no_grad(): end = time.time() all_generated_captions = [] all_gt_captions = [] all_generated_image_paths = [] all_image_features = [] all_text_features = [] for i, (image_paths, images, caption_images, tgt_tokens, token_len) in tqdm.tqdm(enumerate(loader), position=0, total=len(loader)): i = base_progress + i if torch.cuda.is_available(): tgt_tokens = tgt_tokens.cuda(args.gpu, non_blocking=True) token_len = token_len.cuda(args.gpu, non_blocking=True) images = images.cuda() if args.precision == 'fp16': images = images.half() elif args.precision == 'bf16': images = images.bfloat16() for model_mode in model_modes: (model_output, full_labels, last_embedding, _, visual_embs) = model( images, tgt_tokens, token_len, mode=model_mode, input_prefix=args.input_prompt, inference=True) # (N, T, C) if model_mode == 'captioning': loss = args.cap_loss_scale * model_output.loss elif model_mode == 'retrieval': loss = args.ret_loss_scale * model_output.loss else: raise NotImplementedError output = model_output.logits if model_mode == 'captioning': acc1, acc5 = utils.accuracy(output[:, :-1, :], full_labels[:, 1:], -100, topk=(1, 5)) top1.update(acc1[0], images.size(0)) top5.update(acc5[0], images.size(0)) ce_losses.update(loss.item(), images.size(0)) if model_mode == 'captioning': losses.update(loss.item(), images.size(0)) elif model_mode == 'retrieval': if args.distributed: original_last_embedding = torch.clone(last_embedding) all_visual_embs = [torch.zeros_like(visual_embs) for _ in range(dist.get_world_size())] all_last_embedding = [torch.zeros_like(last_embedding) for _ in range(dist.get_world_size())] dist.all_gather(all_visual_embs, visual_embs) dist.all_gather(all_last_embedding, last_embedding) # Overwrite with embeddings produced on this replica, which track the gradients. all_visual_embs[dist.get_rank()] = visual_embs all_last_embedding[dist.get_rank()] = last_embedding visual_embs = torch.cat(all_visual_embs) last_embedding = torch.cat(all_last_embedding) start_idx = args.rank * images.shape[0] end_idx = start_idx + images.shape[0] assert torch.all(last_embedding[start_idx:end_idx] == original_last_embedding), args.rank all_text_features.append(last_embedding.cpu()) all_image_features.append(visual_embs.cpu()) # Run auto-regressive generation sample if model_mode == 'captioning': input_embs = model.module.model.get_visual_embs(images, mode='captioning') # (2, n_visual_tokens, D) if args.input_prompt is not None: print(f'Adding prefix "{args.input_prompt}" to captioning generate=True.') prompt_ids = tokenizer(args.input_prompt, add_special_tokens=False, return_tensors="pt").input_ids prompt_ids = prompt_ids.to(visual_embs.device) prompt_embs = model.module.model.input_embeddings(prompt_ids) prompt_embs = prompt_embs.repeat(input_embs.shape[0], 1, 1) input_embs = torch.cat([input_embs, prompt_embs], dim=1) generated_ids, _, _ = model(input_embs, tgt_tokens, token_len, generate=True, num_words=num_words, temperature=0.0, top_p=1.0, min_word_tokens=num_words) if args.distributed and ngpus_per_node > 1: all_generated_ids = [torch.zeros_like(generated_ids) for _ in range(dist.get_world_size())] dist.all_gather(all_generated_ids, generated_ids) all_generated_ids[dist.get_rank()] = generated_ids generated_ids = torch.cat(all_generated_ids) all_tgt_tokens = [torch.zeros_like(tgt_tokens) for _ in range(dist.get_world_size())] dist.all_gather(all_tgt_tokens, tgt_tokens) all_tgt_tokens[dist.get_rank()] = tgt_tokens all_tgt_tokens = torch.cat(all_tgt_tokens) all_image_paths = [[None for _ in image_paths] for _ in range(dist.get_world_size())] dist.all_gather_object(all_image_paths, image_paths) all_image_paths[dist.get_rank()] = image_paths image_paths = [] for p in all_image_paths: image_paths.extend(p) else: all_tgt_tokens = tgt_tokens all_tgt_tokens[all_tgt_tokens == -100] = tokenizer.pad_token_id generated_captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) gt_captions = tokenizer.batch_decode(all_tgt_tokens, skip_special_tokens=True) for cap_i in range(len(generated_captions)): image_path = image_paths[cap_i] all_generated_image_paths.append(image_path) stop_idx = generated_captions[cap_i].find('.') if stop_idx > 5: all_generated_captions.append(generated_captions[cap_i][:stop_idx]) else: all_generated_captions.append(generated_captions[cap_i]) all_gt_captions.append([gt_captions[cap_i]]) elif model_mode == 'retrieval': if i == 0: # Generate without image input to visualize text-generation ability. input_ids = tgt_tokens[:, :3] # Use first 3 tokens as initial prompt for generation. input_embs = model.module.model.input_embeddings(input_ids) # (N, T, D) generated_ids, _, _ = model(input_embs, tgt_tokens, token_len, generate=True, num_words=num_words, temperature=0.0, top_p=1.0) generated_ids = torch.cat([input_ids, generated_ids], dim=1) generated_captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=False) gt_captions = tokenizer.batch_decode(tgt_tokens, skip_special_tokens=False) else: raise NotImplementedError if i == 0: max_to_display = 5 print('=' * 30) print('Generated samples:') for cap_i, cap in enumerate(generated_captions[:max_to_display]): print(f'{cap_i}) {cap}') print('=' * 30) print('Real samples:') for cap_i, cap in enumerate(gt_captions[:max_to_display]): print(f'{cap_i}) {cap}') print('=' * 30) # Write images and captions to Tensorboard. if not args.distributed or (args.rank % ngpus_per_node == 0): max_images_to_show = 16 normalized_images = images - images.min() normalized_images /= normalized_images.max() # (N, 3, H, W) # Create generated caption text. generated_cap_images = torch.stack([ utils.create_image_of_text( generated_captions[j].encode('ascii', 'ignore'), width=normalized_images.shape[3], color=(255, 255, 0)) for j in range(normalized_images.shape[0])], axis=0) # Append gt/generated caption images. display_images = torch.cat([normalized_images.float().cpu(), caption_images, generated_cap_images], axis=2)[:max_images_to_show] grid = torchvision.utils.make_grid(display_images, nrow=int(max_images_to_show ** 0.5), padding=4) writer.add_image(f'val/images_{model_mode}', grid, actual_step) # measure elapsed time batch_time.update(time.time() - end) end = time.time() if i % args.print_freq == 0: progress.display(i + 1) if i == args.val_steps_per_epoch - 1: break # Measure captioning metrics. path2captions = collections.defaultdict(list) for image_path, caption in zip(all_generated_image_paths, all_gt_captions): assert len(caption) == 1, caption path2captions[image_path].append(caption[0].replace('[RET]', '')) full_gt_captions = [path2captions[path] for path in all_generated_image_paths] print(f'Computing BLEU with {len(all_generated_captions)} generated captions:' f'{all_generated_captions[:5]} and {len(full_gt_captions)} groundtruth captions:', f'{full_gt_captions[:5]}.') bleu1_score = bleu_scorers[0](all_generated_captions, full_gt_captions) bleu1.update(bleu1_score, 1) bleu2_score = bleu_scorers[1](all_generated_captions, full_gt_captions) bleu2.update(bleu2_score, 1) bleu3_score = bleu_scorers[2](all_generated_captions, full_gt_captions) bleu3.update(bleu3_score, 2) bleu4_score = bleu_scorers[3](all_generated_captions, full_gt_captions) bleu4.update(bleu4_score, 3) # Measure retrieval metrics over the entire validation set. all_image_features = torch.cat(all_image_features, axis=0) # (coco_val_len, 2048) all_text_features = torch.cat(all_text_features, axis=0) # (coco_val_len, 2048) print(f"Computing similarity between {all_image_features.shape} and {all_text_features.shape}.") logits_per_image = all_image_features @ all_text_features.t() logits_per_text = logits_per_image.t() all_image_acc1, all_image_acc5 = losses_utils.contrastive_acc(logits_per_image, topk=(1, 5)) all_caption_acc1, all_caption_acc5 = losses_utils.contrastive_acc(logits_per_text, topk=(1, 5)) image_loss = losses_utils.contrastive_loss(logits_per_image) caption_loss = losses_utils.contrastive_loss(logits_per_text) loss = args.ret_loss_scale * (image_loss + caption_loss) / 2.0 losses.update(loss.item(), logits_per_image.size(0)) top1_caption.update(all_caption_acc1.item(), logits_per_image.size(0)) top5_caption.update(all_caption_acc5.item(), logits_per_image.size(0)) top1_image.update(all_image_acc1.item(), logits_per_image.size(0)) top5_image.update(all_image_acc5.item(), logits_per_image.size(0)) batch_time = utils.AverageMeter('Time', ':6.3f', utils.Summary.AVERAGE) losses = utils.AverageMeter('Loss', ':.4e', utils.Summary.AVERAGE) ce_losses = utils.AverageMeter('CeLoss', ':.4e', utils.Summary.AVERAGE) top1 = utils.AverageMeter('Acc@1', ':6.2f', utils.Summary.AVERAGE) top5 = utils.AverageMeter('Acc@5', ':6.2f', utils.Summary.AVERAGE) bleu1 = utils.AverageMeter('BLEU@1', ':6.2f', utils.Summary.AVERAGE) bleu2 = utils.AverageMeter('BLEU@2', ':6.2f', utils.Summary.AVERAGE) bleu3 = utils.AverageMeter('BLEU@3', ':6.2f', utils.Summary.AVERAGE) bleu4 = utils.AverageMeter('BLEU@4', ':6.2f', utils.Summary.AVERAGE) top1_caption = utils.AverageMeter('CaptionAcc@1', ':6.2f', utils.Summary.AVERAGE) top5_caption = utils.AverageMeter('CaptionAcc@5', ':6.2f', utils.Summary.AVERAGE) top1_image = utils.AverageMeter('ImageAcc@1', ':6.2f', utils.Summary.AVERAGE) top5_image = utils.AverageMeter('ImageAcc@5', ':6.2f', utils.Summary.AVERAGE) progress = utils.ProgressMeter( len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))), [batch_time, losses, top1, top5, bleu4], prefix='Test: ') # switch to evaluate mode model.eval() run_validate(val_loader) if args.distributed: batch_time.all_reduce() losses.all_reduce() bleu1.all_reduce() bleu2.all_reduce() bleu3.all_reduce() bleu4.all_reduce() top1.all_reduce() top5.all_reduce() top1_caption.all_reduce() top5_caption.all_reduce() top1_image.all_reduce() top5_image.all_reduce() if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)): aux_val_dataset = Subset(val_loader.dataset, range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset))) aux_val_loader = torch.utils.data.DataLoader( aux_val_dataset, batch_size=(args.val_batch_size or args.batch_size), shuffle=False, num_workers=args.workers, pin_memory=True, collate_fn=data.collate_fn) run_validate(aux_val_loader, len(val_loader)) progress.display_summary() writer.add_scalar('val/total_secs_per_batch', batch_time.avg, actual_step) writer.add_scalar('val/seq_top1_acc', top1.avg, actual_step) writer.add_scalar('val/seq_top5_acc', top5.avg, actual_step) writer.add_scalar('val/ce_loss', losses.avg, actual_step) writer.add_scalar('val/bleu1', bleu1.avg, actual_step) writer.add_scalar('val/bleu2', bleu2.avg, actual_step) writer.add_scalar('val/bleu3', bleu3.avg, actual_step) writer.add_scalar('val/bleu4', bleu4.avg, actual_step) writer.add_scalar('val/contrastive_loss', losses.avg, actual_step) writer.add_scalar('val/t2i_top1_acc', top1_caption.avg, actual_step) writer.add_scalar('val/t2i_top5_acc', top5_caption.avg, actual_step) writer.add_scalar('val/i2t_top1_acc', top1_image.avg, actual_step) writer.add_scalar('val/i2t_top5_acc', top5_image.avg, actual_step) writer.add_scalar('val/top1_acc', (top1_caption.avg + top1_image.avg) / 2.0, actual_step) writer.add_scalar('val/top5_acc', (top5_caption.avg + top5_image.avg) / 2.0, actual_step) writer.close() # Use top1 accuracy as the metric for keeping the best checkpoint. return top1_caption.avg