ryanramos commited on
Commit
ba55bb6
·
1 Parent(s): be8362f

Delete finetune_retrieval.py

Browse files
Files changed (1) hide show
  1. finetune_retrieval.py +0 -400
finetune_retrieval.py DELETED
@@ -1,400 +0,0 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
- #
4
- # This source code is licensed under the BSD-style license found in the
5
- # LICENSE file in the root directory of this source tree.
6
-
7
- import argparse
8
- import datetime
9
- import os
10
- import random
11
- import time
12
-
13
- import ruamel.yaml as yaml
14
- import torch
15
- import torch.backends.cudnn as cudnn
16
- import torch.distributed as dist
17
- from data.retrieval_datamodule import RetrievalDataModule
18
- from model import albef_model_for_retrieval
19
- from torch.optim import AdamW
20
- from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
21
- from utils import (
22
- add_weight_decay,
23
- get_rank,
24
- get_world_size,
25
- init_distributed_mode,
26
- is_dist_avail_and_initialized,
27
- is_main_process,
28
- )
29
-
30
-
31
- def train(model, datamodule, args, device):
32
- model.train()
33
-
34
- model_without_ddp = model.module if is_dist_avail_and_initialized() else model
35
-
36
- optimizer_params = add_weight_decay(model, args["weight_decay"])
37
- optimizer = AdamW(optimizer_params, lr=args["lr"])
38
- scheduler = CosineAnnealingWarmRestarts(
39
- optimizer, T_0=args["max_epochs"], eta_min=args["min_lr"]
40
- )
41
-
42
- step_size = args["step_size"]
43
- warmup_steps = args["warmup_steps"]
44
- warmup_iterations = warmup_steps * step_size
45
-
46
- data_loader = datamodule.train_dataloader(
47
- is_distributed=is_dist_avail_and_initialized(),
48
- num_tasks=get_world_size(),
49
- global_rank=get_rank(),
50
- )
51
-
52
- start_time = time.time()
53
-
54
- for epoch in range(args["max_epochs"]):
55
- if epoch > 0:
56
- scheduler.step(epoch + warmup_steps)
57
-
58
- for batch, (image, text, text_atts, idx) in enumerate(data_loader):
59
- if epoch > 0:
60
- alpha = args["alpha"]
61
- else:
62
- alpha = args["alpha"] * min(1, batch / len(data_loader))
63
-
64
- image = image.to(device, non_blocking=True)
65
- text = text.to(device)
66
- text_atts = text_atts.to(device)
67
- idx = idx.to(device, non_blocking=True)
68
- loss = model(image, text, text_atts, idx, alpha, is_train=True)
69
-
70
- optimizer.zero_grad()
71
- loss.backward()
72
- optimizer.step()
73
-
74
- if epoch == 0 and batch % step_size == 0 and batch <= warmup_iterations:
75
- scheduler.step(batch // step_size)
76
-
77
- if batch % args["log_every_n_steps"] == 0:
78
- total_time = time.time() - start_time
79
- time_str = "time {},".format(
80
- datetime.timedelta(seconds=int(total_time))
81
- )
82
- epoch_str = "epoch {}/{},".format(epoch, args["max_epochs"])
83
- batch_str = "batch {}/{},".format(batch, len(data_loader))
84
- loss_str = "loss {}".format(loss.item())
85
- print(time_str, epoch_str, batch_str, loss_str)
86
-
87
- if is_main_process():
88
- save_obj = {
89
- "model": model_without_ddp.state_dict(),
90
- "optimizer": optimizer.state_dict(),
91
- "lr_scheduler": scheduler.state_dict(),
92
- "epoch": epoch,
93
- }
94
- torch.save(
95
- save_obj,
96
- os.path.join(
97
- args["checkpoint_root"], "retrieval_checkpoint_%02d.pt" % epoch
98
- ),
99
- )
100
-
101
- if is_dist_avail_and_initialized():
102
- dist.barrier()
103
- torch.cuda.empty_cache()
104
-
105
-
106
- @torch.no_grad()
107
- def encode_text(model, text_dataloader, device):
108
- text_embeds = []
109
- text_feats = []
110
- text_atts = []
111
- for text, text_att in text_dataloader:
112
- text = text.to(device)
113
- text_att = text_att.to(device)
114
- text_embed, text_feat = model(
115
- text=text, text_atts=text_att, input_type="text", is_train=False
116
- )
117
- text_embeds.append(text_embed)
118
- text_feats.append(text_feat)
119
- text_atts.append(text_att)
120
- text_embeds = torch.cat(text_embeds, dim=0)
121
- text_feats = torch.cat(text_feats, dim=0)
122
- text_atts = torch.cat(text_atts, dim=0)
123
- return text_embeds, text_feats, text_atts
124
-
125
-
126
- @torch.no_grad()
127
- def encode_image(model, image_dataloader, device):
128
- image_embeds = []
129
- image_feats = []
130
- for image in image_dataloader:
131
- image = image.to(device)
132
- image_embed, image_feat = model(image=image, input_type="image", is_train=False)
133
- image_embeds.append(image_embed)
134
- image_feats.append(image_feat)
135
- image_embeds = torch.cat(image_embeds, dim=0)
136
- image_feats = torch.cat(image_feats, dim=0)
137
- return image_embeds, image_feats
138
-
139
-
140
- @torch.no_grad()
141
- def image_to_text(
142
- model,
143
- image_embeds,
144
- text_embeds,
145
- text_atts,
146
- sims_matrix,
147
- num_images,
148
- num_text,
149
- device,
150
- args,
151
- ):
152
- start_time = time.time()
153
- world_size = get_world_size()
154
- rank = get_rank()
155
- step = sims_matrix.size(0) // world_size + 1
156
- start = rank * step
157
- end = min(sims_matrix.size(0), start + step)
158
- k = args["k_test"]
159
-
160
- image_to_text_scores = torch.full((num_images, num_text), -100.0).to(device)
161
- for i, sims in enumerate(sims_matrix[start:end]):
162
- _, topk_idx = sims.topk(k, dim=0)
163
-
164
- score = model(
165
- image=image_embeds[start + i].repeat(k, 1, 1),
166
- text=text_embeds[topk_idx],
167
- text_atts=text_atts[topk_idx],
168
- input_type="multimodal",
169
- is_train=False,
170
- )
171
- image_to_text_scores[start + i, topk_idx] = score
172
-
173
- if i % args["log_every_n_steps"] == 0:
174
- total_time = time.time() - start_time
175
- time_str = "time {},".format(datetime.timedelta(seconds=int(total_time)))
176
- batch_str = "batch {}/{},".format(i, len(sims_matrix[start:end]))
177
- print("image to text retrieval", time_str, batch_str)
178
- return image_to_text_scores
179
-
180
-
181
- @torch.no_grad()
182
- def text_to_image(
183
- model,
184
- image_embeds,
185
- text_embeds,
186
- text_atts,
187
- sims_matrix,
188
- num_images,
189
- num_text,
190
- device,
191
- args,
192
- ):
193
- start_time = time.time()
194
- world_size = get_world_size()
195
- rank = get_rank()
196
- step = sims_matrix.size(0) // world_size + 1
197
- start = rank * step
198
- end = min(sims_matrix.size(0), start + step)
199
- k = args["k_test"]
200
-
201
- text_to_image_scores = torch.full((num_text, num_images), -100.0).to(device)
202
- for i, sims in enumerate(sims_matrix[start:end]):
203
- _, topk_idx = sims.topk(k, dim=0)
204
- score = model(
205
- image=image_embeds[topk_idx],
206
- text=text_embeds[start + i].repeat(k, 1, 1),
207
- text_atts=text_atts[start + i].repeat(k, 1, 1),
208
- input_type="multimodal",
209
- is_train=False,
210
- )
211
- text_to_image_scores[start + i, topk_idx] = score
212
-
213
- if i % args["log_every_n_steps"] == 0:
214
- total_time = time.time() - start_time
215
- time_str = "time {},".format(datetime.timedelta(seconds=int(total_time)))
216
- batch_str = "batch {}/{},".format(i, len(sims_matrix[start:end]))
217
- print("text to image retrieval", time_str, batch_str)
218
- return text_to_image_scores
219
-
220
-
221
- @torch.no_grad()
222
- def evaluation(model, datamodule, args, device):
223
- model.eval()
224
-
225
- text_loader = datamodule.text_dataloader()
226
- image_loader = datamodule.image_dataloader()
227
- num_images = len(datamodule.image_dataset)
228
- num_text = len(datamodule.text_dataset)
229
-
230
- text_embeds, text_feats, text_atts = encode_text(model, text_loader, device)
231
- image_embeds, image_feats = encode_image(model, image_loader, device)
232
-
233
- sims_matrix = image_feats @ text_feats.t()
234
- image_to_text_scores = image_to_text(
235
- model,
236
- image_embeds,
237
- text_embeds,
238
- text_atts,
239
- sims_matrix,
240
- num_images,
241
- num_text,
242
- device,
243
- args,
244
- )
245
-
246
- sims_matrix = sims_matrix.t()
247
- text_to_image_scores = text_to_image(
248
- model,
249
- image_embeds,
250
- text_embeds,
251
- text_atts,
252
- sims_matrix,
253
- num_images,
254
- num_text,
255
- device,
256
- args,
257
- )
258
-
259
- if is_dist_avail_and_initialized():
260
- dist.barrier()
261
- torch.distributed.all_reduce(
262
- image_to_text_scores, op=torch.distributed.ReduceOp.SUM
263
- )
264
- torch.distributed.all_reduce(
265
- text_to_image_scores, op=torch.distributed.ReduceOp.SUM
266
- )
267
-
268
- return image_to_text_scores.cpu(), text_to_image_scores.cpu()
269
-
270
-
271
- @torch.no_grad()
272
- def itm_eval(
273
- image_to_text_scores,
274
- text_to_image_scores,
275
- image_to_text_mapping,
276
- text_to_image_mapping,
277
- ):
278
- # Images to Text
279
- ranks = torch.zeros(image_to_text_scores.size(0))
280
- for index, score in enumerate(image_to_text_scores):
281
- inds = torch.flip(torch.argsort(score), dims=[0])
282
- rank = 1e10
283
- # each image has multiple text mappings
284
- # check retrieved inds with each ground truth mappping i
285
- for i in image_to_text_mapping[index]:
286
- tmp = torch.where(inds == i)[0][0]
287
- if tmp < rank:
288
- rank = tmp
289
- ranks[index] = rank
290
-
291
- # Compute metrics
292
- tr1 = 100.0 * len(torch.where(ranks < 1)[0]) / len(ranks)
293
- tr5 = 100.0 * len(torch.where(ranks < 5)[0]) / len(ranks)
294
- tr10 = 100.0 * len(torch.where(ranks < 10)[0]) / len(ranks)
295
-
296
- # Text to Images
297
- ranks = torch.zeros(text_to_image_scores.size(0))
298
- for index, score in enumerate(text_to_image_scores):
299
- inds = torch.flip(torch.argsort(score), dims=[0])
300
- ranks[index] = torch.where(inds == text_to_image_mapping[index])[0][0]
301
-
302
- # Compute metrics
303
- ir1 = 100.0 * len(torch.where(ranks < 1)[0]) / len(ranks)
304
- ir5 = 100.0 * len(torch.where(ranks < 5)[0]) / len(ranks)
305
- ir10 = 100.0 * len(torch.where(ranks < 10)[0]) / len(ranks)
306
-
307
- tr_mean = (tr1 + tr5 + tr10) / 3
308
- ir_mean = (ir1 + ir5 + ir10) / 3
309
- r_mean = (tr_mean + ir_mean) / 2
310
-
311
- eval_result = {
312
- "txt_r1": tr1,
313
- "txt_r5": tr5,
314
- "txt_r10": tr10,
315
- "txt_r_mean": tr_mean,
316
- "img_r1": ir1,
317
- "img_r5": ir5,
318
- "img_r10": ir10,
319
- "img_r_mean": ir_mean,
320
- "r_mean": r_mean,
321
- }
322
- return eval_result
323
-
324
-
325
- @torch.no_grad()
326
- def format_output(
327
- image_to_text_scores,
328
- text_to_image_scores,
329
- image_dataset,
330
- text_dataset,
331
- ):
332
- image_to_text_output = {}
333
- for index, score in enumerate(image_to_text_scores):
334
- image = image_dataset.images[index]
335
- top10_ids = torch.flip(torch.argsort(score), dims=[0])[:10]
336
- top10_text = [text_dataset.text[i] for i in top10_ids]
337
- image_to_text_output[index] = {
338
- "image": image,
339
- "output": top10_text,
340
- }
341
- text_to_image_output = {}
342
- for index, score in enumerate(text_to_image_scores):
343
- text = text_dataset.text[index]
344
- top10_ids = torch.flip(torch.argsort(score), dims=[0])[:10]
345
- top10_images = [image_dataset.images[i] for i in top10_ids]
346
- text_to_image_output[index] = {
347
- "text": text,
348
- "output": top10_images,
349
- }
350
- return image_to_text_output, text_to_image_output
351
-
352
-
353
- def main():
354
- parser = argparse.ArgumentParser()
355
- parser.add_argument("--config", default="./examples/albef/configs/retrieval.yaml")
356
- args = parser.parse_args()
357
- config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
358
-
359
- init_distributed_mode(config)
360
- device = torch.device(config["device"])
361
-
362
- seed = config["seed"] + get_rank()
363
- torch.manual_seed(seed)
364
- random.seed(seed)
365
- cudnn.benchmark = True
366
-
367
- datamodule = RetrievalDataModule(**config["datamodule_args"])
368
- model = albef_model_for_retrieval(config, pretrained=True)
369
- model = model.to(device)
370
- if is_dist_avail_and_initialized():
371
- model = torch.nn.parallel.DistributedDataParallel(
372
- model, device_ids=[config["gpu"]]
373
- )
374
-
375
- train(model, datamodule, config["training_args"], device)
376
- image_to_text_scores, text_to_image_scores = evaluation(
377
- model, datamodule, config["eval_args"], device
378
- )
379
- val_result = itm_eval(
380
- image_to_text_scores,
381
- text_to_image_scores,
382
- datamodule.image_dataset.image_to_text,
383
- datamodule.text_dataset.text_to_image,
384
- )
385
- image_to_text_output, text_to_image_output = format_output(
386
- image_to_text_scores,
387
- text_to_image_scores,
388
- datamodule.image_dataset,
389
- datamodule.text_dataset,
390
- )
391
- result = {
392
- "image_to_text_output": image_to_text_output,
393
- "text_to_image_output": text_to_image_output,
394
- **val_result,
395
- }
396
- torch.save(result, config["output_path"])
397
-
398
-
399
- if __name__ == "__main__":
400
- main()