File size: 14,906 Bytes
1f43fd8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
import collections
import json
import os
from PIL import Image
import numpy as np
import time
import tqdm
import torch
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from torchmetrics import BLEUScore
import torchvision

from fromage import losses as losses_utils
from fromage import utils


def validate(val_loader, model, tokenizer, criterion, epoch, args):
  ngpus_per_node = torch.cuda.device_count()
  writer = SummaryWriter(args.log_dir)
  bleu_scorers = [BLEUScore(n_gram=i) for i in [1, 2, 3, 4]]
  actual_step = (epoch + 1) * args.steps_per_epoch
  model_modes = ['captioning', 'retrieval']
  num_words = 32  # Number of tokens to generate.

  feature_extractor = utils.get_feature_extractor_for_model(args.visual_model, image_size=args.image_size, train=False)

  def get_pixel_values_from_path(path: str):
    img = Image.open(path)
    img = img.resize((args.image_size, args.image_size))
    pixel_values = utils.get_pixel_values_for_model(feature_extractor, img)[None, ...]

    if args.precision == 'fp16':
        pixel_values = pixel_values.half()
    elif args.precision == 'bf16':
        pixel_values = pixel_values.bfloat16()
    if torch.cuda.is_available():
      pixel_values = pixel_values.cuda()
    return pixel_values

  def run_validate(loader, base_progress=0):
    with torch.no_grad():
      end = time.time()
      all_generated_captions = []
      all_gt_captions = []
      all_generated_image_paths = []
      all_image_features = []
      all_text_features = []

      for i, (image_paths, images, caption_images, tgt_tokens, token_len) in tqdm.tqdm(enumerate(loader), position=0, total=len(loader)):
        i = base_progress + i

        if torch.cuda.is_available():
          tgt_tokens = tgt_tokens.cuda(args.gpu, non_blocking=True)
          token_len = token_len.cuda(args.gpu, non_blocking=True)
          images = images.cuda()

        if args.precision == 'fp16':
          images = images.half()
        elif args.precision == 'bf16':
          images = images.bfloat16()

        for model_mode in model_modes:
          (model_output, full_labels, last_embedding, _, visual_embs) = model(
            images, tgt_tokens, token_len, mode=model_mode, input_prefix=args.input_prompt, inference=True)  # (N, T, C)

          if model_mode == 'captioning':
            loss = args.cap_loss_scale * model_output.loss
          elif model_mode == 'retrieval':
            loss = args.ret_loss_scale * model_output.loss
          else:
            raise NotImplementedError

          output = model_output.logits
          if model_mode == 'captioning':
            acc1, acc5 = utils.accuracy(output[:, :-1, :], full_labels[:, 1:], -100, topk=(1, 5))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))
            ce_losses.update(loss.item(), images.size(0))

          if model_mode == 'captioning':
            losses.update(loss.item(), images.size(0))
          elif model_mode == 'retrieval':
            if args.distributed:
              original_last_embedding = torch.clone(last_embedding)
              all_visual_embs = [torch.zeros_like(visual_embs) for _ in range(dist.get_world_size())]
              all_last_embedding = [torch.zeros_like(last_embedding) for _ in range(dist.get_world_size())]
              dist.all_gather(all_visual_embs, visual_embs)
              dist.all_gather(all_last_embedding, last_embedding)

              # Overwrite with embeddings produced on this replica, which track the gradients.
              all_visual_embs[dist.get_rank()] = visual_embs
              all_last_embedding[dist.get_rank()] = last_embedding
              visual_embs = torch.cat(all_visual_embs)
              last_embedding = torch.cat(all_last_embedding)
              start_idx = args.rank * images.shape[0]
              end_idx = start_idx + images.shape[0]
              assert torch.all(last_embedding[start_idx:end_idx] == original_last_embedding), args.rank

            all_text_features.append(last_embedding.cpu())
            all_image_features.append(visual_embs.cpu())

          # Run auto-regressive generation sample
          if model_mode == 'captioning':
            input_embs = model.module.model.get_visual_embs(images, mode='captioning')  # (2, n_visual_tokens, D)
            if args.input_prompt is not None:
              print(f'Adding prefix "{args.input_prompt}" to captioning generate=True.')
              prompt_ids = tokenizer(args.input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
              prompt_ids = prompt_ids.to(visual_embs.device)
              prompt_embs = model.module.model.input_embeddings(prompt_ids)
              prompt_embs = prompt_embs.repeat(input_embs.shape[0], 1, 1)
              input_embs = torch.cat([input_embs, prompt_embs], dim=1)

            generated_ids, _, _ = model(input_embs, tgt_tokens, token_len,
              generate=True, num_words=num_words, temperature=0.0, top_p=1.0,
              min_word_tokens=num_words)

            if args.distributed and ngpus_per_node > 1:
              all_generated_ids = [torch.zeros_like(generated_ids) for _ in range(dist.get_world_size())]
              dist.all_gather(all_generated_ids, generated_ids)
              all_generated_ids[dist.get_rank()] = generated_ids
              generated_ids = torch.cat(all_generated_ids)

              all_tgt_tokens = [torch.zeros_like(tgt_tokens) for _ in range(dist.get_world_size())]
              dist.all_gather(all_tgt_tokens, tgt_tokens)
              all_tgt_tokens[dist.get_rank()] = tgt_tokens
              all_tgt_tokens = torch.cat(all_tgt_tokens)

              all_image_paths = [[None for _ in image_paths] for _ in range(dist.get_world_size())]
              dist.all_gather_object(all_image_paths, image_paths)
              all_image_paths[dist.get_rank()] = image_paths
              image_paths = []
              for p in all_image_paths:
                image_paths.extend(p)
            else:
              all_tgt_tokens = tgt_tokens

            all_tgt_tokens[all_tgt_tokens == -100] = tokenizer.pad_token_id
            generated_captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
            gt_captions = tokenizer.batch_decode(all_tgt_tokens, skip_special_tokens=True)

            for cap_i in range(len(generated_captions)):
              image_path = image_paths[cap_i]
              all_generated_image_paths.append(image_path)
              stop_idx = generated_captions[cap_i].find('.')
              if stop_idx > 5:
                all_generated_captions.append(generated_captions[cap_i][:stop_idx])
              else:
                all_generated_captions.append(generated_captions[cap_i])
              all_gt_captions.append([gt_captions[cap_i]])
          elif model_mode == 'retrieval':
            if i == 0:
              # Generate without image input to visualize text-generation ability.
              input_ids = tgt_tokens[:, :3]  # Use first 3 tokens as initial prompt for generation.
              input_embs = model.module.model.input_embeddings(input_ids)  # (N, T, D)
              generated_ids, _, _ = model(input_embs, tgt_tokens, token_len, generate=True, num_words=num_words, temperature=0.0, top_p=1.0)
              generated_ids = torch.cat([input_ids, generated_ids], dim=1)
              generated_captions = tokenizer.batch_decode(generated_ids, skip_special_tokens=False)
              gt_captions = tokenizer.batch_decode(tgt_tokens, skip_special_tokens=False)
          else:
            raise NotImplementedError

          if i == 0:
            max_to_display = 5
            print('=' * 30)
            print('Generated samples:')
            for cap_i, cap in enumerate(generated_captions[:max_to_display]):
              print(f'{cap_i}) {cap}')
            print('=' * 30)
            print('Real samples:')
            for cap_i, cap in enumerate(gt_captions[:max_to_display]):
              print(f'{cap_i}) {cap}')
            print('=' * 30)

            # Write images and captions to Tensorboard.
            if not args.distributed or (args.rank % ngpus_per_node == 0):
              max_images_to_show = 16
              normalized_images = images - images.min()
              normalized_images /= normalized_images.max()  # (N, 3, H, W)
              # Create generated caption text.
              generated_cap_images = torch.stack([
                utils.create_image_of_text(
                  generated_captions[j].encode('ascii', 'ignore'),
                  width=normalized_images.shape[3],
                  color=(255, 255, 0))
                for j in range(normalized_images.shape[0])], axis=0)
              # Append gt/generated caption images.
              display_images = torch.cat([normalized_images.float().cpu(), caption_images, generated_cap_images], axis=2)[:max_images_to_show]
              grid = torchvision.utils.make_grid(display_images, nrow=int(max_images_to_show ** 0.5), padding=4)
              writer.add_image(f'val/images_{model_mode}', grid, actual_step)

          # measure elapsed time
          batch_time.update(time.time() - end)
          end = time.time()

        if i % args.print_freq == 0:
          progress.display(i + 1)

        if i == args.val_steps_per_epoch - 1:
          break

      # Measure captioning metrics.
      path2captions = collections.defaultdict(list)
      for image_path, caption in zip(all_generated_image_paths, all_gt_captions):
        assert len(caption) == 1, caption
        path2captions[image_path].append(caption[0].replace('[RET]', ''))
      full_gt_captions = [path2captions[path] for path in all_generated_image_paths]

      print(f'Computing BLEU with {len(all_generated_captions)} generated captions:'
            f'{all_generated_captions[:5]} and {len(full_gt_captions)} groundtruth captions:',
            f'{full_gt_captions[:5]}.')
      bleu1_score = bleu_scorers[0](all_generated_captions, full_gt_captions)
      bleu1.update(bleu1_score, 1)
      bleu2_score = bleu_scorers[1](all_generated_captions, full_gt_captions)
      bleu2.update(bleu2_score, 1)
      bleu3_score = bleu_scorers[2](all_generated_captions, full_gt_captions)
      bleu3.update(bleu3_score, 2)
      bleu4_score = bleu_scorers[3](all_generated_captions, full_gt_captions)
      bleu4.update(bleu4_score, 3)

      # Measure retrieval metrics over the entire validation set.
      all_image_features = torch.cat(all_image_features, axis=0)  # (coco_val_len, 2048)
      all_text_features = torch.cat(all_text_features, axis=0)  # (coco_val_len, 2048)

      print(f"Computing similarity between {all_image_features.shape} and {all_text_features.shape}.")
      logits_per_image = all_image_features @ all_text_features.t()
      logits_per_text = logits_per_image.t()
      all_image_acc1, all_image_acc5 = losses_utils.contrastive_acc(logits_per_image, topk=(1, 5))
      all_caption_acc1, all_caption_acc5 = losses_utils.contrastive_acc(logits_per_text, topk=(1, 5))
      image_loss = losses_utils.contrastive_loss(logits_per_image)
      caption_loss = losses_utils.contrastive_loss(logits_per_text)

      loss = args.ret_loss_scale * (image_loss + caption_loss) / 2.0
      losses.update(loss.item(), logits_per_image.size(0))
      top1_caption.update(all_caption_acc1.item(), logits_per_image.size(0))
      top5_caption.update(all_caption_acc5.item(), logits_per_image.size(0))
      top1_image.update(all_image_acc1.item(), logits_per_image.size(0))
      top5_image.update(all_image_acc5.item(), logits_per_image.size(0))


  batch_time = utils.AverageMeter('Time', ':6.3f', utils.Summary.AVERAGE)
  losses = utils.AverageMeter('Loss', ':.4e', utils.Summary.AVERAGE)
  ce_losses = utils.AverageMeter('CeLoss', ':.4e', utils.Summary.AVERAGE)
  top1 = utils.AverageMeter('Acc@1', ':6.2f', utils.Summary.AVERAGE)
  top5 = utils.AverageMeter('Acc@5', ':6.2f', utils.Summary.AVERAGE)
  bleu1 = utils.AverageMeter('BLEU@1', ':6.2f', utils.Summary.AVERAGE)
  bleu2 = utils.AverageMeter('BLEU@2', ':6.2f', utils.Summary.AVERAGE)
  bleu3 = utils.AverageMeter('BLEU@3', ':6.2f', utils.Summary.AVERAGE)
  bleu4 = utils.AverageMeter('BLEU@4', ':6.2f', utils.Summary.AVERAGE)
  top1_caption = utils.AverageMeter('CaptionAcc@1', ':6.2f', utils.Summary.AVERAGE)
  top5_caption = utils.AverageMeter('CaptionAcc@5', ':6.2f', utils.Summary.AVERAGE)
  top1_image = utils.AverageMeter('ImageAcc@1', ':6.2f', utils.Summary.AVERAGE)
  top5_image = utils.AverageMeter('ImageAcc@5', ':6.2f', utils.Summary.AVERAGE)

  progress = utils.ProgressMeter(
    len(val_loader) + (args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset))),
    [batch_time, losses, top1, top5, bleu4],
    prefix='Test: ')

  # switch to evaluate mode
  model.eval()

  run_validate(val_loader)
  if args.distributed:
    batch_time.all_reduce()
    losses.all_reduce()
    bleu1.all_reduce()
    bleu2.all_reduce()
    bleu3.all_reduce()
    bleu4.all_reduce()
    top1.all_reduce()
    top5.all_reduce()
    top1_caption.all_reduce()
    top5_caption.all_reduce()
    top1_image.all_reduce()
    top5_image.all_reduce()

  if args.distributed and (len(val_loader.sampler) * args.world_size < len(val_loader.dataset)):
    aux_val_dataset = Subset(val_loader.dataset,
                 range(len(val_loader.sampler) * args.world_size, len(val_loader.dataset)))
    aux_val_loader = torch.utils.data.DataLoader(
      aux_val_dataset, batch_size=(args.val_batch_size or args.batch_size), shuffle=False,
      num_workers=args.workers, pin_memory=True, collate_fn=data.collate_fn)
    run_validate(aux_val_loader, len(val_loader))

  progress.display_summary()

  writer.add_scalar('val/total_secs_per_batch', batch_time.avg, actual_step)
  writer.add_scalar('val/seq_top1_acc', top1.avg, actual_step)
  writer.add_scalar('val/seq_top5_acc', top5.avg, actual_step)
  writer.add_scalar('val/ce_loss', losses.avg, actual_step)
  writer.add_scalar('val/bleu1', bleu1.avg, actual_step)
  writer.add_scalar('val/bleu2', bleu2.avg, actual_step)
  writer.add_scalar('val/bleu3', bleu3.avg, actual_step)
  writer.add_scalar('val/bleu4', bleu4.avg, actual_step)
  writer.add_scalar('val/contrastive_loss', losses.avg, actual_step)
  writer.add_scalar('val/t2i_top1_acc', top1_caption.avg, actual_step)
  writer.add_scalar('val/t2i_top5_acc', top5_caption.avg, actual_step)
  writer.add_scalar('val/i2t_top1_acc', top1_image.avg, actual_step)
  writer.add_scalar('val/i2t_top5_acc', top5_image.avg, actual_step)
  writer.add_scalar('val/top1_acc', (top1_caption.avg + top1_image.avg) / 2.0, actual_step)
  writer.add_scalar('val/top5_acc', (top5_caption.avg + top5_image.avg) / 2.0, actual_step)

  writer.close()

  # Use top1 accuracy as the metric for keeping the best checkpoint.
  return top1_caption.avg