Spaces:
Runtime error
Runtime error
Delete finetune_retrieval.py
Browse files- finetune_retrieval.py +0 -400
finetune_retrieval.py
DELETED
@@ -1,400 +0,0 @@
|
|
1 |
-
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
-
# All rights reserved.
|
3 |
-
#
|
4 |
-
# This source code is licensed under the BSD-style license found in the
|
5 |
-
# LICENSE file in the root directory of this source tree.
|
6 |
-
|
7 |
-
import argparse
|
8 |
-
import datetime
|
9 |
-
import os
|
10 |
-
import random
|
11 |
-
import time
|
12 |
-
|
13 |
-
import ruamel.yaml as yaml
|
14 |
-
import torch
|
15 |
-
import torch.backends.cudnn as cudnn
|
16 |
-
import torch.distributed as dist
|
17 |
-
from data.retrieval_datamodule import RetrievalDataModule
|
18 |
-
from model import albef_model_for_retrieval
|
19 |
-
from torch.optim import AdamW
|
20 |
-
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
|
21 |
-
from utils import (
|
22 |
-
add_weight_decay,
|
23 |
-
get_rank,
|
24 |
-
get_world_size,
|
25 |
-
init_distributed_mode,
|
26 |
-
is_dist_avail_and_initialized,
|
27 |
-
is_main_process,
|
28 |
-
)
|
29 |
-
|
30 |
-
|
31 |
-
def train(model, datamodule, args, device):
|
32 |
-
model.train()
|
33 |
-
|
34 |
-
model_without_ddp = model.module if is_dist_avail_and_initialized() else model
|
35 |
-
|
36 |
-
optimizer_params = add_weight_decay(model, args["weight_decay"])
|
37 |
-
optimizer = AdamW(optimizer_params, lr=args["lr"])
|
38 |
-
scheduler = CosineAnnealingWarmRestarts(
|
39 |
-
optimizer, T_0=args["max_epochs"], eta_min=args["min_lr"]
|
40 |
-
)
|
41 |
-
|
42 |
-
step_size = args["step_size"]
|
43 |
-
warmup_steps = args["warmup_steps"]
|
44 |
-
warmup_iterations = warmup_steps * step_size
|
45 |
-
|
46 |
-
data_loader = datamodule.train_dataloader(
|
47 |
-
is_distributed=is_dist_avail_and_initialized(),
|
48 |
-
num_tasks=get_world_size(),
|
49 |
-
global_rank=get_rank(),
|
50 |
-
)
|
51 |
-
|
52 |
-
start_time = time.time()
|
53 |
-
|
54 |
-
for epoch in range(args["max_epochs"]):
|
55 |
-
if epoch > 0:
|
56 |
-
scheduler.step(epoch + warmup_steps)
|
57 |
-
|
58 |
-
for batch, (image, text, text_atts, idx) in enumerate(data_loader):
|
59 |
-
if epoch > 0:
|
60 |
-
alpha = args["alpha"]
|
61 |
-
else:
|
62 |
-
alpha = args["alpha"] * min(1, batch / len(data_loader))
|
63 |
-
|
64 |
-
image = image.to(device, non_blocking=True)
|
65 |
-
text = text.to(device)
|
66 |
-
text_atts = text_atts.to(device)
|
67 |
-
idx = idx.to(device, non_blocking=True)
|
68 |
-
loss = model(image, text, text_atts, idx, alpha, is_train=True)
|
69 |
-
|
70 |
-
optimizer.zero_grad()
|
71 |
-
loss.backward()
|
72 |
-
optimizer.step()
|
73 |
-
|
74 |
-
if epoch == 0 and batch % step_size == 0 and batch <= warmup_iterations:
|
75 |
-
scheduler.step(batch // step_size)
|
76 |
-
|
77 |
-
if batch % args["log_every_n_steps"] == 0:
|
78 |
-
total_time = time.time() - start_time
|
79 |
-
time_str = "time {},".format(
|
80 |
-
datetime.timedelta(seconds=int(total_time))
|
81 |
-
)
|
82 |
-
epoch_str = "epoch {}/{},".format(epoch, args["max_epochs"])
|
83 |
-
batch_str = "batch {}/{},".format(batch, len(data_loader))
|
84 |
-
loss_str = "loss {}".format(loss.item())
|
85 |
-
print(time_str, epoch_str, batch_str, loss_str)
|
86 |
-
|
87 |
-
if is_main_process():
|
88 |
-
save_obj = {
|
89 |
-
"model": model_without_ddp.state_dict(),
|
90 |
-
"optimizer": optimizer.state_dict(),
|
91 |
-
"lr_scheduler": scheduler.state_dict(),
|
92 |
-
"epoch": epoch,
|
93 |
-
}
|
94 |
-
torch.save(
|
95 |
-
save_obj,
|
96 |
-
os.path.join(
|
97 |
-
args["checkpoint_root"], "retrieval_checkpoint_%02d.pt" % epoch
|
98 |
-
),
|
99 |
-
)
|
100 |
-
|
101 |
-
if is_dist_avail_and_initialized():
|
102 |
-
dist.barrier()
|
103 |
-
torch.cuda.empty_cache()
|
104 |
-
|
105 |
-
|
106 |
-
@torch.no_grad()
|
107 |
-
def encode_text(model, text_dataloader, device):
|
108 |
-
text_embeds = []
|
109 |
-
text_feats = []
|
110 |
-
text_atts = []
|
111 |
-
for text, text_att in text_dataloader:
|
112 |
-
text = text.to(device)
|
113 |
-
text_att = text_att.to(device)
|
114 |
-
text_embed, text_feat = model(
|
115 |
-
text=text, text_atts=text_att, input_type="text", is_train=False
|
116 |
-
)
|
117 |
-
text_embeds.append(text_embed)
|
118 |
-
text_feats.append(text_feat)
|
119 |
-
text_atts.append(text_att)
|
120 |
-
text_embeds = torch.cat(text_embeds, dim=0)
|
121 |
-
text_feats = torch.cat(text_feats, dim=0)
|
122 |
-
text_atts = torch.cat(text_atts, dim=0)
|
123 |
-
return text_embeds, text_feats, text_atts
|
124 |
-
|
125 |
-
|
126 |
-
@torch.no_grad()
|
127 |
-
def encode_image(model, image_dataloader, device):
|
128 |
-
image_embeds = []
|
129 |
-
image_feats = []
|
130 |
-
for image in image_dataloader:
|
131 |
-
image = image.to(device)
|
132 |
-
image_embed, image_feat = model(image=image, input_type="image", is_train=False)
|
133 |
-
image_embeds.append(image_embed)
|
134 |
-
image_feats.append(image_feat)
|
135 |
-
image_embeds = torch.cat(image_embeds, dim=0)
|
136 |
-
image_feats = torch.cat(image_feats, dim=0)
|
137 |
-
return image_embeds, image_feats
|
138 |
-
|
139 |
-
|
140 |
-
@torch.no_grad()
|
141 |
-
def image_to_text(
|
142 |
-
model,
|
143 |
-
image_embeds,
|
144 |
-
text_embeds,
|
145 |
-
text_atts,
|
146 |
-
sims_matrix,
|
147 |
-
num_images,
|
148 |
-
num_text,
|
149 |
-
device,
|
150 |
-
args,
|
151 |
-
):
|
152 |
-
start_time = time.time()
|
153 |
-
world_size = get_world_size()
|
154 |
-
rank = get_rank()
|
155 |
-
step = sims_matrix.size(0) // world_size + 1
|
156 |
-
start = rank * step
|
157 |
-
end = min(sims_matrix.size(0), start + step)
|
158 |
-
k = args["k_test"]
|
159 |
-
|
160 |
-
image_to_text_scores = torch.full((num_images, num_text), -100.0).to(device)
|
161 |
-
for i, sims in enumerate(sims_matrix[start:end]):
|
162 |
-
_, topk_idx = sims.topk(k, dim=0)
|
163 |
-
|
164 |
-
score = model(
|
165 |
-
image=image_embeds[start + i].repeat(k, 1, 1),
|
166 |
-
text=text_embeds[topk_idx],
|
167 |
-
text_atts=text_atts[topk_idx],
|
168 |
-
input_type="multimodal",
|
169 |
-
is_train=False,
|
170 |
-
)
|
171 |
-
image_to_text_scores[start + i, topk_idx] = score
|
172 |
-
|
173 |
-
if i % args["log_every_n_steps"] == 0:
|
174 |
-
total_time = time.time() - start_time
|
175 |
-
time_str = "time {},".format(datetime.timedelta(seconds=int(total_time)))
|
176 |
-
batch_str = "batch {}/{},".format(i, len(sims_matrix[start:end]))
|
177 |
-
print("image to text retrieval", time_str, batch_str)
|
178 |
-
return image_to_text_scores
|
179 |
-
|
180 |
-
|
181 |
-
@torch.no_grad()
|
182 |
-
def text_to_image(
|
183 |
-
model,
|
184 |
-
image_embeds,
|
185 |
-
text_embeds,
|
186 |
-
text_atts,
|
187 |
-
sims_matrix,
|
188 |
-
num_images,
|
189 |
-
num_text,
|
190 |
-
device,
|
191 |
-
args,
|
192 |
-
):
|
193 |
-
start_time = time.time()
|
194 |
-
world_size = get_world_size()
|
195 |
-
rank = get_rank()
|
196 |
-
step = sims_matrix.size(0) // world_size + 1
|
197 |
-
start = rank * step
|
198 |
-
end = min(sims_matrix.size(0), start + step)
|
199 |
-
k = args["k_test"]
|
200 |
-
|
201 |
-
text_to_image_scores = torch.full((num_text, num_images), -100.0).to(device)
|
202 |
-
for i, sims in enumerate(sims_matrix[start:end]):
|
203 |
-
_, topk_idx = sims.topk(k, dim=0)
|
204 |
-
score = model(
|
205 |
-
image=image_embeds[topk_idx],
|
206 |
-
text=text_embeds[start + i].repeat(k, 1, 1),
|
207 |
-
text_atts=text_atts[start + i].repeat(k, 1, 1),
|
208 |
-
input_type="multimodal",
|
209 |
-
is_train=False,
|
210 |
-
)
|
211 |
-
text_to_image_scores[start + i, topk_idx] = score
|
212 |
-
|
213 |
-
if i % args["log_every_n_steps"] == 0:
|
214 |
-
total_time = time.time() - start_time
|
215 |
-
time_str = "time {},".format(datetime.timedelta(seconds=int(total_time)))
|
216 |
-
batch_str = "batch {}/{},".format(i, len(sims_matrix[start:end]))
|
217 |
-
print("text to image retrieval", time_str, batch_str)
|
218 |
-
return text_to_image_scores
|
219 |
-
|
220 |
-
|
221 |
-
@torch.no_grad()
|
222 |
-
def evaluation(model, datamodule, args, device):
|
223 |
-
model.eval()
|
224 |
-
|
225 |
-
text_loader = datamodule.text_dataloader()
|
226 |
-
image_loader = datamodule.image_dataloader()
|
227 |
-
num_images = len(datamodule.image_dataset)
|
228 |
-
num_text = len(datamodule.text_dataset)
|
229 |
-
|
230 |
-
text_embeds, text_feats, text_atts = encode_text(model, text_loader, device)
|
231 |
-
image_embeds, image_feats = encode_image(model, image_loader, device)
|
232 |
-
|
233 |
-
sims_matrix = image_feats @ text_feats.t()
|
234 |
-
image_to_text_scores = image_to_text(
|
235 |
-
model,
|
236 |
-
image_embeds,
|
237 |
-
text_embeds,
|
238 |
-
text_atts,
|
239 |
-
sims_matrix,
|
240 |
-
num_images,
|
241 |
-
num_text,
|
242 |
-
device,
|
243 |
-
args,
|
244 |
-
)
|
245 |
-
|
246 |
-
sims_matrix = sims_matrix.t()
|
247 |
-
text_to_image_scores = text_to_image(
|
248 |
-
model,
|
249 |
-
image_embeds,
|
250 |
-
text_embeds,
|
251 |
-
text_atts,
|
252 |
-
sims_matrix,
|
253 |
-
num_images,
|
254 |
-
num_text,
|
255 |
-
device,
|
256 |
-
args,
|
257 |
-
)
|
258 |
-
|
259 |
-
if is_dist_avail_and_initialized():
|
260 |
-
dist.barrier()
|
261 |
-
torch.distributed.all_reduce(
|
262 |
-
image_to_text_scores, op=torch.distributed.ReduceOp.SUM
|
263 |
-
)
|
264 |
-
torch.distributed.all_reduce(
|
265 |
-
text_to_image_scores, op=torch.distributed.ReduceOp.SUM
|
266 |
-
)
|
267 |
-
|
268 |
-
return image_to_text_scores.cpu(), text_to_image_scores.cpu()
|
269 |
-
|
270 |
-
|
271 |
-
@torch.no_grad()
|
272 |
-
def itm_eval(
|
273 |
-
image_to_text_scores,
|
274 |
-
text_to_image_scores,
|
275 |
-
image_to_text_mapping,
|
276 |
-
text_to_image_mapping,
|
277 |
-
):
|
278 |
-
# Images to Text
|
279 |
-
ranks = torch.zeros(image_to_text_scores.size(0))
|
280 |
-
for index, score in enumerate(image_to_text_scores):
|
281 |
-
inds = torch.flip(torch.argsort(score), dims=[0])
|
282 |
-
rank = 1e10
|
283 |
-
# each image has multiple text mappings
|
284 |
-
# check retrieved inds with each ground truth mappping i
|
285 |
-
for i in image_to_text_mapping[index]:
|
286 |
-
tmp = torch.where(inds == i)[0][0]
|
287 |
-
if tmp < rank:
|
288 |
-
rank = tmp
|
289 |
-
ranks[index] = rank
|
290 |
-
|
291 |
-
# Compute metrics
|
292 |
-
tr1 = 100.0 * len(torch.where(ranks < 1)[0]) / len(ranks)
|
293 |
-
tr5 = 100.0 * len(torch.where(ranks < 5)[0]) / len(ranks)
|
294 |
-
tr10 = 100.0 * len(torch.where(ranks < 10)[0]) / len(ranks)
|
295 |
-
|
296 |
-
# Text to Images
|
297 |
-
ranks = torch.zeros(text_to_image_scores.size(0))
|
298 |
-
for index, score in enumerate(text_to_image_scores):
|
299 |
-
inds = torch.flip(torch.argsort(score), dims=[0])
|
300 |
-
ranks[index] = torch.where(inds == text_to_image_mapping[index])[0][0]
|
301 |
-
|
302 |
-
# Compute metrics
|
303 |
-
ir1 = 100.0 * len(torch.where(ranks < 1)[0]) / len(ranks)
|
304 |
-
ir5 = 100.0 * len(torch.where(ranks < 5)[0]) / len(ranks)
|
305 |
-
ir10 = 100.0 * len(torch.where(ranks < 10)[0]) / len(ranks)
|
306 |
-
|
307 |
-
tr_mean = (tr1 + tr5 + tr10) / 3
|
308 |
-
ir_mean = (ir1 + ir5 + ir10) / 3
|
309 |
-
r_mean = (tr_mean + ir_mean) / 2
|
310 |
-
|
311 |
-
eval_result = {
|
312 |
-
"txt_r1": tr1,
|
313 |
-
"txt_r5": tr5,
|
314 |
-
"txt_r10": tr10,
|
315 |
-
"txt_r_mean": tr_mean,
|
316 |
-
"img_r1": ir1,
|
317 |
-
"img_r5": ir5,
|
318 |
-
"img_r10": ir10,
|
319 |
-
"img_r_mean": ir_mean,
|
320 |
-
"r_mean": r_mean,
|
321 |
-
}
|
322 |
-
return eval_result
|
323 |
-
|
324 |
-
|
325 |
-
@torch.no_grad()
|
326 |
-
def format_output(
|
327 |
-
image_to_text_scores,
|
328 |
-
text_to_image_scores,
|
329 |
-
image_dataset,
|
330 |
-
text_dataset,
|
331 |
-
):
|
332 |
-
image_to_text_output = {}
|
333 |
-
for index, score in enumerate(image_to_text_scores):
|
334 |
-
image = image_dataset.images[index]
|
335 |
-
top10_ids = torch.flip(torch.argsort(score), dims=[0])[:10]
|
336 |
-
top10_text = [text_dataset.text[i] for i in top10_ids]
|
337 |
-
image_to_text_output[index] = {
|
338 |
-
"image": image,
|
339 |
-
"output": top10_text,
|
340 |
-
}
|
341 |
-
text_to_image_output = {}
|
342 |
-
for index, score in enumerate(text_to_image_scores):
|
343 |
-
text = text_dataset.text[index]
|
344 |
-
top10_ids = torch.flip(torch.argsort(score), dims=[0])[:10]
|
345 |
-
top10_images = [image_dataset.images[i] for i in top10_ids]
|
346 |
-
text_to_image_output[index] = {
|
347 |
-
"text": text,
|
348 |
-
"output": top10_images,
|
349 |
-
}
|
350 |
-
return image_to_text_output, text_to_image_output
|
351 |
-
|
352 |
-
|
353 |
-
def main():
|
354 |
-
parser = argparse.ArgumentParser()
|
355 |
-
parser.add_argument("--config", default="./examples/albef/configs/retrieval.yaml")
|
356 |
-
args = parser.parse_args()
|
357 |
-
config = yaml.load(open(args.config, "r"), Loader=yaml.Loader)
|
358 |
-
|
359 |
-
init_distributed_mode(config)
|
360 |
-
device = torch.device(config["device"])
|
361 |
-
|
362 |
-
seed = config["seed"] + get_rank()
|
363 |
-
torch.manual_seed(seed)
|
364 |
-
random.seed(seed)
|
365 |
-
cudnn.benchmark = True
|
366 |
-
|
367 |
-
datamodule = RetrievalDataModule(**config["datamodule_args"])
|
368 |
-
model = albef_model_for_retrieval(config, pretrained=True)
|
369 |
-
model = model.to(device)
|
370 |
-
if is_dist_avail_and_initialized():
|
371 |
-
model = torch.nn.parallel.DistributedDataParallel(
|
372 |
-
model, device_ids=[config["gpu"]]
|
373 |
-
)
|
374 |
-
|
375 |
-
train(model, datamodule, config["training_args"], device)
|
376 |
-
image_to_text_scores, text_to_image_scores = evaluation(
|
377 |
-
model, datamodule, config["eval_args"], device
|
378 |
-
)
|
379 |
-
val_result = itm_eval(
|
380 |
-
image_to_text_scores,
|
381 |
-
text_to_image_scores,
|
382 |
-
datamodule.image_dataset.image_to_text,
|
383 |
-
datamodule.text_dataset.text_to_image,
|
384 |
-
)
|
385 |
-
image_to_text_output, text_to_image_output = format_output(
|
386 |
-
image_to_text_scores,
|
387 |
-
text_to_image_scores,
|
388 |
-
datamodule.image_dataset,
|
389 |
-
datamodule.text_dataset,
|
390 |
-
)
|
391 |
-
result = {
|
392 |
-
"image_to_text_output": image_to_text_output,
|
393 |
-
"text_to_image_output": text_to_image_output,
|
394 |
-
**val_result,
|
395 |
-
}
|
396 |
-
torch.save(result, config["output_path"])
|
397 |
-
|
398 |
-
|
399 |
-
if __name__ == "__main__":
|
400 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|